|
|
from dataclasses import dataclass, field |
|
|
from functools import lru_cache |
|
|
from typing import Optional, Union |
|
|
from transformers import AutoTokenizer |
|
|
from transformers.tokenization_utils import PreTrainedTokenizer |
|
|
|
|
|
|
|
|
|
|
|
_CHARS_TO_BYTES = { |
|
|
"Ā": 0, "ā": 1, "Ă": 2, "ă": 3, "Ą": 4, "ą": 5, "Ć": 6, "ć": 7, "Ĉ": 8, |
|
|
"ĉ": 9, "Ċ": 10, "ċ": 11, "Č": 12, "č": 13, "Ď": 14, "ď": 15, "Đ": 16, |
|
|
"đ": 17, "Ē": 18, "ē": 19, "Ĕ": 20, "ĕ": 21, "Ė": 22, "ė": 23, "Ę": 24, |
|
|
"ę": 25, "Ě": 26, "ě": 27, "Ĝ": 28, "ĝ": 29, "Ğ": 30, "ğ": 31, "Ġ": 32, |
|
|
"!": 33, '"': 34, "#": 35, "$": 36, "%": 37, "&": 38, "'": 39, "(": 40, |
|
|
")": 41, "*": 42, "+": 43, ",": 44, "-": 45, ".": 46, "/": 47, "0": 48, |
|
|
"1": 49, "2": 50, "3": 51, "4": 52, "5": 53, "6": 54, "7": 55, "8": 56, |
|
|
"9": 57, ":": 58, ";": 59, "<": 60, "=": 61, ">": 62, "?": 63, "@": 64, |
|
|
"A": 65, "B": 66, "C": 67, "D": 68, "E": 69, "F": 70, "G": 71, "H": 72, |
|
|
"I": 73, "J": 74, "K": 75, "L": 76, "M": 77, "N": 78, "O": 79, "P": 80, |
|
|
"Q": 81, "R": 82, "S": 83, "T": 84, "U": 85, "V": 86, "W": 87, "X": 88, |
|
|
"Y": 89, "Z": 90, "[": 91, "\\": 92, "]": 93, "^": 94, "_": 95, "`": 96, |
|
|
"a": 97, "b": 98, "c": 99, "d": 100, "e": 101, "f": 102, "g": 103, |
|
|
"h": 104, "i": 105, "j": 106, "k": 107, "l": 108, "m": 109, "n": 110, |
|
|
"o": 111, "p": 112, "q": 113, "r": 114, "s": 115, "t": 116, "u": 117, |
|
|
"v": 118, "w": 119, "x": 120, "y": 121, "z": 122, "{": 123, "|": 124, |
|
|
"}": 125, "~": 126, "ġ": 127, "Ģ": 128, "ģ": 129, "Ĥ": 130, "ĥ": 131, |
|
|
"Ħ": 132, "ħ": 133, "Ĩ": 134, "ĩ": 135, "Ī": 136, "ī": 137, "Ĭ": 138, |
|
|
"ĭ": 139, "Į": 140, "į": 141, "İ": 142, "ı": 143, "IJ": 144, "ij": 145, |
|
|
"Ĵ": 146, "ĵ": 147, "Ķ": 148, "ķ": 149, "ĸ": 150, "Ĺ": 151, "ĺ": 152, |
|
|
"Ļ": 153, "ļ": 154, "Ľ": 155, "ľ": 156, "Ŀ": 157, "ŀ": 158, "Ł": 159, |
|
|
"ł": 160, "¡": 161, "¢": 162, "£": 163, "¤": 164, "¥": 165, "¦": 166, |
|
|
"§": 167, "¨": 168, "©": 169, "ª": 170, "«": 171, "¬": 172, "Ń": 173, |
|
|
"®": 174, "¯": 175, "°": 176, "±": 177, "²": 178, "³": 179, "´": 180, |
|
|
"µ": 181, "¶": 182, "·": 183, "¸": 184, "¹": 185, "º": 186, "»": 187, |
|
|
"¼": 188, "½": 189, "¾": 190, "¿": 191, "À": 192, "Á": 193, "Â": 194, |
|
|
"Ã": 195, "Ä": 196, "Å": 197, "Æ": 198, "Ç": 199, "È": 200, "É": 201, |
|
|
"Ê": 202, "Ë": 203, "Ì": 204, "Í": 205, "Î": 206, "Ï": 207, "Ð": 208, |
|
|
"Ñ": 209, "Ò": 210, "Ó": 211, "Ô": 212, "Õ": 213, "Ö": 214, "×": 215, |
|
|
"Ø": 216, "Ù": 217, "Ú": 218, "Û": 219, "Ü": 220, "Ý": 221, "Þ": 222, |
|
|
"ß": 223, "à": 224, "á": 225, "â": 226, "ã": 227, "ä": 228, "å": 229, |
|
|
"æ": 230, "ç": 231, "è": 232, "é": 233, "ê": 234, "ë": 235, "ì": 236, |
|
|
"í": 237, "î": 238, "ï": 239, "ð": 240, "ñ": 241, "ò": 242, "ó": 243, |
|
|
"ô": 244, "õ": 245, "ö": 246, "÷": 247, "ø": 248, "ù": 249, "ú": 250, |
|
|
"û": 251, "ü": 252, "ý": 253, "þ": 254, "ÿ": 255, |
|
|
} |
|
|
_BYTES_TO_CHARS = {v: k for k, v in _CHARS_TO_BYTES.items()} |
|
|
|
|
|
def _bytes_to_chars(byte_sequence: bytes) -> str: |
|
|
return "".join(_BYTES_TO_CHARS[byte] for byte in byte_sequence) |
|
|
|
|
|
def _chars_to_bytes(char_sequence: str) -> list: |
|
|
return list(bytes(_CHARS_TO_BYTES[char] for char in char_sequence)) |
|
|
|
|
|
@dataclass |
|
|
class BolmoTokenizerConfig: |
|
|
vocab_size: int |
|
|
bos_token_id: int |
|
|
pad_token_id: int |
|
|
eos_token_id: int |
|
|
bpe_token_end_id: int |
|
|
special_tokens: list[str] = field(default_factory=lambda: []) |
|
|
special_tokens_first: bool = True |
|
|
original_identifier: Optional[str] = None |
|
|
|
|
|
|
|
|
@classmethod |
|
|
def bolmo(cls) -> "BolmoTokenizerConfig": |
|
|
special_tokens = [ |
|
|
"<pad>", |
|
|
"<bos>", |
|
|
"<eos>", |
|
|
"<bpe_token_end>", |
|
|
] |
|
|
|
|
|
return cls( |
|
|
|
|
|
vocab_size=(len(special_tokens) + 256) * 2, |
|
|
special_tokens=special_tokens, |
|
|
bos_token_id=special_tokens.index("<bos>"), |
|
|
pad_token_id=special_tokens.index("<pad>"), |
|
|
eos_token_id=special_tokens.index("<bos>"), |
|
|
bpe_token_end_id=special_tokens.index("<bpe_token_end>"), |
|
|
original_identifier="allenai/dolma2-tokenizer", |
|
|
) |
|
|
|
|
|
def build(self): |
|
|
return BolmoTokenizer(tokenizer_config=self) |
|
|
|
|
|
|
|
|
class BolmoTokenizer(PreTrainedTokenizer): |
|
|
TOKEN_ID_KEY = -1 |
|
|
|
|
|
def __init__(self, **kwargs): |
|
|
tokenizer_config = kwargs.pop("tokenizer_config", BolmoTokenizerConfig.bolmo()) |
|
|
|
|
|
self.config = tokenizer_config |
|
|
self.hf_tokenizer = AutoTokenizer.from_pretrained(tokenizer_config.original_identifier) |
|
|
if self.config.special_tokens_first: |
|
|
self.offset = len(tokenizer_config.special_tokens) |
|
|
self.special_tokens_offset = 0 |
|
|
else: |
|
|
self.offset = 0 |
|
|
self.special_tokens_offset = self.config.vocab_size - len(tokenizer_config.special_tokens) |
|
|
|
|
|
self.byte_sequences = {} |
|
|
|
|
|
for key, value in self.hf_tokenizer.get_vocab().items(): |
|
|
if key in self.config.special_tokens: |
|
|
byte_sequence = [self.special_tokens_offset + self.config.special_tokens.index(key)] |
|
|
elif value == self.hf_tokenizer.eos_token_id and self.eos_token_id is not None: |
|
|
byte_sequence = [self.eos_token_id] |
|
|
elif value == self.hf_tokenizer.bos_token_id and self.bos_token_id is not None: |
|
|
byte_sequence = [self.bos_token_id] |
|
|
elif value == self.hf_tokenizer.pad_token_id and self.pad_token_id is not None: |
|
|
byte_sequence = [self.pad_token_id] |
|
|
else: |
|
|
byte_sequence = [self.offset + i for i in _chars_to_bytes(key)] |
|
|
|
|
|
assert self.byte_sequences.get(value) is None |
|
|
self.byte_sequences[value] = byte_sequence |
|
|
|
|
|
self.byte_trie = {} |
|
|
|
|
|
for token_id, byte_sequence in self.byte_sequences.items(): |
|
|
current_dict = self.byte_trie |
|
|
for byte in byte_sequence[::-1]: |
|
|
if byte not in current_dict: |
|
|
current_dict[byte] = {} |
|
|
current_dict = current_dict[byte] |
|
|
current_dict[BolmoTokenizer.TOKEN_ID_KEY] = token_id |
|
|
|
|
|
self.add_bos_token = True |
|
|
self.add_eos_token = False |
|
|
self.padding_side = "left" |
|
|
|
|
|
super().__init__( |
|
|
bos_token=self.config.special_tokens[self.config.bos_token_id], |
|
|
eos_token=self.config.special_tokens[self.config.eos_token_id], |
|
|
pad_token=self.config.special_tokens[self.config.pad_token_id], |
|
|
extra_ids=0, |
|
|
) |
|
|
|
|
|
@property |
|
|
def bos_token_id(self): |
|
|
return self.config.bos_token_id |
|
|
|
|
|
@property |
|
|
def eos_token_id(self): |
|
|
return self.config.eos_token_id |
|
|
|
|
|
@property |
|
|
def pad_token_id(self): |
|
|
return self.config.pad_token_id |
|
|
|
|
|
@property |
|
|
def bpe_token_end_id(self): |
|
|
return self.config.bpe_token_end_id |
|
|
|
|
|
@property |
|
|
def vocab_size(self): |
|
|
return self.config.vocab_size |
|
|
|
|
|
def _convert_id_to_token(self, index): |
|
|
if index < self.offset: |
|
|
return self.config.special_tokens[index - self.special_tokens_offset] |
|
|
|
|
|
if index >= self.offset + 256 and index < self.offset * 2 + 256: |
|
|
|
|
|
return self.config.special_tokens[index - self.offset - 256] + "b" |
|
|
|
|
|
return _BYTES_TO_CHARS[index - self.offset - 256 - self.offset] + "b" if index >= self.offset + 256 else _BYTES_TO_CHARS[index - self.offset] |
|
|
|
|
|
def _convert_token_to_id(self, token): |
|
|
if token in self.config.special_tokens: |
|
|
return self.config.special_tokens.index(token) |
|
|
|
|
|
if token in [x + "b" for x in self.config.special_tokens]: |
|
|
|
|
|
return 256 + self.config.special_tokens.index(token[:-1]) |
|
|
|
|
|
if len(token) > 1 and token[-1] == "b": |
|
|
return self.offset + 256 + _CHARS_TO_BYTES[token[0]] |
|
|
else: |
|
|
return self.offset + _CHARS_TO_BYTES[token] |
|
|
|
|
|
def get_vocab(self): |
|
|
vocab = {self.convert_ids_to_tokens(i): i for i in range(self.vocab_size)} |
|
|
return vocab |
|
|
|
|
|
def expand_byte_ids(self, byte_ids: list[int], n_last: Optional[int] = None) -> list[int]: |
|
|
|
|
|
expanded_ids = [] |
|
|
for i in range(len(byte_ids)): |
|
|
if n_last is not None and i < len(byte_ids) - n_last: |
|
|
continue |
|
|
|
|
|
current_dict = self.byte_trie |
|
|
current_expansion = None |
|
|
|
|
|
for i in range(i, -1, -1): |
|
|
byte = byte_ids[i] |
|
|
|
|
|
if byte == self.bpe_token_end_id: |
|
|
|
|
|
continue |
|
|
|
|
|
if byte >= self.offset + 256: |
|
|
|
|
|
byte -= self.offset + 256 |
|
|
|
|
|
try: |
|
|
current_dict = current_dict[byte] |
|
|
if BolmoTokenizer.TOKEN_ID_KEY in current_dict: |
|
|
current_expansion = current_dict[BolmoTokenizer.TOKEN_ID_KEY] |
|
|
except KeyError: |
|
|
assert current_expansion is not None |
|
|
break |
|
|
|
|
|
expanded_ids.append(current_expansion) |
|
|
|
|
|
return expanded_ids |
|
|
|
|
|
|
|
|
def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None): |
|
|
bos_token_id = [self.bos_token_id] if self.add_bos_token else [] |
|
|
eos_token_id = [self.eos_token_id] if self.add_eos_token else [] |
|
|
|
|
|
output = bos_token_id + token_ids_0 + eos_token_id |
|
|
|
|
|
if token_ids_1 is not None: |
|
|
output = output + bos_token_id + token_ids_1 + eos_token_id |
|
|
|
|
|
return output |
|
|
|
|
|
|
|
|
def get_special_tokens_mask( |
|
|
self, token_ids_0: list[int], token_ids_1: Optional[list[int]] = None, already_has_special_tokens: bool = False |
|
|
) -> list[int]: |
|
|
""" |
|
|
Retrieve sequence ids from a token list that has no special tokens added. This method is called when adding |
|
|
special tokens using the tokenizer `prepare_for_model` method. |
|
|
Args: |
|
|
token_ids_0 (`List[int]`): |
|
|
List of IDs. |
|
|
token_ids_1 (`List[int]`, *optional*): |
|
|
Optional second list of IDs for sequence pairs. |
|
|
already_has_special_tokens (`bool`, *optional*, defaults to `False`): |
|
|
Whether or not the token list is already formatted with special tokens for the model. |
|
|
Returns: |
|
|
`List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token. |
|
|
""" |
|
|
if already_has_special_tokens: |
|
|
return super().get_special_tokens_mask( |
|
|
token_ids_0=token_ids_0, token_ids_1=token_ids_1, already_has_special_tokens=True |
|
|
) |
|
|
|
|
|
bos_token_id = [1] if self.add_bos_token else [] |
|
|
eos_token_id = [1] if self.add_eos_token else [] |
|
|
|
|
|
if token_ids_1 is None: |
|
|
return bos_token_id + ([0] * len(token_ids_0)) + eos_token_id |
|
|
return ( |
|
|
bos_token_id |
|
|
+ ([0] * len(token_ids_0)) |
|
|
+ eos_token_id |
|
|
+ bos_token_id |
|
|
+ ([0] * len(token_ids_1)) |
|
|
+ eos_token_id |
|
|
) |
|
|
|
|
|
|
|
|
def create_token_type_ids_from_sequences( |
|
|
self, token_ids_0: list[int], token_ids_1: Optional[list[int]] = None |
|
|
) -> list[int]: |
|
|
""" |
|
|
Creates a mask from the two sequences passed to be used in a sequence-pair classification task. An ALBERT |
|
|
sequence pair mask has the following format: |
|
|
``` |
|
|
0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 |
|
|
| first sequence | second sequence | |
|
|
``` |
|
|
if token_ids_1 is None, only returns the first portion of the mask (0s). |
|
|
Args: |
|
|
token_ids_0 (`List[int]`): |
|
|
List of ids. |
|
|
token_ids_1 (`List[int]`, *optional*): |
|
|
Optional second list of IDs for sequence pairs. |
|
|
Returns: |
|
|
`List[int]`: List of [token type IDs](../glossary#token-type-ids) according to the given sequence(s). |
|
|
""" |
|
|
bos_token_id = [self.bos_token_id] if self.add_bos_token else [] |
|
|
eos_token_id = [self.eos_token_id] if self.add_eos_token else [] |
|
|
|
|
|
output = [0] * len(bos_token_id + token_ids_0 + eos_token_id) |
|
|
|
|
|
if token_ids_1 is not None: |
|
|
output += [1] * len(bos_token_id + token_ids_1 + eos_token_id) |
|
|
|
|
|
return output |
|
|
|
|
|
def _tokenize(self, text: str, **kwargs) -> list[str]: |
|
|
"""Take as input a string and return a list of strings (tokens) for words/sub-words""" |
|
|
tokens = self.convert_ids_to_tokens(self._bolmo_encode(text)) |
|
|
return tokens |
|
|
|
|
|
def _patch_ids_to_byte_ids(self, input_ids: list[int]): |
|
|
return [byte_token_id for token_id in input_ids for byte_token_id in self.byte_sequences[token_id]] |
|
|
|
|
|
def _bolmo_encode(self, string: str, add_special_tokens=False): |
|
|
input_ids = self.hf_tokenizer.encode(string, add_special_tokens=add_special_tokens) |
|
|
return self._patch_ids_to_byte_ids(input_ids) |
|
|
|
|
|
def _bolmo_decode(self, tokens: list[int], skip_special_tokens: bool = False) -> str: |
|
|
return self._decode_to_bytes(tokens, skip_special_tokens=skip_special_tokens).decode("utf-8", errors="replace") |
|
|
|
|
|
def _decode_to_bytes(self, tokens: list[int], skip_special_tokens: bool = False) -> bytes: |
|
|
tokens_without_boundary = [] |
|
|
for token in tokens: |
|
|
if token >= (self.offset + 256): |
|
|
token -= self.offset + 256 |
|
|
|
|
|
tokens_without_boundary.append(token) |
|
|
|
|
|
utf8_bytes = [] |
|
|
|
|
|
for token in tokens_without_boundary: |
|
|
if token < self.offset: |
|
|
if skip_special_tokens: |
|
|
continue |
|
|
else: |
|
|
utf8_bytes.extend(self.config.special_tokens[token].encode("utf-8")) |
|
|
else: |
|
|
utf8_bytes.append(min(token - self.offset, 255)) |
|
|
|
|
|
return bytes(utf8_bytes) |
|
|
|
|
|
def get_tokens_and_patch_lengths(self, original_input_ids: list[int], add_bos=False, strip_pad=False, skip_last=False): |
|
|
if add_bos and self.bos_token_id is not None: |
|
|
byte_tokens = [self.bos_token_id] |
|
|
patch_lengths = [1] |
|
|
else: |
|
|
byte_tokens = [] |
|
|
patch_lengths = [] |
|
|
|
|
|
for idx, token in enumerate(original_input_ids): |
|
|
|
|
|
if skip_last and idx == len(original_input_ids) - 1: |
|
|
break |
|
|
|
|
|
token_byte_tokens = self._patch_ids_to_byte_ids([int(token)]) |
|
|
|
|
|
if strip_pad and all(t == self.pad_token_id for t in token_byte_tokens): |
|
|
|
|
|
continue |
|
|
|
|
|
patch_lengths.append(len(token_byte_tokens)) |
|
|
byte_tokens.extend(token_byte_tokens) |
|
|
|
|
|
return byte_tokens, patch_lengths |
|
|
|
|
|
def convert_tokens_to_string(self, tokens: list[str]) -> str: |
|
|
return self._bolmo_decode(self.convert_tokens_to_ids(tokens), skip_special_tokens=False) |
|
|
|
|
|
def _decode( |
|
|
self, |
|
|
token_ids: Union[int, list[int]], |
|
|
skip_special_tokens: bool = False, |
|
|
clean_up_tokenization_spaces: Optional[bool] = None, |
|
|
spaces_between_special_tokens: bool = True, |
|
|
**kwargs, |
|
|
) -> str: |
|
|
if isinstance(token_ids, int): |
|
|
token_ids = [token_ids] |
|
|
|
|
|
return self._bolmo_decode(token_ids, skip_special_tokens=skip_special_tokens) |
|
|
|
|
|
def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> tuple[str]: |
|
|
return () |