LSTM 简介
公式 LSTM
LSTM
作为门控循环神经网络因此我们从门控单元切入理解。主要包括:
- 输入门:It
- 遗忘门:Ft
- 输出门:Ot
- 候选细胞:~Ct
- 细胞:Ct
- 隐含状态:Ht
假设隐含状态长度为h,数据Xt是一个样本数为n、特征向量维度为x的批量数据,其计算如下所示(W和b表示权重和偏置):
最后的输出其实只有两个,一个是输出,一个是状态,输出就是Ht,而状态为(Ct,Ht),其他都是中间计算过程。[^2]
图示 LSTM
- 遗忘门
- 输入门
- 当前状态
- 输出层
Tensorflow LSTM
tensorflow 提供了LSTM 实现的一个 basic 版本,不包含 LSTM 的一些高级扩展,同时也提供了一个标准接口,其中包含了 LSTM 的扩展。分别为:tf.nn.rnn_cell.BasicLSTMCell(),tf.nn.rnn_cell.LSTMCell(),我们这里实现一个基本版本。[^1]
Tensorflow 实现 LSTM
1 | from __future__ import print_function |
导入数据集
1 | # Import MNIST data |
Extracting ./data/train-images-idx3-ubyte.gz
Extracting ./data/train-labels-idx1-ubyte.gz
Extracting ./data/t10k-images-idx3-ubyte.gz
Extracting ./data/t10k-labels-idx1-ubyte.gz
设置参数
1 | # 训练参数 |
构建 LSTM 网络
1 | # 定义输入 |
训练+测试
1 | # Start training |
Step 1, Minibatch Loss= 2.8645, Training Accuracy= 0.062
Step 200, Minibatch Loss= 2.1180, Training Accuracy= 0.227
Step 400, Minibatch Loss= 1.9726, Training Accuracy= 0.344
Step 600, Minibatch Loss= 1.7784, Training Accuracy= 0.445
Step 800, Minibatch Loss= 1.5500, Training Accuracy= 0.547
Step 1000, Minibatch Loss= 1.5882, Training Accuracy= 0.453
Step 1200, Minibatch Loss= 1.5326, Training Accuracy= 0.555
Step 1400, Minibatch Loss= 1.3682, Training Accuracy= 0.570
Step 1600, Minibatch Loss= 1.3374, Training Accuracy= 0.594
Step 1800, Minibatch Loss= 1.1551, Training Accuracy= 0.648
Step 2000, Minibatch Loss= 1.2116, Training Accuracy= 0.633
Step 2200, Minibatch Loss= 1.1292, Training Accuracy= 0.609
Step 2400, Minibatch Loss= 1.0862, Training Accuracy= 0.680
Step 2600, Minibatch Loss= 1.0501, Training Accuracy= 0.672
Step 2800, Minibatch Loss= 1.0487, Training Accuracy= 0.688
Step 3000, Minibatch Loss= 1.0223, Training Accuracy= 0.727
Step 3200, Minibatch Loss= 1.0418, Training Accuracy= 0.695
Step 3400, Minibatch Loss= 0.8273, Training Accuracy= 0.719
Step 3600, Minibatch Loss= 0.9088, Training Accuracy= 0.727
Step 3800, Minibatch Loss= 0.9243, Training Accuracy= 0.750
Step 4000, Minibatch Loss= 0.8085, Training Accuracy= 0.703
Step 4200, Minibatch Loss= 0.8466, Training Accuracy= 0.711
Step 4400, Minibatch Loss= 0.8973, Training Accuracy= 0.734
Step 4600, Minibatch Loss= 0.7647, Training Accuracy= 0.750
Step 4800, Minibatch Loss= 0.9088, Training Accuracy= 0.742
Step 5000, Minibatch Loss= 0.7906, Training Accuracy= 0.742
Step 5200, Minibatch Loss= 0.7275, Training Accuracy= 0.781
Step 5400, Minibatch Loss= 0.7488, Training Accuracy= 0.789
Step 5600, Minibatch Loss= 0.7517, Training Accuracy= 0.758
Step 5800, Minibatch Loss= 0.7778, Training Accuracy= 0.797
Step 6000, Minibatch Loss= 0.6736, Training Accuracy= 0.742
Step 6200, Minibatch Loss= 0.6552, Training Accuracy= 0.773
Step 6400, Minibatch Loss= 0.5746, Training Accuracy= 0.828
Step 6600, Minibatch Loss= 0.8102, Training Accuracy= 0.727
Step 6800, Minibatch Loss= 0.6669, Training Accuracy= 0.773
Step 7000, Minibatch Loss= 0.6524, Training Accuracy= 0.766
Step 7200, Minibatch Loss= 0.6481, Training Accuracy= 0.805
Step 7400, Minibatch Loss= 0.5743, Training Accuracy= 0.828
Step 7600, Minibatch Loss= 0.6983, Training Accuracy= 0.773
Step 7800, Minibatch Loss= 0.5552, Training Accuracy= 0.828
Step 8000, Minibatch Loss= 0.5728, Training Accuracy= 0.820
Step 8200, Minibatch Loss= 0.5587, Training Accuracy= 0.789
Step 8400, Minibatch Loss= 0.5205, Training Accuracy= 0.836
Step 8600, Minibatch Loss= 0.4266, Training Accuracy= 0.906
Step 8800, Minibatch Loss= 0.7197, Training Accuracy= 0.812
Step 9000, Minibatch Loss= 0.4216, Training Accuracy= 0.852
Step 9200, Minibatch Loss= 0.4448, Training Accuracy= 0.844
Step 9400, Minibatch Loss= 0.3577, Training Accuracy= 0.891
Step 9600, Minibatch Loss= 0.4034, Training Accuracy= 0.883
Step 9800, Minibatch Loss= 0.4747, Training Accuracy= 0.828
Step 10000, Minibatch Loss= 0.5763, Training Accuracy= 0.805
Optimization Finished!
Testing Accuracy: 0.875