|
|
|
|
|
import math |
|
|
from typing import List, Tuple |
|
|
|
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
|
|
|
from clip import load, tokenize |
|
|
from .simple_tokenizer import SimpleTokenizer as _Tokenizer |
|
|
from data.imagnet_prompts import imagenet_classes |
|
|
from data.fewshot_datasets import fewshot_datasets |
|
|
from data.cls_to_names import * |
|
|
from utils.ModelStock import stock_model |
|
|
|
|
|
_tokenizer = _Tokenizer() |
|
|
|
|
|
DOWNLOAD_ROOT='~/.cache/clip' |
|
|
|
|
|
class ClipImageEncoder(nn.Module): |
|
|
def __init__(self, device, arch="ViT-L/14", image_resolution=224, n_class=1000): |
|
|
super(ClipImageEncoder, self).__init__() |
|
|
clip, embed_dim, _ = load(arch, device=device, download_root=DOWNLOAD_ROOT) |
|
|
self.encoder = clip.visual |
|
|
del clip.transformer |
|
|
torch.cuda.empty_cache() |
|
|
|
|
|
self.cls_head = nn.Linear(embed_dim, n_class) |
|
|
|
|
|
@property |
|
|
def dtype(self): |
|
|
return self.encoder.conv1.weight.dtype |
|
|
|
|
|
def forward(self, image): |
|
|
x = self.encoder(image.type(self.dtype)) |
|
|
output = self.cls_head(x) |
|
|
return output |
|
|
|
|
|
|
|
|
class TextEncoder(nn.Module): |
|
|
def __init__(self, clip_model): |
|
|
super().__init__() |
|
|
self.transformer = clip_model.transformer |
|
|
self.positional_embedding = clip_model.positional_embedding |
|
|
self.ln_final = clip_model.ln_final |
|
|
self.text_projection = clip_model.text_projection |
|
|
self.dtype = clip_model.dtype |
|
|
|
|
|
def forward(self, prompts, tokenized_prompts): |
|
|
x = prompts + self.positional_embedding.type(self.dtype) |
|
|
x = x.permute(1, 0, 2) |
|
|
x = self.transformer(x) |
|
|
x = x.permute(1, 0, 2) |
|
|
x = self.ln_final(x).type(self.dtype) |
|
|
|
|
|
|
|
|
|
|
|
x = x[torch.arange(x.shape[0]), tokenized_prompts.argmax(dim=-1)] @ self.text_projection |
|
|
|
|
|
return x |
|
|
|
|
|
|
|
|
class PromptLearner(nn.Module): |
|
|
def __init__(self, clip_model, classnames, batch_size=None, n_ctx=16, ctx_init=None, ctx_position='end', learned_cls=False): |
|
|
super().__init__() |
|
|
n_cls = len(classnames) |
|
|
self.learned_cls = learned_cls |
|
|
dtype = clip_model.dtype |
|
|
self.dtype = dtype |
|
|
self.device = clip_model.visual.conv1.weight.device |
|
|
ctx_dim = clip_model.ln_final.weight.shape[0] |
|
|
self.ctx_dim = ctx_dim |
|
|
self.batch_size = batch_size |
|
|
|
|
|
|
|
|
|
|
|
if ctx_init: |
|
|
|
|
|
print("Initializing the contect with given words: [{}]".format(ctx_init)) |
|
|
ctx_init = ctx_init.replace("_", " ") |
|
|
if '[CLS]' in ctx_init: |
|
|
ctx_list = ctx_init.split(" ") |
|
|
split_idx = ctx_list.index("[CLS]") |
|
|
ctx_init = ctx_init.replace("[CLS] ", "") |
|
|
ctx_position = "middle" |
|
|
else: |
|
|
split_idx = None |
|
|
self.split_idx = split_idx |
|
|
n_ctx = len(ctx_init.split(" ")) |
|
|
prompt = tokenize(ctx_init).to(self.device) |
|
|
with torch.no_grad(): |
|
|
embedding = clip_model.token_embedding(prompt).type(dtype) |
|
|
ctx_vectors = embedding[0, 1 : 1 + n_ctx, :] |
|
|
prompt_prefix = ctx_init |
|
|
else: |
|
|
print("Random initialization: initializing a generic context") |
|
|
ctx_vectors = torch.empty(n_ctx, ctx_dim, dtype=dtype) |
|
|
nn.init.normal_(ctx_vectors, std=0.02) |
|
|
prompt_prefix = " ".join(["X"] * n_ctx) |
|
|
|
|
|
self.prompt_prefix = prompt_prefix |
|
|
|
|
|
print(f'Initial context: "{prompt_prefix}"') |
|
|
print(f"Number of context words (tokens): {n_ctx}") |
|
|
|
|
|
|
|
|
if self.batch_size is not None: |
|
|
ctx_vectors = ctx_vectors.repeat(batch_size, 1, 1) |
|
|
self.ctx_init_state = ctx_vectors.detach().clone() |
|
|
self.ctx = nn.Parameter(ctx_vectors) |
|
|
|
|
|
if not self.learned_cls: |
|
|
classnames = [name.replace("_", " ") for name in classnames] |
|
|
name_lens = [len(_tokenizer.encode(name)) for name in classnames] |
|
|
prompts = [prompt_prefix + " " + name + "." for name in classnames] |
|
|
else: |
|
|
print("Random initialization: initializing a learnable class token") |
|
|
cls_vectors = torch.empty(n_cls, 1, ctx_dim, dtype=dtype) |
|
|
nn.init.normal_(cls_vectors, std=0.02) |
|
|
cls_token = "X" |
|
|
name_lens = [1 for _ in classnames] |
|
|
prompts = [prompt_prefix + " " + cls_token + "." for _ in classnames] |
|
|
|
|
|
self.cls_init_state = cls_vectors.detach().clone() |
|
|
self.cls = nn.Parameter(cls_vectors) |
|
|
|
|
|
tokenized_prompts = torch.cat([tokenize(p) for p in prompts]).to(self.device) |
|
|
with torch.no_grad(): |
|
|
embedding = clip_model.token_embedding(tokenized_prompts).type(dtype) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.register_buffer("token_prefix", embedding[:, :1, :]) |
|
|
if self.learned_cls: |
|
|
self.register_buffer("token_suffix", embedding[:, 1 + n_ctx + 1:, :]) |
|
|
else: |
|
|
self.register_buffer("token_suffix", embedding[:, 1 + n_ctx :, :]) |
|
|
|
|
|
self.ctx_init = ctx_init |
|
|
self.tokenized_prompts = tokenized_prompts |
|
|
self.name_lens = name_lens |
|
|
self.class_token_position = ctx_position |
|
|
self.n_cls = n_cls |
|
|
self.n_ctx = n_ctx |
|
|
self.classnames = classnames |
|
|
|
|
|
def reset(self): |
|
|
ctx_vectors = self.ctx_init_state |
|
|
self.ctx.copy_(ctx_vectors) |
|
|
if self.learned_cls: |
|
|
cls_vectors = self.cls_init_state |
|
|
self.cls.copy_(cls_vectors) |
|
|
|
|
|
def reset_classnames(self, classnames, arch): |
|
|
self.n_cls = len(classnames) |
|
|
if not self.learned_cls: |
|
|
classnames = [name.replace("_", " ") for name in classnames] |
|
|
name_lens = [len(_tokenizer.encode(name)) for name in classnames] |
|
|
prompts = [self.prompt_prefix + " " + name + "." for name in classnames] |
|
|
else: |
|
|
cls_vectors = torch.empty(self.n_cls, 1, self.ctx_dim, dtype=self.dtype) |
|
|
nn.init.normal_(cls_vectors, std=0.02) |
|
|
cls_token = "X" |
|
|
name_lens = [1 for _ in classnames] |
|
|
prompts = [self.prompt_prefix + " " + cls_token + "." for _ in classnames] |
|
|
|
|
|
|
|
|
self.cls_init_state = cls_vectors.detach().clone() |
|
|
tokenized_prompts = torch.cat([tokenize(p) for p in prompts]).to(self.device) |
|
|
|
|
|
clip, _, _ = load(arch, device=self.device, download_root=DOWNLOAD_ROOT) |
|
|
|
|
|
with torch.no_grad(): |
|
|
embedding = clip.token_embedding(tokenized_prompts).type(self.dtype) |
|
|
|
|
|
self.token_prefix = embedding[:, :1, :] |
|
|
self.token_suffix = embedding[:, 1 + self.n_ctx :, :] |
|
|
|
|
|
self.name_lens = name_lens |
|
|
self.tokenized_prompts = tokenized_prompts |
|
|
self.classnames = classnames |
|
|
|
|
|
def forward(self, init=None): |
|
|
|
|
|
if init is not None: |
|
|
ctx = init |
|
|
else: |
|
|
ctx = self.ctx |
|
|
if ctx.dim() == 2: |
|
|
ctx = ctx.unsqueeze(0).expand(self.n_cls, -1, -1) |
|
|
elif not ctx.size()[0] == self.n_cls: |
|
|
ctx = ctx.unsqueeze(1).expand(-1, self.n_cls, -1, -1) |
|
|
|
|
|
prefix = self.token_prefix |
|
|
suffix = self.token_suffix |
|
|
if self.batch_size is not None: |
|
|
|
|
|
prefix = prefix.repeat(self.batch_size, 1, 1, 1) |
|
|
suffix = suffix.repeat(self.batch_size, 1, 1, 1) |
|
|
|
|
|
if self.learned_cls: |
|
|
assert self.class_token_position == "end" |
|
|
if self.class_token_position == "end": |
|
|
if self.learned_cls: |
|
|
cls = self.cls |
|
|
prompts = torch.cat( |
|
|
[ |
|
|
prefix, |
|
|
ctx, |
|
|
cls, |
|
|
suffix, |
|
|
], |
|
|
dim=-2, |
|
|
) |
|
|
else: |
|
|
prompts = torch.cat( |
|
|
[ |
|
|
prefix, |
|
|
ctx, |
|
|
suffix, |
|
|
], |
|
|
dim=-2, |
|
|
) |
|
|
elif self.class_token_position == "middle": |
|
|
|
|
|
if self.split_idx is not None: |
|
|
half_n_ctx = self.split_idx |
|
|
else: |
|
|
half_n_ctx = self.n_ctx // 2 |
|
|
prompts = [] |
|
|
for i in range(self.n_cls): |
|
|
name_len = self.name_lens[i] |
|
|
prefix_i = prefix[i : i + 1, :, :] |
|
|
class_i = suffix[i : i + 1, :name_len, :] |
|
|
suffix_i = suffix[i : i + 1, name_len:, :] |
|
|
ctx_i_half1 = ctx[i : i + 1, :half_n_ctx, :] |
|
|
ctx_i_half2 = ctx[i : i + 1, half_n_ctx:, :] |
|
|
prompt = torch.cat( |
|
|
[ |
|
|
prefix_i, |
|
|
ctx_i_half1, |
|
|
class_i, |
|
|
ctx_i_half2, |
|
|
suffix_i, |
|
|
], |
|
|
dim=1, |
|
|
) |
|
|
prompts.append(prompt) |
|
|
prompts = torch.cat(prompts, dim=0) |
|
|
|
|
|
elif self.class_token_position == "front": |
|
|
prompts = [] |
|
|
for i in range(self.n_cls): |
|
|
name_len = self.name_lens[i] |
|
|
prefix_i = prefix[i : i + 1, :, :] |
|
|
class_i = suffix[i : i + 1, :name_len, :] |
|
|
suffix_i = suffix[i : i + 1, name_len:, :] |
|
|
ctx_i = ctx[i : i + 1, :, :] |
|
|
prompt = torch.cat( |
|
|
[ |
|
|
prefix_i, |
|
|
class_i, |
|
|
ctx_i, |
|
|
suffix_i, |
|
|
], |
|
|
dim=1, |
|
|
) |
|
|
prompts.append(prompt) |
|
|
prompts = torch.cat(prompts, dim=0) |
|
|
|
|
|
else: |
|
|
raise ValueError |
|
|
|
|
|
return prompts |
|
|
|
|
|
|
|
|
class ClipTestTimeTuning(nn.Module): |
|
|
def __init__(self, device, classnames, batch_size, criterion='cosine', arch="ViT-L/14", |
|
|
n_ctx=16, ctx_init=None, ctx_position='end', learned_cls=False, pubmedclip_path=None, |
|
|
merge=False, state_dict=None): |
|
|
super(ClipTestTimeTuning, self).__init__() |
|
|
clip, _, _ = load(arch, device=device, download_root=DOWNLOAD_ROOT) |
|
|
if pubmedclip_path is not None: |
|
|
ft_dict = torch.load(pubmedclip_path, map_location=f'cuda:{device}') |
|
|
if merge: |
|
|
print("Merging the weights of clip and state dict using WiSE-FT approach") |
|
|
|
|
|
merged_dict = {} |
|
|
alpha = 0.50 |
|
|
for key in clip.state_dict().keys(): |
|
|
merged_dict[key] = alpha * ft_dict[key] + (1 - alpha) * clip.state_dict()[key] |
|
|
|
|
|
|
|
|
else: |
|
|
merged_dict = ft_dict |
|
|
clip.load_state_dict(merged_dict) |
|
|
if state_dict is not None: |
|
|
clip.load_state_dict(state_dict) |
|
|
self.visual = clip.visual |
|
|
self.text_encoder = TextEncoder(clip) |
|
|
self.logit_scale = clip.logit_scale.data |
|
|
|
|
|
self.prompt_learner = PromptLearner(clip, classnames, batch_size, n_ctx, ctx_init, ctx_position, learned_cls) |
|
|
self.criterion = criterion |
|
|
self.l2_norm_cal = False |
|
|
|
|
|
@property |
|
|
def dtype(self): |
|
|
return self.visual.conv1.weight.dtype |
|
|
|
|
|
|
|
|
def reset(self): |
|
|
self.prompt_learner.reset() |
|
|
|
|
|
def reset_classnames(self, classnames, arch): |
|
|
self.prompt_learner.reset_classnames(classnames, arch) |
|
|
|
|
|
def get_text_features(self, normalize=True): |
|
|
text_features = [] |
|
|
prompts = self.prompt_learner() |
|
|
tokenized_prompts = self.prompt_learner.tokenized_prompts |
|
|
t_features = self.text_encoder(prompts, tokenized_prompts) |
|
|
if normalize: |
|
|
t_features = t_features / t_features.norm(dim=-1, keepdim=True) |
|
|
text_features.append(t_features) |
|
|
text_features = torch.stack(text_features, dim=0) |
|
|
|
|
|
return torch.mean(text_features, dim=0) |
|
|
|
|
|
def inference(self, image, return_logits=False, normalize=True): |
|
|
with torch.no_grad(): |
|
|
image_features = self.visual(image.type(self.dtype)) |
|
|
|
|
|
text_features = self.get_text_features(normalize=normalize) |
|
|
if normalize: |
|
|
image_features = image_features / image_features.norm(dim=-1, keepdim=True) |
|
|
|
|
|
|
|
|
if self.l2_norm_cal: |
|
|
prompt_mean = text_features.mean(0) |
|
|
feature_distance = text_features - prompt_mean |
|
|
l2_norm = torch.linalg.norm(feature_distance, dim=-1) |
|
|
l2_norm_mean = l2_norm.mean() |
|
|
|
|
|
|
|
|
self.l2_norm_mean = l2_norm_mean.item() |
|
|
|
|
|
|
|
|
self.l2_norm_mean_training = l2_norm_mean |
|
|
|
|
|
|
|
|
|
|
|
logit_scale = self.logit_scale.exp() |
|
|
logits = logit_scale * image_features @ text_features.t() |
|
|
|
|
|
if return_logits: |
|
|
return logits, image_features, text_features |
|
|
|
|
|
return logits |
|
|
|
|
|
def forward(self, input, return_logits=False, normalize=True): |
|
|
if isinstance(input, Tuple): |
|
|
view_0, view_1, view_2 = input |
|
|
return self.contrast_prompt_tuning(view_0, view_1, view_2) |
|
|
elif len(input.size()) == 2: |
|
|
return self.directional_prompt_tuning(input) |
|
|
else: |
|
|
return self.inference(input, return_logits, normalize) |
|
|
|
|
|
|
|
|
def get_coop(clip_arch, test_set, device, n_ctx, ctx_init, classnames, learned_cls=False, pubmedclip_path=None, merge=False, state_dict=None): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
model = ClipTestTimeTuning(device, classnames, None, arch=clip_arch, |
|
|
n_ctx=n_ctx, ctx_init=ctx_init, learned_cls=learned_cls, pubmedclip_path=pubmedclip_path, merge=merge, |
|
|
state_dict=state_dict) |
|
|
|
|
|
return model |
|
|
|
|
|
|