source: sasmodels/sasmodels/sasview_model.py @ 6d6508e

core_shell_microgelscostrafo411magnetic_modelrelease_v0.94release_v0.95ticket-1257-vesicle-productticket_1156ticket_1265_superballticket_822_more_unit_tests
Last change on this file since 6d6508e was 6d6508e, checked in by Paul Kienzle <pkienzle@…>, 8 years ago

refactor model_info from dictionary to class

  • Property mode set to 100644
File size: 14.7 KB
Line 
1"""
2Sasview model constructor.
3
4Given a module defining an OpenCL kernel such as sasmodels.models.cylinder,
5create a sasview model class to run that kernel as follows::
6
7    from sasmodels.sasview_model import make_class
8    from sasmodels.models import cylinder
9    CylinderModel = make_class(cylinder, dtype='single')
10
11The model parameters for sasmodels are different from those in sasview.
12When reloading previously saved models, the parameters should be converted
13using :func:`sasmodels.convert.convert`.
14"""
15from __future__ import print_function
16
17import math
18from copy import deepcopy
19import collections
20import traceback
21import logging
22
23import numpy as np
24
25from . import core
26from . import custom
27from . import generate
28from . import weights
29from . import details
30from . import modelinfo
31
32def load_standard_models():
33    """
34    Load and return the list of predefined models.
35
36    If there is an error loading a model, then a traceback is logged and the
37    model is not returned.
38    """
39    models = []
40    for name in core.list_models():
41        try:
42            models.append(_make_standard_model(name))
43        except Exception:
44            logging.error(traceback.format_exc())
45    return models
46
47
48def load_custom_model(path):
49    """
50    Load a custom model given the model path.
51    """
52    kernel_module = custom.load_custom_kernel_module(path)
53    model_info = modelinfo.make_model_info(kernel_module)
54    return _make_model_from_info(model_info)
55
56
57def _make_standard_model(name):
58    """
59    Load the sasview model defined by *name*.
60
61    *name* can be a standard model name or a path to a custom model.
62
63    Returns a class that can be used directly as a sasview model.
64    """
65    kernel_module = generate.load_kernel_module(name)
66    model_info = modelinfo.make_model_info(kernel_module)
67    return _make_model_from_info(model_info)
68
69
70def _make_model_from_info(model_info):
71    """
72    Convert *model_info* into a SasView model wrapper.
73    """
74    def __init__(self, multfactor=1):
75        SasviewModel.__init__(self)
76    attrs = dict(__init__=__init__, _model_info=model_info)
77    ConstructedModel = type(model_info.name, (SasviewModel,), attrs)
78    return ConstructedModel
79
80
81class SasviewModel(object):
82    """
83    Sasview wrapper for opencl/ctypes model.
84    """
85    _model_info = {}
86    def __init__(self):
87        self._model = None
88        model_info = self._model_info
89        parameters = model_info.parameters
90
91        self.name = model_info.name
92        self.description = model_info.description
93        self.category = None
94        #self.is_multifunc = False
95        for p in parameters.kernel_parameters:
96            if p.is_control:
97                profile_axes = model_info.profile_axes
98                self.multiplicity_info = [
99                    p.limits[1], p.name, p.choices, profile_axes[0]
100                    ]
101                break
102        else:
103            self.multiplicity_info = []
104
105        ## interpret the parameters
106        ## TODO: reorganize parameter handling
107        self.details = dict()
108        self.params = collections.OrderedDict()
109        self.dispersion = dict()
110
111        self.orientation_params = []
112        self.magnetic_params = []
113        self.fixed = []
114        for p in parameters.user_parameters():
115            self.params[p.name] = p.default
116            self.details[p.name] = [p.units] + p.limits
117            if p.polydisperse:
118                self.dispersion[p.name] = {
119                    'width': 0,
120                    'npts': 35,
121                    'nsigmas': 3,
122                    'type': 'gaussian',
123                }
124            if p.type == 'orientation':
125                self.orientation_params.append(p.name)
126                self.orientation_params.append(p.name+".width")
127                self.fixed.append(p.name+".width")
128            if p.type == 'magnetic':
129                self.orientation_params.append(p.name)
130                self.magnetic_params.append(p.name)
131                self.fixed.append(p.name+".width")
132
133        self.non_fittable = []
134
135        ## independent parameter name and unit [string]
136        self.input_name = "Q", #model_info.get("input_name", "Q")
137        self.input_unit = "A^{-1}" #model_info.get("input_unit", "A^{-1}")
138        self.output_name = "Intensity" #model_info.get("output_name", "Intensity")
139        self.output_unit = "cm^{-1}" #model_info.get("output_unit", "cm^{-1}")
140
141        ## _persistency_dict is used by sas.perspectives.fitting.basepage
142        ## to store dispersity reference.
143        ## TODO: _persistency_dict to persistency_dict throughout sasview
144        self._persistency_dict = {}
145
146        ## New fields introduced for opencl rewrite
147        self.cutoff = 1e-5
148
149    def __get_state__(self):
150        state = self.__dict__.copy()
151        state.pop('_model')
152        # May need to reload model info on set state since it has pointers
153        # to python implementations of Iq, etc.
154        #state.pop('_model_info')
155        return state
156
157    def __set_state__(self, state):
158        self.__dict__ = state
159        self._model = None
160
161    def __str__(self):
162        """
163        :return: string representation
164        """
165        return self.name
166
167    def is_fittable(self, par_name):
168        """
169        Check if a given parameter is fittable or not
170
171        :param par_name: the parameter name to check
172        """
173        return par_name.lower() in self.fixed
174        #For the future
175        #return self.params[str(par_name)].is_fittable()
176
177
178    # pylint: disable=no-self-use
179    def getProfile(self):
180        """
181        Get SLD profile
182
183        : return: (z, beta) where z is a list of depth of the transition points
184                beta is a list of the corresponding SLD values
185        """
186        return None, None
187
188    def setParam(self, name, value):
189        """
190        Set the value of a model parameter
191
192        :param name: name of the parameter
193        :param value: value of the parameter
194
195        """
196        # Look for dispersion parameters
197        toks = name.split('.')
198        if len(toks) == 2:
199            for item in self.dispersion.keys():
200                if item.lower() == toks[0].lower():
201                    for par in self.dispersion[item]:
202                        if par.lower() == toks[1].lower():
203                            self.dispersion[item][par] = value
204                            return
205        else:
206            # Look for standard parameter
207            for item in self.params.keys():
208                if item.lower() == name.lower():
209                    self.params[item] = value
210                    return
211
212        raise ValueError("Model does not contain parameter %s" % name)
213
214    def getParam(self, name):
215        """
216        Set the value of a model parameter
217
218        :param name: name of the parameter
219
220        """
221        # Look for dispersion parameters
222        toks = name.split('.')
223        if len(toks) == 2:
224            for item in self.dispersion.keys():
225                if item.lower() == toks[0].lower():
226                    for par in self.dispersion[item]:
227                        if par.lower() == toks[1].lower():
228                            return self.dispersion[item][par]
229        else:
230            # Look for standard parameter
231            for item in self.params.keys():
232                if item.lower() == name.lower():
233                    return self.params[item]
234
235        raise ValueError("Model does not contain parameter %s" % name)
236
237    def getParamList(self):
238        """
239        Return a list of all available parameters for the model
240        """
241        param_list = self.params.keys()
242        # WARNING: Extending the list with the dispersion parameters
243        param_list.extend(self.getDispParamList())
244        return param_list
245
246    def getDispParamList(self):
247        """
248        Return a list of polydispersity parameters for the model
249        """
250        # TODO: fix test so that parameter order doesn't matter
251        ret = ['%s.%s' % (p.name.lower(), ext)
252               for p in self._model_info.parameters.user_parameters()
253               for ext in ('npts', 'nsigmas', 'width')
254               if p.polydisperse]
255        #print(ret)
256        return ret
257
258    def clone(self):
259        """ Return a identical copy of self """
260        return deepcopy(self)
261
262    def run(self, x=0.0):
263        """
264        Evaluate the model
265
266        :param x: input q, or [q,phi]
267
268        :return: scattering function P(q)
269
270        **DEPRECATED**: use calculate_Iq instead
271        """
272        if isinstance(x, (list, tuple)):
273            # pylint: disable=unpacking-non-sequence
274            q, phi = x
275            return self.calculate_Iq([q * math.cos(phi)],
276                                     [q * math.sin(phi)])[0]
277        else:
278            return self.calculate_Iq([float(x)])[0]
279
280
281    def runXY(self, x=0.0):
282        """
283        Evaluate the model in cartesian coordinates
284
285        :param x: input q, or [qx, qy]
286
287        :return: scattering function P(q)
288
289        **DEPRECATED**: use calculate_Iq instead
290        """
291        if isinstance(x, (list, tuple)):
292            return self.calculate_Iq([float(x[0])], [float(x[1])])[0]
293        else:
294            return self.calculate_Iq([float(x)])[0]
295
296    def evalDistribution(self, qdist):
297        r"""
298        Evaluate a distribution of q-values.
299
300        :param qdist: array of q or a list of arrays [qx,qy]
301
302        * For 1D, a numpy array is expected as input
303
304        ::
305
306            evalDistribution(q)
307
308          where *q* is a numpy array.
309
310        * For 2D, a list of *[qx,qy]* is expected with 1D arrays as input
311
312        ::
313
314              qx = [ qx[0], qx[1], qx[2], ....]
315              qy = [ qy[0], qy[1], qy[2], ....]
316
317        If the model is 1D only, then
318
319        .. math::
320
321            q = \sqrt{q_x^2+q_y^2}
322
323        """
324        if isinstance(qdist, (list, tuple)):
325            # Check whether we have a list of ndarrays [qx,qy]
326            qx, qy = qdist
327            if not self._model_info.parameters.has_2d:
328                return self.calculate_Iq(np.sqrt(qx ** 2 + qy ** 2))
329            else:
330                return self.calculate_Iq(qx, qy)
331
332        elif isinstance(qdist, np.ndarray):
333            # We have a simple 1D distribution of q-values
334            return self.calculate_Iq(qdist)
335
336        else:
337            raise TypeError("evalDistribution expects q or [qx, qy], not %r"
338                            % type(qdist))
339
340    def calculate_Iq(self, *args):
341        """
342        Calculate Iq for one set of q with the current parameters.
343
344        If the model is 1D, use *q*.  If 2D, use *qx*, *qy*.
345
346        This should NOT be used for fitting since it copies the *q* vectors
347        to the card for each evaluation.
348        """
349        if self._model is None:
350            self._model = core.build_model(self._model_info)
351        q_vectors = [np.asarray(q) for q in args]
352        kernel = self._model.make_kernel(q_vectors)
353        pairs = [self._get_weights(p)
354                 for p in self._model_info.parameters.call_parameters]
355        call_details, weights, values = details.build_details(kernel, pairs)
356        result = kernel(call_details, weights, values, cutoff=self.cutoff)
357        kernel.q_input.release()
358        kernel.release()
359        return result
360
361    def calculate_ER(self):
362        """
363        Calculate the effective radius for P(q)*S(q)
364
365        :return: the value of the effective radius
366        """
367        if model_info.ER is None:
368            return 1.0
369        else:
370            values, weights = self._dispersion_mesh()
371            fv = model_info.ER(*values)
372            #print(values[0].shape, weights.shape, fv.shape)
373            return np.sum(weights * fv) / np.sum(weights)
374
375    def calculate_VR(self):
376        """
377        Calculate the volf ratio for P(q)*S(q)
378
379        :return: the value of the volf ratio
380        """
381        if model_info.VR is None:
382            return 1.0
383        else:
384            values, weights = self._dispersion_mesh()
385            whole, part = model_info.VR(*values)
386            return np.sum(weights * part) / np.sum(weights * whole)
387
388    def set_dispersion(self, parameter, dispersion):
389        """
390        Set the dispersion object for a model parameter
391
392        :param parameter: name of the parameter [string]
393        :param dispersion: dispersion object of type Dispersion
394        """
395        if parameter.lower() in (s.lower() for s in self.params.keys()):
396            # TODO: Store the disperser object directly in the model.
397            # The current method of creating one on the fly whenever it is
398            # needed is kind of funky.
399            # Note: can't seem to get disperser parameters from sasview
400            # (1) Could create a sasview model that has not yet # been
401            # converted, assign the disperser to one of its polydisperse
402            # parameters, then retrieve the disperser parameters from the
403            # sasview model.  (2) Could write a disperser parameter retriever
404            # in sasview.  (3) Could modify sasview to use sasmodels.weights
405            # dispersers.
406            # For now, rely on the fact that the sasview only ever uses
407            # new dispersers in the set_dispersion call and create a new
408            # one instead of trying to assign parameters.
409            from . import weights
410            disperser = weights.dispersers[dispersion.__class__.__name__]
411            dispersion = weights.models[disperser]()
412            self.dispersion[parameter] = dispersion.get_pars()
413        else:
414            raise ValueError("%r is not a dispersity or orientation parameter")
415
416    def _dispersion_mesh(self):
417        """
418        Create a mesh grid of dispersion parameters and weights.
419
420        Returns [p1,p2,...],w where pj is a vector of values for parameter j
421        and w is a vector containing the products for weights for each
422        parameter set in the vector.
423        """
424        pars = self._model_info.partype['volume']
425        return details.dispersion_mesh([self._get_weights(p) for p in pars])
426
427    def _get_weights(self, par):
428        """
429        Return dispersion weights for parameter
430        """
431        if par.polydisperse:
432            dis = self.dispersion[par.name]
433            value, weight = weights.get_weights(
434                dis['type'], dis['npts'], dis['width'], dis['nsigmas'],
435                self.params[par.name], par.limits, par.relative_pd)
436            return value, weight / np.sum(weight)
437        else:
438            return [self.params[par.name]], []
439
440def test_model():
441    """
442    Test that a sasview model (cylinder) can be run.
443    """
444    Cylinder = _make_standard_model('cylinder')
445    cylinder = Cylinder()
446    return cylinder.evalDistribution([0.1,0.1])
447
448
449def test_model_list():
450    """
451    Make sure that all models build as sasview models.
452    """
453    from .exception import annotate_exception
454    for name in core.list_models():
455        try:
456            _make_standard_model(name)
457        except:
458            annotate_exception("when loading "+name)
459            raise
460
461if __name__ == "__main__":
462    print("cylinder(0.1,0.1)=%g"%test_model())
Note: See TracBrowser for help on using the repository browser.