Why does loss.backward() cause such a significant memory increase in GPU which there is a matrix on it ?
When you call loss.backward()
, all the gradients of your network are computed. There is one gradient per tensor of parameter. Basically, this multiplies by two the memory usage.
How can I optimize this setup to run on GPUs without running out of memory? Are there specific strategies or PyTorch functionalities that can reduce GPU memory usage during the backward pass?
You could try using mixed-precision training, reduce the dimensions of your input, use SGD instead of Adam, use a network with fewer FLOPs (for instance one that downsamples your input quite early), or use a larger GPU.