diff --git a/VERSION b/VERSION index abd4105..3a4036f 100644 --- a/VERSION +++ b/VERSION @@ -1 +1 @@ -0.2.4 +0.2.5 diff --git a/facexlib/alignment/__init__.py b/facexlib/alignment/__init__.py index 42c9a9c..b4f6378 100644 --- a/facexlib/alignment/__init__.py +++ b/facexlib/alignment/__init__.py @@ -7,14 +7,15 @@ __all__ = ['FAN', 'landmark_98_to_68'] -def init_alignment_model(model_name, half=False, device='cuda'): +def init_alignment_model(model_name, half=False, device='cuda', model_rootpath=None): if model_name == 'awing_fan': model = FAN(num_modules=4, num_landmarks=98) model_url = 'https://github.com/xinntao/facexlib/releases/download/v0.1.0/alignment_WFLW_4HG.pth' else: raise NotImplementedError(f'{model_name} is not implemented.') - model_path = load_file_from_url(url=model_url, model_dir='facexlib/weights', progress=True, file_name=None) + model_path = load_file_from_url( + url=model_url, model_dir='facexlib/weights', progress=True, file_name=None, save_dir=model_rootpath) model.load_state_dict(torch.load(model_path)['state_dict'], strict=True) model.eval() model = model.to(device) diff --git a/facexlib/assessment/__init__.py b/facexlib/assessment/__init__.py index 9e59af0..8f6f3a4 100644 --- a/facexlib/assessment/__init__.py +++ b/facexlib/assessment/__init__.py @@ -4,7 +4,7 @@ from .hyperiqa_net import HyperIQA -def init_assessment_model(model_name, half=False, device='cuda'): +def init_assessment_model(model_name, half=False, device='cuda', model_rootpath=None): if model_name == 'hypernet': model = HyperIQA(16, 112, 224, 112, 56, 28, 14, 7) model_url = 'https://github.com/xinntao/facexlib/releases/download/v0.2.0/assessment_hyperIQA.pth' @@ -12,7 +12,8 @@ def init_assessment_model(model_name, half=False, device='cuda'): raise NotImplementedError(f'{model_name} is not implemented.') # load the pre-trained hypernet model - hypernet_model_path = load_file_from_url(url=model_url, model_dir='facexlib/weights', progress=True, file_name=None) + hypernet_model_path = load_file_from_url( + url=model_url, model_dir='facexlib/weights', progress=True, file_name=None, save_dir=model_rootpath) model.hypernet.load_state_dict((torch.load(hypernet_model_path, map_location=lambda storage, loc: storage))) model = model.eval() model = model.to(device) diff --git a/facexlib/detection/__init__.py b/facexlib/detection/__init__.py index ce867f8..f5d34ea 100644 --- a/facexlib/detection/__init__.py +++ b/facexlib/detection/__init__.py @@ -5,7 +5,7 @@ from .retinaface import RetinaFace -def init_detection_model(model_name, half=False, device='cuda'): +def init_detection_model(model_name, half=False, device='cuda', model_rootpath=None): if model_name == 'retinaface_resnet50': model = RetinaFace(network_name='resnet50', half=half) model_url = 'https://github.com/xinntao/facexlib/releases/download/v0.1.0/detection_Resnet50_Final.pth' @@ -15,7 +15,9 @@ def init_detection_model(model_name, half=False, device='cuda'): else: raise NotImplementedError(f'{model_name} is not implemented.') - model_path = load_file_from_url(url=model_url, model_dir='facexlib/weights', progress=True, file_name=None) + model_path = load_file_from_url( + url=model_url, model_dir='facexlib/weights', progress=True, file_name=None, save_dir=model_rootpath) + # TODO: clean pretrained model load_net = torch.load(model_path, map_location=lambda storage, loc: storage) # remove unnecessary 'module.' diff --git a/facexlib/headpose/__init__.py b/facexlib/headpose/__init__.py index e02a334..c4de5da 100644 --- a/facexlib/headpose/__init__.py +++ b/facexlib/headpose/__init__.py @@ -4,14 +4,15 @@ from .hopenet_arch import HopeNet -def init_headpose_model(model_name, half=False, device='cuda'): +def init_headpose_model(model_name, half=False, device='cuda', model_rootpath=None): if model_name == 'hopenet': model = HopeNet('resnet', [3, 4, 6, 3], 66) model_url = 'https://github.com/xinntao/facexlib/releases/download/v0.2.0/headpose_hopenet.pth' else: raise NotImplementedError(f'{model_name} is not implemented.') - model_path = load_file_from_url(url=model_url, model_dir='facexlib/weights', progress=True, file_name=None) + model_path = load_file_from_url( + url=model_url, model_dir='facexlib/weights', progress=True, file_name=None, save_dir=model_rootpath) load_net = torch.load(model_path, map_location=lambda storage, loc: storage)['params'] model.load_state_dict(load_net, strict=True) model.eval() diff --git a/facexlib/matting/__init__.py b/facexlib/matting/__init__.py index 3301590..02a573b 100644 --- a/facexlib/matting/__init__.py +++ b/facexlib/matting/__init__.py @@ -5,14 +5,15 @@ from .modnet import MODNet -def init_matting_model(model_name='modnet', half=False, device='cuda'): +def init_matting_model(model_name='modnet', half=False, device='cuda', model_rootpath=None): if model_name == 'modnet': model = MODNet(backbone_pretrained=False) model_url = 'https://github.com/xinntao/facexlib/releases/download/v0.2.0/matting_modnet_portrait.pth' else: raise NotImplementedError(f'{model_name} is not implemented.') - model_path = load_file_from_url(url=model_url, model_dir='facexlib/weights', progress=True, file_name=None) + model_path = load_file_from_url( + url=model_url, model_dir='facexlib/weights', progress=True, file_name=None, save_dir=model_rootpath) # TODO: clean pretrained model load_net = torch.load(model_path, map_location=lambda storage, loc: storage) # remove unnecessary 'module.' diff --git a/facexlib/parsing/__init__.py b/facexlib/parsing/__init__.py index b1725fb..9be36a3 100644 --- a/facexlib/parsing/__init__.py +++ b/facexlib/parsing/__init__.py @@ -5,7 +5,7 @@ from .parsenet import ParseNet -def init_parsing_model(model_name='bisenet', half=False, device='cuda'): +def init_parsing_model(model_name='bisenet', half=False, device='cuda', model_rootpath=None): if model_name == 'bisenet': model = BiSeNet(num_class=19) model_url = 'https://github.com/xinntao/facexlib/releases/download/v0.2.0/parsing_bisenet.pth' @@ -15,7 +15,8 @@ def init_parsing_model(model_name='bisenet', half=False, device='cuda'): else: raise NotImplementedError(f'{model_name} is not implemented.') - model_path = load_file_from_url(url=model_url, model_dir='facexlib/weights', progress=True, file_name=None) + model_path = load_file_from_url( + url=model_url, model_dir='facexlib/weights', progress=True, file_name=None, save_dir=model_rootpath) load_net = torch.load(model_path, map_location=lambda storage, loc: storage) model.load_state_dict(load_net, strict=True) model.eval() diff --git a/facexlib/recognition/__init__.py b/facexlib/recognition/__init__.py index 0d52949..1f65d2c 100644 --- a/facexlib/recognition/__init__.py +++ b/facexlib/recognition/__init__.py @@ -4,14 +4,15 @@ from .arcface_arch import Backbone -def init_recognition_model(model_name, half=False, device='cuda'): +def init_recognition_model(model_name, half=False, device='cuda', model_rootpath=None): if model_name == 'arcface': model = Backbone(num_layers=50, drop_ratio=0.6, mode='ir_se').to('cuda').eval() model_url = 'https://github.com/xinntao/facexlib/releases/download/v0.1.0/recognition_arcface_ir_se50.pth' else: raise NotImplementedError(f'{model_name} is not implemented.') - model_path = load_file_from_url(url=model_url, model_dir='facexlib/weights', progress=True, file_name=None) + model_path = load_file_from_url( + url=model_url, model_dir='facexlib/weights', progress=True, file_name=None, save_dir=model_rootpath) model.load_state_dict(torch.load(model_path), strict=True) model.eval() model = model.to(device) diff --git a/facexlib/utils/face_restoration_helper.py b/facexlib/utils/face_restoration_helper.py index fcf1726..cf2254e 100644 --- a/facexlib/utils/face_restoration_helper.py +++ b/facexlib/utils/face_restoration_helper.py @@ -57,7 +57,8 @@ def __init__(self, template_3points=False, pad_blur=False, use_parse=False, - device=None): + device=None, + model_rootpath=None): self.template_3points = template_3points # improve robustness self.upscale_factor = upscale_factor # the cropped face ratio based on the square face @@ -95,11 +96,11 @@ def __init__(self, self.device = device # init face detection model - self.face_det = init_detection_model(det_model, half=False, device=self.device) + self.face_det = init_detection_model(det_model, half=False, device=self.device, model_rootpath=model_rootpath) # init face parsing model self.use_parse = use_parse - self.face_parse = init_parsing_model(model_name='parsenet', device=self.device) + self.face_parse = init_parsing_model(model_name='parsenet', device=self.device, model_rootpath=model_rootpath) def set_upscale_factor(self, upscale_factor): self.upscale_factor = upscale_factor diff --git a/facexlib/utils/misc.py b/facexlib/utils/misc.py index f0da6a2..b1a597c 100644 --- a/facexlib/utils/misc.py +++ b/facexlib/utils/misc.py @@ -56,20 +56,22 @@ def _totensor(img, bgr2rgb, float32): return _totensor(imgs, bgr2rgb, float32) -def load_file_from_url(url, model_dir=None, progress=True, file_name=None): +def load_file_from_url(url, model_dir=None, progress=True, file_name=None, save_dir=None): """Ref:https://github.com/1adrianb/face-alignment/blob/master/face_alignment/utils.py """ if model_dir is None: hub_dir = get_dir() model_dir = os.path.join(hub_dir, 'checkpoints') - os.makedirs(os.path.join(ROOT_DIR, model_dir), exist_ok=True) + if save_dir is None: + save_dir = os.path.join(ROOT_DIR, model_dir) + os.makedirs(save_dir, exist_ok=True) parts = urlparse(url) filename = os.path.basename(parts.path) if file_name is not None: filename = file_name - cached_file = os.path.abspath(os.path.join(ROOT_DIR, model_dir, filename)) + cached_file = os.path.abspath(os.path.join(save_dir, filename)) if not os.path.exists(cached_file): print(f'Downloading: "{url}" to {cached_file}\n') download_url_to_file(url, cached_file, hash_prefix=None, progress=progress)