jax.profiler.start_server
doesn't take a trace by itself. It allows you to use the Tensorboard UI for starting a trace (https://jax.readthedocs.io/en/latest/profiling.html#manual-capture-via-tensorboard). This could be a good way to control how many seconds you're capturing.
That's odd that your trace is < 1GB, yet it says you're hitting the 2GB limit. I can't comment to ask questions that would help debug, so I suggest filing an issue at https://github.com/jax-ml/jax/issues and we can help you more there.
As a workaround, I suggest capturing many smaller traces instead of one large 300s trace.