79286273

Date: 2024-12-16 23:00:59
Score: 0.5
Natty:
Report link

I realized I should set the number of batch dimensions to 1. This now works:

import tensorflow as tf
V=tf.random.uniform((2,3,4),minval=0,maxval=2,dtype=tf.int32)
print(V[0,:])
W=tf.random.uniform((2,5), minval=0, maxval=4, dtype=tf.int32)
print('W=',W)
Z=tf.gather(params=W, indices=V, axis=1, batch_dims=1)
print(Z.shape)

The shape is (2,3,4), as desired. Simple in the end, but a pain to figure out!

Reasons:
  • Has code block (-0.5):
  • Self-answer (0.5):
  • Low reputation (0.5):
Posted by: user3433489