Your energy function contains NumPy operations (sum, linspace, etc.) and a custom deriv function that likely uses NumPy, which breaks JAX’s tracing,that's why there is an error.
Try to replace all NumPy operations with jax.numpy
in the energy
function and deriv
.