79705081

Date: 2025-07-17 15:49:00
Score: 1
Natty:
Report link

Yes, behavior difference make sense in JAX.as When var was a static class member (in the pytree), JAX treats it as a "compile-time constant" - it gets baked into any JIT-compiled code and isn't part of the differentiable computation graph.

But when you move it to a function argument, JAX now sees it as a "dynamic value" that can change between calls by which It becomes part of the computation graph; Gradients can flow through it;gets subject to transformations (like vmap)

u can Use class members for true constants that never change or Use function arguments for values you might differentiate through or batch over

Reasons:
  • Long answer (-0.5):
  • No code block (0.5):
  • Low reputation (1):
Posted by: Vinayak Narayan