So I want to understand exactly how the outputs and hidden state of a GRU cell are calculated.
I obtained the pre-trained model from here and the GRU layer has been defined as nn.GRU(96, 96, bias=True).
I looked at the the PyTorch Documentation and confirmed the dimensions of the weights and bias as:
weight_ih_l0:(288, 96)weight_hh_l0:(288, 96)bias_ih_l0:(288)bias_hh_l0:(288)
My input size and output size are (1000, 8, 96). I understand that there are 1000 tensors, each of size (8, 96). The hidden state is (1, 8, 96), which is one tensor of size (8, 96).
I have also printed the variable batch_first and found it to be False. This means that:
- Sequence length:
L=1000 - Batch size:
B=8 - Input size:
Hin=96
Now going by the equations from the documentation, for the reset gate, I need to multiply the weight by the input x. But my weights are 2-dimensions and my input has three dimensions.
Here is what I've tried, I took the first (8, 96) matrix from my input and multiplied it with the transpose of my weight matrix:
Input (8, 96) x Weight (96, 288) = (8, 288)
Then I add the bias by replicating the (288) eight times to give (8, 288). This would give the size of r(t) as (8, 288). Similarly, z(t) would also be (8, 288).
This r(t) is used in n(t), since Hadamard product is used, both the matrices being multiplied have to be the same size that is (8, 288). This implies that n(t) is also (8, 288).
Finally, h(t) is the Hadamard produce and matrix addition, which would give the size of h(t) as (8, 288) which is wrong.
Where am I going wrong in this process?
TLDR; This confusion comes from the fact that the weights of the layer are the concatenation of input_hidden and hidden-hidden respectively.
-
nn.GRUlayer weight/bias layoutYou can take a closer look at what's inside the GRU layer implementation
torch.nn.GRUby peaking through the weights and biases.First the parameters of the GRU layer:
You can look at
gru.state_dict()to get the dictionary of weights of the layer.We have two weights and two biases,
_ihstands for 'input-hidden' and_hhstands for 'hidden-hidden'.For more efficient computation the parameters have been concatenated together, as the documentation page clearly explains (
|means concatenation). In this particular examplenum_layers=1andk=0:~GRU.weight_ih_l[k]– the learnable input-hidden weights of the layer(W_ir | W_iz | W_in), of shape(3*hidden_size, input_size).~GRU.weight_hh_l[k]– the learnable hidden-hidden weights of the layer(W_hr | W_hz | W_hn), of shape(3*hidden_size, hidden_size).~GRU.bias_ih_l[k]– the learnable input-hidden bias of the layer(b_ir | b_iz | b_in), of shape(3*hidden_size).~GRU.bias_hh_l[k]– the learnable hidden-hidden bias of the(b_hr | b_hz | b_hn).For further inspection we can get those split up with the following code:
Now we have the 12 tensor parameters sorted out.
- Expressions
The four expressions for a GRU layer:
r_t,z_t,n_t, andh_t, are computed at each timestep.The first operation is
r_t = σ(W_ir@x_t + b_ir + W_hr@h + b_hr). I used the@sign to designate the matrix multiplication operator (__matmul__). RememberW_iris shaped(H_in=input_size, hidden_size)whilex_tcontains the element at steptfrom thexsequence. Tensorx_t = x[t]is shaped as(N=batch_size, H_in=input_size). At this point, it's simply a matrix multiplication between the inputx[t]and the weight matrix. The resulting tensorris shaped(N, hidden_size=H_in):The same is true for all other weight multiplication operations performed. As a result, you end up with an output tensor shaped
(N, H_out=hidden_size).In the following expressions
his the tensor containing the hidden state of the previous step for each element in the batch, i.e. shaped(N, hidden_size=H_out), sincenum_layers=1, i.e. there's a single hidden layer.The output of the layer is the concatenation of the computed
htensors at consecutive timestepst(between0andL-1).- Demonstration
Here is a minimal example of an
nn.GRUinference manually computed:H_in3H_out2L3N1k1Setup:
Random input:
Inference loop:
The final output is given by the stacking the tensors
hat consecutive timesteps:In this case the output shape is
(L, N, H_out), i.e.(3, 1, 2).Which you can compare with
output, _ = gru(x).