Skip to content

Commit

Permalink
Merge branch 'feat/refactor3' of https://github.com/modelscope/swift
Browse files Browse the repository at this point in the history
…into feat/refactor3
  • Loading branch information
tastelikefeet committed Nov 27, 2024
2 parents 7b69078 + 04f01ef commit 75b0bcd
Show file tree
Hide file tree
Showing 18 changed files with 82 additions and 166 deletions.
4 changes: 2 additions & 2 deletions swift/llm/argument/base_args/base_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,9 +47,9 @@ def __post_init__(self):
if self.use_hf:
os.environ['USE_HF'] = '1'
self._init_model_kwargs()
self.rank, self.local_rank, self.global_world_size, self.local_world_size = get_dist_setting()
self.rank, self.local_rank, self.world_size, self.local_world_size = get_dist_setting()
logger.info(f'rank: {self.rank}, local_rank: {self.local_rank}, '
f'global_world_size: {self.global_world_size}, local_world_size: {self.local_world_size}')
f'world_size: {self.world_size}, local_world_size: {self.local_world_size}')
ModelArguments.__post_init__(self)
QuantizeArguments.__post_init__(self)
TemplateArguments.__post_init__(self)
Expand Down
7 changes: 5 additions & 2 deletions swift/llm/argument/train_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

import torch
import torch.distributed as dist
from transformers import Seq2SeqTrainingArguments
from transformers import Seq2SeqTrainingArguments, TrainingArguments
from transformers.utils import is_torch_npu_available
from transformers.utils.versions import require_version

Expand Down Expand Up @@ -38,13 +38,16 @@ class Seq2SeqTrainingOverrideArguments(Seq2SeqTrainingArguments):
report_to: List[str] = field(default_factory=lambda: ['tensorboard'])
remove_unused_columns: bool = False
logging_first_step: bool = True
# Usually, the point where eval_loss is minimized does not represent the best model.
metric_for_best_model: str = 'loss'

def _init_output_dir(self):
if self.output_dir is not None:
return
self.output_dir = f'output/{self.model_name}'

def __post_init__(self):
del TrainingArguments.world_size
self._init_output_dir()

if self.learning_rate is None:
Expand Down Expand Up @@ -122,8 +125,8 @@ def __post_init__(self) -> None:
self.load_args_from_ckpt(self.resume_from_checkpoint)
if self.train_type == 'full':
self.model_id_or_path = self.resume_from_checkpoint
BaseArguments.__post_init__(self)
Seq2SeqTrainingOverrideArguments.__post_init__(self)
BaseArguments.__post_init__(self)
TunerArguments.__post_init__(self)
TorchAccArguments.__post_init__(self)

Expand Down
138 changes: 26 additions & 112 deletions swift/llm/dataset/dataset/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -345,131 +345,45 @@ def preprocess(self, row: Dict[str, Any]) -> Dict[str, Any]:
register_dataset(DatasetMeta(ms_dataset_id='swift/ToolBench', tags=['chat', 'agent', 'multi-round']))


def _preprocess_hc3(dataset: DATASET_TYPE, **kwargs) -> DATASET_TYPE:
class HC3Preprocessor(ResponsePreprocessor):
prompt = """Classification Task: Are the following responses from a human or from ChatGPT?
Question: {question}
Answer: {answer}
Category: Human, ChatGPT
Output:"""
if isinstance(dataset, IterableDataset):

def generate_example(dataset):
for example in dataset:
question = example['question']
for h in example['human_answers']:
yield {
'messages': [{
'role': 'user',
'content': prompt.format(question=question, answer=h)
}, {
'role': 'assistant',
'content': 'Human'
}]
}
for c in example['chatgpt_answers']:
yield {
'messages': [{
'role': 'user',
'content': prompt.format(question=question, answer=c)
}, {
'role': 'assistant',
'content': 'ChatGPT'
}]
}

return IterableDataset.from_generator(generate_example, gen_kwargs={'dataset': dataset})

messages = []
for d in dataset:
question = d['question']
for h in d['human_answers']:
messages.append({
'messages': [{
'role': 'user',
'content': prompt.format(question=question, answer=h)
}, {
'role': 'assistant',
'content': 'Human'
}]
})
for c in d['chatgpt_answers']:
messages.append({
'messages': [{
'role': 'user',
'content': prompt.format(question=question, answer=c)
}, {
'role': 'assistant',
'content': 'ChatGPT'
}]
})
return HfDataset.from_list(messages)


def _preprocess_hc3_cls(dataset: DATASET_TYPE, **kwargs) -> DATASET_TYPE:
prompt = """Classification Task: Are the following responses from a human or from ChatGPT?
Question: {question}
Answer: {answer}
Category: 0 for Human, 1 for ChatGPT
Output:"""
if isinstance(dataset, IterableDataset):

def generate_example(dataset):
for example in dataset:
question = example['question']
for h in example['human_answers']:
yield {
'messages': [{
'role': 'user',
'content': prompt.format(question=question, answer=h)
}],
'label': 0,
}
for c in example['chatgpt_answers']:
yield {
'messages': [{
'role': 'user',
'content': prompt.format(question=question, answer=c)
}],
'label': 1,
}

return IterableDataset.from_generator(generate_example, gen_kwargs={'dataset': dataset})

messages = []
for d in dataset:
question = d['question']
for h in d['human_answers']:
messages.append({
'messages': [{
'role': 'user',
'content': prompt.format(question=question, answer=h)
}],
'label': 0,
})
for c in d['chatgpt_answers']:
messages.append({
'messages': [{
'role': 'user',
'content': prompt.format(question=question, answer=c)
}],
'label': 1,
})
return HfDataset.from_list(messages)

def preprocess(self, row):
rows = []
for response in ['Human', 'ChatGPT']:
query = self.prompt.format(question=row['query'], answer=row[f'{response.lower()}_answers'])
rows.append(super().preprocess({'query': query, 'response': response}))
return rows


class HC3ClsPreprocessor(HC3Preprocessor):

def preprocess(self, row):
rows = []
for i, response in enumerate(['Human', 'ChatGPT']):
query = self.prompt.format(question=row['query'], answer=row[f'{response.lower()}_answers'])
rows.append(ResponsePreprocessor.preprocess(self, {'query': query, 'label': i}))
return rows


hc3_subset_names = ['baike', 'open_qa', 'nlpcc_dbqa', 'finance', 'medicine', 'law', 'psychology']
hc3_subsets: List[SubsetDataset] = []
for hc3_subset_name in hc3_subset_names:
hc3_subsets.append(SubsetDataset(
name=hc3_subset_name,
subset=hc3_subset_name,
preprocess_func=_preprocess_hc3,
))
hc3_subsets.append(
SubsetDataset(
name=hc3_subset_name,
subset=hc3_subset_name,
preprocess_func=HC3Preprocessor(),
))
hc3_subsets.append(
SubsetDataset(
name=f'{hc3_subset_name}_cls',
subset=hc3_subset_name,
preprocess_func=_preprocess_hc3_cls,
preprocess_func=HC3ClsPreprocessor(),
))

register_dataset(
Expand All @@ -484,7 +398,7 @@ def generate_example(dataset):
ms_dataset_id='simpleai/HC3',
hf_dataset_id='Hello-SimpleAI/HC3',
subsets=['finance', 'medicine'],
preprocess_func=_preprocess_hc3,
preprocess_func=HC3Preprocessor(),
tags=['text-generation', 'classification', '🔥']))


Expand Down
29 changes: 14 additions & 15 deletions swift/llm/dataset/preprocessor/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,9 +116,14 @@ def batched_preprocess(self, batched_row: Dict[str, Any], *, strict: bool) -> Di
for row in rows:
try:
row = self.preprocess(row)
if row is not None:
self.check_messages(row)
self.check_rejected_response(row)
# support [row1, row2, ...]
if row is None:
row = []
if isinstance(row, dict):
row = [row]
for r in row:
self.check_messages(r)
self.check_rejected_response(r)
except Exception:
if strict:
logger.warning('To avoid errors, you can pass `strict=False`.')
Expand All @@ -128,10 +133,8 @@ def batched_preprocess(self, batched_row: Dict[str, Any], *, strict: bool) -> Di
print(traceback.format_exc())
logger.error('👆👆👆There are errors in the dataset, the data will be deleted')
self._traceback_counter += 1
row = None
if row is None:
continue
new_rows.append(row)
row = []
new_rows += row
res = self.rows_to_batched(new_rows)

if len(res) == 0:
Expand Down Expand Up @@ -247,14 +250,10 @@ def __init__(self, *, columns_mapping: Optional[Dict[str, str]] = None, **kwargs

def preprocess(self, row: Dict[str, Any]) -> Optional[Dict[str, Any]]:
response = row.pop('response', None)
if response is None:
row.pop('query', None)
row.pop('history', None)
row.pop('system', None)
return
if isinstance(response, (list, tuple)):
# sometimes response is a list, pick one randomly
response = self.random_state.choice(response)
if response is not None:
if isinstance(response, (list, tuple)):
# sometimes response is a list, pick one randomly
response = self.random_state.choice(response)
history = row.pop('history', None) or []
query = row.pop('query', None)
system = row.pop('system', None)
Expand Down
8 changes: 4 additions & 4 deletions swift/llm/infer/infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,9 @@

from swift.llm import (InferArguments, InferRequest, Messages, Processor, SwiftPipeline, Template, get_template,
load_dataset, sample_dataset)
from swift.plugin import extra_tuners
from swift.tuners import Swift
from swift.utils import get_logger, is_master, open_jsonl_writer
from ...plugin import extra_tuners
from .protocol import RequestConfig

logger = get_logger()
Expand Down Expand Up @@ -290,9 +290,9 @@ def infer_dataset(self) -> List[Dict[str, Any]]:
if self.jsonl_writer:
self.jsonl_writer.append(data)
else:
is_dist = args.global_world_size > 1 and dist.is_initialized()
is_dist = args.world_size > 1 and dist.is_initialized()
if is_dist:
val_dataset = val_dataset.shard(args.global_world_size, args.rank, contiguous=True)
val_dataset = val_dataset.shard(args.world_size, args.rank, contiguous=True)
infer_requests = [InferRequest(**data) for i, data in enumerate(val_dataset)]

resp_list = self.infer(infer_requests, request_config, template=self.template, use_tqdm=True)
Expand All @@ -301,7 +301,7 @@ def infer_dataset(self) -> List[Dict[str, Any]]:
data = {'response': response, **data}
result_list.append(data)
if is_dist:
total_result_list = [None for _ in range(args.global_world_size)] if args.rank == 0 else None
total_result_list = [None for _ in range(args.world_size)] if args.rank == 0 else None
dist.gather_object(result_list, total_result_list)
result_list = total_result_list and list(chain.from_iterable(total_result_list))

Expand Down
10 changes: 5 additions & 5 deletions swift/llm/infer/infer_engine/pt_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ def _add_stop_words(self, generation_config: _GenerationConfig, request_config:
if generation_config.eos_token_id is None:
generation_config.eos_token_id = self.tokenizer.eos_token_id
if generation_config.pad_token_id is None:
generation_config.pad_token_id = template.pad_token_id
generation_config.pad_token_id = self.tokenizer.pad_token_id

@staticmethod
def preprocess_logits(batched_logits: Optional[List[torch.Tensor]], batched_generate_ids: torch.Tensor,
Expand Down Expand Up @@ -201,15 +201,15 @@ def _model_generate(*args, **kwargs):
generate_ids = batched_generate_ids[i]

# ignore pad_token
masks = generate_ids != template.pad_token_id
masks = generate_ids != self.tokenizer.pad_token_id
generate_ids = generate_ids[masks].tolist()
logprobs_list = None
if batched_logprobs[i]:
logprobs_list = [logprobs for m, logprobs in zip(masks, batched_logprobs[i]) if m.item()]

is_finished[i] = (
all_is_finished or is_finished[i]
or len(generate_ids) > 0 and generate_ids[-1] == template.pad_token_id)
or len(generate_ids) > 0 and generate_ids[-1] == self.tokenizer.pad_token_id)
delta_text = infer_streamers[i].get_printable_text(generate_ids, is_finished[i])
if not delta_text and not is_finished[i]:
res.append(None)
Expand Down Expand Up @@ -269,7 +269,7 @@ def _infer_full(
generate_ids = batched_generate_ids[i]

# ignore pad_token
masks = generate_ids != template.pad_token_id
masks = generate_ids != self.tokenizer.pad_token_id
generate_ids = generate_ids[masks].tolist()
logprobs_list = None
if batched_logprobs is not None:
Expand All @@ -291,7 +291,7 @@ def _infer_full(
elif isinstance(response, Image.Image):
res.append(
ImagesResponse(
created=time.time(), data=[ImageObject(b64_json=MultiModalRequestMixin._to_base64(response))]))
created=time.time(), data=[ImageObject(b64_json=MultiModalRequestMixin.to_base64(response))]))

return res

Expand Down
6 changes: 3 additions & 3 deletions swift/llm/infer/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ class MultiModalRequestMixin:
videos: List[str] = field(default_factory=list)

@staticmethod
def _to_base64(mm_data: Union[str, Image.Image, bytes]) -> str:
def to_base64(mm_data: Union[str, Image.Image, bytes]) -> str:
if isinstance(mm_data, str) and not os.path.isfile(mm_data):
# base64 or url
return mm_data
Expand All @@ -125,7 +125,7 @@ def __post_init__(self):
values = [values]
setattr(self, key, values)
for i, val in enumerate(values):
values[i] = self._to_base64(val)
values[i] = self.to_base64(val)


@dataclass
Expand Down Expand Up @@ -173,7 +173,7 @@ def convert_to_base64(self):
suffix = 'jpeg'
else:
raise ValueError(f'value: {value}')
mm_data_base64 = self._to_base64(value)
mm_data_base64 = self.to_base64(value)
new_value = f'data:{key}/{suffix};base64,{mm_data_base64}'
if is_dict:
new_value = {'url': new_value}
Expand Down
1 change: 0 additions & 1 deletion swift/llm/model/constant.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@ class LLMModelType:
modelscope_agent = 'modelscope_agent'
qwen2 = 'qwen2'
qwen2_5 = 'qwen2_5'
qwen2_5_cls = 'qwen2_5_cls'

llama = 'llama'
llama3 = 'llama3'
Expand Down
16 changes: 11 additions & 5 deletions swift/llm/model/register.py
Original file line number Diff line number Diff line change
Expand Up @@ -457,12 +457,18 @@ def get_model_tokenizer(model_id_or_path: str,
model_meta.check_requires()
get_function = model_meta.get_function
kwargs['automodel_class'] = automodel_class
model, tokenizer = get_function(model_dir, model_info, model_kwargs, load_model, **kwargs)
model, processor = get_function(model_dir, model_info, model_kwargs, load_model, **kwargs)

if not isinstance(tokenizer, PreTrainedTokenizerBase) and hasattr(tokenizer, 'tokenizer'):
patch_processor(tokenizer)
tokenizer.model_info = model_info
tokenizer.model_meta = model_meta
if not isinstance(processor, PreTrainedTokenizerBase) and hasattr(processor, 'tokenizer'):
tokenizer = processor.tokenizer
patch_processor(processor)
else:
tokenizer = processor
processor.model_info = model_info
processor.model_meta = model_meta
tokenizer.pad_token_id = tokenizer.pad_token_id or tokenizer.eos_token_id
assert tokenizer.eos_token_id is not None
assert tokenizer.pad_token_id is not None

if model is not None:
model.model_info = model_info
Expand Down
Loading

0 comments on commit 75b0bcd

Please sign in to comment.