Thanks to jakevdp's comment, I got a significant speedup using one-hot matrix multiplication. I changed to the following code:
@jax.jit
def index_points_3d(features, indices):
"""
Args:
features: shape (B, N, C)
indices: shape (B, npoint, nsample)
Returns:
shape (B, npoint, nsample, C)
"""
B, N, C = features.shape
_, S, K = indices.shape
one_hot = jax.nn.one_hot(indices, num_classes=N, dtype=features.dtype)
return jnp.einsum('bskn,bnc->bskc', one_hot, features)