This API makes it easy to run inference on trained PyTorchForecast modules. To use this code you need three main files: your model’s configuration file, a CSV containing your data, and a path to your model weights.

example initialization
import json
from datetime import datetime
from flood_forecast.deployment.inference import InferenceMode
new_water_data_path = "gs://predict_cfs/day_addition/01046000KGNR_flow.csv"
weight_path = "gs://predict_cfs/experiments/10_December_202009_34PM_model.pth"
with open("config.json") as y:
  config_test = json.load(y)
infer_model = InferenceMode(336, 30, config_test, new_water_data_path, weight_path, "river")
example plotting

class flood_forecast.deployment.inference.InferenceMode(forecast_steps: int, num_prediction_samples: int, model_params, csv_path: str, weight_path, wandb_proj: Optional[str] = None, torch_script=False)[source]
__init__(forecast_steps: int, num_prediction_samples: int, model_params, csv_path: str, weight_path, wandb_proj: Optional[str] = None, torch_script=False)[source]

Class to handle inference for models,

  • forecasts_steps – Number of time-steps to forecast (doesn’t have to be hours)

  • num_prediction_samples (int) – Number of prediction samples

  • model_params (Dict) – A dictionary of model parameters (ideally this should come from saved JSON config file)

  • csv_path (str) – Path to the CSV test file you want to be used for inference. Evem of you aren’t using

  • weight_path (str) – Path to the model weights

  • wandb_proj (str, optionals) – The name of the WB project leave blank if you don’t want to log to Wandb, defaults to None

infer_now(some_date: datetime.datetime, csv_path=None, save_buck=None, save_name=None, use_torch_script=False)[source]

Performs inference on a CSV file at a specified datatime

  • some_date – The date you want inference to begin on.

  • csv_path (str, optional) – A path to a CSV you want to perform inference on, defaults to None

  • save_buck (str, optional) – The GCP bucket where you want to save predictions, defaults to None

  • save_name (str, optional) – The name of the file to save the Pandas data-frame to GCP as, defaults to None

  • use_torch_script – Optional parameter which allows you to use a saved torch script version of your model.


Returns a tuple consisting of the Pandas dataframe with predictions + history,

the prediction tensor, a tensor of the historical values, the forecast start index, the test loader, and the a dataframe of the prediction samples (e.g. the confidence interval preds) :rtype: tuple(pd.DataFrame, torch.Tensor, int, CSVTestLoader, pd.DataFrame)

make_plots(date: datetime.datetime, csv_path: Optional[str] = None, csv_bucket: Optional[str] = None, save_name=None, wandb_plot_id=None)[source]

Function to create plots in inference mode.

  • date (datetime) – The datetime to start inference

  • csv_path (str, optional) – The path to the CSV file you want to use for inference, defaults to None

  • csv_bucket (str, optional) – [description], defaults to None

  • save_name ([type], optional) – [description], defaults to None

  • wandb_plot_id ([type], optional) – [description], defaults to None



Return type

tuple(torch.Tensor, torch.Tensor, CSVTestLoader, matplotlib.pyplot.plot)

flood_forecast.deployment.inference.convert_to_torch_script(model: flood_forecast.time_model.PyTorchForecast, save_path: str) flood_forecast.time_model.PyTorchForecast[source]

Function to convert PyTorch model to torch script and save

  • model (PyTorchForecast) – The PyTorchForecast model you wish to convert

  • save_path (str) – File name to save the TorchScript model under.


Returns the model with an added .script_model attribute

Return type


flood_forecast.deployment.inference.load_model(model_params_dict, file_path: str, weight_path: str) flood_forecast.time_model.PyTorchForecast[source]

Function to load a PyTorchForecast model from an existing config file.

  • model_params_dict (Dict) – Dictionary of model parameters

  • file_path (str) – [description]

  • weight_path (str) – [description]



Return type