Source code for neurox.data.extraction.transformers_extractor

"""Representations Extractor for ``transformers`` toolkit models.

Module that given a file with input sentences and a ``transformers``
model, extracts representations from all layers of the model. The script
supports aggregation over sub-words created due to the tokenization of
the provided model.

Can also be invoked as a script as follows:
    ``python -m neurox.data.extraction.transformers_extractor``
"""

import argparse
import sys

import numpy as np
import torch

from neurox.data.writer import ActivationsWriter

from tqdm import tqdm
from transformers import AutoModel, AutoTokenizer


[docs]def get_model_and_tokenizer(model_desc, device="cpu", random_weights=False): """ Automatically get the appropriate ``transformers`` model and tokenizer based on the model description Parameters ---------- model_desc : str Model description; can either be a model name like ``bert-base-uncased``, a comma separated list indicating <model>,<tokenizer> (since 1.0.8), or a path to a trained model device : str, optional Device to load the model on, cpu or gpu. Default is cpu. random_weights : bool, optional Whether the weights of the model should be randomized. Useful for analyses where one needs an untrained model. Returns ------- model : transformers model An instance of one of the transformers.modeling classes tokenizer : transformers tokenizer An instance of one of the transformers.tokenization classes """ model_desc = model_desc.split(",") if len(model_desc) == 1: model_name = model_desc[0] tokenizer_name = model_desc[0] else: model_name = model_desc[0] tokenizer_name = model_desc[1] model = AutoModel.from_pretrained(model_name, output_hidden_states=True).to(device) tokenizer = AutoTokenizer.from_pretrained(tokenizer_name) if random_weights: print("Randomizing weights") model.init_weights() return model, tokenizer
[docs]def aggregate_repr(state, start, end, aggregation): """ Function that aggregates activations/embeddings over a span of subword tokens. This function will usually be called once per word. For example, if we had the sentence:: This is an example which is tokenized by BPE into:: this is an ex @@am @@ple The function should be called 4 times:: aggregate_repr(state, 0, 0, aggregation) aggregate_repr(state, 1, 1, aggregation) aggregate_repr(state, 2, 2, aggregation) aggregate_repr(state, 3, 5, aggregation) Returns a zero vector if end is less than start, i.e. the request is to aggregate over an empty slice. Parameters ---------- state : numpy.ndarray Matrix of size [ NUM_LAYERS x NUM_SUBWORD_TOKENS_IN_SENT x LAYER_DIM] start : int Index of the first subword of the word being processed end : int Index of the last subword of the word being processed aggregation : {'first', 'last', 'average'} Aggregation method for combining subword activations Returns ------- word_vector : numpy.ndarray Matrix of size [NUM_LAYERS x LAYER_DIM] """ if end < start: sys.stderr.write( "WARNING: An empty slice of tokens was encountered. " + "This probably implies a special unicode character or text " + "encoding issue in your original data that was dropped by the " + "transformer model's tokenizer.\n" ) return np.zeros((state.shape[0], state.shape[2])) if aggregation == "first": return state[:, start, :] elif aggregation == "last": return state[:, end, :] elif aggregation == "average": return np.average(state[:, start : end + 1, :], axis=1)
[docs]def extract_sentence_representations( sentence, model, tokenizer, device="cpu", include_embeddings=True, aggregation="last", dtype="float32", include_special_tokens=False, tokenization_counts={}, ): """ Get representations for a single sentence The extractor runs a detokenization procedure to combine subwords automatically. For instance, a sentence "Hello, how are you?" may be tokenized by the model as "Hell @@o , how are you @@?". This extractor automatically detokenizes the subtokens back into the original token. Parameters ---------- sentence : str Sentence for which the extraction needs to be done. The returned output will have representations for exactly the same number of elements as tokens in this sentence (counted by `sentence.split(' ')`). model : transformers model An instance of one of the transformers.modeling classes tokenizer : transformers tokenizer An instance of one of the transformers.tokenization classes device : str, optional Specifies the device (CPU/GPU) on which the extraction should be performed. Defaults to 'cpu' include_embeddings : bool, optional Whether the embedding layer should be included in the final output, or just regular layers. Defaults to True aggregation : {'first', 'last', 'average'}, optional Aggregation method for combining subword activations. Defaults to 'last' dtype : str, optional Data type in which the activations will be stored. Supports all numpy based tensor types. Common values are 'float32' and 'float16'. Defaults to 'float16' include_special_tokens : bool, optional Whether or not to special tokens in the extracted representations. Special tokens are tokens not present in the original sentence, but are added by the tokenizer, such as [CLS], [SEP] etc. tokenization_counts : dict, optional Tokenization counts to use across a dataset for efficiency Returns ------- final_hidden_states : numpy.ndarray Numpy Matrix of size [``NUM_LAYERs`` x ``NUM_TOKENS`` x ``NUM_NEURONS``]. detokenizer : list List of detokenized words. This will have the same number of elements as tokens in the original sentence, plus special tokens if requested. Each element preserves tokenization artifacts (such as `##`, `@@` etc) to enable further automatic processing. """ special_tokens = [ x for x in tokenizer.all_special_tokens if x != tokenizer.unk_token ] special_tokens_ids = tokenizer.convert_tokens_to_ids(special_tokens) original_tokens = sentence.split(" ") # Add letters and spaces around each word since some tokenizers are context sensitive tmp_tokens = [] if len(original_tokens) > 0: tmp_tokens.append(f"{original_tokens[0]} a") tmp_tokens += [f"a {x} a" for x in original_tokens[1:-1]] if len(original_tokens) > 1: tmp_tokens.append(f"a {original_tokens[-1]}") assert len(original_tokens) == len( tmp_tokens ), f"Original: {original_tokens}, Temp: {tmp_tokens}" with torch.no_grad(): # Get tokenization counts if not already available for token_idx, token in enumerate(tmp_tokens): tok_ids = [ x for x in tokenizer.encode(token) if x not in special_tokens_ids ] # Ignore the added letter tokens if token_idx != 0 and token_idx != len(tmp_tokens) - 1: # Word appearing in the middle of the sentence tok_ids = tok_ids[1:-1] elif token_idx == 0: # Word appearing at the beginning tok_ids = tok_ids[:-1] else: # Word appearing at the end tok_ids = tok_ids[1:] if token in tokenization_counts: assert tokenization_counts[token] == len( tok_ids ), "Got different tokenization for already processed word" else: tokenization_counts[token] = len(tok_ids) ids = tokenizer.encode(sentence, truncation=True) input_ids = torch.tensor([ids]).to(device) # Hugging Face format: tuple of torch.FloatTensor of shape (batch_size, sequence_length, hidden_size) # Tuple has 13 elements for base model: embedding outputs + hidden states at each layer all_hidden_states = model(input_ids)[-1] if include_embeddings: all_hidden_states = [ hidden_states[0].cpu().numpy() for hidden_states in all_hidden_states ] else: all_hidden_states = [ hidden_states[0].cpu().numpy() for hidden_states in all_hidden_states[1:] ] all_hidden_states = np.array(all_hidden_states, dtype=dtype) print('Sentence : "%s"' % (sentence)) print("Original (%03d): %s" % (len(original_tokens), original_tokens)) print( "Tokenized (%03d): %s" % ( len(tokenizer.convert_ids_to_tokens(ids)), tokenizer.convert_ids_to_tokens(ids), ) ) assert all_hidden_states.shape[1] == len(ids) # Handle special tokens # filtered_ids will contain all ids if we are extracting with # special tokens, and only normal word/subword ids if we are # extracting without special tokens # all_hidden_states will also be filtered at this step to match # the ids in filtered ids filtered_ids = ids idx_special_tokens = [t_i for t_i, x in enumerate(ids) if x in special_tokens_ids] special_token_ids = [ids[t_i] for t_i in idx_special_tokens] if not include_special_tokens: idx_without_special_tokens = [ t_i for t_i, x in enumerate(ids) if x not in special_tokens_ids ] filtered_ids = [ids[t_i] for t_i in idx_without_special_tokens] all_hidden_states = all_hidden_states[:, idx_without_special_tokens, :] special_token_ids = [] assert all_hidden_states.shape[1] == len(filtered_ids) print( "Filtered (%03d): %s" % ( len(tokenizer.convert_ids_to_tokens(filtered_ids)), tokenizer.convert_ids_to_tokens(filtered_ids), ) ) # Get actual tokens for filtered ids in order to do subword # aggregation segmented_tokens = tokenizer.convert_ids_to_tokens(filtered_ids) # Perform subword aggregation/detokenization # After aggregation, we should have |original_tokens| embeddings, # one for each word. If special tokens are included, then we will # have |original_tokens| + |special_tokens| counter = 0 detokenized = [] final_hidden_states = np.zeros( ( all_hidden_states.shape[0], len(original_tokens) + len(special_token_ids), all_hidden_states.shape[2], ), dtype=dtype, ) inputs_truncated = False # Keep track of what the previous token was. This is used to detect # special tokens followed/preceeded by dropped tokens, which is an # ambiguous situation for the detokenizer prev_token_type = "NONE" last_special_token_pointer = 0 for token_idx, token in enumerate(tmp_tokens): # Handle special tokens if include_special_tokens and tokenization_counts[token] != 0: if last_special_token_pointer < len(idx_special_tokens): while ( last_special_token_pointer < len(idx_special_tokens) and counter == idx_special_tokens[last_special_token_pointer] ): assert prev_token_type != "DROPPED", ( "A token dropped by the tokenizer appeared next " + "to a special token. Detokenizer cannot resolve " + f"the ambiguity, please remove '{sentence}' from" + "the dataset, or try a different tokenizer" ) prev_token_type = "SPECIAL" final_hidden_states[:, len(detokenized), :] = all_hidden_states[ :, counter, : ] detokenized.append( segmented_tokens[idx_special_tokens[last_special_token_pointer]] ) last_special_token_pointer += 1 counter += 1 current_word_start_idx = counter current_word_end_idx = counter + tokenization_counts[token] # Check for truncated hidden states in the case where the # original word was actually tokenized if ( tokenization_counts[token] != 0 and current_word_start_idx >= all_hidden_states.shape[1] ) or current_word_end_idx > all_hidden_states.shape[1]: final_hidden_states = final_hidden_states[ :, : len(detokenized) + len(special_token_ids) - last_special_token_pointer, :, ] inputs_truncated = True break if tokenization_counts[token] == 0: assert prev_token_type != "SPECIAL", ( "A token dropped by the tokenizer appeared next " + "to a special token. Detokenizer cannot resolve " + f"the ambiguity, please remove '{sentence}' from" + "the dataset, or try a different tokenizer" ) prev_token_type = "DROPPED" else: prev_token_type = "NORMAL" final_hidden_states[:, len(detokenized), :] = aggregate_repr( all_hidden_states, current_word_start_idx, current_word_end_idx - 1, aggregation, ) detokenized.append( "".join(segmented_tokens[current_word_start_idx:current_word_end_idx]) ) counter += tokenization_counts[token] if include_special_tokens: while counter < len(segmented_tokens): if last_special_token_pointer >= len(idx_special_tokens): break if counter == idx_special_tokens[last_special_token_pointer]: assert prev_token_type != "DROPPED", ( "A token dropped by the tokenizer appeared next " + "to a special token. Detokenizer cannot resolve " + f"the ambiguity, please remove '{sentence}' from" + "the dataset, or try a different tokenizer" ) prev_token_type = "SPECIAL" final_hidden_states[:, len(detokenized), :] = all_hidden_states[ :, counter, : ] detokenized.append( segmented_tokens[idx_special_tokens[last_special_token_pointer]] ) last_special_token_pointer += 1 counter += 1 print("Detokenized (%03d): %s" % (len(detokenized), detokenized)) print("Counter: %d" % (counter)) if inputs_truncated: print("WARNING: Input truncated because of length, skipping check") else: assert counter == len(filtered_ids) assert len(detokenized) == len(original_tokens) + len(special_token_ids) print("===================================================================") return final_hidden_states, detokenized
[docs]def extract_representations( model_desc, input_corpus, output_file, device="cpu", aggregation="last", output_type="json", random_weights=False, ignore_embeddings=False, decompose_layers=False, filter_layers=None, dtype="float32", include_special_tokens=False, ): """ Extract representations for an entire corpus and save them to disk Parameters ---------- model_desc : str Model description; can either be a model name like ``bert-base-uncased``, a comma separated list indicating <model>,<tokenizer> (since 1.0.8), or a path to a trained model input_corpus : str Path to the input corpus, where each sentence is on its separate line output_file : str Path to output file. Supports all filetypes supported by ``data.writer.ActivationsWriter``. device : str, optional Specifies the device (CPU/GPU) on which the extraction should be performed. Defaults to 'cpu' aggregation : {'first', 'last', 'average'}, optional Aggregation method for combining subword activations. Defaults to 'last' output_type : str, optional Explicit definition of output file type if it cannot be derived from the ``output_file`` path random_weights : bool, optional Whether the weights of the model should be randomized. Useful for analyses where one needs an untrained model. Defaults to False. ignore_embeddings : bool, optional Whether the embedding layer should be excluded in the final output, or kept with the regular layers. Defaults to False decompose_layers : bool, optional Whether each layer should have it's own output file, or all layers be saved in a single file. Defaults to False, i.e. single file filter_layers : str Comma separated list of layer indices to save. The format is the same as the one accepted by ``data.writer.ActivationsWriter``. dtype : str, optional Data type in which the activations will be stored. Supports all numpy based tensor types. Common values are 'float32' and 'float16'. Defaults to 'float16' include_special_tokens : bool, optional Whether or not to special tokens in the extracted representations. Special tokens are tokens not present in the original sentence, but are added by the tokenizer, such as [CLS], [SEP] etc. """ print(f"Loading model: {model_desc}") model, tokenizer = get_model_and_tokenizer( model_desc, device=device, random_weights=random_weights ) print("Reading input corpus") def corpus_generator(input_corpus_path): with open(input_corpus_path, "r") as fp: for line in fp: yield line.strip() return print("Preparing output file") writer = ActivationsWriter.get_writer( output_file, filetype=output_type, decompose_layers=decompose_layers, filter_layers=filter_layers, dtype=dtype, ) print("Extracting representations from model") tokenization_counts = {} # Cache for tokenizer rules for sentence_idx, sentence in enumerate(corpus_generator(input_corpus)): hidden_states, extracted_words = extract_sentence_representations( sentence, model, tokenizer, device=device, include_embeddings=(not ignore_embeddings), aggregation=aggregation, dtype=dtype, include_special_tokens=include_special_tokens, tokenization_counts=tokenization_counts, ) print("Hidden states: ", hidden_states.shape) print("# Extracted words: ", len(extracted_words)) writer.write_activations(sentence_idx, extracted_words, hidden_states) writer.close()
HDF5_SPECIAL_TOKENS = {".": "__DOT__", "/": "__SLASH__"}
[docs]def main(): parser = argparse.ArgumentParser() parser.add_argument("model_desc", help="Name of model") parser.add_argument( "input_corpus", help="Text file path with one sentence per line" ) parser.add_argument( "output_file", help="Output file path where extracted representations will be stored", ) parser.add_argument( "--aggregation", help="first, last or average aggregation for word representation in the case of subword segmentation", default="last", ) parser.add_argument( "--dtype", choices=["float16", "float32"], default="float32", help="Output dtype of the extracted representations", ) parser.add_argument("--disable_cuda", action="store_true") parser.add_argument("--ignore_embeddings", action="store_true") parser.add_argument( "--random_weights", action="store_true", help="generate representations from randomly initialized model", ) parser.add_argument( "--include_special_tokens", action="store_true", help="Include special tokens like [CLS] and [SEP] in the extracted representations", ) ActivationsWriter.add_writer_options(parser) args = parser.parse_args() assert args.aggregation in [ "average", "first", "last", ], "Invalid aggregation option, please specify first, average or last." assert not ( args.filter_layers is not None and args.ignore_embeddings is True ), "--filter_layers and --ignore_embeddings cannot be used at the same time" if not args.disable_cuda and torch.cuda.is_available(): device = torch.device("cuda") else: device = torch.device("cpu") extract_representations( args.model_desc, args.input_corpus, args.output_file, device=device, aggregation=args.aggregation, output_type=args.output_type, random_weights=args.random_weights, ignore_embeddings=args.ignore_embeddings, dtype=args.dtype, decompose_layers=args.decompose_layers, filter_layers=args.filter_layers, include_special_tokens=args.include_special_tokens, )
if __name__ == "__main__": main()