8

Could you elaborate on this argument? I found the brief explanation from the docs unsatisfying:

stateful: Boolean (default False). If True, the last state for each sample at index i in a batch will be used as initial state for the sample of index i in the following batch.

Also, when stateful = True is to be chosen? What are practical cases of its use?

Leevo
  • 6,225
  • 3
  • 16
  • 52

1 Answers1

8

This flag is used to have truncated back-propagation through time: the gradient is propagated through the hidden states of the LSTM across the time dimension in the batch and then, in the next batch, the last hidden states are used as input states for the LSTM.

This allows the LSTM to use longer context at training time while constraining the number of steps back for the gradient computation.

I know of two scenarios where this is common:

  • Language modeling (LM).
  • Time series modeling.

The training set is a list of sequences, potentially coming from a few documents (LM) or complete time series. During data preparation, the batches are created so that each sequence in a batch is the continuation of the sequence at the same position in the previous batch. This allows having document-level/long time series context when computing predictions.

In these cases, your data is longer than the sequence length dimension in the batch. This may be due to constraints in the available GPU memory (therefore limiting the maximum batch size) or by design due to any other reasons.

Update: Note that the stateful flag affects both training and inference time. If you disable it, you must ensure that at inference time each prediction gets the previous hidden state. For this, you can either create a new model with stateful=True and copy the parameters from the trained model with model.set_weights() or pass it manually. Due to this inconvenience, some people simply set stateful = True always and force the model not use the stored hidden state during training by invoking model.reset_states().

noe
  • 26,410
  • 1
  • 46
  • 76
  • I am training a text generator RNN. Since the model is too large for my current hardware, I am sampling input sequences from the corpus instead of making it read it all batch by batch. Does it make sense to set stateful = True for my LSTM layers, knowing that input sequences are smapled randomly? Or does stateful = True work only if I feed text sequences sequentially, i.e. in the same order of the original text ? – Leevo Jan 07 '20 at 13:08
  • 1
    Using stateful=True does not make much sense in your case, because you are not providing batches that maintain the sequence continuity. – noe Jan 07 '20 at 13:24
  • Do you think I should feed batches with textual continuity and set stateful=True ? Can this improve my training or it doesn't make a difference? – Leevo Jan 07 '20 at 13:31
  • 1
    In LM it normally helps a bit, but it really depends on the data, i.e. whether having a long context actually provides a stronger predictive signal than just local context. – noe Jan 07 '20 at 13:36
  • I tried to repeat the text generation code from TF website. Here too input sequences are sampled randomly from the whole corpus. Still, if I drop stateful = True the quality of generated text decreases A LOT. What do you think? – Leevo Jan 09 '20 at 08:55
  • 1
    In that example, the problem is that the code generating text (inference time) relies on the fact that stateful = True to pass the previous hidden state to the computation of the next prediction. When you made stateful = False, you invalidated such an assumption. Therefore, all predictions are generated with an input hidden state of 0, effectively making the predictions unconditional from the previous text. – noe Jan 09 '20 at 10:33
  • 1
    Also note that, while the GRU is created with stateful = True, the training does not use the stored hidden because model.reset_states() is invoked just before every training iteration. – noe Jan 09 '20 at 10:36
  • 1
    I updated my answer with a remark on this stateful = True + model.reset_states() idiom. – noe Jan 09 '20 at 10:42
  • upvoted for the last paragraph ("Update"). It is really an accurate use-case of stateful LSTMs. – pcko1 Jan 09 '20 at 15:01