Another approach that is nice and concise:
x = torch.where(x == 2, torch.nan, x)
Or
x = torch.where(x != 2, x, torch.nan)
Copied from the question at: https://discuss.pytorch.org/t/filtered-mean-and-std-from-a-tensor/147258