Source code for flood_forecast.transformer_xl.transformer_basic
import torch
import math
from torch.nn.modules import Transformer, TransformerEncoder, TransformerEncoderLayer, LayerNorm
from torch.autograd import Variable
[docs]class SimpleTransformer(torch.nn.Module):
def __init__(
self,
number_time_series: int,
seq_length: int = 48,
output_seq_len: int = None,
d_model: int = 128,
n_heads: int = 8,
dropout=0.1,
forward_dim=2048,
sigmoid=False):
"""
Full transformer model
"""
super().__init__()
if output_seq_len is None:
output_seq_len = seq_length
self.out_seq_len = output_seq_len
self.mask = generate_square_subsequent_mask(seq_length)
self.dense_shape = torch.nn.Linear(number_time_series, d_model)
self.pe = SimplePositionalEncoding(d_model)
self.transformer = Transformer(d_model, nhead=n_heads)
self.final_layer = torch.nn.Linear(d_model, 1)
self.sequence_size = seq_length
self.tgt_mask = generate_square_subsequent_mask(output_seq_len)
self.sigmoid = None
if sigmoid:
self.sigmoid = torch.nn.Sigmoid()
[docs] def forward(self, x: torch.Tensor, t: torch.Tensor, tgt_mask=None, src_mask=None):
if src_mask:
x = self.encode_sequence(x, src_mask)
else:
x = self.encode_sequence(x, src_mask)
return self.decode_seq(x, t, tgt_mask)
[docs] def basic_feature(self, x: torch.Tensor):
x = self.dense_shape(x)
x = self.pe(x)
x = x.permute(1, 0, 2)
return x
[docs] def encode_sequence(self, x, src_mask=None):
x = self.basic_feature(x)
x = self.transformer.encoder(x, src_mask)
return x
[docs] def decode_seq(self, mem, t, tgt_mask=None, view_number=None) -> torch.Tensor:
if view_number is None:
view_number = self.out_seq_len
if tgt_mask is None:
tgt_mask = self.tgt_mask
t = self.basic_feature(t)
x = self.transformer.decoder(t, mem, tgt_mask=tgt_mask)
x = self.final_layer(x)
if self.sigmoid:
x = self.sigmoid(x)
return x.view(-1, view_number)
[docs]class CustomTransformerDecoder(torch.nn.Module):
def __init__(
self,
seq_length: int,
output_seq_length: int,
n_time_series: int,
d_model=128,
output_dim=1,
n_layers_encoder=6,
forward_dim=2048,
dropout=0.1,
use_mask=False,
n_heads=8):
"""
Uses a number of encoder layers with simple linear decoder layer
"""
super().__init__()
self.dense_shape = torch.nn.Linear(n_time_series, d_model)
self.pe = SimplePositionalEncoding(d_model)
encoder_layer = TransformerEncoderLayer(d_model, 8, forward_dim, dropout)
encoder_norm = LayerNorm(d_model)
self.transformer_enc = TransformerEncoder(encoder_layer, n_layers_encoder, encoder_norm)
self.output_dim_layer = torch.nn.Linear(d_model, output_dim)
self.output_seq_length = output_seq_length
self.out_length_lay = torch.nn.Linear(seq_length, output_seq_length)
self.mask = generate_square_subsequent_mask(seq_length)
self.mask_it = use_mask
[docs] def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Performs forward pass on tensor of (batch_size, sequence_length, n_time_series)
Return tensor of dim (batch_size, output_seq_length)
"""
x = self.dense_shape(x)
x = self.pe(x)
x = x.permute(1, 0, 2)
if self.mask_it:
x = self.transformer_enc(x, self.mask)
else:
# Allow no mask
x = self.transformer_enc(x)
x = self.output_dim_layer(x)
x = x.permute(1, 2, 0)
x = self.out_length_lay(x)
return x.view(-1, self.output_seq_length)
[docs]class SimplePositionalEncoding(torch.nn.Module):
def __init__(self, d_model, dropout=0.1, max_len=5000):
super(SimplePositionalEncoding, self).__init__()
self.dropout = torch.nn.Dropout(p=dropout)
pe = torch.zeros(max_len, d_model)
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
pe = pe.unsqueeze(0).transpose(0, 1)
self.register_buffer('pe', pe)
[docs] def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Creates a basic positional encoding"""
x = x + self.pe[:x.size(0), :]
return self.dropout(x)
[docs]def generate_square_subsequent_mask(sz: int) -> torch.Tensor:
r"""Generate a square mask for the sequence. The masked positions are filled with float('-inf').
Unmasked positions are filled with float(0.0).
"""
mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1)
mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
return mask
[docs]def greedy_decode(
model,
src: torch.Tensor,
max_len: int,
real_target: torch.Tensor,
unsqueeze_dim=1,
device='cpu'):
"""
Mechanism to sequentially decode the model
:src Historical time series values
:real_target The real values (they should be masked), however if want can include known real values.
:returns tensor
"""
src = src.float()
real_target = real_target.float()
if hasattr(model, "mask"):
src_mask = model.mask
memory = model.encode_sequence(src, src_mask)
# Get last element of src array to forecast from
ys = src[:, -1, :].unsqueeze(unsqueeze_dim)
for i in range(max_len):
mask = generate_square_subsequent_mask(i + 1).to(device)
with torch.no_grad():
out = model.decode_seq(memory,
Variable(ys),
Variable(mask), i + 1)
real_target[:, i, 0] = out[:, i]
src = torch.cat((src, real_target[:, i, :].unsqueeze(1)), 1)
ys = torch.cat((ys, real_target[:, i, :].unsqueeze(1)), 1)
memory = model.encode_sequence(src[:, i + 1:, :], src_mask)
return ys[:, 1:, :]