diff --git a/examples/llava/minicpmv-convert-image-encoder-to-gguf.py b/examples/llava/minicpmv-convert-image-encoder-to-gguf.py index ea773742a832bb..b44d10820462bf 100644 --- a/examples/llava/minicpmv-convert-image-encoder-to-gguf.py +++ b/examples/llava/minicpmv-convert-image-encoder-to-gguf.py @@ -587,7 +587,6 @@ def bytes_to_unicode(): fname_middle = "mmproj-" has_text_encoder = False has_minicpmv_projector = True - minicpmv_version = 3 elif args.vision_only: fname_middle = "vision-" has_text_encoder = False diff --git a/examples/llava/minicpmv-convert/minicpmv2_0-convert-image-encoder-to-gguf.py b/examples/llava/minicpmv-convert/minicpmv2_0-convert-image-encoder-to-gguf.py deleted file mode 100644 index ab5394900399ca..00000000000000 --- a/examples/llava/minicpmv-convert/minicpmv2_0-convert-image-encoder-to-gguf.py +++ /dev/null @@ -1,405 +0,0 @@ -import argparse -import os -import json -import re - -import torch -import numpy as np -from gguf import * -import timm - -TEXT = "clip.text" -VISION = "clip.vision" - - -def k(raw_key: str, arch: str) -> str: - return raw_key.format(arch=arch) - - -def should_skip_tensor(name: str, has_text: bool, has_vision: bool, has_minicpmv: bool) -> bool: - if name in ( - "logit_scale", - "text_model.embeddings.position_ids", - "vision_model.embeddings.position_ids", - ): - return True - - if has_minicpmv and name in ["visual_projection.weight"]: - return True - - if name.startswith("v") and not has_vision: - return True - - if name.startswith("t") and not has_text: - return True - - return False - - -def get_tensor_name(name: str) -> str: - if "projection" in name: - return name - if "mm_projector" in name: - name = name.replace("model.mm_projector", "mm") - name = re.sub(r'mm\.mlp\.mlp', 'mm.model.mlp', name, count=1) - name = re.sub(r'mm\.peg\.peg', 'mm.model.peg', name, count=1) - return name - - return name.replace("text_model", "t").replace("vision_model", "v").replace("encoder.layers", "blk").replace("embeddings.", "").replace("_proj", "").replace("self_attn.", "attn_").replace("layer_norm", "ln").replace("layernorm", "ln").replace("mlp.fc1", "ffn_down").replace("mlp.fc2", "ffn_up").replace("embedding", "embd").replace("final", "post").replace("layrnorm", "ln") - - -def bytes_to_unicode(): - """ - Returns list of utf-8 byte and a corresponding list of unicode strings. - The reversible bpe codes work on unicode strings. - This means you need a large # of unicode characters in your vocab if you want to avoid UNKs. - When you're at something like a 10B token dataset you end up needing around 5K for decent coverage. - This is a significant percentage of your normal, say, 32K bpe vocab. - To avoid that, we want lookup tables between utf-8 bytes and unicode strings. - And avoids mapping to whitespace/control characters the bpe code barfs on. - """ - bs = ( - list(range(ord("!"), ord("~") + 1)) - + list(range(ord("¡"), ord("¬") + 1)) - + list(range(ord("®"), ord("ÿ") + 1)) - ) - cs = bs[:] - n = 0 - for b in range(2**8): - if b not in bs: - bs.append(b) - cs.append(2**8 + n) - n += 1 - cs = [chr(n) for n in cs] - return dict(zip(bs, cs)) - - -ap = argparse.ArgumentParser() -ap.add_argument("-m", "--model-dir", help="Path to model directory cloned from HF Hub", required=True) -ap.add_argument("--use-f32", action="store_true", default=False, help="Use f32 instead of f16") -ap.add_argument("--text-only", action="store_true", required=False, - help="Save a text-only model. It can't be used to encode images") -ap.add_argument("--vision-only", action="store_true", required=False, - help="Save a vision-only model. It can't be used to encode texts") -ap.add_argument("--clip-model-is-vision", action="store_true", required=False, - help="The clip model is a pure vision model (ShareGPT4V vision extract for example)") -ap.add_argument("--clip-model-is-openclip", action="store_true", required=False, - help="The clip model is from openclip (for ViT-SO400M type))") -ap.add_argument("--minicpmv-projector", help="Path to minicpmv.projector file. If specified, save an image encoder for minicpmv models.") -ap.add_argument("--projector-type", help="Type of projector. Possible values: mlp, ldp, ldpv2", choices=["mlp", "ldp", "ldpv2"], default="mlp") -ap.add_argument("-o", "--output-dir", help="Directory to save GGUF files. Default is the original model directory", default=None) -# Example --image_mean 0.48145466 0.4578275 0.40821073 --image_std 0.26862954 0.26130258 0.27577711 -# Example --image_mean 0.5 0.5 0.5 --image_std 0.5 0.5 0.5 -default_image_mean = [0.48145466, 0.4578275, 0.40821073] -default_image_std = [0.26862954, 0.26130258, 0.27577711] -ap.add_argument('--image-mean', type=float, nargs='+', help='Mean of the images for normalization (overrides processor) ', default=None) -ap.add_argument('--image-std', type=float, nargs='+', help='Standard deviation of the images for normalization (overrides processor)', default=None) - -# with proper -args = ap.parse_args() - - -if args.text_only and args.vision_only: - print("--text-only and --image-only arguments cannot be specified at the same time.") - exit(1) - -if args.use_f32: - print("WARNING: Weights for the convolution op is always saved in f16, as the convolution op in GGML does not support 32-bit kernel weights yet.") - -# output in the same directory as the model if output_dir is None -dir_model = args.model_dir - -if args.clip_model_is_vision or not os.path.exists(dir_model + "/vocab.json") or args.clip_model_is_openclip: - vocab = None - tokens = None -else: - with open(dir_model + "/vocab.json", "r", encoding="utf-8") as f: - vocab = json.load(f) - tokens = [key for key in vocab] - -# possible data types -# ftype == 0 -> float32 -# ftype == 1 -> float16 -# -# map from ftype to string -ftype_str = ["f32", "f16"] - -ftype = 1 -if args.use_f32: - ftype = 0 - -# if args.clip_model_is_vision or args.clip_model_is_openclip: -# model = CLIPVisionModel.from_pretrained(dir_model) -# processor = None -# else: -# model = CLIPModel.from_pretrained(dir_model) -# processor = CLIPProcessor.from_pretrained(dir_model) -model = timm.create_model( - "vit_so400m_patch14_siglip_384.webli", - pretrained=False, - num_classes=0, - dynamic_img_size=True, - dynamic_img_pad=True, -) -processor = None -if model.attn_pool is not None: - model.attn_pool = torch.nn.Identity() - -model.blocks = model.blocks[:-1] -model.load_state_dict(torch.load(os.path.join(dir_model, "minicpmv.clip"))) - -fname_middle = None -has_text_encoder = True -has_vision_encoder = True -has_minicpmv_projector = False -if args.text_only: - fname_middle = "text-" - has_vision_encoder = False -elif args.minicpmv_projector is not None: - fname_middle = "mmproj-" - has_text_encoder = False - has_minicpmv_projector = True - minicpmv_version = 1 -elif args.vision_only: - fname_middle = "vision-" - has_text_encoder = False -else: - fname_middle = "" - -output_dir = args.output_dir if args.output_dir is not None else dir_model -os.makedirs(output_dir, exist_ok=True) -output_prefix = os.path.basename(output_dir).replace("ggml_", "") -fname_out = os.path.join(output_dir, f"{fname_middle}model-{ftype_str[ftype]}.gguf") -fout = GGUFWriter(path=fname_out, arch="clip") - -fout.add_bool("clip.has_text_encoder", has_text_encoder) -fout.add_bool("clip.has_vision_encoder", has_vision_encoder) -fout.add_bool("clip.has_minicpmv_projector", has_minicpmv_projector) -fout.add_file_type(ftype) -if args.text_only: - fout.add_description("text-only CLIP model") -elif args.vision_only and not has_minicpmv_projector: - fout.add_description("vision-only CLIP model") -elif has_minicpmv_projector: - fout.add_description("image encoder for MiniCPM-V") - # add projector type - fout.add_string("clip.projector_type", "resampler") - fout.add_int32("clip.minicpmv_version", minicpmv_version) -else: - fout.add_description("two-tower CLIP model") - -if has_vision_encoder: - # vision_model hparams - fout.add_uint32("clip.vision.image_size", 448) - fout.add_uint32("clip.vision.patch_size", 14) - fout.add_uint32(k(KEY_EMBEDDING_LENGTH, VISION), 1152) - fout.add_uint32(k(KEY_FEED_FORWARD_LENGTH, VISION), 4304) - fout.add_uint32("clip.vision.projection_dim", 0) - fout.add_uint32(k(KEY_ATTENTION_HEAD_COUNT, VISION), 16) - fout.add_float32(k(KEY_ATTENTION_LAYERNORM_EPS, VISION), 1e-6) - block_count = 26 - fout.add_uint32(k(KEY_BLOCK_COUNT, VISION), block_count) - - if processor is not None: - image_mean = processor.image_processor.image_mean if args.image_mean is None or args.image_mean == default_image_mean else args.image_mean - image_std = processor.image_processor.image_std if args.image_std is None or args.image_std == default_image_std else args.image_std - else: - image_mean = args.image_mean if args.image_mean is not None else default_image_mean - image_std = args.image_std if args.image_std is not None else default_image_std - fout.add_array("clip.vision.image_mean", image_mean) - fout.add_array("clip.vision.image_std", image_std) - -use_gelu = True -fout.add_bool("clip.use_gelu", use_gelu) - -def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): - """ - embed_dim: output dimension for each position - pos: a list of positions to be encoded: size (M,) - out: (M, D) - """ - assert embed_dim % 2 == 0 - omega = np.arange(embed_dim // 2, dtype=np.float32) - omega /= embed_dim / 2. - omega = 1. / 10000 ** omega # (D/2,) - - pos = pos.reshape(-1) # (M,) - out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product - - emb_sin = np.sin(out) # (M, D/2) - emb_cos = np.cos(out) # (M, D/2) - - emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D) - return emb - -def get_2d_sincos_pos_embed_from_grid(embed_dim, grid): - assert embed_dim % 2 == 0 - - # use half of dimensions to encode grid_h - emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2) - emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2) - - emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D) - return emb - - -# https://github.com/facebookresearch/mae/blob/efb2a8062c206524e35e47d04501ed4f544c0ae8/util/pos_embed.py#L20 -def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False): - """ - grid_size: int of the grid height and width - return: - pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token) - """ - if isinstance(grid_size, int): - grid_h_size, grid_w_size = grid_size, grid_size - else: - grid_h_size, grid_w_size = grid_size[0], grid_size[1] - - grid_h = np.arange(grid_h_size, dtype=np.float32) - grid_w = np.arange(grid_w_size, dtype=np.float32) - grid = np.meshgrid(grid_w, grid_h) # here w goes first - grid = np.stack(grid, axis=0) - - grid = grid.reshape([2, 1, grid_h_size, grid_w_size]) - pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid) - if cls_token: - pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0) - return pos_embed - -def _replace_name_resampler(s, v): - if re.match("resampler.pos_embed", s): - return { - s: v, - re.sub("pos_embed", "pos_embed_k", s): torch.from_numpy(get_2d_sincos_pos_embed(2304, (448//14, 448//14))), - } - if re.match("resampler.proj", s): - return { - re.sub("proj", "proj.weight", s): v.transpose(-1, -2).contiguous(), - } - if re.match("resampler.attn.in_proj_.*", s): - return { - re.sub("attn.in_proj_", "attn.q.", s): v.chunk(3, dim=0)[0], - re.sub("attn.in_proj_", "attn.k.", s): v.chunk(3, dim=0)[1], - re.sub("attn.in_proj_", "attn.v.", s): v.chunk(3, dim=0)[2], - } - return {s: v} - -if has_minicpmv_projector: - projector = torch.load(args.minicpmv_projector) - new_state_dict = {} - for k, v in projector.items(): - kvs = _replace_name_resampler(k, v) - for nk, nv in kvs.items(): - new_state_dict[nk] = nv - projector = new_state_dict - for name, data in projector.items(): - name = get_tensor_name(name) - data = data.squeeze().numpy() - - n_dims = len(data.shape) - if ftype == 1: - if name[-7:] == ".weight" and n_dims == 2: - print(" Converting to float16") - data = data.astype(np.float16) - ftype_cur = 1 - else: - print(" Converting to float32") - data = data.astype(np.float32) - ftype_cur = 0 - else: - if data.dtype != np.float32: - print(" Converting to float32") - data = data.astype(np.float32) - ftype_cur = 0 - - fout.add_tensor(name, data) - print(f"{name} - {ftype_str[ftype_cur]} - shape = {data.shape}") - - print("Projector tensors added\n") - -def _replace_name(s, v): - if re.match("blocks.([0-9]+).attn.qkv.weight", s): - return { - re.sub("blocks.([0-9]+).attn.qkv.weight", "vision_model.encoder.layers.\\1.self_attn.q_proj.weight", s): v.chunk(3, dim=0)[0], - re.sub("blocks.([0-9]+).attn.qkv.weight", "vision_model.encoder.layers.\\1.self_attn.k_proj.weight", s): v.chunk(3, dim=0)[1], - re.sub("blocks.([0-9]+).attn.qkv.weight", "vision_model.encoder.layers.\\1.self_attn.v_proj.weight", s): v.chunk(3, dim=0)[2], - } - if re.match("blocks.([0-9]+).attn.qkv.bias", s): - return { - re.sub("blocks.([0-9]+).attn.qkv.bias", "vision_model.encoder.layers.\\1.self_attn.q_proj.bias", s): v.chunk(3, dim=0)[0], - re.sub("blocks.([0-9]+).attn.qkv.bias", "vision_model.encoder.layers.\\1.self_attn.k_proj.bias", s): v.chunk(3, dim=0)[1], - re.sub("blocks.([0-9]+).attn.qkv.bias", "vision_model.encoder.layers.\\1.self_attn.v_proj.bias", s): v.chunk(3, dim=0)[2], - } - if re.match("pos_embed", s): - from timm.layers import resample_abs_pos_embed - s = re.sub("pos_embed", "vision_model.embeddings.position_embedding", s) - v = resample_abs_pos_embed(v, (448//14, 448//14), num_prefix_tokens=0) - return {s: v} - - s = re.sub("patch_embed.proj.weight", "vision_model.embeddings.patch_embedding.proj.weight", s) - s = re.sub("patch_embed.proj.bias", "vision_model.embeddings.patch_embedding.proj.bias", s) - - # norm - s = re.sub("blocks.([0-9]+).norm([0-9]+).weight", "vision_model.encoder.layers.\\1.layer_norm\\2.weight", s) - s = re.sub("blocks.([0-9]+).norm([0-9]+).bias", "vision_model.encoder.layers.\\1.layer_norm\\2.bias", s) - - s = re.sub("blocks.([0-9]+).attn.proj.weight", "vision_model.encoder.layers.\\1.self_attn.out_proj.weight", s) - s = re.sub("blocks.([0-9]+).attn.proj.bias", "vision_model.encoder.layers.\\1.self_attn.out_proj.bias", s) - - s = re.sub("blocks.([0-9]+).mlp.fc([0-9]+).weight", "vision_model.encoder.layers.\\1.mlp.fc\\2.weight", s) - s = re.sub("blocks.([0-9]+).mlp.fc([0-9]+).bias", "vision_model.encoder.layers.\\1.mlp.fc\\2.bias", s) - - s = re.sub("norm.weight", "vision_model.post_layernorm.weight", s) - s = re.sub("norm.bias", "vision_model.post_layernorm.bias", s) - - return {s: v} - -state_dict = model.state_dict() -new_state_dict = {} -for k, v in state_dict.items(): - kvs = _replace_name(k, v) - for nk, nv in kvs.items(): - new_state_dict[nk] = nv -state_dict = new_state_dict -for name, data in state_dict.items(): - if should_skip_tensor(name, has_text_encoder, has_vision_encoder, has_minicpmv_projector): - # we don't need this - print(f"skipping parameter: {name}") - continue - - name = get_tensor_name(name) - data = data.squeeze().numpy() - - n_dims = len(data.shape) - - # ftype == 0 -> float32, ftype == 1 -> float16 - ftype_cur = 0 - if n_dims == 4: - print(f"tensor {name} is always saved in f16") - data = data.astype(np.float16) - ftype_cur = 1 - elif ftype == 1: - if name[-7:] == ".weight" and n_dims == 2: - print(" Converting to float16") - data = data.astype(np.float16) - ftype_cur = 1 - else: - print(" Converting to float32") - data = data.astype(np.float32) - ftype_cur = 0 - else: - if data.dtype != np.float32: - print(" Converting to float32") - data = data.astype(np.float32) - ftype_cur = 0 - - print(f"{name} - {ftype_str[ftype_cur]} - shape = {data.shape}") - fout.add_tensor(name, data) - - -fout.write_header_to_file() -fout.write_kv_data_to_file() -fout.write_tensors_to_file() -fout.close() - -print("Done. Output file: " + fname_out) diff --git a/examples/llava/minicpmv-convert/minicpmv2_0-surgery.py b/examples/llava/minicpmv-convert/minicpmv2_0-surgery.py deleted file mode 100644 index cb6168e624e1d7..00000000000000 --- a/examples/llava/minicpmv-convert/minicpmv2_0-surgery.py +++ /dev/null @@ -1,48 +0,0 @@ -import argparse -import glob -import os -import torch -from transformers import AutoModel, AutoTokenizer - -ap = argparse.ArgumentParser() -ap.add_argument("-m", "--model", help="Path to MiniCPM-V-2.0 model") -args = ap.parse_args() - -# find the model part that includes the the multimodal projector weights -model = AutoModel.from_pretrained(args.model, trust_remote_code=True) -checkpoint = model.state_dict() - -# get a list of mm tensor names -mm_tensors = [k for k, v in checkpoint.items() if k.startswith("resampler")] - -# store these tensors in a new dictionary and torch.save them -projector = {name: checkpoint[name].float() for name in mm_tensors} -torch.save(projector, f"{args.model}/minicpmv.projector") - -clip_tensors = [k for k, v in checkpoint.items() if k.startswith("vpm")] -if len(clip_tensors) > 0: - clip = {name.replace("vpm.", ""): checkpoint[name].float() for name in clip_tensors} - torch.save(clip, f"{args.model}/minicpmv.clip") - - # added tokens should be removed to be able to convert Mistral models - if os.path.exists(f"{args.model}/added_tokens.json"): - with open(f"{args.model}/added_tokens.json", "w") as f: - f.write("{}\n") - -config = model.llm.config -config._name_or_path = "openbmb/CPM-2B" -config.auto_map = { - "AutoConfig": "configuration_minicpm.MiniCPMConfig", - "AutoModel": "modeling_minicpm.MiniCPMModel", - "AutoModelForCausalLM": "modeling_minicpm.MiniCPMForCausalLM", - "AutoModelForSeq2SeqLM": "modeling_minicpm.MiniCPMForCausalLM", - "AutoModelForSequenceClassification": "modeling_minicpm.MiniCPMForSequenceClassification" -} -model.llm.save_pretrained(f"{args.model}/model") -tok = AutoTokenizer.from_pretrained(args.model, trust_remote_code=True) -tok.save_pretrained(f"{args.model}/model") -# os.system(f"cp {args.model}/modeling_minicpm.py {args.model}/model/modeling_minicpm.py") - -print("Done!") -print(f"Now you can convert {args.model} to a regular LLaMA GGUF file.") -print(f"Also, use {args.model}/minicpmv.projector to prepare a minicpmv-encoder.gguf file.") diff --git a/examples/llava/minicpmv-convert/minicpmv2_5-convert-image-encoder-to-gguf.py b/examples/llava/minicpmv-convert/minicpmv2_5-convert-image-encoder-to-gguf.py deleted file mode 100644 index fa361bea3bfe3e..00000000000000 --- a/examples/llava/minicpmv-convert/minicpmv2_5-convert-image-encoder-to-gguf.py +++ /dev/null @@ -1,384 +0,0 @@ -import argparse -import os -import json -import re - -import torch -import numpy as np -from gguf import * -import timm -from transformers.models.idefics2.modeling_idefics2 import Idefics2VisionTransformer, Idefics2VisionConfig - -TEXT = "clip.text" -VISION = "clip.vision" - - -def k(raw_key: str, arch: str) -> str: - return raw_key.format(arch=arch) - - -def should_skip_tensor(name: str, has_text: bool, has_vision: bool, has_minicpmv: bool) -> bool: - if name in ( - "logit_scale", - "text_model.embeddings.position_ids", - "vision_model.embeddings.position_ids", - ): - return True - - if has_minicpmv and name in ["visual_projection.weight"]: - return True - - if name.startswith("v") and not has_vision: - return True - - if name.startswith("t") and not has_text: - return True - - return False - - -def get_tensor_name(name: str) -> str: - if "projection" in name: - return name - if "mm_projector" in name: - name = name.replace("model.mm_projector", "mm") - name = re.sub(r'mm\.mlp\.mlp', 'mm.model.mlp', name, count=1) - name = re.sub(r'mm\.peg\.peg', 'mm.model.peg', name, count=1) - return name - - return name.replace("text_model", "t").replace("vision_model", "v").replace("encoder.layers", "blk").replace("embeddings.", "").replace("_proj", "").replace("self_attn.", "attn_").replace("layer_norm", "ln").replace("layernorm", "ln").replace("mlp.fc1", "ffn_down").replace("mlp.fc2", "ffn_up").replace("embedding", "embd").replace("final", "post").replace("layrnorm", "ln") - - -def bytes_to_unicode(): - """ - Returns list of utf-8 byte and a corresponding list of unicode strings. - The reversible bpe codes work on unicode strings. - This means you need a large # of unicode characters in your vocab if you want to avoid UNKs. - When you're at something like a 10B token dataset you end up needing around 5K for decent coverage. - This is a significant percentage of your normal, say, 32K bpe vocab. - To avoid that, we want lookup tables between utf-8 bytes and unicode strings. - And avoids mapping to whitespace/control characters the bpe code barfs on. - """ - bs = ( - list(range(ord("!"), ord("~") + 1)) - + list(range(ord("¡"), ord("¬") + 1)) - + list(range(ord("®"), ord("ÿ") + 1)) - ) - cs = bs[:] - n = 0 - for b in range(2**8): - if b not in bs: - bs.append(b) - cs.append(2**8 + n) - n += 1 - cs = [chr(n) for n in cs] - return dict(zip(bs, cs)) - - -ap = argparse.ArgumentParser() -ap.add_argument("-m", "--model-dir", help="Path to model directory cloned from HF Hub", required=True) -ap.add_argument("--use-f32", action="store_true", default=False, help="Use f32 instead of f16") -ap.add_argument("--text-only", action="store_true", required=False, - help="Save a text-only model. It can't be used to encode images") -ap.add_argument("--vision-only", action="store_true", required=False, - help="Save a vision-only model. It can't be used to encode texts") -ap.add_argument("--clip-model-is-vision", action="store_true", required=False, - help="The clip model is a pure vision model (ShareGPT4V vision extract for example)") -ap.add_argument("--clip-model-is-openclip", action="store_true", required=False, - help="The clip model is from openclip (for ViT-SO400M type))") -ap.add_argument("--minicpmv-projector", help="Path to minicpmv.projector file. If specified, save an image encoder for MiniCPM-V models.") -ap.add_argument("--projector-type", help="Type of projector. Possible values: mlp, ldp, ldpv2", choices=["mlp", "ldp", "ldpv2"], default="mlp") -ap.add_argument("-o", "--output-dir", help="Directory to save GGUF files. Default is the original model directory", default=None) -# Example --image_mean 0.48145466 0.4578275 0.40821073 --image_std 0.26862954 0.26130258 0.27577711 -# Example --image_mean 0.5 0.5 0.5 --image_std 0.5 0.5 0.5 -default_image_mean = [0.48145466, 0.4578275, 0.40821073] -default_image_std = [0.26862954, 0.26130258, 0.27577711] -ap.add_argument('--image-mean', type=float, nargs='+', help='Mean of the images for normalization (overrides processor) ', default=None) -ap.add_argument('--image-std', type=float, nargs='+', help='Standard deviation of the images for normalization (overrides processor)', default=None) - -# with proper -args = ap.parse_args() - - -if args.text_only and args.vision_only: - print("--text-only and --image-only arguments cannot be specified at the same time.") - exit(1) - -if args.use_f32: - print("WARNING: Weights for the convolution op is always saved in f16, as the convolution op in GGML does not support 32-bit kernel weights yet.") - -# output in the same directory as the model if output_dir is None -dir_model = args.model_dir - -if args.clip_model_is_vision or not os.path.exists(dir_model + "/vocab.json") or args.clip_model_is_openclip: - vocab = None - tokens = None -else: - with open(dir_model + "/vocab.json", "r", encoding="utf-8") as f: - vocab = json.load(f) - tokens = [key for key in vocab] - -# possible data types -# ftype == 0 -> float32 -# ftype == 1 -> float16 -# -# map from ftype to string -ftype_str = ["f32", "f16"] - -ftype = 1 -if args.use_f32: - ftype = 0 - -# if args.clip_model_is_vision or args.clip_model_is_openclip: -# model = CLIPVisionModel.from_pretrained(dir_model) -# processor = None -# else: -# model = CLIPModel.from_pretrained(dir_model) -# processor = CLIPProcessor.from_pretrained(dir_model) - -default_vision_config = { - "hidden_size": 1152, - "image_size": 980, - "intermediate_size": 4304, - "model_type": "idefics2", - "num_attention_heads": 16, - "num_hidden_layers": 27, - "patch_size": 14, - } -vision_config = Idefics2VisionConfig(**default_vision_config) -model = Idefics2VisionTransformer(vision_config) - -processor = None -# if model.attn_pool is not None: -# model.attn_pool = torch.nn.Identity() - -# model.blocks = model.blocks[:-1] -model.load_state_dict(torch.load(os.path.join(dir_model, "minicpmv.clip"))) - -fname_middle = None -has_text_encoder = True -has_vision_encoder = True -has_minicpmv_projector = False -if args.text_only: - fname_middle = "text-" - has_vision_encoder = False -elif args.minicpmv_projector is not None: - fname_middle = "mmproj-" - has_text_encoder = False - has_minicpmv_projector = True - minicpmv_version = 2 -elif args.vision_only: - fname_middle = "vision-" - has_text_encoder = False -else: - fname_middle = "" - -output_dir = args.output_dir if args.output_dir is not None else dir_model -os.makedirs(output_dir, exist_ok=True) -output_prefix = os.path.basename(output_dir).replace("ggml_", "") -fname_out = os.path.join(output_dir, f"{fname_middle}model-{ftype_str[ftype]}.gguf") -fout = GGUFWriter(path=fname_out, arch="clip") - -fout.add_bool("clip.has_text_encoder", has_text_encoder) -fout.add_bool("clip.has_vision_encoder", has_vision_encoder) -fout.add_bool("clip.has_minicpmv_projector", has_minicpmv_projector) -fout.add_file_type(ftype) -if args.text_only: - fout.add_description("text-only CLIP model") -elif args.vision_only and not has_minicpmv_projector: - fout.add_description("vision-only CLIP model") -elif has_minicpmv_projector: - fout.add_description("image encoder for MiniCPM-V") - # add projector type - fout.add_string("clip.projector_type", "resampler") - fout.add_int32("clip.minicpmv_version", minicpmv_version) -else: - fout.add_description("two-tower CLIP model") - -if has_vision_encoder: - # vision_model hparams - fout.add_uint32("clip.vision.image_size", 448) - fout.add_uint32("clip.vision.patch_size", 14) - fout.add_uint32(k(KEY_EMBEDDING_LENGTH, VISION), 1152) - fout.add_uint32(k(KEY_FEED_FORWARD_LENGTH, VISION), 4304) - fout.add_uint32("clip.vision.projection_dim", 0) - fout.add_uint32(k(KEY_ATTENTION_HEAD_COUNT, VISION), 16) - fout.add_float32(k(KEY_ATTENTION_LAYERNORM_EPS, VISION), 1e-6) - block_count = 26 - fout.add_uint32(k(KEY_BLOCK_COUNT, VISION), block_count) - - if processor is not None: - image_mean = processor.image_processor.image_mean if args.image_mean is None or args.image_mean == default_image_mean else args.image_mean - image_std = processor.image_processor.image_std if args.image_std is None or args.image_std == default_image_std else args.image_std - else: - image_mean = args.image_mean if args.image_mean is not None else default_image_mean - image_std = args.image_std if args.image_std is not None else default_image_std - fout.add_array("clip.vision.image_mean", image_mean) - fout.add_array("clip.vision.image_std", image_std) - -use_gelu = True -fout.add_bool("clip.use_gelu", use_gelu) - -def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): - """ - embed_dim: output dimension for each position - pos: a list of positions to be encoded: size (M,) - out: (M, D) - """ - assert embed_dim % 2 == 0 - omega = np.arange(embed_dim // 2, dtype=np.float32) - omega /= embed_dim / 2. - omega = 1. / 10000 ** omega # (D/2,) - - pos = pos.reshape(-1) # (M,) - out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product - - emb_sin = np.sin(out) # (M, D/2) - emb_cos = np.cos(out) # (M, D/2) - - emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D) - return emb - -def get_2d_sincos_pos_embed_from_grid(embed_dim, grid): - assert embed_dim % 2 == 0 - - # use half of dimensions to encode grid_h - emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2) - emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2) - - emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D) - return emb - - -# https://github.com/facebookresearch/mae/blob/efb2a8062c206524e35e47d04501ed4f544c0ae8/util/pos_embed.py#L20 -def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False): - """ - grid_size: int of the grid height and width - return: - pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token) - """ - if isinstance(grid_size, int): - grid_h_size, grid_w_size = grid_size, grid_size - else: - grid_h_size, grid_w_size = grid_size[0], grid_size[1] - - grid_h = np.arange(grid_h_size, dtype=np.float32) - grid_w = np.arange(grid_w_size, dtype=np.float32) - grid = np.meshgrid(grid_w, grid_h) # here w goes first - grid = np.stack(grid, axis=0) - - grid = grid.reshape([2, 1, grid_h_size, grid_w_size]) - pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid) - if cls_token: - pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0) - return pos_embed - -def _replace_name_resampler(s, v): - if re.match("resampler.pos_embed", s): - return { - s: v, - re.sub("pos_embed", "pos_embed_k", s): torch.from_numpy(get_2d_sincos_pos_embed(4096, (70, 70))), - } - if re.match("resampler.proj", s): - return { - re.sub("proj", "pos_embed_k", s): torch.from_numpy(get_2d_sincos_pos_embed(4096, (70, 70))), - re.sub("proj", "proj.weight", s): v.transpose(-1, -2).contiguous(), - } - if re.match("resampler.attn.in_proj_.*", s): - return { - re.sub("attn.in_proj_", "attn.q.", s): v.chunk(3, dim=0)[0], - re.sub("attn.in_proj_", "attn.k.", s): v.chunk(3, dim=0)[1], - re.sub("attn.in_proj_", "attn.v.", s): v.chunk(3, dim=0)[2], - } - return {s: v} - -if has_minicpmv_projector: - projector = torch.load(args.minicpmv_projector) - new_state_dict = {} - for k, v in projector.items(): - kvs = _replace_name_resampler(k, v) - for nk, nv in kvs.items(): - new_state_dict[nk] = nv - projector = new_state_dict - for name, data in projector.items(): - name = get_tensor_name(name) - data = data.squeeze().numpy() - - n_dims = len(data.shape) - if ftype == 1: - if name[-7:] == ".weight" and n_dims == 2: - print(" Converting to float16") - data = data.astype(np.float16) - ftype_cur = 1 - else: - print(" Converting to float32") - data = data.astype(np.float32) - ftype_cur = 0 - else: - if data.dtype != np.float32: - print(" Converting to float32") - data = data.astype(np.float32) - ftype_cur = 0 - - fout.add_tensor(name, data) - print(f"{name} - {ftype_str[ftype_cur]} - shape = {data.shape}") - - print("Projector tensors added\n") - -def _replace_name(s, v): - s = "vision_model." + s - if re.match("vision_model.embeddings.position_embedding", s): - v = v.unsqueeze(0) - return {s: v} - - return {s: v} - -state_dict = model.state_dict() -new_state_dict = {} -for k, v in state_dict.items(): - kvs = _replace_name(k, v) - for nk, nv in kvs.items(): - new_state_dict[nk] = nv -state_dict = new_state_dict -for name, data in state_dict.items(): - if should_skip_tensor(name, has_text_encoder, has_vision_encoder, has_minicpmv_projector): - # we don't need this - print(f"skipping parameter: {name}") - continue - - name = get_tensor_name(name) - data = data.squeeze().numpy() - - n_dims = len(data.shape) - - # ftype == 0 -> float32, ftype == 1 -> float16 - ftype_cur = 0 - if n_dims == 4: - print(f"tensor {name} is always saved in f16") - data = data.astype(np.float16) - ftype_cur = 1 - elif ftype == 1: - if name[-7:] == ".weight" and n_dims == 2: - print(" Converting to float16") - data = data.astype(np.float16) - ftype_cur = 1 - else: - print(" Converting to float32") - data = data.astype(np.float32) - ftype_cur = 0 - else: - if data.dtype != np.float32: - print(" Converting to float32") - data = data.astype(np.float32) - ftype_cur = 0 - - print(f"{name} - {ftype_str[ftype_cur]} - shape = {data.shape}") - fout.add_tensor(name, data) - - -fout.write_header_to_file() -fout.write_kv_data_to_file() -fout.write_tensors_to_file() -fout.close() - -print("Done. Output file: " + fname_out) diff --git a/examples/llava/minicpmv-convert/minicpmv2_5-surgery.py b/examples/llava/minicpmv-convert/minicpmv2_5-surgery.py deleted file mode 100644 index 7defb8ff7aa9a1..00000000000000 --- a/examples/llava/minicpmv-convert/minicpmv2_5-surgery.py +++ /dev/null @@ -1,47 +0,0 @@ -import argparse -import glob -import os -import torch -from transformers import AutoModel, AutoTokenizer - -ap = argparse.ArgumentParser() -ap.add_argument("-m", "--model", help="Path to MiniCPM-V-2.5 model") -args = ap.parse_args() - -# find the model part that includes the the multimodal projector weights -model = AutoModel.from_pretrained(args.model, trust_remote_code=True, local_files_only=True) -checkpoint = model.state_dict() - -# get a list of mm tensor names -mm_tensors = [k for k, v in checkpoint.items() if k.startswith("resampler")] - -# store these tensors in a new dictionary and torch.save them -projector = {name: checkpoint[name].float() for name in mm_tensors} -torch.save(projector, f"{args.model}/minicpmv.projector") - -clip_tensors = [k for k, v in checkpoint.items() if k.startswith("vpm")] -if len(clip_tensors) > 0: - clip = {name.replace("vpm.", ""): checkpoint[name].float() for name in clip_tensors} - torch.save(clip, f"{args.model}/minicpmv.clip") - - # added tokens should be removed to be able to convert Mistral models - if os.path.exists(f"{args.model}/added_tokens.json"): - with open(f"{args.model}/added_tokens.json", "w") as f: - f.write("{}\n") - -config = model.llm.config -config._name_or_path = "openbmb/MiniCPM-Llama3-V-2.5" -config.auto_map = { - "AutoConfig": "configuration_minicpm.MiniCPMConfig", - "AutoModel": "modeling_minicpm.MiniCPMModel", - "AutoModelForCausalLM": "modeling_minicpm.MiniCPMForCausalLM", - "AutoModelForSeq2SeqLM": "modeling_minicpm.MiniCPMForCausalLM", - "AutoModelForSequenceClassification": "modeling_minicpm.MiniCPMForSequenceClassification" -} -model.llm.save_pretrained(f"{args.model}/model") -tok = AutoTokenizer.from_pretrained(args.model, trust_remote_code=True) -tok.save_pretrained(f"{args.model}/model") - -print("Done!") -print(f"Now you can convert {args.model} to a regular LLaMA GGUF file.") -print(f"Also, use {args.model}/minicpmv.projector to prepare a minicpmv-encoder.gguf file.") diff --git a/examples/llava/minicpmv-convert/minicpmv2_6-convert-image-encoder-to-gguf.py b/examples/llava/minicpmv-convert/minicpmv2_6-convert-image-encoder-to-gguf.py deleted file mode 100644 index cf907eba54815c..00000000000000 --- a/examples/llava/minicpmv-convert/minicpmv2_6-convert-image-encoder-to-gguf.py +++ /dev/null @@ -1,1321 +0,0 @@ -# coding=utf-8 -# Copyright 2024 Google AI and The HuggingFace Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" PyTorch Siglip model. """ -# Copied from HuggingFaceM4/siglip-so400m-14-980-flash-attn2-navit and add tgt_sizes - - -import os -import math -import warnings -from dataclasses import dataclass -from typing import Any, Optional, Tuple, Union - -import numpy as np -import torch -import torch.nn.functional as F -import torch.utils.checkpoint -from torch import nn -from torch.nn.init import _calculate_fan_in_and_fan_out - -from transformers.activations import ACT2FN -from transformers.modeling_attn_mask_utils import _prepare_4d_attention_mask -from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling -from transformers.modeling_utils import PreTrainedModel -from transformers.configuration_utils import PretrainedConfig -from transformers.utils import ( - ModelOutput, - add_start_docstrings, - add_start_docstrings_to_model_forward, - is_flash_attn_2_available, - logging, - replace_return_docstrings, -) -from transformers.utils import logging - -logger = logging.get_logger(__name__) - -class SiglipVisionConfig(PretrainedConfig): - r""" - This is the configuration class to store the configuration of a [`SiglipVisionModel`]. It is used to instantiate a - Siglip vision encoder according to the specified arguments, defining the model architecture. Instantiating a - configuration with the defaults will yield a similar configuration to that of the vision encoder of the Siglip - [google/siglip-base-patch16-224](https://huggingface.co/google/siglip-base-patch16-224) architecture. - Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the - documentation from [`PretrainedConfig`] for more information. - Args: - hidden_size (`int`, *optional*, defaults to 768): - Dimensionality of the encoder layers and the pooler layer. - intermediate_size (`int`, *optional*, defaults to 3072): - Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder. - num_hidden_layers (`int`, *optional*, defaults to 12): - Number of hidden layers in the Transformer encoder. - num_attention_heads (`int`, *optional*, defaults to 12): - Number of attention heads for each attention layer in the Transformer encoder. - num_channels (`int`, *optional*, defaults to 3): - Number of channels in the input images. - image_size (`int`, *optional*, defaults to 224): - The size (resolution) of each image. - patch_size (`int`, *optional*, defaults to 16): - The size (resolution) of each patch. - hidden_act (`str` or `function`, *optional*, defaults to `"gelu_pytorch_tanh"`): - The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`, - `"relu"`, `"selu"` and `"gelu_new"` ``"quick_gelu"` are supported. - layer_norm_eps (`float`, *optional*, defaults to 1e-06): - The epsilon used by the layer normalization layers. - attention_dropout (`float`, *optional*, defaults to 0.0): - The dropout ratio for the attention probabilities. - Example: - ```python - >>> from transformers import SiglipVisionConfig, SiglipVisionModel - >>> # Initializing a SiglipVisionConfig with google/siglip-base-patch16-224 style configuration - >>> configuration = SiglipVisionConfig() - >>> # Initializing a SiglipVisionModel (with random weights) from the google/siglip-base-patch16-224 style configuration - >>> model = SiglipVisionModel(configuration) - >>> # Accessing the model configuration - >>> configuration = model.config - ```""" - - model_type = "siglip_vision_model" - - def __init__( - self, - hidden_size=768, - intermediate_size=3072, - num_hidden_layers=12, - num_attention_heads=12, - num_channels=3, - image_size=224, - patch_size=16, - hidden_act="gelu_pytorch_tanh", - layer_norm_eps=1e-6, - attention_dropout=0.0, - **kwargs, - ): - super().__init__(**kwargs) - - self.hidden_size = hidden_size - self.intermediate_size = intermediate_size - self.num_hidden_layers = num_hidden_layers - self.num_attention_heads = num_attention_heads - self.num_channels = num_channels - self.patch_size = patch_size - self.image_size = image_size - self.attention_dropout = attention_dropout - self.layer_norm_eps = layer_norm_eps - self.hidden_act = hidden_act - - @classmethod - def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> "PretrainedConfig": - cls._set_token_in_kwargs(kwargs) - - config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs) - - # get the vision config dict if we are loading from SiglipConfig - if config_dict.get("model_type") == "siglip": - config_dict = config_dict["vision_config"] - - if "model_type" in config_dict and hasattr(cls, "model_type") and config_dict["model_type"] != cls.model_type: - logger.warning( - f"You are using a model of type {config_dict['model_type']} to instantiate a model of type " - f"{cls.model_type}. This is not supported for all configurations of models and can yield errors." - ) - - return cls.from_dict(config_dict, **kwargs) - - -_CHECKPOINT_FOR_DOC = "google/siglip-base-patch16-224" - -SIGLIP_PRETRAINED_MODEL_ARCHIVE_LIST = [ - "google/siglip-base-patch16-224", - # See all SigLIP models at https://huggingface.co/models?filter=siglip -] - -if is_flash_attn_2_available(): - from flash_attn import flash_attn_func, flash_attn_varlen_func - from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa - - -# Copied from transformers.models.llama.modeling_llama._get_unpad_data -def _get_unpad_data(attention_mask): - seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32) - indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten() - max_seqlen_in_batch = seqlens_in_batch.max().item() - cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0)) - return ( - indices, - cu_seqlens, - max_seqlen_in_batch, - ) - - -def _trunc_normal_(tensor, mean, std, a, b): - # Cut & paste from PyTorch official master until it's in a few official releases - RW - # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf - def norm_cdf(x): - # Computes standard normal cumulative distribution function - return (1.0 + math.erf(x / math.sqrt(2.0))) / 2.0 - - if (mean < a - 2 * std) or (mean > b + 2 * std): - warnings.warn( - "mean is more than 2 std from [a, b] in nn.init.trunc_normal_. " - "The distribution of values may be incorrect.", - stacklevel=2, - ) - - # Values are generated by using a truncated uniform distribution and - # then using the inverse CDF for the normal distribution. - # Get upper and lower cdf values - l = norm_cdf((a - mean) / std) - u = norm_cdf((b - mean) / std) - - # Uniformly fill tensor with values from [l, u], then translate to - # [2l-1, 2u-1]. - tensor.uniform_(2 * l - 1, 2 * u - 1) - - # Use inverse cdf transform for normal distribution to get truncated - # standard normal - if tensor.dtype in [torch.float16, torch.bfloat16]: - # The `erfinv_` op is not (yet?) defined in float16+cpu, bfloat16+gpu - og_dtype = tensor.dtype - tensor = tensor.to(torch.float32) - tensor.erfinv_() - tensor = tensor.to(og_dtype) - else: - tensor.erfinv_() - - # Transform to proper mean, std - tensor.mul_(std * math.sqrt(2.0)) - tensor.add_(mean) - - # Clamp to ensure it's in the proper range - if tensor.dtype == torch.float16: - # The `clamp_` op is not (yet?) defined in float16+cpu - tensor = tensor.to(torch.float32) - tensor.clamp_(min=a, max=b) - tensor = tensor.to(torch.float16) - else: - tensor.clamp_(min=a, max=b) - - -def trunc_normal_tf_( - tensor: torch.Tensor, mean: float = 0.0, std: float = 1.0, a: float = -2.0, b: float = 2.0 -) -> torch.Tensor: - """Fills the input Tensor with values drawn from a truncated - normal distribution. The values are effectively drawn from the - normal distribution :math:`\\mathcal{N}(\text{mean}, \text{std}^2)` - with values outside :math:`[a, b]` redrawn until they are within - the bounds. The method used for generating the random values works - best when :math:`a \\leq \text{mean} \\leq b`. - NOTE: this 'tf' variant behaves closer to Tensorflow / JAX impl where the - bounds [a, b] are applied when sampling the normal distribution with mean=0, std=1.0 - and the result is subsquently scaled and shifted by the mean and std args. - Args: - tensor: an n-dimensional `torch.Tensor` - mean: the mean of the normal distribution - std: the standard deviation of the normal distribution - a: the minimum cutoff value - b: the maximum cutoff value - """ - with torch.no_grad(): - _trunc_normal_(tensor, 0, 1.0, a, b) - tensor.mul_(std).add_(mean) - - -def variance_scaling_(tensor, scale=1.0, mode="fan_in", distribution="normal"): - fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor) - if mode == "fan_in": - denom = fan_in - elif mode == "fan_out": - denom = fan_out - elif mode == "fan_avg": - denom = (fan_in + fan_out) / 2 - - variance = scale / denom - - if distribution == "truncated_normal": - # constant is stddev of standard normal truncated to (-2, 2) - trunc_normal_tf_(tensor, std=math.sqrt(variance) / 0.87962566103423978) - elif distribution == "normal": - with torch.no_grad(): - tensor.normal_(std=math.sqrt(variance)) - elif distribution == "uniform": - bound = math.sqrt(3 * variance) - with torch.no_grad(): - tensor.uniform_(-bound, bound) - else: - raise ValueError(f"invalid distribution {distribution}") - - -def lecun_normal_(tensor): - variance_scaling_(tensor, mode="fan_in", distribution="truncated_normal") - - -def default_flax_embed_init(tensor): - variance_scaling_(tensor, mode="fan_in", distribution="normal") - - -@dataclass -# Copied from transformers.models.clip.modeling_clip.CLIPVisionModelOutput with CLIP->Siglip -class SiglipVisionModelOutput(ModelOutput): - """ - Base class for vision model's outputs that also contains image embeddings of the pooling of the last hidden states. - Args: - image_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim)` *optional* returned when model is initialized with `with_projection=True`): - The image embeddings obtained by applying the projection layer to the pooler_output. - last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): - Sequence of hidden-states at the output of the last layer of the model. - hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): - Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + - one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. - Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. - attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): - Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, - sequence_length)`. - Attentions weights after the attention softmax, used to compute the weighted average in the self-attention - heads. - """ - - image_embeds: Optional[torch.FloatTensor] = None - last_hidden_state: torch.FloatTensor = None - hidden_states: Optional[Tuple[torch.FloatTensor]] = None - attentions: Optional[Tuple[torch.FloatTensor]] = None - - -class SiglipVisionEmbeddings(nn.Module): - def __init__(self, config: SiglipVisionConfig): - super().__init__() - self.config = config - self.embed_dim = config.hidden_size - self.image_size = config.image_size - self.patch_size = config.patch_size - - self.patch_embedding = nn.Conv2d( - in_channels=config.num_channels, - out_channels=self.embed_dim, - kernel_size=self.patch_size, - stride=self.patch_size, - padding="valid", - ) - - self.num_patches_per_side = self.image_size // self.patch_size - self.num_patches = self.num_patches_per_side**2 - self.num_positions = self.num_patches - self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim) - - def forward(self, pixel_values: torch.FloatTensor, patch_attention_mask: torch.BoolTensor, tgt_sizes: Optional[torch.IntTensor]=None) -> torch.Tensor: - batch_size = pixel_values.size(0) - - patch_embeds = self.patch_embedding(pixel_values) - embeddings = patch_embeds.flatten(2).transpose(1, 2) - - max_im_h, max_im_w = pixel_values.size(2), pixel_values.size(3) - max_nb_patches_h, max_nb_patches_w = max_im_h // self.patch_size, max_im_w // self.patch_size - boundaries = torch.arange(1 / self.num_patches_per_side, 1.0, 1 / self.num_patches_per_side) - position_ids = torch.full( - size=( - batch_size, - max_nb_patches_h * max_nb_patches_w, - ), - fill_value=0, - ) - - for batch_idx, p_attn_mask in enumerate(patch_attention_mask): - if tgt_sizes is not None: - nb_patches_h = tgt_sizes[batch_idx][0] - nb_patches_w = tgt_sizes[batch_idx][1] - else: - nb_patches_h = p_attn_mask[:, 0].sum() - nb_patches_w = p_attn_mask[0].sum() - - fractional_coords_h = torch.arange(0, 1 - 1e-6, 1 / nb_patches_h) - fractional_coords_w = torch.arange(0, 1 - 1e-6, 1 / nb_patches_w) - - bucket_coords_h = torch.bucketize(fractional_coords_h, boundaries, right=True) - bucket_coords_w = torch.bucketize(fractional_coords_w, boundaries, right=True) - - pos_ids = (bucket_coords_h[:, None] * self.num_patches_per_side + bucket_coords_w).flatten() - position_ids[batch_idx][p_attn_mask.view(-1).cpu()] = pos_ids - - position_ids = position_ids.to(self.position_embedding.weight.device) - - embeddings = embeddings + self.position_embedding(position_ids) - return embeddings - - -class SiglipAttention(nn.Module): - """Multi-headed attention from 'Attention Is All You Need' paper""" - - # Copied from transformers.models.clip.modeling_clip.CLIPAttention.__init__ - def __init__(self, config): - super().__init__() - self.config = config - self.embed_dim = config.hidden_size - self.num_heads = config.num_attention_heads - self.head_dim = self.embed_dim // self.num_heads - if self.head_dim * self.num_heads != self.embed_dim: - raise ValueError( - f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:" - f" {self.num_heads})." - ) - self.scale = self.head_dim**-0.5 - self.dropout = config.attention_dropout - - self.k_proj = nn.Linear(self.embed_dim, self.embed_dim) - self.v_proj = nn.Linear(self.embed_dim, self.embed_dim) - self.q_proj = nn.Linear(self.embed_dim, self.embed_dim) - self.out_proj = nn.Linear(self.embed_dim, self.embed_dim) - - def forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - output_attentions: Optional[bool] = False, - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - """Input shape: Batch x Time x Channel""" - - batch_size, q_len, _ = hidden_states.size() - - query_states = self.q_proj(hidden_states) - key_states = self.k_proj(hidden_states) - value_states = self.v_proj(hidden_states) - - query_states = query_states.view(batch_size, q_len, self.num_heads, self.head_dim).transpose(1, 2) - key_states = key_states.view(batch_size, q_len, self.num_heads, self.head_dim).transpose(1, 2) - value_states = value_states.view(batch_size, q_len, self.num_heads, self.head_dim).transpose(1, 2) - - k_v_seq_len = key_states.shape[-2] - attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) * self.scale - - if attn_weights.size() != (batch_size, self.num_heads, q_len, k_v_seq_len): - raise ValueError( - f"Attention weights should be of size {(batch_size, self.num_heads, q_len, k_v_seq_len)}, but is" - f" {attn_weights.size()}" - ) - - if attention_mask is not None: - if attention_mask.size() != (batch_size, 1, q_len, k_v_seq_len): - raise ValueError( - f"Attention mask should be of size {(batch_size, 1, q_len, k_v_seq_len)}, but is {attention_mask.size()}" - ) - attn_weights = attn_weights + attention_mask - - # upcast attention to fp32 - attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) - attn_weights = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training) - attn_output = torch.matmul(attn_weights, value_states) - - if attn_output.size() != (batch_size, self.num_heads, q_len, self.head_dim): - raise ValueError( - f"`attn_output` should be of size {(batch_size, self.num_heads, q_len, self.head_dim)}, but is" - f" {attn_output.size()}" - ) - - attn_output = attn_output.transpose(1, 2).contiguous() - attn_output = attn_output.reshape(batch_size, q_len, self.embed_dim) - - attn_output = self.out_proj(attn_output) - - return attn_output, attn_weights - - -class SiglipFlashAttention2(SiglipAttention): - """ - Llama flash attention module. This module inherits from `LlamaAttention` as the weights of the module stays - untouched. The only required change would be on the forward pass where it needs to correctly call the public API of - flash attention and deal with padding tokens in case the input contains any of them. - """ - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self.is_causal = False # Hack to make sure we don't use a causal mask - - def forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.LongTensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Tuple[torch.Tensor]] = None, - output_attentions: bool = False, - use_cache: bool = False, - **kwargs, - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - output_attentions = False - - bsz, q_len, _ = hidden_states.size() - - query_states = self.q_proj(hidden_states) - key_states = self.k_proj(hidden_states) - value_states = self.v_proj(hidden_states) - - # Flash attention requires the input to have the shape - # batch_size x seq_length x head_dim x hidden_dim - # therefore we just need to keep the original shape - query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) - key_states = key_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) - value_states = value_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) - - kv_seq_len = key_states.shape[-2] - if past_key_value is not None: - kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) - # cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) - # query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) - - # if past_key_value is not None: - # cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models - # key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) - - # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache - # to be able to avoid many of these transpose/reshape/view. - query_states = query_states.transpose(1, 2) - key_states = key_states.transpose(1, 2) - value_states = value_states.transpose(1, 2) - - dropout_rate = self.dropout if self.training else 0.0 - - # In PEFT, usually we cast the layer norms in float32 for training stability reasons - # therefore the input hidden states gets silently casted in float32. Hence, we need - # cast them back in the correct dtype just to be sure everything works as expected. - # This might slowdown training & inference so it is recommended to not cast the LayerNorms - # in fp32. (LlamaRMSNorm handles it correctly) - - input_dtype = query_states.dtype - if input_dtype == torch.float32: - if torch.is_autocast_enabled(): - target_dtype = torch.get_autocast_gpu_dtype() - # Handle the case where the model is quantized - elif hasattr(self.config, "_pre_quantization_dtype"): - target_dtype = self.config._pre_quantization_dtype - else: - target_dtype = self.q_proj.weight.dtype - - logger.warning_once( - "The input hidden states seems to be silently casted in float32, this might be related to the fact" - " you have upcasted embedding or layer norm layers in float32. We will cast back the input in" - f" {target_dtype}." - ) - - query_states = query_states.to(target_dtype) - key_states = key_states.to(target_dtype) - value_states = value_states.to(target_dtype) - - attn_output = self._flash_attention_forward( - query_states, key_states, value_states, attention_mask, q_len, dropout=dropout_rate - ) - - attn_output = attn_output.reshape(bsz, q_len, self.embed_dim).contiguous() - attn_output = self.out_proj(attn_output) - - if not output_attentions: - attn_weights = None - - return attn_output, attn_weights - - def _flash_attention_forward( - self, query_states, key_states, value_states, attention_mask, query_length, dropout=0.0, softmax_scale=None - ): - """ - Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token - first unpad the input, then computes the attention scores and pad the final attention scores. - Args: - query_states (`torch.Tensor`): - Input query states to be passed to Flash Attention API - key_states (`torch.Tensor`): - Input key states to be passed to Flash Attention API - value_states (`torch.Tensor`): - Input value states to be passed to Flash Attention API - attention_mask (`torch.Tensor`): - The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the - position of padding tokens and 1 for the position of non-padding tokens. - dropout (`int`, *optional*): - Attention dropout - softmax_scale (`float`, *optional*): - The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim) - """ - - # TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in LlamaFlashAttention2 __init__. - causal = self.is_causal and query_length != 1 - - # Contains at least one padding token in the sequence - if attention_mask is not None: - batch_size = query_states.shape[0] - query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input( - query_states, key_states, value_states, attention_mask, query_length - ) - - cu_seqlens_q, cu_seqlens_k = cu_seq_lens - max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens - - attn_output_unpad = flash_attn_varlen_func( - query_states, - key_states, - value_states, - cu_seqlens_q=cu_seqlens_q, - cu_seqlens_k=cu_seqlens_k, - max_seqlen_q=max_seqlen_in_batch_q, - max_seqlen_k=max_seqlen_in_batch_k, - dropout_p=dropout, - softmax_scale=softmax_scale, - causal=causal, - ) - - attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length) - else: - attn_output = flash_attn_func( - query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=causal - ) - - return attn_output - - def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length): - indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask) - batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape - - key_layer = index_first_axis( - key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k - ) - value_layer = index_first_axis( - value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k - ) - if query_length == kv_seq_len: - query_layer = index_first_axis( - query_layer.reshape(batch_size * kv_seq_len, self.num_heads, head_dim), indices_k - ) - cu_seqlens_q = cu_seqlens_k - max_seqlen_in_batch_q = max_seqlen_in_batch_k - indices_q = indices_k - elif query_length == 1: - max_seqlen_in_batch_q = 1 - cu_seqlens_q = torch.arange( - batch_size + 1, dtype=torch.int32, device=query_layer.device - ) # There is a memcpy here, that is very bad. - indices_q = cu_seqlens_q[:-1] - query_layer = query_layer.squeeze(1) - else: - # The -q_len: slice assumes left padding. - attention_mask = attention_mask[:, -query_length:] - query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask) - - return ( - query_layer, - key_layer, - value_layer, - indices_q, - (cu_seqlens_q, cu_seqlens_k), - (max_seqlen_in_batch_q, max_seqlen_in_batch_k), - ) - - -# Copied from transformers.models.clip.modeling_clip.CLIPMLP with CLIP->Siglip -class SiglipMLP(nn.Module): - def __init__(self, config): - super().__init__() - self.config = config - self.activation_fn = ACT2FN[config.hidden_act] - self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size) - self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size) - - def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: - hidden_states = self.fc1(hidden_states) - hidden_states = self.activation_fn(hidden_states) - hidden_states = self.fc2(hidden_states) - return hidden_states - - -# Copied from transformers.models.clip.modeling_clip.CLIPEncoderLayer with CLIP->Siglip -class SiglipEncoderLayer(nn.Module): - def __init__(self, config: SiglipVisionConfig): - super().__init__() - self.embed_dim = config.hidden_size - self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2" - self.self_attn = ( - SiglipAttention(config) - if not self._use_flash_attention_2 - else SiglipFlashAttention2(config) - ) - self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps) - self.mlp = SiglipMLP(config) - self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps) - - def forward( - self, - hidden_states: torch.Tensor, - attention_mask: torch.Tensor, - output_attentions: Optional[bool] = False, - ) -> Tuple[torch.FloatTensor]: - """ - Args: - hidden_states (`torch.FloatTensor`): - Input to the layer of shape `(batch, seq_len, embed_dim)`. - attention_mask (`torch.FloatTensor`): - Attention mask of shape `(batch, 1, q_len, k_v_seq_len)` where padding elements are indicated by very large negative values. - output_attentions (`bool`, *optional*, defaults to `False`): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under - returned tensors for more detail. - """ - residual = hidden_states - - hidden_states = self.layer_norm1(hidden_states) - hidden_states, attn_weights = self.self_attn( - hidden_states=hidden_states, - attention_mask=attention_mask, - output_attentions=output_attentions, - ) - hidden_states = residual + hidden_states - - residual = hidden_states - hidden_states = self.layer_norm2(hidden_states) - hidden_states = self.mlp(hidden_states) - hidden_states = residual + hidden_states - - outputs = (hidden_states,) - - if output_attentions: - outputs += (attn_weights,) - - return outputs - - -class SiglipPreTrainedModel(PreTrainedModel): - """ - An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained - models. - """ - - config_class = SiglipVisionConfig - base_model_prefix = "siglip" - supports_gradient_checkpointing = True - - def _init_weights(self, module): - """Initialize the weights""" - - if isinstance(module, SiglipVisionEmbeddings): - width = self.config.hidden_size - nn.init.normal_(module.position_embedding.weight, std=1 / np.sqrt(width)) - elif isinstance(module, nn.Embedding): - default_flax_embed_init(module.weight) - elif isinstance(module, SiglipAttention): - nn.init.normal_(module.q_proj.weight) - nn.init.normal_(module.k_proj.weight) - nn.init.normal_(module.v_proj.weight) - nn.init.normal_(module.out_proj.weight) - nn.init.zeros_(module.q_proj.bias) - nn.init.zeros_(module.k_proj.bias) - nn.init.zeros_(module.v_proj.bias) - nn.init.zeros_(module.out_proj.bias) - elif isinstance(module, SiglipMLP): - nn.init.normal_(module.fc1.weight) - nn.init.normal_(module.fc2.weight) - nn.init.normal_(module.fc1.bias, std=1e-6) - nn.init.normal_(module.fc2.bias, std=1e-6) - elif isinstance(module, (nn.Linear, nn.Conv2d)): - lecun_normal_(module.weight) - if module.bias is not None: - nn.init.zeros_(module.bias) - elif isinstance(module, nn.LayerNorm): - module.bias.data.zero_() - module.weight.data.fill_(1.0) - - -SIGLIP_START_DOCSTRING = r""" - This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the - library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads - etc.) - This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. - Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage - and behavior. - Parameters: - config ([`SiglipVisionConfig`]): Model configuration class with all the parameters of the model. - Initializing with a config file does not load the weights associated with the model, only the - configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. -""" - - -SIGLIP_VISION_INPUTS_DOCSTRING = r""" - Args: - pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): - Pixel values. Padding will be ignored by default should you provide it. Pixel values can be obtained using - [`AutoImageProcessor`]. See [`CLIPImageProcessor.__call__`] for details. - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned - tensors for more detail. - output_hidden_states (`bool`, *optional*): - Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for - more detail. - return_dict (`bool`, *optional*): - Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. -""" - - -# Copied from transformers.models.clip.modeling_clip.CLIPEncoder with CLIP->Siglip -class SiglipEncoder(nn.Module): - """ - Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a - [`SiglipEncoderLayer`]. - Args: - config: SiglipConfig - """ - - def __init__(self, config: SiglipVisionConfig): - super().__init__() - self.config = config - self.layers = nn.ModuleList([SiglipEncoderLayer(config) for _ in range(config.num_hidden_layers)]) - self.gradient_checkpointing = False - - # Ignore copy - def forward( - self, - inputs_embeds, - attention_mask: Optional[torch.Tensor] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - ) -> Union[Tuple, BaseModelOutput]: - r""" - Args: - inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): - Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. - This is useful if you want more control over how to convert `input_ids` indices into associated vectors - than the model's internal embedding lookup matrix. - attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): - Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: - - 1 for tokens that are **not masked**, - - 0 for tokens that are **masked**. - [What are attention masks?](../glossary#attention-mask) - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under - returned tensors for more detail. - output_hidden_states (`bool`, *optional*): - Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors - for more detail. - return_dict (`bool`, *optional*): - Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. - """ - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - encoder_states = () if output_hidden_states else None - all_attentions = () if output_attentions else None - - hidden_states = inputs_embeds - for encoder_layer in self.layers: - if output_hidden_states: - encoder_states = encoder_states + (hidden_states,) - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - encoder_layer.__call__, - hidden_states, - attention_mask, - output_attentions, - ) - else: - layer_outputs = encoder_layer( - hidden_states, - attention_mask, - output_attentions=output_attentions, - ) - - hidden_states = layer_outputs[0] - - if output_attentions: - all_attentions = all_attentions + (layer_outputs[1],) - - if output_hidden_states: - encoder_states = encoder_states + (hidden_states,) - - if not return_dict: - return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None) - return BaseModelOutput( - last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions - ) - -@add_start_docstrings( - """The vision model from SigLIP without any head or projection on top.""", - SIGLIP_START_DOCSTRING -) -class SiglipVisionTransformer(SiglipPreTrainedModel): - config_class = SiglipVisionConfig - main_input_name = "pixel_values" - _supports_flash_attn_2 = True - - def __init__(self, config: SiglipVisionConfig): - super().__init__(config) - self.config = config - embed_dim = config.hidden_size - - self.embeddings = SiglipVisionEmbeddings(config) - self.encoder = SiglipEncoder(config) - self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps) - self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2" - - # Initialize weights and apply final processing - self.post_init() - - def get_input_embeddings(self) -> nn.Module: - return self.embeddings.patch_embedding - - @add_start_docstrings_to_model_forward(SIGLIP_VISION_INPUTS_DOCSTRING) - @replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=SiglipVisionConfig) - def forward( - self, - pixel_values, - patch_attention_mask: Optional[torch.BoolTensor] = None, - tgt_sizes: Optional[torch.IntTensor] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - ) -> Union[Tuple, BaseModelOutputWithPooling]: - r""" - Returns: - """ - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - batch_size = pixel_values.size(0) - if patch_attention_mask is None: - patch_attention_mask = torch.ones( - size=( - batch_size, - pixel_values.size(2) // self.config.patch_size, - pixel_values.size(3) // self.config.patch_size, - ), - dtype=torch.bool, - device=pixel_values.device, - ) - - hidden_states = self.embeddings(pixel_values=pixel_values, patch_attention_mask=patch_attention_mask, tgt_sizes=tgt_sizes) - - patch_attention_mask = patch_attention_mask.view(batch_size, -1) - # The call to `_upad_input` in `_flash_attention_forward` is expensive - # So when the `patch_attention_mask` is full of 1s (i.e. attending to the whole sequence), - # avoiding passing the attention_mask, which is equivalent to attending to the full sequence - if not torch.any(~patch_attention_mask): - attention_mask=None - else: - attention_mask = ( - _prepare_4d_attention_mask(patch_attention_mask, hidden_states.dtype) - if not self._use_flash_attention_2 - else patch_attention_mask - ) - - encoder_outputs = self.encoder( - inputs_embeds=hidden_states, - attention_mask=attention_mask, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - - last_hidden_state = encoder_outputs[0] - last_hidden_state = self.post_layernorm(last_hidden_state) - - if not return_dict: - return (last_hidden_state, None) + encoder_outputs[1:] - - return BaseModelOutputWithPooling( - last_hidden_state=last_hidden_state, - pooler_output=None, - hidden_states=encoder_outputs.hidden_states, - attentions=encoder_outputs.attentions, - ) - -import argparse -import os -import json -import re - -import torch -import numpy as np -from gguf import * - -TEXT = "clip.text" -VISION = "clip.vision" - - -def k(raw_key: str, arch: str) -> str: - return raw_key.format(arch=arch) - - -def should_skip_tensor(name: str, has_text: bool, has_vision: bool, has_minicpmv: bool) -> bool: - if name in ( - "logit_scale", - "text_model.embeddings.position_ids", - "vision_model.embeddings.position_ids", - ): - return True - - if has_minicpmv and name in ["visual_projection.weight"]: - return True - - if name.startswith("v") and not has_vision: - return True - - if name.startswith("t") and not has_text: - return True - - return False - - -def get_tensor_name(name: str) -> str: - if "projection" in name: - return name - if "mm_projector" in name: - name = name.replace("model.mm_projector", "mm") - name = re.sub(r'mm\.mlp\.mlp', 'mm.model.mlp', name, count=1) - name = re.sub(r'mm\.peg\.peg', 'mm.model.peg', name, count=1) - return name - - return name.replace("text_model", "t").replace("vision_model", "v").replace("encoder.layers", "blk").replace("embeddings.", "").replace("_proj", "").replace("self_attn.", "attn_").replace("layer_norm", "ln").replace("layernorm", "ln").replace("mlp.fc1", "ffn_down").replace("mlp.fc2", "ffn_up").replace("embedding", "embd").replace("final", "post").replace("layrnorm", "ln") - - -def bytes_to_unicode(): - """ - Returns list of utf-8 byte and a corresponding list of unicode strings. - The reversible bpe codes work on unicode strings. - This means you need a large # of unicode characters in your vocab if you want to avoid UNKs. - When you're at something like a 10B token dataset you end up needing around 5K for decent coverage. - This is a significant percentage of your normal, say, 32K bpe vocab. - To avoid that, we want lookup tables between utf-8 bytes and unicode strings. - And avoids mapping to whitespace/control characters the bpe code barfs on. - """ - bs = ( - list(range(ord("!"), ord("~") + 1)) - + list(range(ord("¡"), ord("¬") + 1)) - + list(range(ord("®"), ord("ÿ") + 1)) - ) - cs = bs[:] - n = 0 - for b in range(2**8): - if b not in bs: - bs.append(b) - cs.append(2**8 + n) - n += 1 - cs = [chr(n) for n in cs] - return dict(zip(bs, cs)) - - -ap = argparse.ArgumentParser() -ap.add_argument("-m", "--model-dir", help="Path to model directory cloned from HF Hub", required=True) -ap.add_argument("--use-f32", action="store_true", default=False, help="Use f32 instead of f16") -ap.add_argument("--text-only", action="store_true", required=False, - help="Save a text-only model. It can't be used to encode images") -ap.add_argument("--vision-only", action="store_true", required=False, - help="Save a vision-only model. It can't be used to encode texts") -ap.add_argument("--clip-model-is-vision", action="store_true", required=False, - help="The clip model is a pure vision model (ShareGPT4V vision extract for example)") -ap.add_argument("--clip-model-is-openclip", action="store_true", required=False, - help="The clip model is from openclip (for ViT-SO400M type))") -ap.add_argument("--minicpmv-projector", help="Path to minicpmv.projector file. If specified, save an image encoder for MiniCPM-V models.") -ap.add_argument("--projector-type", help="Type of projector. Possible values: mlp, ldp, ldpv2", choices=["mlp", "ldp", "ldpv2"], default="mlp") -ap.add_argument("-o", "--output-dir", help="Directory to save GGUF files. Default is the original model directory", default=None) -# Example --image_mean 0.48145466 0.4578275 0.40821073 --image_std 0.26862954 0.26130258 0.27577711 -# Example --image_mean 0.5 0.5 0.5 --image_std 0.5 0.5 0.5 -default_image_mean = [0.48145466, 0.4578275, 0.40821073] -default_image_std = [0.26862954, 0.26130258, 0.27577711] -ap.add_argument('--image-mean', type=float, nargs='+', help='Mean of the images for normalization (overrides processor) ', default=None) -ap.add_argument('--image-std', type=float, nargs='+', help='Standard deviation of the images for normalization (overrides processor)', default=None) - -# with proper -args = ap.parse_args() - - -if args.text_only and args.vision_only: - print("--text-only and --image-only arguments cannot be specified at the same time.") - exit(1) - -if args.use_f32: - print("WARNING: Weights for the convolution op is always saved in f16, as the convolution op in GGML does not support 32-bit kernel weights yet.") - -# output in the same directory as the model if output_dir is None -dir_model = args.model_dir - -if args.clip_model_is_vision or not os.path.exists(dir_model + "/vocab.json") or args.clip_model_is_openclip: - vocab = None - tokens = None -else: - with open(dir_model + "/vocab.json", "r", encoding="utf-8") as f: - vocab = json.load(f) - tokens = [key for key in vocab] - -# possible data types -# ftype == 0 -> float32 -# ftype == 1 -> float16 -# -# map from ftype to string -ftype_str = ["f32", "f16"] - -ftype = 1 -if args.use_f32: - ftype = 0 - -# if args.clip_model_is_vision or args.clip_model_is_openclip: -# model = CLIPVisionModel.from_pretrained(dir_model) -# processor = None -# else: -# model = CLIPModel.from_pretrained(dir_model) -# processor = CLIPProcessor.from_pretrained(dir_model) - -default_vision_config = { - "hidden_size": 1152, - "image_size": 980, - "intermediate_size": 4304, - "model_type": "idefics2", - "num_attention_heads": 16, - "num_hidden_layers": 27, - "patch_size": 14, - } - -vision_config = SiglipVisionConfig(**default_vision_config) -model = SiglipVisionTransformer(vision_config) - -processor = None -# if model.attn_pool is not None: -# model.attn_pool = torch.nn.Identity() - -# model.blocks = model.blocks[:-1] -model.load_state_dict(torch.load(os.path.join(dir_model, "minicpmv.clip"))) - -fname_middle = None -has_text_encoder = True -has_vision_encoder = True -has_minicpmv_projector = False -if args.text_only: - fname_middle = "text-" - has_vision_encoder = False -elif args.minicpmv_projector is not None: - fname_middle = "mmproj-" - has_text_encoder = False - has_minicpmv_projector = True - minicpmv_version = 3 -elif args.vision_only: - fname_middle = "vision-" - has_text_encoder = False -else: - fname_middle = "" - -output_dir = args.output_dir if args.output_dir is not None else dir_model -os.makedirs(output_dir, exist_ok=True) -output_prefix = os.path.basename(output_dir).replace("ggml_", "") -fname_out = os.path.join(output_dir, f"{fname_middle}model-{ftype_str[ftype]}.gguf") -fout = GGUFWriter(path=fname_out, arch="clip") - -fout.add_bool("clip.has_text_encoder", has_text_encoder) -fout.add_bool("clip.has_vision_encoder", has_vision_encoder) -fout.add_bool("clip.has_minicpmv_projector", has_minicpmv_projector) -fout.add_file_type(ftype) -if args.text_only: - fout.add_description("text-only CLIP model") -elif args.vision_only and not has_minicpmv_projector: - fout.add_description("vision-only CLIP model") -elif has_minicpmv_projector: - fout.add_description("image encoder for MiniCPM-V") - # add projector type - fout.add_string("clip.projector_type", "resampler") - fout.add_int32("clip.minicpmv_version", minicpmv_version) -else: - fout.add_description("two-tower CLIP model") - -if has_vision_encoder: - # vision_model hparams - fout.add_uint32("clip.vision.image_size", 448) - fout.add_uint32("clip.vision.patch_size", 14) - fout.add_uint32(k(KEY_EMBEDDING_LENGTH, VISION), 1152) - fout.add_uint32(k(KEY_FEED_FORWARD_LENGTH, VISION), 4304) - fout.add_uint32("clip.vision.projection_dim", 0) - fout.add_uint32(k(KEY_ATTENTION_HEAD_COUNT, VISION), 16) - fout.add_float32(k(KEY_ATTENTION_LAYERNORM_EPS, VISION), 1e-6) - block_count = 26 - fout.add_uint32(k(KEY_BLOCK_COUNT, VISION), block_count) - - if processor is not None: - image_mean = processor.image_processor.image_mean if args.image_mean is None or args.image_mean == default_image_mean else args.image_mean - image_std = processor.image_processor.image_std if args.image_std is None or args.image_std == default_image_std else args.image_std - else: - image_mean = args.image_mean if args.image_mean is not None else default_image_mean - image_std = args.image_std if args.image_std is not None else default_image_std - fout.add_array("clip.vision.image_mean", image_mean) - fout.add_array("clip.vision.image_std", image_std) - -use_gelu = True -fout.add_bool("clip.use_gelu", use_gelu) - -def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): - """ - embed_dim: output dimension for each position - pos: a list of positions to be encoded: size (M,) - out: (M, D) - """ - assert embed_dim % 2 == 0 - omega = np.arange(embed_dim // 2, dtype=np.float32) - omega /= embed_dim / 2. - omega = 1. / 10000 ** omega # (D/2,) - - pos = pos.reshape(-1) # (M,) - out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product - - emb_sin = np.sin(out) # (M, D/2) - emb_cos = np.cos(out) # (M, D/2) - - emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D) - return emb - -def get_2d_sincos_pos_embed_from_grid(embed_dim, grid): - assert embed_dim % 2 == 0 - - # use half of dimensions to encode grid_h - emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2) - emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2) - - emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D) - return emb - - -# https://github.com/facebookresearch/mae/blob/efb2a8062c206524e35e47d04501ed4f544c0ae8/util/pos_embed.py#L20 -def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False): - """ - grid_size: int of the grid height and width - return: - pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token) - """ - if isinstance(grid_size, int): - grid_h_size, grid_w_size = grid_size, grid_size - else: - grid_h_size, grid_w_size = grid_size[0], grid_size[1] - - grid_h = np.arange(grid_h_size, dtype=np.float32) - grid_w = np.arange(grid_w_size, dtype=np.float32) - grid = np.meshgrid(grid_w, grid_h) # here w goes first - grid = np.stack(grid, axis=0) - - grid = grid.reshape([2, 1, grid_h_size, grid_w_size]) - pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid) - if cls_token: - pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0) - return pos_embed - -def _replace_name_resampler(s, v): - if re.match("resampler.pos_embed", s): - return { - s: v, - re.sub("pos_embed", "pos_embed_k", s): torch.from_numpy(get_2d_sincos_pos_embed(3584, (70, 70))), - } - if re.match("resampler.proj", s): - return { - re.sub("proj", "pos_embed_k", s): torch.from_numpy(get_2d_sincos_pos_embed(3584, (70, 70))), - re.sub("proj", "proj.weight", s): v.transpose(-1, -2).contiguous(), - } - if re.match("resampler.attn.in_proj_.*", s): - return { - re.sub("attn.in_proj_", "attn.q.", s): v.chunk(3, dim=0)[0], - re.sub("attn.in_proj_", "attn.k.", s): v.chunk(3, dim=0)[1], - re.sub("attn.in_proj_", "attn.v.", s): v.chunk(3, dim=0)[2], - } - return {s: v} - -if has_minicpmv_projector: - projector = torch.load(args.minicpmv_projector) - new_state_dict = {} - for k, v in projector.items(): - kvs = _replace_name_resampler(k, v) - for nk, nv in kvs.items(): - new_state_dict[nk] = nv - projector = new_state_dict - for name, data in projector.items(): - name = get_tensor_name(name) - data = data.squeeze().numpy() - - n_dims = len(data.shape) - if ftype == 1: - if name[-7:] == ".weight" and n_dims == 2: - print(" Converting to float16") - data = data.astype(np.float16) - ftype_cur = 1 - else: - print(" Converting to float32") - data = data.astype(np.float32) - ftype_cur = 0 - else: - if data.dtype != np.float32: - print(" Converting to float32") - data = data.astype(np.float32) - ftype_cur = 0 - - fout.add_tensor(name, data) - print(f"{name} - {ftype_str[ftype_cur]} - shape = {data.shape}") - - print("Projector tensors added\n") - -def _replace_name(s, v): - s = "vision_model." + s - if re.match("vision_model.embeddings.position_embedding", s): - v = v.unsqueeze(0) - return {s: v} - - return {s: v} - -state_dict = model.state_dict() -new_state_dict = {} -for k, v in state_dict.items(): - kvs = _replace_name(k, v) - for nk, nv in kvs.items(): - new_state_dict[nk] = nv -state_dict = new_state_dict -for name, data in state_dict.items(): - if should_skip_tensor(name, has_text_encoder, has_vision_encoder, has_minicpmv_projector): - # we don't need this - print(f"skipping parameter: {name}") - continue - - name = get_tensor_name(name) - data = data.squeeze().numpy() - - n_dims = len(data.shape) - - # ftype == 0 -> float32, ftype == 1 -> float16 - ftype_cur = 0 - if n_dims == 4: - print(f"tensor {name} is always saved in f16") - data = data.astype(np.float16) - ftype_cur = 1 - elif ftype == 1: - if name[-7:] == ".weight" and n_dims == 2: - print(" Converting to float16") - data = data.astype(np.float16) - ftype_cur = 1 - else: - print(" Converting to float32") - data = data.astype(np.float32) - ftype_cur = 0 - else: - if data.dtype != np.float32: - print(" Converting to float32") - data = data.astype(np.float32) - ftype_cur = 0 - - print(f"{name} - {ftype_str[ftype_cur]} - shape = {data.shape}") - fout.add_tensor(name, data) - - -fout.write_header_to_file() -fout.write_kv_data_to_file() -fout.write_tensors_to_file() -fout.close() - -print("Done. Output file: " + fname_out) diff --git a/examples/llava/minicpmv-convert/minicpmv2_6-surgery.py b/examples/llava/minicpmv-convert/minicpmv2_6-surgery.py deleted file mode 100644 index cb4a75c6ac5ae5..00000000000000 --- a/examples/llava/minicpmv-convert/minicpmv2_6-surgery.py +++ /dev/null @@ -1,47 +0,0 @@ -import argparse -import glob -import os -import torch -from transformers import AutoModel, AutoTokenizer - -ap = argparse.ArgumentParser() -ap.add_argument("-m", "--model", help="Path to MiniCPM-V-2.6 model") -args = ap.parse_args() - -# find the model part that includes the the multimodal projector weights -model = AutoModel.from_pretrained(args.model, trust_remote_code=True, local_files_only=True) -checkpoint = model.state_dict() - -# get a list of mm tensor names -mm_tensors = [k for k, v in checkpoint.items() if k.startswith("resampler")] - -# store these tensors in a new dictionary and torch.save them -projector = {name: checkpoint[name].float() for name in mm_tensors} -torch.save(projector, f"{args.model}/minicpmv.projector") - -clip_tensors = [k for k, v in checkpoint.items() if k.startswith("vpm")] -if len(clip_tensors) > 0: - clip = {name.replace("vpm.", ""): checkpoint[name].float() for name in clip_tensors} - torch.save(clip, f"{args.model}/minicpmv.clip") - - # added tokens should be removed to be able to convert Mistral models - if os.path.exists(f"{args.model}/added_tokens.json"): - with open(f"{args.model}/added_tokens.json", "w") as f: - f.write("{}\n") - -config = model.llm.config -config._name_or_path = "openbmb/MiniCPM-V-2.6" -config.auto_map = { - "AutoConfig": "configuration_minicpm.MiniCPMConfig", - "AutoModel": "modeling_minicpm.MiniCPMModel", - "AutoModelForCausalLM": "modeling_minicpm.MiniCPMForCausalLM", - "AutoModelForSeq2SeqLM": "modeling_minicpm.MiniCPMForCausalLM", - "AutoModelForSequenceClassification": "modeling_minicpm.MiniCPMForSequenceClassification" -} -model.llm.save_pretrained(f"{args.model}/model") -tok = AutoTokenizer.from_pretrained(args.model, trust_remote_code=True) -tok.save_pretrained(f"{args.model}/model") - -print("Done!") -print(f"Now you can convert {args.model} to a regular LLaMA GGUF file.") -print(f"Also, use {args.model}/minicpmv.projector to prepare a minicpmv-encoder.gguf file.")