rasa.nlu.classifiers._diet_classifier
DIETClassifier Objects
A multi-task model for intent classification and entity extraction.
DIET is Dual Intent and Entity Transformer.
The architecture is based on a transformer which is shared for both tasks.
A sequence of entity labels is predicted through a Conditional Random Field (CRF)
tagging layer on top of the transformer output sequence corresponding to the
input sequence of tokens. The transformer output for the __CLS__
token and
intent labels are embedded into a single semantic vector space. We use the
dot-product loss to maximize the similarity with the target label and minimize
similarities with negative samples.
__init__
Declare instance variables with default values.
label_key
Return key if intent classification is activated.
label_sub_key
Return sub key if intent classification is activated.
preprocess_train_data
Prepares data for training.
Performs sanity checks on training data, extracts encodings for labels.
train
Train the embedding intent classifier on a data set.
process
Augments the message with intents, entities, and diagnostic data.
persist
Persist this model into the passed directory.
Return the metadata necessary to load the model again.
load
Loads the trained model from the provided directory.
DIET Objects
batch_loss
Calculates the loss for the given batch.
Arguments:
batch_in
- The batch.
Returns:
The loss of the given batch.
prepare_for_predict
Prepares the model for prediction.
batch_predict
Predicts the output of the given batch.
Arguments:
batch_in
- The batch.
Returns:
The output to predict.