Following the insight by @Obaskly and @jakevdp, I went with the following wrapper:
# Wrapped fori to handle avoid tracing the case upper<=lower
def fori(lower, upper, body_fun, init_val, unroll=None):
if upper<=lower:
out = init_val
else:
out = jax.lax.fori_loop(lower,upper,body_fun,init_val,unroll=unroll)
return out
This produces the correct behavior. Maybe it could also be done with jax.lax.cond if it doesn't reintroduce the tracing issue.