notice

This is documentation for Rasa Open Source Documentation v2.5.x, which is no longer actively maintained.
For up-to-date documentation, see the latest version (2.8.x).

Version: 2.5.x

rasa.utils.train_utils

normalize

normalize(values: np.ndarray, ranking_length: Optional[int] = 0) -> np.ndarray

Normalizes an array of positive numbers over the top ranking_length values.

Other values will be set to 0.

update_similarity_type

update_similarity_type(config: Dict[Text, Any]) -> Dict[Text, Any]

If SIMILARITY_TYPE is set to 'auto', update the SIMILARITY_TYPE depending on the LOSS_TYPE.

Arguments:

  • config - model configuration
  • Returns - updated model configuration

update_deprecated_loss_type

update_deprecated_loss_type(config: Dict[Text, Any]) -> Dict[Text, Any]

Updates LOSS_TYPE to 'cross_entropy' if it is set to 'softmax'.

Arguments:

  • config - model configuration

Returns:

updated model configuration

align_token_features

align_token_features(list_of_tokens: List[List["Token"]], in_token_features: np.ndarray, shape: Optional[Tuple] = None) -> np.ndarray

Align token features to match tokens.

ConveRTTokenizer, LanguageModelTokenizers might split up tokens into sub-tokens. We need to take the mean of the sub-token vectors and take that as token vector.

Arguments:

  • list_of_tokens - tokens for examples
  • in_token_features - token features from ConveRT
  • shape - shape of feature matrix

Returns:

Token features.

update_evaluation_parameters

update_evaluation_parameters(config: Dict[Text, Any]) -> Dict[Text, Any]

If EVAL_NUM_EPOCHS is set to -1, evaluate at the end of the training.

Arguments:

  • config - model configuration
  • Returns - updated model configuration

load_tf_hub_model

load_tf_hub_model(model_url: Text) -> Any

Load model from cache if possible, otherwise from TFHub

check_deprecated_options

check_deprecated_options(config: Dict[Text, Any]) -> Dict[Text, Any]

Update the config according to changed config params.

If old model configuration parameters are present in the provided config, replace them with the new parameters and log a warning.

Arguments:

  • config - model configuration
  • Returns - updated model configuration

check_core_deprecated_options

check_core_deprecated_options(config: Dict[Text, Any]) -> Dict[Text, Any]

Update the core config according to changed config params.

If old model configuration parameters are present in the provided config, replace them with the new parameters and log a warning.

Arguments:

  • config - model configuration
  • Returns - updated model configuration

entity_label_to_tags

entity_label_to_tags(model_predictions: Dict[Text, Any], entity_tag_specs: List["EntityTagSpec"], bilou_flag: bool = False, prediction_index: int = 0) -> Tuple[Dict[Text, List[Text]], Dict[Text, List[float]]]

Convert the output predictions for entities to the actual entity tags.

Arguments:

  • model_predictions - the output predictions using the entity tag indices
  • entity_tag_specs - the entity tag specifications
  • bilou_flag - if 'True', the BILOU tagging schema was used
  • prediction_index - the index in the batch of predictions to use for entity extraction

Returns:

A map of entity tag type, e.g. entity, role, group, to actual entity tags and confidences.

override_defaults

override_defaults(defaults: Optional[Dict[Text, Any]], custom: Optional[Dict[Text, Any]]) -> Dict[Text, Any]

Override default config with the given config.

We cannot use dict.update method because configs contain nested dicts.

Arguments:

  • defaults - default config
  • custom - user config containing new parameters

Returns:

updated config

create_data_generators

create_data_generators(model_data: RasaModelData, batch_sizes: Union[int, List[int]], epochs: int, batch_strategy: Text = SEQUENCE, eval_num_examples: int = 0, random_seed: Optional[int] = None) -> Tuple[RasaBatchDataGenerator, Optional[RasaBatchDataGenerator]]

Create data generators for train and optional validation data.

Arguments:

  • model_data - The model data to use.
  • batch_sizes - The batch size(s).
  • epochs - The number of epochs to train.
  • batch_strategy - The batch strategy to use.
  • eval_num_examples - Number of examples to use for validation data.
  • random_seed - The random seed.

Returns:

The training data generator and optional validation data generator.

create_common_callbacks

create_common_callbacks(epochs: int, tensorboard_log_dir: Optional[Text] = None, tensorboard_log_level: Optional[Text] = None, checkpoint_dir: Optional[Path] = None) -> List["Callback"]

Create common callbacks.

The following callbacks are created:

  • RasaTrainingLogger callback
  • Optional TensorBoard callback
  • Optional RasaModelCheckpoint callback

Arguments:

  • epochs - the number of epochs to train
  • tensorboard_log_dir - optional directory that should be used for tensorboard
  • tensorboard_log_level - defines when training metrics for tensorboard should be logged. Valid values: 'epoch' and 'batch'.
  • checkpoint_dir - optional directory that should be used for model checkpointing

Returns:

A list of callbacks.

update_confidence_type

update_confidence_type(component_config: Dict[Text, Any]) -> Dict[Text, Any]

Set model confidence to auto if margin loss is used.

Option auto is reserved for margin loss type. It will be removed once margin loss is deprecated.

Arguments:

  • component_config - model configuration

Returns:

updated model configuration

validate_configuration_settings

validate_configuration_settings(component_config: Dict[Text, Any]) -> None

Validates that combination of parameters in the configuration are correctly set.

Arguments:

  • component_config - Configuration to validate.

init_split_entities

init_split_entities(split_entities_config: Union[bool, Dict[Text, Any]], default_split_entity: bool) -> Dict[Text, bool]

Initialise the behaviour for splitting entities by comma (or not).

Returns:

Defines desired behaviour for splitting specific entity types and default behaviour for splitting any entity types for which no behaviour is defined.