79632917

Date: 2025-05-22 00:57:57
Score: 1
Natty:
Report link

Thanks to jakevdp's comment, I got a significant speedup using one-hot matrix multiplication. I changed to the following code:

@jax.jit
def index_points_3d(features, indices):
    """
    Args:
        features: shape (B, N, C)
        indices: shape (B, npoint, nsample)
    
    Returns:
        shape (B, npoint, nsample, C)
    """
    B, N, C = features.shape
    _, S, K = indices.shape
    one_hot = jax.nn.one_hot(indices, num_classes=N, dtype=features.dtype)
    return jnp.einsum('bskn,bnc->bskc', one_hot, features)
Reasons:
  • Blacklisted phrase (0.5): Thanks
  • Long answer (-0.5):
  • Has code block (-0.5):
  • Self-answer (0.5):
  • Low reputation (1):
Posted by: minchou323