This Post is for: understanding the relation bt hidden_size & output_size
+ h_t vs out
+ value visualize
in Rnn.
(disclaimer: This is just my personal understanding. I can be wrong.)
Even though, in math, the output_size of y should be customizable base on the shape of the weight matrix of y.
Math:
\[{{\mathbf{h}}_t} = \operatorname{activation} \left( {{{\mathbf{x}}_t}{{\mathbf{W}}_i} + {{\mathbf{h}}_{t - 1}}{{\mathbf{W}}_h} + {{\mathbf{b}}_h}} \right)\]
\[{{\mathbf{y}}_t} = {{\mathbf{h}}_t}{{\mathbf{W}}_o} + {{\mathbf{b}}_o}\]
It seems that Pytroch decided to treat the output y == hidden state h
-- (use output as hidden state for next time step).
Which is ok, some people use this design for rnn.
So, in pytorch, the output_size == hidden_size
. \
Related:
@note: [h_t vs out]
Though I said, >"Pytroch decided to treat the output y == hidden state h".
The actual output has slight difference.
h_t : This is the final hidden state after processing all time steps.
out : contains output at each time step
You can try & see the values are indeed the same. (Which many other answer posts have done.)
out, h_t = self.rnn(x, h0)
for i in range(0, h_t.shape[1]):
# for every batch, the tensor in the last time step of {y_t} output == the tensor in the last layer of h_t hidden_state
if not torch.equal(out[i, -1], h_t[-1, i]):
Following is the a design of RNN,
for complete code, see https://www.youtube.com/watch?v=0_PgWWmauHk
-> https://github.com/patrickloeber/pytorch-examples/blob/master/rnn-lstm-gru/main.py
This uses rnn to predict mnist digits.
Uses Row-wise Flattening of a 2d mnist matrix.
With
input_size = 28 # H_in - input_size – The number of expected features in the input x # size of each vector (row or col wise)
sequence_length = 28 # L - sequence length or the number of time steps
Uses 2 stacked layers in rnn.
The output_size is transformed after rnn, at
self.fc = nn.Linear(hidden_size, num_classes)
# @shape: (batch_size, 128) -> (batch_size, 10)
(//...)
out = self.fc(out)
class RNN(nn.Module):
def __init__(self, input_size, hidden_size, num_layers, num_classes):
super(RNN, self).__init__()
self.num_layers = num_layers
self.hidden_size = hidden_size
# RNN — PyTorch 2.5 documentation
# https://pytorch.org/docs/stable/generated/torch.nn.RNN.html
self.rnn = nn.RNN(input_size, hidden_size, num_layers, batch_first=True, nonlinearity="tanh", bidirectional=False)
# @shape: 28, 128, 2
# or:
# self.gru = nn.GRU(input_size, hidden_size, num_layers, batch_first=True)
# self.lstm = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True)
self.fc = nn.Linear(hidden_size, num_classes)
# @shape: (batch_size, 128) -> (batch_size, 10)
def forward(self, x: Tensor):
# Set initial hidden states (and cell states for LSTM)
h0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(device)
# @nota: n = batch_size = 100
# @shape: x: (batch_size, sequence_length, input_size) = (n, 28, 28)
# @shape: h0: (2, n, 128)
# or with:
# c0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(device)
# Forward propagate RNN // em, no that self loop for recurrent now ..
out, h_t = self.rnn(x, h0)
out: Tensor
# @shape: out: (batch_size, seq_length, hidden_size) = (n, 28, 128)
# (h_t): Shape: (num_layers, batch_size, hidden_size)
# print(">--")
# print(out)
# print(h_t)
# for i in range(0, h_t.shape[1]):
# # for every batch, the tensor in the last time step of {y_t} output == the tensor in the last layer of h_t hidden_state
# if not torch.equal(out[i, -1], h_t[-1, i]):
# raise ValueError("")
# or:
# out, _ = self.lstm(x, (h0,c0))
# Decode the hidden state of the last time step # The `-1` index refers to the last element along the second dimension (sequence length in this case).
out = out[:, -1, :]
# @shape: out: (n, 128) # output of the 28th/last time step
out = self.fc(out)
# @shape: out: (n, 10)
return out
Here is one real output from out, h_t = self.rnn(x, h0) # print(out) # print(h_t)
.
Watch this tensor: [ 0.118, 0.049, 0.117, -0.012, 0.138, ..., -0.074, -0.069, 0.046, -0.004, 0.098], // << eg see this line, they are same
You can find each row tensor in the h_t
in the last row tensor of out
.
>--
tensor([[[ 0.057, 0.105, -0.045, -0.055, 0.214, ..., -0.022, -0.023, 0.018, 0.058, -0.041],
[ 0.069, 0.013, 0.026, -0.040, 0.123, ..., -0.054, 0.014, 0.114, 0.029, 0.067],
[ 0.142, 0.028, 0.018, -0.053, 0.123, ..., -0.046, -0.019, 0.104, 0.024, 0.047],
[ 0.159, 0.032, 0.028, -0.059, 0.115, ..., -0.034, -0.002, 0.121, -0.036, 0.056],
[ 0.132, 0.046, 0.054, -0.046, 0.109, ..., -0.043, -0.005, 0.098, -0.038, 0.072],
...,
[ 0.093, 0.082, 0.231, -0.073, 0.143, ..., -0.067, -0.138, 0.016, 0.007, 0.150],
[ 0.098, 0.067, 0.218, -0.055, 0.120, ..., -0.082, -0.096, 0.034, -0.005, 0.149],
[ 0.158, 0.026, 0.242, -0.053, 0.095, ..., -0.100, -0.105, 0.032, -0.022, 0.121],
[ 0.184, 0.051, 0.244, -0.048, 0.115, ..., -0.125, -0.123, 0.023, 0.007, 0.129],
[ 0.147, 0.050, 0.239, -0.025, 0.126, ..., -0.100, -0.135, 0.031, -0.027, 0.131]],
[[ 0.057, 0.105, -0.045, -0.055, 0.214, ..., -0.022, -0.023, 0.018, 0.058, -0.041],
[ 0.069, 0.013, 0.026, -0.040, 0.123, ..., -0.054, 0.014, 0.114, 0.029, 0.067],
[ 0.142, 0.028, 0.018, -0.053, 0.123, ..., -0.046, -0.019, 0.104, 0.024, 0.047],
[ 0.159, 0.032, 0.028, -0.059, 0.115, ..., -0.034, -0.002, 0.121, -0.036, 0.056],
[ 0.122, 0.021, 0.088, -0.125, 0.089, ..., -0.045, -0.003, 0.064, 0.024, 0.040],
...,
[ 0.071, -0.062, 0.179, -0.167, -0.020, ..., -0.195, -0.090, 0.108, 0.166, 0.113],
[ 0.137, -0.026, 0.259, -0.002, 0.045, ..., -0.156, -0.108, 0.046, 0.049, 0.138],
[ 0.122, -0.019, 0.159, 0.013, 0.085, ..., -0.098, -0.046, 0.006, -0.002, 0.072],
[ 0.133, 0.047, 0.095, 0.011, 0.123, ..., -0.081, -0.055, 0.033, -0.026, 0.107],
[ 0.090, 0.078, 0.068, -0.047, 0.153, ..., -0.052, -0.063, 0.069, -0.030, 0.119]],
[[ 0.057, 0.105, -0.045, -0.055, 0.214, ..., -0.022, -0.023, 0.018, 0.058, -0.041],
[ 0.069, 0.013, 0.026, -0.040, 0.123, ..., -0.054, 0.014, 0.114, 0.029, 0.067],
[ 0.142, 0.028, 0.018, -0.053, 0.123, ..., -0.046, -0.019, 0.104, 0.024, 0.047],
[ 0.159, 0.032, 0.028, -0.059, 0.115, ..., -0.034, -0.002, 0.121, -0.036, 0.056],
[ 0.132, 0.046, 0.054, -0.046, 0.109, ..., -0.043, -0.005, 0.098, -0.038, 0.072],
...,
[ 0.064, 0.077, -0.109, -0.184, 0.064, ..., 0.100, -0.053, 0.139, 0.047, 0.058],
[ 0.062, 0.056, -0.136, -0.172, 0.081, ..., 0.073, -0.063, 0.159, 0.054, 0.100],
[ 0.026, 0.006, -0.107, -0.160, 0.105, ..., 0.041, -0.043, 0.172, 0.049, 0.113],
[ 0.087, -0.002, 0.011, -0.106, 0.083, ..., -0.021, -0.031, 0.087, 0.058, 0.047],
[ 0.099, 0.001, 0.045, -0.068, 0.126, ..., -0.043, -0.028, 0.099, 0.008, 0.101]],
[[ 0.057, 0.105, -0.045, -0.055, 0.214, ..., -0.022, -0.023, 0.018, 0.058, -0.041],
[ 0.069, 0.013, 0.026, -0.040, 0.123, ..., -0.054, 0.014, 0.114, 0.029, 0.067],
[ 0.142, 0.028, 0.018, -0.053, 0.123, ..., -0.046, -0.019, 0.104, 0.024, 0.047],
[ 0.159, 0.032, 0.028, -0.059, 0.115, ..., -0.034, -0.002, 0.121, -0.036, 0.056],
[ 0.132, 0.046, 0.054, -0.046, 0.109, ..., -0.043, -0.005, 0.098, -0.038, 0.072],
...,
[ 0.046, 0.065, -0.045, -0.082, 0.068, ..., 0.069, -0.037, 0.106, 0.019, 0.102],
[ 0.040, 0.045, -0.063, -0.106, 0.046, ..., 0.068, -0.051, 0.122, 0.023, 0.072],
[ 0.101, 0.050, -0.019, -0.106, 0.068, ..., 0.007, -0.035, 0.119, 0.045, 0.034],
[ 0.126, 0.035, 0.041, -0.078, 0.091, ..., -0.039, -0.030, 0.111, 0.004, 0.087],
[ 0.125, 0.028, 0.066, -0.058, 0.122, ..., -0.053, -0.031, 0.084, -0.016, 0.088]],
[[ 0.057, 0.105, -0.045, -0.055, 0.214, ..., -0.022, -0.023, 0.018, 0.058, -0.041],
[ 0.069, 0.013, 0.026, -0.040, 0.123, ..., -0.054, 0.014, 0.114, 0.029, 0.067],
[ 0.142, 0.028, 0.018, -0.053, 0.123, ..., -0.046, -0.019, 0.104, 0.024, 0.047],
[ 0.159, 0.032, 0.028, -0.059, 0.115, ..., -0.034, -0.002, 0.121, -0.036, 0.056],
[ 0.132, 0.046, 0.054, -0.046, 0.109, ..., -0.043, -0.005, 0.098, -0.038, 0.072],
...,
[ 0.090, 0.034, -0.028, -0.124, 0.030, ..., 0.071, -0.059, 0.109, 0.089, 0.058],
[ 0.109, 0.058, -0.057, -0.142, 0.042, ..., 0.064, -0.043, 0.122, 0.082, 0.047],
[ 0.091, 0.036, -0.018, -0.094, 0.039, ..., 0.036, -0.003, 0.126, 0.028, 0.066],
[ 0.122, 0.046, 0.041, -0.068, 0.088, ..., -0.023, -0.023, 0.114, 0.007, 0.052],
[ 0.127, 0.034, 0.056, -0.061, 0.126, ..., -0.047, -0.022, 0.102, -0.010, 0.079]],
...,
[[ 0.057, 0.105, -0.045, -0.055, 0.214, ..., -0.022, -0.023, 0.018, 0.058, -0.041],
[ 0.069, 0.013, 0.026, -0.040, 0.123, ..., -0.054, 0.014, 0.114, 0.029, 0.067],
[ 0.142, 0.028, 0.018, -0.053, 0.123, ..., -0.046, -0.019, 0.104, 0.024, 0.047],
[ 0.159, 0.032, 0.028, -0.059, 0.115, ..., -0.034, -0.002, 0.121, -0.036, 0.056],
[ 0.132, 0.046, 0.054, -0.046, 0.109, ..., -0.043, -0.005, 0.098, -0.038, 0.072],
...,
[-0.042, -0.062, 0.058, -0.124, 0.102, ..., 0.064, -0.063, 0.070, 0.269, 0.059],
[ 0.005, -0.011, -0.033, -0.105, 0.090, ..., 0.051, -0.037, 0.088, 0.175, 0.086],
[ 0.050, -0.022, -0.007, -0.003, 0.033, ..., 0.051, -0.002, 0.010, 0.080, 0.041],
[ 0.107, 0.045, 0.001, -0.022, 0.127, ..., -0.013, -0.072, 0.049, 0.027, 0.059],
[ 0.121, 0.076, 0.010, -0.044, 0.148, ..., -0.042, -0.076, 0.044, 0.004, 0.082]],
[[ 0.057, 0.105, -0.045, -0.055, 0.214, ..., -0.022, -0.023, 0.018, 0.058, -0.041],
[ 0.069, 0.013, 0.026, -0.040, 0.123, ..., -0.054, 0.014, 0.114, 0.029, 0.067],
[ 0.142, 0.028, 0.018, -0.053, 0.123, ..., -0.046, -0.019, 0.104, 0.024, 0.047],
[ 0.159, 0.032, 0.028, -0.059, 0.115, ..., -0.034, -0.002, 0.121, -0.036, 0.056],
[ 0.132, 0.046, 0.054, -0.046, 0.109, ..., -0.043, -0.005, 0.098, -0.038, 0.072],
...,
[ 0.106, 0.021, 0.173, -0.125, 0.190, ..., -0.045, -0.106, 0.089, 0.073, 0.136],
[ 0.077, 0.070, 0.030, -0.137, 0.173, ..., 0.030, -0.067, 0.129, 0.023, 0.162],
[ 0.067, 0.056, 0.031, -0.074, 0.039, ..., 0.057, -0.016, 0.059, 0.012, 0.073],
[ 0.095, 0.077, 0.036, -0.066, 0.120, ..., 0.021, -0.044, 0.117, -0.004, 0.073],
[ 0.125, 0.057, 0.063, -0.060, 0.109, ..., -0.030, -0.051, 0.102, -0.016, 0.085]],
[[ 0.057, 0.105, -0.045, -0.055, 0.214, ..., -0.022, -0.023, 0.018, 0.058, -0.041],
[ 0.069, 0.013, 0.026, -0.040, 0.123, ..., -0.054, 0.014, 0.114, 0.029, 0.067],
[ 0.142, 0.028, 0.018, -0.053, 0.123, ..., -0.046, -0.019, 0.104, 0.024, 0.047],
[ 0.159, 0.032, 0.028, -0.059, 0.115, ..., -0.034, -0.002, 0.121, -0.036, 0.056],
[ 0.132, 0.046, 0.054, -0.046, 0.109, ..., -0.043, -0.005, 0.098, -0.038, 0.072],
...,
[ 0.048, 0.031, -0.103, -0.133, 0.037, ..., 0.069, -0.058, 0.122, 0.030, 0.073],
[ 0.043, -0.006, -0.082, -0.113, 0.040, ..., 0.040, -0.053, 0.122, 0.032, 0.083],
[ 0.088, 0.013, -0.004, -0.104, 0.083, ..., -0.036, -0.046, 0.100, 0.048, 0.055],
[ 0.115, 0.017, 0.061, -0.080, 0.102, ..., -0.054, -0.037, 0.096, -0.002, 0.102],
[ 0.113, 0.023, 0.075, -0.052, 0.118, ..., -0.050, -0.031, 0.078, -0.022, 0.093]],
[[ 0.057, 0.105, -0.045, -0.055, 0.214, ..., -0.022, -0.023, 0.018, 0.058, -0.041],
[ 0.069, 0.013, 0.026, -0.040, 0.123, ..., -0.054, 0.014, 0.114, 0.029, 0.067],
[ 0.142, 0.028, 0.018, -0.053, 0.123, ..., -0.046, -0.019, 0.104, 0.024, 0.047],
[ 0.159, 0.032, 0.028, -0.059, 0.115, ..., -0.034, -0.002, 0.121, -0.036, 0.056],
[ 0.132, 0.046, 0.054, -0.046, 0.109, ..., -0.043, -0.005, 0.098, -0.038, 0.072],
...,
[ 0.091, -0.014, 0.236, -0.042, 0.136, ..., -0.139, -0.112, -0.028, 0.070, 0.103],
[ 0.095, -0.008, 0.239, -0.029, 0.136, ..., -0.127, -0.120, -0.014, 0.055, 0.126],
[ 0.123, 0.026, 0.213, -0.024, 0.147, ..., -0.134, -0.105, 0.010, 0.061, 0.121],
[ 0.127, 0.038, 0.211, -0.009, 0.147, ..., -0.092, -0.124, 0.040, 0.005, 0.114],
[ 0.118, 0.049, 0.117, -0.012, 0.138, ..., -0.074, -0.069, 0.046, -0.004, 0.098]], // << eg see this line, they are same
[[ 0.057, 0.105, -0.045, -0.055, 0.214, ..., -0.022, -0.023, 0.018, 0.058, -0.041],
[ 0.069, 0.013, 0.026, -0.040, 0.123, ..., -0.054, 0.014, 0.114, 0.029, 0.067],
[ 0.142, 0.028, 0.018, -0.053, 0.123, ..., -0.046, -0.019, 0.104, 0.024, 0.047],
[ 0.159, 0.032, 0.028, -0.059, 0.115, ..., -0.034, -0.002, 0.121, -0.036, 0.056],
[ 0.132, 0.046, 0.054, -0.046, 0.109, ..., -0.043, -0.005, 0.098, -0.038, 0.072],
...,
[ 0.152, -0.010, 0.118, -0.057, 0.071, ..., -0.202, 0.018, 0.039, 0.035, 0.105],
[ 0.118, -0.042, 0.180, -0.128, 0.109, ..., -0.099, -0.110, 0.044, 0.063, 0.076],
[ 0.073, -0.049, 0.173, -0.136, 0.087, ..., -0.036, -0.123, 0.002, 0.133, 0.087],
[ 0.069, -0.013, 0.158, -0.002, 0.055, ..., -0.063, -0.058, 0.019, 0.025, 0.139],
[ 0.124, 0.043, 0.126, 0.003, 0.102, ..., -0.062, -0.050, 0.027, 0.027, 0.043]]], device='cuda:0', grad_fn=<CudnnRnnBackward0>)
tensor([[[ 0.183, -0.012, 0.073, -0.071, -0.109, ..., 0.137, -0.035, -0.130, -0.083, 0.095],
[ 0.180, 0.006, -0.001, -0.072, -0.091, ..., 0.107, -0.048, -0.102, -0.115, 0.042],
[ 0.172, -0.030, -0.035, -0.092, -0.117, ..., 0.134, -0.004, -0.097, -0.090, 0.047],
[ 0.174, -0.006, -0.010, -0.082, -0.100, ..., 0.104, -0.027, -0.120, -0.122, 0.045],
[ 0.185, 0.012, -0.005, -0.058, -0.087, ..., 0.119, -0.025, -0.112, -0.108, 0.044],
...,
[ 0.133, 0.020, -0.024, -0.084, -0.078, ..., 0.120, -0.073, -0.125, -0.093, 0.026],
[ 0.162, -0.003, -0.011, -0.075, -0.095, ..., 0.102, -0.036, -0.110, -0.100, 0.031],
[ 0.172, -0.020, -0.010, -0.087, -0.097, ..., 0.100, -0.028, -0.122, -0.131, 0.044],
[ 0.165, 0.015, 0.037, -0.078, -0.044, ..., 0.125, -0.020, -0.113, -0.134, 0.039],
[ 0.124, 0.012, -0.002, -0.105, 0.006, ..., 0.146, -0.070, -0.155, -0.155, 0.038]],
[[ 0.147, 0.050, 0.239, -0.025, 0.126, ..., -0.100, -0.135, 0.031, -0.027, 0.131],
[ 0.090, 0.078, 0.068, -0.047, 0.153, ..., -0.052, -0.063, 0.069, -0.030, 0.119],
[ 0.099, 0.001, 0.045, -0.068, 0.126, ..., -0.043, -0.028, 0.099, 0.008, 0.101],
[ 0.125, 0.028, 0.066, -0.058, 0.122, ..., -0.053, -0.031, 0.084, -0.016, 0.088],
[ 0.127, 0.034, 0.056, -0.061, 0.126, ..., -0.047, -0.022, 0.102, -0.010, 0.079],
...,
[ 0.121, 0.076, 0.010, -0.044, 0.148, ..., -0.042, -0.076, 0.044, 0.004, 0.082],
[ 0.125, 0.057, 0.063, -0.060, 0.109, ..., -0.030, -0.051, 0.102, -0.016, 0.085],
[ 0.113, 0.023, 0.075, -0.052, 0.118, ..., -0.050, -0.031, 0.078, -0.022, 0.093],
[ 0.118, 0.049, 0.117, -0.012, 0.138, ..., -0.074, -0.069, 0.046, -0.004, 0.098], // << eg see this line, they are same
[ 0.124, 0.043, 0.126, 0.003, 0.102, ..., -0.062, -0.050, 0.027, 0.027, 0.043]]], device='cuda:0', grad_fn=<CudnnRnnBackward0>)