79682996

Date: 2025-06-28 12:29:40
Score: 0.5
Natty:
Report link

@Vikas solution only works for LongTensors. Here is an efficient version which works for any dtype:

import torch

def batch_histogram(scores, num_bins, eps=1e-6):
    """
    Compute histograms for a batch of score arrays.
    
    Args:
        scores (torch.Tensor): Input tensor of shape (batch_size, num_scores)
                               containing the values to histogram
        num_bins (int): Number of histogram bins
        eps (float, optional): Small epsilon value to prevent division by zero.
                              Defaults to 1e-9.
    
    Returns:
        torch.Tensor: Histogram tensor of shape (batch_size, num_bins)
                     where each row contains the bin counts for the 
                     corresponding input row
    
    Example:
        >>> scores = tensor([[ 0.7338,  1.2722, -1.0576,  0.2836, -0.5123],
                             [ 1.0205, -0.6672, -1.0974, -0.1666, -0.6787]])
        >>> hist = batch_histogram(scores, num_bins=3)
        >>> hist
        tensor([[2., 1., 2.],
                [3., 1., 1.]])
    
    Note:
        This is equivalent to:
        torch.stack([torch.histc(scores[i], bins=num_bins) for i in range(scores.shape[0])])
        but is more efficient for batch processing.
    """
    batch_size = scores.shape[0]
    
    # Initialize histogram tensor and ones for counting
    hist = torch.zeros((batch_size, num_bins), dtype=scores.dtype, device=scores.device)
    ones = torch.ones_like(scores)
    
    # Find min and max values for each batch element
    batch_min = torch.min(scores, dim=1, keepdim=True)[0]
    batch_max = torch.max(scores, dim=1, keepdim=True)[0] + eps
    
    # Normalize scores to [0, 1] range
    normalized_scores = (scores - batch_min) / (batch_max - batch_min)
    
    # Convert to bin indices (floor)
    bin_indices = (normalized_scores * num_bins).long()

    # Accumulate counts in histogram
    hist.scatter_add_(1, bin_indices, ones)
    
    return hist

You might want to include bin_indices = torch.clamp(bin_indices, 0, num_bins - 1) after getting the bin_indices in order to avoid out-of-bounds errors.

Reasons:
  • Long answer (-1):
  • Has code block (-0.5):
  • User mentioned (1): @Vikas
  • Low reputation (1):
Posted by: Marcos Treviso