For this specific issue, I noticed that the torch linear model was the reason for the randomness, and adding
torch.nn.init.xavier_uniform_(self.linear.weight) # Xavier initialization
torch.nn.init.zeros_(self.linear.bias)
before the linear model fixed the randomness.