diff --git a/extensions/tokenizers/src/main/python/djl_converter/huggingface_converter.py b/extensions/tokenizers/src/main/python/djl_converter/huggingface_converter.py index 3b15a150507..0f40c2d77c5 100644 --- a/extensions/tokenizers/src/main/python/djl_converter/huggingface_converter.py +++ b/extensions/tokenizers/src/main/python/djl_converter/huggingface_converter.py @@ -52,18 +52,19 @@ def __init__(self): self.outputs = None self.api = HfApi() - def save_model(self, model_info, args: Namespace, temp_dir: str, + def save_model(self, model_info, task: str, args: Namespace, temp_dir: str, model_zoo: bool): if args.output_format == "OnnxRuntime": - return self.save_onnx_model(model_info, args, temp_dir, model_zoo) + return self.save_onnx_model(model_info, task, args, temp_dir, + model_zoo) elif args.output_format == "Rust": return self.save_rust_model(model_info, args, temp_dir, model_zoo) else: return self.save_pytorch_model(model_info, args, temp_dir, model_zoo) - def save_onnx_model(self, model_info, args: Namespace, temp_dir: str, - model_zoo: bool): + def save_onnx_model(self, model_info, task: str, args: Namespace, + temp_dir: str, model_zoo: bool): model_id = model_info.modelId if not os.path.exists(temp_dir): @@ -82,6 +83,8 @@ def save_onnx_model(self, model_info, args: Namespace, temp_dir: str, sys.argv.extend(["--dtype", args.dtype]) if args.trust_remote_code: sys.argv.append("--trust-remote-code") + if os.path.exists(model_id): + sys.argv.extend(["--task", task]) sys.argv.append(temp_dir) main() @@ -135,29 +138,46 @@ def save_rust_model(self, model_info, args: Namespace, temp_dir: str, return False, "Failed to save tokenizer", -1 # Save config.json - config_file = hf_hub_download(repo_id=model_id, filename="config.json") + if os.path.exists(model_id): + config_file = os.path.join(model_id, "config.json") + else: + config_file = hf_hub_download(repo_id=model_id, + filename="config.json") + shutil.copyfile(config_file, os.path.join(temp_dir, "config.json")) target = os.path.join(temp_dir, "model.safetensors") - model = self.api.model_info(model_id, files_metadata=True) - has_sf_file = False - has_pt_file = False - for sibling in model.siblings: - if sibling.rfilename == "model.safetensors": - has_sf_file = True - elif sibling.rfilename == "pytorch_model.bin": - has_pt_file = True - - if has_sf_file: - file = hf_hub_download(repo_id=model_id, - filename="model.safetensors") - shutil.copyfile(file, target) - elif has_pt_file: - file = hf_hub_download(repo_id=model_id, - filename="pytorch_model.bin") - convert_file(file, target) + + if os.path.exists(model_id): + file = os.path.join(model_id, "model.safetensors") + if os.path.exists(file): + shutil.copyfile(file, target) + else: + file = os.path.join(model_id, "pytorch_model.bin") + if os.path.exists(file): + convert_file(file, target) + else: + return False, f"No model file found for: {model_id}", -1 else: - return False, f"No model file found for: {model_id}", -1 + model = self.api.model_info(model_id, files_metadata=True) + has_sf_file = False + has_pt_file = False + for sibling in model.siblings: + if sibling.rfilename == "model.safetensors": + has_sf_file = True + elif sibling.rfilename == "pytorch_model.bin": + has_pt_file = True + + if has_sf_file: + file = hf_hub_download(repo_id=model_id, + filename="model.safetensors") + shutil.copyfile(file, target) + elif has_pt_file: + file = hf_hub_download(repo_id=model_id, + filename="pytorch_model.bin") + convert_file(file, target) + else: + return False, f"No model file found for: {model_id}", -1 arguments = self.save_serving_properties(model_info, "Rust", temp_dir, hf_pipeline, include_types) @@ -191,8 +211,13 @@ def save_pytorch_model(self, model_info, args: Namespace, temp_dir: str, return False, "Failed to save tokenizer", -1 # Save config.json just for reference - config = hf_hub_download(repo_id=model_id, filename="config.json") - shutil.copyfile(config, os.path.join(temp_dir, "config.json")) + if os.path.exists(model_id): + config_file = os.path.join(model_id, "config.json") + else: + config_file = hf_hub_download(repo_id=model_id, + filename="config.json") + + shutil.copyfile(config_file, os.path.join(temp_dir, "config.json")) # Save jit traced .pt file to temp dir include_types = "token_type_ids" in hf_pipeline.tokenizer.model_input_names diff --git a/extensions/tokenizers/src/main/python/djl_converter/model_converter.py b/extensions/tokenizers/src/main/python/djl_converter/model_converter.py index 54c4aa113bc..802439c67de 100644 --- a/extensions/tokenizers/src/main/python/djl_converter/model_converter.py +++ b/extensions/tokenizers/src/main/python/djl_converter/model_converter.py @@ -10,17 +10,24 @@ # or in the "LICENSE.txt" file accompanying this file. This file is distributed on an "AS IS" # BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, express or implied. See the License for # the specific language governing permissions and limitations under the License. +import json import logging import os import sys -from huggingface_hub import HfApi - sys.path.append(os.path.dirname(os.path.realpath(__file__))) from djl_converter.arg_parser import converter_args +class ModelInfoHolder(object): + + def __init__(self, model_id: str): + self.modelId = model_id + with open(os.path.join(model_id, "config.json")) as f: + self.config = json.load(f) + + def main(): logging.basicConfig(stream=sys.stdout, format="%(message)s", @@ -38,10 +45,17 @@ def main(): logging.error(f"output directory: {output_dir} is not empty.") return - api = HfApi() - model_info = api.model_info(args.model_id, - revision=args.revision, - token=args.token) + if os.path.exists(args.model_id): + logging.info(f"converting local model: {args.model_id}") + model_info = ModelInfoHolder(args.model_id) + else: + logging.info(f"converting HuggingFace hub model: {args.model_id}") + from huggingface_hub import HfApi + + api = HfApi() + model_info = api.model_info(args.model_id, + revision=args.revision, + token=args.token) from djl_converter.huggingface_models import HuggingfaceModels, SUPPORTED_TASKS @@ -51,14 +65,14 @@ def main(): task = "sentence-similarity" if not task: logging.error( - f"Unsupported model architecture: {arch} for {model_id}.") + f"Unsupported model architecture: {arch} for {args.model_id}.") return converter = SUPPORTED_TASKS[task] try: - result, reason, _ = converter.save_model(model_info, args, output_dir, - False) + result, reason, _ = converter.save_model(model_info, task, args, + output_dir, False) if result: logging.info(f"Convert model {model_info.modelId} finished.") else: diff --git a/extensions/tokenizers/src/main/python/djl_converter/model_zoo_importer.py b/extensions/tokenizers/src/main/python/djl_converter/model_zoo_importer.py index 7e569e832cd..909563ced6b 100644 --- a/extensions/tokenizers/src/main/python/djl_converter/model_zoo_importer.py +++ b/extensions/tokenizers/src/main/python/djl_converter/model_zoo_importer.py @@ -43,7 +43,7 @@ def main(): try: result, reason, size = converter.save_model( - model_info, args, temp_dir, True) + model_info, task, args, temp_dir, True) if not result: logging.error(f"{model_info.modelId}: {reason}") except Exception as e: