| import torch |
| import torch.nn as nn |
| import pdb |
| import math |
| from transformers.activations import ACT2FN |
| from einops import rearrange, reduce, repeat |
| from inspect import isfunction |
| import math |
| import torch.nn.functional as F |
| from torch import nn, einsum |
| from einops import rearrange, repeat |
| from typing import Optional, Any |
|
|
| try: |
| import xformers |
| import xformers.ops |
|
|
| XFORMERS_IS_AVAILBLE = True |
| except: |
| XFORMERS_IS_AVAILBLE = False |
|
|
| import importlib |
| import numpy as np |
| import cv2, os |
| import torch.distributed as dist |
|
|
|
|
| def count_params(model, verbose=False): |
| total_params = sum(p.numel() for p in model.parameters()) |
| if verbose: |
| print(f"{model.__class__.__name__} has {total_params*1.e-6:.2f} M params.") |
| return total_params |
|
|
|
|
| def check_istarget(name, para_list): |
| """ |
| name: full name of source para |
| para_list: partial name of target para |
| """ |
| istarget = False |
| for para in para_list: |
| if para in name: |
| return True |
| return istarget |
|
|
|
|
| def instantiate_from_config(config): |
| if not "target" in config: |
| if config == "__is_first_stage__": |
| return None |
| elif config == "__is_unconditional__": |
| return None |
| raise KeyError("Expected key `target` to instantiate.") |
|
|
| return get_obj_from_str(config["target"])(**config.get("params", dict())) |
|
|
|
|
| def get_obj_from_str(string, reload=False): |
| module, cls = string.rsplit(".", 1) |
| if reload: |
| module_imp = importlib.import_module(module) |
| importlib.reload(module_imp) |
| return getattr(importlib.import_module(module, package=None), cls) |
|
|
|
|
| def load_npz_from_dir(data_dir): |
| data = [ |
| np.load(os.path.join(data_dir, data_name))["arr_0"] |
| for data_name in os.listdir(data_dir) |
| ] |
| data = np.concatenate(data, axis=0) |
| return data |
|
|
|
|
| def load_npz_from_paths(data_paths): |
| data = [np.load(data_path)["arr_0"] for data_path in data_paths] |
| data = np.concatenate(data, axis=0) |
| return data |
|
|
|
|
| def resize_numpy_image(image, max_resolution=512 * 512, resize_short_edge=None): |
| h, w = image.shape[:2] |
| if resize_short_edge is not None: |
| k = resize_short_edge / min(h, w) |
| else: |
| k = max_resolution / (h * w) |
| k = k**0.5 |
| h = int(np.round(h * k / 64)) * 64 |
| w = int(np.round(w * k / 64)) * 64 |
| image = cv2.resize(image, (w, h), interpolation=cv2.INTER_LANCZOS4) |
| return image |
|
|
|
|
| def setup_dist(args): |
| if dist.is_initialized(): |
| return |
| torch.cuda.set_device(args.local_rank) |
| torch.distributed.init_process_group("nccl", init_method="env://") |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| import torch.nn as nn |
| import math |
| from inspect import isfunction |
| import torch |
| from torch import nn |
| import torch.distributed as dist |
|
|
|
|
| def gather_data(data, return_np=True): |
| """gather data from multiple processes to one list""" |
| data_list = [torch.zeros_like(data) for _ in range(dist.get_world_size())] |
| dist.all_gather(data_list, data) |
| if return_np: |
| data_list = [data.cpu().numpy() for data in data_list] |
| return data_list |
|
|
|
|
| def autocast(f): |
| def do_autocast(*args, **kwargs): |
| with torch.cuda.amp.autocast( |
| enabled=True, |
| dtype=torch.get_autocast_gpu_dtype(), |
| cache_enabled=torch.is_autocast_cache_enabled(), |
| ): |
| return f(*args, **kwargs) |
|
|
| return do_autocast |
|
|
|
|
| def extract_into_tensor(a, t, x_shape): |
| b, *_ = t.shape |
| out = a.gather(-1, t) |
| return out.reshape(b, *((1,) * (len(x_shape) - 1))) |
|
|
|
|
| def noise_like(shape, device, repeat=False): |
| repeat_noise = lambda: torch.randn((1, *shape[1:]), device=device).repeat( |
| shape[0], *((1,) * (len(shape) - 1)) |
| ) |
| noise = lambda: torch.randn(shape, device=device) |
| return repeat_noise() if repeat else noise() |
|
|
|
|
| def default(val, d): |
| if exists(val): |
| return val |
| return d() if isfunction(d) else d |
|
|
|
|
| def exists(val): |
| return val is not None |
|
|
|
|
| def identity(*args, **kwargs): |
| return nn.Identity() |
|
|
|
|
| def uniq(arr): |
| return {el: True for el in arr}.keys() |
|
|
|
|
| def mean_flat(tensor): |
| """ |
| Take the mean over all non-batch dimensions. |
| """ |
| return tensor.mean(dim=list(range(1, len(tensor.shape)))) |
|
|
|
|
| def ismap(x): |
| if not isinstance(x, torch.Tensor): |
| return False |
| return (len(x.shape) == 4) and (x.shape[1] > 3) |
|
|
|
|
| def isimage(x): |
| if not isinstance(x, torch.Tensor): |
| return False |
| return (len(x.shape) == 4) and (x.shape[1] == 3 or x.shape[1] == 1) |
|
|
|
|
| def max_neg_value(t): |
| return -torch.finfo(t.dtype).max |
|
|
|
|
| def shape_to_str(x): |
| shape_str = "x".join([str(x) for x in x.shape]) |
| return shape_str |
|
|
|
|
| def init_(tensor): |
| dim = tensor.shape[-1] |
| std = 1 / math.sqrt(dim) |
| tensor.uniform_(-std, std) |
| return tensor |
|
|
|
|
|
|
| def disabled_train(self, mode=True): |
| """Overwrite model.train with this function to make sure train/eval mode |
| does not change anymore.""" |
| return self |
|
|
|
|
| def zero_module(module): |
| """ |
| Zero out the parameters of a module and return it. |
| """ |
| for p in module.parameters(): |
| p.detach().zero_() |
| return module |
|
|
|
|
| def scale_module(module, scale): |
| """ |
| Scale the parameters of a module and return it. |
| """ |
| for p in module.parameters(): |
| p.detach().mul_(scale) |
| return module |
|
|
|
|
| def conv_nd(dims, *args, **kwargs): |
| """ |
| Create a 1D, 2D, or 3D convolution module. |
| """ |
| if dims == 1: |
| return nn.Conv1d(*args, **kwargs) |
| elif dims == 2: |
| return nn.Conv2d(*args, **kwargs) |
| elif dims == 3: |
| return nn.Conv3d(*args, **kwargs) |
| raise ValueError(f"unsupported dimensions: {dims}") |
|
|
|
|
| def linear(*args, **kwargs): |
| """ |
| Create a linear module. |
| """ |
| return nn.Linear(*args, **kwargs) |
|
|
|
|
| def avg_pool_nd(dims, *args, **kwargs): |
| """ |
| Create a 1D, 2D, or 3D average pooling module. |
| """ |
| if dims == 1: |
| return nn.AvgPool1d(*args, **kwargs) |
| elif dims == 2: |
| return nn.AvgPool2d(*args, **kwargs) |
| elif dims == 3: |
| return nn.AvgPool3d(*args, **kwargs) |
| raise ValueError(f"unsupported dimensions: {dims}") |
|
|
|
|
| def nonlinearity(type="silu"): |
| if type == "silu": |
| return nn.SiLU() |
| elif type == "leaky_relu": |
| return nn.LeakyReLU() |
|
|
|
|
| class GroupNormSpecific(nn.GroupNorm): |
| def forward(self, x): |
| if x.dtype == torch.float16 or x.dtype == torch.bfloat16: |
| return super().forward(x).type(x.dtype) |
| else: |
| return super().forward(x.float()).type(x.dtype) |
|
|
|
|
| def normalization(channels, num_groups=32): |
| """ |
| Make a standard normalization layer. |
| :param channels: number of input channels. |
| :return: an nn.Module for normalization. |
| """ |
| return GroupNormSpecific(num_groups, channels) |
|
|
|
|
| class HybridConditioner(nn.Module): |
|
|
| def __init__(self, c_concat_config, c_crossattn_config): |
| super().__init__() |
| self.concat_conditioner = instantiate_from_config(c_concat_config) |
| self.crossattn_conditioner = instantiate_from_config(c_crossattn_config) |
|
|
| def forward(self, c_concat, c_crossattn): |
| c_concat = self.concat_conditioner(c_concat) |
| c_crossattn = self.crossattn_conditioner(c_crossattn) |
| return {"c_concat": [c_concat], "c_crossattn": [c_crossattn]} |
|
|
| def exists(val): |
| return val is not None |
|
|
|
|
| def uniq(arr): |
| return {el: True for el in arr}.keys() |
|
|
|
|
| def default(val, d): |
| if exists(val): |
| return val |
| return d() if isfunction(d) else d |
|
|
|
|
| def max_neg_value(t): |
| return -torch.finfo(t.dtype).max |
|
|
|
|
| def init_(tensor): |
| dim = tensor.shape[-1] |
| std = 1 / math.sqrt(dim) |
| tensor.uniform_(-std, std) |
| return tensor |
|
|
|
|
| |
| class GEGLU(nn.Module): |
| def __init__(self, dim_in, dim_out): |
| super().__init__() |
| self.proj = nn.Linear(dim_in, dim_out * 2) |
|
|
| def forward(self, x): |
| x, gate = self.proj(x).chunk(2, dim=-1) |
| return x * F.gelu(gate) |
|
|
|
|
| class FeedForward(nn.Module): |
| def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.0): |
| super().__init__() |
| inner_dim = int(dim * mult) |
| dim_out = default(dim_out, dim) |
| project_in = ( |
| nn.Sequential(nn.Linear(dim, inner_dim), nn.GELU()) |
| if not glu |
| else GEGLU(dim, inner_dim) |
| ) |
|
|
| self.net = nn.Sequential( |
| project_in, nn.Dropout(dropout), nn.Linear(inner_dim, dim_out) |
| ) |
|
|
| def forward(self, x): |
| return self.net(x) |
|
|
|
|
| def zero_module(module): |
| """ |
| Zero out the parameters of a module and return it. |
| """ |
| for p in module.parameters(): |
| p.detach().zero_() |
| return module |
|
|
|
|
| def Normalize(in_channels, num_groups=32): |
| return torch.nn.GroupNorm( |
| num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True |
| ) |
|
|
|
|
| class RelativePosition(nn.Module): |
| """https://github.com/evelinehong/Transformer_Relative_Position_PyTorch/blob/master/relative_position.py""" |
|
|
| def __init__(self, num_units, max_relative_position): |
| super().__init__() |
| self.num_units = num_units |
| self.max_relative_position = max_relative_position |
| self.embeddings_table = nn.Parameter( |
| torch.Tensor(max_relative_position * 2 + 1, num_units) |
| ) |
| nn.init.xavier_uniform_(self.embeddings_table) |
|
|
| def forward(self, length_q, length_k): |
| device = self.embeddings_table.device |
| range_vec_q = torch.arange(length_q, device=device) |
| range_vec_k = torch.arange(length_k, device=device) |
| distance_mat = range_vec_k[None, :] - range_vec_q[:, None] |
| distance_mat_clipped = torch.clamp( |
| distance_mat, -self.max_relative_position, self.max_relative_position |
| ) |
| final_mat = distance_mat_clipped + self.max_relative_position |
| |
| |
| final_mat = final_mat.long() |
| embeddings = self.embeddings_table[final_mat] |
| return embeddings |
|
|
|
|
| class TemporalCrossAttention(nn.Module): |
| def __init__( |
| self, |
| query_dim, |
| context_dim=None, |
| heads=8, |
| dim_head=64, |
| dropout=0.0, |
| temporal_length=None, |
| image_length=None, |
| use_relative_position=False, |
| img_video_joint_train=False, |
| use_tempoal_causal_attn=False, |
| bidirectional_causal_attn=False, |
| tempoal_attn_type=None, |
| joint_train_mode="same_batch", |
| **kwargs, |
| ): |
| super().__init__() |
| inner_dim = dim_head * heads |
| context_dim = default(context_dim, query_dim) |
| self.context_dim = context_dim |
|
|
| self.scale = dim_head**-0.5 |
| self.heads = heads |
| self.temporal_length = temporal_length |
| self.use_relative_position = use_relative_position |
| self.img_video_joint_train = img_video_joint_train |
| self.bidirectional_causal_attn = bidirectional_causal_attn |
| self.joint_train_mode = joint_train_mode |
| assert joint_train_mode in ["same_batch", "diff_batch"] |
| self.tempoal_attn_type = tempoal_attn_type |
|
|
| if bidirectional_causal_attn: |
| assert use_tempoal_causal_attn |
| if tempoal_attn_type: |
| assert tempoal_attn_type in ["sparse_causal", "sparse_causal_first"] |
| assert not use_tempoal_causal_attn |
| assert not ( |
| img_video_joint_train and (self.joint_train_mode == "same_batch") |
| ) |
| self.to_q = nn.Linear(query_dim, inner_dim, bias=False) |
| self.to_k = nn.Linear(context_dim, inner_dim, bias=False) |
| self.to_v = nn.Linear(context_dim, inner_dim, bias=False) |
|
|
| assert not ( |
| img_video_joint_train |
| and (self.joint_train_mode == "same_batch") |
| and use_tempoal_causal_attn |
| ) |
| if img_video_joint_train: |
| if self.joint_train_mode == "same_batch": |
| mask = torch.ones( |
| [1, temporal_length + image_length, temporal_length + image_length] |
| ) |
| |
| |
| mask[:, temporal_length:, :] = 0 |
| mask[:, :, temporal_length:] = 0 |
| self.mask = mask |
| else: |
| self.mask = None |
| elif use_tempoal_causal_attn: |
| |
| self.mask = torch.tril(torch.ones([1, temporal_length, temporal_length])) |
| elif tempoal_attn_type == "sparse_causal": |
| |
| mask1 = torch.tril( |
| torch.ones([1, temporal_length, temporal_length]) |
| ).bool() |
| mask2 = torch.zeros( |
| [1, temporal_length, temporal_length] |
| ) |
| mask2[:, 2:temporal_length, : temporal_length - 2] = torch.tril( |
| torch.ones([1, temporal_length - 2, temporal_length - 2]) |
| ) |
| mask2 = (1 - mask2).bool() |
| self.mask = mask1 & mask2 |
| elif tempoal_attn_type == "sparse_causal_first": |
| |
| mask1 = torch.tril( |
| torch.ones([1, temporal_length, temporal_length]) |
| ).bool() |
| mask2 = torch.zeros([1, temporal_length, temporal_length]) |
| mask2[:, 2:temporal_length, 1 : temporal_length - 1] = torch.tril( |
| torch.ones([1, temporal_length - 2, temporal_length - 2]) |
| ) |
| mask2 = (1 - mask2).bool() |
| self.mask = mask1 & mask2 |
| else: |
| self.mask = None |
|
|
| if use_relative_position: |
| assert temporal_length is not None |
| self.relative_position_k = RelativePosition( |
| num_units=dim_head, max_relative_position=temporal_length |
| ) |
| self.relative_position_v = RelativePosition( |
| num_units=dim_head, max_relative_position=temporal_length |
| ) |
|
|
| self.to_out = nn.Sequential( |
| nn.Linear(inner_dim, query_dim), nn.Dropout(dropout) |
| ) |
|
|
| nn.init.constant_(self.to_q.weight, 0) |
| nn.init.constant_(self.to_k.weight, 0) |
| nn.init.constant_(self.to_v.weight, 0) |
| nn.init.constant_(self.to_out[0].weight, 0) |
| nn.init.constant_(self.to_out[0].bias, 0) |
|
|
| def forward(self, x, context=None, mask=None): |
| |
| |
| |
| |
|
|
| nh = self.heads |
| out = x |
| q = self.to_q(out) |
| |
| |
| |
| context = default(context, x) |
| |
| k = self.to_k(context) |
| v = self.to_v(context) |
| |
|
|
| q, k, v = map(lambda t: rearrange(t, "b n (h d) -> (b h) n d", h=nh), (q, k, v)) |
| sim = einsum("b i d, b j d -> b i j", q, k) * self.scale |
|
|
| if self.use_relative_position: |
| len_q, len_k, len_v = q.shape[1], k.shape[1], v.shape[1] |
| k2 = self.relative_position_k(len_q, len_k) |
| sim2 = einsum("b t d, t s d -> b t s", q, k2) * self.scale |
| sim += sim2 |
| |
| if exists(self.mask): |
| if mask is None: |
| mask = self.mask.to(sim.device) |
| else: |
| mask = self.mask.to(sim.device).bool() & mask |
| else: |
| mask = mask |
| |
| |
| |
| |
| |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| |
| |
| if mask is not None: |
| max_neg_value = -1e9 |
| sim = sim + (1 - mask.float()) * max_neg_value |
| |
|
|
| |
| |
|
|
| attn = sim.softmax(dim=-1) |
| |
| |
| |
|
|
| |
| |
| |
| |
| |
| out = einsum("b i j, b j d -> b i d", attn, v) |
|
|
| if self.bidirectional_causal_attn: |
| mask_reverse = torch.triu( |
| torch.ones( |
| [1, self.temporal_length, self.temporal_length], device=sim.device |
| ) |
| ) |
| sim_reverse = sim.float().masked_fill(mask_reverse == 0, max_neg_value) |
| attn_reverse = sim_reverse.softmax(dim=-1) |
| out_reverse = einsum("b i j, b j d -> b i d", attn_reverse, v) |
| out += out_reverse |
|
|
| if self.use_relative_position: |
| v2 = self.relative_position_v(len_q, len_v) |
| out2 = einsum("b t s, t s d -> b t d", attn, v2) |
| out += out2 |
| out = rearrange(out, "(b h) n d -> b n (h d)", h=nh) |
| return self.to_out(out) |
|
|
|
|
| class SpatialSelfAttention(nn.Module): |
| def __init__(self, in_channels): |
| super().__init__() |
| self.in_channels = in_channels |
|
|
| self.norm = Normalize(in_channels) |
| self.q = torch.nn.Conv2d( |
| in_channels, in_channels, kernel_size=1, stride=1, padding=0 |
| ) |
| self.k = torch.nn.Conv2d( |
| in_channels, in_channels, kernel_size=1, stride=1, padding=0 |
| ) |
| self.v = torch.nn.Conv2d( |
| in_channels, in_channels, kernel_size=1, stride=1, padding=0 |
| ) |
| self.proj_out = torch.nn.Conv2d( |
| in_channels, in_channels, kernel_size=1, stride=1, padding=0 |
| ) |
|
|
| def forward(self, x): |
| h_ = x |
| h_ = self.norm(h_) |
| q = self.q(h_) |
| k = self.k(h_) |
| v = self.v(h_) |
|
|
| |
| b, c, h, w = q.shape |
| q = rearrange(q, "b c h w -> b (h w) c") |
| k = rearrange(k, "b c h w -> b c (h w)") |
| w_ = torch.einsum("bij,bjk->bik", q, k) |
|
|
| w_ = w_ * (int(c) ** (-0.5)) |
| w_ = torch.nn.functional.softmax(w_, dim=2) |
|
|
| |
| v = rearrange(v, "b c h w -> b c (h w)") |
| w_ = rearrange(w_, "b i j -> b j i") |
| h_ = torch.einsum("bij,bjk->bik", v, w_) |
| h_ = rearrange(h_, "b c (h w) -> b c h w", h=h) |
| h_ = self.proj_out(h_) |
|
|
| return x + h_ |
|
|
|
|
| class CrossAttention(nn.Module): |
| def __init__( |
| self, |
| query_dim, |
| context_dim=None, |
| heads=8, |
| dim_head=64, |
| dropout=0.0, |
| sa_shared_kv=False, |
| shared_type="only_first", |
| **kwargs, |
| ): |
| super().__init__() |
| inner_dim = dim_head * heads |
| context_dim = default(context_dim, query_dim) |
| self.sa_shared_kv = sa_shared_kv |
| assert shared_type in [ |
| "only_first", |
| "all_frames", |
| "first_and_prev", |
| "only_prev", |
| "full", |
| "causal", |
| "full_qkv", |
| ] |
| self.shared_type = shared_type |
|
|
| self.scale = dim_head**-0.5 |
| self.heads = heads |
| self.dim_head = dim_head |
|
|
| self.to_q = nn.Linear(query_dim, inner_dim, bias=False) |
| self.to_k = nn.Linear(context_dim, inner_dim, bias=False) |
| self.to_v = nn.Linear(context_dim, inner_dim, bias=False) |
|
|
| self.to_out = nn.Sequential( |
| nn.Linear(inner_dim, query_dim), nn.Dropout(dropout) |
| ) |
| self.attention_op: Optional[Any] = None |
|
|
| def forward(self, x, context=None, mask=None): |
| h = self.heads |
| b = x.shape[0] |
|
|
| q = self.to_q(x) |
| context = default(context, x) |
| k = self.to_k(context) |
| v = self.to_v(context) |
| if self.sa_shared_kv: |
| if self.shared_type == "only_first": |
| k, v = map( |
| lambda xx: rearrange(xx[0].unsqueeze(0), "b n c -> (b n) c") |
| .unsqueeze(0) |
| .repeat(b, 1, 1), |
| (k, v), |
| ) |
| else: |
| raise NotImplementedError |
|
|
| q, k, v = map(lambda t: rearrange(t, "b n (h d) -> (b h) n d", h=h), (q, k, v)) |
|
|
| sim = einsum("b i d, b j d -> b i j", q, k) * self.scale |
|
|
| if exists(mask): |
| mask = rearrange(mask, "b ... -> b (...)") |
| max_neg_value = -torch.finfo(sim.dtype).max |
| mask = repeat(mask, "b j -> (b h) () j", h=h) |
| sim.masked_fill_(~mask, max_neg_value) |
|
|
| |
| attn = sim.softmax(dim=-1) |
|
|
| out = einsum("b i j, b j d -> b i d", attn, v) |
| out = rearrange(out, "(b h) n d -> b n (h d)", h=h) |
| return self.to_out(out) |
|
|
| def efficient_forward(self, x, context=None, mask=None): |
| q = self.to_q(x) |
| context = default(context, x) |
| k = self.to_k(context) |
| v = self.to_v(context) |
|
|
| b, _, _ = q.shape |
| q, k, v = map( |
| lambda t: t.unsqueeze(3) |
| .reshape(b, t.shape[1], self.heads, self.dim_head) |
| .permute(0, 2, 1, 3) |
| .reshape(b * self.heads, t.shape[1], self.dim_head) |
| .contiguous(), |
| (q, k, v), |
| ) |
| |
| out = xformers.ops.memory_efficient_attention( |
| q, k, v, attn_bias=None, op=self.attention_op |
| ) |
|
|
| if exists(mask): |
| raise NotImplementedError |
| out = ( |
| out.unsqueeze(0) |
| .reshape(b, self.heads, out.shape[1], self.dim_head) |
| .permute(0, 2, 1, 3) |
| .reshape(b, out.shape[1], self.heads * self.dim_head) |
| ) |
| return self.to_out(out) |
|
|
|
|
| class VideoSpatialCrossAttention(CrossAttention): |
| def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0): |
| super().__init__(query_dim, context_dim, heads, dim_head, dropout) |
|
|
| def forward(self, x, context=None, mask=None): |
| b, c, t, h, w = x.shape |
| if context is not None: |
| context = context.repeat(t, 1, 1) |
| x = super.forward(spatial_attn_reshape(x), context=context) + x |
| return spatial_attn_reshape_back(x, b, h) |
|
|
|
|
| def spatial_attn_reshape(x): |
| return rearrange(x, "b c t h w -> (b t) (h w) c") |
|
|
|
|
| def spatial_attn_reshape_back(x, b, h): |
| return rearrange(x, "(b t) (h w) c -> b c t h w", b=b, h=h) |
|
|
|
|
| def temporal_attn_reshape(x): |
| return rearrange(x, "b c t h w -> (b h w) t c") |
|
|
|
|
| def temporal_attn_reshape_back(x, b, h, w): |
| return rearrange(x, "(b h w) t c -> b c t h w", b=b, h=h, w=w) |
|
|
|
|
| def local_spatial_temporal_attn_reshape(x, window_size): |
| B, C, T, H, W = x.shape |
| NH = H // window_size |
| NW = W // window_size |
| |
| |
| |
| x = rearrange( |
| x, |
| "b c t (nh wh) (nw ww) -> b c t nh wh nw ww", |
| nh=NH, |
| nw=NW, |
| wh=window_size, |
| ww=window_size, |
| ).contiguous() |
| x = rearrange( |
| x, "b c t nh wh nw ww -> (b nh nw) (t wh ww) c" |
| ) |
| return x |
|
|
|
|
| def local_spatial_temporal_attn_reshape_back(x, window_size, b, h, w, t): |
| B, L, C = x.shape |
| NH = h // window_size |
| NW = w // window_size |
| x = rearrange( |
| x, |
| "(b nh nw) (t wh ww) c -> b c t nh wh nw ww", |
| b=b, |
| nh=NH, |
| nw=NW, |
| t=t, |
| wh=window_size, |
| ww=window_size, |
| ) |
| x = rearrange(x, "b c t nh wh nw ww -> b c t (nh wh) (nw ww)") |
| return x |
|
|
|
|
| class SpatialTemporalTransformer(nn.Module): |
| """ |
| Transformer block for video-like data (5D tensor). |
| First, project the input (aka embedding) with NO reshape. |
| Then apply standard transformer action. |
| The 5D -> 3D reshape operation will be done in the specific attention module. |
| """ |
|
|
| def __init__( |
| self, |
| in_channels, |
| n_heads, |
| d_head, |
| depth=1, |
| dropout=0.0, |
| context_dim=None, |
| |
| temporal_length=None, |
| image_length=None, |
| use_relative_position=True, |
| img_video_joint_train=False, |
| cross_attn_on_tempoal=False, |
| temporal_crossattn_type="selfattn", |
| order="stst", |
| temporalcrossfirst=False, |
| split_stcontext=False, |
| temporal_context_dim=None, |
| **kwargs, |
| ): |
| super().__init__() |
|
|
| self.in_channels = in_channels |
| inner_dim = n_heads * d_head |
|
|
| self.norm = Normalize(in_channels) |
| self.proj_in = nn.Conv3d( |
| in_channels, inner_dim, kernel_size=1, stride=1, padding=0 |
| ) |
|
|
| self.transformer_blocks = nn.ModuleList( |
| [ |
| BasicTransformerBlockST( |
| inner_dim, |
| n_heads, |
| d_head, |
| dropout=dropout, |
| |
| context_dim=context_dim, |
| |
| temporal_length=temporal_length, |
| image_length=image_length, |
| use_relative_position=use_relative_position, |
| img_video_joint_train=img_video_joint_train, |
| temporal_crossattn_type=temporal_crossattn_type, |
| order=order, |
| temporalcrossfirst=temporalcrossfirst, |
| split_stcontext=split_stcontext, |
| temporal_context_dim=temporal_context_dim, |
| **kwargs, |
| ) |
| for d in range(depth) |
| ] |
| ) |
|
|
| self.proj_out = zero_module( |
| nn.Conv3d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0) |
| ) |
|
|
| def forward(self, x, context=None, temporal_context=None, **kwargs): |
| |
| assert x.dim() == 5, f"x shape = {x.shape}" |
| b, c, t, h, w = x.shape |
| x_in = x |
|
|
| x = self.norm(x) |
| x = self.proj_in(x) |
|
|
| for block in self.transformer_blocks: |
| x = block(x, context=context, temporal_context=temporal_context, **kwargs) |
|
|
| x = self.proj_out(x) |
| return x + x_in |
|
|
|
|
| class STAttentionBlock2(nn.Module): |
| def __init__( |
| self, |
| channels, |
| num_heads=1, |
| num_head_channels=-1, |
| use_checkpoint=False, |
| use_new_attention_order=False, |
| temporal_length=16, |
| image_length=8, |
| use_relative_position=False, |
| img_video_joint_train=False, |
| |
| attn_norm_type="group", |
| use_tempoal_causal_attn=False, |
| ): |
| """ |
| version 1: guided_diffusion implemented version |
| version 2: remove args input argument |
| """ |
| super().__init__() |
|
|
| if num_head_channels == -1: |
| self.num_heads = num_heads |
| else: |
| assert ( |
| channels % num_head_channels == 0 |
| ), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}" |
| self.num_heads = channels // num_head_channels |
| self.use_checkpoint = use_checkpoint |
|
|
| self.temporal_length = temporal_length |
| self.image_length = image_length |
| self.use_relative_position = use_relative_position |
| self.img_video_joint_train = img_video_joint_train |
| self.attn_norm_type = attn_norm_type |
| assert self.attn_norm_type in ["group", "no_norm"] |
| self.use_tempoal_causal_attn = use_tempoal_causal_attn |
|
|
| if self.attn_norm_type == "group": |
| self.norm_s = normalization(channels) |
| self.norm_t = normalization(channels) |
|
|
| self.qkv_s = conv_nd(1, channels, channels * 3, 1) |
| self.qkv_t = conv_nd(1, channels, channels * 3, 1) |
|
|
| if self.img_video_joint_train: |
| mask = torch.ones( |
| [1, temporal_length + image_length, temporal_length + image_length] |
| ) |
| mask[:, temporal_length:, :] = 0 |
| mask[:, :, temporal_length:] = 0 |
| self.register_buffer("mask", mask) |
| else: |
| self.mask = None |
|
|
| if use_new_attention_order: |
| |
| self.attention_s = QKVAttention(self.num_heads) |
| self.attention_t = QKVAttention(self.num_heads) |
| else: |
| |
| self.attention_s = QKVAttentionLegacy(self.num_heads) |
| self.attention_t = QKVAttentionLegacy(self.num_heads) |
|
|
| if use_relative_position: |
| self.relative_position_k = RelativePosition( |
| num_units=channels // self.num_heads, |
| max_relative_position=temporal_length, |
| ) |
| self.relative_position_v = RelativePosition( |
| num_units=channels // self.num_heads, |
| max_relative_position=temporal_length, |
| ) |
|
|
| self.proj_out_s = zero_module( |
| conv_nd(1, channels, channels, 1) |
| ) |
| self.proj_out_t = zero_module( |
| conv_nd(1, channels, channels, 1) |
| ) |
|
|
| def forward(self, x, mask=None): |
| b, c, t, h, w = x.shape |
|
|
| |
| out = rearrange(x, "b c t h w -> (b t) c (h w)") |
| if self.attn_norm_type == "no_norm": |
| qkv = self.qkv_s(out) |
| else: |
| qkv = self.qkv_s(self.norm_s(out)) |
| out = self.attention_s(qkv) |
| out = self.proj_out_s(out) |
| out = rearrange(out, "(b t) c (h w) -> b c t h w", b=b, h=h) |
| x += out |
|
|
| |
| out = rearrange(x, "b c t h w -> (b h w) c t") |
| if self.attn_norm_type == "no_norm": |
| qkv = self.qkv_t(out) |
| else: |
| qkv = self.qkv_t(self.norm_t(out)) |
|
|
| |
| if self.use_relative_position: |
| len_q = qkv.size()[-1] |
| len_k, len_v = len_q, len_q |
| k_rp = self.relative_position_k(len_q, len_k) |
| v_rp = self.relative_position_v(len_q, len_v) |
| out = self.attention_t( |
| qkv, |
| rp=(k_rp, v_rp), |
| mask=self.mask, |
| use_tempoal_causal_attn=self.use_tempoal_causal_attn, |
| ) |
| else: |
| out = self.attention_t( |
| qkv, |
| rp=None, |
| mask=self.mask, |
| use_tempoal_causal_attn=self.use_tempoal_causal_attn, |
| ) |
|
|
| out = self.proj_out_t(out) |
| out = rearrange(out, "(b h w) c t -> b c t h w", b=b, h=h, w=w) |
|
|
| return x + out |
|
|
|
|
| class QKVAttentionLegacy(nn.Module): |
| """ |
| A module which performs QKV attention. Matches legacy QKVAttention + input/ouput heads shaping |
| """ |
|
|
| def __init__(self, n_heads): |
| super().__init__() |
| self.n_heads = n_heads |
|
|
| def forward(self, qkv, rp=None, mask=None): |
| """ |
| Apply QKV attention. |
| |
| :param qkv: an [N x (H * 3 * C) x T] tensor of Qs, Ks, and Vs. |
| :return: an [N x (H * C) x T] tensor after attention. |
| """ |
| if rp is not None or mask is not None: |
| raise NotImplementedError |
| bs, width, length = qkv.shape |
| assert width % (3 * self.n_heads) == 0 |
| ch = width // (3 * self.n_heads) |
| q, k, v = qkv.reshape(bs * self.n_heads, ch * 3, length).split(ch, dim=1) |
| scale = 1 / math.sqrt(math.sqrt(ch)) |
| weight = torch.einsum( |
| "bct,bcs->bts", q * scale, k * scale |
| ) |
| weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype) |
| a = torch.einsum("bts,bcs->bct", weight, v) |
| return a.reshape(bs, -1, length) |
|
|
| @staticmethod |
| def count_flops(model, _x, y): |
| return count_flops_attn(model, _x, y) |
|
|
|
|
| class QKVAttention(nn.Module): |
| """ |
| A module which performs QKV attention and splits in a different order. |
| """ |
|
|
| def __init__(self, n_heads): |
| super().__init__() |
| self.n_heads = n_heads |
|
|
| def forward(self, qkv, rp=None, mask=None, use_tempoal_causal_attn=False): |
| """ |
| Apply QKV attention. |
| |
| :param qkv: an [N x (3 * H * C) x T] tensor of Qs, Ks, and Vs. |
| :return: an [N x (H * C) x T] tensor after attention. |
| """ |
| bs, width, length = qkv.shape |
| assert width % (3 * self.n_heads) == 0 |
| ch = width // (3 * self.n_heads) |
| |
| qkv=qkv.contiguous() |
| q, k, v = qkv.chunk(3, dim=1) |
| scale = 1 / math.sqrt(math.sqrt(ch)) |
| |
|
|
| weight = torch.einsum( |
| "bct,bcs->bts", |
| (q * scale).view(bs * self.n_heads, ch, length), |
| (k * scale).view(bs * self.n_heads, ch, length), |
| ) |
| |
|
|
| if rp is not None: |
| k_rp, v_rp = rp |
| weight2 = torch.einsum( |
| "bct,tsc->bst", (q * scale).view(bs * self.n_heads, ch, length), k_rp |
| ) |
| weight += weight2 |
|
|
| if use_tempoal_causal_attn: |
| |
| assert mask is None, f"Not implemented for merging two masks!" |
| mask = torch.tril(torch.ones(weight.shape)) |
| else: |
| if mask is not None: |
| |
| c, t, _ = weight.shape |
|
|
| if mask.shape[-1] > t: |
| mask = mask[:, :t, :t] |
| elif mask.shape[-1] < t: |
| mask_ = torch.zeros([c, t, t]).to(mask.device) |
| t_ = mask.shape[-1] |
| mask_[:, :t_, :t_] = mask |
| mask = mask_ |
| else: |
| assert ( |
| weight.shape[-1] == mask.shape[-1] |
| ), f"weight={weight.shape}, mask={mask.shape}" |
|
|
| if mask is not None: |
| INF = -1e8 |
| weight = weight.float().masked_fill(mask == 0, INF) |
|
|
| weight = F.softmax(weight.float(), dim=-1).type( |
| weight.dtype |
| ) |
| |
| a = torch.einsum( |
| "bts,bcs->bct", weight, v.reshape(bs * self.n_heads, ch, length) |
| ) |
|
|
| if rp is not None: |
| a2 = torch.einsum("bts,tsc->btc", weight, v_rp).transpose(1, 2) |
| a += a2 |
|
|
| return a.reshape(bs, -1, length) |
|
|
|
|
| def silu(x): |
| |
| return x * torch.sigmoid(x) |
|
|
|
|
| class SiLU(nn.Module): |
| def __init__(self): |
| super(SiLU, self).__init__() |
|
|
| def forward(self, x): |
| return silu(x) |
|
|
|
|
| def Normalize(in_channels, norm_type="group"): |
| assert norm_type in ["group", "batch",'layer'] |
| if norm_type == "group": |
| return torch.nn.GroupNorm( |
| num_groups=32, num_channels=in_channels, eps=1e-6, affine=True |
| ) |
| elif norm_type == "batch": |
| return torch.nn.SyncBatchNorm(in_channels) |
| elif norm_type == "layer": |
| return nn.LayerNorm(in_channels) |
| |
| class SamePadConv3d(nn.Module): |
| def __init__( |
| self, |
| in_channels, |
| out_channels, |
| kernel_size, |
| stride=1, |
| bias=True, |
| padding_type="replicate", |
| ): |
| super().__init__() |
| if isinstance(kernel_size, int): |
| kernel_size = (kernel_size,) * 3 |
| if isinstance(stride, int): |
| stride = (stride,) * 3 |
|
|
| |
| total_pad = tuple([k - s for k, s in zip(kernel_size, stride)]) |
| pad_input = [] |
| for p in total_pad[::-1]: |
| pad_input.append((p // 2 + p % 2, p // 2)) |
| pad_input = sum(pad_input, tuple()) |
| |
| self.pad_input = pad_input |
| self.padding_type = padding_type |
|
|
| self.conv = nn.Conv3d( |
| in_channels, out_channels, kernel_size, stride=stride, padding=0, bias=bias |
| ) |
|
|
| def forward(self, x): |
| tp=x.dtype |
| x = x.float() |
|
|
| |
| x_padded = F.pad(x, self.pad_input, mode=self.padding_type) |
|
|
| |
| x_padded = x_padded.to(tp) |
| |
| return self.conv(x_padded) |
|
|
| class TemporalAttention(nn.Module): |
| def __init__( |
| self, |
| channels, |
| num_heads=1, |
| num_head_channels=-1, |
| max_temporal_length=64, |
| ): |
| """ |
| a clean multi-head temporal attention |
| """ |
| super().__init__() |
|
|
| if num_head_channels == -1: |
| self.num_heads = num_heads |
| else: |
| assert ( |
| channels % num_head_channels == 0 |
| ), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}" |
| self.num_heads = channels // num_head_channels |
|
|
| self.norm = Normalize(channels) |
| self.qkv = zero_module(conv_nd(1, channels, channels * 3, 1)) |
| self.attention = QKVAttention(self.num_heads) |
| self.relative_position_k = RelativePosition( |
| num_units=channels // self.num_heads, |
| max_relative_position=max_temporal_length, |
| ) |
| self.relative_position_v = RelativePosition( |
| num_units=channels // self.num_heads, |
| max_relative_position=max_temporal_length, |
| ) |
| self.proj_out = zero_module( |
| conv_nd(1, channels, channels, 1) |
| ) |
|
|
| def forward(self, x, mask=None): |
| b, c, t, h, w = x.shape |
| out = rearrange(x, "b c t h w -> (b h w) c t") |
| |
| |
| |
| |
| |
| qkv = self.qkv(self.norm(out)) |
| |
| |
| len_q = qkv.size()[-1] |
| len_k, len_v = len_q, len_q |
|
|
| k_rp = self.relative_position_k(len_q, len_k) |
| v_rp = self.relative_position_v(len_q, len_v) |
| out = self.attention(qkv, rp=(k_rp, v_rp)) |
| |
| out = self.proj_out(out) |
| |
| out = rearrange(out, "(b h w) c t -> b c t h w", b=b, h=h, w=w) |
|
|
| return x + out |
| class TemporalAttention_lin(nn.Module): |
| def __init__( |
| self, |
| channels, |
| num_heads=8, |
| num_head_channels=-1, |
| max_temporal_length=64, |
| ): |
| """ |
| a clean multi-head temporal attention |
| """ |
| super().__init__() |
|
|
| if num_head_channels == -1: |
| self.num_heads = num_heads |
| else: |
| assert ( |
| channels % num_head_channels == 0 |
| ), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}" |
| self.num_heads = channels // num_head_channels |
| |
| self.norm = nn.LayerNorm(channels) |
| |
| |
| self.qkv = nn.Linear(channels, channels * 3) |
| self.attention = QKVAttention(self.num_heads) |
| self.relative_position_k = RelativePosition( |
| num_units=channels // self.num_heads, |
| max_relative_position=max_temporal_length, |
| ) |
| self.relative_position_v = RelativePosition( |
| num_units=channels // self.num_heads, |
| max_relative_position=max_temporal_length, |
| ) |
| self.proj_out = nn.Linear(channels, channels) |
|
|
| def forward(self, x, mask=None): |
| b, c, t, h, w = x.shape |
| out = rearrange(x, "b c t h w -> (b h w) t c") |
| |
| |
| |
| |
| |
| qkv = self.qkv(self.norm(out)).transpose(-1, -2) |
| |
| |
| len_q = qkv.size()[-1] |
| len_k, len_v = len_q, len_q |
|
|
| k_rp = self.relative_position_k(len_q, len_k) |
| v_rp = self.relative_position_v(len_q, len_v) |
| |
| out = self.attention(qkv, rp=(k_rp, v_rp)) |
| |
| out = self.proj_out(out.transpose(-1, -2)).transpose(-1, -2) |
| |
| |
| out = rearrange(out, "(b h w) c t -> b c t h w", b=b, h=h, w=w) |
|
|
| return x + out |
| |
| class AttnBlock3D(nn.Module): |
| def __init__(self, in_channels): |
| super().__init__() |
| self.in_channels = in_channels |
|
|
| self.norm = Normalize(in_channels) |
| self.q = torch.nn.Conv3d( |
| in_channels, in_channels, kernel_size=1, stride=1, padding=0 |
| ) |
| self.k = torch.nn.Conv3d( |
| in_channels, in_channels, kernel_size=1, stride=1, padding=0 |
| ) |
| self.v = torch.nn.Conv3d( |
| in_channels, in_channels, kernel_size=1, stride=1, padding=0 |
| ) |
| self.proj_out = torch.nn.Conv3d( |
| in_channels, in_channels, kernel_size=1, stride=1, padding=0 |
| ) |
|
|
| def forward(self, x): |
| h_ = x |
| |
| |
| h_ = self.norm(h_) |
| q = self.q(h_) |
| k = self.k(h_) |
| v = self.v(h_) |
|
|
| b, c, t, h, w = q.shape |
| |
| |
| |
| q = rearrange(q, "b c t h w -> (b t) (h w) c") |
| k = rearrange(k, "b c t h w -> (b t) c (h w)") |
|
|
| w_ = torch.bmm(q, k) |
| w_ = w_ * (int(c) ** (-0.5)) |
| w_ = torch.nn.functional.softmax(w_, dim=2) |
|
|
| |
| v = rearrange(v, "b c t h w -> (b t) c (h w)") |
|
|
| |
| w_ = w_.permute(0, 2, 1) |
| h_ = torch.bmm(v, w_) |
|
|
| |
| h_ = rearrange(h_, "(b t) c (h w) -> b c t h w", b=b, h=h) |
|
|
| h_ = self.proj_out(h_) |
|
|
| return x + h_ |
| |
| class MultiHeadAttention3D(nn.Module): |
| def __init__(self, in_channels, num_heads=8): |
| super().__init__() |
| self.in_channels = in_channels |
| self.num_heads = num_heads |
| self.head_dim = in_channels // num_heads |
|
|
| assert self.head_dim * num_heads == in_channels, "in_channels must be divisible by num_heads" |
|
|
| self.norm = nn.LayerNorm(in_channels) |
| self.q_linear = nn.Linear(in_channels, in_channels) |
| self.k_linear = nn.Linear(in_channels, in_channels) |
| self.v_linear = nn.Linear(in_channels, in_channels) |
| self.proj_out = nn.Linear(in_channels, in_channels) |
|
|
| def forward(self, x): |
| b, c, t, h, w = x.shape |
| |
| |
| h_ = rearrange(x, "b c t h w -> (b t) (h w) c") |
| h_ = self.norm(h_) |
|
|
| |
| q = self.q_linear(h_) |
| k = self.k_linear(h_) |
| v = self.v_linear(h_) |
|
|
| |
| q = rearrange(q, "b l (h d) -> b h l d", h=self.num_heads) |
| k = rearrange(k, "b l (h d) -> b h l d", h=self.num_heads) |
| v = rearrange(v, "b l (h d) -> b h l d", h=self.num_heads) |
|
|
| |
| scores = torch.matmul(q, k.transpose(-2, -1)) / (self.head_dim ** 0.5) |
| attn = F.softmax(scores, dim=-1) |
|
|
| |
| out = torch.matmul(attn, v) |
| out = rearrange(out, "b h l d -> b l (h d)") |
|
|
| |
| out = self.proj_out(out) |
|
|
| |
| out = rearrange(out, "(b t) (h w) c -> b c t h w", b=b, h=h, t=t) |
| |
| return x + out |
|
|
|
|
| class SiglipAE(nn.Module): |
| def __init__(self): |
| super().__init__() |
| temporal_stride=2 |
| norm_type = "group" |
| |
| self.temporal_encoding = nn.Parameter(torch.randn((4,1152))) |
| |
| self.encoder=nn.Sequential( |
| AttnBlock3D(1152), |
| TemporalAttention(1152), |
| |
| SamePadConv3d(1152,1152,kernel_size=3,stride=(temporal_stride, 1, 1),padding_type="replicate"), |
| |
| AttnBlock3D(1152), |
| TemporalAttention(1152), |
| |
| SamePadConv3d(1152,1152,kernel_size=3,stride=(temporal_stride, 1, 1),padding_type="replicate"), |
| |
| ) |
| def forward(self, x): |
| b_,c_,t_,h_,w_=x.shape |
|
|
| temporal_encoding = self.temporal_encoding.unsqueeze(0).unsqueeze(-1).unsqueeze(-1) |
| temporal_encoding = temporal_encoding.expand(b_, -1, -1, h_, w_) |
| temporal_encoding = temporal_encoding.permute(0, 2, 1, 3, 4) |
| x = x + temporal_encoding |
| |
| x=self.encoder(x) |
| return x |
|
|
|
|