使用 PyTorch 的一些笔记,以防写完就忘,看完 API 又想起来,长此以往。
torch.nn
torch.nn.LSTM
LSTM 中的 hidden state
其实就是指每一个 LSTM cell 的输出,而 cell state
则是每次传递到下一层的「长时记忆」,我总觉得这个名字起的特别别扭,所以总不能很好的理解。下面这张图能更好的说明这些变量的意义。

再来简单的回顾一下 LSTM 的几个公式
其中 $h_t$ 和 $c_t$ 就是所谓的 hidden state
和 cell state
了。可以看到 LSTM 中所谓的 output gate
,即 $o_t$ 其实是中间状态,它和 cell state
经过 $\tanh$ 相乘,得到了 hidden state
,也就是输出值。
PyTorch 中 LSTM 的输出结果是一个二元组套二元组 (output, (h_n, c_n))
。第一个 output
是每一个 timestamp 的输出,也就是每一个 cell 的 hidden state
。第二个输出是一个二元组,分别表示最后一个 timestamp 的 hidden state
和 cell state
。因此,如果把 h_n
和 c_n
记录下来,就可以保留整个 LSTM 的状态了。
PyTorch 中可以通过 bidirectional=True
来方便的将 LSTM 设置为双向,此时 output
会自动把每一个 timestamp 的正向和反向 LSTM 拼在一起。而 h_n
和 c_n
的第一维长度会变为 2(单向是长度为 1)。而且此时有
即正向 output
的最后一个 timestemp(对应 LSTM 的最后一个 cell)的输出和正向的 hidden state
相同,反向 output
的最后一个 timestamp(对应 LSTM 的第一个 cell)的输出和反向的 hidden state
相同。
此外,在 PyTorch 中,LSTM 输出的形状和别的框架不太一样,它是序列长度优先的,(seq_len, batch_size, hz),如果觉得不习惯,可以通过 batch_first=True
来设定为 batch_size
优先。