From 23be0d161205c5a48c8677abeece31195a8dd444 Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Wed, 27 Nov 2024 21:44:02 +0800 Subject: [PATCH 1/4] fix tokenizer.pad_token; fix world_size --- swift/llm/argument/base_args/base_args.py | 4 ++-- swift/llm/argument/train_args.py | 5 +++-- swift/llm/infer/infer.py | 6 +++--- swift/llm/infer/infer_engine/pt_engine.py | 10 +++++----- swift/llm/infer/protocol.py | 6 +++--- swift/llm/model/register.py | 18 ++++++++++++------ swift/llm/template/base.py | 6 ++---- 7 files changed, 30 insertions(+), 25 deletions(-) diff --git a/swift/llm/argument/base_args/base_args.py b/swift/llm/argument/base_args/base_args.py index 5823f7221..ce012db0c 100644 --- a/swift/llm/argument/base_args/base_args.py +++ b/swift/llm/argument/base_args/base_args.py @@ -47,9 +47,9 @@ def __post_init__(self): if self.use_hf: os.environ['USE_HF'] = '1' self._init_model_kwargs() - self.rank, self.local_rank, self.global_world_size, self.local_world_size = get_dist_setting() + self.rank, self.local_rank, self.world_size, self.local_world_size = get_dist_setting() logger.info(f'rank: {self.rank}, local_rank: {self.local_rank}, ' - f'global_world_size: {self.global_world_size}, local_world_size: {self.local_world_size}') + f'world_size: {self.world_size}, local_world_size: {self.local_world_size}') ModelArguments.__post_init__(self) QuantizeArguments.__post_init__(self) TemplateArguments.__post_init__(self) diff --git a/swift/llm/argument/train_args.py b/swift/llm/argument/train_args.py index 1f936e8d0..825c23074 100644 --- a/swift/llm/argument/train_args.py +++ b/swift/llm/argument/train_args.py @@ -7,7 +7,7 @@ import torch import torch.distributed as dist -from transformers import Seq2SeqTrainingArguments +from transformers import Seq2SeqTrainingArguments, TrainingArguments from transformers.utils import is_torch_npu_available from transformers.utils.versions import require_version @@ -45,6 +45,7 @@ def _init_output_dir(self): self.output_dir = f'output/{self.model_name}' def __post_init__(self): + del TrainingArguments.world_size self._init_output_dir() if self.learning_rate is None: @@ -122,8 +123,8 @@ def __post_init__(self) -> None: self.load_args_from_ckpt(self.resume_from_checkpoint) if self.train_type == 'full': self.model_id_or_path = self.resume_from_checkpoint - BaseArguments.__post_init__(self) Seq2SeqTrainingOverrideArguments.__post_init__(self) + BaseArguments.__post_init__(self) TunerArguments.__post_init__(self) TorchAccArguments.__post_init__(self) diff --git a/swift/llm/infer/infer.py b/swift/llm/infer/infer.py index 355899c57..ea1976ac2 100644 --- a/swift/llm/infer/infer.py +++ b/swift/llm/infer/infer.py @@ -284,9 +284,9 @@ def infer_dataset(self) -> List[Dict[str, Any]]: if self.jsonl_writer: self.jsonl_writer.append(data) else: - is_dist = args.global_world_size > 1 and dist.is_initialized() + is_dist = args.world_size > 1 and dist.is_initialized() if is_dist: - val_dataset = val_dataset.shard(args.global_world_size, args.rank, contiguous=True) + val_dataset = val_dataset.shard(args.world_size, args.rank, contiguous=True) infer_requests = [InferRequest(**data) for i, data in enumerate(val_dataset)] resp_list = self.infer(infer_requests, request_config, template=self.template, use_tqdm=True) @@ -295,7 +295,7 @@ def infer_dataset(self) -> List[Dict[str, Any]]: data = {'response': response, **data} result_list.append(data) if is_dist: - total_result_list = [None for _ in range(args.global_world_size)] if args.rank == 0 else None + total_result_list = [None for _ in range(args.world_size)] if args.rank == 0 else None dist.gather_object(result_list, total_result_list) result_list = total_result_list and list(chain.from_iterable(total_result_list)) diff --git a/swift/llm/infer/infer_engine/pt_engine.py b/swift/llm/infer/infer_engine/pt_engine.py index 656c29c94..718c2b526 100644 --- a/swift/llm/infer/infer_engine/pt_engine.py +++ b/swift/llm/infer/infer_engine/pt_engine.py @@ -95,7 +95,7 @@ def _add_stop_words(self, generation_config: _GenerationConfig, request_config: if generation_config.eos_token_id is None: generation_config.eos_token_id = self.tokenizer.eos_token_id if generation_config.pad_token_id is None: - generation_config.pad_token_id = template.pad_token_id + generation_config.pad_token_id = self.tokenizer.pad_token_id @staticmethod def preprocess_logits(batched_logits: Optional[List[torch.Tensor]], batched_generate_ids: torch.Tensor, @@ -201,7 +201,7 @@ def _model_generate(*args, **kwargs): generate_ids = batched_generate_ids[i] # ignore pad_token - masks = generate_ids != template.pad_token_id + masks = generate_ids != self.tokenizer.pad_token_id generate_ids = generate_ids[masks].tolist() logprobs_list = None if batched_logprobs[i]: @@ -209,7 +209,7 @@ def _model_generate(*args, **kwargs): is_finished[i] = ( all_is_finished or is_finished[i] - or len(generate_ids) > 0 and generate_ids[-1] == template.pad_token_id) + or len(generate_ids) > 0 and generate_ids[-1] == self.tokenizer.pad_token_id) delta_text = infer_streamers[i].get_printable_text(generate_ids, is_finished[i]) if not delta_text and not is_finished[i]: res.append(None) @@ -269,7 +269,7 @@ def _infer_full( generate_ids = batched_generate_ids[i] # ignore pad_token - masks = generate_ids != template.pad_token_id + masks = generate_ids != self.tokenizer.pad_token_id generate_ids = generate_ids[masks].tolist() logprobs_list = None if batched_logprobs is not None: @@ -291,7 +291,7 @@ def _infer_full( elif isinstance(response, Image.Image): res.append( ImagesResponse( - created=time.time(), data=[ImageObject(b64_json=MultiModalRequestMixin._to_base64(response))])) + created=time.time(), data=[ImageObject(b64_json=MultiModalRequestMixin.to_base64(response))])) return res diff --git a/swift/llm/infer/protocol.py b/swift/llm/infer/protocol.py index cfe41eea1..65bfcc696 100644 --- a/swift/llm/infer/protocol.py +++ b/swift/llm/infer/protocol.py @@ -101,7 +101,7 @@ class MultiModalRequestMixin: videos: List[str] = field(default_factory=list) @staticmethod - def _to_base64(mm_data: Union[str, Image.Image, bytes]) -> str: + def to_base64(mm_data: Union[str, Image.Image, bytes]) -> str: if isinstance(mm_data, str) and not os.path.isfile(mm_data): # base64 or url return mm_data @@ -125,7 +125,7 @@ def __post_init__(self): values = [values] setattr(self, key, values) for i, val in enumerate(values): - values[i] = self._to_base64(val) + values[i] = self.to_base64(val) @dataclass @@ -173,7 +173,7 @@ def convert_to_base64(self): suffix = 'jpeg' else: raise ValueError(f'value: {value}') - mm_data_base64 = self._to_base64(value) + mm_data_base64 = self.to_base64(value) new_value = f'data:{key}/{suffix};base64,{mm_data_base64}' if is_dict: new_value = {'url': new_value} diff --git a/swift/llm/model/register.py b/swift/llm/model/register.py index 6418ac410..d8c8cfe9f 100644 --- a/swift/llm/model/register.py +++ b/swift/llm/model/register.py @@ -457,12 +457,18 @@ def get_model_tokenizer(model_id_or_path: str, model_meta.check_requires() get_function = model_meta.get_function kwargs['automodel_class'] = automodel_class - model, tokenizer = get_function(model_dir, model_info, model_kwargs, load_model, **kwargs) + model, processor = get_function(model_dir, model_info, model_kwargs, load_model, **kwargs) - if not isinstance(tokenizer, PreTrainedTokenizerBase) and hasattr(tokenizer, 'tokenizer'): - patch_processor(tokenizer) - tokenizer.model_info = model_info - tokenizer.model_meta = model_meta + if not isinstance(processor, PreTrainedTokenizerBase) and hasattr(processor, 'tokenizer'): + tokenizer = processor.tokenizer + patch_processor(processor) + else: + tokenizer = processor + processor.model_info = model_info + processor.model_meta = model_meta + tokenizer.pad_token_id = tokenizer.pad_token_id or tokenizer.eos_token_id + assert tokenizer.eos_token_id is not None + assert tokenizer.pad_token_id is not None if model is not None: model.model_info = model_info @@ -476,4 +482,4 @@ def get_model_tokenizer(model_id_or_path: str, model.generation_config = GenerationConfig.from_pretrained(model_dir) # fix llama2 warning fix_do_sample_warning(model.generation_config) - return model, tokenizer + return model, processor diff --git a/swift/llm/template/base.py b/swift/llm/template/base.py index b255849d6..fc6e3ad5e 100644 --- a/swift/llm/template/base.py +++ b/swift/llm/template/base.py @@ -69,8 +69,6 @@ def __init__( self.model_info = processor.model_info self.model_meta = processor.model_meta tokenizer = self.tokenizer - self.pad_token_id = tokenizer.pad_token_id or tokenizer.eos_token_id - assert self.pad_token_id is not None if not use_chat_template: template_meta = template_meta.to_generate_template_meta() @@ -797,7 +795,7 @@ def _data_collator(self, if len(batch) == 0: return {} from swift.utils import use_torchacc - assert self.pad_token_id is not None + assert self.tokenizer.pad_token_id is not None if padding_side is None: padding_side = self.padding_side padding_right = padding_side == 'right' @@ -815,7 +813,7 @@ def _data_collator(self, res[key] = val keys = ['input_ids', 'inputs_embeds', 'attention_mask', 'labels', 'loss_scale', 'position_ids'] - pad_value = [self.pad_token_id, 0., 0, -100, 0., -1] + pad_value = [self.tokenizer.pad_token_id, 0., 0, -100, 0., -1] # Convert to tensor and remove unnecessary dimensions. seq_lens = None for key in keys: From c0d313d104a9b2ae66957e654790f890685f186f Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Wed, 27 Nov 2024 21:46:48 +0800 Subject: [PATCH 2/4] add metric_for_best_model --- swift/llm/argument/train_args.py | 2 ++ tests/infer/test_agent.py | 1 - tests/train/test_freeze.py | 1 - tests/train/test_kto.py | 1 - tests/train/test_pt.py | 1 - tests/train/test_rlhf.py | 1 - tests/train/test_sft.py | 1 - 7 files changed, 2 insertions(+), 6 deletions(-) diff --git a/swift/llm/argument/train_args.py b/swift/llm/argument/train_args.py index 825c23074..154922946 100644 --- a/swift/llm/argument/train_args.py +++ b/swift/llm/argument/train_args.py @@ -38,6 +38,8 @@ class Seq2SeqTrainingOverrideArguments(Seq2SeqTrainingArguments): report_to: List[str] = field(default_factory=lambda: ['tensorboard']) remove_unused_columns: bool = False logging_first_step: bool = True + # Usually, the point where eval_loss is minimized does not represent the best model. + metric_for_best_model: str = 'loss' def _init_output_dir(self): if self.output_dir is not None: diff --git a/tests/infer/test_agent.py b/tests/infer/test_agent.py index dfc024a42..00d8d4678 100644 --- a/tests/infer/test_agent.py +++ b/tests/infer/test_agent.py @@ -9,7 +9,6 @@ 'save_steps': 50, 'gradient_accumulation_steps': 4, 'num_train_epochs': 1, - 'metric_for_best_model': 'loss' } diff --git a/tests/train/test_freeze.py b/tests/train/test_freeze.py index 6f3537899..c229afab0 100644 --- a/tests/train/test_freeze.py +++ b/tests/train/test_freeze.py @@ -7,7 +7,6 @@ 'save_steps': 5, 'gradient_accumulation_steps': 4, 'num_train_epochs': 1, - 'metric_for_best_model': 'loss' } diff --git a/tests/train/test_kto.py b/tests/train/test_kto.py index dc8fa1b6f..59601a541 100644 --- a/tests/train/test_kto.py +++ b/tests/train/test_kto.py @@ -7,7 +7,6 @@ 'save_steps': 5, 'gradient_accumulation_steps': 4, 'num_train_epochs': 1, - 'metric_for_best_model': 'loss' } diff --git a/tests/train/test_pt.py b/tests/train/test_pt.py index 39aac8a50..7b402bd92 100644 --- a/tests/train/test_pt.py +++ b/tests/train/test_pt.py @@ -7,7 +7,6 @@ 'save_steps': 5, 'gradient_accumulation_steps': 4, 'num_train_epochs': 1, - 'metric_for_best_model': 'loss' } diff --git a/tests/train/test_rlhf.py b/tests/train/test_rlhf.py index 6ce25727a..3b8389071 100644 --- a/tests/train/test_rlhf.py +++ b/tests/train/test_rlhf.py @@ -7,7 +7,6 @@ 'save_steps': 5, 'gradient_accumulation_steps': 4, 'num_train_epochs': 1, - 'metric_for_best_model': 'loss' } diff --git a/tests/train/test_sft.py b/tests/train/test_sft.py index 37d1c38ee..9ef222a9f 100644 --- a/tests/train/test_sft.py +++ b/tests/train/test_sft.py @@ -9,7 +9,6 @@ 'save_steps': 5, 'gradient_accumulation_steps': 4, 'num_train_epochs': 1, - 'metric_for_best_model': 'loss' } From 7078593858ae5ee83b9c0367f0061f66f04d3f06 Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Wed, 27 Nov 2024 22:57:55 +0800 Subject: [PATCH 3/4] support row_list & upate hc3 --- swift/llm/dataset/dataset/llm.py | 156 +++++-------------------- swift/llm/dataset/preprocessor/core.py | 29 +++-- swift/llm/model/constant.py | 1 - swift/llm/train/sft.py | 8 +- tests/general/test_dataset.py | 8 +- 5 files changed, 55 insertions(+), 147 deletions(-) diff --git a/swift/llm/dataset/dataset/llm.py b/swift/llm/dataset/dataset/llm.py index 03aefe51b..477506e0f 100644 --- a/swift/llm/dataset/dataset/llm.py +++ b/swift/llm/dataset/dataset/llm.py @@ -345,131 +345,46 @@ def preprocess(self, row: Dict[str, Any]) -> Dict[str, Any]: register_dataset(DatasetMeta(ms_dataset_id='swift/ToolBench', tags=['chat', 'agent', 'multi-round'])) -def _preprocess_hc3(dataset: DATASET_TYPE, **kwargs) -> DATASET_TYPE: +class HC3Preprocessor(ResponsePreprocessor): prompt = """Classification Task: Are the following responses from a human or from ChatGPT? Question: {question} Answer: {answer} Category: Human, ChatGPT Output:""" - if isinstance(dataset, IterableDataset): - - def generate_example(dataset): - for example in dataset: - question = example['question'] - for h in example['human_answers']: - yield { - 'messages': [{ - 'role': 'user', - 'content': prompt.format(question=question, answer=h) - }, { - 'role': 'assistant', - 'content': 'Human' - }] - } - for c in example['chatgpt_answers']: - yield { - 'messages': [{ - 'role': 'user', - 'content': prompt.format(question=question, answer=c) - }, { - 'role': 'assistant', - 'content': 'ChatGPT' - }] - } - - return IterableDataset.from_generator(generate_example, gen_kwargs={'dataset': dataset}) - - messages = [] - for d in dataset: - question = d['question'] - for h in d['human_answers']: - messages.append({ - 'messages': [{ - 'role': 'user', - 'content': prompt.format(question=question, answer=h) - }, { - 'role': 'assistant', - 'content': 'Human' - }] - }) - for c in d['chatgpt_answers']: - messages.append({ - 'messages': [{ - 'role': 'user', - 'content': prompt.format(question=question, answer=c) - }, { - 'role': 'assistant', - 'content': 'ChatGPT' - }] - }) - return HfDataset.from_list(messages) - - -def _preprocess_hc3_cls(dataset: DATASET_TYPE, **kwargs) -> DATASET_TYPE: - prompt = """Classification Task: Are the following responses from a human or from ChatGPT? -Question: {question} -Answer: {answer} -Category: 0 for Human, 1 for ChatGPT -Output:""" - if isinstance(dataset, IterableDataset): - - def generate_example(dataset): - for example in dataset: - question = example['question'] - for h in example['human_answers']: - yield { - 'messages': [{ - 'role': 'user', - 'content': prompt.format(question=question, answer=h) - }], - 'label': 0, - } - for c in example['chatgpt_answers']: - yield { - 'messages': [{ - 'role': 'user', - 'content': prompt.format(question=question, answer=c) - }], - 'label': 1, - } - - return IterableDataset.from_generator(generate_example, gen_kwargs={'dataset': dataset}) - - messages = [] - for d in dataset: - question = d['question'] - for h in d['human_answers']: - messages.append({ - 'messages': [{ - 'role': 'user', - 'content': prompt.format(question=question, answer=h) - }], - 'label': 0, - }) - for c in d['chatgpt_answers']: - messages.append({ - 'messages': [{ - 'role': 'user', - 'content': prompt.format(question=question, answer=c) - }], - 'label': 1, - }) - return HfDataset.from_list(messages) + + def preprocess(self, row): + rows = [] + for response in ['Human', 'ChatGPT']: + query = self.prompt.format(question=row['query'], answer=row[f'{response.lower()}_answers']) + rows.append(super().preprocess({'query': query, 'response': response})) + return rows + + +class HC3ClsPreprocessor(HC3Preprocessor): + + def preprocess(self, row): + rows = [] + for i, response in enumerate(['Human', 'ChatGPT']): + query = self.prompt.format(question=row['query'], answer=row[f'{response.lower()}_answers']) + rows.append(ResponsePreprocessor.preprocess(self, {'query': query, 'label': i})) + return rows hc3_subset_names = ['baike', 'open_qa', 'nlpcc_dbqa', 'finance', 'medicine', 'law', 'psychology'] hc3_subsets: List[SubsetDataset] = [] for hc3_subset_name in hc3_subset_names: - hc3_subsets.append(SubsetDataset( - name=hc3_subset_name, - subset=hc3_subset_name, - preprocess_func=_preprocess_hc3, - )) - hc3_subsets.append(SubsetDataset( - name=f'{hc3_subset_name}_cls', - subset=hc3_subset_name, - preprocess_func=_preprocess_hc3_cls, - )) + hc3_subsets.append( + SubsetDataset( + name=hc3_subset_name, + subset=hc3_subset_name, + preprocess_func=HC3Preprocessor(), + )) + hc3_subsets.append( + SubsetDataset( + name=f'{hc3_subset_name}_cls', + subset=hc3_subset_name, + preprocess_func=HC3ClsPreprocessor(), + )) register_dataset( DatasetMeta( @@ -478,21 +393,12 @@ def generate_example(dataset): subsets=hc3_subsets, tags=['text-generation', 'classification', 'πŸ”₯'])) - -register_dataset( - DatasetMeta( - ms_dataset_id='simpleai/HC3-Chinese', - hf_dataset_id='Hello-SimpleAI/HC3-Chinese', - subsets=['baike', 'open_qa', 'nlpcc_dbqa', 'finance', 'medicine', 'law', 'psychology'], - preprocess_func=_preprocess_hc3, - tags=['text-generation', 'classification', 'πŸ”₯'])) - register_dataset( DatasetMeta( ms_dataset_id='simpleai/HC3', hf_dataset_id='Hello-SimpleAI/HC3', subsets=['finance', 'medicine'], - preprocess_func=_preprocess_hc3, + preprocess_func=HC3Preprocessor(), tags=['text-generation', 'classification', 'πŸ”₯'])) diff --git a/swift/llm/dataset/preprocessor/core.py b/swift/llm/dataset/preprocessor/core.py index 51439dc37..ce719eb9e 100644 --- a/swift/llm/dataset/preprocessor/core.py +++ b/swift/llm/dataset/preprocessor/core.py @@ -116,9 +116,14 @@ def batched_preprocess(self, batched_row: Dict[str, Any], *, strict: bool) -> Di for row in rows: try: row = self.preprocess(row) - if row is not None: - self.check_messages(row) - self.check_rejected_response(row) + # support [row1, row2, ...] + if row is None: + row = [] + if isinstance(row, dict): + row = [row] + for r in row: + self.check_messages(r) + self.check_rejected_response(r) except Exception: if strict: logger.warning('To avoid errors, you can pass `strict=False`.') @@ -128,10 +133,8 @@ def batched_preprocess(self, batched_row: Dict[str, Any], *, strict: bool) -> Di print(traceback.format_exc()) logger.error('πŸ‘†πŸ‘†πŸ‘†There are errors in the dataset, the data will be deleted') self._traceback_counter += 1 - row = None - if row is None: - continue - new_rows.append(row) + row = [] + new_rows += row res = self.rows_to_batched(new_rows) if len(res) == 0: @@ -247,14 +250,10 @@ def __init__(self, *, columns_mapping: Optional[Dict[str, str]] = None, **kwargs def preprocess(self, row: Dict[str, Any]) -> Optional[Dict[str, Any]]: response = row.pop('response', None) - if response is None: - row.pop('query', None) - row.pop('history', None) - row.pop('system', None) - return - if isinstance(response, (list, tuple)): - # sometimes response is a list, pick one randomly - response = self.random_state.choice(response) + if response is not None: + if isinstance(response, (list, tuple)): + # sometimes response is a list, pick one randomly + response = self.random_state.choice(response) history = row.pop('history', None) or [] query = row.pop('query', None) system = row.pop('system', None) diff --git a/swift/llm/model/constant.py b/swift/llm/model/constant.py index 02f735f3e..4dc178466 100644 --- a/swift/llm/model/constant.py +++ b/swift/llm/model/constant.py @@ -34,7 +34,6 @@ class LLMModelType: modelscope_agent = 'modelscope_agent' qwen2 = 'qwen2' qwen2_5 = 'qwen2_5' - qwen2_5_cls = 'qwen2_5_cls' llama = 'llama' llama3 = 'llama3' diff --git a/swift/llm/train/sft.py b/swift/llm/train/sft.py index eb6648bfb..55d8971eb 100644 --- a/swift/llm/train/sft.py +++ b/swift/llm/train/sft.py @@ -77,13 +77,11 @@ def _get_model_tokenizer(self, model, model_type, model_revision): model_kwargs['model_id_or_path'] = model model_kwargs['model_type'] = model_type model_kwargs['model_revision'] = model_revision - automodel_param = {} if args.num_labels is not None: from modelscope import AutoModelForSequenceClassification - automodel_param = {'automodel_class': AutoModelForSequenceClassification} - model, tokenizer = get_model_tokenizer( - **model_kwargs, use_unsloth=args.tuner_backend == 'unsloth', **automodel_param) - model.num_labels = args.num_labels + model_kwargs = {'automodel_class': AutoModelForSequenceClassification} + model, tokenizer = get_model_tokenizer(**model_kwargs, use_unsloth=args.tuner_backend == 'unsloth') + model.num_labels = args.num_labels # TODO return model, tokenizer def _prepare_model_tokenizer(self): diff --git a/tests/general/test_dataset.py b/tests/general/test_dataset.py index bb43dff07..7c72a4298 100644 --- a/tests/general/test_dataset.py +++ b/tests/general/test_dataset.py @@ -59,6 +59,11 @@ def test_dataset_info(): # _test_dataset(['codefuse-ai/CodeExercise-Python-27k']) +def test_cls(): + _test_dataset(['simpleai/HC3-Chinese:baike']) + _test_dataset(['simpleai/HC3-Chinese:baike_cls']) + + if __name__ == '__main__': # test_sft() # test_agent() @@ -66,4 +71,5 @@ def test_dataset_info(): # test_kto() # test_mllm() # test_pretrain() - test_dataset_info() + # test_dataset_info() + test_cls() From 04f01ef71c632e5d2d34de95a0b3b05d8017ff99 Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Wed, 27 Nov 2024 23:02:40 +0800 Subject: [PATCH 4/4] update --- swift/llm/infer/infer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/swift/llm/infer/infer.py b/swift/llm/infer/infer.py index 5acf140d4..0ff610244 100644 --- a/swift/llm/infer/infer.py +++ b/swift/llm/infer/infer.py @@ -12,9 +12,9 @@ from swift.llm import (InferArguments, InferRequest, Messages, Processor, SwiftPipeline, Template, get_template, load_dataset, sample_dataset) +from swift.plugin import extra_tuners from swift.tuners import Swift from swift.utils import get_logger, is_master, open_jsonl_writer -from ...plugin import extra_tuners from .protocol import RequestConfig logger = get_logger()