Changeset 0e55afe in sasmodels for sasmodels/compare.py
- Timestamp:
- Nov 29, 2017 6:55:21 PM (6 years ago)
- Branches:
- master, core_shell_microgels, magnetic_model, ticket-1257-vesicle-product, ticket_1156, ticket_1265_superball, ticket_822_more_unit_tests
- Children:
- 4493288
- Parents:
- 688d315 (diff), b669b49 (diff)
Note: this is a merge changeset, the changes displayed below correspond to the merge itself.
Use the (diff) links above to see all the changes relative to each parent. - File:
-
- 1 edited
Legend:
- Unmodified
- Added
- Removed
-
sasmodels/compare.py
r376b0ee r0e55afe 40 40 from . import core 41 41 from . import kerneldll 42 from . import exception43 42 from .data import plot_theory, empty_data1D, empty_data2D, load_data 44 43 from .direct_model import DirectModel, get_mesh 45 from .convert import revert_name, revert_pars, constrain_new_to_old46 44 from .generate import FLOAT_RE 47 45 from .weights import plot_weights 48 46 47 # pylint: disable=unused-import 49 48 try: 50 49 from typing import Optional, Dict, Any, Callable, Tuple 51 except Exception:50 except ImportError: 52 51 pass 53 52 else: … … 55 54 from .data import Data 56 55 Calculator = Callable[[float], np.ndarray] 56 # pylint: enable=unused-import 57 57 58 58 USAGE = """ … … 97 97 -single/-double/-half/-fast sets an OpenCL calculation engine 98 98 -single!/-double!/-quad! sets an OpenMP calculation engine 99 -sasview sets the sasview calculation engine100 99 101 100 === plotting === … … 150 149 kerneldll.ALLOW_SINGLE_PRECISION_DLLS = True 151 150 152 # list of math functions for use in evaluating parameters 153 MATH = dict((k,getattr(math, k)) for k in dir(math) if not k.startswith('_')) 151 def build_math_context(): 152 # type: () -> Dict[str, Callable] 153 """build dictionary of functions from math module""" 154 return dict((k, getattr(math, k)) 155 for k in dir(math) if not k.startswith('_')) 156 157 #: list of math functions for use in evaluating parameters 158 MATH = build_math_context() 154 159 155 160 # CRUFT python 2.6 … … 231 236 pass 232 237 233 def __exit__(self, exc_type, exc_value, trace back):238 def __exit__(self, exc_type, exc_value, trace): 234 239 # type: (Any, BaseException, Any) -> None 235 # TODO: better typing for __exit__ method236 240 np.random.set_state(self._state) 237 241 … … 252 256 """ 253 257 Add a beam stop of the given *radius*. If *outer*, make an annulus. 254 255 Note: this function does not require sasview256 258 """ 257 259 if hasattr(data, 'qx_data'): … … 374 376 375 377 def _random_pd(model_info, pars): 378 # type: (ModelInfo, Dict[str, float]) -> None 379 """ 380 Generate a random dispersity distribution for the model. 381 382 1% no shape dispersity 383 85% single shape parameter 384 13% two shape parameters 385 1% three shape parameters 386 387 If oriented, then put dispersity in theta, add phi and psi dispersity 388 with 10% probability for each. 389 """ 376 390 pd = [p for p in model_info.parameters.kernel_parameters if p.polydisperse] 377 391 pd_volume = [] … … 444 458 value = pars[p.name] 445 459 if p.units == 'Ang' and value > maxdim: 446 pars[p.name] = maxdim*10**np.random.uniform(-3, 0)460 pars[p.name] = maxdim*10**np.random.uniform(-3, 0) 447 461 448 462 def constrain_pars(model_info, pars): … … 490 504 if pars['radius'] < pars['thick_string']: 491 505 pars['radius'], pars['thick_string'] = pars['thick_string'], pars['radius'] 492 pass493 506 494 507 elif name == 'rpa': … … 608 621 return pars 609 622 610 def eval_sasview(model_info, data):611 # type: (Modelinfo, Data) -> Calculator612 """613 Return a model calculator using the pre-4.0 SasView models.614 """615 # importing sas here so that the error message will be that sas failed to616 # import rather than the more obscure smear_selection not imported error617 import sas618 import sas.models619 from sas.models.qsmearing import smear_selection620 from sas.models.MultiplicationModel import MultiplicationModel621 from sas.models.dispersion_models import models as dispersers622 623 def get_model_class(name):624 # type: (str) -> "sas.models.BaseComponent"625 #print("new",sorted(_pars.items()))626 __import__('sas.models.' + name)627 ModelClass = getattr(getattr(sas.models, name, None), name, None)628 if ModelClass is None:629 raise ValueError("could not find model %r in sas.models"%name)630 return ModelClass631 632 # WARNING: ugly hack when handling model!633 # Sasview models with multiplicity need to be created with the target634 # multiplicity, so we cannot create the target model ahead of time for635 # for multiplicity models. Instead we store the model in a list and636 # update the first element of that list with the new multiplicity model637 # every time we evaluate.638 639 # grab the sasview model, or create it if it is a product model640 if model_info.composition:641 composition_type, parts = model_info.composition642 if composition_type == 'product':643 P, S = [get_model_class(revert_name(p))() for p in parts]644 model = [MultiplicationModel(P, S)]645 else:646 raise ValueError("sasview mixture models not supported by compare")647 else:648 old_name = revert_name(model_info)649 if old_name is None:650 raise ValueError("model %r does not exist in old sasview"651 % model_info.id)652 ModelClass = get_model_class(old_name)653 model = [ModelClass()]654 model[0].disperser_handles = {}655 656 # build a smearer with which to call the model, if necessary657 smearer = smear_selection(data, model=model)658 if hasattr(data, 'qx_data'):659 q = np.sqrt(data.qx_data**2 + data.qy_data**2)660 index = ((~data.mask) & (~np.isnan(data.data))661 & (q >= data.qmin) & (q <= data.qmax))662 if smearer is not None:663 smearer.model = model # because smear_selection has a bug664 smearer.accuracy = data.accuracy665 smearer.set_index(index)666 def _call_smearer():667 smearer.model = model[0]668 return smearer.get_value()669 theory = _call_smearer670 else:671 theory = lambda: model[0].evalDistribution([data.qx_data[index],672 data.qy_data[index]])673 elif smearer is not None:674 theory = lambda: smearer(model[0].evalDistribution(data.x))675 else:676 theory = lambda: model[0].evalDistribution(data.x)677 678 def calculator(**pars):679 # type: (float, ...) -> np.ndarray680 """681 Sasview calculator for model.682 """683 oldpars = revert_pars(model_info, pars)684 # For multiplicity models, create a model with the correct multiplicity685 control = oldpars.pop("CONTROL", None)686 if control is not None:687 # sphericalSLD has one fewer multiplicity. This update should688 # happen in revert_pars, but it hasn't been called yet.689 model[0] = ModelClass(control)690 # paying for parameter conversion each time to keep life simple, if not fast691 for k, v in oldpars.items():692 if k.endswith('.type'):693 par = k[:-5]694 if v == 'gaussian': continue695 cls = dispersers[v if v != 'rectangle' else 'rectangula']696 handle = cls()697 model[0].disperser_handles[par] = handle698 try:699 model[0].set_dispersion(par, handle)700 except Exception:701 exception.annotate_exception("while setting %s to %r"702 %(par, v))703 raise704 705 706 #print("sasview pars",oldpars)707 for k, v in oldpars.items():708 name_attr = k.split('.') # polydispersity components709 if len(name_attr) == 2:710 par, disp_par = name_attr711 model[0].dispersion[par][disp_par] = v712 else:713 model[0].setParam(k, v)714 return theory()715 716 calculator.engine = "sasview"717 return calculator718 623 719 624 DTYPE_MAP = { … … 809 714 than OpenCL. 810 715 """ 811 if dtype == 'sasview': 812 return eval_sasview(model_info, data) 813 elif dtype is None or not dtype.endswith('!'): 716 if dtype is None or not dtype.endswith('!'): 814 717 return eval_opencl(model_info, data, dtype=dtype, cutoff=cutoff) 815 718 else: … … 847 750 # print a separate seed for each dataset for better reproducibility 848 751 new_seed = np.random.randint(1000000) 849 print("Set %d uses -random=%i"%(k+1, new_seed))752 print("Set %d uses -random=%i"%(k+1, new_seed)) 850 753 np.random.seed(new_seed) 851 754 opts['pars'] = parse_pars(opts, maxdim=maxdim) … … 868 771 def run_models(opts, verbose=False): 869 772 # type: (Dict[str, Any]) -> Dict[str, Any] 773 """ 774 Process a parameter set, return calculation results and times. 775 """ 870 776 871 777 base, comp = opts['engines'] … … 923 829 # work with trimmed data, not the full set 924 830 sorted_err = np.sort(abs(err.compressed())) 925 if len(sorted_err) == 0 .:831 if len(sorted_err) == 0: 926 832 print(label + " no valid values") 927 833 return … … 941 847 def plot_models(opts, result, limits=None, setnum=0): 942 848 # type: (Dict[str, Any], Dict[str, Any], Optional[Tuple[float, float]]) -> Tuple[float, float] 849 """ 850 Plot the results from :func:`run_model`. 851 """ 943 852 import matplotlib.pyplot as plt 944 853 … … 987 896 errview = 'linear' 988 897 if 0: # 95% cutoff 989 sorted = np.sort(err.flatten())990 cutoff = sorted [int(sorted.size*0.95)]898 sorted_err = np.sort(err.flatten()) 899 cutoff = sorted_err[int(sorted_err.size*0.95)] 991 900 err[err > cutoff] = cutoff 992 901 #err,errstr = base/comp,"ratio" … … 1051 960 'engine=', 1052 961 'half', 'fast', 'single', 'double', 'single!', 'double!', 'quad!', 1053 'sasview', # TODO: remove sasview 3.x support1054 962 1055 963 # Output options … … 1057 965 ] 1058 966 1059 NAME_OPTIONS = set(k for k in OPTIONS if not k.endswith('='))1060 VALUE_OPTIONS = [k[:-1] for k in OPTIONS if k.endswith('=')]967 NAME_OPTIONS = (lambda: set(k for k in OPTIONS if not k.endswith('=')))() 968 VALUE_OPTIONS = (lambda: [k[:-1] for k in OPTIONS if k.endswith('=')])() 1061 969 1062 970 … … 1106 1014 1107 1015 INTEGER_RE = re.compile("^[+-]?[1-9][0-9]*$") 1108 def isnumber(str): 1109 match = FLOAT_RE.match(str) 1110 isfloat = (match and not str[match.end():]) 1111 return isfloat or INTEGER_RE.match(str) 1016 def isnumber(s): 1017 # type: (str) -> bool 1018 """Return True if string contains an int or float""" 1019 match = FLOAT_RE.match(s) 1020 isfloat = (match and not s[match.end():]) 1021 return isfloat or INTEGER_RE.match(s) 1112 1022 1113 1023 # For distinguishing pairs of models for comparison … … 1148 1058 name = positional_args[-1] 1149 1059 1150 # pylint: disable=bad-whitespace 1060 # pylint: disable=bad-whitespace,C0321 1151 1061 # Interpret the flags 1152 1062 opts = { … … 1232 1142 elif arg == '-double!': opts['engine'] = 'double!' 1233 1143 elif arg == '-quad!': opts['engine'] = 'quad!' 1234 elif arg == '-sasview': opts['engine'] = 'sasview'1235 1144 elif arg == '-edit': opts['explore'] = True 1236 1145 elif arg == '-demo': opts['use_demo'] = True … … 1239 1148 elif arg == '-html': opts['html'] = True 1240 1149 elif arg == '-help': opts['html'] = True 1241 # pylint: enable=bad-whitespace 1150 # pylint: enable=bad-whitespace,C0321 1242 1151 1243 1152 # Magnetism forces 2D for now … … 1314 1223 1315 1224 def set_spherical_integration_parameters(opts, steps): 1225 # type: (Dict[str, Any], int) -> None 1316 1226 """ 1317 1227 Set integration parameters for spherical integration over the entire … … 1337 1247 'psi_pd_type=rectangle', 1338 1248 ]) 1339 pass1340 1249 1341 1250 def parse_pars(opts, maxdim=np.inf): 1251 # type: (Dict[str, Any], float) -> Tuple[Dict[str, float], Dict[str, float]] 1252 """ 1253 Generate a parameter set. 1254 1255 The default values come from the model, or a randomized model if a seed 1256 value is given. Next, evaluate any parameter expressions, constraining 1257 the value of the parameter within and between models. If *maxdim* is 1258 given, limit parameters with units of Angstrom to this value. 1259 1260 Returns a pair of parameter dictionaries for base and comparison models. 1261 """ 1342 1262 model_info, model_info2 = opts['info'] 1343 1263 … … 1378 1298 print("%r invalid; parameters are: %s"%(k, ", ".join(sorted(s)))) 1379 1299 return None 1380 v1, v2 = v.split(PAR_SPLIT, 2) if PAR_SPLIT in v else (v, v)1300 v1, v2 = v.split(PAR_SPLIT, 2) if PAR_SPLIT in v else (v, v) 1381 1301 if v1 and k in pars: 1382 1302 presets[k] = float(v1) if isnumber(v1) else v1 … … 1427 1347 show html docs for the model 1428 1348 """ 1429 import os1430 1349 from .generate import make_html 1431 1350 from . import rst2html … … 1434 1353 html = make_html(info) 1435 1354 path = os.path.dirname(info.filename) 1436 url = "file://" +path.replace("\\","/")[2:]+"/"1355 url = "file://" + path.replace("\\", "/")[2:] + "/" 1437 1356 rst2html.view_html_qtapp(html, url) 1438 1357 … … 1458 1377 frame.panel.Layout() 1459 1378 frame.panel.aui.Split(0, wx.TOP) 1460 def reset_parameters(event):1379 def _reset_parameters(event): 1461 1380 model.revert_values() 1462 1381 signal.update_parameters(problem) 1463 frame.Bind(wx.EVT_TOOL, reset_parameters, frame.ToolBar.GetToolByPos(1)) 1464 if is_mac: frame.Show() 1382 frame.Bind(wx.EVT_TOOL, _reset_parameters, frame.ToolBar.GetToolByPos(1)) 1383 if is_mac: 1384 frame.Show() 1465 1385 # If running withing an app, start the main loop 1466 1386 if app: … … 1504 1424 1505 1425 def revert_values(self): 1426 # type: () -> None 1427 """ 1428 Restore starting values of the parameters. 1429 """ 1506 1430 for k, v in self.starting_values.items(): 1507 1431 self.pars[k].value = v 1508 1432 1509 1433 def model_update(self): 1434 # type: () -> None 1435 """ 1436 Respond to signal that model parameters have been changed. 1437 """ 1510 1438 pass 1511 1439
Note: See TracChangeset
for help on using the changeset viewer.