79783914

Date: 2025-10-06 17:01:09
Score: 1.5
Natty:
Report link

Following the insight by @Obaskly and @jakevdp, I went with the following wrapper:

# Wrapped fori to handle avoid tracing the case upper<=lower
def fori(lower, upper, body_fun, init_val, unroll=None):
    if upper<=lower:
        out = init_val
    else:
        out = jax.lax.fori_loop(lower,upper,body_fun,init_val,unroll=unroll)
    return out

This produces the correct behavior. Maybe it could also be done with jax.lax.cond if it doesn't reintroduce the tracing issue.

Reasons:
  • Has code block (-0.5):
  • User mentioned (1): @Obaskly
  • User mentioned (0): @jakevdp
  • Self-answer (0.5):
  • Low reputation (0.5):
Posted by: Ben