To resolve the device mismatch error, you should let RLlib and PyTorch manage device placement automatically.
Layers are no longer explicity moved to to(self.device) during initialization
Used dynamic device detection of the input self.device = input_dict["obs"].device
Only inputs in the forward method and values_out in the value_function are moved to the model's device manually.
It's also important to override the forward and value_function methods, as suggested by @Marzi Heifari.
Here is the modified version:
import torch
import torch.nn as nn
from ray.rllib.models.torch.torch_modelv2 import TorchModelV2
from ray.rllib.utils.annotations import override, DeveloperAPI
from ray.rllib.models.modelv2 import ModelV2
@DeveloperAPI
class SimpleTransformer(TorchModelV2, nn.Module):
def __init__(self, obs_space, action_space, num_outputs, model_config, name):
TorchModelV2.__init__(self, obs_space, action_space, num_outputs, model_config, name)
nn.Module.__init__(self)
# Configuration
custom_config = model_config["custom_model_config"]
self.input_dim = 76
self.seq_len = custom_config["seq_len"]
self.embed_size = custom_config["embed_size"]
self.nheads = custom_config["nhead"]
self.nlayers = custom_config["nlayers"]
self.dropout = custom_config["dropout"]
self.values_out = None
self.device = None
# Input layer
self.input_embed = nn.Linear(self.input_dim, self.embed_size)
# Positional encoding
self.pos_encoding = nn.Embedding(self.seq_len, self.embed_size)
# Transformer
self.transformer = nn.TransformerEncoder(
nn.TransformerEncoderLayer(
d_model=self.embed_size,
nhead=self.nheads,
dropout=self.dropout,
activation='gelu'),
num_layers=self.nlayers
)
# Policy and value heads
self.policy_head = nn.Sequential(
nn.Linear(self.embed_size + 2, 64), # Add dynamic features (wallet balance, unrealized PnL)
nn.ReLU(),
nn.Linear(64, num_outputs) # Action space size
)
self.value_head = nn.Sequential(
nn.Linear(self.embed_size + 2, 64),
nn.ReLU(),
nn.Linear(64, 1)
)
@override(ModelV2)
def forward(self, input_dict, state, seq_lens):
self.device = input_dict["obs"].device
x = input_dict["obs"].view(-1, self.seq_len, self.input_dim).to(self.device)
dynamic_features = x[:, -1, 2:4].clone()
x = self.input_embed(x)
position = torch.arange(0, self.seq_len, device=self.device).unsqueeze(0).expand(x.size(0), -1)
x = x + self.pos_encoding(position)
transformer_out = self.transformer(x)
last_out = transformer_out[:, -1, :]
combined = torch.cat((last_out, dynamic_features), dim=1)
logits = self.policy_head(combined)
self.values_out = self.value_head(combined).squeeze(1)
return logits, state
@override(ModelV2)
def value_function(self):
return self.values_out.to(self.device)