source: sasmodels/sasmodels/sasview_model.py @ 5efe850

core_shell_microgelscostrafo411magnetic_modelrelease_v0.94release_v0.95ticket-1257-vesicle-productticket_1156ticket_1265_superballticket_822_more_unit_tests
Last change on this file since 5efe850 was f36c1b7, checked in by Paul Kienzle <pkienzle@…>, 8 years ago

suppress 'initializing model' print statement

  • Property mode set to 100644
File size: 17.3 KB
Line 
1"""
2Sasview model constructor.
3
4Given a module defining an OpenCL kernel such as sasmodels.models.cylinder,
5create a sasview model class to run that kernel as follows::
6
7    from sasmodels.sasview_model import make_class
8    from sasmodels.models import cylinder
9    CylinderModel = make_class(cylinder, dtype='single')
10
11The model parameters for sasmodels are different from those in sasview.
12When reloading previously saved models, the parameters should be converted
13using :func:`sasmodels.convert.convert`.
14"""
15from __future__ import print_function
16
17import math
18from copy import deepcopy
19import collections
20import traceback
21import logging
22
23import numpy as np
24
25from . import core
26from . import custom
27from . import generate
28
29try:
30    from typing import Dict, Mapping, Any, Sequence, Tuple, NamedTuple, List, Optional
31    from .kernel import KernelModel
32    MultiplicityInfoType = NamedTuple(
33        'MuliplicityInfo',
34        [("number", int), ("control", str), ("choices", List[str]),
35         ("x_axis_label", str)])
36except ImportError:
37    pass
38
39# TODO: separate x_axis_label from multiplicity info
40# The x-axis label belongs with the profile generating function
41MultiplicityInfo = collections.namedtuple(
42    'MultiplicityInfo',
43    ["number", "control", "choices", "x_axis_label"],
44)
45
46def load_standard_models():
47    """
48    Load and return the list of predefined models.
49
50    If there is an error loading a model, then a traceback is logged and the
51    model is not returned.
52    """
53    models = []
54    for name in core.list_models():
55        try:
56            models.append(_make_standard_model(name))
57        except:
58            logging.error(traceback.format_exc())
59    return models
60
61
62def load_custom_model(path):
63    """
64    Load a custom model given the model path.
65    """
66    kernel_module = custom.load_custom_kernel_module(path)
67    model_info = generate.make_model_info(kernel_module)
68    return _make_model_from_info(model_info)
69
70
71def _make_standard_model(name):
72    """
73    Load the sasview model defined by *name*.
74
75    *name* can be a standard model name or a path to a custom model.
76
77    Returns a class that can be used directly as a sasview model.
78    """
79    kernel_module = generate.load_kernel_module(name)
80    model_info = generate.make_model_info(kernel_module)
81    return _make_model_from_info(model_info)
82
83
84def _make_model_from_info(model_info):
85    """
86    Convert *model_info* into a SasView model wrapper.
87    """
88    model_info['variant_info'] = None  # temporary hack for older sasview
89    def __init__(self, multiplicity=1):
90        SasviewModel.__init__(self, multiplicity=multiplicity)
91    attrs = _generate_model_attributes(model_info)
92    attrs['__init__'] = __init__
93    ConstructedModel = type(model_info['id'], (SasviewModel,), attrs)
94    return ConstructedModel
95
96def _generate_model_attributes(model_info):
97    # type: (ModelInfo) -> Dict[str, Any]
98    """
99    Generate the class attributes for the model.
100
101    This should include all the information necessary to query the model
102    details so that you do not need to instantiate a model to query it.
103
104    All the attributes should be immutable to avoid accidents.
105    """
106    attrs = {}  # type: Dict[str, Any]
107    attrs['_model_info'] = model_info
108    attrs['name'] = model_info['name']
109    attrs['id'] = model_info['id']
110    attrs['description'] = model_info['description']
111    attrs['category'] = model_info['category']
112
113    # TODO: allow model to override axis labels input/output name/unit
114
115    #self.is_multifunc = False
116    non_fittable = []  # type: List[str]
117    variants = MultiplicityInfo(0, "", [], "")
118    attrs['is_structure_factor'] = model_info['structure_factor']
119    attrs['is_form_factor'] = model_info['ER'] is not None
120    attrs['is_multiplicity_model'] = variants[0] > 1
121    attrs['multiplicity_info'] = variants
122
123    partype = model_info['partype']
124    orientation_params = (
125            partype['orientation']
126            + [n + '.width' for n in partype['orientation']]
127            + partype['magnetic'])
128    magnetic_params = partype['magnetic']
129    fixed = [n + '.width' for n in partype['pd-2d']]
130
131    attrs['orientation_params'] = tuple(orientation_params)
132    attrs['magnetic_params'] = tuple(magnetic_params)
133    attrs['fixed'] = tuple(fixed)
134
135    attrs['non_fittable'] = tuple(non_fittable)
136
137    return attrs
138
139class SasviewModel(object):
140    """
141    Sasview wrapper for opencl/ctypes model.
142    """
143    # Model parameters for the specific model are set in the class constructor
144    # via the _generate_model_attributes function, which subclasses
145    # SasviewModel.  They are included here for typing and documentation
146    # purposes.
147    _model = None       # type: KernelModel
148    _model_info = None  # type: ModelInfo
149    #: load/save name for the model
150    id = None           # type: str
151    #: display name for the model
152    name = None         # type: str
153    #: short model description
154    description = None  # type: str
155    #: default model category
156    category = None     # type: str
157
158    #: names of the orientation parameters in the order they appear
159    orientation_params = None # type: Sequence[str]
160    #: names of the magnetic parameters in the order they appear
161    magnetic_params = None    # type: Sequence[str]
162    #: names of the fittable parameters
163    fixed = None              # type: Sequence[str]
164    # TODO: the attribute fixed is ill-named
165
166    # Axis labels
167    input_name = "Q"
168    input_unit = "A^{-1}"
169    output_name = "Intensity"
170    output_unit = "cm^{-1}"
171
172    #: default cutoff for polydispersity
173    cutoff = 1e-5
174
175    # Note: Use non-mutable values for class attributes to avoid errors
176    #: parameters that are not fitted
177    non_fittable = ()        # type: Sequence[str]
178
179    #: True if model should appear as a structure factor
180    is_structure_factor = False
181    #: True if model should appear as a form factor
182    is_form_factor = False
183    #: True if model has multiplicity
184    is_multiplicity_model = False
185    #: Mulitplicity information
186    multiplicity_info = None # type: MultiplicityInfoType
187
188    # Per-instance variables
189    #: parameter {name: value} mapping
190    params = None      # type: Dict[str, float]
191    #: values for dispersion width, npts, nsigmas and type
192    dispersion = None  # type: Dict[str, Any]
193    #: units and limits for each parameter
194    details = None     # type: Mapping[str, Tuple(str, float, float)]
195    #: multiplicity used, or None if no multiplicity controls
196    multiplicity = None     # type: Optional[int]
197
198    def __init__(self, multiplicity):
199        # type: () -> None
200        #print("initializing", self.name)
201        #raise Exception("first initialization")
202        self._model = None
203
204        ## _persistency_dict is used by sas.perspectives.fitting.basepage
205        ## to store dispersity reference.
206        self._persistency_dict = {}
207
208        self.multiplicity = multiplicity
209
210        self.params = collections.OrderedDict()
211        self.dispersion = {}
212        self.details = {}
213
214        for p in self._model_info['parameters']:
215            self.params[p.name] = p.default
216            self.details[p.name] = [p.units] + p.limits
217
218        for name in self._model_info['partype']['pd-2d']:
219            self.dispersion[name] = {
220                'width': 0,
221                'npts': 35,
222                'nsigmas': 3,
223                'type': 'gaussian',
224            }
225
226    def __get_state__(self):
227        state = self.__dict__.copy()
228        state.pop('_model')
229        # May need to reload model info on set state since it has pointers
230        # to python implementations of Iq, etc.
231        #state.pop('_model_info')
232        return state
233
234    def __set_state__(self, state):
235        self.__dict__ = state
236        self._model = None
237
238    def __str__(self):
239        """
240        :return: string representation
241        """
242        return self.name
243
244    def is_fittable(self, par_name):
245        """
246        Check if a given parameter is fittable or not
247
248        :param par_name: the parameter name to check
249        """
250        return par_name.lower() in self.fixed
251        #For the future
252        #return self.params[str(par_name)].is_fittable()
253
254
255    # pylint: disable=no-self-use
256    def getProfile(self):
257        """
258        Get SLD profile
259
260        : return: (z, beta) where z is a list of depth of the transition points
261                beta is a list of the corresponding SLD values
262        """
263        return None, None
264
265    def setParam(self, name, value):
266        """
267        Set the value of a model parameter
268
269        :param name: name of the parameter
270        :param value: value of the parameter
271
272        """
273        # Look for dispersion parameters
274        toks = name.split('.')
275        if len(toks) == 2:
276            for item in self.dispersion.keys():
277                if item.lower() == toks[0].lower():
278                    for par in self.dispersion[item]:
279                        if par.lower() == toks[1].lower():
280                            self.dispersion[item][par] = value
281                            return
282        else:
283            # Look for standard parameter
284            for item in self.params.keys():
285                if item.lower() == name.lower():
286                    self.params[item] = value
287                    return
288
289        raise ValueError("Model does not contain parameter %s" % name)
290
291    def getParam(self, name):
292        """
293        Set the value of a model parameter
294
295        :param name: name of the parameter
296
297        """
298        # Look for dispersion parameters
299        toks = name.split('.')
300        if len(toks) == 2:
301            for item in self.dispersion.keys():
302                if item.lower() == toks[0].lower():
303                    for par in self.dispersion[item]:
304                        if par.lower() == toks[1].lower():
305                            return self.dispersion[item][par]
306        else:
307            # Look for standard parameter
308            for item in self.params.keys():
309                if item.lower() == name.lower():
310                    return self.params[item]
311
312        raise ValueError("Model does not contain parameter %s" % name)
313
314    def getParamList(self):
315        """
316        Return a list of all available parameters for the model
317        """
318        param_list = self.params.keys()
319        # WARNING: Extending the list with the dispersion parameters
320        param_list.extend(self.getDispParamList())
321        return param_list
322
323    def getDispParamList(self):
324        """
325        Return a list of polydispersity parameters for the model
326        """
327        # TODO: fix test so that parameter order doesn't matter
328        ret = ['%s.%s' % (d.lower(), p)
329               for d in self._model_info['partype']['pd-2d']
330               for p in ('npts', 'nsigmas', 'width')]
331        #print(ret)
332        return ret
333
334    def clone(self):
335        """ Return a identical copy of self """
336        return deepcopy(self)
337
338    def run(self, x=0.0):
339        """
340        Evaluate the model
341
342        :param x: input q, or [q,phi]
343
344        :return: scattering function P(q)
345
346        **DEPRECATED**: use calculate_Iq instead
347        """
348        if isinstance(x, (list, tuple)):
349            # pylint: disable=unpacking-non-sequence
350            q, phi = x
351            return self.calculate_Iq([q * math.cos(phi)],
352                                     [q * math.sin(phi)])[0]
353        else:
354            return self.calculate_Iq([float(x)])[0]
355
356
357    def runXY(self, x=0.0):
358        """
359        Evaluate the model in cartesian coordinates
360
361        :param x: input q, or [qx, qy]
362
363        :return: scattering function P(q)
364
365        **DEPRECATED**: use calculate_Iq instead
366        """
367        if isinstance(x, (list, tuple)):
368            return self.calculate_Iq([float(x[0])], [float(x[1])])[0]
369        else:
370            return self.calculate_Iq([float(x)])[0]
371
372    def evalDistribution(self, qdist):
373        r"""
374        Evaluate a distribution of q-values.
375
376        :param qdist: array of q or a list of arrays [qx,qy]
377
378        * For 1D, a numpy array is expected as input
379
380        ::
381
382            evalDistribution(q)
383
384          where *q* is a numpy array.
385
386        * For 2D, a list of *[qx,qy]* is expected with 1D arrays as input
387
388        ::
389
390              qx = [ qx[0], qx[1], qx[2], ....]
391              qy = [ qy[0], qy[1], qy[2], ....]
392
393        If the model is 1D only, then
394
395        .. math::
396
397            q = \sqrt{q_x^2+q_y^2}
398
399        """
400        if isinstance(qdist, (list, tuple)):
401            # Check whether we have a list of ndarrays [qx,qy]
402            qx, qy = qdist
403            partype = self._model_info['partype']
404            if not partype['orientation'] and not partype['magnetic']:
405                return self.calculate_Iq(np.sqrt(qx ** 2 + qy ** 2))
406            else:
407                return self.calculate_Iq(qx, qy)
408
409        elif isinstance(qdist, np.ndarray):
410            # We have a simple 1D distribution of q-values
411            return self.calculate_Iq(qdist)
412
413        else:
414            raise TypeError("evalDistribution expects q or [qx, qy], not %r"
415                            % type(qdist))
416
417    def calculate_Iq(self, *args):
418        """
419        Calculate Iq for one set of q with the current parameters.
420
421        If the model is 1D, use *q*.  If 2D, use *qx*, *qy*.
422
423        This should NOT be used for fitting since it copies the *q* vectors
424        to the card for each evaluation.
425        """
426        if self._model is None:
427            self._model = core.build_model(self._model_info)
428        q_vectors = [np.asarray(q) for q in args]
429        fn = self._model.make_kernel(q_vectors)
430        pars = [self.params[v] for v in fn.fixed_pars]
431        pd_pars = [self._get_weights(p) for p in fn.pd_pars]
432        result = fn(pars, pd_pars, self.cutoff)
433        fn.q_input.release()
434        fn.release()
435        return result
436
437    def calculate_ER(self):
438        """
439        Calculate the effective radius for P(q)*S(q)
440
441        :return: the value of the effective radius
442        """
443        ER = self._model_info.get('ER', None)
444        if ER is None:
445            return 1.0
446        else:
447            values, weights = self._dispersion_mesh()
448            fv = ER(*values)
449            #print(values[0].shape, weights.shape, fv.shape)
450            return np.sum(weights * fv) / np.sum(weights)
451
452    def calculate_VR(self):
453        """
454        Calculate the volf ratio for P(q)*S(q)
455
456        :return: the value of the volf ratio
457        """
458        VR = self._model_info.get('VR', None)
459        if VR is None:
460            return 1.0
461        else:
462            values, weights = self._dispersion_mesh()
463            whole, part = VR(*values)
464            return np.sum(weights * part) / np.sum(weights * whole)
465
466    def set_dispersion(self, parameter, dispersion):
467        """
468        Set the dispersion object for a model parameter
469
470        :param parameter: name of the parameter [string]
471        :param dispersion: dispersion object of type Dispersion
472        """
473        if parameter.lower() in (s.lower() for s in self.params.keys()):
474            # TODO: Store the disperser object directly in the model.
475            # The current method of creating one on the fly whenever it is
476            # needed is kind of funky.
477            # Note: can't seem to get disperser parameters from sasview
478            # (1) Could create a sasview model that has not yet # been
479            # converted, assign the disperser to one of its polydisperse
480            # parameters, then retrieve the disperser parameters from the
481            # sasview model.  (2) Could write a disperser parameter retriever
482            # in sasview.  (3) Could modify sasview to use sasmodels.weights
483            # dispersers.
484            # For now, rely on the fact that the sasview only ever uses
485            # new dispersers in the set_dispersion call and create a new
486            # one instead of trying to assign parameters.
487            from . import weights
488            disperser = weights.dispersers[dispersion.__class__.__name__]
489            dispersion = weights.models[disperser]()
490            self.dispersion[parameter] = dispersion.get_pars()
491        else:
492            raise ValueError("%r is not a dispersity or orientation parameter")
493
494    def _dispersion_mesh(self):
495        """
496        Create a mesh grid of dispersion parameters and weights.
497
498        Returns [p1,p2,...],w where pj is a vector of values for parameter j
499        and w is a vector containing the products for weights for each
500        parameter set in the vector.
501        """
502        pars = self._model_info['partype']['volume']
503        return core.dispersion_mesh([self._get_weights(p) for p in pars])
504
505    def _get_weights(self, par):
506        """
507        Return dispersion weights for parameter
508        """
509        from . import weights
510        relative = self._model_info['partype']['pd-rel']
511        limits = self._model_info['limits']
512        dis = self.dispersion[par]
513        value, weight = weights.get_weights(
514            dis['type'], dis['npts'], dis['width'], dis['nsigmas'],
515            self.params[par], limits[par], par in relative)
516        return value, weight / np.sum(weight)
517
518
519def test_model():
520    """
521    Test that a sasview model (cylinder) can be run.
522    """
523    Cylinder = _make_standard_model('cylinder')
524    cylinder = Cylinder()
525    return cylinder.evalDistribution([0.1,0.1])
526
527
528def test_model_list():
529    """
530    Make sure that all models build as sasview models.
531    """
532    from .exception import annotate_exception
533    for name in core.list_models():
534        try:
535            _make_standard_model(name)
536        except:
537            annotate_exception("when loading "+name)
538            raise
539
540if __name__ == "__main__":
541    print("cylinder(0.1,0.1)=%g"%test_model())
Note: See TracBrowser for help on using the repository browser.