Changeset 0e55afe in sasmodels for sasmodels/compare.py


Ignore:
Timestamp:
Nov 29, 2017 6:55:21 PM (6 years ago)
Author:
Paul Kienzle <pkienzle@…>
Branches:
master, core_shell_microgels, magnetic_model, ticket-1257-vesicle-product, ticket_1156, ticket_1265_superball, ticket_822_more_unit_tests
Children:
4493288
Parents:
688d315 (diff), b669b49 (diff)
Note: this is a merge changeset, the changes displayed below correspond to the merge itself.
Use the (diff) links above to see all the changes relative to each parent.
Message:

Merge branch 'master' into ticket-786

File:
1 edited

Legend:

Unmodified
Added
Removed
  • sasmodels/compare.py

    r376b0ee r0e55afe  
    4040from . import core 
    4141from . import kerneldll 
    42 from . import exception 
    4342from .data import plot_theory, empty_data1D, empty_data2D, load_data 
    4443from .direct_model import DirectModel, get_mesh 
    45 from .convert import revert_name, revert_pars, constrain_new_to_old 
    4644from .generate import FLOAT_RE 
    4745from .weights import plot_weights 
    4846 
     47# pylint: disable=unused-import 
    4948try: 
    5049    from typing import Optional, Dict, Any, Callable, Tuple 
    51 except Exception: 
     50except ImportError: 
    5251    pass 
    5352else: 
     
    5554    from .data import Data 
    5655    Calculator = Callable[[float], np.ndarray] 
     56# pylint: enable=unused-import 
    5757 
    5858USAGE = """ 
     
    9797    -single/-double/-half/-fast sets an OpenCL calculation engine 
    9898    -single!/-double!/-quad! sets an OpenMP calculation engine 
    99     -sasview sets the sasview calculation engine 
    10099 
    101100    === plotting === 
     
    150149kerneldll.ALLOW_SINGLE_PRECISION_DLLS = True 
    151150 
    152 # list of math functions for use in evaluating parameters 
    153 MATH = dict((k,getattr(math, k)) for k in dir(math) if not k.startswith('_')) 
     151def build_math_context(): 
     152    # type: () -> Dict[str, Callable] 
     153    """build dictionary of functions from math module""" 
     154    return dict((k, getattr(math, k)) 
     155                for k in dir(math) if not k.startswith('_')) 
     156 
     157#: list of math functions for use in evaluating parameters 
     158MATH = build_math_context() 
    154159 
    155160# CRUFT python 2.6 
     
    231236        pass 
    232237 
    233     def __exit__(self, exc_type, exc_value, traceback): 
     238    def __exit__(self, exc_type, exc_value, trace): 
    234239        # type: (Any, BaseException, Any) -> None 
    235         # TODO: better typing for __exit__ method 
    236240        np.random.set_state(self._state) 
    237241 
     
    252256    """ 
    253257    Add a beam stop of the given *radius*.  If *outer*, make an annulus. 
    254  
    255     Note: this function does not require sasview 
    256258    """ 
    257259    if hasattr(data, 'qx_data'): 
     
    374376 
    375377def _random_pd(model_info, pars): 
     378    # type: (ModelInfo, Dict[str, float]) -> None 
     379    """ 
     380    Generate a random dispersity distribution for the model. 
     381 
     382    1% no shape dispersity 
     383    85% single shape parameter 
     384    13% two shape parameters 
     385    1% three shape parameters 
     386 
     387    If oriented, then put dispersity in theta, add phi and psi dispersity 
     388    with 10% probability for each. 
     389    """ 
    376390    pd = [p for p in model_info.parameters.kernel_parameters if p.polydisperse] 
    377391    pd_volume = [] 
     
    444458        value = pars[p.name] 
    445459        if p.units == 'Ang' and value > maxdim: 
    446             pars[p.name] = maxdim*10**np.random.uniform(-3,0) 
     460            pars[p.name] = maxdim*10**np.random.uniform(-3, 0) 
    447461 
    448462def constrain_pars(model_info, pars): 
     
    490504        if pars['radius'] < pars['thick_string']: 
    491505            pars['radius'], pars['thick_string'] = pars['thick_string'], pars['radius'] 
    492         pass 
    493506 
    494507    elif name == 'rpa': 
     
    608621    return pars 
    609622 
    610 def eval_sasview(model_info, data): 
    611     # type: (Modelinfo, Data) -> Calculator 
    612     """ 
    613     Return a model calculator using the pre-4.0 SasView models. 
    614     """ 
    615     # importing sas here so that the error message will be that sas failed to 
    616     # import rather than the more obscure smear_selection not imported error 
    617     import sas 
    618     import sas.models 
    619     from sas.models.qsmearing import smear_selection 
    620     from sas.models.MultiplicationModel import MultiplicationModel 
    621     from sas.models.dispersion_models import models as dispersers 
    622  
    623     def get_model_class(name): 
    624         # type: (str) -> "sas.models.BaseComponent" 
    625         #print("new",sorted(_pars.items())) 
    626         __import__('sas.models.' + name) 
    627         ModelClass = getattr(getattr(sas.models, name, None), name, None) 
    628         if ModelClass is None: 
    629             raise ValueError("could not find model %r in sas.models"%name) 
    630         return ModelClass 
    631  
    632     # WARNING: ugly hack when handling model! 
    633     # Sasview models with multiplicity need to be created with the target 
    634     # multiplicity, so we cannot create the target model ahead of time for 
    635     # for multiplicity models.  Instead we store the model in a list and 
    636     # update the first element of that list with the new multiplicity model 
    637     # every time we evaluate. 
    638  
    639     # grab the sasview model, or create it if it is a product model 
    640     if model_info.composition: 
    641         composition_type, parts = model_info.composition 
    642         if composition_type == 'product': 
    643             P, S = [get_model_class(revert_name(p))() for p in parts] 
    644             model = [MultiplicationModel(P, S)] 
    645         else: 
    646             raise ValueError("sasview mixture models not supported by compare") 
    647     else: 
    648         old_name = revert_name(model_info) 
    649         if old_name is None: 
    650             raise ValueError("model %r does not exist in old sasview" 
    651                             % model_info.id) 
    652         ModelClass = get_model_class(old_name) 
    653         model = [ModelClass()] 
    654     model[0].disperser_handles = {} 
    655  
    656     # build a smearer with which to call the model, if necessary 
    657     smearer = smear_selection(data, model=model) 
    658     if hasattr(data, 'qx_data'): 
    659         q = np.sqrt(data.qx_data**2 + data.qy_data**2) 
    660         index = ((~data.mask) & (~np.isnan(data.data)) 
    661                  & (q >= data.qmin) & (q <= data.qmax)) 
    662         if smearer is not None: 
    663             smearer.model = model  # because smear_selection has a bug 
    664             smearer.accuracy = data.accuracy 
    665             smearer.set_index(index) 
    666             def _call_smearer(): 
    667                 smearer.model = model[0] 
    668                 return smearer.get_value() 
    669             theory = _call_smearer 
    670         else: 
    671             theory = lambda: model[0].evalDistribution([data.qx_data[index], 
    672                                                         data.qy_data[index]]) 
    673     elif smearer is not None: 
    674         theory = lambda: smearer(model[0].evalDistribution(data.x)) 
    675     else: 
    676         theory = lambda: model[0].evalDistribution(data.x) 
    677  
    678     def calculator(**pars): 
    679         # type: (float, ...) -> np.ndarray 
    680         """ 
    681         Sasview calculator for model. 
    682         """ 
    683         oldpars = revert_pars(model_info, pars) 
    684         # For multiplicity models, create a model with the correct multiplicity 
    685         control = oldpars.pop("CONTROL", None) 
    686         if control is not None: 
    687             # sphericalSLD has one fewer multiplicity.  This update should 
    688             # happen in revert_pars, but it hasn't been called yet. 
    689             model[0] = ModelClass(control) 
    690         # paying for parameter conversion each time to keep life simple, if not fast 
    691         for k, v in oldpars.items(): 
    692             if k.endswith('.type'): 
    693                 par = k[:-5] 
    694                 if v == 'gaussian': continue 
    695                 cls = dispersers[v if v != 'rectangle' else 'rectangula'] 
    696                 handle = cls() 
    697                 model[0].disperser_handles[par] = handle 
    698                 try: 
    699                     model[0].set_dispersion(par, handle) 
    700                 except Exception: 
    701                     exception.annotate_exception("while setting %s to %r" 
    702                                                  %(par, v)) 
    703                     raise 
    704  
    705  
    706         #print("sasview pars",oldpars) 
    707         for k, v in oldpars.items(): 
    708             name_attr = k.split('.')  # polydispersity components 
    709             if len(name_attr) == 2: 
    710                 par, disp_par = name_attr 
    711                 model[0].dispersion[par][disp_par] = v 
    712             else: 
    713                 model[0].setParam(k, v) 
    714         return theory() 
    715  
    716     calculator.engine = "sasview" 
    717     return calculator 
    718623 
    719624DTYPE_MAP = { 
     
    809714    than OpenCL. 
    810715    """ 
    811     if dtype == 'sasview': 
    812         return eval_sasview(model_info, data) 
    813     elif dtype is None or not dtype.endswith('!'): 
     716    if dtype is None or not dtype.endswith('!'): 
    814717        return eval_opencl(model_info, data, dtype=dtype, cutoff=cutoff) 
    815718    else: 
     
    847750            # print a separate seed for each dataset for better reproducibility 
    848751            new_seed = np.random.randint(1000000) 
    849             print("Set %d uses -random=%i"%(k+1,new_seed)) 
     752            print("Set %d uses -random=%i"%(k+1, new_seed)) 
    850753            np.random.seed(new_seed) 
    851754        opts['pars'] = parse_pars(opts, maxdim=maxdim) 
     
    868771def run_models(opts, verbose=False): 
    869772    # type: (Dict[str, Any]) -> Dict[str, Any] 
     773    """ 
     774    Process a parameter set, return calculation results and times. 
     775    """ 
    870776 
    871777    base, comp = opts['engines'] 
     
    923829    # work with trimmed data, not the full set 
    924830    sorted_err = np.sort(abs(err.compressed())) 
    925     if len(sorted_err) == 0.: 
     831    if len(sorted_err) == 0: 
    926832        print(label + "  no valid values") 
    927833        return 
     
    941847def plot_models(opts, result, limits=None, setnum=0): 
    942848    # type: (Dict[str, Any], Dict[str, Any], Optional[Tuple[float, float]]) -> Tuple[float, float] 
     849    """ 
     850    Plot the results from :func:`run_model`. 
     851    """ 
    943852    import matplotlib.pyplot as plt 
    944853 
     
    987896                errview = 'linear' 
    988897        if 0:  # 95% cutoff 
    989             sorted = np.sort(err.flatten()) 
    990             cutoff = sorted[int(sorted.size*0.95)] 
     898            sorted_err = np.sort(err.flatten()) 
     899            cutoff = sorted_err[int(sorted_err.size*0.95)] 
    991900            err[err > cutoff] = cutoff 
    992901        #err,errstr = base/comp,"ratio" 
     
    1051960    'engine=', 
    1052961    'half', 'fast', 'single', 'double', 'single!', 'double!', 'quad!', 
    1053     'sasview',  # TODO: remove sasview 3.x support 
    1054962 
    1055963    # Output options 
     
    1057965    ] 
    1058966 
    1059 NAME_OPTIONS = set(k for k in OPTIONS if not k.endswith('=')) 
    1060 VALUE_OPTIONS = [k[:-1] for k in OPTIONS if k.endswith('=')] 
     967NAME_OPTIONS = (lambda: set(k for k in OPTIONS if not k.endswith('=')))() 
     968VALUE_OPTIONS = (lambda: [k[:-1] for k in OPTIONS if k.endswith('=')])() 
    1061969 
    1062970 
     
    11061014 
    11071015INTEGER_RE = re.compile("^[+-]?[1-9][0-9]*$") 
    1108 def isnumber(str): 
    1109     match = FLOAT_RE.match(str) 
    1110     isfloat = (match and not str[match.end():]) 
    1111     return isfloat or INTEGER_RE.match(str) 
     1016def isnumber(s): 
     1017    # type: (str) -> bool 
     1018    """Return True if string contains an int or float""" 
     1019    match = FLOAT_RE.match(s) 
     1020    isfloat = (match and not s[match.end():]) 
     1021    return isfloat or INTEGER_RE.match(s) 
    11121022 
    11131023# For distinguishing pairs of models for comparison 
     
    11481058    name = positional_args[-1] 
    11491059 
    1150     # pylint: disable=bad-whitespace 
     1060    # pylint: disable=bad-whitespace,C0321 
    11511061    # Interpret the flags 
    11521062    opts = { 
     
    12321142        elif arg == '-double!': opts['engine'] = 'double!' 
    12331143        elif arg == '-quad!':   opts['engine'] = 'quad!' 
    1234         elif arg == '-sasview': opts['engine'] = 'sasview' 
    12351144        elif arg == '-edit':    opts['explore'] = True 
    12361145        elif arg == '-demo':    opts['use_demo'] = True 
     
    12391148        elif arg == '-html':    opts['html'] = True 
    12401149        elif arg == '-help':    opts['html'] = True 
    1241     # pylint: enable=bad-whitespace 
     1150    # pylint: enable=bad-whitespace,C0321 
    12421151 
    12431152    # Magnetism forces 2D for now 
     
    13141223 
    13151224def set_spherical_integration_parameters(opts, steps): 
     1225    # type: (Dict[str, Any], int) -> None 
    13161226    """ 
    13171227    Set integration parameters for spherical integration over the entire 
     
    13371247            'psi_pd_type=rectangle', 
    13381248        ]) 
    1339         pass 
    13401249 
    13411250def parse_pars(opts, maxdim=np.inf): 
     1251    # type: (Dict[str, Any], float) -> Tuple[Dict[str, float], Dict[str, float]] 
     1252    """ 
     1253    Generate a parameter set. 
     1254 
     1255    The default values come from the model, or a randomized model if a seed 
     1256    value is given.  Next, evaluate any parameter expressions, constraining 
     1257    the value of the parameter within and between models.  If *maxdim* is 
     1258    given, limit parameters with units of Angstrom to this value. 
     1259 
     1260    Returns a pair of parameter dictionaries for base and comparison models. 
     1261    """ 
    13421262    model_info, model_info2 = opts['info'] 
    13431263 
     
    13781298            print("%r invalid; parameters are: %s"%(k, ", ".join(sorted(s)))) 
    13791299            return None 
    1380         v1, v2 = v.split(PAR_SPLIT, 2) if PAR_SPLIT in v else (v,v) 
     1300        v1, v2 = v.split(PAR_SPLIT, 2) if PAR_SPLIT in v else (v, v) 
    13811301        if v1 and k in pars: 
    13821302            presets[k] = float(v1) if isnumber(v1) else v1 
     
    14271347    show html docs for the model 
    14281348    """ 
    1429     import os 
    14301349    from .generate import make_html 
    14311350    from . import rst2html 
     
    14341353    html = make_html(info) 
    14351354    path = os.path.dirname(info.filename) 
    1436     url = "file://"+path.replace("\\","/")[2:]+"/" 
     1355    url = "file://" + path.replace("\\", "/")[2:] + "/" 
    14371356    rst2html.view_html_qtapp(html, url) 
    14381357 
     
    14581377    frame.panel.Layout() 
    14591378    frame.panel.aui.Split(0, wx.TOP) 
    1460     def reset_parameters(event): 
     1379    def _reset_parameters(event): 
    14611380        model.revert_values() 
    14621381        signal.update_parameters(problem) 
    1463     frame.Bind(wx.EVT_TOOL, reset_parameters, frame.ToolBar.GetToolByPos(1)) 
    1464     if is_mac: frame.Show() 
     1382    frame.Bind(wx.EVT_TOOL, _reset_parameters, frame.ToolBar.GetToolByPos(1)) 
     1383    if is_mac: 
     1384        frame.Show() 
    14651385    # If running withing an app, start the main loop 
    14661386    if app: 
     
    15041424 
    15051425    def revert_values(self): 
     1426        # type: () -> None 
     1427        """ 
     1428        Restore starting values of the parameters. 
     1429        """ 
    15061430        for k, v in self.starting_values.items(): 
    15071431            self.pars[k].value = v 
    15081432 
    15091433    def model_update(self): 
     1434        # type: () -> None 
     1435        """ 
     1436        Respond to signal that model parameters have been changed. 
     1437        """ 
    15101438        pass 
    15111439 
Note: See TracChangeset for help on using the changeset viewer.