notice

This is unreleased documentation for Rasa Open Source Documentation Master/Unreleased version.
For the latest released documentation, see the latest version (2.8.x).

Version: Master/Unreleased

rasa.engine.training.graph_trainer

GraphTrainer Objects

class GraphTrainer()

Trains a model using a graph schema.

__init__

def __init__(model_storage: ModelStorage, cache: TrainingCache, graph_runner_class: Type[GraphRunner]) -> None

Initializes a GraphTrainer.

Arguments:

  • model_storage - Storage which graph components can use to persist and load. Also used for packaging the trained model.
  • cache - Cache used to store fingerprints and outputs.
  • graph_runner_class - The class to instantiate the runner from.

train

def train(train_schema: GraphSchema, predict_schema: GraphSchema, domain_path: Path, output_filename: Path) -> GraphRunner

Trains and packages a model and returns the prediction graph runner.

Arguments:

  • train_schema - The train graph schema.
  • predict_schema - The predict graph schema.
  • domain_path - The path to the domain file.
  • output_filename - The location to save the packaged model.

Returns:

A graph runner loaded with the predict schema.