Afaik, JAX doesn’t compile “Python” directly — it traces your function and lowers everything to a small set of primitives. High-level NumPy functions like reshape and broadcast are rewritten into those primitives, and you can inspect the exact Jaxpr with make_jaxpr. For a deeper dive, the JAX paper and the docs on Jaxpr are the authoritative resources.
Reference:
The most direct explanation is in the JAX documentation and the original JAX paper: