Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Load diffusers in native FP16/BF16 precision to reduce the memory usage #1033

Open
wants to merge 10 commits into
base: main
Choose a base branch
from
43 changes: 43 additions & 0 deletions optimum/exporters/openvino/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
is_openvino_tokenizers_available,
is_openvino_version,
is_transformers_version,
is_safetensors_available,
)
from optimum.intel.utils.modeling_utils import (
_infer_library_from_model_name_or_path,
Expand Down Expand Up @@ -332,6 +333,48 @@ class StoreAttr(object):
return model

GPTQQuantizer.post_init_model = post_init_model
elif library_name == "diffusers" and is_safetensors_available() and is_openvino_version(">=", "2024.6"):
if Path(model_name_or_path).is_dir():
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you please extract and encapsulate this code to a function with a meaningful name?

path = Path(model_name_or_path)
else:
from diffusers import DiffusionPipeline

path = DiffusionPipeline.download(
model_name_or_path,
revision=revision,
cache_dir=cache_dir,
token=token,
local_files_only=local_files_only,
force_download=force_download,
trust_remote_code=trust_remote_code,
)
model_part_name = None
if (path / "transformer").is_dir():
model_part_name = "transformer"
elif (path / "unet").is_dir():
model_part_name = "unet"
dtype = None
if model_part_name:
directory = path / model_part_name
safetensors_files = [
filename for filename in directory.glob("*.safetensors") if len(filename.suffixes) == 1
]
safetensors_file = None
if len(safetensors_files) > 0:
safetensors_file = safetensors_files.pop(0)
if safetensors_file:
from safetensors import safe_open

with safe_open(safetensors_file, framework="pt", device="cpu") as f:
if len(f.keys()) > 0:
for key in f.keys():
tensor = f.get_tensor(key)
if tensor.dtype.is_floating_point:
dtype = tensor.dtype
break
if dtype in [torch.float16, torch.bfloat16]:
loading_kwargs["torch_dtype"] = dtype
patch_16bit = True

if library_name == "open_clip":
model = _OpenClipForZeroShotImageClassification.from_pretrained(model_name_or_path, cache_dir=cache_dir)
Expand Down
Loading