Inference¶
This API makes it easy to run inference on trained PyTorchForecast modules. To use this code you need three main files: your model’s configuration file, a CSV containing your data, and a path to your model weights.
import json
from datetime import datetime
from flood_forecast.deployment.inference import InferenceMode
new_water_data_path = "gs://predict_cfs/day_addition/01046000KGNR_flow.csv"
weight_path = "gs://predict_cfs/experiments/10_December_202009_34PM_model.pth"
with open("config.json") as y:
config_test = json.load(y)
infer_model = InferenceMode(336, 30, config_test, new_water_data_path, weight_path, "river")
-
class
flood_forecast.deployment.inference.
InferenceMode
(forecast_steps: int, num_prediction_samples: int, model_params, csv_path: str, weight_path, wandb_proj: Optional[str] = None, torch_script=False)[source]¶ -
__init__
(forecast_steps: int, num_prediction_samples: int, model_params, csv_path: str, weight_path, wandb_proj: Optional[str] = None, torch_script=False)[source]¶ Class to handle inference for models,
- Parameters
forecasts_steps – Number of time-steps to forecast (doesn’t have to be hours)
num_prediction_samples (int) – Number of prediction samples
model_params (Dict) – A dictionary of model parameters (ideally this should come from saved JSON config file)
csv_path (str) – Path to the CSV test file you want to be used for inference. Evem of you aren’t using
weight_path (str) – Path to the model weights
wandb_proj (str, optionals) – The name of the WB project leave blank if you don’t want to log to Wandb, defaults to None
-
infer_now
(some_date: datetime.datetime, csv_path=None, save_buck=None, save_name=None, use_torch_script=False)[source]¶ Performs inference on a CSV file at a specified datatime
- Parameters
some_date – The date you want inference to begin on.
csv_path (str, optional) – A path to a CSV you want to perform inference on, defaults to None
save_buck (str, optional) – The GCP bucket where you want to save predictions, defaults to None
save_name (str, optional) – The name of the file to save the Pandas data-frame to GCP as, defaults to None
use_torch_script – Optional parameter which allows you to use a saved torch script version of your model.
- Returns
Returns a tuple consisting of the Pandas dataframe with predictions + history,
the prediction tensor, a tensor of the historical values, the forecast start index, the test loader, and the a dataframe of the prediction samples (e.g. the confidence interval preds) :rtype: tuple(pd.DataFrame, torch.Tensor, int, CSVTestLoader, pd.DataFrame)
-
make_plots
(date: datetime.datetime, csv_path: Optional[str] = None, csv_bucket: Optional[str] = None, save_name=None, wandb_plot_id=None)[source]¶ Function to create plots in inference mode.
- Parameters
date (datetime) – The datetime to start inference
csv_path (str, optional) – The path to the CSV file you want to use for inference, defaults to None
csv_bucket (str, optional) – [description], defaults to None
save_name ([type], optional) – [description], defaults to None
wandb_plot_id ([type], optional) – [description], defaults to None
- Returns
[description]
- Return type
tuple(torch.Tensor, torch.Tensor, CSVTestLoader, matplotlib.pyplot.plot)
-
-
flood_forecast.deployment.inference.
convert_to_torch_script
(model: flood_forecast.time_model.PyTorchForecast, save_path: str) → flood_forecast.time_model.PyTorchForecast[source]¶ Function to convert PyTorch model to torch script and save
- Parameters
model (PyTorchForecast) – The PyTorchForecast model you wish to convert
save_path (str) – File name to save the TorchScript model under.
- Returns
Returns the model with an added .script_model attribute
- Return type
-
flood_forecast.deployment.inference.
load_model
(model_params_dict, file_path: str, weight_path: str) → flood_forecast.time_model.PyTorchForecast[source]¶ Function to load a PyTorchForecast model from an existing config file.
- Parameters
model_params_dict (Dict) – Dictionary of model parameters
file_path (str) – [description]
weight_path (str) – [description]
- Returns
[description]
- Return type