PyTorch Training

flood_forecast.pytorch_training.multi_crit(crit_multi: List, output, labels, valid=None)[source]

Used for computing the loss when there are multiple criteria.

Parameters:
  • crit_multi (List) – The list of criteria to use for training.

  • output (_type_)

  • labels (_type_) – _description_

  • valid (_type_, optional) – _description_, defaults to None

Returns:

_description_

Return type:

_type_

flood_forecast.pytorch_training.handle_meta_data(model: PyTorchForecast)[source]

A function to initialize models with meta-data.

Parameters:

model (PyTorchForecast) – A PyTorchForecast model with meta_data parameter block in config file.

Returns:

Returns a tuple of the initial meta-representation

Return type:

tuple(PyTorchForecast, torch.Tensor, float)

flood_forecast.pytorch_training.make_crit(model_params: Dict) Module | List[source]

A function to create the criterion for training from the parameters.

Parameters:

model_params (Dict) – The training params Dict block in FF

flood_forecast.pytorch_training.train_transformer_style(model: PyTorchForecast, training_params: Dict, takes_target=False, forward_params: Dict = {}, model_filepath: str = 'model_save', class2=False) None[source]

Function to train any PyTorchForecast model.

Parameters:
  • model (PyTorchForecast) – A properly wrapped PyTorchForecast model

  • training_params (Dict) – A dictionary of the necessary parameters for training.

  • takes_target (bool, optional) – A parameter to determine whether a model requires the target, defaults to False

  • forward_params (Dict, optional) – [description], defaults to {}

  • model_filepath (str, optional) – The file path to load model weights from, defaults to “model_save”

flood_forecast.pytorch_training.get_meta_representation(column_id: str, uuid: str, meta_model: PyTorchForecast) Tensor[source]
flood_forecast.pytorch_training.handle_scaling(validation_dataset, src, output: Tensor, labels, probabilistic, m, output_std)[source]

Function that handles un-scaling the model output.

Parameters:
  • validation_dataset ([type]) – A dataset object for the validation dataset. We use its inverse scale method.

  • src (torch.Tensor) – [description]

  • output (torch.Tensor) – [description]

  • labels (torch.Tensor) – [description]

  • probabilistic (bool) – Whether the model is probablisitic or not.

  • m (int) – Whether there are multiple targets

  • output_std ([type]) – [description]

Returns:

[description]

Return type:

[type]

flood_forecast.pytorch_training.compute_loss(labels, output, src, criterion, validation_dataset, probabilistic=None, output_std=None, m=1)[source]

Function for computing the loss.

Parameters:
  • labels (torch.Tensor) – The real values for the target. Shape can be variable but should follow (batch_size, time)

  • output (torch.Tensor) – The output of the model

  • src (torch.Tensor) – The source values (only really needed for the MASELoss function)

  • criterion (torch.nn.Loss or some variation) – The loss function to use

  • validation_dataset (torch.utils.data.dataset) – Only passed when unscaling of data is needed.

  • probabilistic (bool, optional) – Whether the model is a probabalistic returns a distribution, defaults to None

  • output_std ([type], optional) – The standard distribution, defaults to None

  • m (int, optional) – The number of targets defaults to 1

Returns:

Returns the computed loss

Return type:

float

flood_forecast.pytorch_training.torch_single_train(model: PyTorchForecast, opt: Optimizer, criterion: Type[_Loss], data_loader: DataLoader, takes_target: bool, meta_data_model: PyTorchForecast, meta_data_model_representation: Tensor, meta_loss=None, multi_targets=1, forward_params: Dict = {}) float[source]

Function that performs training of a single model. Runs through one epoch of the data.

Parameters:
  • model (PyTorchForecast) – The PyTorchForecast model that is trained

  • opt (optim.Optimizer) – The optimizer to use in the code

  • criterion (Type[torch.nn.modules.loss._Loss]) – [m

  • data_loader (DataLoader) – [description]

  • takes_target (bool) – A boolean that indicates whether the model takes the target during training

  • meta_data_model (PyTorchForecast) – If supplied a model that handles meta-data else None.

  • meta_data_model_representation (torch.Tensor) – [description]

  • meta_loss ([type], optional) – [description], defaults to None

  • multi_targets (int, optional) – [description], defaults to 1

  • forward_params (Dict, optional) – [description], defaults to {}

Raises:

ValueError – [description]

Returns:

[description]

Return type:

float

flood_forecast.pytorch_training.compute_validation(validation_loader: DataLoader, model, epoch: int, sequence_size: int, criterion: Type[_Loss], device: device, decoder_structure=False, meta_data_model=None, use_wandb: bool = False, meta_model=None, multi_targets=1, val_or_test='validation_loss', probabilistic=False, classification=False) float[source]

Function to compute the validation loss metrics.

Parameters:
  • validation_loader (DataLoader) – The data-loader of either validation or test-data

  • model ([type]) – model

  • epoch (int) – The epoch where the validation/test loss is being computed.

  • sequence_size (int) – The length of the sequence (equivalent too

  • criterion (Type[torch.nn.modules.loss._Loss]) – [description]

  • device (torch.device) – The device

  • decoder_structure (bool, optional) – Whether the model should use sequential decoding, defaults to False

  • meta_data_model (PyTorchForecast, optional) – The model to handle the meta-data, defaults to None

  • use_wandb (bool, optional) – Whether Weights and Biases is in use, defaults to False

  • meta_model (bool, optional) – Whether the model leverages meta-data, defaults to None

  • multi_targets (int, optional) – Whether the model, defaults to 1

  • val_or_test (str, optional) – Whether validation or test loss is computed, defaults to “validation_loss”

  • probabilistic (bool, optional) – Whether the model is probablistic, defaults to False

Returns:

The loss of the first metric in the list.

Return type:

float