I suggest casting mtl to float32 before casting it to your device.
Replace mtl.to(device) by:
mtl.to(torch.float32).to(device)
Best,
Arnaud