Author(s): Fabio Yáñez Romero Originally published on Towards AI. Bert from Sesame Street is figuring out how to train BERT from zero. Source: DALL-E 3. Token Masking is a widely used strategy for training language models in its classification variant and generation models. The BERT language model introduced it and has been used in many variants (RoBERTa, ALBERT, DeBERTa…). However, Token Masking is a strategy within a larger group called Text Corruption. In the BART research paper, numerous experiments were performed to train an encoder-decoder generation model with different text corruption strategies. Text corruption strategies. Source: “BART: Denoising Sequence-to-Sequence Pre-training for Natural Language Generation, Translation, and Comprehension”. Before discussing the different techniques for Text Corruption, we will talk about the standard concepts of all Text Corruption methods in Large Language Models (LLMs). From supervised to self-supervised A large amount of text is used in the initial training of a language model with the objective that the model learns to represent the language correctly, storing this knowledge implicitly in its parameter weights. This massive amount of text must have labels for training, as we must calculate the cross-entropy after processing the model input data with reference data. However, annotating such a large amount of data is unfeasible. Therefore, we resort to automatic label generation, turning the supervised problem into a self-supervised problem. In this case, the corrupted sequence serves as the model’s training input, while all or part of the original sequence serves as the training data’s labels. This will depend on the nature of the model (encoder or encoder-decoder). Corruption probability With automatic labels, the model learns the label associated with each training example without annotating the data. In Text Corruption (especially in Token Masking, Token Deletion, and Text Infilling), each word will likely be corrupted according to a fixed probability, usually around 15–20%. This probability is kept low so the model can learn the context of each sentence even if the sequence is corrupted. Some Text Corruption techniques, such as Sentence Permutation or Document Rotation, do not focus on corrupting words with a certain probability. This allows them to be compatible with other corruption techniques, as discussed below. Differences between Classification and Generation When training language models with text corruption, the labels vary depending on whether it is a classification model (encoder-only) or a generation model (encoder-decoder). In classification models, the labels used only pay attention to the corrupted areas of the input. So, if a word is masked in a whole sentence, the label will be the initial sequence, paying attention only to the corrupted sequence. For generation models, as the model must be able to generate text continuously, the output label is the initial uncorrupted sequence, paying attention to the whole sequence itself. Setup Now that we have briefly introduced the points in common when training a language model with Text Corruption, let’s discuss the different techniques used to corrupt texts, giving examples with code in each case. We will start with a document in the code examples to see how the different strategies work. We will use Stanza, a library developed by Stanford NLP with different NLP tools that are very useful for our preprocessing. import stanzastanza.download('en')# Text used in our examplestext = "Huntington's disease is a neurodegenerative autosomal disease results due to expansion of polymorphic CAG repeats in the huntingtin gene. Phosphorylation of the translation initiation factor 4E-BP results in the alteration of the translation control leading to unwanted protein synthesis and neuronal function. Consequences of mutant huntington (mhtt) gene transcription are not well known. Variability of age of onset is an important factor of Huntington's disease separating adult and juvenile types. The factors which are taken into account are-genetic modifiers, maternal protection i.e excessive paternal transmission, superior ageing genes and environmental threshold. A major focus has been given to the molecular pathogenesis which includes-motor disturbance, cognitive disturbance and neuropsychiatric disturbance. The diagnosis part has also been taken care of. This includes genetic testing and both primary and secondary symptoms. The present review also focuses on the genetics and pathology of Huntington's disease."# We will use a stanza model for getting each different sentence # as an element of the listnlp = stanza.Pipeline('en', use_gpu=False)doc = nlp(text)sentences = [sentence.text for sentence in doc.sentences] Token Masking Token Masking replaces random words in the text with <mask> to discover the masked word. Token Masking example. BERT introduced this strategy, the first and best-known Sequence Corruption strategy. It consists of corrupting an input sequence by masking random words, which will be used as labels during training. In classification models, we can use the DataCollatorForLanguageModeling class directly from Huggingface transformers to generate the necessary labels, allowing us to train models like BERT or RoBERTa. from transformers import AutoTokenizer, DataCollatorForLanguageModelingimport torchdef load_dataset_mlm(sentences, tokenizer_class=AutoTokenizer, collator_class=DataCollatorForLanguageModeling, mlm=True, mlm_probability=0.20): tokenizer = tokenizer_class.from_pretrained('google-bert/bert-base-uncased') inputs = tokenizer(sentences, return_tensors='pt', padding=True, truncation=True) # Random masking configuration data_collator = collator_class( tokenizer=tokenizer, mlm=mlm, mlm_probability=mlm_probability ) """The collator expects a tuple of tensors, so you have to split the input tensors and then remove the first dimension and pass it to a tuple. """ tuple_ids = torch.split(inputs['input_ids'], 1, dim=0) tuple_ids = list(tuple_ids) for tensor in range(len(tuple_ids)): tuple_ids[tensor] = tuple_ids[tensor].squeeze(0) tuple_ids = tuple(tuple_ids) # Get input_ids, attention_masks and labels for each sentence. batch = data_collator(tuple_ids) return batch['input_ids'], inputs['attention_mask'], batch['labels']input_ids, attention_mask, labels = load_dataset_mlm(sentences)"""input_ids[0]:tensor([ 101, 16364, 1005, 1055, 103, 2003, 1037, 103, 10976, 3207, 103, 25284, 103, 25426, 16870, 4295, 3463, 2349, 2000, 103, 1997, 26572, 18078, 6187, 2290, 17993, 1999, 1996, 5933, 7629, 103, 103, 102, 0, 0])attention_mask[0]:tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0])labels[0]:tensor([ -100, -100, -100, -100, 4295, -100, -100, 11265, -100, -100, 6914, -100, 8285, -100, 2389, -100, -100, -100, -100, 4935, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 4962, 1012, -100, -100, -100]))""" Notice that the generated inputs_ids have an integer number for each token of the […]
↧