notice

This is unreleased documentation for Rasa Documentation Main/Unreleased version.
For the latest released documentation, see the latest version (3.x).

Version: Main/Unreleased

rasa.utils.tensorflow.callback

RasaTrainingLogger Objects

class RasaTrainingLogger(tf.keras.callbacks.Callback)

Callback for logging the status of training.

__init__

def __init__(epochs: int, silent: bool) -> None

Initializes the callback.

Arguments:

  • epochs - Total number of epochs.
  • silent - If 'True' the entire progressbar wrapper is disabled.

on_epoch_end

def on_epoch_end(epoch: int, logs: Optional[Dict[Text, Any]] = None) -> None

Updates the logging output on every epoch end.

Arguments:

  • epoch - The current epoch.
  • logs - The training metrics.

on_train_end

def on_train_end(logs: Optional[Dict[Text, Any]] = None) -> None

Closes the progress bar after training.

Arguments:

  • logs - The training metrics.

RasaModelCheckpoint Objects

class RasaModelCheckpoint(tf.keras.callbacks.Callback)

Callback for saving intermediate model checkpoints.

__init__

def __init__(checkpoint_dir: Path) -> None

Initializes the callback.

Arguments:

  • checkpoint_dir - Directory to store checkpoints to.

on_epoch_end

def on_epoch_end(epoch: int, logs: Optional[Dict[Text, Any]] = None) -> None

Save the model on epoch end if the model has improved.

Arguments:

  • epoch - The current epoch.
  • logs - The training metrics.