Upload folder using huggingface_hub
Browse files- config.json +90 -0
- configuration_bolmo.py +235 -0
- generation_config.json +6 -0
- model-00001-of-00007.safetensors +3 -0
- model-00002-of-00007.safetensors +3 -0
- model-00003-of-00007.safetensors +3 -0
- model-00004-of-00007.safetensors +3 -0
- model-00005-of-00007.safetensors +3 -0
- model-00006-of-00007.safetensors +3 -0
- model-00007-of-00007.safetensors +3 -0
- model.safetensors.index.json +447 -0
- modeling_bolmo.py +1351 -0
- special_tokens_map.json +5 -0
- tokenization_bolmo.py +378 -0
- tokenizer_config.json +34 -0
- utils_bolmo.py +127 -0
config.json
ADDED
|
@@ -0,0 +1,90 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"add_expanded_embeddings": true,
|
| 3 |
+
"architectures": [
|
| 4 |
+
"BolmoForCausalLM"
|
| 5 |
+
],
|
| 6 |
+
"attention_bias": false,
|
| 7 |
+
"attention_dropout": 0.0,
|
| 8 |
+
"auto_map": {
|
| 9 |
+
"AutoConfig": "configuration_bolmo.BolmoConfig",
|
| 10 |
+
"AutoModelForCausalLM": "modeling_bolmo.BolmoForCausalLM"
|
| 11 |
+
},
|
| 12 |
+
"bos_token_id": 1,
|
| 13 |
+
"boundary_predictor_lookahead": 1,
|
| 14 |
+
"boundary_threshold": "sample:0",
|
| 15 |
+
"dtype": "float32",
|
| 16 |
+
"eos_token_id": 1,
|
| 17 |
+
"hidden_act": "silu",
|
| 18 |
+
"hidden_size": 4096,
|
| 19 |
+
"initializer_range": 0.02,
|
| 20 |
+
"intermediate_size": 11008,
|
| 21 |
+
"layer_types": [
|
| 22 |
+
"sliding_attention",
|
| 23 |
+
"sliding_attention",
|
| 24 |
+
"sliding_attention",
|
| 25 |
+
"full_attention",
|
| 26 |
+
"sliding_attention",
|
| 27 |
+
"sliding_attention",
|
| 28 |
+
"sliding_attention",
|
| 29 |
+
"full_attention",
|
| 30 |
+
"sliding_attention",
|
| 31 |
+
"sliding_attention",
|
| 32 |
+
"sliding_attention",
|
| 33 |
+
"full_attention",
|
| 34 |
+
"sliding_attention",
|
| 35 |
+
"sliding_attention",
|
| 36 |
+
"sliding_attention",
|
| 37 |
+
"full_attention",
|
| 38 |
+
"sliding_attention",
|
| 39 |
+
"sliding_attention",
|
| 40 |
+
"sliding_attention",
|
| 41 |
+
"full_attention",
|
| 42 |
+
"sliding_attention",
|
| 43 |
+
"sliding_attention",
|
| 44 |
+
"sliding_attention",
|
| 45 |
+
"full_attention",
|
| 46 |
+
"sliding_attention",
|
| 47 |
+
"sliding_attention",
|
| 48 |
+
"sliding_attention",
|
| 49 |
+
"full_attention",
|
| 50 |
+
"sliding_attention",
|
| 51 |
+
"sliding_attention",
|
| 52 |
+
"sliding_attention",
|
| 53 |
+
"full_attention"
|
| 54 |
+
],
|
| 55 |
+
"local_intermediate_size": 5504,
|
| 56 |
+
"local_rms_norm_eps": 1e-05,
|
| 57 |
+
"max_position_embeddings": 65536,
|
| 58 |
+
"model_type": "bolmo",
|
| 59 |
+
"num_attention_heads": 32,
|
| 60 |
+
"num_hidden_layers": 32,
|
| 61 |
+
"num_key_value_heads": 32,
|
| 62 |
+
"num_local_decoder_layers": 4,
|
| 63 |
+
"num_local_encoder_layers": 1,
|
| 64 |
+
"num_local_heads": 16,
|
| 65 |
+
"pad_token_id": 0,
|
| 66 |
+
"rms_norm_eps": 1e-06,
|
| 67 |
+
"rope_scaling": null,
|
| 68 |
+
"rope_theta": 10000.0,
|
| 69 |
+
"sliding_window": 4096,
|
| 70 |
+
"subword_vocab_size": 100278,
|
| 71 |
+
"tie_word_embeddings": false,
|
| 72 |
+
"tokenizer_config": {
|
| 73 |
+
"bos_token_id": 1,
|
| 74 |
+
"bpe_token_end_id": 3,
|
| 75 |
+
"eos_token_id": 1,
|
| 76 |
+
"original_identifier": "allenai/dolma2-tokenizer",
|
| 77 |
+
"pad_token_id": 0,
|
| 78 |
+
"special_tokens": [
|
| 79 |
+
"<pad>",
|
| 80 |
+
"<bos>",
|
| 81 |
+
"<eos>",
|
| 82 |
+
"<bpe_token_end>"
|
| 83 |
+
],
|
| 84 |
+
"special_tokens_first": true,
|
| 85 |
+
"vocab_size": 520
|
| 86 |
+
},
|
| 87 |
+
"transformers_version": "4.57.3",
|
| 88 |
+
"use_cache": true,
|
| 89 |
+
"vocab_size": 520
|
| 90 |
+
}
|
configuration_bolmo.py
ADDED
|
@@ -0,0 +1,235 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from dataclasses import asdict
|
| 2 |
+
from typing import Any
|
| 3 |
+
|
| 4 |
+
from transformers.configuration_utils import PretrainedConfig, layer_type_validation
|
| 5 |
+
from transformers.modeling_rope_utils import rope_config_validation
|
| 6 |
+
from .tokenization_bolmo import BolmoTokenizerConfig
|
| 7 |
+
|
| 8 |
+
class BolmoConfig(PretrainedConfig):
|
| 9 |
+
r"""
|
| 10 |
+
This is the configuration class to store the configuration of a [`Olmo3Model`]. It is used to instantiate an OLMo3
|
| 11 |
+
model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
|
| 12 |
+
defaults will yield a similar configuration to that of the [allenai/OLMo-3-0725-1B](https://huggingface.co/allenai/OLMo-3-0725-1B).
|
| 13 |
+
|
| 14 |
+
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
| 15 |
+
documentation from [`PretrainedConfig`] for more information.
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
Args:
|
| 19 |
+
vocab_size (`int`, *optional*, defaults to 50304):
|
| 20 |
+
Vocabulary size of the Olmo3 model. Defines the number of different tokens that can be represented by the
|
| 21 |
+
`inputs_ids` passed when calling [`Olmo3Model`]
|
| 22 |
+
hidden_size (`int`, *optional*, defaults to 4096):
|
| 23 |
+
Dimension of the hidden representations.
|
| 24 |
+
intermediate_size (`int`, *optional*, defaults to 11008):
|
| 25 |
+
Dimension of the MLP representations.
|
| 26 |
+
num_hidden_layers (`int`, *optional*, defaults to 32):
|
| 27 |
+
Number of hidden layers in the Transformer decoder.
|
| 28 |
+
num_attention_heads (`int`, *optional*, defaults to 32):
|
| 29 |
+
Number of attention heads for each attention layer in the Transformer decoder.
|
| 30 |
+
num_key_value_heads (`int`, *optional*):
|
| 31 |
+
This is the number of key_value heads that should be used to implement Grouped Query Attention. If
|
| 32 |
+
`num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
|
| 33 |
+
`num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When
|
| 34 |
+
converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
|
| 35 |
+
by meanpooling all the original heads within that group. For more details, check out [this
|
| 36 |
+
paper](https://huggingface.co/papers/2305.13245). If it is not specified, will default to
|
| 37 |
+
`num_attention_heads`.
|
| 38 |
+
hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
|
| 39 |
+
The non-linear activation function (function or string) in the decoder.
|
| 40 |
+
max_position_embeddings (`int`, *optional*, defaults to 2048):
|
| 41 |
+
The maximum sequence length that this model might ever be used with.
|
| 42 |
+
initializer_range (`float`, *optional*, defaults to 0.02):
|
| 43 |
+
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
| 44 |
+
use_cache (`bool`, *optional*, defaults to `True`):
|
| 45 |
+
Whether or not the model should return the last key/values attentions (not used by all models). Only
|
| 46 |
+
relevant if `config.is_decoder=True`.
|
| 47 |
+
pad_token_id (`int`, *optional*, defaults to 1):
|
| 48 |
+
Padding token id.
|
| 49 |
+
bos_token_id (`int`, *optional*):
|
| 50 |
+
Beginning of stream token id.
|
| 51 |
+
eos_token_id (`int`, *optional*, defaults to 50279):
|
| 52 |
+
End of stream token id.
|
| 53 |
+
tie_word_embeddings (`bool`, *optional*, defaults to `False`):
|
| 54 |
+
Whether to tie weight embeddings
|
| 55 |
+
rope_theta (`float`, *optional*, defaults to 10000.0):
|
| 56 |
+
The base period of the RoPE embeddings.
|
| 57 |
+
rope_scaling (`Dict`, *optional*):
|
| 58 |
+
Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type
|
| 59 |
+
and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value
|
| 60 |
+
accordingly.
|
| 61 |
+
Expected contents:
|
| 62 |
+
`rope_type` (`str`):
|
| 63 |
+
The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope',
|
| 64 |
+
'llama3'], with 'default' being the original RoPE implementation.
|
| 65 |
+
`factor` (`float`, *optional*):
|
| 66 |
+
Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In
|
| 67 |
+
most scaling types, a `factor` of x will enable the model to handle sequences of length x *
|
| 68 |
+
original maximum pre-trained length.
|
| 69 |
+
`original_max_position_embeddings` (`int`, *optional*):
|
| 70 |
+
Used with 'dynamic', 'longrope' and 'llama3'. The original max position embeddings used during
|
| 71 |
+
pretraining.
|
| 72 |
+
`attention_factor` (`float`, *optional*):
|
| 73 |
+
Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention
|
| 74 |
+
computation. If unspecified, it defaults to value recommended by the implementation, using the
|
| 75 |
+
`factor` field to infer the suggested value.
|
| 76 |
+
`beta_fast` (`float`, *optional*):
|
| 77 |
+
Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear
|
| 78 |
+
ramp function. If unspecified, it defaults to 32.
|
| 79 |
+
`beta_slow` (`float`, *optional*):
|
| 80 |
+
Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear
|
| 81 |
+
ramp function. If unspecified, it defaults to 1.
|
| 82 |
+
`short_factor` (`list[float]`, *optional*):
|
| 83 |
+
Only used with 'longrope'. The scaling factor to be applied to short contexts (<
|
| 84 |
+
`original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
|
| 85 |
+
size divided by the number of attention heads divided by 2
|
| 86 |
+
`long_factor` (`list[float]`, *optional*):
|
| 87 |
+
Only used with 'longrope'. The scaling factor to be applied to long contexts (<
|
| 88 |
+
`original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
|
| 89 |
+
size divided by the number of attention heads divided by 2
|
| 90 |
+
`low_freq_factor` (`float`, *optional*):
|
| 91 |
+
Only used with 'llama3'. Scaling factor applied to low frequency components of the RoPE
|
| 92 |
+
`high_freq_factor` (`float`, *optional*):
|
| 93 |
+
Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE
|
| 94 |
+
attention_bias (`bool`, defaults to `False`, *optional*, defaults to `False`):
|
| 95 |
+
Whether to use a bias in the query, key, value and output projection layers during self-attention.
|
| 96 |
+
attention_dropout (`float`, *optional*, defaults to 0.0):
|
| 97 |
+
The dropout ratio for the attention probabilities.
|
| 98 |
+
rms_norm_eps (`float`, *optional*, defaults to 1e-05):
|
| 99 |
+
The epsilon used by the rms normalization layers.
|
| 100 |
+
sliding_window (`int`, *optional*, defaults to 4096):
|
| 101 |
+
Size of the sliding window for sliding window attention.
|
| 102 |
+
layer_types (`list`, *optional*):
|
| 103 |
+
Attention pattern for each layer. Defaults to sliding window attention
|
| 104 |
+
for 3 out of 4 layers, and full attention for every 4th layer.
|
| 105 |
+
|
| 106 |
+
```python
|
| 107 |
+
>>> from transformers import Olmo3Model, Olmo3Config
|
| 108 |
+
|
| 109 |
+
>>> # Initializing a Olmo3 7B style configuration
|
| 110 |
+
>>> configuration = Olmo3Config()
|
| 111 |
+
|
| 112 |
+
>>> # Initializing a model from the Olmo3 7B style configuration
|
| 113 |
+
>>> model = Olmo3Model(configuration)
|
| 114 |
+
|
| 115 |
+
>>> # Accessing the model configuration
|
| 116 |
+
>>> configuration = model.config
|
| 117 |
+
```
|
| 118 |
+
"""
|
| 119 |
+
|
| 120 |
+
model_type = "bolmo"
|
| 121 |
+
keys_to_ignore_at_inference = ["past_key_values"]
|
| 122 |
+
base_model_tp_plan = {
|
| 123 |
+
"layers.*.self_attn.q_proj": "colwise_rep", # we need to replicate here due to the added norm on q and k
|
| 124 |
+
"layers.*.self_attn.k_proj": "colwise_rep", # we need to replicate here due to the added norm on q and k
|
| 125 |
+
"layers.*.self_attn.v_proj": "colwise_rep", # we need to replicate here due to the added norm on q and k
|
| 126 |
+
"layers.*.self_attn.o_proj": "rowwise_rep", # we need to replicate here due to the added norm on q and k
|
| 127 |
+
"layers.*.mlp.gate_proj": "colwise",
|
| 128 |
+
"layers.*.mlp.up_proj": "colwise",
|
| 129 |
+
"layers.*.mlp.down_proj": "rowwise",
|
| 130 |
+
}
|
| 131 |
+
base_model_pp_plan = {
|
| 132 |
+
"embed_tokens": (["input_ids"], ["inputs_embeds"]),
|
| 133 |
+
"layers": (["hidden_states", "attention_mask"], ["hidden_states"]),
|
| 134 |
+
"norm": (["hidden_states"], ["hidden_states"]),
|
| 135 |
+
}
|
| 136 |
+
|
| 137 |
+
def __init__(
|
| 138 |
+
self,
|
| 139 |
+
vocab_size=520,
|
| 140 |
+
hidden_size=4096,
|
| 141 |
+
intermediate_size=11008,
|
| 142 |
+
num_hidden_layers=32,
|
| 143 |
+
num_attention_heads=32,
|
| 144 |
+
num_key_value_heads=None,
|
| 145 |
+
hidden_act="silu",
|
| 146 |
+
max_position_embeddings=2048,
|
| 147 |
+
initializer_range=0.02,
|
| 148 |
+
use_cache=True,
|
| 149 |
+
pad_token_id=1,
|
| 150 |
+
bos_token_id=None,
|
| 151 |
+
eos_token_id=50279,
|
| 152 |
+
tie_word_embeddings=False,
|
| 153 |
+
rope_theta=10000.0,
|
| 154 |
+
rope_scaling=None,
|
| 155 |
+
attention_bias=False,
|
| 156 |
+
attention_dropout=0.0,
|
| 157 |
+
rms_norm_eps=1e-5,
|
| 158 |
+
sliding_window=4096,
|
| 159 |
+
layer_types=None,
|
| 160 |
+
# bolmo config
|
| 161 |
+
add_expanded_embeddings: bool = True,
|
| 162 |
+
boundary_predictor_lookahead: int = 1,
|
| 163 |
+
boundary_threshold: str = "sample:0",
|
| 164 |
+
num_local_encoder_layers: int = 1,
|
| 165 |
+
num_local_decoder_layers: int = 4,
|
| 166 |
+
num_local_heads: int = 16,
|
| 167 |
+
local_intermediate_size: int = 5504,
|
| 168 |
+
local_rms_norm_eps=1e-5,
|
| 169 |
+
subword_vocab_size: int = 100278, # dolma2_tokenizer subword vocab size
|
| 170 |
+
tokenizer_config: BolmoTokenizerConfig | dict[str, Any] | None = None,
|
| 171 |
+
**kwargs,
|
| 172 |
+
):
|
| 173 |
+
super().__init__(
|
| 174 |
+
pad_token_id=pad_token_id,
|
| 175 |
+
bos_token_id=bos_token_id,
|
| 176 |
+
eos_token_id=eos_token_id,
|
| 177 |
+
tie_word_embeddings=tie_word_embeddings,
|
| 178 |
+
**kwargs,
|
| 179 |
+
)
|
| 180 |
+
self.vocab_size = vocab_size
|
| 181 |
+
self.max_position_embeddings = max_position_embeddings
|
| 182 |
+
self.hidden_size = hidden_size
|
| 183 |
+
self.intermediate_size = intermediate_size
|
| 184 |
+
self.num_hidden_layers = num_hidden_layers
|
| 185 |
+
self.num_attention_heads = num_attention_heads
|
| 186 |
+
|
| 187 |
+
# for backward compatibility
|
| 188 |
+
if num_key_value_heads is None:
|
| 189 |
+
num_key_value_heads = num_attention_heads
|
| 190 |
+
|
| 191 |
+
self.num_key_value_heads = num_key_value_heads
|
| 192 |
+
self.hidden_act = hidden_act
|
| 193 |
+
self.initializer_range = initializer_range
|
| 194 |
+
self.use_cache = use_cache
|
| 195 |
+
self.rope_theta = rope_theta
|
| 196 |
+
self.rope_scaling = rope_scaling
|
| 197 |
+
self._rope_scaling_validation()
|
| 198 |
+
self.attention_bias = attention_bias
|
| 199 |
+
self.attention_dropout = attention_dropout
|
| 200 |
+
|
| 201 |
+
self.rms_norm_eps = rms_norm_eps
|
| 202 |
+
|
| 203 |
+
self.sliding_window = sliding_window
|
| 204 |
+
self.layer_types = layer_types
|
| 205 |
+
if self.layer_types is None:
|
| 206 |
+
self.layer_types = [
|
| 207 |
+
"sliding_attention" if (i + 1) % 4 != 0 else "full_attention" for i in range(self.num_hidden_layers)
|
| 208 |
+
]
|
| 209 |
+
layer_type_validation(self.layer_types)
|
| 210 |
+
|
| 211 |
+
# bolmo configuration
|
| 212 |
+
self.add_expanded_embeddings = add_expanded_embeddings
|
| 213 |
+
self.boundary_predictor_lookahead = boundary_predictor_lookahead
|
| 214 |
+
self.boundary_threshold = boundary_threshold
|
| 215 |
+
self.num_local_encoder_layers = num_local_encoder_layers
|
| 216 |
+
self.num_local_decoder_layers = num_local_decoder_layers
|
| 217 |
+
self.num_local_heads = num_local_heads
|
| 218 |
+
self.local_intermediate_size = local_intermediate_size
|
| 219 |
+
self.local_rms_norm_eps = local_rms_norm_eps
|
| 220 |
+
self.subword_vocab_size = subword_vocab_size
|
| 221 |
+
|
| 222 |
+
if tokenizer_config is None:
|
| 223 |
+
self.tokenizer_config = asdict(BolmoTokenizerConfig.bolmo())
|
| 224 |
+
elif isinstance(tokenizer_config, BolmoTokenizerConfig):
|
| 225 |
+
self.tokenizer_config = asdict(tokenizer_config)
|
| 226 |
+
else:
|
| 227 |
+
self.tokenizer_config = tokenizer_config
|
| 228 |
+
|
| 229 |
+
def _rope_scaling_validation(self):
|
| 230 |
+
"""
|
| 231 |
+
Validate the `rope_scaling` configuration.
|
| 232 |
+
"""
|
| 233 |
+
rope_config_validation(self)
|
| 234 |
+
|
| 235 |
+
__all__ = ["BolmoConfig"]
|
generation_config.json
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"_from_model_config": true,
|
| 3 |
+
"eos_token_id": 50279,
|
| 4 |
+
"pad_token_id": 1,
|
| 5 |
+
"transformers_version": "4.57.3"
|
| 6 |
+
}
|
model-00001-of-00007.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:81f8066229794b0a0e47fe6c7bf8561eae00a97ada22c81fb6367e2cdc7ef176
|
| 3 |
+
size 4886163184
|
model-00002-of-00007.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:c9f33172be62508dafd949c4882ebe226e0fb9b3eba72a019b392025f7588538
|
| 3 |
+
size 4857404864
|
model-00003-of-00007.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:a7b3725524bec299eb0e7d2cd6b4ae0f42000c13e68aceaba74eeb25ee398823
|
| 3 |
+
size 4857404880
|
model-00004-of-00007.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:09544694501e0155166d09c13394f759e07cb02d70e7a286b2ee2f50cca071ba
|
| 3 |
+
size 4857404928
|
model-00005-of-00007.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:00f2f618fbc0f28b5d5014d5cd60e5b10f759b476c9f2e9166a4d0bc37dc73e8
|
| 3 |
+
size 4857404928
|
model-00006-of-00007.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:fa5e5b88a94da922bda6f3460d1e202d4c72a2e9aba341ae1fe93cd8787bf1c4
|
| 3 |
+
size 4857404928
|
model-00007-of-00007.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:17d220b8dd89c36a635cb2fcb90a57eff5c5013369ada364319cf6e77849e31d
|
| 3 |
+
size 1359202312
|
model.safetensors.index.json
ADDED
|
@@ -0,0 +1,447 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"metadata": {
|
| 3 |
+
"total_parameters": 7633084576,
|
| 4 |
+
"total_size": 30532338304
|
| 5 |
+
},
|
| 6 |
+
"weight_map": {
|
| 7 |
+
"lm_head.weight": "model-00007-of-00007.safetensors",
|
| 8 |
+
"model.layers.0.mlp.down_proj.weight": "model-00002-of-00007.safetensors",
|
| 9 |
+
"model.layers.0.mlp.gate_proj.weight": "model-00002-of-00007.safetensors",
|
| 10 |
+
"model.layers.0.mlp.up_proj.weight": "model-00002-of-00007.safetensors",
|
| 11 |
+
"model.layers.0.post_attention_layernorm.weight": "model-00002-of-00007.safetensors",
|
| 12 |
+
"model.layers.0.post_feedforward_layernorm.weight": "model-00002-of-00007.safetensors",
|
| 13 |
+
"model.layers.0.self_attn.k_norm.weight": "model-00001-of-00007.safetensors",
|
| 14 |
+
"model.layers.0.self_attn.k_proj.weight": "model-00001-of-00007.safetensors",
|
| 15 |
+
"model.layers.0.self_attn.o_proj.weight": "model-00001-of-00007.safetensors",
|
| 16 |
+
"model.layers.0.self_attn.q_norm.weight": "model-00001-of-00007.safetensors",
|
| 17 |
+
"model.layers.0.self_attn.q_proj.weight": "model-00001-of-00007.safetensors",
|
| 18 |
+
"model.layers.0.self_attn.v_proj.weight": "model-00001-of-00007.safetensors",
|
| 19 |
+
"model.layers.1.mlp.down_proj.weight": "model-00002-of-00007.safetensors",
|
| 20 |
+
"model.layers.1.mlp.gate_proj.weight": "model-00002-of-00007.safetensors",
|
| 21 |
+
"model.layers.1.mlp.up_proj.weight": "model-00002-of-00007.safetensors",
|
| 22 |
+
"model.layers.1.post_attention_layernorm.weight": "model-00002-of-00007.safetensors",
|
| 23 |
+
"model.layers.1.post_feedforward_layernorm.weight": "model-00002-of-00007.safetensors",
|
| 24 |
+
"model.layers.1.self_attn.k_norm.weight": "model-00002-of-00007.safetensors",
|
| 25 |
+
"model.layers.1.self_attn.k_proj.weight": "model-00002-of-00007.safetensors",
|
| 26 |
+
"model.layers.1.self_attn.o_proj.weight": "model-00002-of-00007.safetensors",
|
| 27 |
+
"model.layers.1.self_attn.q_norm.weight": "model-00002-of-00007.safetensors",
|
| 28 |
+
"model.layers.1.self_attn.q_proj.weight": "model-00002-of-00007.safetensors",
|
| 29 |
+
"model.layers.1.self_attn.v_proj.weight": "model-00002-of-00007.safetensors",
|
| 30 |
+
"model.layers.10.mlp.down_proj.weight": "model-00003-of-00007.safetensors",
|
| 31 |
+
"model.layers.10.mlp.gate_proj.weight": "model-00003-of-00007.safetensors",
|
| 32 |
+
"model.layers.10.mlp.up_proj.weight": "model-00003-of-00007.safetensors",
|
| 33 |
+
"model.layers.10.post_attention_layernorm.weight": "model-00003-of-00007.safetensors",
|
| 34 |
+
"model.layers.10.post_feedforward_layernorm.weight": "model-00003-of-00007.safetensors",
|
| 35 |
+
"model.layers.10.self_attn.k_norm.weight": "model-00003-of-00007.safetensors",
|
| 36 |
+
"model.layers.10.self_attn.k_proj.weight": "model-00003-of-00007.safetensors",
|
| 37 |
+
"model.layers.10.self_attn.o_proj.weight": "model-00003-of-00007.safetensors",
|
| 38 |
+
"model.layers.10.self_attn.q_norm.weight": "model-00003-of-00007.safetensors",
|
| 39 |
+
"model.layers.10.self_attn.q_proj.weight": "model-00003-of-00007.safetensors",
|
| 40 |
+
"model.layers.10.self_attn.v_proj.weight": "model-00003-of-00007.safetensors",
|
| 41 |
+
"model.layers.11.mlp.down_proj.weight": "model-00003-of-00007.safetensors",
|
| 42 |
+
"model.layers.11.mlp.gate_proj.weight": "model-00003-of-00007.safetensors",
|
| 43 |
+
"model.layers.11.mlp.up_proj.weight": "model-00003-of-00007.safetensors",
|
| 44 |
+
"model.layers.11.post_attention_layernorm.weight": "model-00003-of-00007.safetensors",
|
| 45 |
+
"model.layers.11.post_feedforward_layernorm.weight": "model-00003-of-00007.safetensors",
|
| 46 |
+
"model.layers.11.self_attn.k_norm.weight": "model-00003-of-00007.safetensors",
|
| 47 |
+
"model.layers.11.self_attn.k_proj.weight": "model-00003-of-00007.safetensors",
|
| 48 |
+
"model.layers.11.self_attn.o_proj.weight": "model-00003-of-00007.safetensors",
|
| 49 |
+
"model.layers.11.self_attn.q_norm.weight": "model-00003-of-00007.safetensors",
|
| 50 |
+
"model.layers.11.self_attn.q_proj.weight": "model-00003-of-00007.safetensors",
|
| 51 |
+
"model.layers.11.self_attn.v_proj.weight": "model-00003-of-00007.safetensors",
|
| 52 |
+
"model.layers.12.mlp.down_proj.weight": "model-00004-of-00007.safetensors",
|
| 53 |
+
"model.layers.12.mlp.gate_proj.weight": "model-00004-of-00007.safetensors",
|
| 54 |
+
"model.layers.12.mlp.up_proj.weight": "model-00004-of-00007.safetensors",
|
| 55 |
+
"model.layers.12.post_attention_layernorm.weight": "model-00004-of-00007.safetensors",
|
| 56 |
+
"model.layers.12.post_feedforward_layernorm.weight": "model-00004-of-00007.safetensors",
|
| 57 |
+
"model.layers.12.self_attn.k_norm.weight": "model-00003-of-00007.safetensors",
|
| 58 |
+
"model.layers.12.self_attn.k_proj.weight": "model-00003-of-00007.safetensors",
|
| 59 |
+
"model.layers.12.self_attn.o_proj.weight": "model-00003-of-00007.safetensors",
|
| 60 |
+
"model.layers.12.self_attn.q_norm.weight": "model-00003-of-00007.safetensors",
|
| 61 |
+
"model.layers.12.self_attn.q_proj.weight": "model-00003-of-00007.safetensors",
|
| 62 |
+
"model.layers.12.self_attn.v_proj.weight": "model-00003-of-00007.safetensors",
|
| 63 |
+
"model.layers.13.mlp.down_proj.weight": "model-00004-of-00007.safetensors",
|
| 64 |
+
"model.layers.13.mlp.gate_proj.weight": "model-00004-of-00007.safetensors",
|
| 65 |
+
"model.layers.13.mlp.up_proj.weight": "model-00004-of-00007.safetensors",
|
| 66 |
+
"model.layers.13.post_attention_layernorm.weight": "model-00004-of-00007.safetensors",
|
| 67 |
+
"model.layers.13.post_feedforward_layernorm.weight": "model-00004-of-00007.safetensors",
|
| 68 |
+
"model.layers.13.self_attn.k_norm.weight": "model-00004-of-00007.safetensors",
|
| 69 |
+
"model.layers.13.self_attn.k_proj.weight": "model-00004-of-00007.safetensors",
|
| 70 |
+
"model.layers.13.self_attn.o_proj.weight": "model-00004-of-00007.safetensors",
|
| 71 |
+
"model.layers.13.self_attn.q_norm.weight": "model-00004-of-00007.safetensors",
|
| 72 |
+
"model.layers.13.self_attn.q_proj.weight": "model-00004-of-00007.safetensors",
|
| 73 |
+
"model.layers.13.self_attn.v_proj.weight": "model-00004-of-00007.safetensors",
|
| 74 |
+
"model.layers.14.mlp.down_proj.weight": "model-00004-of-00007.safetensors",
|
| 75 |
+
"model.layers.14.mlp.gate_proj.weight": "model-00004-of-00007.safetensors",
|
| 76 |
+
"model.layers.14.mlp.up_proj.weight": "model-00004-of-00007.safetensors",
|
| 77 |
+
"model.layers.14.post_attention_layernorm.weight": "model-00004-of-00007.safetensors",
|
| 78 |
+
"model.layers.14.post_feedforward_layernorm.weight": "model-00004-of-00007.safetensors",
|
| 79 |
+
"model.layers.14.self_attn.k_norm.weight": "model-00004-of-00007.safetensors",
|
| 80 |
+
"model.layers.14.self_attn.k_proj.weight": "model-00004-of-00007.safetensors",
|
| 81 |
+
"model.layers.14.self_attn.o_proj.weight": "model-00004-of-00007.safetensors",
|
| 82 |
+
"model.layers.14.self_attn.q_norm.weight": "model-00004-of-00007.safetensors",
|
| 83 |
+
"model.layers.14.self_attn.q_proj.weight": "model-00004-of-00007.safetensors",
|
| 84 |
+
"model.layers.14.self_attn.v_proj.weight": "model-00004-of-00007.safetensors",
|
| 85 |
+
"model.layers.15.mlp.down_proj.weight": "model-00004-of-00007.safetensors",
|
| 86 |
+
"model.layers.15.mlp.gate_proj.weight": "model-00004-of-00007.safetensors",
|
| 87 |
+
"model.layers.15.mlp.up_proj.weight": "model-00004-of-00007.safetensors",
|
| 88 |
+
"model.layers.15.post_attention_layernorm.weight": "model-00004-of-00007.safetensors",
|
| 89 |
+
"model.layers.15.post_feedforward_layernorm.weight": "model-00004-of-00007.safetensors",
|
| 90 |
+
"model.layers.15.self_attn.k_norm.weight": "model-00004-of-00007.safetensors",
|
| 91 |
+
"model.layers.15.self_attn.k_proj.weight": "model-00004-of-00007.safetensors",
|
| 92 |
+
"model.layers.15.self_attn.o_proj.weight": "model-00004-of-00007.safetensors",
|
| 93 |
+
"model.layers.15.self_attn.q_norm.weight": "model-00004-of-00007.safetensors",
|
| 94 |
+
"model.layers.15.self_attn.q_proj.weight": "model-00004-of-00007.safetensors",
|
| 95 |
+
"model.layers.15.self_attn.v_proj.weight": "model-00004-of-00007.safetensors",
|
| 96 |
+
"model.layers.16.mlp.down_proj.weight": "model-00004-of-00007.safetensors",
|
| 97 |
+
"model.layers.16.mlp.gate_proj.weight": "model-00004-of-00007.safetensors",
|
| 98 |
+
"model.layers.16.mlp.up_proj.weight": "model-00004-of-00007.safetensors",
|
| 99 |
+
"model.layers.16.post_attention_layernorm.weight": "model-00004-of-00007.safetensors",
|
| 100 |
+
"model.layers.16.post_feedforward_layernorm.weight": "model-00004-of-00007.safetensors",
|
| 101 |
+
"model.layers.16.self_attn.k_norm.weight": "model-00004-of-00007.safetensors",
|
| 102 |
+
"model.layers.16.self_attn.k_proj.weight": "model-00004-of-00007.safetensors",
|
| 103 |
+
"model.layers.16.self_attn.o_proj.weight": "model-00004-of-00007.safetensors",
|
| 104 |
+
"model.layers.16.self_attn.q_norm.weight": "model-00004-of-00007.safetensors",
|
| 105 |
+
"model.layers.16.self_attn.q_proj.weight": "model-00004-of-00007.safetensors",
|
| 106 |
+
"model.layers.16.self_attn.v_proj.weight": "model-00004-of-00007.safetensors",
|
| 107 |
+
"model.layers.17.mlp.down_proj.weight": "model-00004-of-00007.safetensors",
|
| 108 |
+
"model.layers.17.mlp.gate_proj.weight": "model-00004-of-00007.safetensors",
|
| 109 |
+
"model.layers.17.mlp.up_proj.weight": "model-00004-of-00007.safetensors",
|
| 110 |
+
"model.layers.17.post_attention_layernorm.weight": "model-00004-of-00007.safetensors",
|
| 111 |
+
"model.layers.17.post_feedforward_layernorm.weight": "model-00004-of-00007.safetensors",
|
| 112 |
+
"model.layers.17.self_attn.k_norm.weight": "model-00004-of-00007.safetensors",
|
| 113 |
+
"model.layers.17.self_attn.k_proj.weight": "model-00004-of-00007.safetensors",
|
| 114 |
+
"model.layers.17.self_attn.o_proj.weight": "model-00004-of-00007.safetensors",
|
| 115 |
+
"model.layers.17.self_attn.q_norm.weight": "model-00004-of-00007.safetensors",
|
| 116 |
+
"model.layers.17.self_attn.q_proj.weight": "model-00004-of-00007.safetensors",
|
| 117 |
+
"model.layers.17.self_attn.v_proj.weight": "model-00004-of-00007.safetensors",
|
| 118 |
+
"model.layers.18.mlp.down_proj.weight": "model-00005-of-00007.safetensors",
|
| 119 |
+
"model.layers.18.mlp.gate_proj.weight": "model-00005-of-00007.safetensors",
|
| 120 |
+
"model.layers.18.mlp.up_proj.weight": "model-00005-of-00007.safetensors",
|
| 121 |
+
"model.layers.18.post_attention_layernorm.weight": "model-00005-of-00007.safetensors",
|
| 122 |
+
"model.layers.18.post_feedforward_layernorm.weight": "model-00005-of-00007.safetensors",
|
| 123 |
+
"model.layers.18.self_attn.k_norm.weight": "model-00004-of-00007.safetensors",
|
| 124 |
+
"model.layers.18.self_attn.k_proj.weight": "model-00004-of-00007.safetensors",
|
| 125 |
+
"model.layers.18.self_attn.o_proj.weight": "model-00004-of-00007.safetensors",
|
| 126 |
+
"model.layers.18.self_attn.q_norm.weight": "model-00004-of-00007.safetensors",
|
| 127 |
+
"model.layers.18.self_attn.q_proj.weight": "model-00004-of-00007.safetensors",
|
| 128 |
+
"model.layers.18.self_attn.v_proj.weight": "model-00004-of-00007.safetensors",
|
| 129 |
+
"model.layers.19.mlp.down_proj.weight": "model-00005-of-00007.safetensors",
|
| 130 |
+
"model.layers.19.mlp.gate_proj.weight": "model-00005-of-00007.safetensors",
|
| 131 |
+
"model.layers.19.mlp.up_proj.weight": "model-00005-of-00007.safetensors",
|
| 132 |
+
"model.layers.19.post_attention_layernorm.weight": "model-00005-of-00007.safetensors",
|
| 133 |
+
"model.layers.19.post_feedforward_layernorm.weight": "model-00005-of-00007.safetensors",
|
| 134 |
+
"model.layers.19.self_attn.k_norm.weight": "model-00005-of-00007.safetensors",
|
| 135 |
+
"model.layers.19.self_attn.k_proj.weight": "model-00005-of-00007.safetensors",
|
| 136 |
+
"model.layers.19.self_attn.o_proj.weight": "model-00005-of-00007.safetensors",
|
| 137 |
+
"model.layers.19.self_attn.q_norm.weight": "model-00005-of-00007.safetensors",
|
| 138 |
+
"model.layers.19.self_attn.q_proj.weight": "model-00005-of-00007.safetensors",
|
| 139 |
+
"model.layers.19.self_attn.v_proj.weight": "model-00005-of-00007.safetensors",
|
| 140 |
+
"model.layers.2.mlp.down_proj.weight": "model-00002-of-00007.safetensors",
|
| 141 |
+
"model.layers.2.mlp.gate_proj.weight": "model-00002-of-00007.safetensors",
|
| 142 |
+
"model.layers.2.mlp.up_proj.weight": "model-00002-of-00007.safetensors",
|
| 143 |
+
"model.layers.2.post_attention_layernorm.weight": "model-00002-of-00007.safetensors",
|
| 144 |
+
"model.layers.2.post_feedforward_layernorm.weight": "model-00002-of-00007.safetensors",
|
| 145 |
+
"model.layers.2.self_attn.k_norm.weight": "model-00002-of-00007.safetensors",
|
| 146 |
+
"model.layers.2.self_attn.k_proj.weight": "model-00002-of-00007.safetensors",
|
| 147 |
+
"model.layers.2.self_attn.o_proj.weight": "model-00002-of-00007.safetensors",
|
| 148 |
+
"model.layers.2.self_attn.q_norm.weight": "model-00002-of-00007.safetensors",
|
| 149 |
+
"model.layers.2.self_attn.q_proj.weight": "model-00002-of-00007.safetensors",
|
| 150 |
+
"model.layers.2.self_attn.v_proj.weight": "model-00002-of-00007.safetensors",
|
| 151 |
+
"model.layers.20.mlp.down_proj.weight": "model-00005-of-00007.safetensors",
|
| 152 |
+
"model.layers.20.mlp.gate_proj.weight": "model-00005-of-00007.safetensors",
|
| 153 |
+
"model.layers.20.mlp.up_proj.weight": "model-00005-of-00007.safetensors",
|
| 154 |
+
"model.layers.20.post_attention_layernorm.weight": "model-00005-of-00007.safetensors",
|
| 155 |
+
"model.layers.20.post_feedforward_layernorm.weight": "model-00005-of-00007.safetensors",
|
| 156 |
+
"model.layers.20.self_attn.k_norm.weight": "model-00005-of-00007.safetensors",
|
| 157 |
+
"model.layers.20.self_attn.k_proj.weight": "model-00005-of-00007.safetensors",
|
| 158 |
+
"model.layers.20.self_attn.o_proj.weight": "model-00005-of-00007.safetensors",
|
| 159 |
+
"model.layers.20.self_attn.q_norm.weight": "model-00005-of-00007.safetensors",
|
| 160 |
+
"model.layers.20.self_attn.q_proj.weight": "model-00005-of-00007.safetensors",
|
| 161 |
+
"model.layers.20.self_attn.v_proj.weight": "model-00005-of-00007.safetensors",
|
| 162 |
+
"model.layers.21.mlp.down_proj.weight": "model-00005-of-00007.safetensors",
|
| 163 |
+
"model.layers.21.mlp.gate_proj.weight": "model-00005-of-00007.safetensors",
|
| 164 |
+
"model.layers.21.mlp.up_proj.weight": "model-00005-of-00007.safetensors",
|
| 165 |
+
"model.layers.21.post_attention_layernorm.weight": "model-00005-of-00007.safetensors",
|
| 166 |
+
"model.layers.21.post_feedforward_layernorm.weight": "model-00005-of-00007.safetensors",
|
| 167 |
+
"model.layers.21.self_attn.k_norm.weight": "model-00005-of-00007.safetensors",
|
| 168 |
+
"model.layers.21.self_attn.k_proj.weight": "model-00005-of-00007.safetensors",
|
| 169 |
+
"model.layers.21.self_attn.o_proj.weight": "model-00005-of-00007.safetensors",
|
| 170 |
+
"model.layers.21.self_attn.q_norm.weight": "model-00005-of-00007.safetensors",
|
| 171 |
+
"model.layers.21.self_attn.q_proj.weight": "model-00005-of-00007.safetensors",
|
| 172 |
+
"model.layers.21.self_attn.v_proj.weight": "model-00005-of-00007.safetensors",
|
| 173 |
+
"model.layers.22.mlp.down_proj.weight": "model-00005-of-00007.safetensors",
|
| 174 |
+
"model.layers.22.mlp.gate_proj.weight": "model-00005-of-00007.safetensors",
|
| 175 |
+
"model.layers.22.mlp.up_proj.weight": "model-00005-of-00007.safetensors",
|
| 176 |
+
"model.layers.22.post_attention_layernorm.weight": "model-00005-of-00007.safetensors",
|
| 177 |
+
"model.layers.22.post_feedforward_layernorm.weight": "model-00005-of-00007.safetensors",
|
| 178 |
+
"model.layers.22.self_attn.k_norm.weight": "model-00005-of-00007.safetensors",
|
| 179 |
+
"model.layers.22.self_attn.k_proj.weight": "model-00005-of-00007.safetensors",
|
| 180 |
+
"model.layers.22.self_attn.o_proj.weight": "model-00005-of-00007.safetensors",
|
| 181 |
+
"model.layers.22.self_attn.q_norm.weight": "model-00005-of-00007.safetensors",
|
| 182 |
+
"model.layers.22.self_attn.q_proj.weight": "model-00005-of-00007.safetensors",
|
| 183 |
+
"model.layers.22.self_attn.v_proj.weight": "model-00005-of-00007.safetensors",
|
| 184 |
+
"model.layers.23.mlp.down_proj.weight": "model-00005-of-00007.safetensors",
|
| 185 |
+
"model.layers.23.mlp.gate_proj.weight": "model-00005-of-00007.safetensors",
|
| 186 |
+
"model.layers.23.mlp.up_proj.weight": "model-00005-of-00007.safetensors",
|
| 187 |
+
"model.layers.23.post_attention_layernorm.weight": "model-00005-of-00007.safetensors",
|
| 188 |
+
"model.layers.23.post_feedforward_layernorm.weight": "model-00005-of-00007.safetensors",
|
| 189 |
+
"model.layers.23.self_attn.k_norm.weight": "model-00005-of-00007.safetensors",
|
| 190 |
+
"model.layers.23.self_attn.k_proj.weight": "model-00005-of-00007.safetensors",
|
| 191 |
+
"model.layers.23.self_attn.o_proj.weight": "model-00005-of-00007.safetensors",
|
| 192 |
+
"model.layers.23.self_attn.q_norm.weight": "model-00005-of-00007.safetensors",
|
| 193 |
+
"model.layers.23.self_attn.q_proj.weight": "model-00005-of-00007.safetensors",
|
| 194 |
+
"model.layers.23.self_attn.v_proj.weight": "model-00005-of-00007.safetensors",
|
| 195 |
+
"model.layers.24.mlp.down_proj.weight": "model-00006-of-00007.safetensors",
|
| 196 |
+
"model.layers.24.mlp.gate_proj.weight": "model-00006-of-00007.safetensors",
|
| 197 |
+
"model.layers.24.mlp.up_proj.weight": "model-00006-of-00007.safetensors",
|
| 198 |
+
"model.layers.24.post_attention_layernorm.weight": "model-00006-of-00007.safetensors",
|
| 199 |
+
"model.layers.24.post_feedforward_layernorm.weight": "model-00006-of-00007.safetensors",
|
| 200 |
+
"model.layers.24.self_attn.k_norm.weight": "model-00005-of-00007.safetensors",
|
| 201 |
+
"model.layers.24.self_attn.k_proj.weight": "model-00005-of-00007.safetensors",
|
| 202 |
+
"model.layers.24.self_attn.o_proj.weight": "model-00005-of-00007.safetensors",
|
| 203 |
+
"model.layers.24.self_attn.q_norm.weight": "model-00005-of-00007.safetensors",
|
| 204 |
+
"model.layers.24.self_attn.q_proj.weight": "model-00005-of-00007.safetensors",
|
| 205 |
+
"model.layers.24.self_attn.v_proj.weight": "model-00005-of-00007.safetensors",
|
| 206 |
+
"model.layers.25.mlp.down_proj.weight": "model-00006-of-00007.safetensors",
|
| 207 |
+
"model.layers.25.mlp.gate_proj.weight": "model-00006-of-00007.safetensors",
|
| 208 |
+
"model.layers.25.mlp.up_proj.weight": "model-00006-of-00007.safetensors",
|
| 209 |
+
"model.layers.25.post_attention_layernorm.weight": "model-00006-of-00007.safetensors",
|
| 210 |
+
"model.layers.25.post_feedforward_layernorm.weight": "model-00006-of-00007.safetensors",
|
| 211 |
+
"model.layers.25.self_attn.k_norm.weight": "model-00006-of-00007.safetensors",
|
| 212 |
+
"model.layers.25.self_attn.k_proj.weight": "model-00006-of-00007.safetensors",
|
| 213 |
+
"model.layers.25.self_attn.o_proj.weight": "model-00006-of-00007.safetensors",
|
| 214 |
+
"model.layers.25.self_attn.q_norm.weight": "model-00006-of-00007.safetensors",
|
| 215 |
+
"model.layers.25.self_attn.q_proj.weight": "model-00006-of-00007.safetensors",
|
| 216 |
+
"model.layers.25.self_attn.v_proj.weight": "model-00006-of-00007.safetensors",
|
| 217 |
+
"model.layers.26.mlp.down_proj.weight": "model-00006-of-00007.safetensors",
|
| 218 |
+
"model.layers.26.mlp.gate_proj.weight": "model-00006-of-00007.safetensors",
|
| 219 |
+
"model.layers.26.mlp.up_proj.weight": "model-00006-of-00007.safetensors",
|
| 220 |
+
"model.layers.26.post_attention_layernorm.weight": "model-00006-of-00007.safetensors",
|
| 221 |
+
"model.layers.26.post_feedforward_layernorm.weight": "model-00006-of-00007.safetensors",
|
| 222 |
+
"model.layers.26.self_attn.k_norm.weight": "model-00006-of-00007.safetensors",
|
| 223 |
+
"model.layers.26.self_attn.k_proj.weight": "model-00006-of-00007.safetensors",
|
| 224 |
+
"model.layers.26.self_attn.o_proj.weight": "model-00006-of-00007.safetensors",
|
| 225 |
+
"model.layers.26.self_attn.q_norm.weight": "model-00006-of-00007.safetensors",
|
| 226 |
+
"model.layers.26.self_attn.q_proj.weight": "model-00006-of-00007.safetensors",
|
| 227 |
+
"model.layers.26.self_attn.v_proj.weight": "model-00006-of-00007.safetensors",
|
| 228 |
+
"model.layers.27.mlp.down_proj.weight": "model-00006-of-00007.safetensors",
|
| 229 |
+
"model.layers.27.mlp.gate_proj.weight": "model-00006-of-00007.safetensors",
|
| 230 |
+
"model.layers.27.mlp.up_proj.weight": "model-00006-of-00007.safetensors",
|
| 231 |
+
"model.layers.27.post_attention_layernorm.weight": "model-00006-of-00007.safetensors",
|
| 232 |
+
"model.layers.27.post_feedforward_layernorm.weight": "model-00006-of-00007.safetensors",
|
| 233 |
+
"model.layers.27.self_attn.k_norm.weight": "model-00006-of-00007.safetensors",
|
| 234 |
+
"model.layers.27.self_attn.k_proj.weight": "model-00006-of-00007.safetensors",
|
| 235 |
+
"model.layers.27.self_attn.o_proj.weight": "model-00006-of-00007.safetensors",
|
| 236 |
+
"model.layers.27.self_attn.q_norm.weight": "model-00006-of-00007.safetensors",
|
| 237 |
+
"model.layers.27.self_attn.q_proj.weight": "model-00006-of-00007.safetensors",
|
| 238 |
+
"model.layers.27.self_attn.v_proj.weight": "model-00006-of-00007.safetensors",
|
| 239 |
+
"model.layers.28.mlp.down_proj.weight": "model-00006-of-00007.safetensors",
|
| 240 |
+
"model.layers.28.mlp.gate_proj.weight": "model-00006-of-00007.safetensors",
|
| 241 |
+
"model.layers.28.mlp.up_proj.weight": "model-00006-of-00007.safetensors",
|
| 242 |
+
"model.layers.28.post_attention_layernorm.weight": "model-00006-of-00007.safetensors",
|
| 243 |
+
"model.layers.28.post_feedforward_layernorm.weight": "model-00006-of-00007.safetensors",
|
| 244 |
+
"model.layers.28.self_attn.k_norm.weight": "model-00006-of-00007.safetensors",
|
| 245 |
+
"model.layers.28.self_attn.k_proj.weight": "model-00006-of-00007.safetensors",
|
| 246 |
+
"model.layers.28.self_attn.o_proj.weight": "model-00006-of-00007.safetensors",
|
| 247 |
+
"model.layers.28.self_attn.q_norm.weight": "model-00006-of-00007.safetensors",
|
| 248 |
+
"model.layers.28.self_attn.q_proj.weight": "model-00006-of-00007.safetensors",
|
| 249 |
+
"model.layers.28.self_attn.v_proj.weight": "model-00006-of-00007.safetensors",
|
| 250 |
+
"model.layers.29.mlp.down_proj.weight": "model-00006-of-00007.safetensors",
|
| 251 |
+
"model.layers.29.mlp.gate_proj.weight": "model-00006-of-00007.safetensors",
|
| 252 |
+
"model.layers.29.mlp.up_proj.weight": "model-00006-of-00007.safetensors",
|
| 253 |
+
"model.layers.29.post_attention_layernorm.weight": "model-00006-of-00007.safetensors",
|
| 254 |
+
"model.layers.29.post_feedforward_layernorm.weight": "model-00006-of-00007.safetensors",
|
| 255 |
+
"model.layers.29.self_attn.k_norm.weight": "model-00006-of-00007.safetensors",
|
| 256 |
+
"model.layers.29.self_attn.k_proj.weight": "model-00006-of-00007.safetensors",
|
| 257 |
+
"model.layers.29.self_attn.o_proj.weight": "model-00006-of-00007.safetensors",
|
| 258 |
+
"model.layers.29.self_attn.q_norm.weight": "model-00006-of-00007.safetensors",
|
| 259 |
+
"model.layers.29.self_attn.q_proj.weight": "model-00006-of-00007.safetensors",
|
| 260 |
+
"model.layers.29.self_attn.v_proj.weight": "model-00006-of-00007.safetensors",
|
| 261 |
+
"model.layers.3.mlp.down_proj.weight": "model-00002-of-00007.safetensors",
|
| 262 |
+
"model.layers.3.mlp.gate_proj.weight": "model-00002-of-00007.safetensors",
|
| 263 |
+
"model.layers.3.mlp.up_proj.weight": "model-00002-of-00007.safetensors",
|
| 264 |
+
"model.layers.3.post_attention_layernorm.weight": "model-00002-of-00007.safetensors",
|
| 265 |
+
"model.layers.3.post_feedforward_layernorm.weight": "model-00002-of-00007.safetensors",
|
| 266 |
+
"model.layers.3.self_attn.k_norm.weight": "model-00002-of-00007.safetensors",
|
| 267 |
+
"model.layers.3.self_attn.k_proj.weight": "model-00002-of-00007.safetensors",
|
| 268 |
+
"model.layers.3.self_attn.o_proj.weight": "model-00002-of-00007.safetensors",
|
| 269 |
+
"model.layers.3.self_attn.q_norm.weight": "model-00002-of-00007.safetensors",
|
| 270 |
+
"model.layers.3.self_attn.q_proj.weight": "model-00002-of-00007.safetensors",
|
| 271 |
+
"model.layers.3.self_attn.v_proj.weight": "model-00002-of-00007.safetensors",
|
| 272 |
+
"model.layers.30.mlp.down_proj.weight": "model-00007-of-00007.safetensors",
|
| 273 |
+
"model.layers.30.mlp.gate_proj.weight": "model-00007-of-00007.safetensors",
|
| 274 |
+
"model.layers.30.mlp.up_proj.weight": "model-00007-of-00007.safetensors",
|
| 275 |
+
"model.layers.30.post_attention_layernorm.weight": "model-00007-of-00007.safetensors",
|
| 276 |
+
"model.layers.30.post_feedforward_layernorm.weight": "model-00007-of-00007.safetensors",
|
| 277 |
+
"model.layers.30.self_attn.k_norm.weight": "model-00006-of-00007.safetensors",
|
| 278 |
+
"model.layers.30.self_attn.k_proj.weight": "model-00006-of-00007.safetensors",
|
| 279 |
+
"model.layers.30.self_attn.o_proj.weight": "model-00006-of-00007.safetensors",
|
| 280 |
+
"model.layers.30.self_attn.q_norm.weight": "model-00006-of-00007.safetensors",
|
| 281 |
+
"model.layers.30.self_attn.q_proj.weight": "model-00006-of-00007.safetensors",
|
| 282 |
+
"model.layers.30.self_attn.v_proj.weight": "model-00006-of-00007.safetensors",
|
| 283 |
+
"model.layers.31.mlp.down_proj.weight": "model-00007-of-00007.safetensors",
|
| 284 |
+
"model.layers.31.mlp.gate_proj.weight": "model-00007-of-00007.safetensors",
|
| 285 |
+
"model.layers.31.mlp.up_proj.weight": "model-00007-of-00007.safetensors",
|
| 286 |
+
"model.layers.31.post_attention_layernorm.weight": "model-00007-of-00007.safetensors",
|
| 287 |
+
"model.layers.31.post_feedforward_layernorm.weight": "model-00007-of-00007.safetensors",
|
| 288 |
+
"model.layers.31.self_attn.k_norm.weight": "model-00007-of-00007.safetensors",
|
| 289 |
+
"model.layers.31.self_attn.k_proj.weight": "model-00007-of-00007.safetensors",
|
| 290 |
+
"model.layers.31.self_attn.o_proj.weight": "model-00007-of-00007.safetensors",
|
| 291 |
+
"model.layers.31.self_attn.q_norm.weight": "model-00007-of-00007.safetensors",
|
| 292 |
+
"model.layers.31.self_attn.q_proj.weight": "model-00007-of-00007.safetensors",
|
| 293 |
+
"model.layers.31.self_attn.v_proj.weight": "model-00007-of-00007.safetensors",
|
| 294 |
+
"model.layers.4.mlp.down_proj.weight": "model-00002-of-00007.safetensors",
|
| 295 |
+
"model.layers.4.mlp.gate_proj.weight": "model-00002-of-00007.safetensors",
|
| 296 |
+
"model.layers.4.mlp.up_proj.weight": "model-00002-of-00007.safetensors",
|
| 297 |
+
"model.layers.4.post_attention_layernorm.weight": "model-00002-of-00007.safetensors",
|
| 298 |
+
"model.layers.4.post_feedforward_layernorm.weight": "model-00002-of-00007.safetensors",
|
| 299 |
+
"model.layers.4.self_attn.k_norm.weight": "model-00002-of-00007.safetensors",
|
| 300 |
+
"model.layers.4.self_attn.k_proj.weight": "model-00002-of-00007.safetensors",
|
| 301 |
+
"model.layers.4.self_attn.o_proj.weight": "model-00002-of-00007.safetensors",
|
| 302 |
+
"model.layers.4.self_attn.q_norm.weight": "model-00002-of-00007.safetensors",
|
| 303 |
+
"model.layers.4.self_attn.q_proj.weight": "model-00002-of-00007.safetensors",
|
| 304 |
+
"model.layers.4.self_attn.v_proj.weight": "model-00002-of-00007.safetensors",
|
| 305 |
+
"model.layers.5.mlp.down_proj.weight": "model-00002-of-00007.safetensors",
|
| 306 |
+
"model.layers.5.mlp.gate_proj.weight": "model-00002-of-00007.safetensors",
|
| 307 |
+
"model.layers.5.mlp.up_proj.weight": "model-00002-of-00007.safetensors",
|
| 308 |
+
"model.layers.5.post_attention_layernorm.weight": "model-00002-of-00007.safetensors",
|
| 309 |
+
"model.layers.5.post_feedforward_layernorm.weight": "model-00002-of-00007.safetensors",
|
| 310 |
+
"model.layers.5.self_attn.k_norm.weight": "model-00002-of-00007.safetensors",
|
| 311 |
+
"model.layers.5.self_attn.k_proj.weight": "model-00002-of-00007.safetensors",
|
| 312 |
+
"model.layers.5.self_attn.o_proj.weight": "model-00002-of-00007.safetensors",
|
| 313 |
+
"model.layers.5.self_attn.q_norm.weight": "model-00002-of-00007.safetensors",
|
| 314 |
+
"model.layers.5.self_attn.q_proj.weight": "model-00002-of-00007.safetensors",
|
| 315 |
+
"model.layers.5.self_attn.v_proj.weight": "model-00002-of-00007.safetensors",
|
| 316 |
+
"model.layers.6.mlp.down_proj.weight": "model-00003-of-00007.safetensors",
|
| 317 |
+
"model.layers.6.mlp.gate_proj.weight": "model-00003-of-00007.safetensors",
|
| 318 |
+
"model.layers.6.mlp.up_proj.weight": "model-00003-of-00007.safetensors",
|
| 319 |
+
"model.layers.6.post_attention_layernorm.weight": "model-00003-of-00007.safetensors",
|
| 320 |
+
"model.layers.6.post_feedforward_layernorm.weight": "model-00003-of-00007.safetensors",
|
| 321 |
+
"model.layers.6.self_attn.k_norm.weight": "model-00002-of-00007.safetensors",
|
| 322 |
+
"model.layers.6.self_attn.k_proj.weight": "model-00002-of-00007.safetensors",
|
| 323 |
+
"model.layers.6.self_attn.o_proj.weight": "model-00002-of-00007.safetensors",
|
| 324 |
+
"model.layers.6.self_attn.q_norm.weight": "model-00002-of-00007.safetensors",
|
| 325 |
+
"model.layers.6.self_attn.q_proj.weight": "model-00002-of-00007.safetensors",
|
| 326 |
+
"model.layers.6.self_attn.v_proj.weight": "model-00002-of-00007.safetensors",
|
| 327 |
+
"model.layers.7.mlp.down_proj.weight": "model-00003-of-00007.safetensors",
|
| 328 |
+
"model.layers.7.mlp.gate_proj.weight": "model-00003-of-00007.safetensors",
|
| 329 |
+
"model.layers.7.mlp.up_proj.weight": "model-00003-of-00007.safetensors",
|
| 330 |
+
"model.layers.7.post_attention_layernorm.weight": "model-00003-of-00007.safetensors",
|
| 331 |
+
"model.layers.7.post_feedforward_layernorm.weight": "model-00003-of-00007.safetensors",
|
| 332 |
+
"model.layers.7.self_attn.k_norm.weight": "model-00003-of-00007.safetensors",
|
| 333 |
+
"model.layers.7.self_attn.k_proj.weight": "model-00003-of-00007.safetensors",
|
| 334 |
+
"model.layers.7.self_attn.o_proj.weight": "model-00003-of-00007.safetensors",
|
| 335 |
+
"model.layers.7.self_attn.q_norm.weight": "model-00003-of-00007.safetensors",
|
| 336 |
+
"model.layers.7.self_attn.q_proj.weight": "model-00003-of-00007.safetensors",
|
| 337 |
+
"model.layers.7.self_attn.v_proj.weight": "model-00003-of-00007.safetensors",
|
| 338 |
+
"model.layers.8.mlp.down_proj.weight": "model-00003-of-00007.safetensors",
|
| 339 |
+
"model.layers.8.mlp.gate_proj.weight": "model-00003-of-00007.safetensors",
|
| 340 |
+
"model.layers.8.mlp.up_proj.weight": "model-00003-of-00007.safetensors",
|
| 341 |
+
"model.layers.8.post_attention_layernorm.weight": "model-00003-of-00007.safetensors",
|
| 342 |
+
"model.layers.8.post_feedforward_layernorm.weight": "model-00003-of-00007.safetensors",
|
| 343 |
+
"model.layers.8.self_attn.k_norm.weight": "model-00003-of-00007.safetensors",
|
| 344 |
+
"model.layers.8.self_attn.k_proj.weight": "model-00003-of-00007.safetensors",
|
| 345 |
+
"model.layers.8.self_attn.o_proj.weight": "model-00003-of-00007.safetensors",
|
| 346 |
+
"model.layers.8.self_attn.q_norm.weight": "model-00003-of-00007.safetensors",
|
| 347 |
+
"model.layers.8.self_attn.q_proj.weight": "model-00003-of-00007.safetensors",
|
| 348 |
+
"model.layers.8.self_attn.v_proj.weight": "model-00003-of-00007.safetensors",
|
| 349 |
+
"model.layers.9.mlp.down_proj.weight": "model-00003-of-00007.safetensors",
|
| 350 |
+
"model.layers.9.mlp.gate_proj.weight": "model-00003-of-00007.safetensors",
|
| 351 |
+
"model.layers.9.mlp.up_proj.weight": "model-00003-of-00007.safetensors",
|
| 352 |
+
"model.layers.9.post_attention_layernorm.weight": "model-00003-of-00007.safetensors",
|
| 353 |
+
"model.layers.9.post_feedforward_layernorm.weight": "model-00003-of-00007.safetensors",
|
| 354 |
+
"model.layers.9.self_attn.k_norm.weight": "model-00003-of-00007.safetensors",
|
| 355 |
+
"model.layers.9.self_attn.k_proj.weight": "model-00003-of-00007.safetensors",
|
| 356 |
+
"model.layers.9.self_attn.o_proj.weight": "model-00003-of-00007.safetensors",
|
| 357 |
+
"model.layers.9.self_attn.q_norm.weight": "model-00003-of-00007.safetensors",
|
| 358 |
+
"model.layers.9.self_attn.q_proj.weight": "model-00003-of-00007.safetensors",
|
| 359 |
+
"model.layers.9.self_attn.v_proj.weight": "model-00003-of-00007.safetensors",
|
| 360 |
+
"model.local_decoder.in_projection.bias": "model-00001-of-00007.safetensors",
|
| 361 |
+
"model.local_decoder.in_projection.weight": "model-00001-of-00007.safetensors",
|
| 362 |
+
"model.local_decoder.initial_norm.weight": "model-00001-of-00007.safetensors",
|
| 363 |
+
"model.local_decoder.layers.0.mlp.down_proj.weight": "model-00001-of-00007.safetensors",
|
| 364 |
+
"model.local_decoder.layers.0.mlp.gate_proj.weight": "model-00001-of-00007.safetensors",
|
| 365 |
+
"model.local_decoder.layers.0.mlp.up_proj.weight": "model-00001-of-00007.safetensors",
|
| 366 |
+
"model.local_decoder.layers.0.pre_feedforward_layernorm.weight": "model-00001-of-00007.safetensors",
|
| 367 |
+
"model.local_decoder.layers.0.pre_xlstm_layernorm.weight": "model-00001-of-00007.safetensors",
|
| 368 |
+
"model.local_decoder.layers.0.xlstm.fgate_preact.bias": "model-00001-of-00007.safetensors",
|
| 369 |
+
"model.local_decoder.layers.0.xlstm.fgate_preact.weight": "model-00001-of-00007.safetensors",
|
| 370 |
+
"model.local_decoder.layers.0.xlstm.igate_preact.bias": "model-00001-of-00007.safetensors",
|
| 371 |
+
"model.local_decoder.layers.0.xlstm.igate_preact.weight": "model-00001-of-00007.safetensors",
|
| 372 |
+
"model.local_decoder.layers.0.xlstm.k.weight": "model-00001-of-00007.safetensors",
|
| 373 |
+
"model.local_decoder.layers.0.xlstm.multihead_norm.weight": "model-00001-of-00007.safetensors",
|
| 374 |
+
"model.local_decoder.layers.0.xlstm.ogate_preact.weight": "model-00001-of-00007.safetensors",
|
| 375 |
+
"model.local_decoder.layers.0.xlstm.out_proj.weight": "model-00001-of-00007.safetensors",
|
| 376 |
+
"model.local_decoder.layers.0.xlstm.q.weight": "model-00001-of-00007.safetensors",
|
| 377 |
+
"model.local_decoder.layers.0.xlstm.v.weight": "model-00001-of-00007.safetensors",
|
| 378 |
+
"model.local_decoder.layers.1.mlp.down_proj.weight": "model-00001-of-00007.safetensors",
|
| 379 |
+
"model.local_decoder.layers.1.mlp.gate_proj.weight": "model-00001-of-00007.safetensors",
|
| 380 |
+
"model.local_decoder.layers.1.mlp.up_proj.weight": "model-00001-of-00007.safetensors",
|
| 381 |
+
"model.local_decoder.layers.1.pre_feedforward_layernorm.weight": "model-00001-of-00007.safetensors",
|
| 382 |
+
"model.local_decoder.layers.1.pre_xlstm_layernorm.weight": "model-00001-of-00007.safetensors",
|
| 383 |
+
"model.local_decoder.layers.1.xlstm.fgate_preact.bias": "model-00001-of-00007.safetensors",
|
| 384 |
+
"model.local_decoder.layers.1.xlstm.fgate_preact.weight": "model-00001-of-00007.safetensors",
|
| 385 |
+
"model.local_decoder.layers.1.xlstm.igate_preact.bias": "model-00001-of-00007.safetensors",
|
| 386 |
+
"model.local_decoder.layers.1.xlstm.igate_preact.weight": "model-00001-of-00007.safetensors",
|
| 387 |
+
"model.local_decoder.layers.1.xlstm.k.weight": "model-00001-of-00007.safetensors",
|
| 388 |
+
"model.local_decoder.layers.1.xlstm.multihead_norm.weight": "model-00001-of-00007.safetensors",
|
| 389 |
+
"model.local_decoder.layers.1.xlstm.ogate_preact.weight": "model-00001-of-00007.safetensors",
|
| 390 |
+
"model.local_decoder.layers.1.xlstm.out_proj.weight": "model-00001-of-00007.safetensors",
|
| 391 |
+
"model.local_decoder.layers.1.xlstm.q.weight": "model-00001-of-00007.safetensors",
|
| 392 |
+
"model.local_decoder.layers.1.xlstm.v.weight": "model-00001-of-00007.safetensors",
|
| 393 |
+
"model.local_decoder.layers.2.mlp.down_proj.weight": "model-00001-of-00007.safetensors",
|
| 394 |
+
"model.local_decoder.layers.2.mlp.gate_proj.weight": "model-00001-of-00007.safetensors",
|
| 395 |
+
"model.local_decoder.layers.2.mlp.up_proj.weight": "model-00001-of-00007.safetensors",
|
| 396 |
+
"model.local_decoder.layers.2.pre_feedforward_layernorm.weight": "model-00001-of-00007.safetensors",
|
| 397 |
+
"model.local_decoder.layers.2.pre_xlstm_layernorm.weight": "model-00001-of-00007.safetensors",
|
| 398 |
+
"model.local_decoder.layers.2.xlstm.fgate_preact.bias": "model-00001-of-00007.safetensors",
|
| 399 |
+
"model.local_decoder.layers.2.xlstm.fgate_preact.weight": "model-00001-of-00007.safetensors",
|
| 400 |
+
"model.local_decoder.layers.2.xlstm.igate_preact.bias": "model-00001-of-00007.safetensors",
|
| 401 |
+
"model.local_decoder.layers.2.xlstm.igate_preact.weight": "model-00001-of-00007.safetensors",
|
| 402 |
+
"model.local_decoder.layers.2.xlstm.k.weight": "model-00001-of-00007.safetensors",
|
| 403 |
+
"model.local_decoder.layers.2.xlstm.multihead_norm.weight": "model-00001-of-00007.safetensors",
|
| 404 |
+
"model.local_decoder.layers.2.xlstm.ogate_preact.weight": "model-00001-of-00007.safetensors",
|
| 405 |
+
"model.local_decoder.layers.2.xlstm.out_proj.weight": "model-00001-of-00007.safetensors",
|
| 406 |
+
"model.local_decoder.layers.2.xlstm.q.weight": "model-00001-of-00007.safetensors",
|
| 407 |
+
"model.local_decoder.layers.2.xlstm.v.weight": "model-00001-of-00007.safetensors",
|
| 408 |
+
"model.local_decoder.layers.3.mlp.down_proj.weight": "model-00001-of-00007.safetensors",
|
| 409 |
+
"model.local_decoder.layers.3.mlp.gate_proj.weight": "model-00001-of-00007.safetensors",
|
| 410 |
+
"model.local_decoder.layers.3.mlp.up_proj.weight": "model-00001-of-00007.safetensors",
|
| 411 |
+
"model.local_decoder.layers.3.pre_feedforward_layernorm.weight": "model-00001-of-00007.safetensors",
|
| 412 |
+
"model.local_decoder.layers.3.pre_xlstm_layernorm.weight": "model-00001-of-00007.safetensors",
|
| 413 |
+
"model.local_decoder.layers.3.xlstm.fgate_preact.bias": "model-00001-of-00007.safetensors",
|
| 414 |
+
"model.local_decoder.layers.3.xlstm.fgate_preact.weight": "model-00001-of-00007.safetensors",
|
| 415 |
+
"model.local_decoder.layers.3.xlstm.igate_preact.bias": "model-00001-of-00007.safetensors",
|
| 416 |
+
"model.local_decoder.layers.3.xlstm.igate_preact.weight": "model-00001-of-00007.safetensors",
|
| 417 |
+
"model.local_decoder.layers.3.xlstm.k.weight": "model-00001-of-00007.safetensors",
|
| 418 |
+
"model.local_decoder.layers.3.xlstm.multihead_norm.weight": "model-00001-of-00007.safetensors",
|
| 419 |
+
"model.local_decoder.layers.3.xlstm.ogate_preact.weight": "model-00001-of-00007.safetensors",
|
| 420 |
+
"model.local_decoder.layers.3.xlstm.out_proj.weight": "model-00001-of-00007.safetensors",
|
| 421 |
+
"model.local_decoder.layers.3.xlstm.q.weight": "model-00001-of-00007.safetensors",
|
| 422 |
+
"model.local_decoder.layers.3.xlstm.v.weight": "model-00001-of-00007.safetensors",
|
| 423 |
+
"model.local_encoder.boundary_predictor_module.k_proj_layer.weight": "model-00001-of-00007.safetensors",
|
| 424 |
+
"model.local_encoder.boundary_predictor_module.q_proj_layer.weight": "model-00001-of-00007.safetensors",
|
| 425 |
+
"model.local_encoder.byte_embedding.weight": "model-00001-of-00007.safetensors",
|
| 426 |
+
"model.local_encoder.layers.0.mlp.down_proj.weight": "model-00001-of-00007.safetensors",
|
| 427 |
+
"model.local_encoder.layers.0.mlp.gate_proj.weight": "model-00001-of-00007.safetensors",
|
| 428 |
+
"model.local_encoder.layers.0.mlp.up_proj.weight": "model-00001-of-00007.safetensors",
|
| 429 |
+
"model.local_encoder.layers.0.pre_feedforward_layernorm.weight": "model-00001-of-00007.safetensors",
|
| 430 |
+
"model.local_encoder.layers.0.pre_xlstm_layernorm.weight": "model-00001-of-00007.safetensors",
|
| 431 |
+
"model.local_encoder.layers.0.xlstm.fgate_preact.bias": "model-00001-of-00007.safetensors",
|
| 432 |
+
"model.local_encoder.layers.0.xlstm.fgate_preact.weight": "model-00001-of-00007.safetensors",
|
| 433 |
+
"model.local_encoder.layers.0.xlstm.igate_preact.bias": "model-00001-of-00007.safetensors",
|
| 434 |
+
"model.local_encoder.layers.0.xlstm.igate_preact.weight": "model-00001-of-00007.safetensors",
|
| 435 |
+
"model.local_encoder.layers.0.xlstm.k.weight": "model-00001-of-00007.safetensors",
|
| 436 |
+
"model.local_encoder.layers.0.xlstm.multihead_norm.weight": "model-00001-of-00007.safetensors",
|
| 437 |
+
"model.local_encoder.layers.0.xlstm.ogate_preact.weight": "model-00001-of-00007.safetensors",
|
| 438 |
+
"model.local_encoder.layers.0.xlstm.out_proj.weight": "model-00001-of-00007.safetensors",
|
| 439 |
+
"model.local_encoder.layers.0.xlstm.q.weight": "model-00001-of-00007.safetensors",
|
| 440 |
+
"model.local_encoder.layers.0.xlstm.v.weight": "model-00001-of-00007.safetensors",
|
| 441 |
+
"model.local_encoder.out_projection.bias": "model-00001-of-00007.safetensors",
|
| 442 |
+
"model.local_encoder.out_projection.weight": "model-00001-of-00007.safetensors",
|
| 443 |
+
"model.local_encoder.post_last_block_norm.weight": "model-00001-of-00007.safetensors",
|
| 444 |
+
"model.local_encoder.subword_embedding.weight": "model-00001-of-00007.safetensors",
|
| 445 |
+
"model.norm.weight": "model-00007-of-00007.safetensors"
|
| 446 |
+
}
|
| 447 |
+
}
|
modeling_bolmo.py
ADDED
|
@@ -0,0 +1,1351 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import copy
|
| 2 |
+
from typing import Callable, Optional, Union, cast
|
| 3 |
+
import math
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
import torch.nn as nn
|
| 7 |
+
from torch.nn import functional as F
|
| 8 |
+
|
| 9 |
+
from transformers.utils.generic import TransformersKwargs
|
| 10 |
+
|
| 11 |
+
from transformers.activations import ACT2FN
|
| 12 |
+
from transformers.cache_utils import Cache, DynamicCache
|
| 13 |
+
from transformers.generation import GenerationMixin, GenerationConfig, LogitsProcessorList, StoppingCriteriaList
|
| 14 |
+
from transformers.generation.utils import GenerateOutput
|
| 15 |
+
from transformers.integrations import use_kernel_forward_from_hub
|
| 16 |
+
from transformers.masking_utils import create_causal_mask, create_sliding_window_causal_mask
|
| 17 |
+
from transformers.modeling_layers import GradientCheckpointingLayer
|
| 18 |
+
from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
|
| 19 |
+
from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
|
| 20 |
+
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
| 21 |
+
from transformers.processing_utils import Unpack
|
| 22 |
+
from transformers.utils import can_return_tuple
|
| 23 |
+
from transformers.utils.deprecation import deprecate_kwarg
|
| 24 |
+
from transformers.utils.generic import check_model_inputs
|
| 25 |
+
|
| 26 |
+
from .configuration_bolmo import BolmoConfig
|
| 27 |
+
from .tokenization_bolmo import BolmoTokenizerConfig
|
| 28 |
+
from .utils_bolmo import compute_boundary_mask, pad_right, pad_left, MaskState
|
| 29 |
+
|
| 30 |
+
try:
|
| 31 |
+
from xlstm.xlstm_large.model import mLSTMLayer, mLSTMLayerConfig, mLSTMLayerStateType, soft_cap, mLSTMBackendConfig
|
| 32 |
+
except ImportError:
|
| 33 |
+
raise ImportError("The `xlstm` package is required to use Bolmo. Please install it via `pip install xlstm`.")
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
@use_kernel_forward_from_hub("RMSNorm")
|
| 37 |
+
class BolmoRMSNorm(nn.Module):
|
| 38 |
+
def __init__(self, hidden_size, eps=1e-6):
|
| 39 |
+
"""
|
| 40 |
+
BolmoRMSNorm is equivalent to T5LayerNorm
|
| 41 |
+
"""
|
| 42 |
+
super().__init__()
|
| 43 |
+
self.weight = nn.Parameter(torch.ones(hidden_size))
|
| 44 |
+
self.variance_epsilon = eps
|
| 45 |
+
|
| 46 |
+
def forward(self, hidden_states):
|
| 47 |
+
input_dtype = hidden_states.dtype
|
| 48 |
+
hidden_states = hidden_states.to(torch.float32)
|
| 49 |
+
variance = hidden_states.pow(2).mean(-1, keepdim=True)
|
| 50 |
+
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
|
| 51 |
+
return (self.weight * hidden_states).to(input_dtype)
|
| 52 |
+
|
| 53 |
+
def extra_repr(self):
|
| 54 |
+
return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
|
| 58 |
+
"""
|
| 59 |
+
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
|
| 60 |
+
num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
|
| 61 |
+
"""
|
| 62 |
+
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
|
| 63 |
+
if n_rep == 1:
|
| 64 |
+
return hidden_states
|
| 65 |
+
hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
|
| 66 |
+
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
def eager_attention_forward(
|
| 70 |
+
module: nn.Module,
|
| 71 |
+
query: torch.Tensor,
|
| 72 |
+
key: torch.Tensor,
|
| 73 |
+
value: torch.Tensor,
|
| 74 |
+
attention_mask: Optional[torch.Tensor],
|
| 75 |
+
scaling: float,
|
| 76 |
+
dropout: float = 0.0,
|
| 77 |
+
**kwargs: Unpack[TransformersKwargs],
|
| 78 |
+
):
|
| 79 |
+
key_states = repeat_kv(key, module.num_key_value_groups)
|
| 80 |
+
value_states = repeat_kv(value, module.num_key_value_groups)
|
| 81 |
+
|
| 82 |
+
attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
|
| 83 |
+
if attention_mask is not None:
|
| 84 |
+
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
|
| 85 |
+
attn_weights = attn_weights + causal_mask
|
| 86 |
+
|
| 87 |
+
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
|
| 88 |
+
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
|
| 89 |
+
attn_output = torch.matmul(attn_weights, value_states)
|
| 90 |
+
attn_output = attn_output.transpose(1, 2).contiguous()
|
| 91 |
+
|
| 92 |
+
return attn_output, attn_weights
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
|
| 96 |
+
"""Applies Rotary Position Embedding to the query and key tensors.
|
| 97 |
+
|
| 98 |
+
Args:
|
| 99 |
+
q (`torch.Tensor`): The query tensor.
|
| 100 |
+
k (`torch.Tensor`): The key tensor.
|
| 101 |
+
cos (`torch.Tensor`): The cosine part of the rotary embedding.
|
| 102 |
+
sin (`torch.Tensor`): The sine part of the rotary embedding.
|
| 103 |
+
position_ids (`torch.Tensor`, *optional*):
|
| 104 |
+
Deprecated and unused.
|
| 105 |
+
unsqueeze_dim (`int`, *optional*, defaults to 1):
|
| 106 |
+
The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
|
| 107 |
+
sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
|
| 108 |
+
that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
|
| 109 |
+
k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
|
| 110 |
+
cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
|
| 111 |
+
the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
|
| 112 |
+
Returns:
|
| 113 |
+
`tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
|
| 114 |
+
"""
|
| 115 |
+
q_type, k_type = q.dtype, k.dtype
|
| 116 |
+
cos = cos.unsqueeze(unsqueeze_dim)
|
| 117 |
+
sin = sin.unsqueeze(unsqueeze_dim)
|
| 118 |
+
q_embed = (q * cos) + (rotate_half(q) * sin)
|
| 119 |
+
k_embed = (k * cos) + (rotate_half(k) * sin)
|
| 120 |
+
return q_embed.to(q_type), k_embed.to(k_type)
|
| 121 |
+
|
| 122 |
+
|
| 123 |
+
def rotate_half(x):
|
| 124 |
+
"""Rotates half the hidden dims of the input."""
|
| 125 |
+
x1 = x[..., : x.shape[-1] // 2]
|
| 126 |
+
x2 = x[..., x.shape[-1] // 2 :]
|
| 127 |
+
return torch.cat((-x2, x1), dim=-1)
|
| 128 |
+
|
| 129 |
+
|
| 130 |
+
class BolmoAttention(nn.Module):
|
| 131 |
+
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
| 132 |
+
|
| 133 |
+
def __init__(self, config: BolmoConfig, layer_idx: int):
|
| 134 |
+
super().__init__()
|
| 135 |
+
self.config = config
|
| 136 |
+
self.layer_idx = layer_idx
|
| 137 |
+
self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
|
| 138 |
+
self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
|
| 139 |
+
self.scaling = self.head_dim**-0.5
|
| 140 |
+
self.attention_dropout = config.attention_dropout
|
| 141 |
+
self.is_causal = True
|
| 142 |
+
|
| 143 |
+
self.q_proj = nn.Linear(
|
| 144 |
+
config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias
|
| 145 |
+
)
|
| 146 |
+
self.k_proj = nn.Linear(
|
| 147 |
+
config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
|
| 148 |
+
)
|
| 149 |
+
self.v_proj = nn.Linear(
|
| 150 |
+
config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
|
| 151 |
+
)
|
| 152 |
+
self.o_proj = nn.Linear(
|
| 153 |
+
config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias
|
| 154 |
+
)
|
| 155 |
+
self.q_norm = BolmoRMSNorm(config.num_attention_heads * self.head_dim, config.rms_norm_eps)
|
| 156 |
+
self.k_norm = BolmoRMSNorm(config.num_key_value_heads * self.head_dim, config.rms_norm_eps)
|
| 157 |
+
assert config.layer_types is not None
|
| 158 |
+
self.attention_type = config.layer_types[layer_idx]
|
| 159 |
+
self.sliding_window = config.sliding_window if self.attention_type == "sliding_attention" else None
|
| 160 |
+
|
| 161 |
+
@deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
|
| 162 |
+
def forward(
|
| 163 |
+
self,
|
| 164 |
+
hidden_states: torch.Tensor,
|
| 165 |
+
position_embeddings: tuple[torch.Tensor, torch.Tensor],
|
| 166 |
+
attention_mask: Optional[torch.Tensor],
|
| 167 |
+
past_key_values: Optional[Cache] = None,
|
| 168 |
+
cache_position: Optional[torch.Tensor] = None,
|
| 169 |
+
**kwargs: Unpack[TransformersKwargs],
|
| 170 |
+
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
|
| 171 |
+
input_shape = hidden_states.shape[:-1]
|
| 172 |
+
hidden_shape = (*input_shape, -1, self.head_dim)
|
| 173 |
+
|
| 174 |
+
query_states = self.q_norm(self.q_proj(hidden_states))
|
| 175 |
+
key_states = self.k_norm(self.k_proj(hidden_states))
|
| 176 |
+
value_states = self.v_proj(hidden_states)
|
| 177 |
+
|
| 178 |
+
query_states = query_states.view(hidden_shape).transpose(1, 2)
|
| 179 |
+
key_states = key_states.view(hidden_shape).transpose(1, 2)
|
| 180 |
+
value_states = value_states.view(hidden_shape).transpose(1, 2)
|
| 181 |
+
|
| 182 |
+
cos, sin = position_embeddings
|
| 183 |
+
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
|
| 184 |
+
|
| 185 |
+
if past_key_values is not None:
|
| 186 |
+
# sin and cos are specific to RoPE models; cache_position needed for the static cache
|
| 187 |
+
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
|
| 188 |
+
key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs)
|
| 189 |
+
|
| 190 |
+
attention_interface: Callable = eager_attention_forward
|
| 191 |
+
if self.config._attn_implementation != "eager":
|
| 192 |
+
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
|
| 193 |
+
|
| 194 |
+
attn_output, attn_weights = attention_interface(
|
| 195 |
+
self,
|
| 196 |
+
query_states,
|
| 197 |
+
key_states,
|
| 198 |
+
value_states,
|
| 199 |
+
attention_mask,
|
| 200 |
+
dropout=0.0 if not self.training else self.attention_dropout,
|
| 201 |
+
scaling=self.scaling,
|
| 202 |
+
sliding_window=self.sliding_window,
|
| 203 |
+
**kwargs,
|
| 204 |
+
)
|
| 205 |
+
|
| 206 |
+
attn_output = attn_output.reshape(*input_shape, -1).contiguous()
|
| 207 |
+
attn_output = self.o_proj(attn_output)
|
| 208 |
+
return attn_output, attn_weights
|
| 209 |
+
|
| 210 |
+
|
| 211 |
+
class BolmoMLP(nn.Module):
|
| 212 |
+
def __init__(self, config):
|
| 213 |
+
super().__init__()
|
| 214 |
+
self.config = config
|
| 215 |
+
self.hidden_size = config.hidden_size
|
| 216 |
+
self.intermediate_size = config.intermediate_size
|
| 217 |
+
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
|
| 218 |
+
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
|
| 219 |
+
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
|
| 220 |
+
self.act_fn = ACT2FN[config.hidden_act]
|
| 221 |
+
|
| 222 |
+
def forward(self, x):
|
| 223 |
+
down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
|
| 224 |
+
return down_proj
|
| 225 |
+
|
| 226 |
+
|
| 227 |
+
class BolmoDecoderLayer(GradientCheckpointingLayer):
|
| 228 |
+
def __init__(self, config: BolmoConfig, layer_idx: int):
|
| 229 |
+
super().__init__()
|
| 230 |
+
self.hidden_size = config.hidden_size
|
| 231 |
+
self.self_attn = BolmoAttention(config=config, layer_idx=layer_idx)
|
| 232 |
+
|
| 233 |
+
self.mlp = BolmoMLP(config)
|
| 234 |
+
self.post_attention_layernorm = BolmoRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
| 235 |
+
self.post_feedforward_layernorm = BolmoRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
| 236 |
+
|
| 237 |
+
@deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
|
| 238 |
+
def forward(
|
| 239 |
+
self,
|
| 240 |
+
hidden_states: torch.Tensor,
|
| 241 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 242 |
+
position_ids: Optional[torch.Tensor] = None,
|
| 243 |
+
past_key_values: Optional[Cache] = None,
|
| 244 |
+
use_cache: Optional[bool] = False,
|
| 245 |
+
cache_position: Optional[torch.Tensor] = None,
|
| 246 |
+
position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
|
| 247 |
+
**kwargs: Unpack[TransformersKwargs],
|
| 248 |
+
) -> torch.Tensor:
|
| 249 |
+
residual = hidden_states
|
| 250 |
+
attn_out, _ = self.self_attn(
|
| 251 |
+
hidden_states=hidden_states,
|
| 252 |
+
attention_mask=attention_mask,
|
| 253 |
+
position_ids=position_ids,
|
| 254 |
+
past_key_values=past_key_values,
|
| 255 |
+
use_cache=use_cache,
|
| 256 |
+
cache_position=cache_position,
|
| 257 |
+
position_embeddings=position_embeddings,
|
| 258 |
+
**kwargs,
|
| 259 |
+
)
|
| 260 |
+
hidden_states = self.post_attention_layernorm(attn_out)
|
| 261 |
+
hidden_states = residual + hidden_states
|
| 262 |
+
|
| 263 |
+
# Fully Connected
|
| 264 |
+
residual = hidden_states
|
| 265 |
+
mlp_out = self.mlp(hidden_states)
|
| 266 |
+
hidden_states = self.post_feedforward_layernorm(mlp_out)
|
| 267 |
+
hidden_states = residual + hidden_states
|
| 268 |
+
|
| 269 |
+
return hidden_states
|
| 270 |
+
|
| 271 |
+
|
| 272 |
+
class BolmoBoundaryPredictor(nn.Module):
|
| 273 |
+
def __init__(self, config: BolmoConfig):
|
| 274 |
+
super().__init__()
|
| 275 |
+
|
| 276 |
+
self.d_model = config.hidden_size
|
| 277 |
+
self.boundary_threshold = config.boundary_threshold
|
| 278 |
+
self.boundary_predictor_lookahead = config.boundary_predictor_lookahead
|
| 279 |
+
self.q_proj_layer = nn.Linear(self.d_model, self.d_model, bias=False)
|
| 280 |
+
self.k_proj_layer = nn.Linear(self.d_model, self.d_model, bias=False)
|
| 281 |
+
|
| 282 |
+
def forward(
|
| 283 |
+
self,
|
| 284 |
+
hidden_states: torch.Tensor,
|
| 285 |
+
sequence_start_indices: Optional[torch.Tensor] = None,
|
| 286 |
+
epsilon: float = 1e-3,
|
| 287 |
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
| 288 |
+
if self.boundary_predictor_lookahead == 0:
|
| 289 |
+
# do not use the same rep for k and v, use current and one before as in H-Net + pad with negative to the left
|
| 290 |
+
cos_sim = torch.cat([
|
| 291 |
+
torch.ones((hidden_states.shape[0], 1), device=hidden_states.device, dtype=hidden_states.dtype) * -1,
|
| 292 |
+
torch.einsum(
|
| 293 |
+
"b l d, b l d -> b l",
|
| 294 |
+
F.normalize(self.q_proj_layer(hidden_states[:, :-1]), dim=-1),
|
| 295 |
+
F.normalize(self.k_proj_layer(hidden_states[:, 1:]), dim=-1),
|
| 296 |
+
)
|
| 297 |
+
], dim=1)
|
| 298 |
+
else:
|
| 299 |
+
cos_sim = torch.einsum(
|
| 300 |
+
"b l d, b l d -> b l",
|
| 301 |
+
F.normalize(self.q_proj_layer(hidden_states[:, :-self.boundary_predictor_lookahead]), dim=-1),
|
| 302 |
+
F.normalize(self.k_proj_layer(hidden_states[:, self.boundary_predictor_lookahead:]), dim=-1),
|
| 303 |
+
)
|
| 304 |
+
boundary_logprobs = torch.log1p(-cos_sim.float().clip(max=1.0 - epsilon)) - math.log(2)
|
| 305 |
+
POSITIVE_LOGPROB = 0.0
|
| 306 |
+
NEGATIVE_LOGPROB = -100_000
|
| 307 |
+
if sequence_start_indices is None:
|
| 308 |
+
boundary_logprobs[:, 0] = POSITIVE_LOGPROB
|
| 309 |
+
else:
|
| 310 |
+
pad_mask = torch.arange(boundary_logprobs.shape[1], device=boundary_logprobs.device)[None, :] < sequence_start_indices[:, None]
|
| 311 |
+
boundary_logprobs = boundary_logprobs.masked_fill(pad_mask, NEGATIVE_LOGPROB)
|
| 312 |
+
boundary_logprobs[torch.arange(len(boundary_logprobs), device=boundary_logprobs.device), sequence_start_indices] = POSITIVE_LOGPROB
|
| 313 |
+
|
| 314 |
+
boundary_logprobs = F.pad(boundary_logprobs, (0, self.boundary_predictor_lookahead), "constant", NEGATIVE_LOGPROB)
|
| 315 |
+
boundary_mask = compute_boundary_mask(boundary_logprobs, self.boundary_threshold)
|
| 316 |
+
|
| 317 |
+
return boundary_logprobs, boundary_mask
|
| 318 |
+
|
| 319 |
+
|
| 320 |
+
class BolmoXLSTMLayer(mLSTMLayer):
|
| 321 |
+
def __init__(self, config: BolmoConfig):
|
| 322 |
+
super().__init__(mLSTMLayerConfig(
|
| 323 |
+
embedding_dim=config.hidden_size,
|
| 324 |
+
num_heads=config.num_local_heads,
|
| 325 |
+
mlstm_backend=mLSTMBackendConfig(
|
| 326 |
+
chunkwise_kernel="chunkwise--triton_limit_chunk",
|
| 327 |
+
sequence_kernel="native_sequence__triton",
|
| 328 |
+
step_kernel="triton",
|
| 329 |
+
mode="train",
|
| 330 |
+
return_last_states=True,
|
| 331 |
+
autocast_kernel_dtype="float32",
|
| 332 |
+
)
|
| 333 |
+
))
|
| 334 |
+
|
| 335 |
+
# original forward adapted to support sequence_start_indices
|
| 336 |
+
# i.e. set the forget gate to zero at the start of sequence
|
| 337 |
+
def _original_forward(
|
| 338 |
+
self, x: torch.Tensor,
|
| 339 |
+
state: mLSTMLayerStateType | None = None,
|
| 340 |
+
sequence_start_indices: Optional[torch.Tensor] = None,
|
| 341 |
+
) -> tuple[torch.Tensor, mLSTMLayerStateType | None]:
|
| 342 |
+
assert x.ndim == 3, f"Input must have shape [B, S, D], got {x.shape}"
|
| 343 |
+
B, S, _ = x.shape
|
| 344 |
+
if self.config.weight_mode == "single":
|
| 345 |
+
q = self.q(x)
|
| 346 |
+
k = self.k(x)
|
| 347 |
+
v = self.v(x)
|
| 348 |
+
o_preact = self.ogate_preact(x)
|
| 349 |
+
i_preact = soft_cap(
|
| 350 |
+
self.igate_preact(x), cap_value=self.config.gate_soft_cap
|
| 351 |
+
)
|
| 352 |
+
f_preact = soft_cap(
|
| 353 |
+
self.fgate_preact(x), cap_value=self.config.gate_soft_cap
|
| 354 |
+
)
|
| 355 |
+
elif self.config.weight_mode == "fused":
|
| 356 |
+
qkv_opreact = self.qkv_opreact(x)
|
| 357 |
+
q, k, v, o_preact = torch.tensor_split(
|
| 358 |
+
qkv_opreact,
|
| 359 |
+
(
|
| 360 |
+
self.qk_dim,
|
| 361 |
+
2 * self.qk_dim,
|
| 362 |
+
2 * self.qk_dim + self.v_dim,
|
| 363 |
+
),
|
| 364 |
+
dim=-1,
|
| 365 |
+
)
|
| 366 |
+
|
| 367 |
+
if_preact = soft_cap(
|
| 368 |
+
self.ifgate_preact(x), cap_value=self.config.gate_soft_cap
|
| 369 |
+
)
|
| 370 |
+
i_preact, f_preact = torch.tensor_split(
|
| 371 |
+
if_preact, (self.config.num_heads,), dim=-1
|
| 372 |
+
)
|
| 373 |
+
else:
|
| 374 |
+
raise ValueError(f"Unknown weight_mode: {self.config.weight_mode}")
|
| 375 |
+
|
| 376 |
+
q = q.reshape(B, S, self.config.num_heads, -1).transpose(1, 2)
|
| 377 |
+
k = k.reshape(B, S, self.config.num_heads, -1).transpose(1, 2)
|
| 378 |
+
v = v.reshape(B, S, self.config.num_heads, -1).transpose(1, 2)
|
| 379 |
+
|
| 380 |
+
if sequence_start_indices is not None:
|
| 381 |
+
f_preact[torch.arange(B, device=f_preact.device), sequence_start_indices] = -100_000
|
| 382 |
+
|
| 383 |
+
i_preact = i_preact.transpose(1, 2)
|
| 384 |
+
f_preact = f_preact.transpose(1, 2)
|
| 385 |
+
if state is None:
|
| 386 |
+
c_initial, n_initial, m_initial = None, None, None
|
| 387 |
+
else:
|
| 388 |
+
c_initial, n_initial, m_initial = state
|
| 389 |
+
|
| 390 |
+
h, state = self.mlstm_backend(
|
| 391 |
+
q=q,
|
| 392 |
+
k=k,
|
| 393 |
+
v=v,
|
| 394 |
+
i=i_preact,
|
| 395 |
+
f=f_preact,
|
| 396 |
+
c_initial=c_initial,
|
| 397 |
+
n_initial=n_initial,
|
| 398 |
+
m_initial=m_initial,
|
| 399 |
+
)
|
| 400 |
+
expected_h_shape = (
|
| 401 |
+
B,
|
| 402 |
+
self.config.num_heads,
|
| 403 |
+
S,
|
| 404 |
+
self.v_dim // self.config.num_heads,
|
| 405 |
+
)
|
| 406 |
+
assert (
|
| 407 |
+
h.shape == expected_h_shape
|
| 408 |
+
), f"Got {h.shape}, expected {expected_h_shape}"
|
| 409 |
+
|
| 410 |
+
h = h.transpose(1, 2)
|
| 411 |
+
h_norm = self.multihead_norm(h)
|
| 412 |
+
h_norm = h_norm.reshape(B, S, -1)
|
| 413 |
+
|
| 414 |
+
h_out = self.ogate_act_fn(o_preact) * h_norm
|
| 415 |
+
|
| 416 |
+
y = self.out_proj(h_out)
|
| 417 |
+
return y, state
|
| 418 |
+
|
| 419 |
+
def forward( # type: ignore
|
| 420 |
+
self,
|
| 421 |
+
x: torch.Tensor,
|
| 422 |
+
past_key_values: Optional[dict] = None,
|
| 423 |
+
use_cache: bool = False,
|
| 424 |
+
sequence_start_indices: Optional[torch.Tensor] = None,
|
| 425 |
+
cache_mask: Optional[MaskState] = None
|
| 426 |
+
):
|
| 427 |
+
if self.training:
|
| 428 |
+
self.mlstm_backend.config.mode = "train"
|
| 429 |
+
else:
|
| 430 |
+
self.mlstm_backend.config.mode = "inference"
|
| 431 |
+
|
| 432 |
+
if use_cache:
|
| 433 |
+
assert past_key_values is not None
|
| 434 |
+
|
| 435 |
+
prev_mode = self.mlstm_backend.config.mode
|
| 436 |
+
state = past_key_values.get("state", None)
|
| 437 |
+
|
| 438 |
+
if cache_mask is not None:
|
| 439 |
+
state_for_model = cast(mLSTMLayerStateType, tuple(cache_mask.selective_get(x, inv=True) for x in state) if state is not None else None)
|
| 440 |
+
else:
|
| 441 |
+
state_for_model = state
|
| 442 |
+
|
| 443 |
+
h, new_state = self._original_forward(
|
| 444 |
+
x,
|
| 445 |
+
state=state_for_model,
|
| 446 |
+
sequence_start_indices=sequence_start_indices
|
| 447 |
+
)
|
| 448 |
+
assert new_state is not None
|
| 449 |
+
|
| 450 |
+
if state is None or cache_mask is None:
|
| 451 |
+
state = new_state
|
| 452 |
+
else:
|
| 453 |
+
if cache_mask is not None:
|
| 454 |
+
for i in range(len(state)):
|
| 455 |
+
cache_mask.selective_put(new_state[i], state[i], inv=True)
|
| 456 |
+
|
| 457 |
+
past_key_values["state"] = state
|
| 458 |
+
self.mlstm_backend.config.mode = prev_mode
|
| 459 |
+
|
| 460 |
+
return h
|
| 461 |
+
else:
|
| 462 |
+
h, _ = super().forward(x)
|
| 463 |
+
return h
|
| 464 |
+
|
| 465 |
+
class BolmoLocalLayer(nn.Module):
|
| 466 |
+
def __init__(self, config: BolmoConfig):
|
| 467 |
+
super().__init__()
|
| 468 |
+
self.config = config
|
| 469 |
+
self.hidden_size = config.hidden_size
|
| 470 |
+
|
| 471 |
+
self.act_fn = ACT2FN[config.hidden_act]
|
| 472 |
+
|
| 473 |
+
self.xlstm = BolmoXLSTMLayer(config)
|
| 474 |
+
|
| 475 |
+
local_mlp_config = copy.deepcopy(config)
|
| 476 |
+
local_mlp_config.intermediate_size = config.local_intermediate_size
|
| 477 |
+
self.mlp = BolmoMLP(local_mlp_config)
|
| 478 |
+
|
| 479 |
+
self.pre_xlstm_layernorm = BolmoRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
| 480 |
+
self.pre_feedforward_layernorm = BolmoRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
| 481 |
+
|
| 482 |
+
def forward(
|
| 483 |
+
self,
|
| 484 |
+
hidden_states: torch.Tensor,
|
| 485 |
+
sequence_start_indices: Optional[torch.Tensor] = None,
|
| 486 |
+
past_key_values: Optional[dict] = None,
|
| 487 |
+
use_cache: Optional[bool] = False,
|
| 488 |
+
cache_mask: Optional[MaskState] = None,
|
| 489 |
+
) -> torch.Tensor:
|
| 490 |
+
residual = hidden_states
|
| 491 |
+
xlstm_out = self.xlstm(self.pre_xlstm_layernorm(hidden_states), sequence_start_indices=sequence_start_indices, past_key_values=past_key_values["xlstm"] if past_key_values is not None else None, use_cache=use_cache, cache_mask=cache_mask)
|
| 492 |
+
hidden_states = residual + xlstm_out
|
| 493 |
+
|
| 494 |
+
# Fully Connected
|
| 495 |
+
residual = hidden_states
|
| 496 |
+
ffn_out = self.mlp(self.pre_feedforward_layernorm(hidden_states))
|
| 497 |
+
hidden_states = residual + ffn_out
|
| 498 |
+
|
| 499 |
+
return hidden_states
|
| 500 |
+
|
| 501 |
+
|
| 502 |
+
class BolmoLocalEncoder(nn.Module):
|
| 503 |
+
def __init__(self, config: BolmoConfig):
|
| 504 |
+
super().__init__()
|
| 505 |
+
self.config = config
|
| 506 |
+
self.hidden_size = config.hidden_size
|
| 507 |
+
self.add_expanded_embeddings = config.add_expanded_embeddings
|
| 508 |
+
|
| 509 |
+
self.byte_embedding = nn.Embedding(
|
| 510 |
+
config.vocab_size,
|
| 511 |
+
self.hidden_size,
|
| 512 |
+
)
|
| 513 |
+
if self.add_expanded_embeddings:
|
| 514 |
+
self.subword_embedding = nn.Embedding(
|
| 515 |
+
config.subword_vocab_size,
|
| 516 |
+
self.hidden_size,
|
| 517 |
+
)
|
| 518 |
+
else:
|
| 519 |
+
self.subword_embedding = None
|
| 520 |
+
|
| 521 |
+
self.layers = nn.ModuleList(
|
| 522 |
+
[BolmoLocalLayer(config) for _ in range(config.num_local_encoder_layers)]
|
| 523 |
+
)
|
| 524 |
+
|
| 525 |
+
self.post_last_block_norm = BolmoRMSNorm(
|
| 526 |
+
self.hidden_size,
|
| 527 |
+
config.local_rms_norm_eps,
|
| 528 |
+
)
|
| 529 |
+
self.out_projection = nn.Linear(
|
| 530 |
+
self.hidden_size,
|
| 531 |
+
self.hidden_size,
|
| 532 |
+
bias=True,
|
| 533 |
+
)
|
| 534 |
+
|
| 535 |
+
self.boundary_predictor_module = BolmoBoundaryPredictor(config)
|
| 536 |
+
|
| 537 |
+
self.has_cache = False
|
| 538 |
+
|
| 539 |
+
def prepare_inference_cache(self, batch_size: int):
|
| 540 |
+
device = next(self.parameters()).device
|
| 541 |
+
self.has_cache = True
|
| 542 |
+
|
| 543 |
+
self.cache_seqlens = 0
|
| 544 |
+
self.last_h = torch.zeros((batch_size, self.hidden_size), dtype=self.out_projection.weight.dtype, device=device)
|
| 545 |
+
self.layer_states = [{"xlstm": {}} for _ in range(len(self.layers))]
|
| 546 |
+
|
| 547 |
+
def free_inference_cache(self):
|
| 548 |
+
self.has_cache = False
|
| 549 |
+
if hasattr(self, "cache_seqlens"):
|
| 550 |
+
del self.cache_seqlens
|
| 551 |
+
if hasattr(self, "last_h"):
|
| 552 |
+
del self.last_h
|
| 553 |
+
if hasattr(self, "layer_states"):
|
| 554 |
+
del self.layer_states
|
| 555 |
+
|
| 556 |
+
def _embed(self, tokens, expanded_input_ids: Optional[torch.Tensor] = None):
|
| 557 |
+
embeddings = self.byte_embedding(tokens)
|
| 558 |
+
if self.add_expanded_embeddings:
|
| 559 |
+
assert expanded_input_ids is not None and self.subword_embedding is not None
|
| 560 |
+
embeddings = embeddings + self.subword_embedding(expanded_input_ids)
|
| 561 |
+
|
| 562 |
+
return embeddings
|
| 563 |
+
|
| 564 |
+
def _pool(
|
| 565 |
+
self,
|
| 566 |
+
h: torch.Tensor,
|
| 567 |
+
boundary_mask: torch.Tensor | None,
|
| 568 |
+
n_patches: int,
|
| 569 |
+
boundary_state: Optional[MaskState] = None,
|
| 570 |
+
):
|
| 571 |
+
if self.has_cache and self.cache_seqlens > 0:
|
| 572 |
+
assert boundary_state is not None
|
| 573 |
+
if boundary_state.all():
|
| 574 |
+
assert h.shape[1] == 1
|
| 575 |
+
reduced_h = h
|
| 576 |
+
else:
|
| 577 |
+
reduced_h = h[[], :, :]
|
| 578 |
+
else:
|
| 579 |
+
assert boundary_mask is not None
|
| 580 |
+
|
| 581 |
+
L = h.shape[1]
|
| 582 |
+
token_idx = (
|
| 583 |
+
torch.arange(L, device=h.device)[None, :] + (~boundary_mask).long() * L # type: ignore
|
| 584 |
+
)
|
| 585 |
+
seq_sorted_indices = torch.argsort(token_idx, dim=1)
|
| 586 |
+
index = seq_sorted_indices[:, :n_patches, None].expand(
|
| 587 |
+
-1, -1, h.shape[-1]
|
| 588 |
+
)
|
| 589 |
+
|
| 590 |
+
reduced_h = torch.gather(
|
| 591 |
+
h,
|
| 592 |
+
dim=1,
|
| 593 |
+
index=index,
|
| 594 |
+
)
|
| 595 |
+
|
| 596 |
+
return reduced_h
|
| 597 |
+
|
| 598 |
+
def forward(
|
| 599 |
+
self,
|
| 600 |
+
input_ids,
|
| 601 |
+
true_boundary_mask: Optional[torch.Tensor] = None,
|
| 602 |
+
boundary_state: Optional[MaskState] = None,
|
| 603 |
+
pad_state: Optional[MaskState] = None,
|
| 604 |
+
expanded_input_ids: Optional[torch.Tensor] = None,
|
| 605 |
+
sequence_start_indices: Optional[torch.Tensor] = None,
|
| 606 |
+
):
|
| 607 |
+
embeddings = self._embed(input_ids, expanded_input_ids)
|
| 608 |
+
|
| 609 |
+
# pass through encoder layers
|
| 610 |
+
if self.has_cache and self.cache_seqlens > 0:
|
| 611 |
+
assert pad_state is not None
|
| 612 |
+
|
| 613 |
+
# step those batch positions which are not currently idle (i.e. at a boundary position)
|
| 614 |
+
# if all batch positions are idle, skip the step entirely
|
| 615 |
+
# all positions being idle only happens if fuse_boundaries=False. In this case, the step where we
|
| 616 |
+
# obtain a new representation from the global model will have all positions for the local encoder being idle.
|
| 617 |
+
if not pad_state.all():
|
| 618 |
+
h = pad_state.selective_get(embeddings, inv=True)
|
| 619 |
+
|
| 620 |
+
for i, block in enumerate(self.layers):
|
| 621 |
+
h = block(h, past_key_values=self.layer_states[i], use_cache=True, cache_mask=pad_state)
|
| 622 |
+
|
| 623 |
+
if self.post_last_block_norm is not None:
|
| 624 |
+
h = self.post_last_block_norm(h)
|
| 625 |
+
|
| 626 |
+
pad_state.selective_put(h[:, -1, :], self.last_h, inv=True)
|
| 627 |
+
|
| 628 |
+
h = self.last_h.unsqueeze(1)
|
| 629 |
+
else:
|
| 630 |
+
h = embeddings
|
| 631 |
+
for i, block in enumerate(self.layers):
|
| 632 |
+
if self.has_cache:
|
| 633 |
+
use_cache = True
|
| 634 |
+
past_key_values = self.layer_states[i]
|
| 635 |
+
else:
|
| 636 |
+
use_cache = False
|
| 637 |
+
past_key_values = None
|
| 638 |
+
|
| 639 |
+
h = block(h, past_key_values=past_key_values, use_cache=use_cache, sequence_start_indices=sequence_start_indices)
|
| 640 |
+
|
| 641 |
+
if self.post_last_block_norm is not None:
|
| 642 |
+
h = self.post_last_block_norm(h)
|
| 643 |
+
|
| 644 |
+
if self.has_cache:
|
| 645 |
+
self.last_h.copy_(h[:, -1, :])
|
| 646 |
+
|
| 647 |
+
if not self.has_cache or self.cache_seqlens == 0: # only used for prefill
|
| 648 |
+
boundary_logprobs, boundary_mask = self.boundary_predictor_module(
|
| 649 |
+
h,
|
| 650 |
+
sequence_start_indices=sequence_start_indices,
|
| 651 |
+
)
|
| 652 |
+
if boundary_state is not None:
|
| 653 |
+
# can't predict through encoder - must be through prev local decoder step
|
| 654 |
+
boundary_mask[:, -1] = boundary_state.mask
|
| 655 |
+
else:
|
| 656 |
+
boundary_logprobs = boundary_mask = None
|
| 657 |
+
|
| 658 |
+
# overwrite with true boundaries
|
| 659 |
+
if true_boundary_mask is not None:
|
| 660 |
+
boundary_mask = true_boundary_mask
|
| 661 |
+
|
| 662 |
+
patch_embeddings = self._pool(
|
| 663 |
+
h=h,
|
| 664 |
+
boundary_mask=boundary_mask,
|
| 665 |
+
n_patches=int(cast(torch.Tensor, boundary_mask).sum(-1).max().item()) if boundary_mask is not None else 1,
|
| 666 |
+
boundary_state=boundary_state,
|
| 667 |
+
)
|
| 668 |
+
patch_embeddings = self.out_projection(patch_embeddings)
|
| 669 |
+
|
| 670 |
+
if self.has_cache:
|
| 671 |
+
self.cache_seqlens += input_ids.shape[1]
|
| 672 |
+
|
| 673 |
+
return h, patch_embeddings, boundary_logprobs, boundary_mask
|
| 674 |
+
|
| 675 |
+
|
| 676 |
+
class BolmoLocalDecoder(nn.Module):
|
| 677 |
+
def __init__(self, config: BolmoConfig):
|
| 678 |
+
super().__init__()
|
| 679 |
+
self.config = config
|
| 680 |
+
self.hidden_size = config.hidden_size
|
| 681 |
+
|
| 682 |
+
self.initial_norm = BolmoRMSNorm(
|
| 683 |
+
self.hidden_size,
|
| 684 |
+
eps=config.local_rms_norm_eps,
|
| 685 |
+
)
|
| 686 |
+
|
| 687 |
+
self.in_projection = nn.Linear(
|
| 688 |
+
self.hidden_size,
|
| 689 |
+
self.hidden_size,
|
| 690 |
+
bias=True,
|
| 691 |
+
)
|
| 692 |
+
|
| 693 |
+
self.layers = nn.ModuleList(
|
| 694 |
+
[BolmoLocalLayer(config) for _ in range(config.num_local_decoder_layers)]
|
| 695 |
+
)
|
| 696 |
+
|
| 697 |
+
self.has_cache = False
|
| 698 |
+
|
| 699 |
+
def prepare_inference_cache(self, batch_size: int):
|
| 700 |
+
device = next(self.parameters()).device
|
| 701 |
+
self.has_cache = True
|
| 702 |
+
|
| 703 |
+
self.cache_seqlens = 0
|
| 704 |
+
self.last_value = torch.zeros((batch_size, self.hidden_size), dtype=self.in_projection.weight.dtype, device=device)
|
| 705 |
+
self.layer_states = [{"xlstm": {}} for _ in range(len(self.layers))]
|
| 706 |
+
|
| 707 |
+
def free_inference_cache(self):
|
| 708 |
+
self.has_cache = False
|
| 709 |
+
if hasattr(self, "cache_seqlens"):
|
| 710 |
+
del self.cache_seqlens
|
| 711 |
+
if hasattr(self, "last_value"):
|
| 712 |
+
del self.last_value
|
| 713 |
+
if hasattr(self, "layer_states"):
|
| 714 |
+
del self.layer_states
|
| 715 |
+
|
| 716 |
+
def _depool(
|
| 717 |
+
self,
|
| 718 |
+
embeds: torch.Tensor,
|
| 719 |
+
patch_embeds: torch.Tensor,
|
| 720 |
+
boundary_mask: Optional[torch.Tensor],
|
| 721 |
+
boundary_state: Optional[MaskState] = None,
|
| 722 |
+
sequence_start_indices: Optional[torch.Tensor] = None,
|
| 723 |
+
) -> torch.Tensor:
|
| 724 |
+
if self.has_cache and self.cache_seqlens > 0:
|
| 725 |
+
assert boundary_state is not None
|
| 726 |
+
|
| 727 |
+
if patch_embeds.numel() > 0:
|
| 728 |
+
# we got a new value from the global model, so must be at boundary position
|
| 729 |
+
h_patch = patch_embeds[:, -1:, :]
|
| 730 |
+
h = embeds + h_patch
|
| 731 |
+
|
| 732 |
+
self.last_value.copy_(h_patch[:, -1])
|
| 733 |
+
else:
|
| 734 |
+
h = embeds + self.last_value.unsqueeze(1)
|
| 735 |
+
|
| 736 |
+
# skip pad positions until we get a new value from the global model
|
| 737 |
+
if patch_embeds.numel() == 0:
|
| 738 |
+
h = boundary_state.selective_get(h, inv=True)
|
| 739 |
+
else:
|
| 740 |
+
boundary_state = None
|
| 741 |
+
|
| 742 |
+
if h.shape[0] > 0:
|
| 743 |
+
for i, layer in enumerate(self.layers):
|
| 744 |
+
h = layer(h, past_key_values=self.layer_states[i], use_cache=True, cache_mask=boundary_state)
|
| 745 |
+
|
| 746 |
+
self.cache_seqlens += h.shape[1]
|
| 747 |
+
|
| 748 |
+
return h
|
| 749 |
+
else:
|
| 750 |
+
assert boundary_mask is not None
|
| 751 |
+
|
| 752 |
+
h_patch = patch_embeds
|
| 753 |
+
prepool_out = h_patch
|
| 754 |
+
|
| 755 |
+
# TODO(benjaminm): clipping is problematic if it happens too much; track clip %.
|
| 756 |
+
plug_back_idx = (torch.cumsum(boundary_mask, dim=1) - 1).clip(min=0, max=prepool_out.shape[1] - 1)
|
| 757 |
+
depool_out = torch.gather(
|
| 758 |
+
prepool_out,
|
| 759 |
+
dim=1,
|
| 760 |
+
index=plug_back_idx.unsqueeze(-1).expand(-1, -1, self.hidden_size),
|
| 761 |
+
)
|
| 762 |
+
|
| 763 |
+
depool_out_modulated = depool_out
|
| 764 |
+
h = depool_out_modulated + embeds
|
| 765 |
+
|
| 766 |
+
for i, layer in enumerate(self.layers):
|
| 767 |
+
if self.has_cache:
|
| 768 |
+
use_cache = True
|
| 769 |
+
past_key_values = self.layer_states[i]
|
| 770 |
+
else:
|
| 771 |
+
use_cache = False
|
| 772 |
+
past_key_values = None
|
| 773 |
+
|
| 774 |
+
h = layer(h, past_key_values=past_key_values, use_cache=use_cache, sequence_start_indices=sequence_start_indices)
|
| 775 |
+
|
| 776 |
+
if self.has_cache:
|
| 777 |
+
self.last_value.copy_(prepool_out[:, -1])
|
| 778 |
+
self.cache_seqlens += h.shape[1]
|
| 779 |
+
|
| 780 |
+
return h
|
| 781 |
+
|
| 782 |
+
def forward(
|
| 783 |
+
self,
|
| 784 |
+
embeds: torch.Tensor,
|
| 785 |
+
patch_embeds: torch.Tensor,
|
| 786 |
+
boundary_state: Optional[MaskState],
|
| 787 |
+
boundary_mask: torch.Tensor | None,
|
| 788 |
+
sequence_start_indices: Optional[torch.Tensor] = None,
|
| 789 |
+
) -> torch.Tensor:
|
| 790 |
+
h = self.in_projection(embeds)
|
| 791 |
+
h_patch = self.initial_norm(patch_embeds)
|
| 792 |
+
|
| 793 |
+
return self._depool(
|
| 794 |
+
embeds=h,
|
| 795 |
+
patch_embeds=h_patch,
|
| 796 |
+
boundary_mask=boundary_mask,
|
| 797 |
+
boundary_state=boundary_state,
|
| 798 |
+
sequence_start_indices=sequence_start_indices,
|
| 799 |
+
)
|
| 800 |
+
|
| 801 |
+
|
| 802 |
+
class BolmoRotaryEmbedding(nn.Module):
|
| 803 |
+
inv_freq: torch.Tensor # fix linting for `register_buffer`
|
| 804 |
+
|
| 805 |
+
def __init__(self, config: BolmoConfig, device=None, rope_type: Optional[str] = None):
|
| 806 |
+
super().__init__()
|
| 807 |
+
if rope_type is not None:
|
| 808 |
+
self.rope_type = rope_type
|
| 809 |
+
elif hasattr(config, "rope_scaling") and isinstance(config.rope_scaling, dict):
|
| 810 |
+
# BC: "rope_type" was originally "type"
|
| 811 |
+
self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
|
| 812 |
+
else:
|
| 813 |
+
self.rope_type = "default"
|
| 814 |
+
assert self.rope_type is not None
|
| 815 |
+
|
| 816 |
+
self.max_seq_len_cached = config.max_position_embeddings
|
| 817 |
+
self.original_max_seq_len = config.max_position_embeddings
|
| 818 |
+
|
| 819 |
+
self.config = config
|
| 820 |
+
self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
|
| 821 |
+
|
| 822 |
+
inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
|
| 823 |
+
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
| 824 |
+
self.original_inv_freq = self.inv_freq
|
| 825 |
+
|
| 826 |
+
@torch.no_grad()
|
| 827 |
+
@dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
|
| 828 |
+
def forward(self, x, position_ids):
|
| 829 |
+
inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
|
| 830 |
+
position_ids_expanded = position_ids[:, None, :].float()
|
| 831 |
+
|
| 832 |
+
device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
|
| 833 |
+
with torch.autocast(device_type=device_type, enabled=False): # Force float32
|
| 834 |
+
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
|
| 835 |
+
emb = torch.cat((freqs, freqs), dim=-1)
|
| 836 |
+
cos = emb.cos() * self.attention_scaling
|
| 837 |
+
sin = emb.sin() * self.attention_scaling
|
| 838 |
+
return cos, sin
|
| 839 |
+
|
| 840 |
+
|
| 841 |
+
class BolmoPreTrainedModel(PreTrainedModel):
|
| 842 |
+
config: BolmoConfig
|
| 843 |
+
base_model_prefix = "model"
|
| 844 |
+
supports_gradient_checkpointing = True
|
| 845 |
+
_no_split_modules = ["BolmoDecoderLayer"]
|
| 846 |
+
_skip_keys_device_placement = ["past_key_values"]
|
| 847 |
+
_supports_flash_attn = True
|
| 848 |
+
_supports_sdpa = True
|
| 849 |
+
_supports_flex_attn = True
|
| 850 |
+
|
| 851 |
+
_can_compile_fullgraph = True
|
| 852 |
+
_supports_attention_backend = True
|
| 853 |
+
_can_record_outputs = {
|
| 854 |
+
"hidden_states": BolmoDecoderLayer,
|
| 855 |
+
"attentions": BolmoAttention,
|
| 856 |
+
}
|
| 857 |
+
|
| 858 |
+
|
| 859 |
+
class BolmoModel(BolmoPreTrainedModel):
|
| 860 |
+
def __init__(self, config: BolmoConfig):
|
| 861 |
+
super().__init__(config)
|
| 862 |
+
self.padding_idx = config.pad_token_id
|
| 863 |
+
self.vocab_size = config.vocab_size
|
| 864 |
+
|
| 865 |
+
self.local_encoder = BolmoLocalEncoder(config)
|
| 866 |
+
self.local_decoder = BolmoLocalDecoder(config)
|
| 867 |
+
|
| 868 |
+
self.layers = nn.ModuleList(
|
| 869 |
+
[BolmoDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
|
| 870 |
+
)
|
| 871 |
+
self.norm = BolmoRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
| 872 |
+
self.gradient_checkpointing = False
|
| 873 |
+
self.rotary_embs = nn.ModuleDict(
|
| 874 |
+
{
|
| 875 |
+
"sliding_attention": BolmoRotaryEmbedding(config=config, rope_type="default"),
|
| 876 |
+
"full_attention": BolmoRotaryEmbedding(config=config),
|
| 877 |
+
}
|
| 878 |
+
)
|
| 879 |
+
|
| 880 |
+
self.tokenizer_config = BolmoTokenizerConfig(**config.tokenizer_config)
|
| 881 |
+
self._tokenizer = None
|
| 882 |
+
|
| 883 |
+
# Initialize weights and apply final processing
|
| 884 |
+
self.post_init()
|
| 885 |
+
|
| 886 |
+
def get_input_embeddings(self):
|
| 887 |
+
return self.local_encoder.byte_embedding
|
| 888 |
+
|
| 889 |
+
def set_input_embeddings(self, value: nn.Embedding): # type: ignore
|
| 890 |
+
self.local_encoder.byte_embedding = value
|
| 891 |
+
|
| 892 |
+
@property
|
| 893 |
+
def tokenizer(self):
|
| 894 |
+
if self._tokenizer is None:
|
| 895 |
+
self._tokenizer = self.tokenizer_config.build()
|
| 896 |
+
|
| 897 |
+
return self._tokenizer
|
| 898 |
+
|
| 899 |
+
def prefill_boundary_prediction_forward(
|
| 900 |
+
self,
|
| 901 |
+
input_ids: torch.Tensor,
|
| 902 |
+
expanded_input_ids: Optional[torch.Tensor] = None,
|
| 903 |
+
sequence_start_indices: Optional[torch.Tensor] = None,
|
| 904 |
+
last_token_is_boundary: bool = False,
|
| 905 |
+
**kwargs,
|
| 906 |
+
) -> torch.Tensor:
|
| 907 |
+
_, _, _, boundary_mask = self.local_encoder.forward( # type: ignore
|
| 908 |
+
input_ids,
|
| 909 |
+
expanded_input_ids=expanded_input_ids,
|
| 910 |
+
boundary_state=MaskState(torch.full((input_ids.shape[0],), fill_value=last_token_is_boundary, device=input_ids.device, dtype=torch.bool)),
|
| 911 |
+
pad_state=MaskState(torch.zeros((input_ids.shape[0],), device=input_ids.device, dtype=torch.bool)),
|
| 912 |
+
sequence_start_indices=sequence_start_indices,
|
| 913 |
+
)
|
| 914 |
+
|
| 915 |
+
return cast(torch.Tensor, boundary_mask)
|
| 916 |
+
|
| 917 |
+
@check_model_inputs()
|
| 918 |
+
def forward(
|
| 919 |
+
self,
|
| 920 |
+
input_ids: torch.Tensor,
|
| 921 |
+
expanded_input_ids: Optional[torch.Tensor] = None,
|
| 922 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 923 |
+
position_ids: Optional[torch.Tensor] = None,
|
| 924 |
+
past_key_values: Optional[Cache] = None,
|
| 925 |
+
cache_position: Optional[torch.Tensor] = None,
|
| 926 |
+
use_cache: Optional[bool] = None,
|
| 927 |
+
boundary_mask: Optional[torch.Tensor] = None,
|
| 928 |
+
boundary_state: Optional[MaskState] = None,
|
| 929 |
+
pad_state: Optional[MaskState] = None,
|
| 930 |
+
sequence_start_indices: Optional[torch.Tensor] = None,
|
| 931 |
+
**kwargs: Unpack[TransformersKwargs],
|
| 932 |
+
) -> BaseModelOutputWithPast:
|
| 933 |
+
batch_size = input_ids.shape[0]
|
| 934 |
+
device = input_ids.device
|
| 935 |
+
|
| 936 |
+
if self.local_encoder.add_expanded_embeddings and expanded_input_ids is None and input_ids is not None:
|
| 937 |
+
# not optimized
|
| 938 |
+
expanded_input_ids_list: list[torch.Tensor] = []
|
| 939 |
+
for example_idx in range(batch_size):
|
| 940 |
+
expanded_input_ids_list.append(torch.tensor(self.tokenizer.expand_byte_ids(input_ids[example_idx].tolist()), dtype=torch.long, device=device))
|
| 941 |
+
expanded_input_ids = pad_right(expanded_input_ids_list, value=self.tokenizer.pad_token_id, multiple_of=1) # type: ignore
|
| 942 |
+
|
| 943 |
+
h_byte, h_patch, _, boundary_mask = self.local_encoder(
|
| 944 |
+
input_ids=input_ids,
|
| 945 |
+
expanded_input_ids=expanded_input_ids,
|
| 946 |
+
true_boundary_mask=boundary_mask,
|
| 947 |
+
boundary_state=boundary_state,
|
| 948 |
+
pad_state=pad_state,
|
| 949 |
+
)
|
| 950 |
+
|
| 951 |
+
if use_cache and past_key_values is None:
|
| 952 |
+
past_key_values = DynamicCache(config=self.config)
|
| 953 |
+
|
| 954 |
+
if cache_position is None:
|
| 955 |
+
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
|
| 956 |
+
cache_position: torch.Tensor = torch.arange(
|
| 957 |
+
past_seen_tokens, past_seen_tokens + h_patch.shape[1], device=device
|
| 958 |
+
)
|
| 959 |
+
|
| 960 |
+
if position_ids is None:
|
| 961 |
+
position_ids = cache_position.unsqueeze(0) # type: ignore
|
| 962 |
+
|
| 963 |
+
# It may already have been prepared by e.g. `generate`
|
| 964 |
+
if not isinstance(causal_mask_mapping := attention_mask, dict):
|
| 965 |
+
# Prepare mask arguments
|
| 966 |
+
mask_kwargs = {
|
| 967 |
+
"config": self.config,
|
| 968 |
+
"input_embeds": h_patch,
|
| 969 |
+
"attention_mask": attention_mask,
|
| 970 |
+
"cache_position": cache_position,
|
| 971 |
+
"past_key_values": past_key_values,
|
| 972 |
+
"position_ids": position_ids,
|
| 973 |
+
}
|
| 974 |
+
# Create the masks
|
| 975 |
+
causal_mask_mapping = {
|
| 976 |
+
"full_attention": create_causal_mask(**mask_kwargs),
|
| 977 |
+
"sliding_attention": create_sliding_window_causal_mask(**mask_kwargs),
|
| 978 |
+
}
|
| 979 |
+
|
| 980 |
+
position_embeddings_mapping = {
|
| 981 |
+
"sliding_attention": self.rotary_embs["sliding_attention"](h_byte, position_ids),
|
| 982 |
+
"full_attention": self.rotary_embs["full_attention"](h_byte, position_ids),
|
| 983 |
+
}
|
| 984 |
+
|
| 985 |
+
if h_patch.numel() > 0:
|
| 986 |
+
# we need to convert from right-pad to left-pad and back for prefill
|
| 987 |
+
# since flash attention expects left-pad and local/enc dec expect right-pad global tokens
|
| 988 |
+
# should add better left-pad support but this only affects prefill so OK for now
|
| 989 |
+
# although super inefficient!
|
| 990 |
+
if boundary_mask is not None: # prefill
|
| 991 |
+
n_boundaries = boundary_mask.sum(-1)
|
| 992 |
+
|
| 993 |
+
for i, current_n_boundaries in enumerate(n_boundaries):
|
| 994 |
+
h_patch[i, -current_n_boundaries:] = h_patch[i, :current_n_boundaries].clone()
|
| 995 |
+
|
| 996 |
+
h_patch_after_global = h_patch
|
| 997 |
+
|
| 998 |
+
for decoder_layer in self.layers[: self.config.num_hidden_layers]:
|
| 999 |
+
h_patch_after_global = decoder_layer(
|
| 1000 |
+
h_patch_after_global,
|
| 1001 |
+
attention_mask=causal_mask_mapping[decoder_layer.self_attn.attention_type],
|
| 1002 |
+
position_ids=position_ids,
|
| 1003 |
+
past_key_values=past_key_values,
|
| 1004 |
+
cache_position=cache_position,
|
| 1005 |
+
position_embeddings=position_embeddings_mapping[decoder_layer.self_attn.attention_type],
|
| 1006 |
+
**kwargs,
|
| 1007 |
+
)
|
| 1008 |
+
|
| 1009 |
+
if boundary_mask is not None: # prefill
|
| 1010 |
+
n_boundaries = boundary_mask.sum(-1)
|
| 1011 |
+
|
| 1012 |
+
for i, current_n_boundaries in enumerate(n_boundaries):
|
| 1013 |
+
h_patch_after_global[i, :current_n_boundaries] = h_patch_after_global[i, -current_n_boundaries:].clone()
|
| 1014 |
+
else:
|
| 1015 |
+
h_patch_after_global = h_patch
|
| 1016 |
+
|
| 1017 |
+
h_out = self.local_decoder.forward( # type: ignore
|
| 1018 |
+
embeds=h_byte,
|
| 1019 |
+
patch_embeds=h_patch_after_global,
|
| 1020 |
+
boundary_mask=boundary_mask,
|
| 1021 |
+
boundary_state=boundary_state,
|
| 1022 |
+
sequence_start_indices=sequence_start_indices,
|
| 1023 |
+
)
|
| 1024 |
+
h_out = self.norm(h_out)
|
| 1025 |
+
|
| 1026 |
+
return BaseModelOutputWithPast(
|
| 1027 |
+
last_hidden_state=h_out,
|
| 1028 |
+
past_key_values=past_key_values,
|
| 1029 |
+
)
|
| 1030 |
+
|
| 1031 |
+
|
| 1032 |
+
class BolmoForCausalLM(BolmoPreTrainedModel, GenerationMixin):
|
| 1033 |
+
_tied_weights_keys = ["lm_head.weight"]
|
| 1034 |
+
_tp_plan = {"lm_head": "colwise_rep"}
|
| 1035 |
+
_pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
|
| 1036 |
+
|
| 1037 |
+
def __init__(self, config):
|
| 1038 |
+
super().__init__(config)
|
| 1039 |
+
self.model = BolmoModel(config)
|
| 1040 |
+
self.vocab_size = config.vocab_size
|
| 1041 |
+
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
| 1042 |
+
|
| 1043 |
+
# Initialize weights and apply final processing
|
| 1044 |
+
self.post_init()
|
| 1045 |
+
|
| 1046 |
+
def get_output_embeddings(self):
|
| 1047 |
+
return self.lm_head
|
| 1048 |
+
|
| 1049 |
+
def set_output_embeddings(self, new_embeddings: nn.Linear):
|
| 1050 |
+
self.lm_head = new_embeddings
|
| 1051 |
+
|
| 1052 |
+
@can_return_tuple
|
| 1053 |
+
def forward(
|
| 1054 |
+
self,
|
| 1055 |
+
input_ids: torch.Tensor,
|
| 1056 |
+
expanded_input_ids: Optional[torch.Tensor] = None,
|
| 1057 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 1058 |
+
position_ids: Optional[torch.Tensor] = None,
|
| 1059 |
+
past_key_values: Optional[Cache] = None,
|
| 1060 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
| 1061 |
+
cache_position: Optional[torch.Tensor] = None,
|
| 1062 |
+
use_cache: Optional[bool] = None,
|
| 1063 |
+
boundary_mask: Optional[torch.Tensor] = None,
|
| 1064 |
+
boundary_state: Optional[MaskState] = None,
|
| 1065 |
+
pad_state: Optional[MaskState] = None,
|
| 1066 |
+
sequence_start_indices: Optional[torch.Tensor] = None,
|
| 1067 |
+
logits_to_keep: Union[int, torch.Tensor] = 0,
|
| 1068 |
+
**kwargs: Unpack[TransformersKwargs],
|
| 1069 |
+
) -> CausalLMOutputWithPast:
|
| 1070 |
+
r"""
|
| 1071 |
+
Example:
|
| 1072 |
+
|
| 1073 |
+
```python
|
| 1074 |
+
>>> from transformers import AutoTokenizer, BolmoForCausalLM
|
| 1075 |
+
|
| 1076 |
+
>>> model = BolmoForCausalLM.from_pretrained("meta-olmo3/Bolmo-2-7b-hf")
|
| 1077 |
+
>>> tokenizer = AutoTokenizer.from_pretrained("meta-olmo3/Bolmo-2-7b-hf")
|
| 1078 |
+
|
| 1079 |
+
>>> prompt = "Hey, are you conscious? Can you talk to me?"
|
| 1080 |
+
>>> inputs = tokenizer(prompt, return_tensors="pt")
|
| 1081 |
+
|
| 1082 |
+
>>> # Generate
|
| 1083 |
+
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
|
| 1084 |
+
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
| 1085 |
+
"Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
|
| 1086 |
+
```"""
|
| 1087 |
+
outputs: BaseModelOutputWithPast = self.model(
|
| 1088 |
+
input_ids=input_ids,
|
| 1089 |
+
expanded_input_ids=expanded_input_ids,
|
| 1090 |
+
attention_mask=attention_mask,
|
| 1091 |
+
position_ids=position_ids,
|
| 1092 |
+
past_key_values=past_key_values,
|
| 1093 |
+
inputs_embeds=inputs_embeds,
|
| 1094 |
+
cache_position=cache_position,
|
| 1095 |
+
use_cache=use_cache,
|
| 1096 |
+
boundary_mask=boundary_mask,
|
| 1097 |
+
boundary_state=boundary_state,
|
| 1098 |
+
pad_state=pad_state,
|
| 1099 |
+
sequence_start_indices=sequence_start_indices,
|
| 1100 |
+
**kwargs,
|
| 1101 |
+
)
|
| 1102 |
+
|
| 1103 |
+
hidden_states = cast(torch.Tensor, outputs.last_hidden_state)
|
| 1104 |
+
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
|
| 1105 |
+
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
|
| 1106 |
+
logits = self.lm_head(hidden_states[:, slice_indices, :])
|
| 1107 |
+
|
| 1108 |
+
return CausalLMOutputWithPast(
|
| 1109 |
+
logits=logits,
|
| 1110 |
+
past_key_values=outputs.past_key_values,
|
| 1111 |
+
hidden_states=outputs.hidden_states,
|
| 1112 |
+
attentions=outputs.attentions,
|
| 1113 |
+
)
|
| 1114 |
+
|
| 1115 |
+
@torch.no_grad()
|
| 1116 |
+
def generate( # type: ignore
|
| 1117 |
+
self,
|
| 1118 |
+
inputs: torch.Tensor,
|
| 1119 |
+
generation_config: Optional[GenerationConfig] = None,
|
| 1120 |
+
logits_processor: Optional[LogitsProcessorList] = None,
|
| 1121 |
+
stopping_criteria: Optional[StoppingCriteriaList] = None,
|
| 1122 |
+
use_model_defaults: Optional[bool] = None,
|
| 1123 |
+
**kwargs,
|
| 1124 |
+
) -> Union[GenerateOutput, torch.Tensor]:
|
| 1125 |
+
# generic preprocessing
|
| 1126 |
+
|
| 1127 |
+
generation_config, model_kwargs = self._prepare_generation_config(
|
| 1128 |
+
generation_config, use_model_defaults, **kwargs
|
| 1129 |
+
)
|
| 1130 |
+
self._prepare_special_tokens(generation_config, device=self.model.device)
|
| 1131 |
+
|
| 1132 |
+
logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
|
| 1133 |
+
stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()
|
| 1134 |
+
|
| 1135 |
+
# start of custom generate
|
| 1136 |
+
|
| 1137 |
+
expand_input_ids = self.model.local_encoder.add_expanded_embeddings
|
| 1138 |
+
batch_size = len(inputs)
|
| 1139 |
+
|
| 1140 |
+
if expand_input_ids:
|
| 1141 |
+
expanded_input_ids = []
|
| 1142 |
+
|
| 1143 |
+
for i in range(len(inputs)):
|
| 1144 |
+
expanded_input_ids.append(torch.tensor(self.model.tokenizer.expand_byte_ids(inputs[i].tolist()), device=self.device, dtype=torch.long))
|
| 1145 |
+
|
| 1146 |
+
expanded_input_ids = pad_left(expanded_input_ids, value=self.model.tokenizer.pad_token_id, multiple_of=1) # type: ignore
|
| 1147 |
+
else:
|
| 1148 |
+
expanded_input_ids = None
|
| 1149 |
+
|
| 1150 |
+
byte_input_ids = inputs
|
| 1151 |
+
sequence_start_indices = (byte_input_ids == self.model.tokenizer.pad_token_id).sum(-1)
|
| 1152 |
+
batch_size, prompt_len = byte_input_ids.shape
|
| 1153 |
+
finished = torch.zeros(batch_size, dtype=torch.bool, device=self.device)
|
| 1154 |
+
|
| 1155 |
+
boundary_offset = self.model.tokenizer.offset + 256
|
| 1156 |
+
eos = self.model.tokenizer.eos_token_id
|
| 1157 |
+
|
| 1158 |
+
self.model.local_encoder.free_inference_cache()
|
| 1159 |
+
self.model.local_decoder.free_inference_cache()
|
| 1160 |
+
|
| 1161 |
+
boundary_mask = self.model.prefill_boundary_prediction_forward( # type: ignore
|
| 1162 |
+
byte_input_ids,
|
| 1163 |
+
expanded_input_ids=expanded_input_ids,
|
| 1164 |
+
sequence_start_indices=sequence_start_indices,
|
| 1165 |
+
)
|
| 1166 |
+
|
| 1167 |
+
self.model.local_encoder.prepare_inference_cache(batch_size)
|
| 1168 |
+
self.model.local_decoder.prepare_inference_cache(batch_size)
|
| 1169 |
+
|
| 1170 |
+
# roll back by one and force decoding to account for lookahead
|
| 1171 |
+
boundary_mask = boundary_mask[:, :-1]
|
| 1172 |
+
# need to roll one byte back and force decoding to detect whether the last byte is a boundary
|
| 1173 |
+
forced_decoding_ids = byte_input_ids[:, -1].cpu().tolist()
|
| 1174 |
+
byte_input_ids = byte_input_ids[:, :-1]
|
| 1175 |
+
expanded_input_ids = expanded_input_ids[:, :-1] if expanded_input_ids is not None else None
|
| 1176 |
+
# stays the same unless last token is pad.
|
| 1177 |
+
sequence_start_indices = (byte_input_ids == self.model.tokenizer.pad_token_id).sum(-1)
|
| 1178 |
+
|
| 1179 |
+
has_default_max_length = kwargs.get("max_length") is None and generation_config.max_length is not None
|
| 1180 |
+
has_default_min_length = kwargs.get("min_length") is None and generation_config.min_length is not None
|
| 1181 |
+
generation_config = self._prepare_generated_length(
|
| 1182 |
+
generation_config=generation_config,
|
| 1183 |
+
has_default_max_length=has_default_max_length,
|
| 1184 |
+
has_default_min_length=has_default_min_length,
|
| 1185 |
+
model_input_name="input_ids",
|
| 1186 |
+
inputs_tensor=byte_input_ids,
|
| 1187 |
+
input_ids_length=byte_input_ids.shape[1],
|
| 1188 |
+
)
|
| 1189 |
+
|
| 1190 |
+
logits_processor = self._get_logits_processor(
|
| 1191 |
+
generation_config=generation_config, # type: ignore
|
| 1192 |
+
input_ids_seq_length=byte_input_ids.shape[1],
|
| 1193 |
+
encoder_input_ids=byte_input_ids, # type: ignore
|
| 1194 |
+
logits_processor=logits_processor,
|
| 1195 |
+
device=byte_input_ids.device, # type: ignore
|
| 1196 |
+
model_kwargs=model_kwargs,
|
| 1197 |
+
)
|
| 1198 |
+
stopping_criteria = self._get_stopping_criteria(
|
| 1199 |
+
generation_config=generation_config, # type: ignore
|
| 1200 |
+
stopping_criteria=stopping_criteria,
|
| 1201 |
+
tokenizer=self.model.tokenizer,
|
| 1202 |
+
)
|
| 1203 |
+
|
| 1204 |
+
# output container
|
| 1205 |
+
generated = byte_input_ids
|
| 1206 |
+
|
| 1207 |
+
max_n_prefill_patches = boundary_mask.sum(-1).max().item()
|
| 1208 |
+
tokens_generated_plus_prefilled = max_n_prefill_patches
|
| 1209 |
+
bytes_generated = 0
|
| 1210 |
+
|
| 1211 |
+
# generation state
|
| 1212 |
+
boundary_state = MaskState(boundary_mask[:, -1].clone())
|
| 1213 |
+
pad_state = MaskState(torch.zeros(batch_size, dtype=torch.bool, device=self.device))
|
| 1214 |
+
next_tokens = torch.full((batch_size,), self.model.tokenizer.bpe_token_end_id, device=self.device, dtype=torch.long) # type: ignore
|
| 1215 |
+
non_boundary_generated_tokens = [[byte_input_ids[example_idx, -1].item()] for example_idx in range(batch_size)]
|
| 1216 |
+
bytes_since_boundary = (boundary_mask.flip(1).cumsum(-1) == 0).sum(-1)
|
| 1217 |
+
is_first_forward = True
|
| 1218 |
+
global_past_key_values = None
|
| 1219 |
+
|
| 1220 |
+
while not finished.all():
|
| 1221 |
+
input_ids_for_model = (
|
| 1222 |
+
generated
|
| 1223 |
+
if is_first_forward
|
| 1224 |
+
else torch.tensor([x[-1] for x in non_boundary_generated_tokens], device=generated.device, dtype=generated.dtype).unsqueeze(1)
|
| 1225 |
+
)
|
| 1226 |
+
assert not (
|
| 1227 |
+
(input_ids_for_model == self.model.tokenizer.bpe_token_end_id) |
|
| 1228 |
+
(input_ids_for_model >= boundary_offset)
|
| 1229 |
+
).any().item() # type: ignore
|
| 1230 |
+
if expand_input_ids:
|
| 1231 |
+
expanded_input_ids_for_model = torch.zeros_like(input_ids_for_model)
|
| 1232 |
+
for i in range(input_ids_for_model.shape[0]):
|
| 1233 |
+
expanded_input_ids_for_model[i, :] = torch.tensor(self.model.tokenizer.expand_byte_ids(
|
| 1234 |
+
generated[i, :].tolist(),
|
| 1235 |
+
n_last=input_ids_for_model.shape[1],
|
| 1236 |
+
), device=expanded_input_ids_for_model.device, dtype=expanded_input_ids_for_model.dtype)
|
| 1237 |
+
else:
|
| 1238 |
+
expanded_input_ids_for_model = None
|
| 1239 |
+
|
| 1240 |
+
out = self.forward( # type: ignore
|
| 1241 |
+
input_ids_for_model,
|
| 1242 |
+
expanded_input_ids=expanded_input_ids_for_model,
|
| 1243 |
+
boundary_mask=boundary_mask if is_first_forward else None,
|
| 1244 |
+
boundary_state=boundary_state,
|
| 1245 |
+
pad_state=pad_state,
|
| 1246 |
+
sequence_start_indices=sequence_start_indices,
|
| 1247 |
+
logits_to_keep=1,
|
| 1248 |
+
use_cache=True,
|
| 1249 |
+
past_key_values=global_past_key_values,
|
| 1250 |
+
)
|
| 1251 |
+
next_token_logits = cast(torch.Tensor, out.logits)
|
| 1252 |
+
global_past_key_values = out.past_key_values
|
| 1253 |
+
|
| 1254 |
+
if boundary_state.all():
|
| 1255 |
+
# new token, must not be boundary
|
| 1256 |
+
bytes_since_boundary[:] = 0
|
| 1257 |
+
else:
|
| 1258 |
+
boundary_state.selective_add(1, bytes_since_boundary, inv=True)
|
| 1259 |
+
|
| 1260 |
+
if any(x is not None for x in forced_decoding_ids):
|
| 1261 |
+
# only supported for the first token atm, so len(next_token_logits) == batch_size
|
| 1262 |
+
assert len(next_token_logits) == batch_size and is_first_forward
|
| 1263 |
+
for example_idx in range(batch_size):
|
| 1264 |
+
forced_decoding_id = forced_decoding_ids[example_idx]
|
| 1265 |
+
|
| 1266 |
+
if forced_decoding_id is not None:
|
| 1267 |
+
no_boundary_logit = next_token_logits[example_idx, 0, forced_decoding_id].item()
|
| 1268 |
+
boundary_logit = next_token_logits[example_idx, 0, forced_decoding_id + boundary_offset].item()
|
| 1269 |
+
|
| 1270 |
+
next_token_logits[example_idx, 0, :] = -100_000
|
| 1271 |
+
next_token_logits[example_idx, 0, forced_decoding_id] = no_boundary_logit
|
| 1272 |
+
next_token_logits[example_idx, 0, forced_decoding_id + boundary_offset] = boundary_logit
|
| 1273 |
+
|
| 1274 |
+
forced_decoding_ids[example_idx] = None # only force once
|
| 1275 |
+
|
| 1276 |
+
# passing input_ids to logit processor not implemented
|
| 1277 |
+
next_token_scores = logits_processor(None, next_token_logits[:, -1]) # type: ignore
|
| 1278 |
+
|
| 1279 |
+
if generation_config is not None and generation_config.do_sample:
|
| 1280 |
+
probs = nn.functional.softmax(next_token_scores, dim=-1)
|
| 1281 |
+
new_next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)
|
| 1282 |
+
else:
|
| 1283 |
+
new_next_tokens = torch.argmax(next_token_scores, dim=-1)
|
| 1284 |
+
|
| 1285 |
+
if boundary_state.all() or is_first_forward:
|
| 1286 |
+
tokens_generated_plus_prefilled += 1
|
| 1287 |
+
|
| 1288 |
+
next_tokens = new_next_tokens
|
| 1289 |
+
next_tokens_cpu = next_tokens.cpu()
|
| 1290 |
+
for example_idx in range(batch_size):
|
| 1291 |
+
if finished[example_idx].item():
|
| 1292 |
+
continue
|
| 1293 |
+
|
| 1294 |
+
next_token_cpu = next_tokens_cpu[example_idx].item()
|
| 1295 |
+
|
| 1296 |
+
if next_token_cpu >= boundary_offset:
|
| 1297 |
+
next_token_cpu -= boundary_offset
|
| 1298 |
+
|
| 1299 |
+
non_boundary_generated_tokens[example_idx].append(next_token_cpu)
|
| 1300 |
+
else:
|
| 1301 |
+
next_tokens[:] = self.model.tokenizer.bpe_token_end_id # type: ignore
|
| 1302 |
+
boundary_state.selective_put(new_next_tokens, next_tokens, inv=True)
|
| 1303 |
+
next_tokens_cpu = next_tokens.cpu()
|
| 1304 |
+
|
| 1305 |
+
for example_idx in range(batch_size):
|
| 1306 |
+
if finished[example_idx].item():
|
| 1307 |
+
continue
|
| 1308 |
+
|
| 1309 |
+
next_token_cpu = next_tokens_cpu[example_idx].item()
|
| 1310 |
+
|
| 1311 |
+
if not boundary_state.cpu_mask[example_idx].item():
|
| 1312 |
+
if next_token_cpu >= boundary_offset:
|
| 1313 |
+
next_token_cpu -= boundary_offset
|
| 1314 |
+
|
| 1315 |
+
non_boundary_generated_tokens[example_idx].append(next_token_cpu)
|
| 1316 |
+
|
| 1317 |
+
is_first_forward = False
|
| 1318 |
+
|
| 1319 |
+
boundary_state = MaskState(
|
| 1320 |
+
(next_tokens == self.model.tokenizer.bpe_token_end_id) |
|
| 1321 |
+
(next_tokens >= boundary_offset) |
|
| 1322 |
+
finished
|
| 1323 |
+
) # type: ignore
|
| 1324 |
+
pad_state = MaskState(
|
| 1325 |
+
(next_tokens == self.model.tokenizer.bpe_token_end_id) |
|
| 1326 |
+
finished
|
| 1327 |
+
)
|
| 1328 |
+
|
| 1329 |
+
# Force EOS for (previously) finished sequences
|
| 1330 |
+
next_tokens = torch.where(finished, torch.full_like(next_tokens, eos), next_tokens)
|
| 1331 |
+
|
| 1332 |
+
# Append next tokens
|
| 1333 |
+
generated = torch.cat([generated, next_tokens.unsqueeze(-1)], dim=1)
|
| 1334 |
+
|
| 1335 |
+
# Handle finished sequences
|
| 1336 |
+
stop_hit = next_tokens.eq(eos) | next_tokens.eq(eos + boundary_offset)
|
| 1337 |
+
|
| 1338 |
+
for i in range(batch_size):
|
| 1339 |
+
# passing `scores` to stopping criteria not implemented
|
| 1340 |
+
if stopping_criteria(torch.tensor(non_boundary_generated_tokens[i], dtype=torch.long).unsqueeze(0), None).squeeze(0).item(): # type: ignore
|
| 1341 |
+
stop_hit[i] = True
|
| 1342 |
+
|
| 1343 |
+
finished |= stop_hit
|
| 1344 |
+
bytes_generated += 1
|
| 1345 |
+
|
| 1346 |
+
return pad_left([
|
| 1347 |
+
torch.cat([byte_input_ids[i, :-1], torch.tensor(x, dtype=torch.long, device=byte_input_ids.device)])
|
| 1348 |
+
for i, x in enumerate(non_boundary_generated_tokens)
|
| 1349 |
+
], value=self.model.tokenizer.pad_token_id, multiple_of=1) # type: ignore
|
| 1350 |
+
|
| 1351 |
+
__all__ = ["BolmoForCausalLM", "BolmoModel", "BolmoPreTrainedModel"]
|
special_tokens_map.json
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"bos_token": "<bos>",
|
| 3 |
+
"eos_token": "<bos>",
|
| 4 |
+
"pad_token": "<pad>"
|
| 5 |
+
}
|
tokenization_bolmo.py
ADDED
|
@@ -0,0 +1,378 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from dataclasses import dataclass, field
|
| 2 |
+
from functools import lru_cache
|
| 3 |
+
from typing import Optional, Union
|
| 4 |
+
from transformers import AutoTokenizer
|
| 5 |
+
from transformers.tokenization_utils import PreTrainedTokenizer
|
| 6 |
+
|
| 7 |
+
# Source: https://github.com/openai/gpt-2/blob/master/src/encoder.py#L9
|
| 8 |
+
# Also implemented in https://docs.rs/tokenizers/latest/src/tokenizers/pre_tokenizers/byte_level.rs.html#13-39
|
| 9 |
+
_CHARS_TO_BYTES = {
|
| 10 |
+
"Ā": 0, "ā": 1, "Ă": 2, "ă": 3, "Ą": 4, "ą": 5, "Ć": 6, "ć": 7, "Ĉ": 8,
|
| 11 |
+
"ĉ": 9, "Ċ": 10, "ċ": 11, "Č": 12, "č": 13, "Ď": 14, "ď": 15, "Đ": 16,
|
| 12 |
+
"đ": 17, "Ē": 18, "ē": 19, "Ĕ": 20, "ĕ": 21, "Ė": 22, "ė": 23, "Ę": 24,
|
| 13 |
+
"ę": 25, "Ě": 26, "ě": 27, "Ĝ": 28, "ĝ": 29, "Ğ": 30, "ğ": 31, "Ġ": 32,
|
| 14 |
+
"!": 33, '"': 34, "#": 35, "$": 36, "%": 37, "&": 38, "'": 39, "(": 40,
|
| 15 |
+
")": 41, "*": 42, "+": 43, ",": 44, "-": 45, ".": 46, "/": 47, "0": 48,
|
| 16 |
+
"1": 49, "2": 50, "3": 51, "4": 52, "5": 53, "6": 54, "7": 55, "8": 56,
|
| 17 |
+
"9": 57, ":": 58, ";": 59, "<": 60, "=": 61, ">": 62, "?": 63, "@": 64,
|
| 18 |
+
"A": 65, "B": 66, "C": 67, "D": 68, "E": 69, "F": 70, "G": 71, "H": 72,
|
| 19 |
+
"I": 73, "J": 74, "K": 75, "L": 76, "M": 77, "N": 78, "O": 79, "P": 80,
|
| 20 |
+
"Q": 81, "R": 82, "S": 83, "T": 84, "U": 85, "V": 86, "W": 87, "X": 88,
|
| 21 |
+
"Y": 89, "Z": 90, "[": 91, "\\": 92, "]": 93, "^": 94, "_": 95, "`": 96,
|
| 22 |
+
"a": 97, "b": 98, "c": 99, "d": 100, "e": 101, "f": 102, "g": 103,
|
| 23 |
+
"h": 104, "i": 105, "j": 106, "k": 107, "l": 108, "m": 109, "n": 110,
|
| 24 |
+
"o": 111, "p": 112, "q": 113, "r": 114, "s": 115, "t": 116, "u": 117,
|
| 25 |
+
"v": 118, "w": 119, "x": 120, "y": 121, "z": 122, "{": 123, "|": 124,
|
| 26 |
+
"}": 125, "~": 126, "ġ": 127, "Ģ": 128, "ģ": 129, "Ĥ": 130, "ĥ": 131,
|
| 27 |
+
"Ħ": 132, "ħ": 133, "Ĩ": 134, "ĩ": 135, "Ī": 136, "ī": 137, "Ĭ": 138,
|
| 28 |
+
"ĭ": 139, "Į": 140, "į": 141, "İ": 142, "ı": 143, "IJ": 144, "ij": 145,
|
| 29 |
+
"Ĵ": 146, "ĵ": 147, "Ķ": 148, "ķ": 149, "ĸ": 150, "Ĺ": 151, "ĺ": 152,
|
| 30 |
+
"Ļ": 153, "ļ": 154, "Ľ": 155, "ľ": 156, "Ŀ": 157, "ŀ": 158, "Ł": 159,
|
| 31 |
+
"ł": 160, "¡": 161, "¢": 162, "£": 163, "¤": 164, "¥": 165, "¦": 166,
|
| 32 |
+
"§": 167, "¨": 168, "©": 169, "ª": 170, "«": 171, "¬": 172, "Ń": 173,
|
| 33 |
+
"®": 174, "¯": 175, "°": 176, "±": 177, "²": 178, "³": 179, "´": 180,
|
| 34 |
+
"µ": 181, "¶": 182, "·": 183, "¸": 184, "¹": 185, "º": 186, "»": 187,
|
| 35 |
+
"¼": 188, "½": 189, "¾": 190, "¿": 191, "À": 192, "Á": 193, "Â": 194,
|
| 36 |
+
"Ã": 195, "Ä": 196, "Å": 197, "Æ": 198, "Ç": 199, "È": 200, "É": 201,
|
| 37 |
+
"Ê": 202, "Ë": 203, "Ì": 204, "Í": 205, "Î": 206, "Ï": 207, "Ð": 208,
|
| 38 |
+
"Ñ": 209, "Ò": 210, "Ó": 211, "Ô": 212, "Õ": 213, "Ö": 214, "×": 215,
|
| 39 |
+
"Ø": 216, "Ù": 217, "Ú": 218, "Û": 219, "Ü": 220, "Ý": 221, "Þ": 222,
|
| 40 |
+
"ß": 223, "à": 224, "á": 225, "â": 226, "ã": 227, "ä": 228, "å": 229,
|
| 41 |
+
"æ": 230, "ç": 231, "è": 232, "é": 233, "ê": 234, "ë": 235, "ì": 236,
|
| 42 |
+
"í": 237, "î": 238, "ï": 239, "ð": 240, "ñ": 241, "ò": 242, "ó": 243,
|
| 43 |
+
"ô": 244, "õ": 245, "ö": 246, "÷": 247, "ø": 248, "ù": 249, "ú": 250,
|
| 44 |
+
"û": 251, "ü": 252, "ý": 253, "þ": 254, "ÿ": 255,
|
| 45 |
+
}
|
| 46 |
+
_BYTES_TO_CHARS = {v: k for k, v in _CHARS_TO_BYTES.items()}
|
| 47 |
+
|
| 48 |
+
def _bytes_to_chars(byte_sequence: bytes) -> str:
|
| 49 |
+
return "".join(_BYTES_TO_CHARS[byte] for byte in byte_sequence)
|
| 50 |
+
|
| 51 |
+
def _chars_to_bytes(char_sequence: str) -> list:
|
| 52 |
+
return list(bytes(_CHARS_TO_BYTES[char] for char in char_sequence))
|
| 53 |
+
|
| 54 |
+
@dataclass
|
| 55 |
+
class BolmoTokenizerConfig:
|
| 56 |
+
vocab_size: int
|
| 57 |
+
bos_token_id: int
|
| 58 |
+
pad_token_id: int
|
| 59 |
+
eos_token_id: int
|
| 60 |
+
bpe_token_end_id: int
|
| 61 |
+
special_tokens: list[str] = field(default_factory=lambda: [])
|
| 62 |
+
special_tokens_first: bool = True
|
| 63 |
+
original_identifier: Optional[str] = None
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
@classmethod
|
| 67 |
+
def bolmo(cls) -> "BolmoTokenizerConfig":
|
| 68 |
+
special_tokens = [
|
| 69 |
+
"<pad>",
|
| 70 |
+
"<bos>",
|
| 71 |
+
"<eos>",
|
| 72 |
+
"<bpe_token_end>",
|
| 73 |
+
]
|
| 74 |
+
|
| 75 |
+
return cls(
|
| 76 |
+
# *2 to accomodate fused boundary tokens
|
| 77 |
+
vocab_size=(len(special_tokens) + 256) * 2,
|
| 78 |
+
special_tokens=special_tokens,
|
| 79 |
+
bos_token_id=special_tokens.index("<bos>"),
|
| 80 |
+
pad_token_id=special_tokens.index("<pad>"),
|
| 81 |
+
eos_token_id=special_tokens.index("<bos>"),
|
| 82 |
+
bpe_token_end_id=special_tokens.index("<bpe_token_end>"),
|
| 83 |
+
original_identifier="allenai/dolma2-tokenizer",
|
| 84 |
+
)
|
| 85 |
+
|
| 86 |
+
def build(self):
|
| 87 |
+
return BolmoTokenizer(tokenizer_config=self)
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
class BolmoTokenizer(PreTrainedTokenizer):
|
| 91 |
+
TOKEN_ID_KEY = -1
|
| 92 |
+
|
| 93 |
+
def __init__(self, **kwargs):
|
| 94 |
+
tokenizer_config = kwargs.pop("tokenizer_config", BolmoTokenizerConfig.bolmo())
|
| 95 |
+
|
| 96 |
+
self.config = tokenizer_config
|
| 97 |
+
self.hf_tokenizer = AutoTokenizer.from_pretrained(tokenizer_config.original_identifier)
|
| 98 |
+
if self.config.special_tokens_first:
|
| 99 |
+
self.offset = len(tokenizer_config.special_tokens)
|
| 100 |
+
self.special_tokens_offset = 0
|
| 101 |
+
else:
|
| 102 |
+
self.offset = 0
|
| 103 |
+
self.special_tokens_offset = self.config.vocab_size - len(tokenizer_config.special_tokens)
|
| 104 |
+
|
| 105 |
+
self.byte_sequences = {}
|
| 106 |
+
|
| 107 |
+
for key, value in self.hf_tokenizer.get_vocab().items():
|
| 108 |
+
if key in self.config.special_tokens:
|
| 109 |
+
byte_sequence = [self.special_tokens_offset + self.config.special_tokens.index(key)]
|
| 110 |
+
elif value == self.hf_tokenizer.eos_token_id and self.eos_token_id is not None:
|
| 111 |
+
byte_sequence = [self.eos_token_id]
|
| 112 |
+
elif value == self.hf_tokenizer.bos_token_id and self.bos_token_id is not None:
|
| 113 |
+
byte_sequence = [self.bos_token_id]
|
| 114 |
+
elif value == self.hf_tokenizer.pad_token_id and self.pad_token_id is not None:
|
| 115 |
+
byte_sequence = [self.pad_token_id]
|
| 116 |
+
else:
|
| 117 |
+
byte_sequence = [self.offset + i for i in _chars_to_bytes(key)]
|
| 118 |
+
|
| 119 |
+
assert self.byte_sequences.get(value) is None
|
| 120 |
+
self.byte_sequences[value] = byte_sequence
|
| 121 |
+
|
| 122 |
+
self.byte_trie = {}
|
| 123 |
+
|
| 124 |
+
for token_id, byte_sequence in self.byte_sequences.items():
|
| 125 |
+
current_dict = self.byte_trie
|
| 126 |
+
for byte in byte_sequence[::-1]: # retrieved from the back so store in reverse order
|
| 127 |
+
if byte not in current_dict:
|
| 128 |
+
current_dict[byte] = {}
|
| 129 |
+
current_dict = current_dict[byte]
|
| 130 |
+
current_dict[BolmoTokenizer.TOKEN_ID_KEY] = token_id
|
| 131 |
+
|
| 132 |
+
self.add_bos_token = True
|
| 133 |
+
self.add_eos_token = False
|
| 134 |
+
self.padding_side = "left" # for generate
|
| 135 |
+
|
| 136 |
+
super().__init__(
|
| 137 |
+
bos_token=self.config.special_tokens[self.config.bos_token_id],
|
| 138 |
+
eos_token=self.config.special_tokens[self.config.eos_token_id],
|
| 139 |
+
pad_token=self.config.special_tokens[self.config.pad_token_id],
|
| 140 |
+
extra_ids=0,
|
| 141 |
+
)
|
| 142 |
+
|
| 143 |
+
@property
|
| 144 |
+
def bos_token_id(self):
|
| 145 |
+
return self.config.bos_token_id
|
| 146 |
+
|
| 147 |
+
@property
|
| 148 |
+
def eos_token_id(self):
|
| 149 |
+
return self.config.eos_token_id
|
| 150 |
+
|
| 151 |
+
@property
|
| 152 |
+
def pad_token_id(self):
|
| 153 |
+
return self.config.pad_token_id
|
| 154 |
+
|
| 155 |
+
@property
|
| 156 |
+
def bpe_token_end_id(self):
|
| 157 |
+
return self.config.bpe_token_end_id
|
| 158 |
+
|
| 159 |
+
@property
|
| 160 |
+
def vocab_size(self):
|
| 161 |
+
return self.config.vocab_size
|
| 162 |
+
|
| 163 |
+
def _convert_id_to_token(self, index):
|
| 164 |
+
if index < self.offset:
|
| 165 |
+
return self.config.special_tokens[index - self.special_tokens_offset]
|
| 166 |
+
|
| 167 |
+
if index >= self.offset + 256 and index < self.offset * 2 + 256:
|
| 168 |
+
# special token with fused boundary
|
| 169 |
+
return self.config.special_tokens[index - self.offset - 256] + "b"
|
| 170 |
+
|
| 171 |
+
return _BYTES_TO_CHARS[index - self.offset - 256 - self.offset] + "b" if index >= self.offset + 256 else _BYTES_TO_CHARS[index - self.offset]
|
| 172 |
+
|
| 173 |
+
def _convert_token_to_id(self, token):
|
| 174 |
+
if token in self.config.special_tokens:
|
| 175 |
+
return self.config.special_tokens.index(token)
|
| 176 |
+
|
| 177 |
+
if token in [x + "b" for x in self.config.special_tokens]:
|
| 178 |
+
# special token with fused boundary
|
| 179 |
+
return 256 + self.config.special_tokens.index(token[:-1])
|
| 180 |
+
|
| 181 |
+
if len(token) > 1 and token[-1] == "b":
|
| 182 |
+
return self.offset + 256 + _CHARS_TO_BYTES[token[0]]
|
| 183 |
+
else:
|
| 184 |
+
return self.offset + _CHARS_TO_BYTES[token]
|
| 185 |
+
|
| 186 |
+
def get_vocab(self):
|
| 187 |
+
vocab = {self.convert_ids_to_tokens(i): i for i in range(self.vocab_size)}
|
| 188 |
+
return vocab
|
| 189 |
+
|
| 190 |
+
def expand_byte_ids(self, byte_ids: list[int], n_last: Optional[int] = None) -> list[int]:
|
| 191 |
+
# search in the byte tree for the longest matching token at every byte position
|
| 192 |
+
expanded_ids = []
|
| 193 |
+
for i in range(len(byte_ids)):
|
| 194 |
+
if n_last is not None and i < len(byte_ids) - n_last:
|
| 195 |
+
continue
|
| 196 |
+
|
| 197 |
+
current_dict = self.byte_trie
|
| 198 |
+
current_expansion = None
|
| 199 |
+
|
| 200 |
+
for i in range(i, -1, -1):
|
| 201 |
+
byte = byte_ids[i]
|
| 202 |
+
|
| 203 |
+
if byte == self.bpe_token_end_id:
|
| 204 |
+
# skip bpe token end markers, needed for generation
|
| 205 |
+
continue
|
| 206 |
+
|
| 207 |
+
if byte >= self.offset + 256:
|
| 208 |
+
# ignore fused boundary
|
| 209 |
+
byte -= self.offset + 256
|
| 210 |
+
|
| 211 |
+
try:
|
| 212 |
+
current_dict = current_dict[byte]
|
| 213 |
+
if BolmoTokenizer.TOKEN_ID_KEY in current_dict:
|
| 214 |
+
current_expansion = current_dict[BolmoTokenizer.TOKEN_ID_KEY]
|
| 215 |
+
except KeyError:
|
| 216 |
+
assert current_expansion is not None
|
| 217 |
+
break
|
| 218 |
+
|
| 219 |
+
expanded_ids.append(current_expansion)
|
| 220 |
+
|
| 221 |
+
return expanded_ids
|
| 222 |
+
|
| 223 |
+
# Copied from transformers.models.llama.tokenization_llama.LlamaTokenizer.build_inputs_with_special_tokens
|
| 224 |
+
def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None):
|
| 225 |
+
bos_token_id = [self.bos_token_id] if self.add_bos_token else []
|
| 226 |
+
eos_token_id = [self.eos_token_id] if self.add_eos_token else []
|
| 227 |
+
|
| 228 |
+
output = bos_token_id + token_ids_0 + eos_token_id
|
| 229 |
+
|
| 230 |
+
if token_ids_1 is not None:
|
| 231 |
+
output = output + bos_token_id + token_ids_1 + eos_token_id
|
| 232 |
+
|
| 233 |
+
return output
|
| 234 |
+
|
| 235 |
+
# Copied from transformers.models.llama.tokenization_llama.LlamaTokenizer.get_special_tokens_mask
|
| 236 |
+
def get_special_tokens_mask(
|
| 237 |
+
self, token_ids_0: list[int], token_ids_1: Optional[list[int]] = None, already_has_special_tokens: bool = False
|
| 238 |
+
) -> list[int]:
|
| 239 |
+
"""
|
| 240 |
+
Retrieve sequence ids from a token list that has no special tokens added. This method is called when adding
|
| 241 |
+
special tokens using the tokenizer `prepare_for_model` method.
|
| 242 |
+
Args:
|
| 243 |
+
token_ids_0 (`List[int]`):
|
| 244 |
+
List of IDs.
|
| 245 |
+
token_ids_1 (`List[int]`, *optional*):
|
| 246 |
+
Optional second list of IDs for sequence pairs.
|
| 247 |
+
already_has_special_tokens (`bool`, *optional*, defaults to `False`):
|
| 248 |
+
Whether or not the token list is already formatted with special tokens for the model.
|
| 249 |
+
Returns:
|
| 250 |
+
`List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token.
|
| 251 |
+
"""
|
| 252 |
+
if already_has_special_tokens:
|
| 253 |
+
return super().get_special_tokens_mask(
|
| 254 |
+
token_ids_0=token_ids_0, token_ids_1=token_ids_1, already_has_special_tokens=True
|
| 255 |
+
)
|
| 256 |
+
|
| 257 |
+
bos_token_id = [1] if self.add_bos_token else []
|
| 258 |
+
eos_token_id = [1] if self.add_eos_token else []
|
| 259 |
+
|
| 260 |
+
if token_ids_1 is None:
|
| 261 |
+
return bos_token_id + ([0] * len(token_ids_0)) + eos_token_id
|
| 262 |
+
return (
|
| 263 |
+
bos_token_id
|
| 264 |
+
+ ([0] * len(token_ids_0))
|
| 265 |
+
+ eos_token_id
|
| 266 |
+
+ bos_token_id
|
| 267 |
+
+ ([0] * len(token_ids_1))
|
| 268 |
+
+ eos_token_id
|
| 269 |
+
)
|
| 270 |
+
|
| 271 |
+
# Copied from transformers.models.llama.tokenization_llama.LlamaTokenizer.create_token_type_ids_from_sequences
|
| 272 |
+
def create_token_type_ids_from_sequences(
|
| 273 |
+
self, token_ids_0: list[int], token_ids_1: Optional[list[int]] = None
|
| 274 |
+
) -> list[int]:
|
| 275 |
+
"""
|
| 276 |
+
Creates a mask from the two sequences passed to be used in a sequence-pair classification task. An ALBERT
|
| 277 |
+
sequence pair mask has the following format:
|
| 278 |
+
```
|
| 279 |
+
0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1
|
| 280 |
+
| first sequence | second sequence |
|
| 281 |
+
```
|
| 282 |
+
if token_ids_1 is None, only returns the first portion of the mask (0s).
|
| 283 |
+
Args:
|
| 284 |
+
token_ids_0 (`List[int]`):
|
| 285 |
+
List of ids.
|
| 286 |
+
token_ids_1 (`List[int]`, *optional*):
|
| 287 |
+
Optional second list of IDs for sequence pairs.
|
| 288 |
+
Returns:
|
| 289 |
+
`List[int]`: List of [token type IDs](../glossary#token-type-ids) according to the given sequence(s).
|
| 290 |
+
"""
|
| 291 |
+
bos_token_id = [self.bos_token_id] if self.add_bos_token else []
|
| 292 |
+
eos_token_id = [self.eos_token_id] if self.add_eos_token else []
|
| 293 |
+
|
| 294 |
+
output = [0] * len(bos_token_id + token_ids_0 + eos_token_id)
|
| 295 |
+
|
| 296 |
+
if token_ids_1 is not None:
|
| 297 |
+
output += [1] * len(bos_token_id + token_ids_1 + eos_token_id)
|
| 298 |
+
|
| 299 |
+
return output
|
| 300 |
+
|
| 301 |
+
def _tokenize(self, text: str, **kwargs) -> list[str]:
|
| 302 |
+
"""Take as input a string and return a list of strings (tokens) for words/sub-words"""
|
| 303 |
+
tokens = self.convert_ids_to_tokens(self._bolmo_encode(text))
|
| 304 |
+
return tokens
|
| 305 |
+
|
| 306 |
+
def _patch_ids_to_byte_ids(self, input_ids: list[int]):
|
| 307 |
+
return [byte_token_id for token_id in input_ids for byte_token_id in self.byte_sequences[token_id]]
|
| 308 |
+
|
| 309 |
+
def _bolmo_encode(self, string: str, add_special_tokens=False):
|
| 310 |
+
input_ids = self.hf_tokenizer.encode(string, add_special_tokens=add_special_tokens)
|
| 311 |
+
return self._patch_ids_to_byte_ids(input_ids)
|
| 312 |
+
|
| 313 |
+
def _bolmo_decode(self, tokens: list[int], skip_special_tokens: bool = False) -> str:
|
| 314 |
+
return self._decode_to_bytes(tokens, skip_special_tokens=skip_special_tokens).decode("utf-8", errors="replace")
|
| 315 |
+
|
| 316 |
+
def _decode_to_bytes(self, tokens: list[int], skip_special_tokens: bool = False) -> bytes:
|
| 317 |
+
tokens_without_boundary = []
|
| 318 |
+
for token in tokens:
|
| 319 |
+
if token >= (self.offset + 256):
|
| 320 |
+
token -= self.offset + 256
|
| 321 |
+
|
| 322 |
+
tokens_without_boundary.append(token)
|
| 323 |
+
|
| 324 |
+
utf8_bytes = []
|
| 325 |
+
|
| 326 |
+
for token in tokens_without_boundary:
|
| 327 |
+
if token < self.offset:
|
| 328 |
+
if skip_special_tokens:
|
| 329 |
+
continue
|
| 330 |
+
else:
|
| 331 |
+
utf8_bytes.extend(self.config.special_tokens[token].encode("utf-8"))
|
| 332 |
+
else:
|
| 333 |
+
utf8_bytes.append(min(token - self.offset, 255))
|
| 334 |
+
|
| 335 |
+
return bytes(utf8_bytes)
|
| 336 |
+
|
| 337 |
+
def get_tokens_and_patch_lengths(self, original_input_ids: list[int], add_bos=False, strip_pad=False, skip_last=False):
|
| 338 |
+
if add_bos and self.bos_token_id is not None:
|
| 339 |
+
byte_tokens = [self.bos_token_id]
|
| 340 |
+
patch_lengths = [1]
|
| 341 |
+
else:
|
| 342 |
+
byte_tokens = []
|
| 343 |
+
patch_lengths = []
|
| 344 |
+
|
| 345 |
+
for idx, token in enumerate(original_input_ids):
|
| 346 |
+
# optionally skip last token to keep the length the same if add_bos=True
|
| 347 |
+
if skip_last and idx == len(original_input_ids) - 1:
|
| 348 |
+
break
|
| 349 |
+
|
| 350 |
+
token_byte_tokens = self._patch_ids_to_byte_ids([int(token)])
|
| 351 |
+
|
| 352 |
+
if strip_pad and all(t == self.pad_token_id for t in token_byte_tokens):
|
| 353 |
+
# skip padding tokens
|
| 354 |
+
continue
|
| 355 |
+
|
| 356 |
+
patch_lengths.append(len(token_byte_tokens))
|
| 357 |
+
byte_tokens.extend(token_byte_tokens)
|
| 358 |
+
|
| 359 |
+
return byte_tokens, patch_lengths
|
| 360 |
+
|
| 361 |
+
def convert_tokens_to_string(self, tokens: list[str]) -> str:
|
| 362 |
+
return self._bolmo_decode(self.convert_tokens_to_ids(tokens), skip_special_tokens=False) # type: ignore
|
| 363 |
+
|
| 364 |
+
def _decode(
|
| 365 |
+
self,
|
| 366 |
+
token_ids: Union[int, list[int]],
|
| 367 |
+
skip_special_tokens: bool = False,
|
| 368 |
+
clean_up_tokenization_spaces: Optional[bool] = None,
|
| 369 |
+
spaces_between_special_tokens: bool = True,
|
| 370 |
+
**kwargs,
|
| 371 |
+
) -> str:
|
| 372 |
+
if isinstance(token_ids, int):
|
| 373 |
+
token_ids = [token_ids]
|
| 374 |
+
|
| 375 |
+
return self._bolmo_decode(token_ids, skip_special_tokens=skip_special_tokens)
|
| 376 |
+
|
| 377 |
+
def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> tuple[str]:
|
| 378 |
+
return () # type: ignore
|
tokenizer_config.json
ADDED
|
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"added_tokens_decoder": {
|
| 3 |
+
"0": {
|
| 4 |
+
"content": "<pad>",
|
| 5 |
+
"lstrip": false,
|
| 6 |
+
"normalized": false,
|
| 7 |
+
"rstrip": false,
|
| 8 |
+
"single_word": false,
|
| 9 |
+
"special": true
|
| 10 |
+
},
|
| 11 |
+
"1": {
|
| 12 |
+
"content": "<bos>",
|
| 13 |
+
"lstrip": false,
|
| 14 |
+
"normalized": false,
|
| 15 |
+
"rstrip": false,
|
| 16 |
+
"single_word": false,
|
| 17 |
+
"special": true
|
| 18 |
+
}
|
| 19 |
+
},
|
| 20 |
+
"auto_map": {
|
| 21 |
+
"AutoTokenizer": [
|
| 22 |
+
"tokenization_bolmo.BolmoTokenizer",
|
| 23 |
+
null
|
| 24 |
+
]
|
| 25 |
+
},
|
| 26 |
+
"bos_token": "<bos>",
|
| 27 |
+
"clean_up_tokenization_spaces": false,
|
| 28 |
+
"eos_token": "<bos>",
|
| 29 |
+
"extra_ids": 0,
|
| 30 |
+
"extra_special_tokens": {},
|
| 31 |
+
"model_max_length": 1000000000000000019884624838656,
|
| 32 |
+
"pad_token": "<pad>",
|
| 33 |
+
"tokenizer_class": "BolmoTokenizer"
|
| 34 |
+
}
|
utils_bolmo.py
ADDED
|
@@ -0,0 +1,127 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
import torch.nn.functional as F
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
def compute_boundary_mask(boundary_logprobs: torch.Tensor, boundary_threshold: str) -> torch.Tensor:
|
| 8 |
+
if boundary_threshold.startswith("sample:"):
|
| 9 |
+
_, temperature = boundary_threshold.split(":")
|
| 10 |
+
temperature = float(temperature)
|
| 11 |
+
|
| 12 |
+
if temperature == 0:
|
| 13 |
+
return (boundary_logprobs > math.log(0.5))
|
| 14 |
+
elif temperature == 1:
|
| 15 |
+
return torch.bernoulli(torch.exp(boundary_logprobs)).to(torch.bool)
|
| 16 |
+
else:
|
| 17 |
+
raise NotImplementedError("Temperatures outside {0,1} are not implemented yet.")
|
| 18 |
+
elif boundary_threshold.startswith("topk:"):
|
| 19 |
+
_, topk = boundary_threshold.split(":")
|
| 20 |
+
topk = int(topk)
|
| 21 |
+
thresholds = torch.quantile(boundary_logprobs, dim=1, q=1 - (topk / boundary_logprobs.shape[1]))
|
| 22 |
+
return (boundary_logprobs >= thresholds.unsqueeze(-1))
|
| 23 |
+
elif boundary_threshold.startswith("topk_percent:"):
|
| 24 |
+
_, topk_percent = boundary_threshold.split(":")
|
| 25 |
+
topk_percent = float(topk_percent)
|
| 26 |
+
assert 0 <= topk_percent <= 1
|
| 27 |
+
thresholds = torch.quantile(boundary_logprobs, dim=1, q=1 - topk_percent)
|
| 28 |
+
return (boundary_logprobs >= thresholds.unsqueeze(-1))
|
| 29 |
+
else:
|
| 30 |
+
raise ValueError(f"Unknown boundary threshold: {boundary_threshold}")
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
def _pad(tensors: list[torch.Tensor], multiple_of: int, direction: str, value):
|
| 34 |
+
max_len = max(t.size(0) for t in tensors)
|
| 35 |
+
if multiple_of > 1:
|
| 36 |
+
# Round up max_len to the nearest multiple_of
|
| 37 |
+
max_len = ((max_len + multiple_of - 1) // multiple_of) * multiple_of
|
| 38 |
+
padded = []
|
| 39 |
+
for t in tensors:
|
| 40 |
+
if direction == "left":
|
| 41 |
+
pad_shape = (max_len - t.size(0), 0)
|
| 42 |
+
elif direction == "right":
|
| 43 |
+
pad_shape = (0, max_len - t.size(0))
|
| 44 |
+
else:
|
| 45 |
+
raise ValueError(f"Unknown direction: {direction}. Must be 'left' or 'right'.")
|
| 46 |
+
padded.append(F.pad(t, pad_shape, value=value))
|
| 47 |
+
return torch.stack(padded, dim=0)
|
| 48 |
+
|
| 49 |
+
def pad_right(
|
| 50 |
+
tensors: list[torch.Tensor],
|
| 51 |
+
multiple_of: int = 128,
|
| 52 |
+
value=0,
|
| 53 |
+
):
|
| 54 |
+
return _pad(tensors, multiple_of, direction="right", value=value)
|
| 55 |
+
|
| 56 |
+
def pad_left(
|
| 57 |
+
tensors: list[torch.Tensor],
|
| 58 |
+
multiple_of: int = 128,
|
| 59 |
+
value=0,
|
| 60 |
+
):
|
| 61 |
+
return _pad(tensors, multiple_of, direction="left", value=value)
|
| 62 |
+
|
| 63 |
+
class MaskState:
|
| 64 |
+
def __init__(self, mask):
|
| 65 |
+
self.cpu_mask = mask.cpu()
|
| 66 |
+
|
| 67 |
+
self.mask = mask
|
| 68 |
+
self.inv_mask = ~mask
|
| 69 |
+
self._all = self.cpu_mask.all().item()
|
| 70 |
+
self._any = self.cpu_mask.any().item()
|
| 71 |
+
|
| 72 |
+
def any(self):
|
| 73 |
+
return self._any
|
| 74 |
+
|
| 75 |
+
def all(self):
|
| 76 |
+
return self._all
|
| 77 |
+
|
| 78 |
+
def selective_get(self, x, inv=False):
|
| 79 |
+
# try to avoid sync through nonzero on index
|
| 80 |
+
if inv:
|
| 81 |
+
if self.all():
|
| 82 |
+
return x[[]]
|
| 83 |
+
elif not self.any():
|
| 84 |
+
return x
|
| 85 |
+
else:
|
| 86 |
+
return x[self.inv_mask]
|
| 87 |
+
else:
|
| 88 |
+
if self.all():
|
| 89 |
+
return x
|
| 90 |
+
elif not self.any():
|
| 91 |
+
return x[[]]
|
| 92 |
+
else:
|
| 93 |
+
return x[self.mask]
|
| 94 |
+
|
| 95 |
+
def selective_put(self, x, out, inv=False):
|
| 96 |
+
# try to avoid sync through nonzero on index
|
| 97 |
+
if inv:
|
| 98 |
+
if self.all():
|
| 99 |
+
return
|
| 100 |
+
elif not self.any():
|
| 101 |
+
out.copy_(x)
|
| 102 |
+
else:
|
| 103 |
+
out[self.inv_mask] = x
|
| 104 |
+
else:
|
| 105 |
+
if self.all():
|
| 106 |
+
out.copy_(x)
|
| 107 |
+
elif not self.any():
|
| 108 |
+
return
|
| 109 |
+
else:
|
| 110 |
+
out[self.mask] = x
|
| 111 |
+
|
| 112 |
+
def selective_add(self, x, out, inv=False):
|
| 113 |
+
# try to avoid sync through nonzero on index
|
| 114 |
+
if inv:
|
| 115 |
+
if self.all():
|
| 116 |
+
return
|
| 117 |
+
elif not self.any():
|
| 118 |
+
out.add_(x)
|
| 119 |
+
else:
|
| 120 |
+
out[self.inv_mask] += x
|
| 121 |
+
else:
|
| 122 |
+
if self.all():
|
| 123 |
+
out.add_(x)
|
| 124 |
+
elif not self.any():
|
| 125 |
+
return
|
| 126 |
+
else:
|
| 127 |
+
out[self.mask] += x
|