| | import inspect |
| | |
| | import os |
| | import os.path as osp |
| | import shutil |
| | import warnings |
| | from typing import List, Optional, Tuple, Union |
| |
|
| | |
| | |
| | |
| | import torch |
| | import torch.nn as nn |
| | from huggingface_hub import repo_exists, snapshot_download |
| | from huggingface_hub.utils import HFValidationError, validate_repo_id |
| | |
| | |
| | from transformers import (AutoConfig, AutoModel, AutoModelForCausalLM, |
| | AutoTokenizer, BitsAndBytesConfig, GenerationConfig, |
| | LlamaConfig, LlamaForCausalLM, PretrainedConfig, |
| | PreTrainedModel, SiglipImageProcessor, |
| | SiglipVisionModel) |
| | from transformers.modeling_outputs import CausalLMOutputWithPast |
| |
|
| | from .configuration_llava import LlavaConfig |
| | |
| | from .utils import get_model_config |
| |
|
| | CONTROLLER_HEART_BEAT_EXPIRATION = 30 |
| | WORKER_HEART_BEAT_INTERVAL = 15 |
| |
|
| | LOGDIR = "." |
| |
|
| | |
| | IGNORE_INDEX = -100 |
| | IMAGE_TOKEN_INDEX = -200 |
| | DEFAULT_IMAGE_TOKEN = "<image>" |
| | DEFAULT_IMAGE_PATCH_TOKEN = "<im_patch>" |
| | DEFAULT_IM_START_TOKEN = "<im_start>" |
| | DEFAULT_IM_END_TOKEN = "<im_end>" |
| | IMAGE_PLACEHOLDER = "<image-placeholder>" |
| |
|
| | def is_deepspeed_zero3_enabled(): |
| | return None |
| |
|
| | import torch |
| | import torch.nn as nn |
| | from transformers import (AutoConfig, AutoModel, PretrainedConfig, |
| | PreTrainedModel) |
| |
|
| |
|
| | class IdentityMap(nn.Module): |
| | def __init__(self): |
| | super().__init__() |
| |
|
| | def forward(self, x, *args, **kwargs): |
| | return x |
| |
|
| | @property |
| | def config(self): |
| | return {"mm_projector_type": "identity"} |
| |
|
| |
|
| | class SimpleResBlock(nn.Module): |
| | def __init__(self, channels): |
| | super().__init__() |
| | self.pre_norm = nn.LayerNorm(channels) |
| |
|
| | self.proj = nn.Sequential( |
| | nn.Linear(channels, channels), nn.GELU(), nn.Linear(channels, channels) |
| | ) |
| |
|
| | def forward(self, x): |
| | x = self.pre_norm(x) |
| | return x + self.proj(x) |
| |
|
| |
|
| | class DownSampleBlock(nn.Module): |
| | def forward(self, x): |
| | vit_embeds = x |
| | h = w = int(vit_embeds.shape[1] ** 0.5) |
| | vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], h, w, -1) |
| | vit_embeds = self.flat_square(vit_embeds) |
| | vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], -1, vit_embeds.shape[-1]) |
| | return vit_embeds |
| |
|
| | def flat_square(self, x): |
| | n, w, h, c = x.size() |
| | if w % 2 == 1: |
| | x = torch.concat( |
| | [x, torch.zeros((n, 1, h, c), dtype=x.dtype).to(x.device)], dim=1 |
| | ).contiguous() |
| | n, w, h, c = x.size() |
| | if h % 2 == 1: |
| | x = torch.concat( |
| | [x, torch.zeros((n, w, 1, c), dtype=x.dtype).to(x.device)], dim=2 |
| | ).contiguous() |
| | n, w, h, c = x.size() |
| | x = x.view(n, w, int(h / 2), int(c * 2)) |
| | x = x.permute(0, 2, 1, 3).contiguous() |
| | x = x.view(n, int(h / 2), int(w / 2), int(c * 4)) |
| | return x |
| |
|
| |
|
| | class MultimodalProjectorConfig(PretrainedConfig): |
| | model_type = "v2l_projector" |
| |
|
| | def __init__(self, mm_projector_type: str = None, **kwargs): |
| | super().__init__() |
| | self.mm_projector_type = mm_projector_type |
| |
|
| |
|
| | class MultimodalProjector(PreTrainedModel): |
| | config_class = MultimodalProjectorConfig |
| |
|
| | def __init__( |
| | self, mm_projector_cfg: MultimodalProjectorConfig, config: PretrainedConfig |
| | ): |
| | super().__init__(mm_projector_cfg) |
| | mm_projector_type = mm_projector_cfg.mm_projector_type |
| | if mm_projector_type == "identity": |
| | self.layers = IdentityMap() |
| | elif mm_projector_type == "linear": |
| | self.layers = nn.Linear(config.mm_hidden_size, config.hidden_size) |
| | elif mm_projector_type == "mlp_downsample": |
| | self.layers = nn.Sequential( |
| | DownSampleBlock(), |
| | nn.LayerNorm(config.mm_hidden_size * 4), |
| | nn.Linear(config.mm_hidden_size * 4, config.hidden_size), |
| | nn.GELU(), |
| | nn.Linear(config.hidden_size, config.hidden_size), |
| | ) |
| | else: |
| | mlp_gelu_match = re.match(r"^mlp(\d+)x_gelu$", mm_projector_type) |
| | if mlp_gelu_match: |
| | mlp_depth = int(mlp_gelu_match.group(1)) |
| | modules = [nn.Linear(config.mm_hidden_size, config.hidden_size)] |
| | for _ in range(1, mlp_depth): |
| | modules.append(nn.GELU()) |
| | modules.append(nn.Linear(config.hidden_size, config.hidden_size)) |
| | self.layers = nn.Sequential(*modules) |
| | else: |
| | raise ValueError(f"Unknown projector type: {mm_projector_type}") |
| |
|
| | def forward(self, x, *args, **kwargs): |
| | return self.layers(x) |
| | |
| | |
| | def build_mm_projector( |
| | model_type_or_path: str, config: PretrainedConfig |
| | ) -> PreTrainedModel: |
| | if model_type_or_path is None: |
| | return None |
| |
|
| | |
| | if config.resume_path: |
| | assert os.path.exists( |
| | model_type_or_path |
| | ), f"Resume mm projector path {model_type_or_path} does not exist!" |
| | return MultimodalProjector.from_pretrained( |
| | model_type_or_path, config, torch_dtype=eval(config.model_dtype) |
| | ) |
| | |
| | else: |
| | mm_projector_cfg = MultimodalProjectorConfig(model_type_or_path) |
| | mm_projector = MultimodalProjector(mm_projector_cfg, config).to( |
| | eval(config.model_dtype) |
| | ) |
| | return mm_projector |
| |
|
| |
|
| | class VisionTower(nn.Module): |
| | def __init__(self, vision_tower, args, delay_load=False): |
| | super().__init__() |
| |
|
| | self.is_loaded = False |
| |
|
| | self.vision_tower_name = vision_tower |
| | self.select_layer = getattr(args, "mm_vision_select_layer", -2) |
| | self.select_feature = getattr(args, "mm_vision_select_feature", "patch") |
| |
|
| | self.cfg_only = None |
| |
|
| | def feature_select(self, image_forward_outs): |
| | image_features = image_forward_outs.hidden_states[self.select_layer] |
| | if self.select_feature == "patch": |
| | image_features = image_features[:, 1:] |
| | elif self.select_feature == "cls_patch": |
| | image_features = image_features |
| | else: |
| | raise ValueError(f"Unexpected select feature: {self.select_feature}") |
| | return image_features |
| |
|
| | def _maybe_resize_pos_embeds( |
| | self, |
| | model: PreTrainedModel, |
| | image_processor, |
| | resolution: int = -1, |
| | interpolate_mode: str = "linear", |
| | ): |
| | if resolution in [model.config.image_size, -1]: |
| | return |
| | print( |
| | f"Resizing vision model's position embeddings to support higher vision resolution: from {model.config.image_size} to {resolution} ..." |
| | ) |
| | embeddings = model.vision_model.embeddings |
| | patch_size = embeddings.patch_size |
| | num_new_tokens = int((resolution // patch_size) ** 2) |
| |
|
| | old_embeddings = embeddings.position_embedding |
| | match interpolate_mode: |
| | case "linear": |
| | |
| | |
| | import torch |
| | import torch.nn as nn |
| |
|
| | |
| | old_num_tokens, old_embedding_dim = old_embeddings.weight.size() |
| | new_embeddings = nn.Embedding( |
| | num_new_tokens, |
| | old_embedding_dim, |
| | dtype=old_embeddings.weight.dtype, |
| | device=old_embeddings.weight.device, |
| | ) |
| | mapped_indices = ( |
| | torch.arange(num_new_tokens).to(old_embeddings.weight.device) |
| | / (num_new_tokens - 1) |
| | * (old_num_tokens - 1) |
| | ) |
| | floor_indices = torch.clamp( |
| | mapped_indices.floor().long(), min=0, max=old_num_tokens - 1 |
| | ) |
| | ceil_indices = torch.clamp( |
| | mapped_indices.ceil().long(), min=0, max=old_num_tokens - 1 |
| | ) |
| | if is_deepspeed_zero3_enabled(): |
| | params = [old_embeddings.weight, new_embeddings.weight] |
| | with deepspeed.zero.GatheredParameters(params, modifier_rank=0): |
| | interpolated_embeds = (mapped_indices - floor_indices)[ |
| | :, None |
| | ] * old_embeddings.weight.data[ceil_indices, :] + ( |
| | ceil_indices - mapped_indices |
| | )[ |
| | :, None |
| | ] * old_embeddings.weight.data[ |
| | floor_indices, : |
| | ] |
| | else: |
| | interpolated_embeds = (mapped_indices - floor_indices)[ |
| | :, None |
| | ] * old_embeddings.weight.data[ceil_indices, :] + ( |
| | ceil_indices - mapped_indices |
| | )[ |
| | :, None |
| | ] * old_embeddings.weight.data[ |
| | floor_indices, : |
| | ] |
| | new_embeddings.weight.data = interpolated_embeds |
| | case _: |
| | raise NotImplementedError |
| |
|
| | if hasattr(old_embeddings, "_hf_hook"): |
| | hook = old_embeddings._hf_hook |
| | |
| | |
| | new_embeddings.requires_grad_(old_embeddings.weight.requires_grad) |
| | |
| | model.config.image_size = resolution |
| | if hasattr(image_processor, "crop_size"): |
| | |
| | image_processor.crop_size = resolution |
| | else: |
| | |
| | assert hasattr(image_processor, "size") |
| | image_processor.size = {"height": resolution, "width": resolution} |
| | |
| | embeddings.position_embedding = new_embeddings |
| | embeddings.image_size = resolution |
| | embeddings.num_patches = embeddings.num_positions = num_new_tokens |
| | embeddings.position_ids = ( |
| | torch.arange(embeddings.num_positions) |
| | .expand((1, -1)) |
| | .to(old_embeddings.weight.device) |
| | ) |
| |
|
| | def forward(self, images): |
| | if type(images) is list: |
| | image_features = [] |
| | for image in images: |
| | image_forward_out = self.vision_tower( |
| | image.to(device=self.device, dtype=self.dtype).unsqueeze(0), |
| | output_hidden_states=True, |
| | ) |
| | image_feature = self.feature_select(image_forward_out).to(image.dtype) |
| | image_features.append(image_feature) |
| | else: |
| | image_forward_outs = self.vision_tower( |
| | images.to(device=self.device, dtype=self.dtype), |
| | output_hidden_states=True, |
| | ) |
| | image_features = self.feature_select(image_forward_outs).to(images.dtype) |
| |
|
| | return image_features |
| |
|
| | @property |
| | def dummy_feature(self): |
| | return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype) |
| |
|
| | @property |
| | def dtype(self): |
| | return self.vision_tower.dtype |
| |
|
| | @property |
| | def device(self): |
| | return self.vision_tower.device |
| |
|
| | @property |
| | def config(self): |
| | if self.is_loaded: |
| | return self.vision_tower.config |
| | else: |
| | return self.cfg_only |
| |
|
| | @property |
| | def hidden_size(self): |
| | return self.config.hidden_size |
| |
|
| | @property |
| | def num_patches(self): |
| | return (self.config.image_size // self.config.patch_size) ** 2 |
| |
|
| |
|
| | class SiglipVisionTower(VisionTower): |
| | def __init__( |
| | self, model_name_or_path: str, config: PretrainedConfig, state_dict=None |
| | ): |
| | super().__init__(model_name_or_path, config) |
| | self.image_processor = SiglipImageProcessor.from_pretrained(model_name_or_path) |
| | self.vision_tower = SiglipVisionModel.from_pretrained( |
| | |
| | model_name_or_path, |
| | torch_dtype=eval(config.model_dtype), |
| | state_dict=state_dict, |
| | ) |
| | self.is_loaded = True |
| |
|
| |
|
| |
|
| | def build_vision_tower( |
| | model_name_or_path: str, config: PretrainedConfig |
| | ) -> PreTrainedModel: |
| | |
| | if model_name_or_path is None: |
| | return None |
| |
|
| | vision_tower_arch = None |
| | if config.resume_path and "radio" not in model_name_or_path: |
| | assert os.path.exists( |
| | model_name_or_path |
| | ), f"Resume vision tower path {model_name_or_path} does not exist!" |
| | vision_tower_cfg = AutoConfig.from_pretrained( |
| | model_name_or_path, trust_remote_code=True |
| | ) |
| | vision_tower_arch = vision_tower_cfg.architectures[0].lower() |
| | vision_tower_name = ( |
| | vision_tower_arch if vision_tower_arch is not None else model_name_or_path |
| | ) |
| |
|
| | use_s2 = getattr(config, "s2", False) |
| |
|
| | if "siglip" in vision_tower_name: |
| | if use_s2: |
| | vision_tower = SiglipVisionTowerS2(model_name_or_path, config) |
| | else: |
| | vision_tower = SiglipVisionTower(model_name_or_path, config) |
| | else: |
| | raise ValueError(f"Unknown vision tower: {model_name_or_path}") |
| |
|
| | config.mm_hidden_size = ( |
| | vision_tower.config.hidden_size if not use_s2 else vision_tower.hidden_size |
| | ) |
| | return vision_tower |
| |
|
| |
|
| |
|
| | def has_tokenizer(repo_id_or_path: str) -> bool: |
| | |
| | if osp.exists(osp.join(repo_id_or_path, "tokenizer_config.json")): |
| | return True |
| |
|
| | |
| | try: |
| | return repo_exists(repo_id_or_path) and file_exists( |
| | repo_id_or_path, "tokenizer_config.json" |
| | ) |
| | except HFValidationError: |
| | return False |
| |
|
| |
|
| | def context_length_extension(config): |
| | orig_ctx_len = getattr(config, "max_position_embeddings", None) |
| | model_max_length = getattr(config, "model_max_length", None) |
| | if orig_ctx_len and model_max_length > orig_ctx_len: |
| | print(f"Scaling RoPE from {orig_ctx_len} to {model_max_length}") |
| | scaling_factor = float(math.ceil(model_max_length / orig_ctx_len)) |
| | config.rope_scaling = {"type": "linear", "factor": scaling_factor} |
| | return config |
| |
|
| |
|
| | def build_llm_and_tokenizer( |
| | model_name_or_path: str, |
| | config: PretrainedConfig, |
| | attn_implementation=None, |
| | model_max_length=None, |
| | *args, |
| | **kwargs, |
| | ): |
| | llm_cfg = AutoConfig.from_pretrained(model_name_or_path) |
| | llm_cfg._attn_implementation = attn_implementation |
| | llm_cfg.model_max_length = model_max_length |
| | if model_max_length is not None: |
| | context_length_extension(llm_cfg) |
| |
|
| | llm = AutoModelForCausalLM.from_pretrained( |
| | model_name_or_path, |
| | config=llm_cfg, |
| | torch_dtype=eval(config.model_dtype), |
| | *args, |
| | **kwargs, |
| | ) |
| |
|
| | |
| | llm_path = model_name_or_path |
| | if not has_tokenizer(llm_path): |
| | llm_path = osp.join(llm_path, "llm") |
| | if not has_tokenizer(llm_path): |
| | raise ValueError(f"Cannot find tokenizer in {llm_path}.") |
| |
|
| | |
| | try: |
| | llm_arch = getattr(llm_cfg, "architectures")[0].lower() |
| | except BaseException: |
| | warnings.warn( |
| | f'Cannot find LLM architecture, please check the "config.json" under "{llm_path}".' |
| | ) |
| |
|
| | if "mpt" in llm_arch: |
| | tokenizer = AutoTokenizer.from_pretrained( |
| | llm_path, |
| | model_max_length=llm_cfg.model_max_length, |
| | padding_side="right", |
| | ) |
| | elif "yi" in llm_path or ( |
| | getattr(llm_cfg, "num_hidden_layers", -1) == 60 |
| | and getattr(llm_cfg, "num_attention_heads", -1) == 56 |
| | ): |
| | tokenizer = AutoTokenizer.from_pretrained( |
| | llm_path, |
| | model_max_length=llm_cfg.model_max_length, |
| | padding_side="right", |
| | use_fast=False, |
| | ) |
| | else: |
| | tokenizer = AutoTokenizer.from_pretrained( |
| | llm_path, |
| | model_max_length=llm_cfg.model_max_length, |
| | padding_side="right", |
| | use_fast=False, |
| | legacy=False, |
| | ) |
| |
|
| | |
| | config.hidden_size = llm.config.hidden_size |
| | return llm, tokenizer |
| |
|
| |
|
| | def is_mm_model(model_path): |
| | """ |
| | Check if the model at the given path is a visual language model. |
| | |
| | Args: |
| | model_path (str): The path to the model. |
| | |
| | Returns: |
| | bool: True if the model is an MM model, False otherwise. |
| | """ |
| | config = AutoConfig.from_pretrained(model_path) |
| | architectures = config.architectures |
| | for architecture in architectures: |
| | if "llava" in architecture.lower(): |
| | return True |
| | return False |
| |
|
| |
|
| | def load_pretrained_model( |
| | model_path, |
| | model_name, |
| | model_base=None, |
| | load_8bit=False, |
| | load_4bit=False, |
| | device_map="auto", |
| | device="cuda", |
| | **kwargs, |
| | ): |
| | kwargs = {"device_map": device_map, **kwargs} |
| |
|
| | if device != "cuda": |
| | kwargs["device_map"] = {"": device} |
| |
|
| | if load_8bit: |
| | kwargs["load_in_8bit"] = True |
| | elif load_4bit: |
| | kwargs["load_in_4bit"] = True |
| | kwargs["quantization_config"] = BitsAndBytesConfig( |
| | load_in_4bit=True, |
| | bnb_4bit_compute_dtype=torch.float16, |
| | bnb_4bit_use_double_quant=True, |
| | bnb_4bit_quant_type="nf4", |
| | ) |
| | else: |
| | kwargs["torch_dtype"] = torch.float16 |
| | |
| |
|
| | if is_mm_model(model_path): |
| | |
| | |
| | if "lora" in model_name.lower() and model_base is None: |
| | warnings.warn( |
| | "There is `lora` in model name but no `model_base` is provided. If you are loading a LoRA model, please provide the `model_base` argument. Detailed instruction: https://github.com/haotian-liu/LLaVA#launch-a-model-worker-lora-weights-unmerged." |
| | ) |
| | if ( |
| | "lora" in model_name.lower() or "dora" in model_name.lower() |
| | ) and model_base is not None: |
| | lora_cfg_pretrained = AutoConfig.from_pretrained(model_path) |
| | print(lora_cfg_pretrained) |
| | print("Loading LLaVA from base model...") |
| | config = AutoConfig.from_pretrained(model_base) |
| | prepare_config_for_eval(config, kwargs) |
| | model = LlavaLlamaModel.from_pretrained( |
| | model_base, low_cpu_mem_usage=True, config=config, **kwargs |
| | ) |
| | tokenizer = model.tokenizer |
| | token_num, tokem_dim = ( |
| | model.llm.lm_head.out_features, |
| | model.llm.lm_head.in_features, |
| | ) |
| | if model.llm.lm_head.weight.shape[0] != token_num: |
| | model.llm.lm_head.weight = torch.nn.Parameter( |
| | torch.empty( |
| | token_num, tokem_dim, device=model.device, dtype=model.dtype |
| | ) |
| | ) |
| | model.llm.embed_tokens.weight = torch.nn.Parameter( |
| | torch.empty( |
| | token_num, tokem_dim, device=model.device, dtype=model.dtype |
| | ) |
| | ) |
| |
|
| | print("Loading additional LLaVA weights...") |
| | if os.path.exists(os.path.join(model_path, "non_lora_trainables.bin")): |
| | non_lora_trainables = torch.load( |
| | os.path.join(model_path, "non_lora_trainables.bin"), |
| | map_location="cpu", |
| | ) |
| | else: |
| | |
| | from huggingface_hub import hf_hub_download |
| |
|
| | def load_from_hf(repo_id, filename, subfolder=None): |
| | cache_file = hf_hub_download( |
| | repo_id=repo_id, filename=filename, subfolder=subfolder |
| | ) |
| | return torch.load(cache_file, map_location="cpu") |
| |
|
| | non_lora_trainables = load_from_hf( |
| | model_path, "non_lora_trainables.bin" |
| | ) |
| | non_lora_trainables = { |
| | (k[11:] if k.startswith("base_model.") else k): v |
| | for k, v in non_lora_trainables.items() |
| | } |
| | if any(k.startswith("model.model.") for k in non_lora_trainables): |
| | non_lora_trainables = { |
| | (k[6:] if k.startswith("model.") else k): v |
| | for k, v in non_lora_trainables.items() |
| | } |
| | model.load_state_dict(non_lora_trainables, strict=False) |
| |
|
| | from peft import PeftModel |
| |
|
| | print("Loading LoRA weights...") |
| | model = PeftModel.from_pretrained(model, model_path) |
| | print("Merging LoRA weights...") |
| | model = model.merge_and_unload() |
| | print("Model is loaded...") |
| | |
| | elif model_base is not None: |
| | |
| | print("Loading LLaVA from base model...") |
| | cfg_pretrained = AutoConfig.from_pretrained( |
| | model_path, trust_remote_code=True |
| | ) |
| | mm_config_wrapper(config, kwargs) |
| | if "mpt" in model_name.lower(): |
| | if not os.path.isfile(os.path.join(model_path, "configuration_mpt.py")): |
| | shutil.copyfile( |
| | os.path.join(model_base, "configuration_mpt.py"), |
| | os.path.join(model_path, "configuration_mpt.py"), |
| | ) |
| | tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=True) |
| | model = LlavaMPTForCausalLM.from_pretrained( |
| | model_base, low_cpu_mem_usage=True, config=cfg_pretrained, **kwargs |
| | ) |
| | else: |
| | tokenizer = AutoTokenizer.from_pretrained( |
| | model_base, use_fast=False, legacy=False |
| | ) |
| | model = LlavaLlamaForCausalLM.from_pretrained( |
| | model_base, low_cpu_mem_usage=True, config=cfg_pretrained, **kwargs |
| | ) |
| | else: |
| | config = AutoConfig.from_pretrained(model_path) |
| | config.resume_path = model_path |
| | prepare_config_for_eval(config, kwargs) |
| | if "mpt" in model_name.lower(): |
| | model = LlavaMPTForCausalLM.from_pretrained( |
| | model_path, config=config, low_cpu_mem_usage=True, **kwargs |
| | ) |
| | elif "mistral" in model_name.lower() or "mixtral" in model_name.lower(): |
| | model = LlavaMistralForCausalLM.from_pretrained( |
| | model_path, config=config, low_cpu_mem_usage=True, **kwargs |
| | ) |
| | elif "gemma" in model_name.lower(): |
| | model = LlavaGemmaForCausalLM.from_pretrained( |
| | model_path, config=config, low_cpu_mem_usage=True, **kwargs |
| | ) |
| | else: |
| | |
| | |
| | model = LlavaLlamaModel(config=config, low_cpu_mem_usage=True, **kwargs) |
| | tokenizer = model.tokenizer |
| | else: |
| | |
| | if model_base is not None: |
| | |
| | from peft import PeftModel |
| |
|
| | tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False) |
| | model = AutoModelForCausalLM.from_pretrained( |
| | model_base, low_cpu_mem_usage=True, **kwargs |
| | ) |
| | print(f"Loading LoRA weights from {model_path}") |
| | model = PeftModel.from_pretrained(model, model_path) |
| | print(f"Merging weights") |
| | model = model.merge_and_unload() |
| | print("Convert to FP16...") |
| | model.to(torch.float16) |
| | else: |
| | if "mpt" in model_name.lower(): |
| | tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=True) |
| | model = AutoModelForCausalLM.from_pretrained( |
| | model_path, low_cpu_mem_usage=True, trust_remote_code=True, **kwargs |
| | ) |
| | else: |
| | tokenizer = AutoTokenizer.from_pretrained( |
| | model_path, use_fast=False, legacy=False |
| | ) |
| | model = AutoModelForCausalLM.from_pretrained( |
| | model_path, low_cpu_mem_usage=True, **kwargs |
| | ) |
| | model.eval() |
| | image_processor = None |
| | if is_mm_model(model_path): |
| | mm_use_im_start_end = getattr(model.config, "mm_use_im_start_end", False) |
| | mm_use_im_patch_token = getattr(model.config, "mm_use_im_patch_token", True) |
| | if mm_use_im_patch_token: |
| | tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True) |
| | if mm_use_im_start_end: |
| | tokenizer.add_tokens( |
| | [DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True |
| | ) |
| | model.resize_token_embeddings(len(tokenizer)) |
| | vision_tower = model.get_vision_tower() |
| | vision_tower.to(device=device, dtype=torch.float16) |
| | |
| | mm_projector = model.get_mm_projector() |
| | mm_projector.to(device=device, dtype=torch.float16) |
| | |
| | image_processor = vision_tower.image_processor |
| |
|
| | if hasattr(model.llm.config, "max_sequence_length"): |
| | context_len = model.config.max_sequence_length |
| | else: |
| | context_len = 2048 |
| |
|
| | return tokenizer, model, image_processor, context_len |
| |
|
| |
|
| | def parse_model_name_or_path(config: PretrainedConfig, model_name="llm", suffix="_cfg"): |
| | target_model = f"{model_name}{suffix}" |
| | target_cfg = getattr(config, target_model, None) |
| |
|
| | if isinstance(target_cfg, str): |
| | return target_cfg |
| | elif isinstance(target_cfg, dict): |
| | return target_cfg["architectures"][0] |
| | else: |
| | raise ValueError(f"Invalid {target_model} configuration!") |
| |
|
| |
|
| | def prepare_config_for_eval(config: PretrainedConfig, kwargs: dict): |
| | try: |
| | |
| | if getattr(config, "vision_tower_cfg", None) is None: |
| | config.vision_tower_cfg = config.mm_vision_tower |
| | except AttributeError: |
| | raise ValueError( |
| | f"Invalid configuration! Cannot find vision_tower in config:\n{config}" |
| | ) |
| |
|
| | config.model_dtype = kwargs.pop("torch_dtype").__str__() |
| | |
| | vision_tower_name = parse_model_name_or_path(config, "vision_tower") |
| | if "siglip" in vision_tower_name.lower(): |
| | kwargs["device_map"] = "cuda" |
| |
|
| |
|
| | class LlavaLlamaConfig(LlavaConfig): |
| | model_type = "llava_llama" |
| |
|
| |
|
| | |
| | |
| | |
| | |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | from abc import ABC, abstractmethod |
| | from collections import OrderedDict |
| |
|
| |
|
| | class LlavaMetaModel(ABC): |
| | def init_vlm(self, config: PreTrainedModel = None, *args, **kwargs): |
| | |
| | if ( |
| | hasattr(self, "llm") |
| | or hasattr(self, "vision_tower") |
| | or hasattr(self, "mm_projector") |
| | ): |
| | |
| | return |
| |
|
| | model_dtype = getattr(config, "model_dtype", "torch.float16") |
| | if not hasattr(config, "model_dtype"): |
| | warnings.warn( |
| | "model_dtype not found in config, defaulting to torch.float16." |
| | ) |
| | config.model_dtype = model_dtype |
| |
|
| | cfgs = get_model_config(config) |
| | if len(cfgs) == 3: |
| | llm_cfg, vision_tower_cfg, mm_projector_cfg = cfgs |
| | else: |
| | raise ValueError( |
| | "`llm_cfg` `mm_projector_cfg` `vision_tower_cfg` not found in the config." |
| | ) |
| |
|
| | self.llm, self.tokenizer = build_llm_and_tokenizer( |
| | llm_cfg, config, *args, **kwargs |
| | ) |
| | self.vision_tower = build_vision_tower(vision_tower_cfg, config) |
| | self.mm_projector = build_mm_projector(mm_projector_cfg, config) |
| |
|
| | self.post_config() |
| | self.is_loaded = True |
| |
|
| | assert ( |
| | self.llm is not None |
| | or self.vision_tower is not None |
| | or self.mm_projector is not None |
| | ), "At least one of the components must be instantiated." |
| |
|
| | @classmethod |
| | def load_from_config(cls, model_path_or_config, *args, **kwargs): |
| | pass |
| |
|
| | |
| | @classmethod |
| | def load_pretrained(cls, model_path_or_config, *args, **kwargs): |
| | kwargs.pop("config", None) |
| |
|
| | if isinstance(model_path_or_config, str): |
| | config = AutoConfig.from_pretrained(model_path_or_config) |
| | elif isinstance(model_path_or_config, LlavaConfig): |
| | config = model_path_or_config |
| | else: |
| | raise NotImplementedError( |
| | f"wrong type, {type(model_path_or_config)} \ |
| | {isinstance(model_path_or_config, LlavaConfig)}" |
| | ) |
| |
|
| | model_dtype = getattr(config, "model_dtype", "torch.float16") |
| | if not hasattr(config, "model_dtype"): |
| | warnings.warn( |
| | "model_dtype not found in config, defaulting to torch.float16." |
| | ) |
| | config.model_dtype = model_dtype |
| |
|
| | cfgs = get_model_config(config) |
| | if len(cfgs) == 3: |
| | llm_cfg, vision_tower_cfg, mm_projector_cfg = cfgs |
| | else: |
| | raise ValueError( |
| | "`llm_cfg` `mm_projector_cfg` `vision_tower_cfg` not found in the config." |
| | ) |
| |
|
| | vlm = cls(config, *args, **kwargs) |
| | |
| |
|
| | if ( |
| | hasattr(vlm, "llm") |
| | or hasattr(vlm, "vision_tower") |
| | or hasattr(vlm, "mm_projector") |
| | ): |
| | if vlm.is_loaded: |
| | return vlm |
| |
|
| | vlm.llm, vlm.tokenizer = build_llm_and_tokenizer( |
| | llm_cfg, config, *args, **kwargs |
| | ) |
| | vlm.vision_tower = build_vision_tower(vision_tower_cfg, config) |
| | vlm.mm_projector = build_mm_projector(mm_projector_cfg, config) |
| |
|
| | cls.post_config() |
| | cls.is_loaded = True |
| |
|
| | |
| | assert ( |
| | vlm.llm is not None |
| | or vlm.vision_tower is not None |
| | or vlm.mm_projector is not None |
| | ), "At least one of the components must be instantiated." |
| | return vlm |
| |
|
| | |
| | def save_pretrained(self, output_dir, state_dict=None): |
| | if state_dict is None: |
| | |
| | |
| | state_dict = self.state_dict() |
| |
|
| | if getattr(self, "tokenizer", None): |
| | self.tokenizer.save_pretrained(osp.join(output_dir, "llm")) |
| |
|
| | if self.get_llm(): |
| | print(f"saving llm to {osp.join(output_dir, 'llm')}") |
| | self.llm.config._name_or_path = osp.join(output_dir, "llm") |
| | llm_state_dict = OrderedDict( |
| | {k.split("llm.")[-1]: v for k, v in state_dict.items() if "llm" in k} |
| | ) |
| | self.llm.save_pretrained( |
| | os.path.join(output_dir, "llm"), state_dict=llm_state_dict |
| | ) |
| | self.config.llm_cfg = self.llm.config |
| |
|
| | if self.get_vision_tower(): |
| | print(f"saving vision_tower to {osp.join(output_dir, 'vision_tower')}") |
| | self.vision_tower.config._name_or_path = osp.join( |
| | output_dir, "vision_tower" |
| | ) |
| | vision_tower_state_dict = OrderedDict( |
| | { |
| | k.split("vision_tower.vision_tower.")[-1]: v |
| | for k, v in state_dict.items() |
| | if "vision_tower" in k |
| | } |
| | ) |
| | self.vision_tower.vision_tower.save_pretrained( |
| | os.path.join(output_dir, "vision_tower"), |
| | state_dict=vision_tower_state_dict, |
| | ) |
| | self.vision_tower.image_processor.save_pretrained( |
| | os.path.join(output_dir, "vision_tower") |
| | ) |
| | self.config.vision_tower_cfg = self.vision_tower.config |
| | if hasattr(self.config.vision_tower_cfg, "auto_map"): |
| | if "radio" not in self.get_vision_tower().__class__.__name__.lower(): |
| | delattr(self.config.vision_tower_cfg, "auto_map") |
| |
|
| | if self.get_mm_projector(): |
| | print(f"saving mm_projector to {osp.join(output_dir, 'mm_projector')}") |
| | self.mm_projector.config._name_or_path = osp.join( |
| | output_dir, "mm_projector" |
| | ) |
| | mm_projector_state_dict = OrderedDict( |
| | { |
| | k.split("mm_projector.")[-1]: v |
| | for k, v in state_dict.items() |
| | if "mm_projector" in k |
| | } |
| | ) |
| | self.mm_projector.save_pretrained( |
| | os.path.join(output_dir, "mm_projector"), |
| | state_dict=mm_projector_state_dict, |
| | ) |
| | self.config.mm_projector_cfg = self.mm_projector.config |
| | |
| | self.config._name_or_path = output_dir |
| | self.config.architectures = [self.__class__.__name__] |
| | self.config.save_pretrained(output_dir) |
| |
|
| | def get_llm(self): |
| | llm = getattr(self, "llm", None) |
| | if type(llm) is list: |
| | llm = llm[0] |
| | return llm |
| |
|
| | def get_lm_head(self): |
| | lm_head = getattr(self.get_llm(), "lm_head", None) |
| | return lm_head |
| |
|
| | def get_vision_tower(self): |
| | vision_tower = getattr(self, "vision_tower", None) |
| | if type(vision_tower) is list: |
| | vision_tower = vision_tower[0] |
| | return vision_tower |
| |
|
| | def get_mm_projector(self): |
| | mm_projector = getattr(self, "mm_projector", None) |
| | if type(mm_projector) is list: |
| | mm_projector = mm_projector[0] |
| | return mm_projector |
| |
|
| | def post_config(self): |
| | self.training = self.get_llm().training |
| | |
| | if getattr(self.config, "llm_cfg", None) is None: |
| | self.config.llm_cfg = self.llm.config |
| | if getattr(self.config, "vision_tower_cfg", None) is None: |
| | self.config.vision_tower_cfg = self.vision_tower.config |
| | if getattr(self.config, "mm_projector_cfg", None) is None: |
| | self.config.mm_projector_cfg = self.mm_projector.config |
| |
|
| | def freezed_module_patch(self): |
| | """ |
| | Huggingface will call model.train() at each training_step. To ensure the expected behaviors for modules like dropout, batchnorm, etc., we need to call model.eval() for the freezed modules. |
| | """ |
| | if self.training: |
| | if self.get_llm() and not getattr( |
| | self.config, "tune_language_model", False |
| | ): |
| | pass |
| | |
| | if self.get_vision_tower() and not getattr( |
| | self.config, "tune_vision_tower", False |
| | ): |
| | self.get_vision_tower().eval() |
| | if self.get_mm_projector() and not getattr( |
| | self.config, "tune_mm_projector", False |
| | ): |
| | self.get_mm_projector().eval() |
| |
|
| | def encode_images(self, images): |
| | image_features = self.get_vision_tower()(images) |
| | image_features = self.get_mm_projector()(image_features) |
| | return image_features |
| |
|
| | |
| | |
| | def _temporary_reorder_cache(self, past_key_values, sorted_idx): |
| | return self.get_llm()._temporary_reorder_cache(past_key_values, sorted_idx) |
| |
|
| | def get_input_embeddings(self): |
| | return self.get_llm().get_input_embeddings() |
| |
|
| | def get_output_embeddings(self): |
| | return self.get_llm().get_output_embeddings() |
| |
|
| | def resize_token_embeddings(self, embed_size): |
| | self.get_llm().resize_token_embeddings(embed_size) |
| |
|
| |
|
| | |
| | class LlavaLlamaModel(LlavaMetaModel, PreTrainedModel): |
| | config_class = LlavaLlamaConfig |
| | main_input_name = "input_embeds" |
| | supports_gradient_checkpointing = True |
| |
|
| | def __init__(self, config: LlavaLlamaConfig = None, *args, **kwargs) -> None: |
| | super().__init__(config) |
| | return self.init_vlm(config=config, *args, **kwargs) |
| |
|
| | @classmethod |
| | def from_pretrained( |
| | cls, |
| | pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], |
| | *model_args, |
| | config: Optional[Union[PretrainedConfig, str, os.PathLike]] = None, |
| | cache_dir: Optional[Union[str, os.PathLike]] = None, |
| | ignore_mismatched_sizes: bool = False, |
| | force_download: bool = False, |
| | local_files_only: bool = False, |
| | token: Optional[Union[str, bool]] = None, |
| | revision: str = "main", |
| | use_safetensors: bool = None, |
| | **kwargs, |
| | ): |
| | if hasattr(cls, "load_pretrained"): |
| | return cls.load_pretrained( |
| | pretrained_model_name_or_path, |
| | *model_args, |
| | config=config, |
| | cache_dir=cache_dir, |
| | ignore_mismatched_sizes=ignore_mismatched_sizes, |
| | force_download=force_download, |
| | local_files_only=local_files_only, |
| | token=token, |
| | revision=revision, |
| | use_safetensors=use_safetensors, |
| | **kwargs, |
| | ) |
| | return super(LlavaLlamaModel).from_pretrained( |
| | pretrained_model_name_or_path, |
| | *model_args, |
| | config=config, |
| | cache_dir=cache_dir, |
| | ignore_mismatched_sizes=ignore_mismatched_sizes, |
| | force_download=force_download, |
| | local_files_only=local_files_only, |
| | token=token, |
| | revision=revision, |
| | use_safetensors=use_safetensors, |
| | **kwargs, |
| | ) |
| |
|
| | def forward( |
| | self, |
| | input_ids: torch.LongTensor = None, |
| | images: Optional[torch.FloatTensor] = None, |
| | attention_mask: Optional[torch.Tensor] = None, |
| | position_ids: Optional[torch.LongTensor] = None, |
| | past_key_values: Optional[List[torch.FloatTensor]] = None, |
| | seqlens_in_batch: Optional[torch.LongTensor] = None, |
| | inputs_embeds: Optional[torch.FloatTensor] = None, |
| | labels: Optional[torch.LongTensor] = None, |
| | use_cache: Optional[bool] = None, |
| | output_attentions: Optional[bool] = None, |
| | output_hidden_states: Optional[bool] = None, |
| | return_dict: Optional[bool] = None, |
| | dpo_forward: bool = False, |
| | ) -> Union[Tuple, CausalLMOutputWithPast]: |
| | self.freezed_module_patch() |
| | if inputs_embeds is None: |
| | ( |
| | input_ids, |
| | position_ids, |
| | attention_mask, |
| | past_key_values, |
| | inputs_embeds, |
| | labels, |
| | ) = self.prepare_inputs_labels_for_multimodal( |
| | input_ids, position_ids, attention_mask, past_key_values, labels, images |
| | ) |
| |
|
| | support_packing = ( |
| | "seqlens_in_batch" in inspect.signature(self.llm.forward).parameters |
| | ) |
| |
|
| | if self.training and support_packing and not dpo_forward: |
| | ( |
| | _, |
| | new_position_ids, |
| | new_attention_mask, |
| | _, |
| | new_inputs_embeds, |
| | new_labels, |
| | sorted_seqlens_in_batch, |
| | ) = self.repack_multimodal_data( |
| | input_ids, |
| | position_ids, |
| | attention_mask, |
| | past_key_values, |
| | inputs_embeds, |
| | labels, |
| | ) |
| | if sorted_seqlens_in_batch is None: |
| | sorted_seqlens_in_batch = seqlens_in_batch |
| | new_input_ids = None |
| | past_key_values = None |
| | else: |
| | new_attention_mask = attention_mask |
| | new_position_ids = position_ids |
| | new_inputs_embeds = inputs_embeds |
| | new_labels = labels |
| | sorted_seqlens_in_batch = attention_mask.sum(-1).int() |
| | new_input_ids = input_ids |
| |
|
| | if support_packing: |
| | outputs = self.llm.forward( |
| | input_ids=new_input_ids, |
| | attention_mask=new_attention_mask, |
| | position_ids=new_position_ids, |
| | past_key_values=past_key_values, |
| | inputs_embeds=new_inputs_embeds, |
| | labels=new_labels, |
| | use_cache=use_cache, |
| | output_attentions=output_attentions, |
| | output_hidden_states=output_hidden_states, |
| | return_dict=return_dict, |
| | seqlens_in_batch=sorted_seqlens_in_batch, |
| | ) |
| | else: |
| | outputs = self.llm.forward( |
| | input_ids=new_input_ids, |
| | attention_mask=new_attention_mask, |
| | position_ids=new_position_ids, |
| | past_key_values=past_key_values, |
| | inputs_embeds=new_inputs_embeds, |
| | labels=new_labels, |
| | use_cache=use_cache, |
| | output_attentions=output_attentions, |
| | output_hidden_states=output_hidden_states, |
| | return_dict=return_dict, |
| | ) |
| |
|
| | if dpo_forward: |
| | return outputs.logits, new_labels |
| | return outputs |
| |
|
| | @torch.no_grad() |
| | def generate( |
| | self, |
| | input_ids: Optional[torch.FloatTensor] = None, |
| | images: Optional[torch.FloatTensor] = None, |
| | attention_mask: Optional[torch.LongTensor] = None, |
| | **generation_kwargs, |
| | ): |
| | if images is not None: |
| | ( |
| | _, |
| | _, |
| | attention_mask, |
| | _, |
| | inputs_embeds, |
| | _, |
| | ) = self.prepare_inputs_labels_for_multimodal( |
| | input_ids, None, attention_mask, None, None, images |
| | ) |
| | else: |
| | inputs_embeds = self.get_input_embeddings()(input_ids) |
| | inputs_embeds = inputs_embeds.to(self.dtype) |
| |
|
| | outputs = self.llm.generate( |
| | inputs_embeds=inputs_embeds, |
| | attention_mask=attention_mask, |
| | **generation_kwargs, |
| | ) |
| | return outputs |
| |
|
| |
|
| | |
| | |
| |
|