If you’re reading this article, you probably know about Deep Learning Transformer models like BERT. They’re revolutionizing the way we do Natural Language Processing (NLP).
In case you don’t know, we wrote about the history and impact of BERT and the Transformer architecture in a previous post.
These models perform very well. But why? And why does BERT perform so well in comparison to other Transformer models?
Some might say that there’s nothing special about BERT. That the Transformer architecture is singularly responsible for the state-of-the-art (SOTA) improvements we’ve seen in recent years.
This is confusing since we need to know exactly what parts of BERT are worth replicating in future models. Otherwise, we may end up producing copycat models which are unnecessarily large and contain layers and training objectives that don’t contribute to improvements.
To clear the confusion, we’ll look at some of the latest research in this area and propose that:
- Masking is the key: “Masking” is the training objective responsible for most of the success we attribute to BERT and BERT-like models.
- Masking needs Attention: While masking is the critical element that differentiates BERT from other models, it’s built on the attention mechanism introduced via the Transformer architecture.
- We still don’t understand masking: Despite being a core element in recent NLP SOTA results, we still don’t fully understand how masking works. We’ll look at recent research that shows syntactic aspects of language (like word ordering), previously thought to be critical to the success of masking, are unimportant. This raises important questions about what we think these models are learning about language.
- We may need to re-evaluate how we learn language: How humans learn a language is a topic of ongoing debate. For example, we believe that humans are doing more than simple co-occurrence and pattern matching when they learn the meaning of words. But if this is in fact how models like BERT learn language, and if they can perform at or near human level, then, are humans just organic statistical inference engines that learn and understand language in the same way? Maybe we need to revisit how we think humans learn meaning from language.
It seems clear that at the moment, practice is far ahead of theory in the technological life cycle of Deep Learning for NLP. We’re using approaches like masking which seem to work, and we fiddle with the numbers a bit and it works a little better or a little worse. Yet we can’t fully explain why that happens!
Some people may find this frustrating and disappointing. If we can’t explain it, then how can we use it?
Maybe we don’t need to fully explain these models. We can still confidently use them in production and important business applications as we strive, in parallel, to better understand their inner workings.
Representation and learning: how Transformer models learn context
BERT isn’t the only show in town, or in this case, the only muppet on the street! There are other models that use the Transformer architecture without learning objectives like masking (don’t worry, we’ll define all these terms shortly). We can look at how these models perform and what goes on inside these models to compare and contrast how well BERTs learning objective performs at learning context from its inputs.
Context is what separates BERT, and other Transformer models like GPT-3, from previous models like Word2Vec which learned only one “meaning” for a word. In this sense, Word2Vec produces static definitions (or embeddings which are just vectors representing the word in question), since they only have one meaning which doesn’t change depending on the context in which it’s used.
This has obvious limits. You’re currently reading a post about BERT, but you can also receive mail in the post, or put up a wooden post in your garden, and so on. Each use of “post” has very different meanings. This issue prevents models like Word2Vec from matching the results of models like BERT (…or maybe not? Later we’ll discuss whether BERT is simply a massively scaled-up version of Word2Vec with dynamic context).
In this section, we’ll look at three different types of models with three different learning objectives, and show that these learning objectives change how information on meaning and context is passed between the layers of their neural networks. This will lead us to claim that masking when used as a learning objective, is the reason why models like BERT learn context better than non-masking alternatives. The models we will look at in this post are:
- Machine Translation (MT) models: These were the first models to show the potential of the Transformer architecture to produce significant improvements in applications such as Neural Machine Translation (NMT).
- Language Models (LM): LMs such as GPT-like models, and their Recurrent Neural Network (RNN) predecessors, learn by predicting the next word in a sequence.
- Masked Language Models (MLM): MLMs like BERT used an approach called masking, where they tried to predict a random word in the text sequence. This has very different implications in how these models work, which we will get into shortly.
Representation: how Transformer models are similar
All Transformer models share a common approach in their neural network design. We can think of this common approach as having two main parts. While there are many nuances to all of these networks which differentiate them, we can think of them from a high level as being somewhat similar in their approach. The two main parts of these networks are:
- Representation: This is where the model takes in a text input and represents it as a vector we call an embedding. The representation part of the network encodes the textual information into a format the network can “read” or process. This includes encoding things like the position of a word in a sentence, and its relationship to other words in that sentence.
In Transformers this encoding of information is performed via the “Attention” mechanism (the specifics of the attention mechanism are beyond the scope of this post, we discussed them in detail here; alternatively, this post is a great illustration of how attention works in Transformer-based networks). For this article, all we need to know is that attention is a mechanism that encodes information about our input text into a vector format – an embedding – which our neural network can process.
- Learning: Once the Transformer model can represent the text as an embedding, we can then create a learning objective that will enable the model to “learn” from the text. We hope that the learning objective will enable these models to learn things like syntactic structure, semantics, and context, which will enable them to perform well in a wide range of linguistic tasks. The learning objective dictates what these models can learn from their training data, as it controls how information flows through their networks.
Learning: how Transformer models are different
Each Transformer model can create its learning object, which takes as its input the output from the representation part of the network. The representation part consists of the tokenizer and attention layers, which take the input and turn it into an embedding to represent the input text in vector format.
The choice of learning objective is critical for Transformer models because this choice-
- Defines how the network processes the input: Depending on the learning objective the Transformer model will either process the input one token (word) at a time, i.e. it will not be able to “look ahead” at the next word; or it will be able to process text “bidirectionally” and have access to both past and future tokens at learning time.
- Defines how information flows through the network: The goal of the learning object will define how information flows through the network. This is important since a deep neural network like the Transformer or BERT models consists of multiple layers and each layer will build on the output from the previous layer.
Ideally, we’d like each layer to add information to the input and build on what was learned in the previous layer. In this way, a deep neural network should, in theory, build up a more complex and higher level of knowledge as information passes from lower to higher levels. The learning object will define this information flow.
For example, in Convolution Neural Networks (CNNs) used in image processing, the lower levels of the network learn general detail like borders and outlines of shapes. Then the higher layers learn more nuanced detail like facial features and detailed aspects which differentiate between the same type of images. If the image is a cat, the lower layers will identify the general format of a cat and be able to differentiate that from a dog or a person, and the higher layers will be able to differentiate between different types of cats. Ideally (we don’t know this for certain, but it seems likely), we want language models to learn information from the text in a similar way so that the higher layers understand semantic as well as syntactic features.
The three types of models we’re looking at in this post have different learning objectives:
- Machine Translation (MT) Learning Objective: MT models take the input text and try and predict the words in a target sentence. The input sentence could be in English, and the target sentence a French translation of that sentence. ML models differ slightly from the other models discussed here since the representation layer embedding is not used to predict the output directly. Instead, for an MT model, the output of the representation layer is passed to a decoder, which turns the embedding back into its target language. For our purposes here we can still think of it in the same way as the other models, i.e. the learning objective tries to predict the words for a target sentence given input in a different source language.
- Language Models (LM) Learning Objective: LM models like GPT try to predict the next word given the previous words in the input sentence. The learning objective thus has access only to the past words in the sentence, i.e. it can’t “look” forward to the next word in the sentence to help it predict the current word. In this way, an LM model is said to be unidirectional as it only “reads” text from the start of the input to the end of the input.
- Masked Language Models (MLM): Learning Objective: MLMs, instead of predicting the next word, attempt to predict a “masked” word that is randomly selected from the input. This is a form of cloze test where a participant is asked to predict a hidden word in a sentence. It’s commonly used as an assessment of a person’s language ability. Since a cloze test requires the user to be able to see all of the text to understand the context of the missing word, MLMs need to be able to “see” all the text at the same time. This is in contrast to LMs, which only see past words when predicting the current word. This is why we say that MLMs are “bidirectional” since they have access to words that are before and after the current word being predicted.
Do Masked Language Models (MLMs) perform better?
The key question we want to answer is whether MLM models are better at learning information about language than the other models. While this is a difficult thing to measure, since we don’t know exactly how humans learn language in the first place, we can estimate it by looking at how information about the input changes between the layers of the network. Specifically, we’re looking at:
- How the model learns about the input: Initially the model will need to understand the input it’s given. It will look at the available information and try and understand as much about the input as possible before it generates the output required by the learning objective.
- How the model uses that information to perform its task: The learning objective defines the goal of our model, i.e. what it’s trying to predict. Once the model has learned something about the available input, it can then focus on the output it needs to generate to try and meet the requirements of the learning objective.
For example, imagine you were given a task like a cloze test and asked to predict the missing word:
“this is the [MISSING WORD] post I have ever read”
In this case, you would read the whole sentence first and focus on the context around the missing word. At first, you wouldn’t think about the missing word itself. You would look at the words before and after it, and use them as an anchor to predict the missing word.
Once you have thought about the surrounding words, you try and think about the missing word. You come up with some potential options for that word and think about whether they fit the context. In this case, you would (obviously) think about words like “best”, “greatest”, “perfect”, “outstanding”, “finest” and choose one of those which you think is correct.
Some researchers have looked at this aspect of information change between networks and tried to estimate the information that is lost or gained between the different layers. They do this in several different ways. One way is by looking at the information gained or lost between the input and output layers at different stages, i.e. what’s the difference between the input and the 2nd layer, the difference between the input and the 5th layer, the Nth layer, and so on.
By doing this, they showed that:
- Machine Translation changes less between layers: Since MT is not predicting a token for a cloze test type learning object, it shows less change between representations of consecutive layers. You can see this in the diagram below, where it shows less and less change between layers. It’s trying to translate the entire sentence so it doesn’t have the two-step change in focus we mentioned above, where the model understands the context and then predicts a label.
- Language Models shows evidence of forging and re-remembering: In contrast, LMs show more change between layers as they focus initially on the context of the surrounding words, and the difference between the layers gets progressively smaller. Then we see larger changes between layers as the model attempts to predict the next word.
- Masked Language Models shows the greatest change between layers: The MLM shows a much starker change between layers as it undergoes this two-step process. Initially, the MLM discards or “lose” information that it doesn’t need right now, i.e. the model focuses on the context of the surrounding words and not on the missing word it’s trying to predict. Then we see large changes in the representations between layers as it shifts focus to re-remembering the information relevant to the missing token it’s trying to predict, i.e. the learning objective.
The Masked Language Model secret: context encoding and token reconstruction
The cloze-like task of MLMs, i.e. masking a random word in the input, appears to force the Transformer models to undergo a two-part process which the authors of the study describe as:
- Context Encoding: This is where, as we noted earlier, the model is forced to focus on the context of the surrounding words and not the goal of the learning objective. MLMs appear to do this better than LMs, since they have access to the entire array of words in the input text and not, like LMs, only the preceding words. MT models have no label predicting output requirement tasks similar to either LMs or MLMs, so they “forget” the information relating to the current word being processed even more quickly.
- Token Reconstruction: MLMs show a clear second phase where the model shifts focus from learning context to trying to predict the missing or masked word. To do this, it tries to “recover” or “reconstruct” information about the current input token. It does this since it’s trying to understand how relevant the current token is to the other tokens in the sentence. In the context encoding phase, the model looks at the surrounding work, e.g. if the current word is “bank” we might first look for surrounding words like “fishing” or “money” to know what type of “bank” we are referring to, i.e. a river “bank” or a financial institution.
The below diagram from another experiment in the study shows the impact of this two-phase process more clearly:
In the above diagram, we can see that the MLM starts to reconstruct the information about the current input token, whereas the ML and LM continue to lose information relating to the input token, i.e. they don’t have a clear “reconstruction” phase.
Masked Language Models like BERT benefit from the future
Studies like those noted above show that MLMs seem to be the better choice for large pre-trained models like BERT, which need to have a more general knowledge of language to perform better on downstream NLP tasks.
It seems related to the fact that MLMs can access future information in the input text due to the bidirectional feature provided by the Transformer architecture. Remember the introduction, where we noted that masking is crucial but it depends on attention? The Attention mechanism of the Transformer architecture enables the model to “look” into the future when it’s learning the context of the surrounding words concerning the current input token.
Language Models also have access to this feature since they use the same underlying architecture, but the nature of their learning objectives means that they only look at the previous words in the sentence. They don’t “look” ahead at the next words since this creates some issues in terms of the design of the network. For example, if the model can have access to all the tokens in the input then it can, in theory, “cheat” and always predict the correct next word without learning anything.
So there is a tradeoff between model simplicity and learning objective benefits. MLMs make the models learn more about the input text, but this comes at the cost of greater model complexity.
GPT-3, for example, is an LM model which uses the decoder part of the Transformer architecture to predict the next word in an input sequence. Models like GPT-3 can compensate for the difference in learning objectives by being trained on a much larger and diverse dataset. The performance of GPT-3 shows that you can still achieve SOTA results via this route.
However, from the above results, we can see the benefit of MLM due to clear context encoding and token reconstruction phases. If you want to learn context from your input text then MLMs learn more from less. This appears to be related to MLMs’ ability to access the future context surrounding the current input token.
This contextual advantage is key to MLM performance in many NLP tasks. However, some tasks are not suited to this bidirectional access. Like in Text generation models like GPT-X, where the main goal is to predict the next word or words in a sequence of text. These models may be fit for training via an LM learning objective to enhance their ability to predict the future words given an input task.
How does masking work in BERT?
Before we wrap up, it’s worth specifically looking at BERT and its masking process implementation.
In the previous sections, we talked about masking as a general concept where you mask out a random word. Sounds relatively easy.
But, if you’ve worked with Transformer models in the past, you know that things turn out to be much more complicated. There are two main parts to the BERT masking implementation:
- Mask 15% of input tokens: Masking in BERT doesn’t just mask one token. Instead, it randomly chooses 15% of the input tokens and masks those. The 15% was chosen via a trial and error approach. If you mask more than that, it starts to make it difficult for the model to learn from the input since too much information is hidden. If you use less than 15%, it could take longer and require more data to learn enough context, since the model might predict the missing tokens more easily.
- Mask token, correct token, or wrong token: This is where it starts to get a little weird! After choosing the tokens we want to mask, we then need to decide if we actually want to mask them or not. We don’t end up masking all the tokens we said we would. Instead, we have another random selection process (which we’ll discuss below) where we choose to either add a mask token hiding the word, or we replace the token with a different word that is randomly selected, or we replace the token with the correct word we were initially going to mask. Sounds weird, right? It is, and we’ll try to clarify it below.
To mask or not to mask?
Once we’ve selected our 15% of input tokens we want to mask, we then need to decide whether we want to mask the tokens or not. What they do in BERT is as follows:
- 80% Replace with a [MASK] token: For 80% of the selected inputs the token is replaced with a [MASK] token similar to the classic cloze tests mentioned earlier.
- 10% Replace with an incorrect word: For 10% of the selected inputs, the token is replaced by another randomly selected word whose only requirement is that it’s different from the selected token.
- 10% Replace with the correct word: The remaining 10% of the time the selected token is simply replaced with the correct token.
The choice of these different ratios is somewhat arbitrary, authors of the BERT paper note that it was chosen in a trial-and-error process. Nevertheless, the thinking behind this seems to be related to:
- If you use the MASK token all the time: Just using the MASK token resulted in the model learning very little about the context of the surrounding words. This seems to occur since the model knows it can “forget” all the information about the surrounding words and focus only on the target word. This is similar to an LM-like approach and means the MLM doesn’t create the clear two-phase, context encoding, and token reconstruction information flow we saw earlier which was so important to the improved learning seen in MLMs.
- If you use the MASK token and the right word: To address this shortcoming, you could use the MASK token 80% of the time and then replace the token with the correct word 20% of the time. But there’s a problem. The model will know that when the MASK is not there, then the word is correct and it needs to learn nothing and just keep the current word since it’s correct. In other words, the model can “cheat” and learn nothing since it knows the non-masked token is always correct.
- If you use the MASK token and the wrong word: Alternatively, if you just use the wrong token all the time then the model will know when the MASK doesn’t appear the selected token is the wrong one, and it will just treat it like another MASK, i.e. you would likely encounter the same problem as above.
The result is that the ideal 80/10/10 splits force BERT to learn more about all the tokens in the input and not just the current input token. Or as the authors noted, “The advantage of this procedure is that the Transformer encoder does not know which words it will be asked to predict or which have been replaced by random words, so it is forced to keep a distributional contextual representation of every input token.”
The on masking latest research?
Don’t worry, I didn’t have a stroke while writing this article. The title of this section is jumbled on purpose, to show that word order is important for our understanding of language. This is why bag-of-word TF-IDF models perform poorly on some NLP tasks. They don’t maintain the word order in the input sentence, so they end up losing some meaning when the order of that input sentence is important for the task being performed.
MLMs like BERT are thought to preserve the order and positioning of words in the input text. This is often cited as a reason for their improved performance. Similarly, the cloze-like task is thought to force these models to learn syntactic and semantic aspects of language. This is in contrast to models like Word2Vec which learn semantic meaning solely from the distributional properties of large amounts of text.
Two new research papers shine more light on masking and MLMs, and show that we still don’t fully understand all the nuances and intricacies of this new learning objective:
- Masked Language Modeling and the Distributional Hypothesis: Order Word Matters Pre-training for Little: This paper claims to show that the ordering of words is not important for MLMs. They do this by randomly shuffling the input sentences and showing that this has a limited impact on the overall performance on a range of NLP tasks. They propose that the reason for MLMs improved performance could be due to:
- Poorly constructed NLP evaluation frameworks: The way these models are tested and evaluated may be too simple for these models. Thus the results do not reflect the models’ actual ability at linguistic tasks. Instead, the authors claim these evaluations frameworks need to be improved to keep pace with the performance of Transformer models like BERT. We also noted this in an earlier post when talking about the potential limits of Transformer models.
- MLMs’ ability to model higher-order word co-occurrence statistics: The paper claims that BERT performs so well solely due to its ability to learn from co-occurrence data. Masking and Attention simply enable BERT to learn more information than models like Word2Vec.
- On the Inductive Bias of Masked Language Modeling: From Statistical to Syntactic Dependencies: This paper focuses on the cloze tasks aspect of MLMs and claims that:
- MLMs close task is not really a close task: Since MLMs randomly select the words to be masked it is not a supervised cloze-like test. In a cloze test, the missing words are generally selected to ensure the question requires the participant to know something about the syntactic nature of that language. Randomly selecting these missing words means some common or meaningless words like “this”,” that”, “the” and so on could be chosen for the cloze test. Depending on the task these words may not be useful in forcing the mode to learn something meaningful for that task.
- MLMs don’t learn syntax directly: It was thought that the cloze-like test forced the model to learn information about the syntactic structure of language. However, the paper claims to show that, instead, MLMs learn direct statistical dependencies between tokens but that this enables the model to indirectly learn syntax. Previously it was thought that it was not possible to learn syntax solely from statistical inference. This paper claims that MLMs do indeed show that this is possible
To summarize, let’s revisit the original four points we noted for discussion in the introduction:
- Masking is key to BERTs success: We saw that research showed how masking as a learning objective changes the information flow within deep learning Transformer neural networks like BERT. This change creates a two-phase learning process where the model first performs contextual encoding to learn about the words surrounding the current input word. Then it performs token reconstruction where it attempts to predict the output word selected via the cloze-like MLM test. It’s the nature of this two-phase process that seems to differentiate MLMs from other approaches like LMs and MT.
- Masking needs attention: While masking enables BERT to perform awesomely, it still needs the Transformer architecture to do so. The Transformer network uses Attention to process text bi-directionally, which means models can look at text before and after the current input token. Non-MLMs don’t fully utilize this feature of the Transformer, and thus don’t benefit in the same way in terms of their ability to learn linguistic features.
- We still don’t understand masking: We saw that recent research has thrown up more questions than answers about how masking improves performance. However, this doesn’t mean masking doesn’t work. It just works in ways we didn’t think of. Indeed, one of the papers we discussed showed that we may be able to learn high-level knowledge, like language syntax, from statistical inference alone. This wasn’t something we thought possible until recently.
- Do we need to re-evaluate how we learn language? If it’s indeed possible to learn semantic and syntactic knowledge from something as simple as co-occurrence statistics, then what does that say about language? Specifically, what does it say about how we learn language? While this is a contentious area of research, it was thought that language was a complex and unique field of knowledge and required uniquely human-like skills to understand it. But maybe we are, in the same way as models like BERT, learning language by identifying when common words appear together? If this is the case, then why do we seem to learn things like semantics much quicker and with much less data than these models? It’s in this way that, as we strive to better understand models like BERT and learning objectives such as masking, we may also learn more about our unique ability to learn and understand language.
And finally, should all models be like BERT and embrace masking? As with most things in NLP and Machine Learning the answer is:
We can confidently say that most models would be improved by using an approach like masking in pre-training. However, there are some cases, such as text generation, where a learning objective like masking may not suit the task, since having access to future words in the input may counter the goal of the specific task. For more general models, where the goal is to train a model which has the general linguistic ability and can perform well in a wide range of downstream tasks, it does seem like those types of models would benefit from a masking-like approach.
The main source for much of this research into the Transformer comes from the amazing work of people like Lean Voita. Her blog is a great source of all things Transformers and there is even material for a course she teaches which is well worth checking out. The work by people like Lean into explaining what is going on inside these models is critical to improving our understanding of how models like BERT learn linguistic tasks. For this post, I specifically referenced her work on:
- The evolution of representations in the Transformer: This paper has lots of detail on how the representation of inputs changes between layers and due to the learning objective of the model.
- Evolution of representations blog post: On her blog, Lena has great posts that explain her papers in a more accessible format which you can check out if you don’t want to dive into the full detail of the paper
- Video about the paper: This is a great video about the paper that explains some of the findings we discussed in this post
The other papers that are worth checking out for more detail on all things BERT and Masking are:
- The original BERT paper: The original BERT paper provides a lot of detail and explains things like the authors’ attempts to find the right masking ratios
- New research on the ordering of words in MLM: This is new research on MLMs that shows that the order of words does not seem to matter and what that means for our understanding of models like BERT and GPT.
- New research on how much MLMs can learn from statistics: This paper looks at the Cloze task aspect of MLMs and claims that they help the models learn syntactic knowledge directly but indirectly via statistical inference.