To resolve the device mismatch error, you should let RLlib and PyTorch manage device placement automatically.
Layers are no longer explicity moved to to(self.device)
during initialization
Used dynamic device detection of the input self.device = input_dict["obs"].device
Only inputs in the forward
method and values_out
in the value_function
are moved to the model's device manually.
It's also important to override the forward
and value_function
methods, as suggested by @Marzi Heifari.
Here is the modified version:
import torch
import torch.nn as nn
from ray.rllib.models.torch.torch_modelv2 import TorchModelV2
from ray.rllib.utils.annotations import override, DeveloperAPI
from ray.rllib.models.modelv2 import ModelV2
@DeveloperAPI
class SimpleTransformer(TorchModelV2, nn.Module):
def __init__(self, obs_space, action_space, num_outputs, model_config, name):
TorchModelV2.__init__(self, obs_space, action_space, num_outputs, model_config, name)
nn.Module.__init__(self)
# Configuration
custom_config = model_config["custom_model_config"]
self.input_dim = 76
self.seq_len = custom_config["seq_len"]
self.embed_size = custom_config["embed_size"]
self.nheads = custom_config["nhead"]
self.nlayers = custom_config["nlayers"]
self.dropout = custom_config["dropout"]
self.values_out = None
self.device = None
# Input layer
self.input_embed = nn.Linear(self.input_dim, self.embed_size)
# Positional encoding
self.pos_encoding = nn.Embedding(self.seq_len, self.embed_size)
# Transformer
self.transformer = nn.TransformerEncoder(
nn.TransformerEncoderLayer(
d_model=self.embed_size,
nhead=self.nheads,
dropout=self.dropout,
activation='gelu'),
num_layers=self.nlayers
)
# Policy and value heads
self.policy_head = nn.Sequential(
nn.Linear(self.embed_size + 2, 64), # Add dynamic features (wallet balance, unrealized PnL)
nn.ReLU(),
nn.Linear(64, num_outputs) # Action space size
)
self.value_head = nn.Sequential(
nn.Linear(self.embed_size + 2, 64),
nn.ReLU(),
nn.Linear(64, 1)
)
@override(ModelV2)
def forward(self, input_dict, state, seq_lens):
self.device = input_dict["obs"].device
x = input_dict["obs"].view(-1, self.seq_len, self.input_dim).to(self.device)
dynamic_features = x[:, -1, 2:4].clone()
x = self.input_embed(x)
position = torch.arange(0, self.seq_len, device=self.device).unsqueeze(0).expand(x.size(0), -1)
x = x + self.pos_encoding(position)
transformer_out = self.transformer(x)
last_out = transformer_out[:, -1, :]
combined = torch.cat((last_out, dynamic_features), dim=1)
logits = self.policy_head(combined)
self.values_out = self.value_head(combined).squeeze(1)
return logits, state
@override(ModelV2)
def value_function(self):
return self.values_out.to(self.device)