So after more searching, I think the issues of it failing and its VMEM consumption are unrelated. As @teapot418 mentioned above, this is likely a multi-threading issue.
I added the following to my import statements before importing Numpy, Tensorflow, and Ray:
import os
os.environ['OPENBLAS_NUM_THREADS'] = '1'
os.environ['OMP_NUM_THREADS'] = '1'
os.environ['RAY_num_server_call_thread'] = '1'
os.environ['TF_NUM_INTEROP_THREADS'] = '1'
os.environ['TF_NUM_INTRAOP_THREADS'] = '1'
And it now runs - still consumes 4TB of VMEM but I don't care anymore.
I'm guessing the issue is due to one or some of the packages enabling multi-threading when I actually want multi-core processing. Someone better will have to explain to me what the difference is and why mixing multi-threading and multi-processing is bad and why I don't encounter this when running on my laptop.
The above was pieced together from the following discussions:
Just to note that setting the environment variables OPENBLAS_NUM_THREADS
and OMP_NUM_THREADS
alone did not work for the original code, but did work for the simplified code in the post above. My guess is that Numpy uses OpenBLAS to multi-thread operations and was trying to spawn too many threads, thus causing it to fail. The original code does not involve such large matrix multiplications, but does use Tensorflow hence requiring the TF_NUM_INTEROP_THREADS
and TF_NUM_INTRAOP_THREADS
environment variables to be set to 1 to stop multi-threading.
This does raise the question of, in more complex codes which do involve large matrix multiplication or other operations that Numpy multi-threads and as well as Tensorflow operations, should I use multi-threading or multi-processing? Which is faster? But question for another time... I'm sleepy...