Changeset 6d6508e in sasmodels for sasmodels/core.py


Ignore:
Timestamp:
Apr 7, 2016 6:57:33 PM (8 years ago)
Author:
Paul Kienzle <pkienzle@…>
Branches:
master, core_shell_microgels, costrafo411, magnetic_model, release_v0.94, release_v0.95, ticket-1257-vesicle-product, ticket_1156, ticket_1265_superball, ticket_822_more_unit_tests
Children:
d2fc9a4
Parents:
3707eee
Message:

refactor model_info from dictionary to class

File:
1 edited

Legend:

Unmodified
Added
Removed
  • sasmodels/core.py

    ree8f734 r6d6508e  
    22Core model handling routines. 
    33""" 
     4from __future__ import print_function 
     5 
    46__all__ = [ 
    5     "list_models", "load_model_info", "precompile_dll", 
    6     "build_model", "make_kernel", "call_kernel", "call_ER_VR", 
     7    "list_models", "load_model", "load_model_info", 
     8    "build_model", "precompile_dll", 
    79    ] 
    810 
    9 from os.path import basename, dirname, join as joinpath, splitext 
     11from os.path import basename, dirname, join as joinpath 
    1012from glob import glob 
    1113 
    1214import numpy as np 
    1315 
    14 from . import models 
    15 from . import weights 
    1616from . import generate 
    17 # TODO: remove circular references between product and core 
    18 # product uses call_ER/call_VR, core uses make_product_info/ProductModel 
    19 #from . import product 
     17from . import modelinfo 
     18from . import product 
    2019from . import mixture 
    2120from . import kernelpy 
     
    2625except Exception: 
    2726    HAVE_OPENCL = False 
    28  
    29 try: 
    30     np.meshgrid([]) 
    31     meshgrid = np.meshgrid 
    32 except ValueError: 
    33     # CRUFT: np.meshgrid requires multiple vectors 
    34     def meshgrid(*args): 
    35         if len(args) > 1: 
    36             return np.meshgrid(*args) 
    37         else: 
    38             return [np.asarray(v) for v in args] 
    3927 
    4028# TODO: refactor composite model support 
     
    8876    parts = model_name.split('*') 
    8977    if len(parts) > 1: 
    90         from . import product 
    91         # Note: currently have circular reference 
    9278        if len(parts) > 2: 
    9379            raise ValueError("use P*S to apply structure factor S to model P") 
     
    9682 
    9783    kernel_module = generate.load_kernel_module(model_name) 
    98     return generate.make_model_info(kernel_module) 
     84    return modelinfo.make_model_info(kernel_module) 
    9985 
    10086 
     
    118104    otherwise it uses the default "ocl". 
    119105    """ 
    120     composition = model_info.get('composition', None) 
     106    composition = model_info.composition 
    121107    if composition is not None: 
    122108        composition_type, parts = composition 
     
    137123    ##  4. rerun "python -m sasmodels.direct_model $MODELNAME" 
    138124    ##  5. uncomment open().read() so that source will be regenerated from model 
    139     # open(model_info['name']+'.c','w').write(source) 
    140     # source = open(model_info['name']+'.cl','r').read() 
     125    # open(model_info.name+'.c','w').write(source) 
     126    # source = open(model_info.name+'.cl','r').read() 
    141127    source = generate.make_source(model_info) 
    142128    if dtype is None: 
    143         dtype = 'single' if model_info['single'] else 'double' 
    144     if callable(model_info.get('Iq', None)): 
     129        dtype = 'single' if model_info.single else 'double' 
     130    if callable(model_info.Iq): 
    145131        return kernelpy.PyModel(model_info) 
    146132    if (platform == "dll" 
     
    168154    source = generate.make_source(model_info) 
    169155    return kerneldll.make_dll(source, model_info, dtype=dtype) if source else None 
    170  
    171  
    172 def get_weights(parameter, values): 
    173     """ 
    174     Generate the distribution for parameter *name* given the parameter values 
    175     in *pars*. 
    176  
    177     Uses "name", "name_pd", "name_pd_type", "name_pd_n", "name_pd_sigma" 
    178     from the *pars* dictionary for parameter value and parameter dispersion. 
    179     """ 
    180     value = float(values.get(parameter.name, parameter.default)) 
    181     relative = parameter.relative_pd 
    182     limits = parameter.limits 
    183     disperser = values.get(parameter.name+'_pd_type', 'gaussian') 
    184     npts = values.get(parameter.name+'_pd_n', 0) 
    185     width = values.get(parameter.name+'_pd', 0.0) 
    186     nsigma = values.get(parameter.name+'_pd_nsigma', 3.0) 
    187     if npts == 0 or width == 0: 
    188         return [value], [] 
    189     value, weight = weights.get_weights( 
    190         disperser, npts, width, nsigma, value, limits, relative) 
    191     return value, weight / np.sum(weight) 
    192  
    193 def dispersion_mesh(pars): 
    194     """ 
    195     Create a mesh grid of dispersion parameters and weights. 
    196  
    197     Returns [p1,p2,...],w where pj is a vector of values for parameter j 
    198     and w is a vector containing the products for weights for each 
    199     parameter set in the vector. 
    200     """ 
    201     value, weight = zip(*pars) 
    202     weight = [w if w else [1.] for w in weight] 
    203     value = [v.flatten() for v in meshgrid(*value)] 
    204     weight = np.vstack([v.flatten() for v in meshgrid(*weight)]) 
    205     weight = np.prod(weight, axis=0) 
    206     return value, weight 
    207  
    208 def call_kernel(kernel, pars, cutoff=0, mono=False): 
    209     """ 
    210     Call *kernel* returned from *model.make_kernel* with parameters *pars*. 
    211  
    212     *cutoff* is the limiting value for the product of dispersion weights used 
    213     to perform the multidimensional dispersion calculation more quickly at a 
    214     slight cost to accuracy. The default value of *cutoff=0* integrates over 
    215     the entire dispersion cube.  Using *cutoff=1e-5* can be 50% faster, but 
    216     with an error of about 1%, which is usually less than the measurement 
    217     uncertainty. 
    218  
    219     *mono* is True if polydispersity should be set to none on all parameters. 
    220     """ 
    221     parameters = kernel.info['parameters'] 
    222     if mono: 
    223         active = lambda name: False 
    224     elif kernel.dim == '1d': 
    225         active = lambda name: name in parameters.pd_1d 
    226     elif kernel.dim == '2d': 
    227         active = lambda name: name in parameters.pd_2d 
    228     else: 
    229         active = lambda name: True 
    230  
    231     vw_pairs = [(get_weights(p, pars) if active(p.name) 
    232                  else ([pars.get(p.name, p.default)], [])) 
    233                 for p in parameters.call_parameters] 
    234  
    235     details, weights, values = build_details(kernel, vw_pairs) 
    236     return kernel(details, weights, values, cutoff) 
    237  
    238 def build_details(kernel, pairs): 
    239     values, weights = zip(*pairs) 
    240     if max([len(w) for w in weights]) > 1: 
    241         details = generate.poly_details(kernel.info, weights) 
    242     else: 
    243         details = kernel.info['mono_details'] 
    244     weights, values = [np.hstack(v) for v in (weights, values)] 
    245     weights = weights.astype(dtype=kernel.dtype) 
    246     values = values.astype(dtype=kernel.dtype) 
    247     return details, weights, values 
    248  
    249  
    250 def call_ER_VR(model_info, vol_pars): 
    251     """ 
    252     Return effect radius and volume ratio for the model. 
    253  
    254     *info* is either *kernel.info* for *kernel=make_kernel(model,q)* 
    255     or *model.info*. 
    256  
    257     *pars* are the parameters as expected by :func:`call_kernel`. 
    258     """ 
    259     ER = model_info.get('ER', None) 
    260     VR = model_info.get('VR', None) 
    261     value, weight = dispersion_mesh(vol_pars) 
    262  
    263     individual_radii = ER(*value) if ER else 1.0 
    264     whole, part = VR(*value) if VR else (1.0, 1.0) 
    265  
    266     effect_radius = np.sum(weight*individual_radii) / np.sum(weight) 
    267     volume_ratio = np.sum(weight*part)/np.sum(weight*whole) 
    268     return effect_radius, volume_ratio 
    269  
    270  
    271 def call_ER(model_info, values): 
    272     """ 
    273     Call the model ER function using *values*. *model_info* is either 
    274     *model.info* if you have a loaded model, or *kernel.info* if you 
    275     have a model kernel prepared for evaluation. 
    276     """ 
    277     ER = model_info.get('ER', None) 
    278     if ER is None: 
    279         return 1.0 
    280     else: 
    281         vol_pars = [get_weights(parameter, values) 
    282                     for parameter in model_info['parameters'].call_parameters 
    283                     if parameter.type == 'volume'] 
    284         value, weight = dispersion_mesh(vol_pars) 
    285         individual_radii = ER(*value) 
    286         return np.sum(weight*individual_radii) / np.sum(weight) 
    287  
    288 def call_VR(model_info, values): 
    289     """ 
    290     Call the model VR function using *pars*. 
    291     *info* is either *model.info* if you have a loaded model, or *kernel.info* 
    292     if you have a model kernel prepared for evaluation. 
    293     """ 
    294     VR = model_info.get('VR', None) 
    295     if VR is None: 
    296         return 1.0 
    297     else: 
    298         vol_pars = [get_weights(parameter, values) 
    299                     for parameter in model_info['parameters'].call_parameters 
    300                     if parameter.type == 'volume'] 
    301         value, weight = dispersion_mesh(vol_pars) 
    302         whole, part = VR(*value) 
    303         return np.sum(weight*part)/np.sum(weight*whole) 
    304  
    305 # TODO: remove call_ER, call_VR 
    306  
Note: See TracChangeset for help on using the changeset viewer.