source: sasmodels/sasmodels/sasview_model.py @ 92d38285

core_shell_microgelscostrafo411magnetic_modelrelease_v0.94release_v0.95ticket-1257-vesicle-productticket_1156ticket_1265_superballticket_822_more_unit_tests
Last change on this file since 92d38285 was 92d38285, checked in by Paul Kienzle <pkienzle@…>, 8 years ago

manage set of models loaded by sasview; support loading sasview sum model

  • Property mode set to 100644
File size: 17.7 KB
Line 
1"""
2Sasview model constructor.
3
4Given a module defining an OpenCL kernel such as sasmodels.models.cylinder,
5create a sasview model class to run that kernel as follows::
6
7    from sasmodels.sasview_model import load_custom_model
8    CylinderModel = load_custom_model('sasmodels/models/cylinder.py')
9"""
10from __future__ import print_function
11
12import math
13from copy import deepcopy
14import collections
15import traceback
16import logging
17
18import numpy as np
19
20from . import core
21from . import custom
22from . import generate
23
24try:
25    from typing import Dict, Mapping, Any, Sequence, Tuple, NamedTuple, List, Optional
26    from .kernel import KernelModel
27    MultiplicityInfoType = NamedTuple(
28        'MuliplicityInfo',
29        [("number", int), ("control", str), ("choices", List[str]),
30         ("x_axis_label", str)])
31except ImportError:
32    pass
33
34# TODO: separate x_axis_label from multiplicity info
35# The x-axis label belongs with the profile generating function
36MultiplicityInfo = collections.namedtuple(
37    'MultiplicityInfo',
38    ["number", "control", "choices", "x_axis_label"],
39)
40
41MODELS = {}
42def find_model(modelname):
43    # TODO: used by sum/product model to load an existing model
44    # TODO: doesn't handle custom models properly
45    if modelname.endswith('.py'):
46        return load_custom_model(modelname)
47    elif modelname in MODELS:
48        return MODELS[modelname]
49    else:
50        raise ValueError("unknown model %r"%modelname)
51
52def load_standard_models():
53    """
54    Load and return the list of predefined models.
55
56    If there is an error loading a model, then a traceback is logged and the
57    model is not returned.
58    """
59    models = []
60    for name in core.list_models():
61        try:
62            MODELS[name] = _make_standard_model(name)
63            models.append(MODELS[name])
64        except Exception:
65            logging.error(traceback.format_exc())
66    return models
67
68
69def load_custom_model(path):
70    """
71    Load a custom model given the model path.
72    """
73    #print("load custom model", path)
74    kernel_module = custom.load_custom_kernel_module(path)
75    try:
76        model = kernel_module.Model
77    except AttributeError:
78        model_info = generate.make_model_info(kernel_module)
79        model = _make_model_from_info(model_info)
80    MODELS[model.name] = model
81    return model
82
83
84def _make_standard_model(name):
85    """
86    Load the sasview model defined by *name*.
87
88    *name* can be a standard model name or a path to a custom model.
89
90    Returns a class that can be used directly as a sasview model.
91    """
92    kernel_module = generate.load_kernel_module(name)
93    model_info = generate.make_model_info(kernel_module)
94    return _make_model_from_info(model_info)
95
96
97def _make_model_from_info(model_info):
98    """
99    Convert *model_info* into a SasView model wrapper.
100    """
101    model_info['variant_info'] = None  # temporary hack for older sasview
102    def __init__(self, multiplicity=1):
103        SasviewModel.__init__(self, multiplicity=multiplicity)
104    attrs = _generate_model_attributes(model_info)
105    attrs['__init__'] = __init__
106    ConstructedModel = type(model_info['id'], (SasviewModel,), attrs)
107    return ConstructedModel
108
109def _generate_model_attributes(model_info):
110    # type: (ModelInfo) -> Dict[str, Any]
111    """
112    Generate the class attributes for the model.
113
114    This should include all the information necessary to query the model
115    details so that you do not need to instantiate a model to query it.
116
117    All the attributes should be immutable to avoid accidents.
118    """
119    attrs = {}  # type: Dict[str, Any]
120    attrs['_model_info'] = model_info
121    attrs['name'] = model_info['name']
122    attrs['id'] = model_info['id']
123    attrs['description'] = model_info['description']
124    attrs['category'] = model_info['category']
125
126    # TODO: allow model to override axis labels input/output name/unit
127
128    #self.is_multifunc = False
129    non_fittable = []  # type: List[str]
130    variants = MultiplicityInfo(0, "", [], "")
131    attrs['is_structure_factor'] = model_info['structure_factor']
132    attrs['is_form_factor'] = model_info['ER'] is not None
133    attrs['is_multiplicity_model'] = variants[0] > 1
134    attrs['multiplicity_info'] = variants
135
136    partype = model_info['partype']
137    orientation_params = (
138            partype['orientation']
139            + [n + '.width' for n in partype['orientation']]
140            + partype['magnetic'])
141    magnetic_params = partype['magnetic']
142    fixed = [n + '.width' for n in partype['pd-2d']]
143
144    attrs['orientation_params'] = tuple(orientation_params)
145    attrs['magnetic_params'] = tuple(magnetic_params)
146    attrs['fixed'] = tuple(fixed)
147
148    attrs['non_fittable'] = tuple(non_fittable)
149
150    return attrs
151
152class SasviewModel(object):
153    """
154    Sasview wrapper for opencl/ctypes model.
155    """
156    # Model parameters for the specific model are set in the class constructor
157    # via the _generate_model_attributes function, which subclasses
158    # SasviewModel.  They are included here for typing and documentation
159    # purposes.
160    _model = None       # type: KernelModel
161    _model_info = None  # type: ModelInfo
162    #: load/save name for the model
163    id = None           # type: str
164    #: display name for the model
165    name = None         # type: str
166    #: short model description
167    description = None  # type: str
168    #: default model category
169    category = None     # type: str
170
171    #: names of the orientation parameters in the order they appear
172    orientation_params = None # type: Sequence[str]
173    #: names of the magnetic parameters in the order they appear
174    magnetic_params = None    # type: Sequence[str]
175    #: names of the fittable parameters
176    fixed = None              # type: Sequence[str]
177    # TODO: the attribute fixed is ill-named
178
179    # Axis labels
180    input_name = "Q"
181    input_unit = "A^{-1}"
182    output_name = "Intensity"
183    output_unit = "cm^{-1}"
184
185    #: default cutoff for polydispersity
186    cutoff = 1e-5
187
188    # Note: Use non-mutable values for class attributes to avoid errors
189    #: parameters that are not fitted
190    non_fittable = ()        # type: Sequence[str]
191
192    #: True if model should appear as a structure factor
193    is_structure_factor = False
194    #: True if model should appear as a form factor
195    is_form_factor = False
196    #: True if model has multiplicity
197    is_multiplicity_model = False
198    #: Mulitplicity information
199    multiplicity_info = None # type: MultiplicityInfoType
200
201    # Per-instance variables
202    #: parameter {name: value} mapping
203    params = None      # type: Dict[str, float]
204    #: values for dispersion width, npts, nsigmas and type
205    dispersion = None  # type: Dict[str, Any]
206    #: units and limits for each parameter
207    details = None     # type: Mapping[str, Tuple(str, float, float)]
208    #: multiplicity used, or None if no multiplicity controls
209    multiplicity = None     # type: Optional[int]
210
211    def __init__(self, multiplicity):
212        # type: () -> None
213        #print("initializing", self.name)
214        #raise Exception("first initialization")
215        self._model = None
216
217        ## _persistency_dict is used by sas.perspectives.fitting.basepage
218        ## to store dispersity reference.
219        self._persistency_dict = {}
220
221        self.multiplicity = multiplicity
222
223        self.params = collections.OrderedDict()
224        self.dispersion = {}
225        self.details = {}
226
227        for p in self._model_info['parameters']:
228            self.params[p.name] = p.default
229            self.details[p.name] = [p.units] + p.limits
230
231        for name in self._model_info['partype']['pd-2d']:
232            self.dispersion[name] = {
233                'width': 0,
234                'npts': 35,
235                'nsigmas': 3,
236                'type': 'gaussian',
237            }
238
239    def __get_state__(self):
240        state = self.__dict__.copy()
241        state.pop('_model')
242        # May need to reload model info on set state since it has pointers
243        # to python implementations of Iq, etc.
244        #state.pop('_model_info')
245        return state
246
247    def __set_state__(self, state):
248        self.__dict__ = state
249        self._model = None
250
251    def __str__(self):
252        """
253        :return: string representation
254        """
255        return self.name
256
257    def is_fittable(self, par_name):
258        """
259        Check if a given parameter is fittable or not
260
261        :param par_name: the parameter name to check
262        """
263        return par_name.lower() in self.fixed
264        #For the future
265        #return self.params[str(par_name)].is_fittable()
266
267
268    # pylint: disable=no-self-use
269    def getProfile(self):
270        """
271        Get SLD profile
272
273        : return: (z, beta) where z is a list of depth of the transition points
274                beta is a list of the corresponding SLD values
275        """
276        return None, None
277
278    def setParam(self, name, value):
279        """
280        Set the value of a model parameter
281
282        :param name: name of the parameter
283        :param value: value of the parameter
284
285        """
286        # Look for dispersion parameters
287        toks = name.split('.')
288        if len(toks) == 2:
289            for item in self.dispersion.keys():
290                if item.lower() == toks[0].lower():
291                    for par in self.dispersion[item]:
292                        if par.lower() == toks[1].lower():
293                            self.dispersion[item][par] = value
294                            return
295        else:
296            # Look for standard parameter
297            for item in self.params.keys():
298                if item.lower() == name.lower():
299                    self.params[item] = value
300                    return
301
302        raise ValueError("Model does not contain parameter %s" % name)
303
304    def getParam(self, name):
305        """
306        Set the value of a model parameter
307
308        :param name: name of the parameter
309
310        """
311        # Look for dispersion parameters
312        toks = name.split('.')
313        if len(toks) == 2:
314            for item in self.dispersion.keys():
315                if item.lower() == toks[0].lower():
316                    for par in self.dispersion[item]:
317                        if par.lower() == toks[1].lower():
318                            return self.dispersion[item][par]
319        else:
320            # Look for standard parameter
321            for item in self.params.keys():
322                if item.lower() == name.lower():
323                    return self.params[item]
324
325        raise ValueError("Model does not contain parameter %s" % name)
326
327    def getParamList(self):
328        """
329        Return a list of all available parameters for the model
330        """
331        param_list = self.params.keys()
332        # WARNING: Extending the list with the dispersion parameters
333        param_list.extend(self.getDispParamList())
334        return param_list
335
336    def getDispParamList(self):
337        """
338        Return a list of polydispersity parameters for the model
339        """
340        # TODO: fix test so that parameter order doesn't matter
341        ret = ['%s.%s' % (d.lower(), p)
342               for d in self._model_info['partype']['pd-2d']
343               for p in ('npts', 'nsigmas', 'width')]
344        #print(ret)
345        return ret
346
347    def clone(self):
348        """ Return a identical copy of self """
349        return deepcopy(self)
350
351    def run(self, x=0.0):
352        """
353        Evaluate the model
354
355        :param x: input q, or [q,phi]
356
357        :return: scattering function P(q)
358
359        **DEPRECATED**: use calculate_Iq instead
360        """
361        if isinstance(x, (list, tuple)):
362            # pylint: disable=unpacking-non-sequence
363            q, phi = x
364            return self.calculate_Iq([q * math.cos(phi)],
365                                     [q * math.sin(phi)])[0]
366        else:
367            return self.calculate_Iq([float(x)])[0]
368
369
370    def runXY(self, x=0.0):
371        """
372        Evaluate the model in cartesian coordinates
373
374        :param x: input q, or [qx, qy]
375
376        :return: scattering function P(q)
377
378        **DEPRECATED**: use calculate_Iq instead
379        """
380        if isinstance(x, (list, tuple)):
381            return self.calculate_Iq([float(x[0])], [float(x[1])])[0]
382        else:
383            return self.calculate_Iq([float(x)])[0]
384
385    def evalDistribution(self, qdist):
386        r"""
387        Evaluate a distribution of q-values.
388
389        :param qdist: array of q or a list of arrays [qx,qy]
390
391        * For 1D, a numpy array is expected as input
392
393        ::
394
395            evalDistribution(q)
396
397          where *q* is a numpy array.
398
399        * For 2D, a list of *[qx,qy]* is expected with 1D arrays as input
400
401        ::
402
403              qx = [ qx[0], qx[1], qx[2], ....]
404              qy = [ qy[0], qy[1], qy[2], ....]
405
406        If the model is 1D only, then
407
408        .. math::
409
410            q = \sqrt{q_x^2+q_y^2}
411
412        """
413        if isinstance(qdist, (list, tuple)):
414            # Check whether we have a list of ndarrays [qx,qy]
415            qx, qy = qdist
416            partype = self._model_info['partype']
417            if not partype['orientation'] and not partype['magnetic']:
418                return self.calculate_Iq(np.sqrt(qx ** 2 + qy ** 2))
419            else:
420                return self.calculate_Iq(qx, qy)
421
422        elif isinstance(qdist, np.ndarray):
423            # We have a simple 1D distribution of q-values
424            return self.calculate_Iq(qdist)
425
426        else:
427            raise TypeError("evalDistribution expects q or [qx, qy], not %r"
428                            % type(qdist))
429
430    def calculate_Iq(self, *args):
431        """
432        Calculate Iq for one set of q with the current parameters.
433
434        If the model is 1D, use *q*.  If 2D, use *qx*, *qy*.
435
436        This should NOT be used for fitting since it copies the *q* vectors
437        to the card for each evaluation.
438        """
439        if self._model is None:
440            self._model = core.build_model(self._model_info)
441        q_vectors = [np.asarray(q) for q in args]
442        fn = self._model.make_kernel(q_vectors)
443        pars = [self.params[v] for v in fn.fixed_pars]
444        pd_pars = [self._get_weights(p) for p in fn.pd_pars]
445        result = fn(pars, pd_pars, self.cutoff)
446        fn.q_input.release()
447        fn.release()
448        return result
449
450    def calculate_ER(self):
451        """
452        Calculate the effective radius for P(q)*S(q)
453
454        :return: the value of the effective radius
455        """
456        ER = self._model_info.get('ER', None)
457        if ER is None:
458            return 1.0
459        else:
460            values, weights = self._dispersion_mesh()
461            fv = ER(*values)
462            #print(values[0].shape, weights.shape, fv.shape)
463            return np.sum(weights * fv) / np.sum(weights)
464
465    def calculate_VR(self):
466        """
467        Calculate the volf ratio for P(q)*S(q)
468
469        :return: the value of the volf ratio
470        """
471        VR = self._model_info.get('VR', None)
472        if VR is None:
473            return 1.0
474        else:
475            values, weights = self._dispersion_mesh()
476            whole, part = VR(*values)
477            return np.sum(weights * part) / np.sum(weights * whole)
478
479    def set_dispersion(self, parameter, dispersion):
480        """
481        Set the dispersion object for a model parameter
482
483        :param parameter: name of the parameter [string]
484        :param dispersion: dispersion object of type Dispersion
485        """
486        if parameter.lower() in (s.lower() for s in self.params.keys()):
487            # TODO: Store the disperser object directly in the model.
488            # The current method of creating one on the fly whenever it is
489            # needed is kind of funky.
490            # Note: can't seem to get disperser parameters from sasview
491            # (1) Could create a sasview model that has not yet # been
492            # converted, assign the disperser to one of its polydisperse
493            # parameters, then retrieve the disperser parameters from the
494            # sasview model.  (2) Could write a disperser parameter retriever
495            # in sasview.  (3) Could modify sasview to use sasmodels.weights
496            # dispersers.
497            # For now, rely on the fact that the sasview only ever uses
498            # new dispersers in the set_dispersion call and create a new
499            # one instead of trying to assign parameters.
500            from . import weights
501            disperser = weights.dispersers[dispersion.__class__.__name__]
502            dispersion = weights.models[disperser]()
503            self.dispersion[parameter] = dispersion.get_pars()
504        else:
505            raise ValueError("%r is not a dispersity or orientation parameter")
506
507    def _dispersion_mesh(self):
508        """
509        Create a mesh grid of dispersion parameters and weights.
510
511        Returns [p1,p2,...],w where pj is a vector of values for parameter j
512        and w is a vector containing the products for weights for each
513        parameter set in the vector.
514        """
515        pars = self._model_info['partype']['volume']
516        return core.dispersion_mesh([self._get_weights(p) for p in pars])
517
518    def _get_weights(self, par):
519        """
520        Return dispersion weights for parameter
521        """
522        from . import weights
523        relative = self._model_info['partype']['pd-rel']
524        limits = self._model_info['limits']
525        dis = self.dispersion[par]
526        value, weight = weights.get_weights(
527            dis['type'], dis['npts'], dis['width'], dis['nsigmas'],
528            self.params[par], limits[par], par in relative)
529        return value, weight / np.sum(weight)
530
531
532def test_model():
533    """
534    Test that a sasview model (cylinder) can be run.
535    """
536    Cylinder = _make_standard_model('cylinder')
537    cylinder = Cylinder()
538    return cylinder.evalDistribution([0.1,0.1])
539
540
541def test_model_list():
542    """
543    Make sure that all models build as sasview models.
544    """
545    from .exception import annotate_exception
546    for name in core.list_models():
547        try:
548            _make_standard_model(name)
549        except:
550            annotate_exception("when loading "+name)
551            raise
552
553if __name__ == "__main__":
554    print("cylinder(0.1,0.1)=%g"%test_model())
Note: See TracBrowser for help on using the repository browser.