source: sasmodels/sasmodels/sasview_model.py @ 72a081d

core_shell_microgelscostrafo411magnetic_modelrelease_v0.94release_v0.95ticket-1257-vesicle-productticket_1156ticket_1265_superballticket_822_more_unit_tests
Last change on this file since 72a081d was 72a081d, checked in by Paul Kienzle <pkienzle@…>, 8 years ago

refactor product/mixture; add load model from path; default compare to -cutoff=0

  • Property mode set to 100644
File size: 13.1 KB
Line 
1"""
2Sasview model constructor.
3
4Given a module defining an OpenCL kernel such as sasmodels.models.cylinder,
5create a sasview model class to run that kernel as follows::
6
7    from sasmodels.sasview_model import make_class
8    from sasmodels.models import cylinder
9    CylinderModel = make_class(cylinder, dtype='single')
10
11The model parameters for sasmodels are different from those in sasview.
12When reloading previously saved models, the parameters should be converted
13using :func:`sasmodels.convert.convert`.
14"""
15
16import math
17from copy import deepcopy
18import collections
19
20import numpy as np
21
22from . import core
23from . import generate
24from . import custom
25
26def standard_models():
27    return [make_class(model_name) for model_name in core.list_models()]
28
29# TODO: rename to make_class_from_name and update sasview
30def make_class(model_name):
31    """
32    Load the sasview model defined in *kernel_module*.
33
34    Returns a class that can be used directly as a sasview model.t
35
36    Defaults to using the new name for a model.  Setting
37    *namestyle='oldname'* will produce a class with a name
38    compatible with SasView.
39    """
40    model_info = core.load_model_info(model_name)
41    return make_class_from_info(model_info)
42
43def make_class_from_file(path):
44    model_info = core.load_model_info_from_path(path)
45    return make_class_from_info(model_info)
46
47def make_class_from_info(model_info):
48    def __init__(self, multfactor=1):
49        SasviewModel.__init__(self)
50    attrs = dict(__init__=__init__, _model_info=model_info)
51    ConstructedModel = type(model_info['name'], (SasviewModel,), attrs)
52    return ConstructedModel
53
54class SasviewModel(object):
55    """
56    Sasview wrapper for opencl/ctypes model.
57    """
58    def __init__(self):
59        self._kernel = None
60        model_info = self._model_info
61
62        self.name = model_info['name']
63        self.oldname = model_info['oldname']
64        self.description = model_info['description']
65        self.category = None
66        self.multiplicity_info = None
67        self.is_multifunc = False
68
69        ## interpret the parameters
70        ## TODO: reorganize parameter handling
71        self.details = dict()
72        self.params = collections.OrderedDict()
73        self.dispersion = dict()
74        partype = model_info['partype']
75
76        for p in model_info['parameters']:
77            self.params[p.name] = p.default
78            self.details[p.name] = [p.units] + p.limits
79
80        for name in partype['pd-2d']:
81            self.dispersion[name] = {
82                'width': 0,
83                'npts': 35,
84                'nsigmas': 3,
85                'type': 'gaussian',
86            }
87
88        self.orientation_params = (
89            partype['orientation']
90            + [n + '.width' for n in partype['orientation']]
91            + partype['magnetic'])
92        self.magnetic_params = partype['magnetic']
93        self.fixed = [n + '.width' for n in partype['pd-2d']]
94        self.non_fittable = []
95
96        ## independent parameter name and unit [string]
97        self.input_name = model_info.get("input_name", "Q")
98        self.input_unit = model_info.get("input_unit", "A^{-1}")
99        self.output_name = model_info.get("output_name", "Intensity")
100        self.output_unit = model_info.get("output_unit", "cm^{-1}")
101
102        ## _persistency_dict is used by sas.perspectives.fitting.basepage
103        ## to store dispersity reference.
104        ## TODO: _persistency_dict to persistency_dict throughout sasview
105        self._persistency_dict = {}
106
107        ## New fields introduced for opencl rewrite
108        self.cutoff = 1e-5
109
110    def __get_state__(self):
111        state = self.__dict__.copy()
112        model_id = self._model_info['id']
113        state.pop('_kernel')
114        # May need to reload model info on set state since it has pointers
115        # to python implementations of Iq, etc.
116        #state.pop('_model_info')
117        return state
118
119    def __set_state__(self, state):
120        self.__dict__ = state
121        self._kernel = None
122
123    def __str__(self):
124        """
125        :return: string representation
126        """
127        return self.name
128
129    def is_fittable(self, par_name):
130        """
131        Check if a given parameter is fittable or not
132
133        :param par_name: the parameter name to check
134        """
135        return par_name.lower() in self.fixed
136        #For the future
137        #return self.params[str(par_name)].is_fittable()
138
139
140    # pylint: disable=no-self-use
141    def getProfile(self):
142        """
143        Get SLD profile
144
145        : return: (z, beta) where z is a list of depth of the transition points
146                beta is a list of the corresponding SLD values
147        """
148        return None, None
149
150    def setParam(self, name, value):
151        """
152        Set the value of a model parameter
153
154        :param name: name of the parameter
155        :param value: value of the parameter
156
157        """
158        # Look for dispersion parameters
159        toks = name.split('.')
160        if len(toks) == 2:
161            for item in self.dispersion.keys():
162                if item.lower() == toks[0].lower():
163                    for par in self.dispersion[item]:
164                        if par.lower() == toks[1].lower():
165                            self.dispersion[item][par] = value
166                            return
167        else:
168            # Look for standard parameter
169            for item in self.params.keys():
170                if item.lower() == name.lower():
171                    self.params[item] = value
172                    return
173
174        raise ValueError("Model does not contain parameter %s" % name)
175
176    def getParam(self, name):
177        """
178        Set the value of a model parameter
179
180        :param name: name of the parameter
181
182        """
183        # Look for dispersion parameters
184        toks = name.split('.')
185        if len(toks) == 2:
186            for item in self.dispersion.keys():
187                if item.lower() == toks[0].lower():
188                    for par in self.dispersion[item]:
189                        if par.lower() == toks[1].lower():
190                            return self.dispersion[item][par]
191        else:
192            # Look for standard parameter
193            for item in self.params.keys():
194                if item.lower() == name.lower():
195                    return self.params[item]
196
197        raise ValueError("Model does not contain parameter %s" % name)
198
199    def getParamList(self):
200        """
201        Return a list of all available parameters for the model
202        """
203        param_list = self.params.keys()
204        # WARNING: Extending the list with the dispersion parameters
205        param_list.extend(self.getDispParamList())
206        return param_list
207
208    def getDispParamList(self):
209        """
210        Return a list of all available parameters for the model
211        """
212        # TODO: fix test so that parameter order doesn't matter
213        ret = ['%s.%s' % (d.lower(), p)
214               for d in self._model_info['partype']['pd-2d']
215               for p in ('npts', 'nsigmas', 'width')]
216        #print(ret)
217        return ret
218
219    def clone(self):
220        """ Return a identical copy of self """
221        return deepcopy(self)
222
223    def run(self, x=0.0):
224        """
225        Evaluate the model
226
227        :param x: input q, or [q,phi]
228
229        :return: scattering function P(q)
230
231        **DEPRECATED**: use calculate_Iq instead
232        """
233        if isinstance(x, (list, tuple)):
234            # pylint: disable=unpacking-non-sequence
235            q, phi = x
236            return self.calculate_Iq([q * math.cos(phi)],
237                                     [q * math.sin(phi)])[0]
238        else:
239            return self.calculate_Iq([float(x)])[0]
240
241
242    def runXY(self, x=0.0):
243        """
244        Evaluate the model in cartesian coordinates
245
246        :param x: input q, or [qx, qy]
247
248        :return: scattering function P(q)
249
250        **DEPRECATED**: use calculate_Iq instead
251        """
252        if isinstance(x, (list, tuple)):
253            return self.calculate_Iq([float(x[0])], [float(x[1])])[0]
254        else:
255            return self.calculate_Iq([float(x)])[0]
256
257    def evalDistribution(self, qdist):
258        r"""
259        Evaluate a distribution of q-values.
260
261        :param qdist: array of q or a list of arrays [qx,qy]
262
263        * For 1D, a numpy array is expected as input
264
265        ::
266
267            evalDistribution(q)
268
269          where *q* is a numpy array.
270
271        * For 2D, a list of *[qx,qy]* is expected with 1D arrays as input
272
273        ::
274
275              qx = [ qx[0], qx[1], qx[2], ....]
276              qy = [ qy[0], qy[1], qy[2], ....]
277
278        If the model is 1D only, then
279
280        .. math::
281
282            q = \sqrt{q_x^2+q_y^2}
283
284        """
285        if isinstance(qdist, (list, tuple)):
286            # Check whether we have a list of ndarrays [qx,qy]
287            qx, qy = qdist
288            partype = self._model_info['partype']
289            if not partype['orientation'] and not partype['magnetic']:
290                return self.calculate_Iq(np.sqrt(qx ** 2 + qy ** 2))
291            else:
292                return self.calculate_Iq(qx, qy)
293
294        elif isinstance(qdist, np.ndarray):
295            # We have a simple 1D distribution of q-values
296            return self.calculate_Iq(qdist)
297
298        else:
299            raise TypeError("evalDistribution expects q or [qx, qy], not %r"
300                            % type(qdist))
301
302    def calculate_Iq(self, *args):
303        """
304        Calculate Iq for one set of q with the current parameters.
305
306        If the model is 1D, use *q*.  If 2D, use *qx*, *qy*.
307
308        This should NOT be used for fitting since it copies the *q* vectors
309        to the card for each evaluation.
310        """
311        if self._kernel is None:
312            self._kernel = core.build_model(self._model_info)
313        q_vectors = [np.asarray(q) for q in args]
314        fn = self._kernel(q_vectors)
315        pars = [self.params[v] for v in fn.fixed_pars]
316        pd_pars = [self._get_weights(p) for p in fn.pd_pars]
317        result = fn(pars, pd_pars, self.cutoff)
318        fn.q_input.release()
319        fn.release()
320        return result
321
322    def calculate_ER(self):
323        """
324        Calculate the effective radius for P(q)*S(q)
325
326        :return: the value of the effective radius
327        """
328        ER = self._model_info.get('ER', None)
329        if ER is None:
330            return 1.0
331        else:
332            values, weights = self._dispersion_mesh()
333            fv = ER(*values)
334            #print(values[0].shape, weights.shape, fv.shape)
335            return np.sum(weights * fv) / np.sum(weights)
336
337    def calculate_VR(self):
338        """
339        Calculate the volf ratio for P(q)*S(q)
340
341        :return: the value of the volf ratio
342        """
343        VR = self._model_info.get('VR', None)
344        if VR is None:
345            return 1.0
346        else:
347            values, weights = self._dispersion_mesh()
348            whole, part = VR(*values)
349            return np.sum(weights * part) / np.sum(weights * whole)
350
351    def set_dispersion(self, parameter, dispersion):
352        """
353        Set the dispersion object for a model parameter
354
355        :param parameter: name of the parameter [string]
356        :param dispersion: dispersion object of type Dispersion
357        """
358        if parameter.lower() in (s.lower() for s in self.params.keys()):
359            # TODO: Store the disperser object directly in the model.
360            # The current method of creating one on the fly whenever it is
361            # needed is kind of funky.
362            # Note: can't seem to get disperser parameters from sasview
363            # (1) Could create a sasview model that has not yet # been
364            # converted, assign the disperser to one of its polydisperse
365            # parameters, then retrieve the disperser parameters from the
366            # sasview model.  (2) Could write a disperser parameter retriever
367            # in sasview.  (3) Could modify sasview to use sasmodels.weights
368            # dispersers.
369            # For now, rely on the fact that the sasview only ever uses
370            # new dispersers in the set_dispersion call and create a new
371            # one instead of trying to assign parameters.
372            from . import weights
373            disperser = weights.dispersers[dispersion.__class__.__name__]
374            dispersion = weights.models[disperser]()
375            self.dispersion[parameter] = dispersion.get_pars()
376        else:
377            raise ValueError("%r is not a dispersity or orientation parameter")
378
379    def _dispersion_mesh(self):
380        """
381        Create a mesh grid of dispersion parameters and weights.
382
383        Returns [p1,p2,...],w where pj is a vector of values for parameter j
384        and w is a vector containing the products for weights for each
385        parameter set in the vector.
386        """
387        pars = self._model_info['partype']['volume']
388        return core.dispersion_mesh([self._get_weights(p) for p in pars])
389
390    def _get_weights(self, par):
391        """
392            Return dispersion weights
393            :param par parameter name
394        """
395        from . import weights
396
397        relative = self._model_info['partype']['pd-rel']
398        limits = self._model_info['limits']
399        dis = self.dispersion[par]
400        value, weight = weights.get_weights(
401            dis['type'], dis['npts'], dis['width'], dis['nsigmas'],
402            self.params[par], limits[par], par in relative)
403        return value, weight / np.sum(weight)
404
Note: See TracBrowser for help on using the repository browser.