notice
This is documentation for Rasa Documentation v2.x, which is no longer actively maintained.
For up-to-date documentation, see the latest version (3.x).
rasa.utils.train_utils
normalize
Normalizes an array of positive numbers over the top ranking_length
values.
Other values will be set to 0.
update_similarity_type
If SIMILARITY_TYPE is set to 'auto', update the SIMILARITY_TYPE depending on the LOSS_TYPE.
Arguments:
config
- model configurationReturns
- updated model configuration
update_deprecated_loss_type
Updates LOSS_TYPE to 'cross_entropy' if it is set to 'softmax'.
Arguments:
config
- model configuration
Returns:
updated model configuration
update_deprecated_sparsity_to_density
Updates WEIGHT_SPARSITY
to CONNECTION_DENSITY = 1 - WEIGHT_SPARSITY
.
Arguments:
config
- model configuration
Returns:
Updated model configuration
align_token_features
Align token features to match tokens.
ConveRTTokenizer, LanguageModelTokenizers might split up tokens into sub-tokens. We need to take the mean of the sub-token vectors and take that as token vector.
Arguments:
list_of_tokens
- tokens for examplesin_token_features
- token features from ConveRTshape
- shape of feature matrix
Returns:
Token features.
update_evaluation_parameters
If EVAL_NUM_EPOCHS is set to -1, evaluate at the end of the training.
Arguments:
config
- model configurationReturns
- updated model configuration
load_tf_hub_model
Load model from cache if possible, otherwise from TFHub
check_deprecated_options
Update the config according to changed config params.
If old model configuration parameters are present in the provided config, replace them with the new parameters and log a warning.
Arguments:
config
- model configurationReturns
- updated model configuration
check_core_deprecated_options
Update the core config according to changed config params.
If old model configuration parameters are present in the provided config, replace them with the new parameters and log a warning.
Arguments:
config
- model configurationReturns
- updated model configuration
entity_label_to_tags
Convert the output predictions for entities to the actual entity tags.
Arguments:
model_predictions
- the output predictions using the entity tag indicesentity_tag_specs
- the entity tag specificationsbilou_flag
- if 'True', the BILOU tagging schema was usedprediction_index
- the index in the batch of predictions to use for entity extraction
Returns:
A map of entity tag type, e.g. entity, role, group, to actual entity tags and confidences.
override_defaults
Override default config with the given config.
We cannot use dict.update
method because configs contain nested dicts.
Arguments:
defaults
- default configcustom
- user config containing new parameters
Returns:
updated config
create_data_generators
Create data generators for train and optional validation data.
Arguments:
model_data
- The model data to use.batch_sizes
- The batch size(s).epochs
- The number of epochs to train.batch_strategy
- The batch strategy to use.eval_num_examples
- Number of examples to use for validation data.random_seed
- The random seed.shuffle
- Whether to shuffle data inside the data generator.
Returns:
The training data generator and optional validation data generator.
create_common_callbacks
Create common callbacks.
The following callbacks are created:
- RasaTrainingLogger callback
- Optional TensorBoard callback
- Optional RasaModelCheckpoint callback
Arguments:
epochs
- the number of epochs to traintensorboard_log_dir
- optional directory that should be used for tensorboardtensorboard_log_level
- defines when training metrics for tensorboard should be logged. Valid values: 'epoch' and 'batch'.checkpoint_dir
- optional directory that should be used for model checkpointing
Returns:
A list of callbacks.
update_confidence_type
Set model confidence to auto if margin loss is used.
Option auto
is reserved for margin loss type. It will be removed once margin loss
is deprecated.
Arguments:
component_config
- model configuration
Returns:
updated model configuration
validate_configuration_settings
Validates that combination of parameters in the configuration are correctly set.
Arguments:
component_config
- Configuration to validate.
init_split_entities
Initialise the behaviour for splitting entities by comma (or not).
Returns:
Defines desired behaviour for splitting specific entity types and default behaviour for splitting any entity types for which no behaviour is defined.