Trainer

flood_forecast.trainer.handle_model_evaluation1(test_acc, params: Dict) None[source]

Utility function to help handle model evaluation. Primarily used at the moment for forecasting models.

Parameters:
  • trained_model (PyTorchForecast) – A PyTorchForecast model that has already been trained.

  • params (Dict) – A dictionary of the trained model parameters.

  • model_type (str) – The type of model. Almost always PyTorch in practice.

flood_forecast.trainer.handle_core_eval(trained_model, params: Dict, model_type: str)[source]

_summary_

Parameters:
  • trained_model (_type_) – _description_

  • params (Dict) – _description_

  • model_type (str) – _description_

flood_forecast.trainer.train_function(model_type: str, params: Dict) PyTorchForecast[source]

Function to train a Model(TimeSeriesModel) or da_rnn. Will return the trained model.

Parameters:
  • model_type (str) – Type of the model. In almost all cases this will be ‘PyTorch’

  • params – Dictionary containing all the parameters needed to run the model.

Returns:

A trained model

with open("model_config.json") as f:
    params_dict = json.load(f)
train_function("PyTorch", params_dict)

For information on what this params_dict should include see Confluence pages on training models.

flood_forecast.trainer.correct_stupid_sklearn_error(training_conf: Dict) Dict[source]

Sklearn for whatever reason decided to only allow scaler params in the form of tuples this was stupid so now we have to convert JSON list to tuple.

Parameters:

scaling_params – A list of the scaling params

flood_forecast.trainer.main()[source]

Main function which is called from the command line.

Entrypoint for training all TS models.