Skip to content

Commit

Permalink
support change model rootpath to save models, v0.2.5
Browse files Browse the repository at this point in the history
  • Loading branch information
xinntao committed Aug 31, 2022
1 parent 8b96d51 commit 7655b7c
Show file tree
Hide file tree
Showing 10 changed files with 32 additions and 21 deletions.
2 changes: 1 addition & 1 deletion VERSION
Original file line number Diff line number Diff line change
@@ -1 +1 @@
0.2.4
0.2.5
5 changes: 3 additions & 2 deletions facexlib/alignment/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,15 @@
__all__ = ['FAN', 'landmark_98_to_68']


def init_alignment_model(model_name, half=False, device='cuda'):
def init_alignment_model(model_name, half=False, device='cuda', model_rootpath=None):
if model_name == 'awing_fan':
model = FAN(num_modules=4, num_landmarks=98)
model_url = 'https://github.com/xinntao/facexlib/releases/download/v0.1.0/alignment_WFLW_4HG.pth'
else:
raise NotImplementedError(f'{model_name} is not implemented.')

model_path = load_file_from_url(url=model_url, model_dir='facexlib/weights', progress=True, file_name=None)
model_path = load_file_from_url(
url=model_url, model_dir='facexlib/weights', progress=True, file_name=None, save_dir=model_rootpath)
model.load_state_dict(torch.load(model_path)['state_dict'], strict=True)
model.eval()
model = model.to(device)
Expand Down
5 changes: 3 additions & 2 deletions facexlib/assessment/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,16 @@
from .hyperiqa_net import HyperIQA


def init_assessment_model(model_name, half=False, device='cuda'):
def init_assessment_model(model_name, half=False, device='cuda', model_rootpath=None):
if model_name == 'hypernet':
model = HyperIQA(16, 112, 224, 112, 56, 28, 14, 7)
model_url = 'https://github.com/xinntao/facexlib/releases/download/v0.2.0/assessment_hyperIQA.pth'
else:
raise NotImplementedError(f'{model_name} is not implemented.')

# load the pre-trained hypernet model
hypernet_model_path = load_file_from_url(url=model_url, model_dir='facexlib/weights', progress=True, file_name=None)
hypernet_model_path = load_file_from_url(
url=model_url, model_dir='facexlib/weights', progress=True, file_name=None, save_dir=model_rootpath)
model.hypernet.load_state_dict((torch.load(hypernet_model_path, map_location=lambda storage, loc: storage)))
model = model.eval()
model = model.to(device)
Expand Down
6 changes: 4 additions & 2 deletions facexlib/detection/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from .retinaface import RetinaFace


def init_detection_model(model_name, half=False, device='cuda'):
def init_detection_model(model_name, half=False, device='cuda', model_rootpath=None):
if model_name == 'retinaface_resnet50':
model = RetinaFace(network_name='resnet50', half=half)
model_url = 'https://github.com/xinntao/facexlib/releases/download/v0.1.0/detection_Resnet50_Final.pth'
Expand All @@ -15,7 +15,9 @@ def init_detection_model(model_name, half=False, device='cuda'):
else:
raise NotImplementedError(f'{model_name} is not implemented.')

model_path = load_file_from_url(url=model_url, model_dir='facexlib/weights', progress=True, file_name=None)
model_path = load_file_from_url(
url=model_url, model_dir='facexlib/weights', progress=True, file_name=None, save_dir=model_rootpath)

# TODO: clean pretrained model
load_net = torch.load(model_path, map_location=lambda storage, loc: storage)
# remove unnecessary 'module.'
Expand Down
5 changes: 3 additions & 2 deletions facexlib/headpose/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,15 @@
from .hopenet_arch import HopeNet


def init_headpose_model(model_name, half=False, device='cuda'):
def init_headpose_model(model_name, half=False, device='cuda', model_rootpath=None):
if model_name == 'hopenet':
model = HopeNet('resnet', [3, 4, 6, 3], 66)
model_url = 'https://github.com/xinntao/facexlib/releases/download/v0.2.0/headpose_hopenet.pth'
else:
raise NotImplementedError(f'{model_name} is not implemented.')

model_path = load_file_from_url(url=model_url, model_dir='facexlib/weights', progress=True, file_name=None)
model_path = load_file_from_url(
url=model_url, model_dir='facexlib/weights', progress=True, file_name=None, save_dir=model_rootpath)
load_net = torch.load(model_path, map_location=lambda storage, loc: storage)['params']
model.load_state_dict(load_net, strict=True)
model.eval()
Expand Down
5 changes: 3 additions & 2 deletions facexlib/matting/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,15 @@
from .modnet import MODNet


def init_matting_model(model_name='modnet', half=False, device='cuda'):
def init_matting_model(model_name='modnet', half=False, device='cuda', model_rootpath=None):
if model_name == 'modnet':
model = MODNet(backbone_pretrained=False)
model_url = 'https://github.com/xinntao/facexlib/releases/download/v0.2.0/matting_modnet_portrait.pth'
else:
raise NotImplementedError(f'{model_name} is not implemented.')

model_path = load_file_from_url(url=model_url, model_dir='facexlib/weights', progress=True, file_name=None)
model_path = load_file_from_url(
url=model_url, model_dir='facexlib/weights', progress=True, file_name=None, save_dir=model_rootpath)
# TODO: clean pretrained model
load_net = torch.load(model_path, map_location=lambda storage, loc: storage)
# remove unnecessary 'module.'
Expand Down
5 changes: 3 additions & 2 deletions facexlib/parsing/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from .parsenet import ParseNet


def init_parsing_model(model_name='bisenet', half=False, device='cuda'):
def init_parsing_model(model_name='bisenet', half=False, device='cuda', model_rootpath=None):
if model_name == 'bisenet':
model = BiSeNet(num_class=19)
model_url = 'https://github.com/xinntao/facexlib/releases/download/v0.2.0/parsing_bisenet.pth'
Expand All @@ -15,7 +15,8 @@ def init_parsing_model(model_name='bisenet', half=False, device='cuda'):
else:
raise NotImplementedError(f'{model_name} is not implemented.')

model_path = load_file_from_url(url=model_url, model_dir='facexlib/weights', progress=True, file_name=None)
model_path = load_file_from_url(
url=model_url, model_dir='facexlib/weights', progress=True, file_name=None, save_dir=model_rootpath)
load_net = torch.load(model_path, map_location=lambda storage, loc: storage)
model.load_state_dict(load_net, strict=True)
model.eval()
Expand Down
5 changes: 3 additions & 2 deletions facexlib/recognition/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,15 @@
from .arcface_arch import Backbone


def init_recognition_model(model_name, half=False, device='cuda'):
def init_recognition_model(model_name, half=False, device='cuda', model_rootpath=None):
if model_name == 'arcface':
model = Backbone(num_layers=50, drop_ratio=0.6, mode='ir_se').to('cuda').eval()
model_url = 'https://github.com/xinntao/facexlib/releases/download/v0.1.0/recognition_arcface_ir_se50.pth'
else:
raise NotImplementedError(f'{model_name} is not implemented.')

model_path = load_file_from_url(url=model_url, model_dir='facexlib/weights', progress=True, file_name=None)
model_path = load_file_from_url(
url=model_url, model_dir='facexlib/weights', progress=True, file_name=None, save_dir=model_rootpath)
model.load_state_dict(torch.load(model_path), strict=True)
model.eval()
model = model.to(device)
Expand Down
7 changes: 4 additions & 3 deletions facexlib/utils/face_restoration_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,8 @@ def __init__(self,
template_3points=False,
pad_blur=False,
use_parse=False,
device=None):
device=None,
model_rootpath=None):
self.template_3points = template_3points # improve robustness
self.upscale_factor = upscale_factor
# the cropped face ratio based on the square face
Expand Down Expand Up @@ -95,11 +96,11 @@ def __init__(self,
self.device = device

# init face detection model
self.face_det = init_detection_model(det_model, half=False, device=self.device)
self.face_det = init_detection_model(det_model, half=False, device=self.device, model_rootpath=model_rootpath)

# init face parsing model
self.use_parse = use_parse
self.face_parse = init_parsing_model(model_name='parsenet', device=self.device)
self.face_parse = init_parsing_model(model_name='parsenet', device=self.device, model_rootpath=model_rootpath)

def set_upscale_factor(self, upscale_factor):
self.upscale_factor = upscale_factor
Expand Down
8 changes: 5 additions & 3 deletions facexlib/utils/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,20 +56,22 @@ def _totensor(img, bgr2rgb, float32):
return _totensor(imgs, bgr2rgb, float32)


def load_file_from_url(url, model_dir=None, progress=True, file_name=None):
def load_file_from_url(url, model_dir=None, progress=True, file_name=None, save_dir=None):
"""Ref:https://github.com/1adrianb/face-alignment/blob/master/face_alignment/utils.py
"""
if model_dir is None:
hub_dir = get_dir()
model_dir = os.path.join(hub_dir, 'checkpoints')

os.makedirs(os.path.join(ROOT_DIR, model_dir), exist_ok=True)
if save_dir is None:
save_dir = os.path.join(ROOT_DIR, model_dir)
os.makedirs(save_dir, exist_ok=True)

parts = urlparse(url)
filename = os.path.basename(parts.path)
if file_name is not None:
filename = file_name
cached_file = os.path.abspath(os.path.join(ROOT_DIR, model_dir, filename))
cached_file = os.path.abspath(os.path.join(save_dir, filename))
if not os.path.exists(cached_file):
print(f'Downloading: "{url}" to {cached_file}\n')
download_url_to_file(url, cached_file, hash_prefix=None, progress=progress)
Expand Down

0 comments on commit 7655b7c

Please sign in to comment.