79374295

Date: 2025-01-21 11:59:45
Score: 0.5
Natty:
Report link

I found a horrible hack that allows me to do make jit functions which operate on masked vectors. I'm hoping that @jakevdp is gonna show up and show me how it's done.

Ok so the idea is that even though x2[x2_mask] has size which is dependent on the values of x2_mask, for the purposes of the operation x.at[x2[x2_mask][:, 0], x2[x2_mask][:, 1]].add(1) once it's compiled, this is just a for loop which skips based on the values of x2_mask. Even though the intermediate variables's sizes are not known at compile time, the memory layout to carry out this operation, is known at compile time. But if jax needs intermediate variables with statically known size then we'll fill the values of the masked array we don't care about with coordinates we don't care about. But since in principle we care about all coordinates of x1, then we first enlarge it.

Lets create a minimal example that shows the problem:

x = jnp.zeros((5,5))
coords = jnp.array([
    [1,2],
    [2,3],
    [1,2],
    [1,2],
])
coords_mask = jnp.array([True, True, False, True])

def testfunc(x, coords, coords_mask):
    coords_masked = coords[coords_mask]
    return x.at[coords_masked[:, 0], coords_masked[:, 1]].add(1)

testfunc(x, coords, coords_mask)

This works, outputs

[[0., 0., 0., 0., 0.],
 [0., 0., 2., 0., 0.],
 [0., 0., 0., 1., 0.],
 [0., 0., 0., 0., 0.],
 [0., 0., 0., 0., 0.]]

note that one of the [1,2] has been masked out and the other two were counted twice.

But this doesn't work:

@jax.jit
def testfunc(x, coords, coords_mask):
    coords_masked = coords[coords_mask]
    return x.at[coords_masked[:, 0], coords_masked[:, 1]].add(1)

testfunc(x, coords, coords_mask)  # NonConcreteBooleanIndexError

So here's a horrible hack around:

@jax.jit
def testfunc(x, coords, coords_mask):
    len_0, len_1 = x.shape

    # enlarge x by 1 in axis=1
    x = jnp.concatenate([x, jnp.zeros((len_0, 1))], axis=1)

    # prepare mask coordinates so that the False points to a position in the enlarged x array
    default = jnp.full(coords.shape, fill_value=jnp.array([0, len_1]))
    mask_repeated = jnp.repeat(coords_mask.reshape((coords.shape[0],1)), coords.shape[1], axis=1)
    coords_masked = jnp.where(mask_repeated, coords, default)

    # scatter coords onto enlarged x
    x = x.at[coords_masked[:, 0], coords_masked[:, 1]].add(1)
    
    # take a slice of x
    x = x[:, :len_1]
    return x

testfunc(x, coords, coords_mask)  # works
Reasons:
  • Long answer (-1):
  • Has code block (-0.5):
  • User mentioned (1): @jakevdp
  • Self-answer (0.5):
  • Low reputation (0.5):
Posted by: oneloop