File size: 16,421 Bytes
8b28006
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
from dataclasses import dataclass, field
from functools import lru_cache
from typing import Optional, Union
from transformers import AutoTokenizer
from transformers.tokenization_utils import PreTrainedTokenizer

# Source: https://github.com/openai/gpt-2/blob/master/src/encoder.py#L9
# Also implemented in https://docs.rs/tokenizers/latest/src/tokenizers/pre_tokenizers/byte_level.rs.html#13-39
_CHARS_TO_BYTES = {
    "Ā": 0, "ā": 1, "Ă": 2, "ă": 3, "Ą": 4, "ą": 5, "Ć": 6, "ć": 7, "Ĉ": 8,
    "ĉ": 9, "Ċ": 10, "ċ": 11, "Č": 12, "č": 13, "Ď": 14, "ď": 15, "Đ": 16,
    "đ": 17, "Ē": 18, "ē": 19, "Ĕ": 20, "ĕ": 21, "Ė": 22, "ė": 23, "Ę": 24,
    "ę": 25, "Ě": 26, "ě": 27, "Ĝ": 28, "ĝ": 29, "Ğ": 30, "ğ": 31, "Ġ": 32,
    "!": 33, '"': 34, "#": 35, "$": 36, "%": 37, "&": 38, "'": 39, "(": 40,
    ")": 41, "*": 42, "+": 43, ",": 44, "-": 45, ".": 46, "/": 47, "0": 48,
    "1": 49, "2": 50, "3": 51, "4": 52, "5": 53, "6": 54, "7": 55, "8": 56,
    "9": 57, ":": 58, ";": 59, "<": 60, "=": 61, ">": 62, "?": 63, "@": 64, 
    "A": 65, "B": 66, "C": 67, "D": 68, "E": 69, "F": 70, "G": 71, "H": 72,
    "I": 73, "J": 74, "K": 75, "L": 76, "M": 77, "N": 78, "O": 79, "P": 80,
    "Q": 81, "R": 82, "S": 83, "T": 84, "U": 85, "V": 86, "W": 87, "X": 88,
    "Y": 89, "Z": 90, "[": 91, "\\": 92, "]": 93, "^": 94, "_": 95, "`": 96, 
    "a": 97, "b": 98, "c": 99, "d": 100, "e": 101, "f": 102, "g": 103,
    "h": 104, "i": 105, "j": 106, "k": 107, "l": 108, "m": 109, "n": 110,
    "o": 111, "p": 112, "q": 113, "r": 114, "s": 115, "t": 116, "u": 117,
    "v": 118, "w": 119, "x": 120, "y": 121, "z": 122, "{": 123, "|": 124,
    "}": 125, "~": 126, "ġ": 127, "Ģ": 128, "ģ": 129, "Ĥ": 130, "ĥ": 131,
    "Ħ": 132, "ħ": 133, "Ĩ": 134, "ĩ": 135, "Ī": 136, "ī": 137, "Ĭ": 138,
    "ĭ": 139, "Į": 140, "į": 141, "İ": 142, "ı": 143, "IJ": 144, "ij": 145,
    "Ĵ": 146, "ĵ": 147, "Ķ": 148, "ķ": 149, "ĸ": 150, "Ĺ": 151, "ĺ": 152,
    "Ļ": 153, "ļ": 154, "Ľ": 155, "ľ": 156, "Ŀ": 157, "ŀ": 158, "Ł": 159,
    "ł": 160, "¡": 161, "¢": 162, "£": 163, "¤": 164, "¥": 165, "¦": 166,
    "§": 167, "¨": 168, "©": 169, "ª": 170, "«": 171, "¬": 172, "Ń": 173,
    "®": 174, "¯": 175, "°": 176, "±": 177, "²": 178, "³": 179, "´": 180,
    "µ": 181, "¶": 182, "·": 183, "¸": 184, "¹": 185, "º": 186, "»": 187,
    "¼": 188, "½": 189, "¾": 190, "¿": 191, "À": 192, "Á": 193, "Â": 194,
    "Ã": 195, "Ä": 196, "Å": 197, "Æ": 198, "Ç": 199, "È": 200, "É": 201,
    "Ê": 202, "Ë": 203, "Ì": 204, "Í": 205, "Î": 206, "Ï": 207, "Ð": 208,
    "Ñ": 209, "Ò": 210, "Ó": 211, "Ô": 212, "Õ": 213, "Ö": 214, "×": 215,
    "Ø": 216, "Ù": 217, "Ú": 218, "Û": 219, "Ü": 220, "Ý": 221, "Þ": 222,
    "ß": 223, "à": 224, "á": 225, "â": 226, "ã": 227, "ä": 228, "å": 229,
    "æ": 230, "ç": 231, "è": 232, "é": 233, "ê": 234, "ë": 235, "ì": 236,
    "í": 237, "î": 238, "ï": 239, "ð": 240, "ñ": 241, "ò": 242, "ó": 243,
    "ô": 244, "õ": 245, "ö": 246, "÷": 247, "ø": 248, "ù": 249, "ú": 250,
    "û": 251, "ü": 252, "ý": 253, "þ": 254, "ÿ": 255,
}
_BYTES_TO_CHARS = {v: k for k, v in _CHARS_TO_BYTES.items()}

def _bytes_to_chars(byte_sequence: bytes) -> str:
    return "".join(_BYTES_TO_CHARS[byte] for byte in byte_sequence)

def _chars_to_bytes(char_sequence: str) -> list:
    return list(bytes(_CHARS_TO_BYTES[char] for char in char_sequence))

@dataclass
class BolmoTokenizerConfig:
    vocab_size: int
    bos_token_id: int
    pad_token_id: int
    eos_token_id: int
    bpe_token_end_id: int
    special_tokens: list[str] = field(default_factory=lambda: [])
    special_tokens_first: bool = True
    original_identifier: Optional[str] = None


    @classmethod
    def bolmo(cls) -> "BolmoTokenizerConfig":
        special_tokens = [
            "<pad>",
            "<bos>",
            "<eos>",
            "<bpe_token_end>",
        ]

        return cls(
            # *2 to accomodate fused boundary tokens
            vocab_size=(len(special_tokens) + 256) * 2,
            special_tokens=special_tokens,
            bos_token_id=special_tokens.index("<bos>"),
            pad_token_id=special_tokens.index("<pad>"),
            eos_token_id=special_tokens.index("<bos>"),
            bpe_token_end_id=special_tokens.index("<bpe_token_end>"),
            original_identifier="allenai/dolma2-tokenizer",
        )
    
    def build(self):
        return BolmoTokenizer(tokenizer_config=self)


class BolmoTokenizer(PreTrainedTokenizer):
    TOKEN_ID_KEY = -1

    def __init__(self, **kwargs):
        tokenizer_config = kwargs.pop("tokenizer_config", BolmoTokenizerConfig.bolmo())

        self.config = tokenizer_config
        self.hf_tokenizer = AutoTokenizer.from_pretrained(tokenizer_config.original_identifier)
        if self.config.special_tokens_first:
            self.offset = len(tokenizer_config.special_tokens)
            self.special_tokens_offset = 0
        else:
            self.offset = 0
            self.special_tokens_offset = self.config.vocab_size - len(tokenizer_config.special_tokens)

        self.byte_sequences = {}
    
        for key, value in self.hf_tokenizer.get_vocab().items():
            if key in self.config.special_tokens:
                byte_sequence = [self.special_tokens_offset + self.config.special_tokens.index(key)]
            elif value == self.hf_tokenizer.eos_token_id and self.eos_token_id is not None:
                byte_sequence = [self.eos_token_id]
            elif value == self.hf_tokenizer.bos_token_id and self.bos_token_id is not None:
                byte_sequence = [self.bos_token_id]
            elif value == self.hf_tokenizer.pad_token_id and self.pad_token_id is not None:
                byte_sequence = [self.pad_token_id]
            else:
                byte_sequence = [self.offset + i for i in _chars_to_bytes(key)]

            assert self.byte_sequences.get(value) is None
            self.byte_sequences[value] = byte_sequence

        self.byte_trie = {}

        for token_id, byte_sequence in self.byte_sequences.items():
            current_dict = self.byte_trie
            for byte in byte_sequence[::-1]: # retrieved from the back so store in reverse order
                if byte not in current_dict:
                    current_dict[byte] = {}
                current_dict = current_dict[byte]
            current_dict[BolmoTokenizer.TOKEN_ID_KEY] = token_id

        self.add_bos_token = True
        self.add_eos_token = False
        self.padding_side = "left" # for generate

        super().__init__(
            bos_token=self.config.special_tokens[self.config.bos_token_id],
            eos_token=self.config.special_tokens[self.config.eos_token_id],
            pad_token=self.config.special_tokens[self.config.pad_token_id],
            extra_ids=0,
        )

    @property
    def bos_token_id(self):
        return self.config.bos_token_id

    @property
    def eos_token_id(self):
        return self.config.eos_token_id
    
    @property
    def pad_token_id(self):
        return self.config.pad_token_id

    @property
    def bpe_token_end_id(self):
        return self.config.bpe_token_end_id 

    @property
    def vocab_size(self):
        return self.config.vocab_size

    def _convert_id_to_token(self, index):
        if index < self.offset:
            return self.config.special_tokens[index - self.special_tokens_offset]

        if index >= self.offset + 256 and index < self.offset * 2 + 256:
            # special token with fused boundary
            return self.config.special_tokens[index - self.offset - 256] + "b"

        return _BYTES_TO_CHARS[index - self.offset - 256 - self.offset] + "b" if index >= self.offset + 256 else _BYTES_TO_CHARS[index - self.offset]

    def _convert_token_to_id(self, token):
        if token in self.config.special_tokens:
            return self.config.special_tokens.index(token)

        if token in [x + "b" for x in self.config.special_tokens]:
            # special token with fused boundary
            return 256 + self.config.special_tokens.index(token[:-1])

        if len(token) > 1 and token[-1] == "b":
            return self.offset + 256 + _CHARS_TO_BYTES[token[0]]
        else:
            return self.offset + _CHARS_TO_BYTES[token]

    def get_vocab(self):
        vocab = {self.convert_ids_to_tokens(i): i for i in range(self.vocab_size)}
        return vocab

    def expand_byte_ids(self, byte_ids: list[int], n_last: Optional[int] = None) -> list[int]:
        # search in the byte tree for the longest matching token at every byte position
        expanded_ids = []
        for i in range(len(byte_ids)):
            if n_last is not None and i < len(byte_ids) - n_last:
                continue

            current_dict = self.byte_trie
            current_expansion = None

            for i in range(i, -1, -1):
                byte = byte_ids[i]

                if byte == self.bpe_token_end_id:
                    # skip bpe token end markers, needed for generation
                    continue

                if byte >= self.offset + 256:
                    # ignore fused boundary
                    byte -= self.offset + 256

                try:
                    current_dict = current_dict[byte]
                    if BolmoTokenizer.TOKEN_ID_KEY in current_dict:
                        current_expansion = current_dict[BolmoTokenizer.TOKEN_ID_KEY]
                except KeyError:
                    assert current_expansion is not None
                    break

            expanded_ids.append(current_expansion)

        return expanded_ids

    # Copied from transformers.models.llama.tokenization_llama.LlamaTokenizer.build_inputs_with_special_tokens
    def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None):
        bos_token_id = [self.bos_token_id] if self.add_bos_token else []
        eos_token_id = [self.eos_token_id] if self.add_eos_token else []

        output = bos_token_id + token_ids_0 + eos_token_id

        if token_ids_1 is not None:
            output = output + bos_token_id + token_ids_1 + eos_token_id

        return output

    # Copied from transformers.models.llama.tokenization_llama.LlamaTokenizer.get_special_tokens_mask
    def get_special_tokens_mask(
        self, token_ids_0: list[int], token_ids_1: Optional[list[int]] = None, already_has_special_tokens: bool = False
    ) -> list[int]:
        """
        Retrieve sequence ids from a token list that has no special tokens added. This method is called when adding
        special tokens using the tokenizer `prepare_for_model` method.
        Args:
            token_ids_0 (`List[int]`):
                List of IDs.
            token_ids_1 (`List[int]`, *optional*):
                Optional second list of IDs for sequence pairs.
            already_has_special_tokens (`bool`, *optional*, defaults to `False`):
                Whether or not the token list is already formatted with special tokens for the model.
        Returns:
            `List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token.
        """
        if already_has_special_tokens:
            return super().get_special_tokens_mask(
                token_ids_0=token_ids_0, token_ids_1=token_ids_1, already_has_special_tokens=True
            )

        bos_token_id = [1] if self.add_bos_token else []
        eos_token_id = [1] if self.add_eos_token else []

        if token_ids_1 is None:
            return bos_token_id + ([0] * len(token_ids_0)) + eos_token_id
        return (
            bos_token_id
            + ([0] * len(token_ids_0))
            + eos_token_id
            + bos_token_id
            + ([0] * len(token_ids_1))
            + eos_token_id
        )

    # Copied from transformers.models.llama.tokenization_llama.LlamaTokenizer.create_token_type_ids_from_sequences
    def create_token_type_ids_from_sequences(
        self, token_ids_0: list[int], token_ids_1: Optional[list[int]] = None
    ) -> list[int]:
        """
        Creates a mask from the two sequences passed to be used in a sequence-pair classification task. An ALBERT
        sequence pair mask has the following format:
        ```
        0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1
        | first sequence    | second sequence |
        ```
        if token_ids_1 is None, only returns the first portion of the mask (0s).
        Args:
            token_ids_0 (`List[int]`):
                List of ids.
            token_ids_1 (`List[int]`, *optional*):
                Optional second list of IDs for sequence pairs.
        Returns:
            `List[int]`: List of [token type IDs](../glossary#token-type-ids) according to the given sequence(s).
        """
        bos_token_id = [self.bos_token_id] if self.add_bos_token else []
        eos_token_id = [self.eos_token_id] if self.add_eos_token else []

        output = [0] * len(bos_token_id + token_ids_0 + eos_token_id)

        if token_ids_1 is not None:
            output += [1] * len(bos_token_id + token_ids_1 + eos_token_id)

        return output

    def _tokenize(self, text: str, **kwargs) -> list[str]:
        """Take as input a string and return a list of strings (tokens) for words/sub-words"""
        tokens = self.convert_ids_to_tokens(self._bolmo_encode(text))
        return tokens

    def _patch_ids_to_byte_ids(self, input_ids: list[int]):
        return [byte_token_id for token_id in input_ids for byte_token_id in self.byte_sequences[token_id]]

    def _bolmo_encode(self, string: str, add_special_tokens=False):
        input_ids = self.hf_tokenizer.encode(string, add_special_tokens=add_special_tokens)
        return self._patch_ids_to_byte_ids(input_ids)

    def _bolmo_decode(self, tokens: list[int], skip_special_tokens: bool = False) -> str:
        return self._decode_to_bytes(tokens, skip_special_tokens=skip_special_tokens).decode("utf-8", errors="replace")

    def _decode_to_bytes(self, tokens: list[int], skip_special_tokens: bool = False) -> bytes:
        tokens_without_boundary = []
        for token in tokens:
            if token >= (self.offset + 256):
                token -= self.offset + 256

            tokens_without_boundary.append(token)

        utf8_bytes = []

        for token in tokens_without_boundary:
            if token < self.offset:
                if skip_special_tokens:
                    continue
                else:
                    utf8_bytes.extend(self.config.special_tokens[token].encode("utf-8"))
            else:
                utf8_bytes.append(min(token - self.offset, 255))

        return bytes(utf8_bytes)

    def get_tokens_and_patch_lengths(self, original_input_ids: list[int], add_bos=False, strip_pad=False, skip_last=False):
        if add_bos and self.bos_token_id is not None:
            byte_tokens = [self.bos_token_id]
            patch_lengths = [1]
        else:
            byte_tokens = []
            patch_lengths = []

        for idx, token in enumerate(original_input_ids):
            # optionally skip last token to keep the length the same if add_bos=True
            if skip_last and idx == len(original_input_ids) - 1:
                break

            token_byte_tokens = self._patch_ids_to_byte_ids([int(token)])

            if strip_pad and all(t == self.pad_token_id for t in token_byte_tokens):
                # skip padding tokens
                continue

            patch_lengths.append(len(token_byte_tokens))
            byte_tokens.extend(token_byte_tokens)

        return byte_tokens, patch_lengths

    def convert_tokens_to_string(self, tokens: list[str]) -> str:
        return self._bolmo_decode(self.convert_tokens_to_ids(tokens), skip_special_tokens=False)  # type: ignore

    def _decode(
        self,
        token_ids: Union[int, list[int]],
        skip_special_tokens: bool = False,
        clean_up_tokenization_spaces: Optional[bool] = None,
        spaces_between_special_tokens: bool = True,
        **kwargs,
    ) -> str:
        if isinstance(token_ids, int):
            token_ids = [token_ids]

        return self._bolmo_decode(token_ids, skip_special_tokens=skip_special_tokens)

    def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> tuple[str]:
        return ()  # type: ignore