Skip to content

Commit

Permalink
[djl-convert] Support convert local model to DJL format (#3386)
Browse files Browse the repository at this point in the history
Co-authored-by: nobody <nobody@localhost>
  • Loading branch information
frankfliu and nobody authored Aug 5, 2024
1 parent a03ab68 commit d7c8a74
Show file tree
Hide file tree
Showing 3 changed files with 74 additions and 35 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -52,18 +52,19 @@ def __init__(self):
self.outputs = None
self.api = HfApi()

def save_model(self, model_info, args: Namespace, temp_dir: str,
def save_model(self, model_info, task: str, args: Namespace, temp_dir: str,
model_zoo: bool):
if args.output_format == "OnnxRuntime":
return self.save_onnx_model(model_info, args, temp_dir, model_zoo)
return self.save_onnx_model(model_info, task, args, temp_dir,
model_zoo)
elif args.output_format == "Rust":
return self.save_rust_model(model_info, args, temp_dir, model_zoo)
else:
return self.save_pytorch_model(model_info, args, temp_dir,
model_zoo)

def save_onnx_model(self, model_info, args: Namespace, temp_dir: str,
model_zoo: bool):
def save_onnx_model(self, model_info, task: str, args: Namespace,
temp_dir: str, model_zoo: bool):
model_id = model_info.modelId

if not os.path.exists(temp_dir):
Expand All @@ -82,6 +83,8 @@ def save_onnx_model(self, model_info, args: Namespace, temp_dir: str,
sys.argv.extend(["--dtype", args.dtype])
if args.trust_remote_code:
sys.argv.append("--trust-remote-code")
if os.path.exists(model_id):
sys.argv.extend(["--task", task])
sys.argv.append(temp_dir)

main()
Expand Down Expand Up @@ -135,29 +138,46 @@ def save_rust_model(self, model_info, args: Namespace, temp_dir: str,
return False, "Failed to save tokenizer", -1

# Save config.json
config_file = hf_hub_download(repo_id=model_id, filename="config.json")
if os.path.exists(model_id):
config_file = os.path.join(model_id, "config.json")
else:
config_file = hf_hub_download(repo_id=model_id,
filename="config.json")

shutil.copyfile(config_file, os.path.join(temp_dir, "config.json"))

target = os.path.join(temp_dir, "model.safetensors")
model = self.api.model_info(model_id, files_metadata=True)
has_sf_file = False
has_pt_file = False
for sibling in model.siblings:
if sibling.rfilename == "model.safetensors":
has_sf_file = True
elif sibling.rfilename == "pytorch_model.bin":
has_pt_file = True

if has_sf_file:
file = hf_hub_download(repo_id=model_id,
filename="model.safetensors")
shutil.copyfile(file, target)
elif has_pt_file:
file = hf_hub_download(repo_id=model_id,
filename="pytorch_model.bin")
convert_file(file, target)

if os.path.exists(model_id):
file = os.path.join(model_id, "model.safetensors")
if os.path.exists(file):
shutil.copyfile(file, target)
else:
file = os.path.join(model_id, "pytorch_model.bin")
if os.path.exists(file):
convert_file(file, target)
else:
return False, f"No model file found for: {model_id}", -1
else:
return False, f"No model file found for: {model_id}", -1
model = self.api.model_info(model_id, files_metadata=True)
has_sf_file = False
has_pt_file = False
for sibling in model.siblings:
if sibling.rfilename == "model.safetensors":
has_sf_file = True
elif sibling.rfilename == "pytorch_model.bin":
has_pt_file = True

if has_sf_file:
file = hf_hub_download(repo_id=model_id,
filename="model.safetensors")
shutil.copyfile(file, target)
elif has_pt_file:
file = hf_hub_download(repo_id=model_id,
filename="pytorch_model.bin")
convert_file(file, target)
else:
return False, f"No model file found for: {model_id}", -1

arguments = self.save_serving_properties(model_info, "Rust", temp_dir,
hf_pipeline, include_types)
Expand Down Expand Up @@ -191,8 +211,13 @@ def save_pytorch_model(self, model_info, args: Namespace, temp_dir: str,
return False, "Failed to save tokenizer", -1

# Save config.json just for reference
config = hf_hub_download(repo_id=model_id, filename="config.json")
shutil.copyfile(config, os.path.join(temp_dir, "config.json"))
if os.path.exists(model_id):
config_file = os.path.join(model_id, "config.json")
else:
config_file = hf_hub_download(repo_id=model_id,
filename="config.json")

shutil.copyfile(config_file, os.path.join(temp_dir, "config.json"))

# Save jit traced .pt file to temp dir
include_types = "token_type_ids" in hf_pipeline.tokenizer.model_input_names
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,17 +10,24 @@
# or in the "LICENSE.txt" file accompanying this file. This file is distributed on an "AS IS"
# BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, express or implied. See the License for
# the specific language governing permissions and limitations under the License.
import json
import logging
import os
import sys

from huggingface_hub import HfApi

sys.path.append(os.path.dirname(os.path.realpath(__file__)))

from djl_converter.arg_parser import converter_args


class ModelInfoHolder(object):

def __init__(self, model_id: str):
self.modelId = model_id
with open(os.path.join(model_id, "config.json")) as f:
self.config = json.load(f)


def main():
logging.basicConfig(stream=sys.stdout,
format="%(message)s",
Expand All @@ -38,10 +45,17 @@ def main():
logging.error(f"output directory: {output_dir} is not empty.")
return

api = HfApi()
model_info = api.model_info(args.model_id,
revision=args.revision,
token=args.token)
if os.path.exists(args.model_id):
logging.info(f"converting local model: {args.model_id}")
model_info = ModelInfoHolder(args.model_id)
else:
logging.info(f"converting HuggingFace hub model: {args.model_id}")
from huggingface_hub import HfApi

api = HfApi()
model_info = api.model_info(args.model_id,
revision=args.revision,
token=args.token)

from djl_converter.huggingface_models import HuggingfaceModels, SUPPORTED_TASKS

Expand All @@ -51,14 +65,14 @@ def main():
task = "sentence-similarity"
if not task:
logging.error(
f"Unsupported model architecture: {arch} for {model_id}.")
f"Unsupported model architecture: {arch} for {args.model_id}.")
return

converter = SUPPORTED_TASKS[task]

try:
result, reason, _ = converter.save_model(model_info, args, output_dir,
False)
result, reason, _ = converter.save_model(model_info, task, args,
output_dir, False)
if result:
logging.info(f"Convert model {model_info.modelId} finished.")
else:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def main():

try:
result, reason, size = converter.save_model(
model_info, args, temp_dir, True)
model_info, task, args, temp_dir, True)
if not result:
logging.error(f"{model_info.modelId}: {reason}")
except Exception as e:
Expand Down

0 comments on commit d7c8a74

Please sign in to comment.