If your span lengths are equal, we can do this:
idx = torch.arange(tokens.size(1)).unsqueeze(0)
mask = (start_index.unsqueeze(1) <= idx) & (idx < end_index.unsqueeze(1) )
spans = tokens[mask].view(tokens.size(0), -1)
otherwise we cannot store it in a single tensor as @dennlinger mentioned