import torch
import torch.nn as nn
from einops import rearrange, repeat
from math import ceil, sqrt
[docs]
class SegMerging(nn.Module):
"""Segment Merging Layer.
The adjacent `win_size' segments in each dimension will be merged into one segment to
get representation of a coarser scale
we set win_size = 2 in our paper
"""
[docs]
def __init__(self, d_model, win_size, norm_layer=nn.LayerNorm):
super().__init__()
self.d_model = d_model
self.win_size = win_size
self.linear_trans = nn.Linear(win_size * d_model, d_model)
self.norm = norm_layer(win_size * d_model)
[docs]
def forward(self, x):
"""
x: B, ts_d, L, d_model
"""
batch_size, ts_d, seg_num, d_model = x.shape
pad_num = seg_num % self.win_size
if pad_num != 0:
pad_num = self.win_size - pad_num
x = torch.cat((x, x[:, :, -pad_num:, :]), dim=-2)
seg_to_merge = []
for i in range(self.win_size):
seg_to_merge.append(x[:, :, i:: self.win_size, :])
x = torch.cat(seg_to_merge, -1) # [B, ts_d, seg_num/win_size, win_size*d_model]
x = self.norm(x)
x = self.linear_trans(x)
return x
[docs]
class scale_block(nn.Module):
"""
We can use one segment merging layer followed by multiple TSA layers in each scale
the parameter `depth' determines the number of TSA layers used in each scale
We set depth = 1 in the paper .
"""
[docs]
def __init__(
self, win_size, d_model, n_heads, d_ff, depth, dropout, seg_num=10, factor=10
):
super(scale_block, self).__init__()
if win_size > 1:
self.merge_layer = SegMerging(d_model, win_size, nn.LayerNorm)
else:
self.merge_layer = None
self.encode_layers = nn.ModuleList()
for i in range(depth):
self.encode_layers.append(
TwoStageAttentionLayer(seg_num, factor, d_model, n_heads, d_ff, dropout)
)
[docs]
def forward(self, x):
_, ts_dim, _, _ = x.shape
if self.merge_layer is not None:
x = self.merge_layer(x)
for layer in self.encode_layers:
x = layer(x)
return x
[docs]
class Encoder(nn.Module):
"""The Encoder of Crossformer model."""
[docs]
def __init__(
self,
e_blocks,
win_size,
d_model,
n_heads,
d_ff,
block_depth,
dropout,
in_seg_num=10,
factor=10,
):
super(Encoder, self).__init__()
self.encode_blocks = nn.ModuleList()
self.encode_blocks.append(
scale_block(
1, d_model, n_heads, d_ff, block_depth, dropout, in_seg_num, factor
)
)
for i in range(1, e_blocks):
self.encode_blocks.append(
scale_block(
win_size,
d_model,
n_heads,
d_ff,
block_depth,
dropout,
ceil(in_seg_num / win_size ** i),
factor,
)
)
[docs]
def forward(self, x):
encode_x = []
encode_x.append(x)
for block in self.encode_blocks:
x = block(x)
encode_x.append(x)
return encode_x
[docs]
class DecoderLayer(nn.Module):
"""The decoder layer of Crossformer, each layer will make a prediction at its scale."""
[docs]
def __init__(
self,
seg_len,
d_model,
n_heads,
d_ff=None,
dropout=0.1,
out_seg_num=10,
factor=10,
):
super(DecoderLayer, self).__init__()
self.self_attention = TwoStageAttentionLayer(
out_seg_num, factor, d_model, n_heads, d_ff, dropout
)
self.cross_attention = AttentionLayer(d_model, n_heads, dropout=dropout)
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
self.dropout = nn.Dropout(dropout)
self.MLP1 = nn.Sequential(
nn.Linear(d_model, d_model), nn.GELU(), nn.Linear(d_model, d_model)
)
self.linear_pred = nn.Linear(d_model, seg_len)
[docs]
def forward(self, x, cross: torch.Tensor):
"""
x: the output of last decoder layer
cross: the output of the corresponding encoder layer
"""
batch = x.shape[0]
x = self.self_attention(x)
x = rearrange(x, "b ts_d out_seg_num d_model -> (b ts_d) out_seg_num d_model")
cross = rearrange(
cross, "b ts_d in_seg_num d_model -> (b ts_d) in_seg_num d_model"
)
tmp = self.cross_attention(x, cross, cross,)
x = x + self.dropout(tmp)
y = x = self.norm1(x)
y = self.MLP1(y)
dec_output = self.norm2(x + y)
dec_output = rearrange(
dec_output,
"(b ts_d) seg_dec_num d_model -> b ts_d seg_dec_num d_model",
b=batch,
)
layer_predict = self.linear_pred(dec_output)
layer_predict = rearrange(
layer_predict, "b out_d seg_num seg_len -> b (out_d seg_num) seg_len"
)
return dec_output, layer_predict
[docs]
class Decoder(nn.Module):
"""The decoder of Crossformer, making the final prediction by adding up predictions at each scale."""
[docs]
def __init__(
self,
seg_len,
d_layers,
d_model,
n_heads,
d_ff,
dropout,
router=False,
out_seg_num=10,
factor=10,
):
super(Decoder, self).__init__()
self.router = router
self.decode_layers = nn.ModuleList()
for i in range(d_layers):
self.decode_layers.append(
DecoderLayer(
seg_len, d_model, n_heads, d_ff, dropout, out_seg_num, factor
)
)
[docs]
def forward(self, x, cross):
final_predict = None
i = 0
ts_d = x.shape[1]
for layer in self.decode_layers:
cross_enc = cross[i]
x, layer_predict = layer(x, cross_enc)
if final_predict is None:
final_predict = layer_predict
else:
final_predict = final_predict + layer_predict
i += 1
final_predict = rearrange(
final_predict,
"b (out_d seg_num) seg_len -> b (seg_num seg_len) out_d",
out_d=ts_d,
)
return final_predict
[docs]
class FullAttention(nn.Module):
"""The Attention operation."""
[docs]
def __init__(self, scale=None, attention_dropout=0.1):
super(FullAttention, self).__init__()
self.scale = scale
self.dropout = nn.Dropout(attention_dropout)
[docs]
def forward(self, queries, keys, values):
B, L, H, E = queries.shape
_, S, _, D = values.shape
scale = self.scale or 1.0 / sqrt(E)
scores = torch.einsum("blhe,bshe->bhls", queries, keys)
A = self.dropout(torch.softmax(scale * scores, dim=-1))
V = torch.einsum("bhls,bshd->blhd", A, values)
return V.contiguous()
[docs]
class AttentionLayer(nn.Module):
"""The Multi-head Self-Attention (MSA) Layer."""
[docs]
def __init__(
self, d_model, n_heads, d_keys=None, d_values=None, mix=True, dropout=0.1
):
super(AttentionLayer, self).__init__()
d_keys = d_keys or (d_model // n_heads)
d_values = d_values or (d_model // n_heads)
self.inner_attention = FullAttention(scale=None, attention_dropout=dropout)
self.query_projection = nn.Linear(d_model, d_keys * n_heads)
self.key_projection = nn.Linear(d_model, d_keys * n_heads)
self.value_projection = nn.Linear(d_model, d_values * n_heads)
self.out_projection = nn.Linear(d_values * n_heads, d_model)
self.n_heads = n_heads
self.mix = mix
[docs]
def forward(self, queries, keys, values):
B, L, _ = queries.shape
_, S, _ = keys.shape
H = self.n_heads
queries = self.query_projection(queries).view(B, L, H, -1)
keys = self.key_projection(keys).view(B, S, H, -1)
values = self.value_projection(values).view(B, S, H, -1)
out = self.inner_attention(queries, keys, values,)
if self.mix:
out = out.transpose(2, 1).contiguous()
out = out.view(B, L, -1)
return self.out_projection(out)
[docs]
class TwoStageAttentionLayer(nn.Module):
"""The Two Stage Attention (TSA) Layer input/output shape: [batch_size, Data_dim(D), Seg_num(L), d_model]"""
[docs]
def __init__(self, seg_num, factor, d_model, n_heads, d_ff=None, dropout=0.1):
super(TwoStageAttentionLayer, self).__init__()
d_ff = d_ff or 4 * d_model
self.time_attention = AttentionLayer(d_model, n_heads, dropout=dropout)
self.dim_sender = AttentionLayer(d_model, n_heads, dropout=dropout)
self.dim_receiver = AttentionLayer(d_model, n_heads, dropout=dropout)
self.router = nn.Parameter(torch.randn(seg_num, factor, d_model))
self.dropout = nn.Dropout(dropout)
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
self.norm3 = nn.LayerNorm(d_model)
self.norm4 = nn.LayerNorm(d_model)
self.MLP1 = nn.Sequential(
nn.Linear(d_model, d_ff), nn.GELU(), nn.Linear(d_ff, d_model)
)
self.MLP2 = nn.Sequential(
nn.Linear(d_model, d_ff), nn.GELU(), nn.Linear(d_ff, d_model)
)
[docs]
def forward(self, x: torch.Tensor):
# Cross Time Stage: Directly apply MSA to each dimension
batch = x.shape[0]
time_in = rearrange(x, "b ts_d seg_num d_model -> (b ts_d) seg_num d_model")
time_enc = self.time_attention(time_in, time_in, time_in)
dim_in = time_in + self.dropout(time_enc)
dim_in = self.norm1(dim_in)
dim_in = dim_in + self.dropout(self.MLP1(dim_in))
dim_in = self.norm2(dim_in)
# Cross Dimension Stage: use a small set of learnable vectors to aggregate and distribute
# messages to build the D-to-D connection
dim_send = rearrange(
dim_in, "(b ts_d) seg_num d_model -> (b seg_num) ts_d d_model", b=batch
)
batch_router = repeat(
self.router,
"seg_num factor d_model -> (repeat seg_num) factor d_model",
repeat=batch,
)
dim_buffer = self.dim_sender(batch_router, dim_send, dim_send)
dim_receive = self.dim_receiver(dim_send, dim_buffer, dim_buffer)
dim_enc = dim_send + self.dropout(dim_receive)
dim_enc = self.norm3(dim_enc)
dim_enc = dim_enc + self.dropout(self.MLP2(dim_enc))
dim_enc = self.norm4(dim_enc)
final_out = rearrange(
dim_enc, "(b seg_num) ts_d d_model -> b ts_d seg_num d_model", b=batch
)
return final_out
[docs]
class DSW_embedding(nn.Module):
[docs]
def __init__(self, seg_len, d_model):
"""_summary_
:param seg_len: _description_
:type seg_len: _type_
:param d_model: _description_
:type d_model: _type_
"""
super(DSW_embedding, self).__init__()
self.seg_len = seg_len
self.linear = nn.Linear(seg_len, d_model)
[docs]
def forward(self, x):
batch, ts_len, ts_dim = x.shape
x_segment = rearrange(
x, "b (seg_num seg_len) d -> (b d seg_num) seg_len", seg_len=self.seg_len
)
x_embed = self.linear(x_segment)
x_embed = rearrange(
x_embed, "(b d seg_num) d_model -> b d seg_num d_model", b=batch, d=ts_dim
)
return x_embed