This happens because of extreme logits in my model. imbalanced datasets and small pos_weight values make the logits explode (e.g., 1e20). and this caused the loss to become NaN. I have stabilized gradients.
from torch.cuda.amp import autocast, GradScaler
scaler = GradScaler()
for batch in dataloader:
with autocast():
logits = model(input_ids, attention_mask)
loss = criterion(logits, labels)
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
then, I have added a bit of smoothing to reduce over confident prediction.
def smooth_labels(labels, smoothing=0.1):
return labels * (1 - smoothing) + 0.5 * smoothing
smoothed_labels = smooth_labels(labels)
loss = criterion(logits, smoothed_labels)
then, to avoid exploding gradients, I has added an L2 regularization.
reg_lambda = 0.01
l2_reg = sum(torch.norm(p) for p in model.parameters())
loss += reg_lambda * l2_reg
and finally, I had normalized the logits with BatchNorm after nn.Linear.
self.classifier = nn.Sequential(
nn.Linear(self.bert.config.hidden_size, num_labels),
nn.BatchNorm1d(num_labels)
)
Problem solved. everything seems fine now. thanks.