Source code for flood_forecast.transformer_xl.itransformer
import torch
import torch.nn as nn
from flood_forecast.transformer_xl.informer import Encoder, EncoderLayer
from flood_forecast.transformer_xl.attn import FullAttention, AttentionLayer
from flood_forecast.transformer_xl.data_embedding import DataEmbedding_inverted
[docs]
class ITransformer(nn.Module):
"""Paper link: https://arxiv.org/abs/2310.06625."""
[docs]
def __init__(self, forecast_history, forecast_length, d_model, embed, dropout, n_heads=8, use_norm=True,
e_layers=3, d_ff=512, freq='h', activation='gelu', factor=1, output_attention=True, targs=1):
"""The complete iTransformer model.
:param forecast_history: The number of historical steps to use for forecasting
:type forecast_history: int
:param forecast_length: The length of the forecast the model outputs.
:type forecast_length: int
:param d_model: The embedding dimension of the model. For the paper the authors used 512.
:type d_model: int
:param embed: THe embedding type to use. For the paper the authors used 'fixed'.
:type embed: str
:param dropout: The dropout for the model.
:type dropout: float
:param n_heads: Number of heads for the attention, defaults to 8
:type n_heads: int, optional
:param use_norm: Whether to use normalization, defaults to True
:type use_norm: bool, optional
:param e_layers: The number of embedding layers, defaults to 3
:type e_layers: int, optional
:param d_ff: _description_, defaults to 512
:type d_ff: int, optional
:param freq: The frequency of the time series data, defaults to 'h' for hourly
:type freq: str, optional
:param activation: The activation, defaults to 'gelu'
:type activation: str, optional
:param factor: =n_, defaults to 1
:type factor: int, optional
:param output_attention: Whether to output the scores, defaults to True
:type output_attention: bool, optional
"""
class_strategy = 'projection'
super(ITransformer, self).__init__()
self.seq_len = forecast_history
self.pred_len = forecast_length
self.output_attention = output_attention
self.use_norm = use_norm
# Embedding
self.enc_embedding = DataEmbedding_inverted(self.seq_len, d_model, embed, freq,
dropout)
self.class_strategy = class_strategy
# Encoder-only architecture
self.encoder = Encoder(
[
EncoderLayer(
AttentionLayer(
FullAttention(False, factor, attention_dropout=dropout), d_model, n_heads),
d_model,
d_ff,
dropout=dropout,
activation=activation
) for el in range(e_layers)
],
norm_layer=torch.nn.LayerNorm(d_model)
)
self.projector = nn.Linear(d_model, self.pred_len, bias=True)
self.c_out = targs
[docs]
def forecast(self, x_enc, x_mark_enc, x_dec, x_mark_dec):
"""_summary_
:param x_enc: _description_
:type x_enc: _type_
:param x_mark_enc: _description_
:type x_mark_enc: _type_
:param x_dec: _description_
:type x_dec: _type_
:param x_mark_dec: _description_
:type x_mark_dec: _type_
:return: _description_
:rtype: _type_
"""
if self.use_norm:
# Normalization from Non-stationary Transformer
means = x_enc.mean(1, keepdim=True).detach()
x_enc = x_enc - means
stdev = torch.sqrt(torch.var(x_enc, dim=1, keepdim=True, unbiased=False) + 1e-5)
x_enc /= stdev
_, _, N = x_enc.shape # B L N
# B: batch_size; E: d_model;
# L: seq_len; S: pred_len;
# N: number of variate (tokens), can also includes covariates
# Embedding
# B L N -> B N E (B L N -> B L E in the vanilla Transformer)
enc_out = self.enc_embedding(x_enc, x_mark_enc) # covariates (e.g timestamp) can be also embedded as tokens
# B N E -> B N E (B L E -> B L E in the vanilla Transformer)
# the dimensions of embedded time series has been inverted, and then processed by native attn,
# layernorm and ffn modules
enc_out = self.encoder(enc_out, attn_mask=None)
# B N E -> B N S -> B S N
dec_out = self.projector(enc_out[0]).permute(0, 2, 1)[:, :, :N] # filter the covariates
if self.use_norm:
# De-Normalization from Non-stationary Transformer
dec_out = dec_out * (stdev[:, 0, :].unsqueeze(1).repeat(1, self.pred_len, 1))
dec_out = dec_out + (means[:, 0, :].unsqueeze(1).repeat(1, self.pred_len, 1))
return dec_out
[docs]
def forward(self, x_enc, x_mark_enc, x_dec, x_mark_dec, mask=None):
"""_summary_
:param x_enc: _description_
:type x_enc: _type_
:param x_mark_enc: _description_
:type x_mark_enc: _type_
:param x_dec: _description_
:type x_dec: _type_
:param x_mark_dec: _description_
:type x_mark_dec: _type_
:param mask: _description_, defaults to None
:type mask: _type_, optional
:return: _description_
:rtype: _type_
"""
dec_out = self.forecast(x_enc, x_mark_enc, x_dec, x_mark_dec)
return dec_out[:, -self.pred_len:, :] # [B, L, D]