Skip to content

Commit

Permalink
fix default loading path
Browse files Browse the repository at this point in the history
  • Loading branch information
zhreshold committed Jun 26, 2017
1 parent d8a359b commit ef7a773
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 11 deletions.
6 changes: 3 additions & 3 deletions demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
import tools.find_mxnet
import mxnet as mx
import os
import importlib
import sys
from detect.detector import Detector
from symbol.symbol_factory import get_symbol
Expand Down Expand Up @@ -52,7 +51,7 @@ def parse_args():
parser.add_argument('--epoch', dest='epoch', help='epoch of trained model',
default=0, type=int)
parser.add_argument('--prefix', dest='prefix', help='trained model prefix',
default=os.path.join(os.getcwd(), 'model', 'ssd_vgg16_reduced_300'),
default=os.path.join(os.getcwd(), 'model', 'ssd_'),
type=str)
parser.add_argument('--cpu', dest='cpu', help='(override GPU) use CPU to detect',
action='store_true', default=False)
Expand Down Expand Up @@ -112,7 +111,8 @@ def parse_class_names(class_names):

network = None if args.deploy_net else args.network
class_names = parse_class_names(args.class_names)
detector = get_detector(network, args.prefix, args.epoch,
prefix = args.prefix + args.network + '_' + str(args.data_shape)
detector = get_detector(network, prefix, args.epoch,
args.data_shape,
(args.mean_r, args.mean_g, args.mean_b),
ctx, len(class_names), args.nms_thresh, args.force_nms)
Expand Down
9 changes: 5 additions & 4 deletions deploy.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,12 @@

def parse_args():
parser = argparse.ArgumentParser(description='Convert a trained model to deploy model')
parser.add_argument('--network', dest='network', type=str, default='vgg16_ssd_300',
choices=['vgg16_ssd_300', 'vgg16_ssd_512'], help='which network to use')
parser.add_argument('--network', dest='network', type=str, default='vgg16_reduced',
help='which network to use')
parser.add_argument('--epoch', dest='epoch', help='epoch of trained model',
default=0, type=int)
parser.add_argument('--prefix', dest='prefix', help='trained model prefix',
default=os.path.join(os.getcwd(), 'model', 'ssd_300'), type=str)
default=os.path.join(os.getcwd(), 'model', 'ssd_'), type=str)
parser.add_argument('--data-shape', dest='data_shape', type=int, default=300,
help='data shape')
parser.add_argument('--num-class', dest='num_classes', help='number of classes',
Expand All @@ -33,7 +33,8 @@ def parse_args():
net = get_symbol(args.network).get_symbol(args.network, args.data_shape,
num_classes=args.num_classes, nms_thresh=args.nms_thresh,
force_suppress=args.force_nms, nms_topk=args.nms_topk)
_, arg_params, aux_params = mx.model.load_checkpoint(args.prefix, args.epoch)
prefix = args.prefix + args.network + '_' + str(args.data_shape)
_, arg_params, aux_params = mx.model.load_checkpoint(prefix, args.epoch)
# new name
tmp = args.prefix.rsplit('/', 1)
save_prefix = '/deploy_'.join(tmp)
Expand Down
8 changes: 4 additions & 4 deletions evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@ def parse_args():
default=os.path.join(os.getcwd(), 'data', 'val.rec'), type=str)
parser.add_argument('--list-path', dest='list_path', help='which list file to use',
default="", type=str)
parser.add_argument('--network', dest='network', type=str, default='vgg16_ssd_300',
choices=['vgg16_ssd_300', 'vgg16_ssd_512'], help='which network to use')
parser.add_argument('--network', dest='network', type=str, default='vgg16_reduced',
help='which network to use')
parser.add_argument('--batch-size', dest='batch_size', type=int, default=32,
help='evaluation batch size')
parser.add_argument('--num-class', dest='num_class', type=int, default=20,
Expand All @@ -25,7 +25,7 @@ def parse_args():
parser.add_argument('--epoch', dest='epoch', help='epoch of pretrained model',
default=0, type=int)
parser.add_argument('--prefix', dest='prefix', help='load model prefix',
default=os.path.join(os.getcwd(), 'model', 'ssd'), type=str)
default=os.path.join(os.getcwd(), 'model', 'ssd_'), type=str)
parser.add_argument('--gpus', dest='gpu_id', help='GPU devices to evaluate with',
default='0', type=str)
parser.add_argument('--cpu', dest='cpu', help='use cpu to evaluate, this can be slow',
Expand Down Expand Up @@ -78,7 +78,7 @@ def parse_args():
network = None if args.deploy_net else args.network
evaluate_net(network, args.rec_path, num_class,
(args.mean_r, args.mean_g, args.mean_b), args.data_shape,
args.prefix, args.epoch, ctx, batch_size=args.batch_size,
args.prefix + args.network, args.epoch, ctx, batch_size=args.batch_size,
path_imglist=args.list_path, nms_thresh=args.nms_thresh,
force_nms=args.force_nms, ovp_thresh=args.overlap_thresh,
use_difficult=args.use_difficult, class_names=class_names,
Expand Down

0 comments on commit ef7a773

Please sign in to comment.