Source code for flood_forecast.da_rnn.model

import torch
from torch import nn
from flood_forecast.da_rnn.modules import Encoder, Decoder


[docs]class DARNN(nn.Module): def __init__( self, input_size: int, hidden_size_encoder: int, T: int, decoder_hidden_size: int, out_feats=1): """ input size: number of underlying factors (81) T: number of time steps (10) hidden_size: dimension of the hidden state """ self.encoder = Encoder(input_size, hidden_size_encoder, T) self.decoder = Decoder(hidden_size_encoder, decoder_hidden_size, T, out_feats)
[docs] def forward(self, x: torch.Tensor, y_history: torch.Tensor): """will implement""" input_weighted, input_encoded = self.encoder(x) y_pred = self.decoder(input_encoded, y_history) return y_pred