import torch
from torch import nn
from torch.autograd import Variable
from torch.nn import functional as tf
[docs]
def init_hidden(x, hidden_size: int):
"""
Train the initial value of the hidden state:
https://r2rt.com/non-zero-initial-states-for-recurrent-neural-networks.html
"""
return Variable(torch.zeros(1, x.size(0), hidden_size)).to(x.device)
[docs]
class Encoder(nn.Module):
[docs]
def __init__(self, input_size: int, hidden_size: int, T: int):
"""
input size: number of underlying factors (81)
T: number of time steps (10)
hidden_size: dimension of the hidden state
"""
super(Encoder, self).__init__()
self.input_size = input_size
self.hidden_size = hidden_size
self.T = T
self.lstm_layer = nn.LSTM(input_size=input_size, hidden_size=hidden_size, num_layers=1)
self.attn_linear = nn.Linear(in_features=2 * hidden_size + T - 1, out_features=1)
[docs]
def forward(self, input_data: torch.Tensor):
# input_data: (batch_size, T - 1, input_size)
device = input_data.device
input_weighted = Variable(
torch.zeros(
input_data.size(0),
self.T - 1,
self.input_size)).to(device)
input_encoded = Variable(
torch.zeros(
input_data.size(0),
self.T - 1,
self.hidden_size)).to(device)
# hidden, cell: initial states with dimension hidden_size
hidden = init_hidden(input_data, self.hidden_size) # 1 * batch_size * hidden_size
cell = init_hidden(input_data, self.hidden_size)
for t in range(self.T - 1):
# Eqn. 8: concatenate the hidden states with each predictor
x = torch.cat(
(hidden.repeat(
self.input_size, 1, 1).permute(
1, 0, 2), cell.repeat(
self.input_size, 1, 1).permute(
1, 0, 2), input_data.permute(
0, 2, 1)), dim=2) # batch_size * input_size * (2*hidden_size + T - 1)
# Eqn. 8: Get attention weights
x = self.attn_linear(x.view(-1, self.hidden_size * 2 + self.T - 1)
) # (batch_size * input_size) * 1
# Eqn. 9: Softmax the attention weights
attn_weights = tf.softmax(x.view(-1, self.input_size),
dim=1) # (batch_size, input_size)
# Eqn. 10: LSTM
# (batch_size, input_size)
weighted_input = torch.mul(attn_weights, input_data[:, t, :])
# Fix the warning about non-contiguous memory
# see https://discuss.pytorch.org/t/dataparallel-issue-with-flatten-parameter/8282
self.lstm_layer.flatten_parameters()
_, lstm_states = self.lstm_layer(weighted_input.unsqueeze(0), (hidden, cell))
hidden = lstm_states[0]
cell = lstm_states[1]
# Save output
input_weighted[:, t, :] = weighted_input
input_encoded[:, t, :] = hidden
return input_weighted, input_encoded
[docs]
class Decoder(nn.Module):
[docs]
def __init__(self, encoder_hidden_size: int, decoder_hidden_size: int, T: int, out_feats=1):
super(Decoder, self).__init__()
self.T = T
self.encoder_hidden_size = encoder_hidden_size
self.decoder_hidden_size = decoder_hidden_size
self.attn_layer = nn.Sequential(nn.Linear(2 * decoder_hidden_size + encoder_hidden_size,
encoder_hidden_size),
nn.Tanh(),
nn.Linear(encoder_hidden_size, 1))
self.lstm_layer = nn.LSTM(input_size=out_feats, hidden_size=decoder_hidden_size)
self.fc = nn.Linear(encoder_hidden_size + out_feats, out_feats)
self.fc_final = nn.Linear(decoder_hidden_size + encoder_hidden_size, out_feats)
self.fc.weight.data.normal_()
[docs]
def forward(self, input_encoded, y_history):
# input_encoded: (batch_size, T - 1, encoder_hidden_size)
# y_history: (batch_size, (T-1))
# Initialize hidden and cell, (1, batch_size, decoder_hidden_size)
hidden = init_hidden(input_encoded, self.decoder_hidden_size)
cell = init_hidden(input_encoded, self.decoder_hidden_size)
context = Variable(torch.zeros(input_encoded.size(0), self.encoder_hidden_size))
for t in range(self.T - 1):
# (batch_size, T, (2 * decoder_hidden_size + encoder_hidden_size))
x = torch.cat((hidden.repeat(self.T - 1, 1, 1).permute(1, 0, 2),
cell.repeat(self.T - 1, 1, 1).permute(1, 0, 2),
input_encoded), dim=2)
# Eqn. 12 & 13: softmax on the computed attention weights
x = tf.softmax(
self.attn_layer(
x.view(-1, 2 * self.decoder_hidden_size + self.encoder_hidden_size)
).view(-1, self.T - 1),
dim=1) # (batch_size, T - 1)
# Eqn. 14: compute context vector
context = torch.bmm(x.unsqueeze(1), input_encoded)[
:, 0, :] # (batch_size, encoder_hidden_size)
# Eqn. 15
# (batch_size, out_size)
y_tilde = self.fc(torch.cat((context, y_history[:, t]), dim=1))
# Eqn. 16: LSTM
self.lstm_layer.flatten_parameters()
_, lstm_output = self.lstm_layer(y_tilde.unsqueeze(0), (hidden, cell))
hidden = lstm_output[0] # 1 * batch_size * decoder_hidden_size
cell = lstm_output[1] # 1 * batch_size * decoder_hidden_size
# Eqn. 22: final output
return self.fc_final(torch.cat((hidden[0], context), dim=1))