Source code for flood_forecast.meta_models.merging_model

import torch
from typing import Dict
from torch.nn.modules.activation import MultiheadAttention


[docs]class MergingModel(torch.nn.Module):
[docs] def __init__(self, method: str, other_params: Dict): super().__init__() self.method_dict = {"Bilinear": torch.nn.Bilinear, "Bilinear2": torch.nn.Bilinear, "MultiAttn": MultiModalSelfAttention, "Concat": Concatenation, "Other": "other"} self.method_layer = self.method_dict[method](**other_params) self.method = method
[docs] def forward(self, temporal_data: torch.Tensor, meta_data: torch.Tensor): """ Args: temporal_data: """ batch_size = temporal_data.shape[0] # This assume there is no batch size present in meta-data # This will make meta_data -> (batch_size, 1, meta_data_shape) meta_data = meta_data.repeat(batch_size, 1).unsqueeze(1) if self.method == "Bilinear": meta_data = meta_data.permute(0, 2, 1) temporal_data = temporal_data.permute(0, 2, 1).contiguous() x = self.method_layer(temporal_data, meta_data) x = x.permute(0, 2, 1) elif self.method == "Bilinear2": temporal_shape = temporal_data.shape[1] meta_data = meta_data.repeat(1, temporal_shape, 1) x = self.method_layer(temporal_data, meta_data) else: x = self.method_layer(temporal_data, meta_data) return x
# A class to handle concatenation
[docs]class Concatenation(torch.nn.Module):
[docs] def __init__(self, cat_dim: int, repeat: bool = True, use_layer: bool = False, combined_shape: int = 1, out_shape: int = 1): """ Args: combined_shape int: The shape of the combined tensor along the cat dim out_shape int: The dimension of the outshape cat_dim int: The dimension to concatenate along Examples: s """ super().__init__() self.combined_shape = combined_shape self.out_shape = out_shape self.cat_dim = cat_dim self.repeat = repeat self.use_layer = use_layer if self.use_layer: self.linear = torch.nn.Linear(combined_shape, out_shape)
[docs] def forward(self, temporal_data: torch.Tensor, meta_data: torch.Tensor) -> torch.Tensor: """ Args: temporal_data: (batch_size, seq_len, d_model) meta_data (batch_size, d_embedding) """ if self.repeat: meta_data = meta_data.repeat(1, temporal_data.shape[1], 1) else: # TODO figure out pass x = torch.cat((temporal_data, meta_data), self.cat_dim) if self.use_layer: x = self.linear(x) return x
[docs]class MultiModalSelfAttention(torch.nn.Module):
[docs] def __init__(self, d_model: int, n_heads: int, dropout: float): self.main_layer = MultiheadAttention(d_model, n_heads, dropout)
[docs] def forward(self, temporal_data: torch.Tensor, meta_data) -> torch.Tensor: meta_data = meta_data.permute(2, 0, 1) temporal_data = temporal_data.permute(1, 0, 2) x = self.main_layer(temporal_data, meta_data, meta_data) return x.permute(1, 0, 2)