Changeset ec7e360 in sasmodels for sasmodels/bumps_model.py


Ignore:
Timestamp:
Dec 23, 2015 10:17:49 AM (8 years ago)
Author:
Paul Kienzle <pkienzle@…>
Branches:
master, core_shell_microgels, costrafo411, magnetic_model, release_v0.94, release_v0.95, ticket-1257-vesicle-product, ticket_1156, ticket_1265_superball, ticket_822_more_unit_tests
Children:
e21cc31
Parents:
ce166d3
Message:

refactor option processing for compare.py, allowing more flexible selection of calculation engines

File:
1 edited

Legend:

Unmodified
Added
Removed
  • sasmodels/bumps_model.py

    r9404dd3 rec7e360  
    1212""" 
    1313 
    14 import datetime 
    1514import warnings 
    1615 
    1716import numpy as np 
    1817 
    19 from . import sesans 
    20 from . import weights 
    2118from .data import plot_theory 
    2219from .direct_model import DataMixin 
     
    3128    return experiment 
    3229 
     30def create_parameters(model_info, **kwargs): 
     31    # lazy import; this allows the doc builder and nosetests to run even 
     32    # when bumps is not on the path. 
     33    from bumps.names import Parameter 
     34 
     35    partype = model_info['partype'] 
     36 
     37    pars = {} 
     38    for p in model_info['parameters']: 
     39        name, default, limits = p[0], p[2], p[3] 
     40        value = kwargs.pop(name, default) 
     41        pars[name] = Parameter.default(value, name=name, limits=limits) 
     42    for name in partype['pd-2d']: 
     43        for xpart, xdefault, xlimits in [ 
     44            ('_pd', 0., limits), 
     45            ('_pd_n', 35., (0, 1000)), 
     46            ('_pd_nsigma', 3., (0, 10)), 
     47        ]: 
     48            xname = name + xpart 
     49            xvalue = kwargs.pop(xname, xdefault) 
     50            pars[xname] = Parameter.default(xvalue, name=xname, limits=xlimits) 
     51 
     52    pd_types = {} 
     53    for name in partype['pd-2d']: 
     54        xname = name + '_pd_type' 
     55        xvalue = kwargs.pop(xname, 'gaussian') 
     56        pd_types[xname] = xvalue 
     57 
     58    if kwargs:  # args not corresponding to parameters 
     59        raise TypeError("unexpected parameters: %s" 
     60                        % (", ".join(sorted(kwargs.keys())))) 
     61 
     62    return pars, pd_types 
    3363 
    3464class Model(object): 
    35     def __init__(self, model, **kw): 
    36         # lazy import; this allows the doc builder and nosetests to run even 
    37         # when bumps is not on the path. 
    38         from bumps.names import Parameter 
    39  
     65    def __init__(self, model, **kwargs): 
    4066        self._sasmodel = model 
    41         partype = model.info['partype'] 
    42  
    43         pars = [] 
    44         for p in model.info['parameters']: 
    45             name, default, limits = p[0], p[2], p[3] 
    46             value = kw.pop(name, default) 
    47             setattr(self, name, Parameter.default(value, name=name, limits=limits)) 
    48             pars.append(name) 
    49         for name in partype['pd-2d']: 
    50             for xpart, xdefault, xlimits in [ 
    51                 ('_pd', 0, limits), 
    52                 ('_pd_n', 35, (0, 1000)), 
    53                 ('_pd_nsigma', 3, (0, 10)), 
    54                 ('_pd_type', 'gaussian', None), 
    55                 ]: 
    56                 xname = name + xpart 
    57                 xvalue = kw.pop(xname, xdefault) 
    58                 if xlimits is not None: 
    59                     xvalue = Parameter.default(xvalue, name=xname, limits=xlimits) 
    60                     pars.append(xname) 
    61                 setattr(self, xname, xvalue) 
    62         self._parameter_names = pars 
    63         if kw: 
    64             raise TypeError("unexpected parameters: %s" 
    65                             % (", ".join(sorted(kw.keys())))) 
     67        pars, pd_types = create_parameters(model.info, **kwargs) 
     68        for k,v in pars.items(): 
     69            setattr(self, k, v) 
     70        for k,v in pd_types.items(): 
     71            setattr(self, k, v) 
     72        self._parameter_names = list(pars.keys()) 
     73        self._pd_type_names = list(pd_types.keys()) 
    6674 
    6775    def parameters(self): 
     
    7179        return dict((k, getattr(self, k)) for k in self._parameter_names) 
    7280 
     81    def state(self): 
     82        pars = dict((k, getattr(self, k).value) for k in self._parameter_names) 
     83        pars.update((k, getattr(self, k)) for k in self._pd_type_names) 
     84        return pars 
    7385 
    7486class Experiment(DataMixin): 
     
    113125    def theory(self): 
    114126        if 'theory' not in self._cache: 
    115             pars = dict((k, v.value) for k,v in self.model.parameters().items()) 
     127            pars = self.model.state() 
    116128            self._cache['theory'] = self._calc_theory(pars, cutoff=self.cutoff) 
    117             """ 
    118             if self._fn is None: 
    119                 q_input = self.model.kernel.make_input(self._kernel_inputs) 
    120                 self._fn = self.model.kernel(q_input) 
    121  
    122             fixed_pars = [getattr(self.model, p).value for p in self._fn.fixed_pars] 
    123             pd_pars = [self._get_weights(p) for p in self._fn.pd_pars] 
    124             #print(fixed_pars,pd_pars) 
    125             Iq_calc = self._fn(fixed_pars, pd_pars, self.cutoff) 
    126             #self._theory[:] = self._fn.eval(pars, pd_pars) 
    127             if self.data_type == 'sesans': 
    128                 result = sesans.hankel(self.data.x, self.data.lam * 1e-9, 
    129                                        self.data.sample.thickness / 10, 
    130                                        self._kernel_inputs[0], Iq_calc) 
    131                 self._cache['theory'] = result 
    132             else: 
    133                 Iq = self.resolution.apply(Iq_calc) 
    134                 self._cache['theory'] = Iq 
    135             """ 
    136129        return self._cache['theory'] 
    137130 
     
    162155        pass 
    163156 
    164     def remove_get_weights(self, name): 
    165         """ 
    166         Get parameter dispersion weights 
    167         """ 
    168         info = self.model.kernel.info 
    169         relative = name in info['partype']['pd-rel'] 
    170         limits = info['limits'][name] 
    171         disperser, value, npts, width, nsigma = [ 
    172             getattr(self.model, name + ext) 
    173             for ext in ('_pd_type', '', '_pd_n', '_pd', '_pd_nsigma')] 
    174         value, weight = weights.get_weights( 
    175             disperser, int(npts.value), width.value, nsigma.value, 
    176             value.value, limits, relative) 
    177         return value, weight / np.sum(weight) 
    178  
    179157    def __getstate__(self): 
    180158        # Can't pickle gpu functions, so instead make them lazy 
Note: See TracChangeset for help on using the changeset viewer.