This was solved by passing allow_unused=True
and materialize_grads=True
to grad
. That is:
d_loss_params = grad(loss, model.parameters(), retain_graph=True, allow_unused=True, materialize_grads=True)
See discussion on https://discuss.pytorch.org/t/gradient-of-loss-that-depends-on-gradient-of-network-with-respect-to-parameters/217275 for more info.