The error shows you're trying to allocate ~69GB (69271363584 bytes), which is far too large for most GPUs or CPUs
a) Input Dimensions:
You haven't shared the dimensions of your input data, but this could be a key issue Check your batch size, sequence length, and d_model values The memory usage grows quadratically with sequence length due to attention mechanisms
b) Positional Encoding Implementation:
self.encoding = torch.zeros(max_len, d_model)
Suggestions for fixing the issue:
class PositionalEncoding(nn.Module): def init(self, d_model, max_len=5000, dropout=0.1): super(PositionalEncoding, self).init() self.dropout = nn.Dropout(p=dropout)
# Create more memory-efficient positional encoding
pe = torch.zeros(max_len, d_model)
position = torch.arange(0, max_len).unsqueeze(1).float()
div_term = torch.exp(
torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)
)
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
pe = pe.unsqueeze(0)
# Register buffer instead of parameter
self.register_buffer('pe', pe)
def forward(self, x):
x = x + self.pe[:, :x.size(1)]
return self.dropout(x)
Additional optimization suggestions:
class TransformerDecoder(nn.Module): def init(self, vocab_size, d_model, num_heads, num_layers, max_seq_length): super(TransformerDecoder, self).init() self.embedding = nn.Embedding(vocab_size, d_model) self.positional_encoding = PositionalEncoding(d_model, max_seq_length)
# Add dropout
self.dropout = nn.Dropout(0.1)
# Memory-efficient decoder layer
decoder_layer = nn.TransformerDecoderLayer(
d_model=d_model,
nhead=num_heads,
dim_feedforward=4*d_model, # Standard size
dropout=0.1,
batch_first=True # Avoid permute operations
)
self.transformer_decoder = nn.TransformerDecoder(decoder_layer, num_layers=num_layers)
self.fc_out = nn.Linear(d_model, vocab_size)
def forward(self, target_seqs, memory, target_mask):
embedded = self.embedding(target_seqs)
embedded = self.positional_encoding(embedded)
embedded = self.dropout(embedded)
# No need for permute if batch_first=True
output = self.transformer_decoder(embedded, memory, target_mask)
return self.fc_out(output)
To debug this issue:
a) Print the shapes of your tensors:
def forward(self, target_seqs, memory, target_mask):
print(f"Target seqs shape: {target_seqs.shape}")
print(f"Memory shape: {memory.shape}")
print(f"Target mask shape: {target_mask.shape}")
# ... rest of the forward method
b) Try reducing these parameters:
c) Add gradient checkpointing more strategically: from torch.utils.checkpoint import checkpoint
class TransformerDecoder(nn.Module):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.use_checkpointing = True # Add a flag for checkpointing
def forward(self, target_seqs, memory, target_mask):
if self.use_checkpointing and self.training:
return checkpoint(self._forward, target_seqs, memory, target_mask)
return self._forward(target_seqs, memory, target_mask)
def _forward(self, target_seqs, memory, target_mask):
# Original forward pass code here
pass
Could you share:
This would help pinpoint the exact cause of the memory issue.