Following the insight by @Obaskly and @jakevdp, I went with the following wrapper:
# Wrapped fori to handle avoid tracing the case upper<=lower
def fori(lower, upper, body_fun, init_val, unroll=None):
    if upper<=lower:
        out = init_val
    else:
        out = jax.lax.fori_loop(lower,upper,body_fun,init_val,unroll=unroll)
    return out
This produces the correct behavior. Maybe it could also be done with jax.lax.cond if it doesn't reintroduce the tracing issue.