Inference
This API makes it easy to run inference on trained PyTorchForecast modules. To use this code you need three main files: your model’s configuration file, a CSV containing your data, and a path to your model weights.
import json
from datetime import datetime
from flood_forecast.deployment.inference import InferenceMode
new_water_data_path = "gs://predict_cfs/day_addition/01046000KGNR_flow.csv"
weight_path = "gs://predict_cfs/experiments/10_December_202009_34PM_model.pth"
with open("config.json") as y:
config_test = json.load(y)
infer_model = InferenceMode(336, 30, config_test, new_water_data_path, weight_path, "river")
- class flood_forecast.deployment.inference.InferenceMode(forecast_steps: int, num_prediction_samples: int, model_params, csv_path: str | DataFrame, weight_path, wandb_proj: str = None, torch_script=False)[source]
- __init__(forecast_steps: int, num_prediction_samples: int, model_params, csv_path: str | DataFrame, weight_path, wandb_proj: str = None, torch_script=False)[source]
Class to handle inference for models,
- Parameters:
forecast_steps (int) – Number of time-steps to forecast (doesn’t have to be hours)
num_prediction_samples (int) – The number of prediction samples
model_params (Dict) – A dictionafry of model parameters (ideally this should come from saved JSON config file)
csv_path (str) – Path to the CSV test file you want to be used for inference or a Pandas dataframe.
weight_path (str) – Path to the model weights (.pth file)
wandb_proj (str, optionals) – The name of the WB project leave blank if you don’t want to log to Wandb, defaults to None
- infer_now(some_date: datetime, csv_path=None, save_buck=None, save_name=None, use_torch_script=False)[source]
Performs inference on a CSV file at a specified date-time.
- Parameters:
some_date – The date you want inference to begin on.
csv_path (str, optional) – A path to a CSV you want to perform inference on, defaults to None
save_buck (str, optional) – The GCP bucket where you want to save predictions, defaults to None
save_name (str, optional) – The name of the file to save the Pandas data-frame to GCP as, defaults to None
use_torch_script – Optional parameter which allows you to use a saved torch script version of your model.
- Returns:
Returns a tuple consisting of the Pandas dataframe with predictions + history,
the prediction tensor, a tensor of the historical values, the forecast start index, the test loader, and the a dataframe of the prediction samples (e.g. the confidence interval preds) :rtype: tuple(pd.DataFrame, torch.Tensor, int, CSVTestLoader, pd.DataFrame)
- infer_now_classification(data=None, over_lap_seq=True, save_buck=None, save_name=None, batch_size=1)[source]
Function to preform classification/anomaly detection on sequences in real-time.
- Parameters:
data (Union[pd.DataFrame, str], optional) – The data to perform inference on
over_lap_seq (bool,) – Whether to increment by one throughout the df or by sequence length
batch_size – The batch size to use, defaults to 1
- make_plots(date: datetime, csv_path: str = None, csv_bucket: str = None, save_name=None, wandb_plot_id=None)[source]
Function to create plots in inference mode.
- Parameters:
date (datetime) – The datetime to start inference
csv_path (str, optional) – The path to the CSV file or you want to use for inference, defaults to None
csv_bucket (str, optional) – The bucket where the CSV file is located, defaults to None
save_name (str, optional) – Where to save the output csv, defaults to None
wandb_plot_id (str, optional) – The id to save wandb plot as on dashboard, defaults to None
- Returns:
[description]
- Return type:
tuple(torch.Tensor, torch.Tensor, CSVTestLoader, matplotlib.pyplot.plot)
- flood_forecast.deployment.inference.convert_to_torch_script(model: PyTorchForecast, save_path: str) PyTorchForecast [source]
Function to convert PyTorch model to torch script and save.
- Parameters:
model (PyTorchForecast) – The PyTorchForecast model you wish to convert
save_path (str) – File name to save the TorchScript model under.
- Returns:
Returns the model with an added .script_model attribute
- Return type:
- flood_forecast.deployment.inference.load_model(model_params_dict, file_path: str, weight_path: str) PyTorchForecast [source]
Function to load a PyTorchForecast model from an existing config file.
- Parameters:
model_params_dict (Dict) – Dictionary of model parameters
file_path (str) – The path to the CSV for running infer
weight_path (str) – The path to the model weights (can be GCS)
- Returns:
Returns a PyTorchForecast model initialized with the proper data
- Return type: