The problem has been solved (thanks to dai on discord). The issue is input_precision is tf32 by default for dot product, which has 10bits mantissa - leading to trailing digit loss. The problem was very pronounced with V = torch.arange(4096, 4096 + 2048, device = 'cuda', dtype = torch.float32), where the output was [6080., 6080., 6080., 6080., 6084., 6084., 6084., 6084., 6088.,...] . Switching to "ieee" input precision tl.dot(x,y, input_precision = "ieee") solved the issue.