79146829

Date: 2024-11-01 03:14:27
Score: 1
Natty:
Report link

By using the method provided by @trialNerror, slicing performance improved by about 8 times on my machine compared to the for loop.

import torch
import random
import time 
B = 1024
N = 40
D = 512
src_tensor = torch.randn((B,N,D))
indices = [random.randint(0,N-1-1) for _ in range(B)]
t1 = torch.empty((B,2,D))
t2 = torch.empty((B,2,D))
start = time.perf_counter()
t1[:,0,:] = src_tensor[range(B),indices]
t1[:,1,:] = src_tensor[:,1:,:][range(B),indices]
end1 = time.perf_counter()
for i in range(B):
    p1,p2 = indices[i],indices[i]+1
    t2[i,:,:]=src_tensor[i,[p1,p2],:]
end2 = time.perf_counter()

print(f't1 == t2 ? :{t1.equal(t2)}')
print(f't1: {end1-start}')
print(f't2: {end2-end1}')
print(f't2/t1: {(end2-end1)/(end1-start)}')

# t1 == t2 ? :True
# t1: 0.002891700016334653
# t2: 0.023338499944657087
# t2/t1: 8.070857908089506
Reasons:
  • Long answer (-0.5):
  • Has code block (-0.5):
  • User mentioned (1): @trialNerror
  • Self-answer (0.5):
  • Low reputation (0.5):
Posted by: MasterLu