Changeset 17bbadd in sasmodels for sasmodels/compare.py


Ignore:
Timestamp:
Mar 15, 2016 10:47:12 AM (8 years ago)
Author:
Paul Kienzle <pkienzle@…>
Branches:
master, core_shell_microgels, costrafo411, magnetic_model, release_v0.94, release_v0.95, ticket-1257-vesicle-product, ticket_1156, ticket_1265_superball, ticket_822_more_unit_tests
Children:
754e27b
Parents:
5ceb7d0
Message:

refactor so all model defintion queries use model_info; better documentation of model_info structure; initial implementation of product model (broken)

File:
1 edited

Legend:

Unmodified
Added
Removed
  • sasmodels/compare.py

    r6869ceb r17bbadd  
    3838from . import core 
    3939from . import kerneldll 
    40 from . import generate 
     40from . import product 
    4141from .data import plot_theory, empty_data1D, empty_data2D 
    4242from .direct_model import DirectModel 
    43 from .convert import revert_model, constrain_new_to_old 
     43from .convert import revert_pars, constrain_new_to_old 
    4444 
    4545USAGE = """ 
     
    264264    return pars 
    265265 
    266 def constrain_pars(model_definition, pars): 
     266def constrain_pars(model_info, pars): 
    267267    """ 
    268268    Restrict parameters to valid values. 
     
    272272    cylinder radius in this case). 
    273273    """ 
    274     name = model_definition.name 
     274    name = model_info['id'] 
     275    # if it is a product model, then just look at the form factor since 
     276    # none of the structure factors need any constraints. 
     277    if '*' in name: 
     278        name = name.split('*')[0] 
     279 
    275280    if name == 'capped_cylinder' and pars['cap_radius'] < pars['radius']: 
    276281        pars['radius'], pars['cap_radius'] = pars['cap_radius'], pars['radius'] 
     
    340345    return pars 
    341346 
    342 def eval_sasview(model_definition, data): 
     347def eval_sasview(model_info, data): 
    343348    """ 
    344349    Return a model calculator using the SasView fitting engine. 
     
    349354    from sas.models.qsmearing import smear_selection 
    350355 
    351     # convert model parameters from sasmodel form to sasview form 
    352     #print("old",sorted(pars.items())) 
    353     modelname, _ = revert_model(model_definition, {}) 
    354     #print("new",sorted(_pars.items())) 
    355     sas = __import__('sas.models.'+modelname) 
    356     ModelClass = getattr(getattr(sas.models, modelname, None), modelname, None) 
    357     if ModelClass is None: 
    358         raise ValueError("could not find model %r in sas.models"%modelname) 
    359     model = ModelClass() 
     356    def get_model(name): 
     357        #print("new",sorted(_pars.items())) 
     358        sas = __import__('sas.models.' + name) 
     359        ModelClass = getattr(getattr(sas.models, name, None), name, None) 
     360        if ModelClass is None: 
     361            raise ValueError("could not find model %r in sas.models"%name) 
     362        return ModelClass() 
     363 
     364    # grab the sasview model, or create it if it is a product model 
     365    if model_info['composition']: 
     366        composition_type, parts = model_info['composition'] 
     367        if composition_type == 'product': 
     368            from sas.models import MultiplicationModel 
     369            P, S = [get_model(p) for p in model_info['oldname']] 
     370            model = MultiplicationModel(P, S) 
     371        else: 
     372            raise ValueError("mixture models not handled yet") 
     373    else: 
     374        model = get_model(model_info['oldname']) 
     375 
     376    # build a smearer with which to call the model, if necessary 
    360377    smearer = smear_selection(data, model=model) 
    361  
    362378    if hasattr(data, 'qx_data'): 
    363379        q = np.sqrt(data.qx_data**2 + data.qy_data**2) 
     
    382398        """ 
    383399        # paying for parameter conversion each time to keep life simple, if not fast 
    384         _, pars = revert_model(model_definition, pars) 
     400        pars = revert_pars(model_info, pars) 
    385401        for k, v in pars.items(): 
    386402            parts = k.split('.')  # polydispersity components 
     
    405421    'longdouble': '128', 
    406422} 
    407 def eval_opencl(model_definition, data, dtype='single', cutoff=0.): 
     423def eval_opencl(model_info, data, dtype='single', cutoff=0.): 
    408424    """ 
    409425    Return a model calculator using the OpenCL calculation engine. 
    410426    """ 
    411     try: 
    412         model = core.load_model(model_definition, dtype=dtype, platform="ocl") 
    413     except Exception as exc: 
    414         print(exc) 
    415         print("... trying again with single precision") 
    416         dtype = 'single' 
    417         model = core.load_model(model_definition, dtype=dtype, platform="ocl") 
     427    def builder(model_info): 
     428        try: 
     429            return core.build_model(model_info, dtype=dtype, platform="ocl") 
     430        except Exception as exc: 
     431            print(exc) 
     432            print("... trying again with single precision") 
     433            dtype = 'single' 
     434            return core.build_model(model_info, dtype=dtype, platform="ocl") 
     435    if model_info['composition']: 
     436        composition_type, parts = model_info['composition'] 
     437        if composition_type == 'product': 
     438            P, S = [builder(p) for p in parts] 
     439            model = product.ProductModel(P, S) 
     440        else: 
     441            raise ValueError("mixture models not handled yet") 
     442    else: 
     443        model = builder(model_info) 
    418444    calculator = DirectModel(data, model, cutoff=cutoff) 
    419445    calculator.engine = "OCL%s"%DTYPE_MAP[dtype] 
    420446    return calculator 
    421447 
    422 def eval_ctypes(model_definition, data, dtype='double', cutoff=0.): 
     448def eval_ctypes(model_info, data, dtype='double', cutoff=0.): 
    423449    """ 
    424450    Return a model calculator using the DLL calculation engine. 
     
    426452    if dtype == 'quad': 
    427453        dtype = 'longdouble' 
    428     model = core.load_model(model_definition, dtype=dtype, platform="dll") 
     454    def builder(model_info): 
     455        return core.build_model(model_info, dtype=dtype, platform="dll") 
     456 
     457    if model_info['composition']: 
     458        composition_type, parts = model_info['composition'] 
     459        if composition_type == 'product': 
     460            P, S = [builder(p) for p in parts] 
     461            model = product.ProductModel(P, S) 
     462        else: 
     463            raise ValueError("mixture models not handled yet") 
     464    else: 
     465        model = builder(model_info) 
    429466    calculator = DirectModel(data, model, cutoff=cutoff) 
    430467    calculator.engine = "OMP%s"%DTYPE_MAP[dtype] 
     
    470507    return data, index 
    471508 
    472 def make_engine(model_definition, data, dtype, cutoff): 
     509def make_engine(model_info, data, dtype, cutoff): 
    473510    """ 
    474511    Generate the appropriate calculation engine for the given datatype. 
     
    478515    """ 
    479516    if dtype == 'sasview': 
    480         return eval_sasview(model_definition, data) 
     517        return eval_sasview(model_info, data) 
    481518    elif dtype.endswith('!'): 
    482         return eval_ctypes(model_definition, data, dtype=dtype[:-1], 
    483                            cutoff=cutoff) 
    484     else: 
    485         return eval_opencl(model_definition, data, dtype=dtype, 
    486                            cutoff=cutoff) 
     519        return eval_ctypes(model_info, data, dtype=dtype[:-1], cutoff=cutoff) 
     520    else: 
     521        return eval_opencl(model_info, data, dtype=dtype, cutoff=cutoff) 
    487522 
    488523def compare(opts, limits=None): 
     
    642677 
    643678 
    644 def get_demo_pars(model_definition): 
     679def get_demo_pars(model_info): 
    645680    """ 
    646681    Extract demo parameters from the model definition. 
    647682    """ 
    648     info = generate.make_info(model_definition) 
    649683    # Get the default values for the parameters 
    650     pars = dict((p[0], p[2]) for p in info['parameters']) 
     684    pars = dict((p[0], p[2]) for p in model_info['parameters']) 
    651685 
    652686    # Fill in default values for the polydispersity parameters 
    653     for p in info['parameters']: 
     687    for p in model_info['parameters']: 
    654688        if p[4] in ('volume', 'orientation'): 
    655689            pars[p[0]+'_pd'] = 0.0 
     
    659693 
    660694    # Plug in values given in demo 
    661     pars.update(info['demo']) 
     695    pars.update(model_info['demo']) 
    662696    return pars 
     697 
    663698 
    664699def parse_opts(): 
     
    679714        print(columnize(MODELS, indent="  ")) 
    680715        sys.exit(1) 
    681  
    682     name = args[0] 
    683     try: 
    684         model_definition = core.load_model_definition(name) 
    685     except ImportError, exc: 
    686         print(str(exc)) 
    687         print("Use one of:\n    " + models) 
    688         sys.exit(1) 
    689716    if len(args) > 3: 
    690717        print("expected parameters: model N1 N2") 
     718 
     719    def load_model(name): 
     720        try: 
     721            model_info = core.load_model_info(name) 
     722        except ImportError, exc: 
     723            print(str(exc)) 
     724            print("Use one of:\n    " + models) 
     725            sys.exit(1) 
     726        return model_info 
     727 
     728    name = args[0] 
     729    if '*' in name: 
     730        parts = [load_model(k) for k in name.split('*')] 
     731        model_info = product.make_product_info(*parts) 
     732    else: 
     733        model_info = load_model(name) 
    691734 
    692735    invalid = [o[1:] for o in flags 
     
    770813    # Get demo parameters from model definition, or use default parameters 
    771814    # if model does not define demo parameters 
    772     pars = get_demo_pars(model_definition) 
     815    pars = get_demo_pars(model_info) 
    773816 
    774817    # Fill in parameters given on the command line 
     
    791834        pars = suppress_pd(pars) 
    792835    pars.update(presets)  # set value after random to control value 
    793     constrain_pars(model_definition, pars) 
    794     constrain_new_to_old(model_definition, pars) 
     836    constrain_pars(model_info, pars) 
     837    constrain_new_to_old(model_info, pars) 
    795838    if opts['show_pars']: 
    796839        print(str(parlist(pars))) 
     
    799842    data, _ = make_data(opts) 
    800843    if n1: 
    801         base = make_engine(model_definition, data, engines[0], opts['cutoff']) 
     844        base = make_engine(model_info, data, engines[0], opts['cutoff']) 
    802845    else: 
    803846        base = None 
    804847    if n2: 
    805         comp = make_engine(model_definition, data, engines[1], opts['cutoff']) 
     848        comp = make_engine(model_info, data, engines[1], opts['cutoff']) 
    806849    else: 
    807850        comp = None 
     
    811854    opts.update({ 
    812855        'name'      : name, 
    813         'def'       : model_definition, 
     856        'def'       : model_info, 
    814857        'n1'        : n1, 
    815858        'n2'        : n2, 
     
    854897        config_matplotlib() 
    855898        self.opts = opts 
    856         info = generate.make_info(opts['def']) 
    857         pars, pd_types = bumps_model.create_parameters(info, **opts['pars']) 
     899        model_info = opts['def'] 
     900        pars, pd_types = bumps_model.create_parameters(model_info, **opts['pars']) 
    858901        if not opts['is2d']: 
    859902            active = [base + ext 
    860                       for base in info['partype']['pd-1d'] 
     903                      for base in model_info['partype']['pd-1d'] 
    861904                      for ext in ['', '_pd', '_pd_n', '_pd_nsigma']] 
    862             active.extend(info['partype']['fixed-1d']) 
     905            active.extend(model_info['partype']['fixed-1d']) 
    863906            for k in active: 
    864907                v = pars[k] 
Note: See TracChangeset for help on using the changeset viewer.