From b51c1c1be3ce4f79ac22e645afdec3cb99144916 Mon Sep 17 00:00:00 2001 From: Jintao Date: Tue, 8 Oct 2024 21:46:55 +0800 Subject: [PATCH 1/3] fix bugs (#2207) --- swift/llm/deploy.py | 5 +--- swift/llm/export.py | 6 ++--- swift/llm/infer.py | 44 ++++++++++++++++++----------------- swift/llm/utils/template.py | 6 ++++- swift/llm/utils/vllm_utils.py | 26 --------------------- 5 files changed, 32 insertions(+), 55 deletions(-) diff --git a/swift/llm/deploy.py b/swift/llm/deploy.py index 24f322c34..a905bab5e 100644 --- a/swift/llm/deploy.py +++ b/swift/llm/deploy.py @@ -275,7 +275,7 @@ async def inference_vllm_async(request: Union[ChatCompletionRequest, CompletionR request_id = request_info['request_id'] kwargs = {'max_tokens': request.max_tokens} - for key in ['n', 'best_of', 'frequency_penalty', 'length_penalty', 'presence_penalty', 'num_beams']: + for key in ['n', 'best_of', 'frequency_penalty', 'length_penalty', 'presence_penalty']: kwargs[key] = getattr(request, key) for key in ['temperature', 'top_k', 'top_p', 'repetition_penalty']: new_value = getattr(request, key) @@ -292,9 +292,6 @@ async def inference_vllm_async(request: Union[ChatCompletionRequest, CompletionR kwargs['logprobs'] = max(1, request.top_logprobs) generation_config = VllmGenerationConfig(**kwargs) - if generation_config.use_beam_search and request.stream: - error_msg = 'Streaming generation does not support beam search.' - raise ValueError(error_msg) tokenizer = template.tokenizer if tokenizer.eos_token is not None and tokenizer.eos_token not in generation_config.stop: generation_config.stop.append(tokenizer.eos_token) diff --git a/swift/llm/export.py b/swift/llm/export.py index 3b04f5f76..0f85a7c5e 100644 --- a/swift/llm/export.py +++ b/swift/llm/export.py @@ -255,18 +255,18 @@ def llm_export(args: ExportArguments) -> None: if args.quant_method == 'awq': from awq import AutoAWQForCausalLM model, template = prepare_model_template( - args, device_map=args.quant_device_map, verbose=False, automodel_class=AutoAWQForCausalLM) + args, device_map=args.quant_device_map, task='export', automodel_class=AutoAWQForCausalLM) awq_model_quantize(model, template.tokenizer, args.quant_batch_size) model.save_quantized(args.quant_output_dir) elif args.quant_method == 'gptq': - model, template = prepare_model_template(args, device_map=args.quant_device_map, verbose=False) + model, template = prepare_model_template(args, device_map=args.quant_device_map, task='export') gptq_quantizer = gptq_model_quantize(model, template.tokenizer, args.quant_batch_size) model.config.quantization_config.pop('dataset', None) gptq_quantizer.save(model, args.quant_output_dir) elif args.quant_method == 'bnb': args.quantization_bit = args.quant_bits args.bnb_4bit_compute_dtype, args.load_in_4bit, args.load_in_8bit = args.select_bnb() - model, template = prepare_model_template(args, device_map=args.quant_device_map, verbose=False) + model, template = prepare_model_template(args, device_map=args.quant_device_map, task='export') model.save_pretrained(args.quant_output_dir) else: raise ValueError(f'args.quant_method: {args.quant_method}') diff --git a/swift/llm/infer.py b/swift/llm/infer.py index 3e1a1439e..8181a2f47 100644 --- a/swift/llm/infer.py +++ b/swift/llm/infer.py @@ -109,7 +109,7 @@ def merge_lora(args: InferArguments, if device_map is None: device_map = args.merge_device_map logger.info(f'merge_device_map: {device_map}') - model, template = prepare_model_template(args, device_map=device_map, verbose=False) + model, template = prepare_model_template(args, device_map=device_map, task='export') logger.info('Merge LoRA...') Swift.merge_and_unload(model) model = model.model @@ -133,7 +133,7 @@ def merge_lora(args: InferArguments, def prepare_model_template(args: InferArguments, *, device_map: Optional[str] = None, - verbose: bool = True, + task: Literal['infer', 'export'] = 'infer', automodel_class=None) -> Tuple[PreTrainedModel, Template]: from .sft import get_default_device_map if is_torch_npu_available(): @@ -188,25 +188,7 @@ def prepare_model_template(args: InferArguments, revision=args.model_revision, quant_method=args.quant_method, **kwargs) - if verbose: - logger.info(f'model_config: {model.config}') - - generation_config = GenerationConfig( - max_new_tokens=args.max_new_tokens, - temperature=args.temperature, - top_k=args.top_k, - top_p=args.top_p, - do_sample=args.do_sample, - repetition_penalty=args.repetition_penalty, - num_beams=args.num_beams, - pad_token_id=tokenizer.pad_token_id, - eos_token_id=tokenizer.eos_token_id) - set_generation_config(model, generation_config) - logger.info(f'model.generation_config: {model.generation_config}') - if model.generation_config.num_beams != 1: - args.stream = False - logger.info('Setting args.stream: False') if model.max_model_len is None: model.max_model_len = args.max_model_len elif args.max_model_len is not None: @@ -215,6 +197,26 @@ def prepare_model_template(args: InferArguments, else: raise ValueError('args.max_model_len exceeds the maximum max_model_len supported by the model.' f'args.max_model_len: {args.max_model_len}, model.max_model_len: {model.max_model_len}') + if task == 'infer': + logger.info(f'model_config: {model.config}') + generation_config = GenerationConfig( + max_new_tokens=args.max_new_tokens, + temperature=args.temperature, + top_k=args.top_k, + top_p=args.top_p, + do_sample=args.do_sample, + repetition_penalty=args.repetition_penalty, + num_beams=args.num_beams, + pad_token_id=tokenizer.pad_token_id, + eos_token_id=tokenizer.eos_token_id) + model._generation_config_origin = model.generation_config + set_generation_config(model, generation_config) + logger.info(f'model.generation_config: {model.generation_config}') + + if model.generation_config.num_beams != 1: + args.stream = False + logger.info('Setting args.stream: False') + # Preparing LoRA if is_adapter(args.sft_type) and args.ckpt_dir is not None: if isinstance(args, DeployArguments) and args.lora_request_list is not None: @@ -227,7 +229,7 @@ def prepare_model_template(args: InferArguments, model = model.to(model.dtype) model.requires_grad_(False) - if verbose: + if task == 'infer': show_layers(model) logger.info(model) logger.info(get_model_info(model)) diff --git a/swift/llm/utils/template.py b/swift/llm/utils/template.py index ac3b95611..84c01a2b6 100644 --- a/swift/llm/utils/template.py +++ b/swift/llm/utils/template.py @@ -2028,6 +2028,10 @@ def _post_encode(self, model, data: Any) -> Dict[str, Any]: res['labels'] = labels[0] return res + @staticmethod + def _get_generate_ids(generate_ids: List[int], input_token_len: int) -> List[int]: + return generate_ids + register_template(TemplateType.llama3_1_omni, Llama3_1OmniTemplate(), lazy_tokenize=True) @@ -2642,7 +2646,7 @@ def _encode(self, example: Dict[str, Any]) -> Tuple[Dict[str, Any], Dict[str, An videos_path = example.get('videos') or [] if len(videos_path) > 0: video_processor = self.tokenizer.processor.video_processor - video_inputs = video_processor(videos, return_tensors='pt').to(self.model.dtype) + video_inputs = video_processor(videos_path, return_tensors='pt').to(self.model.dtype) inputs['pixel_values_videos'] = video_inputs['pixel_values_videos'] if len(images) > 0: image_processor = self.tokenizer.processor.image_processor diff --git a/swift/llm/utils/vllm_utils.py b/swift/llm/utils/vllm_utils.py index a7e36e870..8235fdcf3 100644 --- a/swift/llm/utils/vllm_utils.py +++ b/swift/llm/utils/vllm_utils.py @@ -204,7 +204,6 @@ def __init__( top_k: int = 50, # -1: all top_p: float = 1., repetition_penalty: float = 1., - num_beams: int = 1, *, n: int = 1, logprobs: Optional[int] = None, @@ -218,12 +217,6 @@ def __init__( max_new_tokens = kwargs.pop('max_new_tokens', None) if max_new_tokens is not None: max_tokens = max_new_tokens - if num_beams > 1: - top_k = -1 - top_p = 1 - temperature = 0 - logger.warning('The output of num_beams in vllm may not be consistent with ' - 'the output of num_beams in transformers.') if top_k == 0: top_k = -1 if stop is None: @@ -233,11 +226,6 @@ def __init__( kwargs['top_k'] = top_k kwargs['top_p'] = top_p kwargs['repetition_penalty'] = repetition_penalty - if num_beams > 1: - best_of = kwargs.get('best_of') - assert 'use_beam_search' not in kwargs and best_of is None - kwargs['use_beam_search'] = True - kwargs['best_of'] = num_beams kwargs['n'] = n kwargs['logprobs'] = logprobs kwargs['seed'] = seed @@ -260,7 +248,6 @@ class VllmGenerationConfig(_VllmGenerationConfigMixin, SamplingParams): top_k: int = 50 # -1: all top_p: float = 1. repetition_penalty: float = 1. - num_beams: int = 1 n: int = 1 logprobs: Optional[int] = None seed: Optional[int] = None @@ -269,15 +256,6 @@ class VllmGenerationConfig(_VllmGenerationConfigMixin, SamplingParams): skip_special_tokens: bool = False def __post_init__(self): - if self.num_beams > 1: - self.top_k = -1 - self.top_p = 1 - self.temperature = 0 - logger.warning('The output of num_beams in vllm may not be consistent with ' - 'the output of num_beams in transformers.') - assert self.best_of is None - self.use_beam_search = True - self.best_of = self.num_beams if self.top_k == 0: self.top_k = -1 if self.stop is None: @@ -435,10 +413,6 @@ def inference_stream_vllm( use_tqdm=use_tqdm, **kwargs) - if generation_config.use_beam_search: - error_msg = 'Streaming generation does not support beam search.' - raise ValueError(error_msg) - n_finished = 0 n_steps = 0 if flush_steps is None: From 1658ccb62c8b079daf15ed183e2784afb6bc593f Mon Sep 17 00:00:00 2001 From: Jintao Date: Wed, 9 Oct 2024 15:55:23 +0800 Subject: [PATCH 2/3] support telechat2 (#2210) --- ...222\214\346\225\260\346\215\256\351\233\206.md" | 5 +++-- .../Instruction/Supported-models-datasets.md | 5 +++-- swift/llm/utils/model.py | 14 ++++++++++++-- swift/llm/utils/template.py | 4 ++-- 4 files changed, 20 insertions(+), 8 deletions(-) diff --git "a/docs/source/Instruction/\346\224\257\346\214\201\347\232\204\346\250\241\345\236\213\345\222\214\346\225\260\346\215\256\351\233\206.md" "b/docs/source/Instruction/\346\224\257\346\214\201\347\232\204\346\250\241\345\236\213\345\222\214\346\225\260\346\215\256\351\233\206.md" index 5e59223ab..ea1f89b25 100644 --- "a/docs/source/Instruction/\346\224\257\346\214\201\347\232\204\346\250\241\345\236\213\345\222\214\346\225\260\346\215\256\351\233\206.md" +++ "b/docs/source/Instruction/\346\224\257\346\214\201\347\232\204\346\250\241\345\236\213\345\222\214\346\225\260\346\215\256\351\233\206.md" @@ -405,8 +405,9 @@ |mamba-2.8b|[AI-ModelScope/mamba-2.8b-hf](https://modelscope.cn/models/AI-ModelScope/mamba-2.8b-hf/summary)|in_proj, x_proj, embeddings, out_proj|default-generation|✘|✘|✘|✘|transformers>=4.39.0|-|[state-spaces/mamba-2.8b-hf](https://huggingface.co/state-spaces/mamba-2.8b-hf)| |telechat-7b|[TeleAI/TeleChat-7B](https://modelscope.cn/models/TeleAI/TeleChat-7B/summary)|key_value, query|telechat|✔|✘|✘|✘||-|[Tele-AI/telechat-7B](https://huggingface.co/Tele-AI/telechat-7B)| |telechat-12b|[TeleAI/TeleChat-12B](https://modelscope.cn/models/TeleAI/TeleChat-12B/summary)|key_value, query|telechat|✔|✘|✘|✘||-|[Tele-AI/TeleChat-12B](https://huggingface.co/Tele-AI/TeleChat-12B)| -|telechat-12b-v2|[TeleAI/TeleChat-12B-v2](https://modelscope.cn/models/TeleAI/TeleChat-12B-v2/summary)|key_value, query|telechat-v2|✔|✘|✘|✘||-|[Tele-AI/TeleChat-12B-v2](https://huggingface.co/Tele-AI/TeleChat-12B-v2)| -|telechat-12b-v2-gptq-int4|[swift/TeleChat-12B-V2-GPTQ-Int4](https://modelscope.cn/models/swift/TeleChat-12B-V2-GPTQ-Int4/summary)|key_value, query|telechat-v2|✔|✘|✘|✘|auto_gptq>=0.5|-|-| +|telechat-12b-v2|[TeleAI/TeleChat-12B-v2](https://modelscope.cn/models/TeleAI/TeleChat-12B-v2/summary)|key_value, query|telechat|✔|✘|✘|✘||-|[Tele-AI/TeleChat-12B-v2](https://huggingface.co/Tele-AI/TeleChat-12B-v2)| +|telechat-12b-v2-gptq-int4|[swift/TeleChat-12B-V2-GPTQ-Int4](https://modelscope.cn/models/swift/TeleChat-12B-V2-GPTQ-Int4/summary)|key_value, query|telechat|✔|✘|✘|✘|auto_gptq>=0.5|-|-| +|telechat2-115b|[TeleAI/TeleChat2-115B](https://modelscope.cn/models/TeleAI/TeleChat2-115B/summary)|key_value, query|telechat2|✔|✘|✘|✘||-|[Tele-AI/TeleChat2-115B](https://huggingface.co/Tele-AI/TeleChat2-115B)| |grok-1|[colossalai/grok-1-pytorch](https://modelscope.cn/models/colossalai/grok-1-pytorch/summary)|q_proj, k_proj, v_proj|default-generation|✘|✘|✘|✘||-|[hpcai-tech/grok-1](https://huggingface.co/hpcai-tech/grok-1)| |dbrx-instruct|[AI-ModelScope/dbrx-instruct](https://modelscope.cn/models/AI-ModelScope/dbrx-instruct/summary)|attn.Wqkv|dbrx|✔|✔|✘|✘|transformers>=4.36|moe|[databricks/dbrx-instruct](https://huggingface.co/databricks/dbrx-instruct)| |dbrx-base|[AI-ModelScope/dbrx-base](https://modelscope.cn/models/AI-ModelScope/dbrx-base/summary)|attn.Wqkv|dbrx|✔|✔|✘|✘|transformers>=4.36|moe|[databricks/dbrx-base](https://huggingface.co/databricks/dbrx-base)| diff --git a/docs/source_en/Instruction/Supported-models-datasets.md b/docs/source_en/Instruction/Supported-models-datasets.md index f45779768..08bc0b19a 100644 --- a/docs/source_en/Instruction/Supported-models-datasets.md +++ b/docs/source_en/Instruction/Supported-models-datasets.md @@ -405,8 +405,9 @@ The table below introcudes all models supported by SWIFT: |mamba-2.8b|[AI-ModelScope/mamba-2.8b-hf](https://modelscope.cn/models/AI-ModelScope/mamba-2.8b-hf/summary)|in_proj, x_proj, embeddings, out_proj|default-generation|✘|✘|✘|✘|transformers>=4.39.0|-|[state-spaces/mamba-2.8b-hf](https://huggingface.co/state-spaces/mamba-2.8b-hf)| |telechat-7b|[TeleAI/TeleChat-7B](https://modelscope.cn/models/TeleAI/TeleChat-7B/summary)|key_value, query|telechat|✔|✘|✘|✘||-|[Tele-AI/telechat-7B](https://huggingface.co/Tele-AI/telechat-7B)| |telechat-12b|[TeleAI/TeleChat-12B](https://modelscope.cn/models/TeleAI/TeleChat-12B/summary)|key_value, query|telechat|✔|✘|✘|✘||-|[Tele-AI/TeleChat-12B](https://huggingface.co/Tele-AI/TeleChat-12B)| -|telechat-12b-v2|[TeleAI/TeleChat-12B-v2](https://modelscope.cn/models/TeleAI/TeleChat-12B-v2/summary)|key_value, query|telechat-v2|✔|✘|✘|✘||-|[Tele-AI/TeleChat-12B-v2](https://huggingface.co/Tele-AI/TeleChat-12B-v2)| -|telechat-12b-v2-gptq-int4|[swift/TeleChat-12B-V2-GPTQ-Int4](https://modelscope.cn/models/swift/TeleChat-12B-V2-GPTQ-Int4/summary)|key_value, query|telechat-v2|✔|✘|✘|✘|auto_gptq>=0.5|-|-| +|telechat-12b-v2|[TeleAI/TeleChat-12B-v2](https://modelscope.cn/models/TeleAI/TeleChat-12B-v2/summary)|key_value, query|telechat|✔|✘|✘|✘||-|[Tele-AI/TeleChat-12B-v2](https://huggingface.co/Tele-AI/TeleChat-12B-v2)| +|telechat-12b-v2-gptq-int4|[swift/TeleChat-12B-V2-GPTQ-Int4](https://modelscope.cn/models/swift/TeleChat-12B-V2-GPTQ-Int4/summary)|key_value, query|telechat|✔|✘|✘|✘|auto_gptq>=0.5|-|-| +|telechat2-115b|[TeleAI/TeleChat2-115B](https://modelscope.cn/models/TeleAI/TeleChat2-115B/summary)|key_value, query|telechat2|✔|✘|✘|✘||-|[Tele-AI/TeleChat2-115B](https://huggingface.co/Tele-AI/TeleChat2-115B)| |grok-1|[colossalai/grok-1-pytorch](https://modelscope.cn/models/colossalai/grok-1-pytorch/summary)|q_proj, k_proj, v_proj|default-generation|✘|✘|✘|✘||-|[hpcai-tech/grok-1](https://huggingface.co/hpcai-tech/grok-1)| |dbrx-instruct|[AI-ModelScope/dbrx-instruct](https://modelscope.cn/models/AI-ModelScope/dbrx-instruct/summary)|attn.Wqkv|dbrx|✔|✔|✘|✘|transformers>=4.36|moe|[databricks/dbrx-instruct](https://huggingface.co/databricks/dbrx-instruct)| |dbrx-base|[AI-ModelScope/dbrx-base](https://modelscope.cn/models/AI-ModelScope/dbrx-base/summary)|attn.Wqkv|dbrx|✔|✔|✘|✘|transformers>=4.36|moe|[databricks/dbrx-base](https://huggingface.co/databricks/dbrx-base)| diff --git a/swift/llm/utils/model.py b/swift/llm/utils/model.py index 579b3006b..84ed65ac7 100644 --- a/swift/llm/utils/model.py +++ b/swift/llm/utils/model.py @@ -599,6 +599,7 @@ class ModelType: telechat_12b = 'telechat-12b' telechat_12b_v2 = 'telechat-12b-v2' telechat_12b_v2_gptq_int4 = 'telechat-12b-v2-gptq-int4' + telechat2_115b = 'telechat2-115b' # grok-1 grok_1 = 'grok-1' # dbrx @@ -930,6 +931,14 @@ def _new_forward(self, x): support_vllm=True, support_flash_attn=True, hf_model_id='CohereForAI/c4ai-command-r-plus') +@register_model( + ModelType.telechat2_115b, + 'TeleAI/TeleChat2-115B', + LoRATM.telechat, + TemplateType.telechat2, + torch_dtype=torch.float16, + support_flash_attn=True, + hf_model_id='Tele-AI/TeleChat2-115B') def get_model_tokenizer_from_repo(model_dir: str, torch_dtype: Optional[torch.dtype], model_kwargs: Dict[str, Any], @@ -5829,7 +5838,7 @@ def get_model_tokenizer_codellama(model_dir: str, ModelType.telechat_12b_v2, 'TeleAI/TeleChat-12B-v2', LoRATM.telechat, - TemplateType.telechat_v2, + TemplateType.telechat, eos_token=2, support_flash_attn=True, hf_model_id='Tele-AI/TeleChat-12B-v2') @@ -5837,9 +5846,10 @@ def get_model_tokenizer_codellama(model_dir: str, ModelType.telechat_12b_v2_gptq_int4, 'swift/TeleChat-12B-V2-GPTQ-Int4', LoRATM.telechat, - TemplateType.telechat_v2, + TemplateType.telechat, eos_token=2, requires=['auto_gptq>=0.5'], + torch_dtype=torch.float16, support_flash_attn=True, function_kwargs={'gptq_bits': 4}) def get_model_tokenizer_phi(model_dir: str, diff --git a/swift/llm/utils/template.py b/swift/llm/utils/template.py index 84c01a2b6..40a2918e5 100644 --- a/swift/llm/utils/template.py +++ b/swift/llm/utils/template.py @@ -138,7 +138,7 @@ class TemplateType: phi3 = 'phi3' phi3_vl = 'phi3-vl' telechat = 'telechat' - telechat_v2 = 'telechat-v2' + telechat2 = 'telechat2' dbrx = 'dbrx' mengzi = 'mengzi' c4ai = 'c4ai' @@ -3448,7 +3448,7 @@ class MiniCPMV2_5Template(Llama3TemplateMixin, MiniCPMVTemplate): register_template(TemplateType.telechat, Template([], ['<_user>{{QUERY}}<_bot>'], ['<_end>'], ['<_end>'])) -register_template(TemplateType.telechat_v2, Template([], ['<_user> {{QUERY}}<_bot>'], [], ['<_end>'])) +register_template(TemplateType.telechat2, Template(['<_start>'], [[4], '{{QUERY}}', [5]], ['<_end>'], ['<_end>'])) DBRX_SYSTEM = ( 'You are DBRX, created by Databricks. You were last updated in December 2023. ' From ba7e07ba8965809c411f4ef033256acfe43db545 Mon Sep 17 00:00:00 2001 From: Jintao Date: Wed, 9 Oct 2024 16:40:33 +0800 Subject: [PATCH 3/3] Support ovis 1.6 (#2211) --- README.md | 4 +- README_CN.md | 4 +- ...14\346\225\260\346\215\256\351\233\206.md" | 1 + .../Instruction/Supported-models-datasets.md | 1 + swift/llm/utils/model.py | 48 +++++++++++++-- swift/llm/utils/template.py | 60 +++++++++++++++++++ swift/utils/import_utils.py | 7 +-- swift/utils/module_mapping.py | 6 ++ 8 files changed, 119 insertions(+), 12 deletions(-) diff --git a/README.md b/README.md index 7b653b58c..ba95c0e0c 100644 --- a/README.md +++ b/README.md @@ -36,7 +36,7 @@ - [Citation](#-citation) ## 📝 Introduction -SWIFT supports training(PreTraining/Fine-tuning/RLHF), inference, evaluation and deployment of **350+ LLMs and 90+ MLLMs** (multimodal large models). Developers can directly apply our framework to their own research and production environments to realize the complete workflow from model training and evaluation to application. In addition to supporting the lightweight training solutions provided by [PEFT](https://github.com/huggingface/peft), we also provide a complete **Adapters library** to support the latest training techniques such as NEFTune, LoRA+, LLaMA-PRO, etc. This adapter library can be used directly in your own custom workflow without our training scripts. +SWIFT supports training(PreTraining/Fine-tuning/RLHF), inference, evaluation and deployment of **350+ LLMs and 100+ MLLMs** (multimodal large models). Developers can directly apply our framework to their own research and production environments to realize the complete workflow from model training and evaluation to application. In addition to supporting the lightweight training solutions provided by [PEFT](https://github.com/huggingface/peft), we also provide a complete **Adapters library** to support the latest training techniques such as NEFTune, LoRA+, LLaMA-PRO, etc. This adapter library can be used directly in your own custom workflow without our training scripts. To facilitate use by users unfamiliar with deep learning, we provide a Gradio web-ui for controlling training and inference, as well as accompanying deep learning courses and best practices for beginners. SWIFT web-ui is available both on [Huggingface space](https://huggingface.co/spaces/tastelikefeet/swift) and [ModelScope studio](https://www.modelscope.cn/studios/iic/Scalable-lightWeight-Infrastructure-for-Fine-Tuning/summary), please feel free to try! @@ -55,6 +55,7 @@ You can contact us and communicate with us by adding our group: | ## 🎉 News +- 2024.10.09: Support for training and deploying ovis1.6-gemma2 series models. Experience it using `swift infer --model_type ovis1_6-gemma2-9b`. - 2024.09.26: Support for training and deploying llama3.2-vision series models. Experience it using `swift infer --model_type llama3_2-11b-vision-instruct`. - 2024.09.26: Support for training and deploying llama3.2 series models. Experience it using `swift infer --model_type llama3_2-1b-instruct`. - 2024.09.25: Support for training to deployment with got-ocr2. Best practices can be found [here](https://github.com/modelscope/ms-swift/issues/2122). @@ -642,6 +643,7 @@ The complete list of supported models and datasets can be found at [Supported Mo | Idefics3 | [HuggingFaceM4](https://huggingface.co/HuggingFaceM4) | English | 8B | chat model | | Pixtral | [mistralai](https://huggingface.co/mistralai) | English | 12B | chat model | | Llama3.1-Omni | [LLaMA-Omni](https://github.com/ictnlp/LLaMA-Omni) | English | 8B | chat model | +| Ovis | [Ovis](https://github.com/AIDC-AI/Ovis) | English | 9B | chat model | #### Diffusion Models diff --git a/README_CN.md b/README_CN.md index 23eb681bd..e0458d217 100644 --- a/README_CN.md +++ b/README_CN.md @@ -37,7 +37,7 @@ - [引用](#-引用) ## 📝 简介 -SWIFT支持**350+ LLM和90+ MLLM**(多模态大模型)的训练(预训练、微调、对齐)、推理、评测和部署。开发者可以直接将我们的框架应用到自己的Research和生产环境中,实现模型训练评测到应用的完整链路。我们除支持了[PEFT](https://github.com/huggingface/peft)提供的轻量训练方案外,也提供了一个完整的**Adapters库**以支持最新的训练技术,如NEFTune、LoRA+、LLaMA-PRO等,这个适配器库可以脱离训练脚本直接使用在自己的自定流程中。 +SWIFT支持**350+ LLM和100+ MLLM**(多模态大模型)的训练(预训练、微调、对齐)、推理、评测和部署。开发者可以直接将我们的框架应用到自己的Research和生产环境中,实现模型训练评测到应用的完整链路。我们除支持了[PEFT](https://github.com/huggingface/peft)提供的轻量训练方案外,也提供了一个完整的**Adapters库**以支持最新的训练技术,如NEFTune、LoRA+、LLaMA-PRO等,这个适配器库可以脱离训练脚本直接使用在自己的自定流程中。 为方便不熟悉深度学习的用户使用,我们提供了一个Gradio的web-ui用于控制训练和推理,并提供了配套的深度学习课程和最佳实践供新手入门。 可以在[Huggingface space](https://huggingface.co/spaces/tastelikefeet/swift) 和 [ModelScope创空间](https://www.modelscope.cn/studios/iic/Scalable-lightWeight-Infrastructure-for-Fine-Tuning/summary) 中体验SWIFT web-ui功能了。 @@ -56,6 +56,7 @@ SWIFT具有丰富全面的文档,请查看我们的文档网站: ## 🎉 新闻 +- 2024.10.09: 支持ovis1.6-gemma2的训练到部署. 使用`swift infer --model_type ovis1_6-gemma2-9b`进行体验. - 2024.09.26: 支持llama3.2-vision系列模型的训练到部署. 使用`swift infer --model_type llama3_2-11b-vision-instruct`进行体验. - 2024.09.26: 支持llama3.2系列模型的训练到部署. 使用`swift infer --model_type llama3_2-1b-instruct`进行体验. - 2024.09.25: 支持got-ocr2的训练到部署. 最佳实践可以查看[这里](https://github.com/modelscope/ms-swift/issues/2122). @@ -635,6 +636,7 @@ CUDA_VISIBLE_DEVICES=0 swift deploy \ | Idefics3 | [HuggingFaceM4](https://huggingface.co/HuggingFaceM4) | 英文 | 8B | chat模型 | | Pixtral | [mistralai](https://huggingface.co/mistralai) | 英文 | 12B | chat模型 | | Llama3.1-Omni | [LLaMA-Omni](https://github.com/ictnlp/LLaMA-Omni) | 英文 | 8B | chat模型 | +| Ovis | [Ovis](https://github.com/AIDC-AI/Ovis) | 英文 | 9B | chat模型 | #### 扩散模型 diff --git "a/docs/source/Instruction/\346\224\257\346\214\201\347\232\204\346\250\241\345\236\213\345\222\214\346\225\260\346\215\256\351\233\206.md" "b/docs/source/Instruction/\346\224\257\346\214\201\347\232\204\346\250\241\345\236\213\345\222\214\346\225\260\346\215\256\351\233\206.md" index ea1f89b25..beee80e77 100644 --- "a/docs/source/Instruction/\346\224\257\346\214\201\347\232\204\346\250\241\345\236\213\345\222\214\346\225\260\346\215\256\351\233\206.md" +++ "b/docs/source/Instruction/\346\224\257\346\214\201\347\232\204\346\250\241\345\236\213\345\222\214\346\225\260\346\215\256\351\233\206.md" @@ -493,6 +493,7 @@ |internvl2-llama3-76b-awq|[OpenGVLab/InternVL2-Llama3-76B-AWQ](https://modelscope.cn/models/OpenGVLab/InternVL2-Llama3-76B-AWQ/summary)|^(language_model\|mlp1)(?!.\*(lm_head\|output\|emb\|wte\|shared)).\*|internvl2|✔|✔|✔|✘|transformers>=4.36, timm|vision, video|[OpenGVLab/InternVL2-Llama3-76B-AWQ](https://huggingface.co/OpenGVLab/InternVL2-Llama3-76B-AWQ)| |deepseek-vl-1_3b-chat|[deepseek-ai/deepseek-vl-1.3b-chat](https://modelscope.cn/models/deepseek-ai/deepseek-vl-1.3b-chat/summary)|^(language_model\|aligner)(?!.\*(lm_head\|output\|emb\|wte\|shared)).\*|deepseek-vl|✔|✘|✔|✘||vision|[deepseek-ai/deepseek-vl-1.3b-chat](https://huggingface.co/deepseek-ai/deepseek-vl-1.3b-chat)| |deepseek-vl-7b-chat|[deepseek-ai/deepseek-vl-7b-chat](https://modelscope.cn/models/deepseek-ai/deepseek-vl-7b-chat/summary)|^(language_model\|aligner)(?!.\*(lm_head\|output\|emb\|wte\|shared)).\*|deepseek-vl|✔|✘|✔|✘||vision|[deepseek-ai/deepseek-vl-7b-chat](https://huggingface.co/deepseek-ai/deepseek-vl-7b-chat)| +|ovis1_6-gemma2-9b|[AIDC-AI/Ovis1.6-Gemma2-9B](https://modelscope.cn/models/AIDC-AI/Ovis1.6-Gemma2-9B/summary)|^(llm)(?!.\*(lm_head\|output\|emb\|wte\|shared)).\*|ovis1_6|✔|✘|✘|✘|transformers>=4.42|vision|[AIDC-AI/Ovis1.6-Gemma2-9B](https://huggingface.co/AIDC-AI/Ovis1.6-Gemma2-9B)| |paligemma-3b-pt-224|[AI-ModelScope/paligemma-3b-pt-224](https://modelscope.cn/models/AI-ModelScope/paligemma-3b-pt-224/summary)|^(language_model\|multi_modal_projector)(?!.\*(lm_head\|output\|emb\|wte\|shared)).\*|paligemma|✔|✔|✘|✘|transformers>=4.41|vision|[google/paligemma-3b-pt-224](https://huggingface.co/google/paligemma-3b-pt-224)| |paligemma-3b-pt-448|[AI-ModelScope/paligemma-3b-pt-448](https://modelscope.cn/models/AI-ModelScope/paligemma-3b-pt-448/summary)|^(language_model\|multi_modal_projector)(?!.\*(lm_head\|output\|emb\|wte\|shared)).\*|paligemma|✔|✔|✘|✘|transformers>=4.41|vision|[google/paligemma-3b-pt-448](https://huggingface.co/google/paligemma-3b-pt-448)| |paligemma-3b-pt-896|[AI-ModelScope/paligemma-3b-pt-896](https://modelscope.cn/models/AI-ModelScope/paligemma-3b-pt-896/summary)|^(language_model\|multi_modal_projector)(?!.\*(lm_head\|output\|emb\|wte\|shared)).\*|paligemma|✔|✔|✘|✘|transformers>=4.41|vision|[google/paligemma-3b-pt-896](https://huggingface.co/google/paligemma-3b-pt-896)| diff --git a/docs/source_en/Instruction/Supported-models-datasets.md b/docs/source_en/Instruction/Supported-models-datasets.md index 08bc0b19a..397b2e906 100644 --- a/docs/source_en/Instruction/Supported-models-datasets.md +++ b/docs/source_en/Instruction/Supported-models-datasets.md @@ -493,6 +493,7 @@ The table below introcudes all models supported by SWIFT: |internvl2-llama3-76b-awq|[OpenGVLab/InternVL2-Llama3-76B-AWQ](https://modelscope.cn/models/OpenGVLab/InternVL2-Llama3-76B-AWQ/summary)|^(language_model\|mlp1)(?!.\*(lm_head\|output\|emb\|wte\|shared)).\*|internvl2|✔|✔|✔|✘|transformers>=4.36, timm|vision, video|[OpenGVLab/InternVL2-Llama3-76B-AWQ](https://huggingface.co/OpenGVLab/InternVL2-Llama3-76B-AWQ)| |deepseek-vl-1_3b-chat|[deepseek-ai/deepseek-vl-1.3b-chat](https://modelscope.cn/models/deepseek-ai/deepseek-vl-1.3b-chat/summary)|^(language_model\|aligner)(?!.\*(lm_head\|output\|emb\|wte\|shared)).\*|deepseek-vl|✔|✘|✔|✘||vision|[deepseek-ai/deepseek-vl-1.3b-chat](https://huggingface.co/deepseek-ai/deepseek-vl-1.3b-chat)| |deepseek-vl-7b-chat|[deepseek-ai/deepseek-vl-7b-chat](https://modelscope.cn/models/deepseek-ai/deepseek-vl-7b-chat/summary)|^(language_model\|aligner)(?!.\*(lm_head\|output\|emb\|wte\|shared)).\*|deepseek-vl|✔|✘|✔|✘||vision|[deepseek-ai/deepseek-vl-7b-chat](https://huggingface.co/deepseek-ai/deepseek-vl-7b-chat)| +|ovis1_6-gemma2-9b|[AIDC-AI/Ovis1.6-Gemma2-9B](https://modelscope.cn/models/AIDC-AI/Ovis1.6-Gemma2-9B/summary)|^(llm)(?!.\*(lm_head\|output\|emb\|wte\|shared)).\*|ovis1_6|✔|✘|✘|✘|transformers>=4.42|vision|[AIDC-AI/Ovis1.6-Gemma2-9B](https://huggingface.co/AIDC-AI/Ovis1.6-Gemma2-9B)| |paligemma-3b-pt-224|[AI-ModelScope/paligemma-3b-pt-224](https://modelscope.cn/models/AI-ModelScope/paligemma-3b-pt-224/summary)|^(language_model\|multi_modal_projector)(?!.\*(lm_head\|output\|emb\|wte\|shared)).\*|paligemma|✔|✔|✘|✘|transformers>=4.41|vision|[google/paligemma-3b-pt-224](https://huggingface.co/google/paligemma-3b-pt-224)| |paligemma-3b-pt-448|[AI-ModelScope/paligemma-3b-pt-448](https://modelscope.cn/models/AI-ModelScope/paligemma-3b-pt-448/summary)|^(language_model\|multi_modal_projector)(?!.\*(lm_head\|output\|emb\|wte\|shared)).\*|paligemma|✔|✔|✘|✘|transformers>=4.41|vision|[google/paligemma-3b-pt-448](https://huggingface.co/google/paligemma-3b-pt-448)| |paligemma-3b-pt-896|[AI-ModelScope/paligemma-3b-pt-896](https://modelscope.cn/models/AI-ModelScope/paligemma-3b-pt-896/summary)|^(language_model\|multi_modal_projector)(?!.\*(lm_head\|output\|emb\|wte\|shared)).\*|paligemma|✔|✔|✘|✘|transformers>=4.41|vision|[google/paligemma-3b-pt-896](https://huggingface.co/google/paligemma-3b-pt-896)| diff --git a/swift/llm/utils/model.py b/swift/llm/utils/model.py index 84ed65ac7..327280c2c 100644 --- a/swift/llm/utils/model.py +++ b/swift/llm/utils/model.py @@ -458,6 +458,8 @@ class ModelType: gemma2_2b_instruct = 'gemma2-2b-instruct' gemma2_9b_instruct = 'gemma2-9b-instruct' gemma2_27b_instruct = 'gemma2-27b-instruct' + + ovis1_6_gemma2_9b = 'ovis1_6-gemma2-9b' # paligemma paligemma_3b_pt_224 = 'paligemma-3b-pt-224' paligemma_3b_pt_448 = 'paligemma-3b-pt-448' @@ -652,6 +654,7 @@ class LoRATM(NamedTuple): llama3_1_omni = 'llama3_1_omni' got_ocr2 = 'got_ocr2' llama3_2_vision = 'llama3_2_vision' + ovis1_6 = 'ovis1_6' # default lora target modules for nlp llms. minicpm3 = ['q_a_proj', 'q_b_proj', 'kv_a_proj_with_mqa', 'kv_b_proj'] baichuan = ['W_pack'] @@ -2745,6 +2748,41 @@ def get_model_tokenizer_with_flash_attn(model_dir: str, model_dir, torch_dtype, model_kwargs, load_model, model_config=model_config, **kwargs) +@register_model( + ModelType.ovis1_6_gemma2_9b, + 'AIDC-AI/Ovis1.6-Gemma2-9B', + LoRATM.ovis1_6, + TemplateType.ovis1_6, + requires=['transformers>=4.42'], + support_flash_attn=True, + tags=['multi-modal', 'vision'], + hf_model_id='AIDC-AI/Ovis1.6-Gemma2-9B') +def get_model_tokenizer_ovis(*args, **kwargs): + model, tokenizer = get_model_tokenizer_with_flash_attn(*args, **kwargs) + if model is not None: + func_list = ['generate', 'forward', 'get_input_embeddings'] + _use_submodel_func(model, 'llm', func_list) + embedding = model.get_input_embeddings() + embedding.register_forward_hook(_clone_hook) + model.config.keys_to_ignore_at_inference = ['past_key_values'] # fix prediction_step + try: + # fix device_map + from transformers.cache_utils import HybridCache + + def update(self, key_states: torch.Tensor, value_states: torch.Tensor, layer_idx: int, *args, + **kwargs) -> Tuple[torch.Tensor]: + self.key_cache[layer_idx] = self.key_cache[layer_idx].to(key_states.device) + self.value_cache[layer_idx] = self.value_cache[layer_idx].to(value_states.device) + return self._update_origin(key_states, value_states, layer_idx, *args, **kwargs) + + if not hasattr(HybridCache, '_update_origin'): + HybridCache._update_origin = HybridCache.update + HybridCache.update = update + except ImportError: + pass + return model, tokenizer + + @register_model( ModelType.mplug_owl3_7b_chat, 'iic/mPLUG-Owl3-7B-240728', @@ -2762,8 +2800,9 @@ def get_model_tokenizer_mplug_owl3(model_dir: str, model, tokenizer = get_model_tokenizer_with_flash_attn(model_dir, torch_dtype, model_kwargs, load_model, **kwargs) processor = model.init_processor(tokenizer) tokenizer.processor = processor - func_list = ['generate', 'forward'] - _use_submodel_func(model, 'language_model', func_list) + if model is not None: + func_list = ['generate', 'forward'] + _use_submodel_func(model, 'language_model', func_list) return model, tokenizer @@ -2958,8 +2997,9 @@ def get_model_tokenizer_florence(model_dir: str, model_dir, torch_dtype, model_kwargs, load_model, tokenizer=processor.tokenizer, **kwargs) tokenizer.processor = processor - # model.vision_tower.enable_checkpoint = True - _use_submodel_func(model, 'language_model', ['generate', 'forward']) + if model is not None: + model.vision_tower.enable_checkpoint = True + _use_submodel_func(model, 'language_model', ['generate', 'forward']) return model, tokenizer diff --git a/swift/llm/utils/template.py b/swift/llm/utils/template.py index 40a2918e5..42057a8e0 100644 --- a/swift/llm/utils/template.py +++ b/swift/llm/utils/template.py @@ -144,6 +144,7 @@ class TemplateType: c4ai = 'c4ai' chatml = 'chatml' got_ocr2 = 'got_ocr2' + ovis1_6 = 'ovis1_6' # compatibility. (Deprecated) default_generation_bos = 'default-generation-bos' yi = 'yi' @@ -1285,6 +1286,65 @@ def data_collator(self, batch: List[Dict[str, Any]], padding_to: Optional[int] = register_template(TemplateType.got_ocr2, GOT_OCR2Template(), lazy_tokenize=True, use_model=True) +class OVIS1_6Template(Template): + + def replace_tag(self, media_type: Literal['image', 'video', 'audio'], index: int, + example: Dict[str, Any]) -> List[Context]: + assert media_type == 'image' + return [[-200], '\n'] + + def _encode(self, example: Dict[str, Any]) -> Tuple[Dict[str, Any], Dict[str, Any]]: + inputs, tokenizer_kwargs = super()._encode(example) + if len(inputs) == 0: + return inputs, {} + images = example['images'] + input_ids = inputs['input_ids'] + labels = inputs['labels'] + idx_list = _findall(input_ids, [-200]) + added_tokens_len = 0 + pixel_values = [] + for i, idx in enumerate(idx_list): + max_partition = get_env_args('max_partition', int, 9) + raw_pixel_values, image_placeholders = self.model.visual_tokenizer.preprocess_image( + images[i], max_partition=max_partition) + input_ids = input_ids[:idx] + image_placeholders + input_ids[idx + 1:] + if labels is not None: + labels = labels[:idx] + [-100] * len(image_placeholders) + labels[idx + 1:] + pixel_values.append(raw_pixel_values) + added_tokens_len += len(image_placeholders) - 1 + if pixel_values: + pixel_values = torch.cat(pixel_values, dim=0).to(self.model.visual_tokenizer.dtype) + else: + pixel_values = None + inputs = {'labels': labels} + if labels is not None: + labels = torch.tensor(labels)[None] + inputs['_data'] = {'input_ids': torch.tensor(input_ids)[None], 'labels': labels, 'pixel_values': [pixel_values]} + return inputs, {} + + def _post_encode(self, model, data: Any) -> Dict[str, Any]: + _, inputs_embeds, labels, _ = self.model.merge_multimodal( + text_input_ids=data['input_ids'], + text_attention_masks=torch.ones_like(data['input_ids']), # not use, only compat + text_labels=data['labels'], + pixel_values=data['pixel_values'], + left_padding=True) + return {'inputs_embeds': inputs_embeds[0], 'labels': labels} + + @staticmethod + def _get_generate_ids(generate_ids: List[int], input_token_len: int) -> List[int]: + return generate_ids + + +register_template( + TemplateType.ovis1_6, + OVIS1_6Template([''], ['user\n{{QUERY}}\nmodel\n'], + ['\n'], [''], None, + ['system\n{{SYSTEM}}\n']), + lazy_tokenize=True, + use_model=True) + + class _QwenVLTemplateMixin: load_medias = False diff --git a/swift/utils/import_utils.py b/swift/utils/import_utils.py index 99991c646..6b5ce2791 100644 --- a/swift/utils/import_utils.py +++ b/swift/utils/import_utils.py @@ -60,12 +60,7 @@ def __getattr__(self, name: str) -> Any: return value def _get_module(self, module_name: str): - try: - return importlib.import_module('.' + module_name, self.__name__) - except Exception as e: - raise RuntimeError( - f'Failed to import {self.__name__}.{module_name} because of the following error (look up to see its' - f' traceback):\n{e}') from e + return importlib.import_module('.' + module_name, self.__name__) def __reduce__(self): return self.__class__, (self._name, self.__file__, self._import_structure) diff --git a/swift/utils/module_mapping.py b/swift/utils/module_mapping.py index 6c0bf6b40..fc9fad880 100644 --- a/swift/utils/module_mapping.py +++ b/swift/utils/module_mapping.py @@ -302,6 +302,11 @@ def __post_init__(self): vision_tower='vision_model', ) +OVIS1_6 = MultiModelKeys( + language_model='llm', + vision_tower='visual_tokenizer', +) + MODEL_KEYS_MAPPING = OrderedDict([ # MLLM here ('qwen_audio', QWEN_AUDIO_KEYS), @@ -324,6 +329,7 @@ def __post_init__(self): ('llama3_1_omni', LLAMA3_1_OMNI), ('got_ocr2', GOT_OCR2), ('llama3_2_vision', LLAMA3_2_VISION), + ('ovis1_6', OVIS1_6), # LLM begins here ('llama', LLAMA_KEYS), ('mistral', LLAMA_KEYS),