Source code for flood_forecast.transformer_xl.dummy_torch

"""
A small dummy model specifically for unit and integration testing purposes
"""
import torch
from torch import nn


[docs]class DummyTorchModel(nn.Module):
[docs] def __init__(self, forecast_length: int): """A dummy model that will return a tensor of ones (batch_size, forecast_len) :param forecast_length: The length to forecast :type forecast_length: int """ super(DummyTorchModel, self).__init__() self.out_len = forecast_length # Layer specifically to avoid NULL parameter method self.linear_test_layer = nn.Linear(3, 10)
[docs] def forward(self, x: torch.Tensor) -> torch.Tensor: """The forward pass for the dummy model :param x: Here the data is irrelvant. Only batch_size is grabbed :type x: torch.Tensor :param mask: [description], defaults to None :type mask: [type], optional :return: A tensor with fixed data of one :rtype: torch.Tensor """ batch_sz = x.size(0) result = torch.ones(batch_sz, self.out_len, requires_grad=True, device=x.device) return result