I also had similar issues with GPU memory when loading the dataset into memory.
I see you're using a generator which is fine in your example, but why would you use model.predict(d1024.take(n))
instead of simply model.predict(d1024)
?
When you use dataset.take(n)
it will create a new dataset with n
batches, so it will not process the entire dataset. Furthermore, it will try to load at once the n
batches of your dataset into the GPU which explains why you get memory problems.
I found the best approach for me is using a custom generator that yield
batches of data so only one batch at the time is loaded into memory.
Something like:
def gen():
while True:
...
yield X, Y
You'll be sure of having no memory problems and using X and Y as numpy arrays instead of using tf.Dataset you also have more flexibility.