Text classification for online conversations with machine learning on AWS
Online conversations are ubiquitous in modern life, spanning industries from video games to telecommunications. This has led to an exponential growth in the amount of online conversation data, which has helped in the development of state-of-the-art natural language processing (NLP) systems like chatbots and natural language generation (NLG) models. Over time, various NLP techniques for text analysis have also evolved. This necessitates the requirement for a fully managed service that can be integrated into applications using API calls without the need for extensive machine learning (ML) expertise. AWS offers pre-trained AWS AI services like Amazon Comprehend, which can effectively handle NLP use cases involving classification, text summarization, entity recognition, and more to gather insights from text.
Additionally, online conversations have led to a wide-spread phenomenon of non-traditional usage of language. Traditional NLP techniques often perform poorly on this text data due to the constantly evolving and domain-specific vocabularies that exist within different platforms, as well as the significant lexical deviations of words from proper English, either by accident or intentionally as a form of adversarial attack.
In this post, we describe multiple ML approaches for text classification of online conversations with tools and services available on AWS.
Before diving deep into this use case, please complete the following prerequisites:
For this post, we use the Jigsaw Unintended Bias in Toxicity Classification dataset, a benchmark for the specific problem of classification of toxicity in online conversations. The dataset provides toxicity labels as well as several subgroup attributes such as obscene, identity attack, insult, threat, and sexually explicit. Labels are provided as fractional values, which represent the proportion of human annotators who believed the attribute applied to a given piece of text, which are rarely unanimous. To generate binary labels (for example, toxic or non-toxic), a threshold of 0.5 is applied to the fractional values, and comments with values greater than the threshold are treated as the positive class for that label.
Subword embedding and RNNs
For our first modeling approach, we use a combination of subword embedding and recurrent neural networks (RNNs) to train text classification models. Subword embeddings were introduced by Bojanowski et al. in 2017 as an improvement upon previous word-level embedding methods. Traditional Word2Vec skip-gram models are trained to learn a static vector representation of a target word that optimally predicts that word’s context. Subword models, on the other hand, represent each target word as a bag of the character n-grams that make up the word, where an n-gram is composed of a set of n consecutive characters. This method allows for the embedding model to better represent the underlying morphology of related words in the corpus as well as the computation of embeddings for novel, out-of-vocabulary (OOV) words. This is particularly important in the context of online conversations, a problem space in which users often misspell words (sometimes intentionally to evade detection) and also use a unique, constantly evolving vocabulary that might not be captured by a general training corpus.
Amazon SageMaker makes it easy to train and optimize an unsupervised subword embedding model on your own corpus of domain-specific text data with the built-in BlazingText algorithm. We can also download existing general-purpose models trained on large datasets of online text, such as the following English language models available directly from fastText. From your SageMaker notebook instance, simply run the following to download a pretrained fastText model:
Whether you’ve trained your own embeddings with BlazingText or downloaded a pretrained model, the result is a zipped model binary that you can use with the gensim library to embed a given target word as a vector based on its constituent subwords:
After we preprocess a given segment of text, we can use this approach to generate a vector representation for each of the constituent words (as separated by spaces). We then use SageMaker and a deep learning framework such as PyTorch to train a customized RNN with a binary or multilabel classification objective to predict whether the text is toxic or not and the specific sub-type of toxicity based on labeled training examples.
To upload your preprocessed text to Amazon Simple Storage Service (Amazon S3), use the following code:
To initiate scalable, multi-GPU model training with SageMaker, enter the following code:
Within <source_dir_path>, we define a PyTorch Dataset that is used by train.py to prepare the text data for training and evaluation of the model:
Note that this code anticipates that the vectors.zip file containing your fastText or BlazingText embeddings will be stored in <source_dir_path>.
Additionally, you can easily deploy pretrained fastText models on their own to live SageMaker endpoints to compute embedding vectors on the fly for use in relevant word-level tasks. See the following GitHub example for more details.
Transformers with Hugging Face
For our second modeling approach, we transition to the usage of Transformers, introduced in the paper Attention Is All You Need. Transformers are deep learning models designed to deliberately avoid the pitfalls of RNNs by relying on a self-attention mechanism to draw global dependencies between input and output. The Transformer model architecture allows for significantly better parallelization and can achieve high performance in relatively short training time.
Built on the success of Transformers, BERT, introduced in the paper BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding, added bidirectional pre-training for language representation. Inspired by the Cloze task, BERT is pre-trained with masked language modeling (MLM), in which the model learns to recover the original words for randomly masked tokens. The BERT model is also pretrained on the next sentence prediction (NSP) task to predict if two sentences are in correct reading order. Since its advent in 2018, BERT and its variations have been widely used in text classification tasks.
Our solution uses a variant of BERT known as RoBERTa, which was introduced in the paper RoBERTa: A Robustly Optimized BERT Pretraining Approach. RoBERTa further improves BERT performance on a variety of natural language tasks by optimized model training, including training models longer on a 10 times larger bigger corpus, using optimized hyperparameters, dynamic random masking, removing the NSP task, and more.
Our RoBERTa-based models use the Hugging Face Transformers library, which is a popular open-source Python framework that provides high-quality implementations of all kinds of state-of-the-art Transformer models for a variety of NLP tasks. Hugging Face has partnered with AWS to enable you to easily train and deploy Transformer models on SageMaker. This functionality is available through Hugging Face AWS Deep Learning Container images, which include the Transformers, Tokenizers, and Datasets libraries, and optimized integration with SageMaker for model training and inference.
In our implementation, we inherit the RoBERTa architecture backbone from the Hugging Face Transformers framework and use SageMaker to train and deploy our own text classification model, which we call RoBERTox. RoBERTox uses byte pair encoding (BPE), introduced in Neural Machine Translation of Rare Words with Subword Units, to tokenize input text into subword representations. We can then train our models and tokenizers on the Jigsaw data or any large domain-specific corpus (such as the chat logs from a specific game) and use them for customized text classification. We define our custom classification model class in the following code:
Before training, we prepare our text data and labels using Hugging Face’s datasets library and upload the result to Amazon S3:
We initiate training of the model in a similar fashion to the RNN:
Finally, the following Python code snippet illustrates the process of serving RoBERTox via a live SageMaker endpoint for real-time text classification for a JSON request:
Evaluation of model performance: Jigsaw unintended bias dataset
The following table contains performance metrics for models trained and evaluated on data from the Jigsaw Unintended Bias in Toxicity Detection Kaggle competition. We trained models for three different but interrelated tasks:
Binary case – The model was trained on the full training dataset to predict the toxicity label only
Fine-grained case – The subset of the training data for which toxicity>=0.5 was used to predict other toxicity sub-type labels (obscene, threat, insult, identity_attack, sexual_explicit)
Multitask case – The full training dataset was used to predict all six labels simultaneously
We trained RNN and RoBERTa models for each of these three tasks using the Jigsaw-provided fractional labels, which correspond to the proportion of annotators who thought the label was appropriate for the text, as well as with binary labels combined with class weights in the network loss function. In the binary labeling scheme, the proportions were thresholded at 0.5 for each available label (1 if label>=0.5, 0 otherwise), and the model loss functions were weighted based on the relative proportions of each binary label in the training dataset. In all cases, we found that using the fractional labels directly resulted in the best performance, indicating the added value of the information inherent in the degree of agreement between annotators.
We display two model metrics: the average precision (AP), which provides a summary of the precision-recall curve by computing the weighted mean of the precision values achieved at each classification threshold, and the area under the receiver operating characteristic curve (AUC), which aggregates model performance across classification thresholds with respect to the true positive rate and false positive rate. Note that the true class for a given text instance in the test set corresponds to whether the true proportion is greater than or equal to 0.5 (1 if label>=0.5, 0 otherwise).
Subword Embedding + RNN
Binary labels + Class weighting
Binary labels + Class weighting
In this post, we presented two text classification approaches for online conversations using AWS ML services. You can generalize these solutions across online communication platforms, with industries such as gaming particularly likely to benefit from improved ability to detect harmful content. In future posts, we plan to further discuss an end-to-end architecture for seamless deployment of models into your AWS account.
If you’d like help accelerating your use of ML in your products and processes, please contact the Amazon ML Solutions Lab.
About the Authors
Ryan Brand is a Data Scientist in the Amazon Machine Learning Solutions Lab. He has specific experience in applying machine learning to problems in healthcare and the life sciences, and in his free time he enjoys reading history and science fiction.
Sourav Bhabesh is a Data Scientist at the Amazon ML Solutions Lab. He develops AI/ML solutions for AWS customers across various industries. His specialty is Natural Language Processing (NLP) and is passionate about deep learning. Outside of work he enjoys reading books and traveling.
Liutong Zhou is an Applied Scientist at the Amazon ML Solutions Lab. He builds bespoke AI/ML solutions for AWS customers across various industries. He specializes in Natural Language Processing (NLP) and is passionate about multi-modal deep learning. He is a lyric tenor and enjoys singing operas outside of work.
Sia Gholami is a Senior Data Scientist at the Amazon ML Solutions Lab, where he builds AI/ML solutions for customers across various industries. He is passionate about natural language processing (NLP) and deep learning. Outside of work, Sia enjoys spending time in nature and playing tennis.