Skip to content

Commit

Permalink
Merge branch 'master' of https://github.com/SasView/sasmodels.git
Browse files Browse the repository at this point in the history
  • Loading branch information
butlerpd committed Oct 2, 2016
2 parents 5f1acda + 79906d1 commit d247047
Show file tree
Hide file tree
Showing 14 changed files with 417 additions and 100 deletions.
2 changes: 1 addition & 1 deletion sascomp
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ def main():
sys.path.insert(0, sasmodels)

import sasmodels.compare
sasmodels.compare.main()
sasmodels.compare.main(*sys.argv[1:])

if __name__ == "__main__":
main()
2 changes: 1 addition & 1 deletion sasmodels/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
OpenCL drivers are available. See :mod:`generate` for details on
defining new models.
"""
__version__ = "0.93"
__version__ = "0.94"

def data_files():
"""
Expand Down
55 changes: 29 additions & 26 deletions sasmodels/compare.py
Original file line number Diff line number Diff line change
Expand Up @@ -788,41 +788,41 @@ def get_pars(model_info, use_demo=False):
return pars


def parse_opts():
# type: () -> Dict[str, Any]
def parse_opts(argv):
# type: (List[str]) -> Dict[str, Any]
"""
Parse command line options.
"""
MODELS = core.list_models()
flags = [arg for arg in sys.argv[1:]
flags = [arg for arg in argv
if arg.startswith('-')]
values = [arg for arg in sys.argv[1:]
values = [arg for arg in argv
if not arg.startswith('-') and '=' in arg]
args = [arg for arg in sys.argv[1:]
positional_args = [arg for arg in argv
if not arg.startswith('-') and '=' not in arg]
models = "\n ".join("%-15s"%v for v in MODELS)
if len(args) == 0:
if len(positional_args) == 0:
print(USAGE)
print("\nAvailable models:")
print(columnize(MODELS, indent=" "))
sys.exit(1)
if len(args) > 3:
return None
if len(positional_args) > 3:
print("expected parameters: model N1 N2")

name = args[0]
name = positional_args[0]
try:
model_info = core.load_model_info(name)
except ImportError as exc:
print(str(exc))
print("Could not find model; use one of:\n " + models)
sys.exit(1)
return None

invalid = [o[1:] for o in flags
if o[1:] not in NAME_OPTIONS
and not any(o.startswith('-%s='%t) for t in VALUE_OPTIONS)]
if invalid:
print("Invalid options: %s"%(", ".join(invalid)))
sys.exit(1)
return None


# pylint: disable=bad-whitespace
Expand Down Expand Up @@ -897,8 +897,8 @@ def parse_opts():
elif len(engines) > 2:
del engines[2:]

n1 = int(args[1]) if len(args) > 1 else 1
n2 = int(args[2]) if len(args) > 2 else 1
n1 = int(positional_args[1]) if len(positional_args) > 1 else 1
n2 = int(positional_args[2]) if len(positional_args) > 2 else 1
use_sasview = any(engine == 'sasview' and count > 0
for engine, count in zip(engines, [n1, n2]))

Expand All @@ -915,7 +915,7 @@ def parse_opts():
# extract base name without polydispersity info
s = set(p.split('_pd')[0] for p in pars)
print("%r invalid; parameters are: %s"%(k, ", ".join(sorted(s))))
sys.exit(1)
return None
presets[k] = float(v) if not k.endswith('type') else v

# randomize parameters
Expand Down Expand Up @@ -969,16 +969,18 @@ def explore(opts):
from bumps.names import FitProblem # type: ignore
from bumps.gui.app_frame import AppFrame # type: ignore

problem = FitProblem(Explore(opts))
is_mac = "cocoa" in wx.version()
app = wx.App()
frame = AppFrame(parent=None, title="explore")
# Create an app if not running embedded
app = wx.App() if wx.GetApp() is None else None
problem = FitProblem(Explore(opts))
frame = AppFrame(parent=None, title="explore", size=(1000,700))
if not is_mac: frame.Show()
frame.panel.set_model(model=problem)
frame.panel.Layout()
frame.panel.aui.Split(0, wx.TOP)
if is_mac: frame.Show()
app.MainLoop()
# If running withing an app, start the main loop
if app: app.MainLoop()

class Explore(object):
"""
Expand Down Expand Up @@ -1047,16 +1049,17 @@ def plot(self, view='log'):
self.limits = vmax*1e-7, 1.3*vmax


def main():
# type: () -> None
def main(*argv):
# type: (*str) -> None
"""
Main program.
"""
opts = parse_opts()
if opts['explore']:
explore(opts)
else:
compare(opts)
opts = parse_opts(argv)
if opts is not None:
if opts['explore']:
explore(opts)
else:
compare(opts)

if __name__ == "__main__":
main()
main(*sys.argv[1:])
32 changes: 16 additions & 16 deletions sasmodels/compare_many.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,32 +228,32 @@ def print_help():
""")
print_models()

def main():
def main(argv):
"""
Main program.
"""
if len(sys.argv) not in (6, 7):
if len(argv) not in (5, 6):
print_help()
sys.exit(1)
return

model = sys.argv[1]
model = argv[0]
if not (model in MODELS) and (model != "all"):
print('Bad model %s. Use "all" or one of:'%model)
print_models()
sys.exit(1)
return
try:
count = int(sys.argv[2])
is2D = sys.argv[3].startswith('2d')
assert sys.argv[3][1] == 'd'
Nq = int(sys.argv[3][2:])
mono = sys.argv[4] == 'mono'
cutoff = float(sys.argv[4]) if not mono else 0
base = sys.argv[5]
comp = sys.argv[6] if len(sys.argv) > 6 else "sasview"
count = int(argv[1])
is2D = argv[2].startswith('2d')
assert argv[2][1] == 'd'
Nq = int(argv[2][2:])
mono = argv[3] == 'mono'
cutoff = float(argv[3]) if not mono else 0
base = argv[4]
comp = argv[5] if len(argv) > 5 else "sasview"
except Exception:
traceback.print_exc()
print_usage()
sys.exit(1)
return

data, index = make_data({'qmax':1.0, 'is2d':is2D, 'nq':Nq, 'res':0.,
'accuracy': 'Low', 'view':'log', 'zero': False})
Expand All @@ -264,5 +264,5 @@ def main():

if __name__ == "__main__":
#from .compare import push_seed
#with push_seed(1): main()
main()
#with push_seed(1): main(sys.argv[1:])
main(sys.argv[1:])
7 changes: 4 additions & 3 deletions sasmodels/custom/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from importlib.util import spec_from_file_location, module_from_spec # type: ignore
def load_module_from_path(fullname, path):
"""load module from *path* as *fullname*"""
spec = spec_from_file_location(fullname, path)
spec = spec_from_file_location(fullname, os.path.expanduser(path))
module = module_from_spec(spec)
spec.loader.exec_module(module)
return module
Expand All @@ -25,7 +25,7 @@ def load_module_from_path(fullname, path):
import imp
def load_module_from_path(fullname, path):
"""load module from *path* as *fullname*"""
module = imp.load_source(fullname, path)
module = imp.load_source(fullname, os.path.expanduser(path))
#os.unlink(path+"c") # remove the automatic pyc file
return module

Expand All @@ -34,5 +34,6 @@ def load_custom_kernel_module(path):
# Pull off the last .ext if it exists; there may be others
name = basename(splitext(path)[0])
# Placing the model in the 'sasmodels.custom' name space.
kernel_module = load_module_from_path('sasmodels.custom.'+name, path)
kernel_module = load_module_from_path('sasmodels.custom.'+name,
os.path.expanduser(path))
return kernel_module
9 changes: 8 additions & 1 deletion sasmodels/direct_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,9 @@ def call_ER(model_info, pars):
"""
if model_info.ER is None:
return 1.0
elif not model_info.parameters.form_volume_parameters:
# handle the case where ER is provided but model is not polydisperse
return model_info.ER()
else:
value, weight = _vol_pars(model_info, pars)
individual_radii = model_info.ER(*value)
Expand All @@ -100,6 +103,9 @@ def call_VR(model_info, pars):
"""
if model_info.VR is None:
return 1.0
elif not model_info.parameters.form_volume_parameters:
# handle the case where ER is provided but model is not polydisperse
return model_info.VR()
else:
value, weight = _vol_pars(model_info, pars)
whole, part = model_info.VR(*value)
Expand Down Expand Up @@ -151,6 +157,7 @@ def _vol_pars(model_info, pars):
vol_pars = [get_weights(p, pars)
for p in model_info.parameters.call_parameters
if p.type == 'volume']
#import pylab; pylab.plot(vol_pars[0][0],vol_pars[0][1]); pylab.show()
value, weight = dispersion_mesh(model_info, vol_pars)
return value, weight

Expand Down Expand Up @@ -394,7 +401,7 @@ def main():
model_info = load_model_info(model_name)
model = build_model(model_info)
calculator = DirectModel(data, model)
pars = dict((k, float(v))
pars = dict((k, (float(v) if not k.endswith("_pd_type") else v))
for pair in sys.argv[3:]
for k, v in [pair.split('=')])
if call == "ER_VR":
Expand Down
Loading

0 comments on commit d247047

Please sign in to comment.