In case of CNNs the mean/variance should be taken across all pixels over the batch for each input channel. In other words, if your input is of shape (B, C, H, W)
, your mean/variance will be of shape (1,C,1,1)
. The reason for that is that the weights of a kernel in a CNN are shared in a spatial dimension (HxW
). However, in the channel dimension C
the weights are not shared.
In contrast, in case of a fully-connected network the inputs would be (B, D)
and the mean/variance will be of shape (1, D)
, as there is no notion of spatial dimension and channels, but only features. Each feature is normalized independently over the batch.
References: