source: sasmodels/sasmodels/sasview_model.py @ 6d6508e

core_shell_microgelscostrafo411magnetic_modelrelease_v0.94release_v0.95ticket-1257-vesicle-productticket_1156ticket_1265_superballticket_822_more_unit_tests
Last change on this file since 6d6508e was 6d6508e, checked in by Paul Kienzle <pkienzle@…>, 8 years ago

refactor model_info from dictionary to class

  • Property mode set to 100644
File size: 14.7 KB
RevLine 
[87985ca]1"""
2Sasview model constructor.
3
4Given a module defining an OpenCL kernel such as sasmodels.models.cylinder,
5create a sasview model class to run that kernel as follows::
6
7    from sasmodels.sasview_model import make_class
8    from sasmodels.models import cylinder
9    CylinderModel = make_class(cylinder, dtype='single')
10
11The model parameters for sasmodels are different from those in sasview.
12When reloading previously saved models, the parameters should be converted
13using :func:`sasmodels.convert.convert`.
14"""
[4d76711]15from __future__ import print_function
[87985ca]16
[ce27e21]17import math
18from copy import deepcopy
[2622b3f]19import collections
[4d76711]20import traceback
21import logging
[ce27e21]22
23import numpy as np
24
[aa4946b]25from . import core
[4d76711]26from . import custom
[72a081d]27from . import generate
[fb5914f]28from . import weights
[6d6508e]29from . import details
30from . import modelinfo
[ff7119b]31
[4d76711]32def load_standard_models():
33    """
34    Load and return the list of predefined models.
35
36    If there is an error loading a model, then a traceback is logged and the
37    model is not returned.
38    """
39    models = []
40    for name in core.list_models():
41        try:
42            models.append(_make_standard_model(name))
[ee8f734]43        except Exception:
[4d76711]44            logging.error(traceback.format_exc())
45    return models
[de97440]46
[4d76711]47
48def load_custom_model(path):
49    """
50    Load a custom model given the model path.
[ff7119b]51    """
[4d76711]52    kernel_module = custom.load_custom_kernel_module(path)
[6d6508e]53    model_info = modelinfo.make_model_info(kernel_module)
[4d76711]54    return _make_model_from_info(model_info)
55
[87985ca]56
[4d76711]57def _make_standard_model(name):
[ff7119b]58    """
[4d76711]59    Load the sasview model defined by *name*.
[72a081d]60
[4d76711]61    *name* can be a standard model name or a path to a custom model.
[87985ca]62
[4d76711]63    Returns a class that can be used directly as a sasview model.
[ff7119b]64    """
[4d76711]65    kernel_module = generate.load_kernel_module(name)
[6d6508e]66    model_info = modelinfo.make_model_info(kernel_module)
[4d76711]67    return _make_model_from_info(model_info)
[72a081d]68
69
[4d76711]70def _make_model_from_info(model_info):
71    """
72    Convert *model_info* into a SasView model wrapper.
73    """
[32c160a]74    def __init__(self, multfactor=1):
[08376e7]75        SasviewModel.__init__(self)
76    attrs = dict(__init__=__init__, _model_info=model_info)
[6d6508e]77    ConstructedModel = type(model_info.name, (SasviewModel,), attrs)
[ce27e21]78    return ConstructedModel
79
[4d76711]80
[ce27e21]81class SasviewModel(object):
82    """
83    Sasview wrapper for opencl/ctypes model.
84    """
[4d76711]85    _model_info = {}
[08376e7]86    def __init__(self):
[fb5914f]87        self._model = None
[08376e7]88        model_info = self._model_info
[6d6508e]89        parameters = model_info.parameters
[ce27e21]90
[6d6508e]91        self.name = model_info.name
92        self.description = model_info.description
[ce27e21]93        self.category = None
[ce896fd]94        #self.is_multifunc = False
95        for p in parameters.kernel_parameters:
96            if p.is_control:
[6d6508e]97                profile_axes = model_info.profile_axes
[ce896fd]98                self.multiplicity_info = [
99                    p.limits[1], p.name, p.choices, profile_axes[0]
100                    ]
101                break
102        else:
103            self.multiplicity_info = []
[ce27e21]104
105        ## interpret the parameters
106        ## TODO: reorganize parameter handling
107        self.details = dict()
[2622b3f]108        self.params = collections.OrderedDict()
[ce27e21]109        self.dispersion = dict()
[2622b3f]110
[fb5914f]111        self.orientation_params = []
112        self.magnetic_params = []
113        self.fixed = []
114        for p in parameters.user_parameters():
[fcd7bbd]115            self.params[p.name] = p.default
116            self.details[p.name] = [p.units] + p.limits
[fb5914f]117            if p.polydisperse:
118                self.dispersion[p.name] = {
119                    'width': 0,
120                    'npts': 35,
121                    'nsigmas': 3,
122                    'type': 'gaussian',
123                }
124            if p.type == 'orientation':
125                self.orientation_params.append(p.name)
126                self.orientation_params.append(p.name+".width")
127                self.fixed.append(p.name+".width")
128            if p.type == 'magnetic':
129                self.orientation_params.append(p.name)
130                self.magnetic_params.append(p.name)
131                self.fixed.append(p.name+".width")
[ce27e21]132
133        self.non_fittable = []
134
135        ## independent parameter name and unit [string]
[6d6508e]136        self.input_name = "Q", #model_info.get("input_name", "Q")
137        self.input_unit = "A^{-1}" #model_info.get("input_unit", "A^{-1}")
138        self.output_name = "Intensity" #model_info.get("output_name", "Intensity")
139        self.output_unit = "cm^{-1}" #model_info.get("output_unit", "cm^{-1}")
[ce27e21]140
[87c722e]141        ## _persistency_dict is used by sas.perspectives.fitting.basepage
[ce27e21]142        ## to store dispersity reference.
143        ## TODO: _persistency_dict to persistency_dict throughout sasview
144        self._persistency_dict = {}
145
146        ## New fields introduced for opencl rewrite
147        self.cutoff = 1e-5
148
[de97440]149    def __get_state__(self):
150        state = self.__dict__.copy()
[4d76711]151        state.pop('_model')
[de97440]152        # May need to reload model info on set state since it has pointers
153        # to python implementations of Iq, etc.
154        #state.pop('_model_info')
155        return state
156
157    def __set_state__(self, state):
158        self.__dict__ = state
[fb5914f]159        self._model = None
[de97440]160
[ce27e21]161    def __str__(self):
162        """
163        :return: string representation
164        """
165        return self.name
166
167    def is_fittable(self, par_name):
168        """
169        Check if a given parameter is fittable or not
170
171        :param par_name: the parameter name to check
172        """
173        return par_name.lower() in self.fixed
174        #For the future
175        #return self.params[str(par_name)].is_fittable()
176
177
[3c56da87]178    # pylint: disable=no-self-use
[ce27e21]179    def getProfile(self):
180        """
181        Get SLD profile
182
183        : return: (z, beta) where z is a list of depth of the transition points
184                beta is a list of the corresponding SLD values
185        """
186        return None, None
187
188    def setParam(self, name, value):
189        """
190        Set the value of a model parameter
191
192        :param name: name of the parameter
193        :param value: value of the parameter
194
195        """
196        # Look for dispersion parameters
197        toks = name.split('.')
[de0c4ba]198        if len(toks) == 2:
[ce27e21]199            for item in self.dispersion.keys():
[de0c4ba]200                if item.lower() == toks[0].lower():
[ce27e21]201                    for par in self.dispersion[item]:
202                        if par.lower() == toks[1].lower():
203                            self.dispersion[item][par] = value
204                            return
205        else:
206            # Look for standard parameter
207            for item in self.params.keys():
[de0c4ba]208                if item.lower() == name.lower():
[ce27e21]209                    self.params[item] = value
210                    return
211
[63b32bb]212        raise ValueError("Model does not contain parameter %s" % name)
[ce27e21]213
214    def getParam(self, name):
215        """
216        Set the value of a model parameter
217
218        :param name: name of the parameter
219
220        """
221        # Look for dispersion parameters
222        toks = name.split('.')
[de0c4ba]223        if len(toks) == 2:
[ce27e21]224            for item in self.dispersion.keys():
[de0c4ba]225                if item.lower() == toks[0].lower():
[ce27e21]226                    for par in self.dispersion[item]:
227                        if par.lower() == toks[1].lower():
228                            return self.dispersion[item][par]
229        else:
230            # Look for standard parameter
231            for item in self.params.keys():
[de0c4ba]232                if item.lower() == name.lower():
[ce27e21]233                    return self.params[item]
234
[63b32bb]235        raise ValueError("Model does not contain parameter %s" % name)
[ce27e21]236
237    def getParamList(self):
238        """
239        Return a list of all available parameters for the model
240        """
[de0c4ba]241        param_list = self.params.keys()
[ce27e21]242        # WARNING: Extending the list with the dispersion parameters
[de0c4ba]243        param_list.extend(self.getDispParamList())
244        return param_list
[ce27e21]245
246    def getDispParamList(self):
247        """
[fb5914f]248        Return a list of polydispersity parameters for the model
[ce27e21]249        """
[1780d59]250        # TODO: fix test so that parameter order doesn't matter
[fb5914f]251        ret = ['%s.%s' % (p.name.lower(), ext)
[6d6508e]252               for p in self._model_info.parameters.user_parameters()
[fb5914f]253               for ext in ('npts', 'nsigmas', 'width')
254               if p.polydisperse]
[9404dd3]255        #print(ret)
[1780d59]256        return ret
[ce27e21]257
258    def clone(self):
259        """ Return a identical copy of self """
260        return deepcopy(self)
261
262    def run(self, x=0.0):
263        """
264        Evaluate the model
265
266        :param x: input q, or [q,phi]
267
268        :return: scattering function P(q)
269
270        **DEPRECATED**: use calculate_Iq instead
271        """
[de0c4ba]272        if isinstance(x, (list, tuple)):
[3c56da87]273            # pylint: disable=unpacking-non-sequence
[ce27e21]274            q, phi = x
275            return self.calculate_Iq([q * math.cos(phi)],
276                                     [q * math.sin(phi)])[0]
277        else:
278            return self.calculate_Iq([float(x)])[0]
279
280
281    def runXY(self, x=0.0):
282        """
283        Evaluate the model in cartesian coordinates
284
285        :param x: input q, or [qx, qy]
286
287        :return: scattering function P(q)
288
289        **DEPRECATED**: use calculate_Iq instead
290        """
[de0c4ba]291        if isinstance(x, (list, tuple)):
292            return self.calculate_Iq([float(x[0])], [float(x[1])])[0]
[ce27e21]293        else:
294            return self.calculate_Iq([float(x)])[0]
295
296    def evalDistribution(self, qdist):
[d138d43]297        r"""
[ce27e21]298        Evaluate a distribution of q-values.
299
[d138d43]300        :param qdist: array of q or a list of arrays [qx,qy]
[ce27e21]301
[d138d43]302        * For 1D, a numpy array is expected as input
[ce27e21]303
[d138d43]304        ::
[ce27e21]305
[d138d43]306            evalDistribution(q)
[ce27e21]307
[d138d43]308          where *q* is a numpy array.
[ce27e21]309
[d138d43]310        * For 2D, a list of *[qx,qy]* is expected with 1D arrays as input
[ce27e21]311
[d138d43]312        ::
[ce27e21]313
[d138d43]314              qx = [ qx[0], qx[1], qx[2], ....]
315              qy = [ qy[0], qy[1], qy[2], ....]
[ce27e21]316
[d138d43]317        If the model is 1D only, then
[ce27e21]318
[d138d43]319        .. math::
[ce27e21]320
[d138d43]321            q = \sqrt{q_x^2+q_y^2}
[ce27e21]322
323        """
[de0c4ba]324        if isinstance(qdist, (list, tuple)):
[ce27e21]325            # Check whether we have a list of ndarrays [qx,qy]
326            qx, qy = qdist
[6d6508e]327            if not self._model_info.parameters.has_2d:
[de0c4ba]328                return self.calculate_Iq(np.sqrt(qx ** 2 + qy ** 2))
[5d4777d]329            else:
330                return self.calculate_Iq(qx, qy)
[ce27e21]331
332        elif isinstance(qdist, np.ndarray):
333            # We have a simple 1D distribution of q-values
334            return self.calculate_Iq(qdist)
335
336        else:
[3c56da87]337            raise TypeError("evalDistribution expects q or [qx, qy], not %r"
338                            % type(qdist))
[ce27e21]339
340    def calculate_Iq(self, *args):
[ff7119b]341        """
342        Calculate Iq for one set of q with the current parameters.
343
344        If the model is 1D, use *q*.  If 2D, use *qx*, *qy*.
345
346        This should NOT be used for fitting since it copies the *q* vectors
347        to the card for each evaluation.
348        """
[fb5914f]349        if self._model is None:
[d2bb604]350            self._model = core.build_model(self._model_info)
[ce27e21]351        q_vectors = [np.asarray(q) for q in args]
[fb5914f]352        kernel = self._model.make_kernel(q_vectors)
353        pairs = [self._get_weights(p)
[6d6508e]354                 for p in self._model_info.parameters.call_parameters]
355        call_details, weights, values = details.build_details(kernel, pairs)
356        result = kernel(call_details, weights, values, cutoff=self.cutoff)
[fb5914f]357        kernel.q_input.release()
358        kernel.release()
[ce27e21]359        return result
360
361    def calculate_ER(self):
362        """
363        Calculate the effective radius for P(q)*S(q)
364
365        :return: the value of the effective radius
366        """
[6d6508e]367        if model_info.ER is None:
[ce27e21]368            return 1.0
369        else:
[aa4946b]370            values, weights = self._dispersion_mesh()
[6d6508e]371            fv = model_info.ER(*values)
[9404dd3]372            #print(values[0].shape, weights.shape, fv.shape)
[de0c4ba]373            return np.sum(weights * fv) / np.sum(weights)
[ce27e21]374
375    def calculate_VR(self):
376        """
377        Calculate the volf ratio for P(q)*S(q)
378
379        :return: the value of the volf ratio
380        """
[6d6508e]381        if model_info.VR is None:
[ce27e21]382            return 1.0
383        else:
[aa4946b]384            values, weights = self._dispersion_mesh()
[6d6508e]385            whole, part = model_info.VR(*values)
[de0c4ba]386            return np.sum(weights * part) / np.sum(weights * whole)
[ce27e21]387
388    def set_dispersion(self, parameter, dispersion):
389        """
390        Set the dispersion object for a model parameter
391
392        :param parameter: name of the parameter [string]
393        :param dispersion: dispersion object of type Dispersion
394        """
[1780d59]395        if parameter.lower() in (s.lower() for s in self.params.keys()):
396            # TODO: Store the disperser object directly in the model.
397            # The current method of creating one on the fly whenever it is
398            # needed is kind of funky.
399            # Note: can't seem to get disperser parameters from sasview
400            # (1) Could create a sasview model that has not yet # been
401            # converted, assign the disperser to one of its polydisperse
402            # parameters, then retrieve the disperser parameters from the
403            # sasview model.  (2) Could write a disperser parameter retriever
404            # in sasview.  (3) Could modify sasview to use sasmodels.weights
405            # dispersers.
406            # For now, rely on the fact that the sasview only ever uses
407            # new dispersers in the set_dispersion call and create a new
408            # one instead of trying to assign parameters.
409            from . import weights
410            disperser = weights.dispersers[dispersion.__class__.__name__]
411            dispersion = weights.models[disperser]()
[ce27e21]412            self.dispersion[parameter] = dispersion.get_pars()
413        else:
414            raise ValueError("%r is not a dispersity or orientation parameter")
415
[aa4946b]416    def _dispersion_mesh(self):
[ce27e21]417        """
418        Create a mesh grid of dispersion parameters and weights.
419
420        Returns [p1,p2,...],w where pj is a vector of values for parameter j
421        and w is a vector containing the products for weights for each
422        parameter set in the vector.
423        """
[6d6508e]424        pars = self._model_info.partype['volume']
425        return details.dispersion_mesh([self._get_weights(p) for p in pars])
[ce27e21]426
427    def _get_weights(self, par):
[de0c4ba]428        """
[fb5914f]429        Return dispersion weights for parameter
[de0c4ba]430        """
[fb5914f]431        if par.polydisperse:
432            dis = self.dispersion[par.name]
433            value, weight = weights.get_weights(
434                dis['type'], dis['npts'], dis['width'], dis['nsigmas'],
435                self.params[par.name], par.limits, par.relative_pd)
436            return value, weight / np.sum(weight)
437        else:
438            return [self.params[par.name]], []
[ce27e21]439
[fb5914f]440def test_model():
[4d76711]441    """
442    Test that a sasview model (cylinder) can be run.
443    """
444    Cylinder = _make_standard_model('cylinder')
[fb5914f]445    cylinder = Cylinder()
446    return cylinder.evalDistribution([0.1,0.1])
[de97440]447
[4d76711]448
449def test_model_list():
450    """
451    Make sure that all models build as sasview models.
452    """
453    from .exception import annotate_exception
454    for name in core.list_models():
455        try:
456            _make_standard_model(name)
457        except:
458            annotate_exception("when loading "+name)
459            raise
460
[fb5914f]461if __name__ == "__main__":
[ea05c87]462    print("cylinder(0.1,0.1)=%g"%test_model())
Note: See TracBrowser for help on using the repository browser.