Source code for flood_forecast.da_rnn.modules

import torch
from torch import nn
from torch.autograd import Variable
from torch.nn import functional as tf


[docs] def init_hidden(x, hidden_size: int): """ Train the initial value of the hidden state: https://r2rt.com/non-zero-initial-states-for-recurrent-neural-networks.html """ return Variable(torch.zeros(1, x.size(0), hidden_size)).to(x.device)
[docs] class Encoder(nn.Module):
[docs] def __init__(self, input_size: int, hidden_size: int, T: int): """ input size: number of underlying factors (81) T: number of time steps (10) hidden_size: dimension of the hidden state """ super(Encoder, self).__init__() self.input_size = input_size self.hidden_size = hidden_size self.T = T self.lstm_layer = nn.LSTM(input_size=input_size, hidden_size=hidden_size, num_layers=1) self.attn_linear = nn.Linear(in_features=2 * hidden_size + T - 1, out_features=1)
[docs] def forward(self, input_data: torch.Tensor): # input_data: (batch_size, T - 1, input_size) device = input_data.device input_weighted = Variable( torch.zeros( input_data.size(0), self.T - 1, self.input_size)).to(device) input_encoded = Variable( torch.zeros( input_data.size(0), self.T - 1, self.hidden_size)).to(device) # hidden, cell: initial states with dimension hidden_size hidden = init_hidden(input_data, self.hidden_size) # 1 * batch_size * hidden_size cell = init_hidden(input_data, self.hidden_size) for t in range(self.T - 1): # Eqn. 8: concatenate the hidden states with each predictor x = torch.cat( (hidden.repeat( self.input_size, 1, 1).permute( 1, 0, 2), cell.repeat( self.input_size, 1, 1).permute( 1, 0, 2), input_data.permute( 0, 2, 1)), dim=2) # batch_size * input_size * (2*hidden_size + T - 1) # Eqn. 8: Get attention weights x = self.attn_linear(x.view(-1, self.hidden_size * 2 + self.T - 1) ) # (batch_size * input_size) * 1 # Eqn. 9: Softmax the attention weights attn_weights = tf.softmax(x.view(-1, self.input_size), dim=1) # (batch_size, input_size) # Eqn. 10: LSTM # (batch_size, input_size) weighted_input = torch.mul(attn_weights, input_data[:, t, :]) # Fix the warning about non-contiguous memory # see https://discuss.pytorch.org/t/dataparallel-issue-with-flatten-parameter/8282 self.lstm_layer.flatten_parameters() _, lstm_states = self.lstm_layer(weighted_input.unsqueeze(0), (hidden, cell)) hidden = lstm_states[0] cell = lstm_states[1] # Save output input_weighted[:, t, :] = weighted_input input_encoded[:, t, :] = hidden return input_weighted, input_encoded
[docs] class Decoder(nn.Module):
[docs] def __init__(self, encoder_hidden_size: int, decoder_hidden_size: int, T: int, out_feats=1): super(Decoder, self).__init__() self.T = T self.encoder_hidden_size = encoder_hidden_size self.decoder_hidden_size = decoder_hidden_size self.attn_layer = nn.Sequential(nn.Linear(2 * decoder_hidden_size + encoder_hidden_size, encoder_hidden_size), nn.Tanh(), nn.Linear(encoder_hidden_size, 1)) self.lstm_layer = nn.LSTM(input_size=out_feats, hidden_size=decoder_hidden_size) self.fc = nn.Linear(encoder_hidden_size + out_feats, out_feats) self.fc_final = nn.Linear(decoder_hidden_size + encoder_hidden_size, out_feats) self.fc.weight.data.normal_()
[docs] def forward(self, input_encoded, y_history): # input_encoded: (batch_size, T - 1, encoder_hidden_size) # y_history: (batch_size, (T-1)) # Initialize hidden and cell, (1, batch_size, decoder_hidden_size) hidden = init_hidden(input_encoded, self.decoder_hidden_size) cell = init_hidden(input_encoded, self.decoder_hidden_size) context = Variable(torch.zeros(input_encoded.size(0), self.encoder_hidden_size)) for t in range(self.T - 1): # (batch_size, T, (2 * decoder_hidden_size + encoder_hidden_size)) x = torch.cat((hidden.repeat(self.T - 1, 1, 1).permute(1, 0, 2), cell.repeat(self.T - 1, 1, 1).permute(1, 0, 2), input_encoded), dim=2) # Eqn. 12 & 13: softmax on the computed attention weights x = tf.softmax( self.attn_layer( x.view(-1, 2 * self.decoder_hidden_size + self.encoder_hidden_size) ).view(-1, self.T - 1), dim=1) # (batch_size, T - 1) # Eqn. 14: compute context vector context = torch.bmm(x.unsqueeze(1), input_encoded)[ :, 0, :] # (batch_size, encoder_hidden_size) # Eqn. 15 # (batch_size, out_size) y_tilde = self.fc(torch.cat((context, y_history[:, t]), dim=1)) # Eqn. 16: LSTM self.lstm_layer.flatten_parameters() _, lstm_output = self.lstm_layer(y_tilde.unsqueeze(0), (hidden, cell)) hidden = lstm_output[0] # 1 * batch_size * decoder_hidden_size cell = lstm_output[1] # 1 * batch_size * decoder_hidden_size # Eqn. 22: final output return self.fc_final(torch.cat((hidden[0], context), dim=1))