import math from typing import List, Tuple import torch import torch.nn as nn import torch.nn.functional as F from clip import load, tokenize from .simple_tokenizer import SimpleTokenizer as _Tokenizer from data.imagnet_prompts import imagenet_classes from data.fewshot_datasets import fewshot_datasets from data.cls_to_names import * from utils.ModelStock import stock_model _tokenizer = _Tokenizer() DOWNLOAD_ROOT='~/.cache/clip' class ClipImageEncoder(nn.Module): def __init__(self, device, arch="ViT-L/14", image_resolution=224, n_class=1000): super(ClipImageEncoder, self).__init__() clip, embed_dim, _ = load(arch, device=device, download_root=DOWNLOAD_ROOT) self.encoder = clip.visual del clip.transformer torch.cuda.empty_cache() self.cls_head = nn.Linear(embed_dim, n_class) @property def dtype(self): return self.encoder.conv1.weight.dtype def forward(self, image): x = self.encoder(image.type(self.dtype)) output = self.cls_head(x) return output class TextEncoder(nn.Module): def __init__(self, clip_model): super().__init__() self.transformer = clip_model.transformer self.positional_embedding = clip_model.positional_embedding self.ln_final = clip_model.ln_final self.text_projection = clip_model.text_projection self.dtype = clip_model.dtype def forward(self, prompts, tokenized_prompts): x = prompts + self.positional_embedding.type(self.dtype) x = x.permute(1, 0, 2) # NLD -> LND x = self.transformer(x) x = x.permute(1, 0, 2) # LND -> NLD x = self.ln_final(x).type(self.dtype) # x.shape = [batch_size, n_ctx, transformer.width] # take features from the eot embedding (eot_token is the highest number in each sequence) x = x[torch.arange(x.shape[0]), tokenized_prompts.argmax(dim=-1)] @ self.text_projection return x class PromptLearner(nn.Module): def __init__(self, clip_model, classnames, batch_size=None, n_ctx=16, ctx_init=None, ctx_position='end', learned_cls=False): super().__init__() n_cls = len(classnames) self.learned_cls = learned_cls dtype = clip_model.dtype self.dtype = dtype self.device = clip_model.visual.conv1.weight.device ctx_dim = clip_model.ln_final.weight.shape[0] self.ctx_dim = ctx_dim self.batch_size = batch_size # self.ctx, prompt_prefix = self.reset_prompt(ctx_dim, ctx_init, clip_model) if ctx_init: # use given words to initialize context vectors print("Initializing the contect with given words: [{}]".format(ctx_init)) ctx_init = ctx_init.replace("_", " ") if '[CLS]' in ctx_init: ctx_list = ctx_init.split(" ") split_idx = ctx_list.index("[CLS]") ctx_init = ctx_init.replace("[CLS] ", "") ctx_position = "middle" else: split_idx = None self.split_idx = split_idx n_ctx = len(ctx_init.split(" ")) prompt = tokenize(ctx_init).to(self.device) with torch.no_grad(): embedding = clip_model.token_embedding(prompt).type(dtype) ctx_vectors = embedding[0, 1 : 1 + n_ctx, :] prompt_prefix = ctx_init else: print("Random initialization: initializing a generic context") ctx_vectors = torch.empty(n_ctx, ctx_dim, dtype=dtype) nn.init.normal_(ctx_vectors, std=0.02) prompt_prefix = " ".join(["X"] * n_ctx) self.prompt_prefix = prompt_prefix print(f'Initial context: "{prompt_prefix}"') print(f"Number of context words (tokens): {n_ctx}") # batch-wise prompt tuning for test-time adaptation if self.batch_size is not None: ctx_vectors = ctx_vectors.repeat(batch_size, 1, 1) #(N, L, D) self.ctx_init_state = ctx_vectors.detach().clone() self.ctx = nn.Parameter(ctx_vectors) # to be optimized if not self.learned_cls: classnames = [name.replace("_", " ") for name in classnames] name_lens = [len(_tokenizer.encode(name)) for name in classnames] prompts = [prompt_prefix + " " + name + "." for name in classnames] else: print("Random initialization: initializing a learnable class token") cls_vectors = torch.empty(n_cls, 1, ctx_dim, dtype=dtype) # assume each learnable cls_token is only 1 word nn.init.normal_(cls_vectors, std=0.02) cls_token = "X" name_lens = [1 for _ in classnames] prompts = [prompt_prefix + " " + cls_token + "." for _ in classnames] self.cls_init_state = cls_vectors.detach().clone() self.cls = nn.Parameter(cls_vectors) # to be optimized tokenized_prompts = torch.cat([tokenize(p) for p in prompts]).to(self.device) with torch.no_grad(): embedding = clip_model.token_embedding(tokenized_prompts).type(dtype) # These token vectors will be saved when in save_model(), # but they should be ignored in load_model() as we want to use # those computed using the current class names self.register_buffer("token_prefix", embedding[:, :1, :]) # SOS if self.learned_cls: self.register_buffer("token_suffix", embedding[:, 1 + n_ctx + 1:, :]) # ..., EOS else: self.register_buffer("token_suffix", embedding[:, 1 + n_ctx :, :]) # CLS, EOS self.ctx_init = ctx_init self.tokenized_prompts = tokenized_prompts # torch.Tensor self.name_lens = name_lens self.class_token_position = ctx_position self.n_cls = n_cls self.n_ctx = n_ctx self.classnames = classnames def reset(self): ctx_vectors = self.ctx_init_state self.ctx.copy_(ctx_vectors) # to be optimized if self.learned_cls: cls_vectors = self.cls_init_state self.cls.copy_(cls_vectors) def reset_classnames(self, classnames, arch): self.n_cls = len(classnames) if not self.learned_cls: classnames = [name.replace("_", " ") for name in classnames] name_lens = [len(_tokenizer.encode(name)) for name in classnames] prompts = [self.prompt_prefix + " " + name + "." for name in classnames] else: cls_vectors = torch.empty(self.n_cls, 1, self.ctx_dim, dtype=self.dtype) # assume each learnable cls_token is only 1 word nn.init.normal_(cls_vectors, std=0.02) cls_token = "X" name_lens = [1 for _ in classnames] prompts = [self.prompt_prefix + " " + cls_token + "." for _ in classnames] # TODO: re-init the cls parameters # self.cls = nn.Parameter(cls_vectors) # to be optimized self.cls_init_state = cls_vectors.detach().clone() tokenized_prompts = torch.cat([tokenize(p) for p in prompts]).to(self.device) clip, _, _ = load(arch, device=self.device, download_root=DOWNLOAD_ROOT) with torch.no_grad(): embedding = clip.token_embedding(tokenized_prompts).type(self.dtype) self.token_prefix = embedding[:, :1, :] self.token_suffix = embedding[:, 1 + self.n_ctx :, :] # CLS, EOS self.name_lens = name_lens self.tokenized_prompts = tokenized_prompts self.classnames = classnames def forward(self, init=None): # the init will be used when computing CLIP directional loss if init is not None: ctx = init else: ctx = self.ctx if ctx.dim() == 2: ctx = ctx.unsqueeze(0).expand(self.n_cls, -1, -1) elif not ctx.size()[0] == self.n_cls: ctx = ctx.unsqueeze(1).expand(-1, self.n_cls, -1, -1) prefix = self.token_prefix suffix = self.token_suffix if self.batch_size is not None: # This way only works for single-gpu setting (could pass batch size as an argument for forward()) prefix = prefix.repeat(self.batch_size, 1, 1, 1) suffix = suffix.repeat(self.batch_size, 1, 1, 1) if self.learned_cls: assert self.class_token_position == "end" if self.class_token_position == "end": if self.learned_cls: cls = self.cls prompts = torch.cat( [ prefix, # (n_cls, 1, dim) ctx, # (n_cls, n_ctx, dim) cls, # (n_cls, 1, dim) suffix, # (n_cls, *, dim) ], dim=-2, ) else: prompts = torch.cat( [ prefix, # (n_cls, 1, dim) ctx, # (n_cls, n_ctx, dim) suffix, # (n_cls, *, dim) ], dim=-2, ) elif self.class_token_position == "middle": # TODO: to work with a batch of prompts if self.split_idx is not None: half_n_ctx = self.split_idx # split the ctx at the position of [CLS] in `ctx_init` else: half_n_ctx = self.n_ctx // 2 prompts = [] for i in range(self.n_cls): name_len = self.name_lens[i] prefix_i = prefix[i : i + 1, :, :] class_i = suffix[i : i + 1, :name_len, :] suffix_i = suffix[i : i + 1, name_len:, :] ctx_i_half1 = ctx[i : i + 1, :half_n_ctx, :] ctx_i_half2 = ctx[i : i + 1, half_n_ctx:, :] prompt = torch.cat( [ prefix_i, # (1, 1, dim) ctx_i_half1, # (1, n_ctx//2, dim) class_i, # (1, name_len, dim) ctx_i_half2, # (1, n_ctx//2, dim) suffix_i, # (1, *, dim) ], dim=1, ) prompts.append(prompt) prompts = torch.cat(prompts, dim=0) elif self.class_token_position == "front": prompts = [] for i in range(self.n_cls): name_len = self.name_lens[i] prefix_i = prefix[i : i + 1, :, :] class_i = suffix[i : i + 1, :name_len, :] suffix_i = suffix[i : i + 1, name_len:, :] ctx_i = ctx[i : i + 1, :, :] prompt = torch.cat( [ prefix_i, # (1, 1, dim) class_i, # (1, name_len, dim) ctx_i, # (1, n_ctx, dim) suffix_i, # (1, *, dim) ], dim=1, ) prompts.append(prompt) prompts = torch.cat(prompts, dim=0) else: raise ValueError return prompts class ClipTestTimeTuning(nn.Module): def __init__(self, device, classnames, batch_size, criterion='cosine', arch="ViT-L/14", n_ctx=16, ctx_init=None, ctx_position='end', learned_cls=False, pubmedclip_path=None, merge=False, state_dict=None): super(ClipTestTimeTuning, self).__init__() clip, _, _ = load(arch, device=device, download_root=DOWNLOAD_ROOT) if pubmedclip_path is not None: ft_dict = torch.load(pubmedclip_path, map_location=f'cuda:{device}') if merge: print("Merging the weights of clip and state dict using WiSE-FT approach") # WiSE-FT approach merged_dict = {} alpha = 0.50 # You can adjust this value as needed for key in clip.state_dict().keys(): merged_dict[key] = alpha * ft_dict[key] + (1 - alpha) * clip.state_dict()[key] # clip.load_state_dict(state_dict) # Model Stock # state_dict = stock_model(state_dict, clip.state_dict()) else: merged_dict = ft_dict clip.load_state_dict(merged_dict) if state_dict is not None: clip.load_state_dict(state_dict) self.visual = clip.visual self.text_encoder = TextEncoder(clip) self.logit_scale = clip.logit_scale.data # prompt tuning self.prompt_learner = PromptLearner(clip, classnames, batch_size, n_ctx, ctx_init, ctx_position, learned_cls) self.criterion = criterion self.l2_norm_cal = False @property def dtype(self): return self.visual.conv1.weight.dtype # restore the initial state of the prompt_learner (tunable prompt) def reset(self): self.prompt_learner.reset() def reset_classnames(self, classnames, arch): self.prompt_learner.reset_classnames(classnames, arch) def get_text_features(self, normalize=True): text_features = [] prompts = self.prompt_learner() tokenized_prompts = self.prompt_learner.tokenized_prompts t_features = self.text_encoder(prompts, tokenized_prompts) if normalize: t_features = t_features / t_features.norm(dim=-1, keepdim=True) text_features.append(t_features) text_features = torch.stack(text_features, dim=0) return torch.mean(text_features, dim=0) def inference(self, image, return_logits=False, normalize=True): with torch.no_grad(): image_features = self.visual(image.type(self.dtype)) # with torch.no_grad(): text_features = self.get_text_features(normalize=normalize) if normalize: image_features = image_features / image_features.norm(dim=-1, keepdim=True) #[c-tpt] -------------------------------------------- if self.l2_norm_cal: prompt_mean = text_features.mean(0) feature_distance = text_features - prompt_mean l2_norm = torch.linalg.norm(feature_distance, dim=-1) l2_norm_mean = l2_norm.mean() #for saving to csv file self.l2_norm_mean = l2_norm_mean.item() #for training self.l2_norm_mean_training = l2_norm_mean #----------------------------------------------------- logit_scale = self.logit_scale.exp() logits = logit_scale * image_features @ text_features.t() if return_logits: return logits, image_features, text_features return logits def forward(self, input, return_logits=False, normalize=True): if isinstance(input, Tuple): view_0, view_1, view_2 = input return self.contrast_prompt_tuning(view_0, view_1, view_2) elif len(input.size()) == 2: return self.directional_prompt_tuning(input) else: return self.inference(input, return_logits, normalize) def get_coop(clip_arch, test_set, device, n_ctx, ctx_init, classnames, learned_cls=False, pubmedclip_path=None, merge=False, state_dict=None): # if test_set in fewshot_datasets: # classnames = eval("{}_classes".format(test_set.lower())) # elif test_set == 'bongard': # if learned_cls: # classnames = ['X', 'X'] # else: # classnames = ['True', 'False'] # else: # classnames = imagenet_classes model = ClipTestTimeTuning(device, classnames, None, arch=clip_arch, n_ctx=n_ctx, ctx_init=ctx_init, learned_cls=learned_cls, pubmedclip_path=pubmedclip_path, merge=merge, state_dict=state_dict) return model