you might try calling your training script like so:
PYTORCH_ENABLE_MPS_FALLBACK=1 python train_me.py
that would make sure that the variable is set when you pytorch goes looking for it :)