理解LSTM&GRU

RNN现在可谓是深度学习领域的耀眼明珠,不管是作为独立模型还是应用于encoder-decoder或者是对抗网络的混合模型中,都有十分显著的效果。各大会议的paper也层出不穷,业界也在大力推荐相关模型在实际业务场景中的落地,去年自己也和其他部门的同事也上线了一个双向GRU的标注模型,取得了还算不错的效果。因为RNN本身存在的gradient vanish的问题,前文信息可能会在BPTT的过程逐渐丢失,所以如何保证这些信息在传播的过程中能够尽量保留,是基于RNN系列的改进模型所需要思考的问题。

Long Short-Term Memory

LSTM提出的时间很早了。顾名思义,它所需要做的就是记忆-用某种方式存储在传播过程中逐渐丢失的信息。那么怎么记忆呢?加一个记忆存储单元呗,也即Cell State. 这个Cell State保存了从起始时刻到现在的信息,并且在BPTT的过程中不断的进行自我更新,正如细胞繁殖。标准的LSTM的结构如下所示。

不过文本要介绍的是加了peehold结构的LSTM. 首先来看LSTM的Cell State的更新方式,假设$S^t$为t时刻的Cell State. 那么它更新的方式如下:

$$\begin{align}
S^t & = f_\iota^t * S^{t-1} + i_\phi^t * \tilde{S^{t}} \\
\end{align}$$

其中,$\tilde{S_{t}}$为当前时刻产生的新的信息,$f_\iota^t * S^{t-1}$ 表示的是上一时刻的细胞状态的遗留部分,而$i_\phi^t * \tilde{S_{t}}$则表示当前产生的新的信息有多少需要加入当前时刻的Cell State. $f_\iota^t$与$i_\phi^t$分别控制遗留和添加的比率.是一个介于(0,1)之间的数值。而隐含state则由Cell State经过一个非线性变换得到,具体如下:

$$ \begin{align}
H^t = o_\omega^t h(S^t)
\end{align} $$

其中,$o_\omega^t$也是一个介于(0,1)的数值。那么问题的关键就是如何产生$f_\iota^t$, $i_\phi^t$, $o_\omega^t$?LSTM使用了门的机制,这三种中间变量分别由遗忘门、输入门以及输出门进行控制,即:输入门和遗忘门共同决定了Cell State的更新;输出门则可以对输出进行控制。门其实还可以理解为隐含层内部的某个神经元节点,接受输入然后经过激励函数得到输出,如下:

$$
\begin{align}
f_\iota^t & = \sigma ( \sum_i W_{i\iota}X_i^t + \sum_h W_{h\iota} H_h^{t-1} + \sum_c W_{c\iota} S_c^t) \\
i_\phi^t & = \sigma ( \sum_i W_{i\phi}X_i^t + \sum_h W_{h\phi} H_h^{t-1} + \sum_c W_{c\phi} S_c^t) \\
o_\omega^t & = \sigma (\sum_i W_{i\omega} X_i^t + \sum_h W_{h\omega} H_h^{t-1} + \sum_c W_{c\omega} S_c^t) \\
\tilde{C_c^{t}} & = h( \sum_i W_{ic} X_i^t + \sum_h H_h^{t-1} W_{hc}) \\
\end{align}
$$

tensorflow以及keras里都提供了BasicLSTM的实现,代码如下:

Tensorflow:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
class BasicLSTMCell(RNNCell):
  """Basic LSTM recurrent network cell.
 
  The implementation is based on: http://arxiv.org/abs/1409.2329.
 
  We add forget_bias (default: 1) to the biases of the forget gate in order to
  reduce the scale of forgetting in the beginning of the training.
 
  It does not allow cell clipping, a projection layer, and does not
  use peep-hole connections: it is the basic baseline.
 
  For advanced models, please use the full LSTMCell that follows.
  """
 
  def __init__(self, num_units, forget_bias=1.0, input_size=None,
               state_is_tuple=True, activation=tanh):
    """Initialize the basic LSTM cell.
 
    Args:
      num_units: int, The number of units in the LSTM cell.
      forget_bias: float, The bias added to forget gates (see above).
      input_size: Deprecated and unused.
      state_is_tuple: If True, accepted and returned states are 2-tuples of
        the `c_state` and `m_state`.  If False, they are concatenated
        along the column axis.  The latter behavior will soon be deprecated.
      activation: Activation function of the inner states.
    """
    if not state_is_tuple:
      logging.warn("%s: Using a concatenated state is slower and will soon be "
                   "deprecated.  Use state_is_tuple=True.", self)
    if input_size is not None:
      logging.warn("%s: The input_size parameter is deprecated.", self)
    self._num_units = num_units
    self._forget_bias = forget_bias
    self._state_is_tuple = state_is_tuple
    self._activation = activation
 
  @property
  def state_size(self):
    return (LSTMStateTuple(self._num_units, self._num_units)
            if self._state_is_tuple else 2 * self._num_units)
 
  @property
  def output_size(self):
    return self._num_units
 
  def __call__(self, inputs, state, scope=None):
    """Long short-term memory cell (LSTM)."""
    with vs.variable_scope(scope or "basic_lstm_cell"):
      # Parameters of gates are concatenated into one multiply for efficiency.
      if self._state_is_tuple:
        c, h = state
      else:
        c, h = array_ops.split(value=state, num_or_size_splits=2, axis=1)
 
    # 线性计算 concat = [inputs, h]W + b 
    # 线性计算,分配W和b,W的shape为(2*num_units, 4*num_units), b的shape为(4*num_units,),共包含有四套参数,
      # concat shape(batch_size, 4*num_units)
     # 注意:只有cell 的input和output的size相等时才可以这样计算,否则要定义两套W,b.每套再包含四套参数
      concat = _linear([inputs, h], 4 * self._num_units, True, scope=scope)
 
      # i = input_gate, j = new_input, f = forget_gate, o = output_gate
      i, j, f, o = array_ops.split(value=concat, num_or_size_splits=4, axis=1)
 
      new_c = (c * sigmoid(f + self._forget_bias) + sigmoid(i) *
               self._activation(j))
      new_h = self._activation(new_c) * sigmoid(o)
 
      if self._state_is_tuple:
        new_state = LSTMStateTuple(new_c, new_h)
      else:
        new_state = array_ops.concat([new_c, new_h], 1)
      return new_h, new_state

Keras

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
def step(self, inputs, states):
        h_tm1 = states[0]
        c_tm1 = states[1]
        dp_mask = states[2]
        rec_dp_mask = states[3]
 
        if self.implementation == 2:
            z = K.dot(inputs * dp_mask[0], self.kernel)
            z += K.dot(h_tm1 * rec_dp_mask[0], self.recurrent_kernel)
            if self.use_bias:
                z = K.bias_add(z, self.bias)
 
            z0 = z[:, :self.units]
            z1 = z[:, self.units: 2 * self.units]
            z2 = z[:, 2 * self.units: 3 * self.units]
            z3 = z[:, 3 * self.units:]
 
            i = self.recurrent_activation(z0)
            f = self.recurrent_activation(z1)
            c = f * c_tm1 + i * self.activation(z2)
            o = self.recurrent_activation(z3)
        else:
            if self.implementation == 0:
                x_i = inputs[:, :self.units]
                x_f = inputs[:, self.units: 2 * self.units]
                x_c = inputs[:, 2 * self.units: 3 * self.units]
                x_o = inputs[:, 3 * self.units:]
            elif self.implementation == 1:
                x_i = K.dot(inputs * dp_mask[0], self.kernel_i) + self.bias_i
                x_f = K.dot(inputs * dp_mask[1], self.kernel_f) + self.bias_f
                x_c = K.dot(inputs * dp_mask[2], self.kernel_c) + self.bias_c
                x_o = K.dot(inputs * dp_mask[3], self.kernel_o) + self.bias_o
            else:
                raise ValueError('Unknown `implementation` mode.')
 
            i = self.recurrent_activation(x_i + K.dot(h_tm1 * rec_dp_mask[0],
                                                      self.recurrent_kernel_i))
            f = self.recurrent_activation(x_f + K.dot(h_tm1 * rec_dp_mask[1],
                                                      self.recurrent_kernel_f))
            c = f * c_tm1 + i * self.activation(x_c + K.dot(h_tm1 * rec_dp_mask[2],
                                                            self.recurrent_kernel_c))
            o = self.recurrent_activation(x_o + K.dot(h_tm1 * rec_dp_mask[3],
                                                      self.recurrent_kernel_o))
        h = o * self.activation(c)
        if 0 < self.dropout + self.recurrent_dropout:
            h._uses_learning_phase = True
        return h, [h, c]

另外一种流行的变体就是GRU了,不同于LSTM的是,GRU省略了Cell State状态,同时将三个门减少为两个门:更新门和重置门。大体的思路差不多,也是通过对信息的添加更新,实现对信息的全局记忆。在GRU中,重置门用以控制产生的信息$\tilde{H^{t}} $中包含的上一时刻信息的比例,更新门则控制着新产生的信息和上一时刻的信息的比例。公式如下:
$$
\begin{align}
u_\gamma^t & = \sigma( \sum_i W_{i\gamma} X_i^t + \sum_h W_{h\gamma} H_h^{t-1}) \\
r_\delta^t & = \sigma( \sum_i W_{i\delta} X_i^t + \sum_h W_{h\delta} H_h^{t-1}) \\
\tilde{H_c^{t}} & = h( \sum_i W_{ic} X_i^t + r_\delta^t* \sum_h H_h^{t-1} W_{hc}) \\
H^t & = (1-u_\gamma^t) * H^{t-1} + u_\gamma^t * \tilde{H_c^{t}}
\end{align}
$$

参考资料

Supervised Sequence Labelling with Recurrent Neural Networks
Understanding LSTM Networks

Leave a Reply

Your email address will not be published. Required fields are marked *