Press "Enter" to skip to content

LSTM in Databricks

Vedant Jain shows us an example of solving a multivariate time series forecasting problem using LSTM networks:

LSTM is a type of Recurrent Neural Network (RNN) that allows the network to retain long-term dependencies at a given time from many timesteps before. RNNs were designed to that effect using a simple feedback approach for neurons where the output sequence of data serves as one of the inputs. However, long term dependencies can make the network untrainable due to the vanishing gradient problem. LSTM is designed precisely to solve that problem.

Sometimes accurate time series predictions depend on a combination of both bits of old and recent data. We have to efficiently learn even what to pay attention to, accepting that there may be a long history of data to learn from. LSTMs combine simple DNN architectures with clever mechanisms to learn what parts of history to ‘remember’ and what to ‘forget’ over long periods. The ability of LSTM to learn patterns in data over long sequences makes them suitable for time series forecasting.

This is a nice overview and as a bonus, there’s a notebook as well where you can try it on your own.