Version: 3.x
rasa.engine.training.hooks
TrainingHook Objects
class TrainingHook(GraphNodeHook)
Caches fingerprints and outputs of nodes during model training.
__init__
def __init__(cache: TrainingCache, model_storage: ModelStorage,
pruned_schema: GraphSchema) -> None
Initializes a TrainingHook
.
Arguments:
cache
- Cache used to store fingerprints and outputs.model_storage
- Used to cacheResource
s.pruned_schema
- The pruned training schema.
on_before_node
def on_before_node(node_name: Text, execution_context: ExecutionContext,
config: Dict[Text, Any],
received_inputs: Dict[Text, Any]) -> Dict
Calculates the run fingerprint for use in on_after_node
.
on_after_node
def on_after_node(node_name: Text, execution_context: ExecutionContext,
config: Dict[Text, Any], output: Any,
input_hook_data: Dict) -> None
Stores the fingerprints and caches the output of the node.
LoggingHook Objects
class LoggingHook(GraphNodeHook)
Logs the training of components.
__init__
def __init__(pruned_schema: GraphSchema) -> None
Creates hook.
Arguments:
pruned_schema
- The pruned schema provides us with the information whether a component is cached or not.
on_before_node
def on_before_node(node_name: Text, execution_context: ExecutionContext,
config: Dict[Text, Any],
received_inputs: Dict[Text, Any]) -> Dict
Logs the training start of a graph node.
on_after_node
def on_after_node(node_name: Text, execution_context: ExecutionContext,
config: Dict[Text, Any], output: Any,
input_hook_data: Dict) -> None
Logs when a component finished its training.