79391584

Date: 2025-01-27 17:07:18
Score: 0.5
Natty:
Report link

I also had similar issues with GPU memory when loading the dataset into memory. I see you're using a generator which is fine in your example, but why would you use model.predict(d1024.take(n)) instead of simply model.predict(d1024)?

When you use dataset.take(n) it will create a new dataset with n batches, so it will not process the entire dataset. Furthermore, it will try to load at once the n batches of your dataset into the GPU which explains why you get memory problems.

I found the best approach for me is using a custom generator that yield batches of data so only one batch at the time is loaded into memory.

Something like:

def gen():

   while True:
      ...

      yield X, Y

You'll be sure of having no memory problems and using X and Y as numpy arrays instead of using tf.Dataset you also have more flexibility.

Reasons:
  • Long answer (-0.5):
  • Has code block (-0.5):
  • Contains question mark (0.5):
  • Low reputation (1):
Posted by: Ruben Bento