notice

This is unreleased documentation for Rasa Documentation Main/Unreleased version.
For the latest released documentation, see the latest version (3.x).

Version: Main/Unreleased

rasa.engine.training.graph_trainer

GraphTrainer Objects

class GraphTrainer()

Trains a model using a graph schema.

__init__

def __init__(model_storage: ModelStorage, cache: TrainingCache,
graph_runner_class: Type[GraphRunner]) -> None

Initializes a GraphTrainer.

Arguments:

  • model_storage - Storage which graph components can use to persist and load. Also used for packaging the trained model.
  • cache - Cache used to store fingerprints and outputs.
  • graph_runner_class - The class to instantiate the runner from.

train

def train(model_configuration: GraphModelConfiguration,
importer: TrainingDataImporter,
output_filename: Path,
force_retraining: bool = False,
is_finetuning: bool = False) -> ModelMetadata

Trains and packages a model and returns the prediction graph runner.

Arguments:

  • model_configuration - The model configuration (schemas, language, etc.)
  • importer - The importer which provides the training data for the training.
  • output_filename - The location to save the packaged model.
  • force_retraining - If True then the cache is skipped and all components are retrained.

Returns:

The metadata describing the trained model.

fingerprint

def fingerprint(
train_schema: GraphSchema,
importer: TrainingDataImporter,
is_finetuning: bool = False
) -> Dict[Text, Union[FingerprintStatus, Any]]

Runs the graph using fingerprints to determine which nodes need to re-run.

Nodes which have a matching fingerprint key in the cache can either be removed entirely from the graph, or replaced with a cached value if their output is needed by descendent nodes.

Arguments:

  • train_schema - The train graph schema that will be run in fingerprint mode.
  • importer - The importer which provides the training data for the training.
  • is_finetuning - True if we want to finetune the model.

Returns:

Mapping of node names to fingerprint results.