@Vikas solution only works for LongTensors. Here is an efficient version which works for any dtype:
import torch
def batch_histogram(scores, num_bins, eps=1e-6):
"""
Compute histograms for a batch of score arrays.
Args:
scores (torch.Tensor): Input tensor of shape (batch_size, num_scores)
containing the values to histogram
num_bins (int): Number of histogram bins
eps (float, optional): Small epsilon value to prevent division by zero.
Defaults to 1e-9.
Returns:
torch.Tensor: Histogram tensor of shape (batch_size, num_bins)
where each row contains the bin counts for the
corresponding input row
Example:
>>> scores = tensor([[ 0.7338, 1.2722, -1.0576, 0.2836, -0.5123],
[ 1.0205, -0.6672, -1.0974, -0.1666, -0.6787]])
>>> hist = batch_histogram(scores, num_bins=3)
>>> hist
tensor([[2., 1., 2.],
[3., 1., 1.]])
Note:
This is equivalent to:
torch.stack([torch.histc(scores[i], bins=num_bins) for i in range(scores.shape[0])])
but is more efficient for batch processing.
"""
batch_size = scores.shape[0]
# Initialize histogram tensor and ones for counting
hist = torch.zeros((batch_size, num_bins), dtype=scores.dtype, device=scores.device)
ones = torch.ones_like(scores)
# Find min and max values for each batch element
batch_min = torch.min(scores, dim=1, keepdim=True)[0]
batch_max = torch.max(scores, dim=1, keepdim=True)[0] + eps
# Normalize scores to [0, 1] range
normalized_scores = (scores - batch_min) / (batch_max - batch_min)
# Convert to bin indices (floor)
bin_indices = (normalized_scores * num_bins).long()
# Accumulate counts in histogram
hist.scatter_add_(1, bin_indices, ones)
return hist
You might want to include bin_indices = torch.clamp(bin_indices, 0, num_bins - 1)
after getting the bin_indices in order to avoid out-of-bounds errors.