source: sasmodels/sasmodels/sasview_model.py @ a738209

core_shell_microgelscostrafo411magnetic_modelrelease_v0.94release_v0.95ticket-1257-vesicle-productticket_1156ticket_1265_superballticket_822_more_unit_tests
Last change on this file since a738209 was a738209, checked in by Paul Kienzle <pkienzle@…>, 8 years ago

simplify kernels by remove coordination parameter logic

  • Property mode set to 100644
File size: 21.8 KB
Line 
1"""
2Sasview model constructor.
3
4Given a module defining an OpenCL kernel such as sasmodels.models.cylinder,
5create a sasview model class to run that kernel as follows::
6
7    from sasmodels.sasview_model import load_custom_model
8    CylinderModel = load_custom_model('sasmodels/models/cylinder.py')
9"""
10from __future__ import print_function
11
12import math
13from copy import deepcopy
14import collections
15import traceback
16import logging
17
18import numpy as np  # type: ignore
19
20from . import core
21from . import custom
22from . import generate
23from . import weights
24from . import modelinfo
25from . import kernel
26
27try:
28    from typing import Dict, Mapping, Any, Sequence, Tuple, NamedTuple, List, Optional, Union, Callable
29    from .modelinfo import ModelInfo, Parameter
30    from .kernel import KernelModel
31    MultiplicityInfoType = NamedTuple(
32        'MuliplicityInfo',
33        [("number", int), ("control", str), ("choices", List[str]),
34         ("x_axis_label", str)])
35    SasviewModelType = Callable[[int], "SasviewModel"]
36except ImportError:
37    pass
38
39# TODO: separate x_axis_label from multiplicity info
40# The profile x-axis label belongs with the profile generating function
41MultiplicityInfo = collections.namedtuple(
42    'MultiplicityInfo',
43    ["number", "control", "choices", "x_axis_label"],
44)
45
46MODELS = {}
47def find_model(modelname):
48    # TODO: used by sum/product model to load an existing model
49    # TODO: doesn't handle custom models properly
50    if modelname.endswith('.py'):
51        return load_custom_model(modelname)
52    elif modelname in MODELS:
53        return MODELS[modelname]
54    else:
55        raise ValueError("unknown model %r"%modelname)
56
57
58# TODO: figure out how to say that the return type is a subclass
59def load_standard_models():
60    # type: () -> List[SasviewModelType]
61    """
62    Load and return the list of predefined models.
63
64    If there is an error loading a model, then a traceback is logged and the
65    model is not returned.
66    """
67    models = []
68    for name in core.list_models():
69        try:
70            MODELS[name] = _make_standard_model(name)
71            models.append(MODELS[name])
72        except Exception:
73            logging.error(traceback.format_exc())
74    return models
75
76
77def load_custom_model(path):
78    # type: (str) -> SasviewModelType
79    """
80    Load a custom model given the model path.
81    """
82    #print("load custom model", path)
83    kernel_module = custom.load_custom_kernel_module(path)
84    try:
85        model = kernel_module.Model
86    except AttributeError:
87        model_info = modelinfo.make_model_info(kernel_module)
88        model = _make_model_from_info(model_info)
89    MODELS[model.name] = model
90    return model
91
92
93def _make_standard_model(name):
94    # type: (str) -> SasviewModelType
95    """
96    Load the sasview model defined by *name*.
97
98    *name* can be a standard model name or a path to a custom model.
99
100    Returns a class that can be used directly as a sasview model.
101    """
102    kernel_module = generate.load_kernel_module(name)
103    model_info = modelinfo.make_model_info(kernel_module)
104    return _make_model_from_info(model_info)
105
106
107def _make_model_from_info(model_info):
108    # type: (ModelInfo) -> SasviewModelType
109    """
110    Convert *model_info* into a SasView model wrapper.
111    """
112    def __init__(self, multiplicity=None):
113        SasviewModel.__init__(self, multiplicity=multiplicity)
114    attrs = _generate_model_attributes(model_info)
115    attrs['__init__'] = __init__
116    ConstructedModel = type(model_info.name, (SasviewModel,), attrs) # type: SasviewModelType
117    return ConstructedModel
118
119def _generate_model_attributes(model_info):
120    # type: (ModelInfo) -> Dict[str, Any]
121    """
122    Generate the class attributes for the model.
123
124    This should include all the information necessary to query the model
125    details so that you do not need to instantiate a model to query it.
126
127    All the attributes should be immutable to avoid accidents.
128    """
129
130    # TODO: allow model to override axis labels input/output name/unit
131
132    # Process multiplicity
133    non_fittable = []  # type: List[str]
134    xlabel = model_info.profile_axes[0] if model_info.profile is not None else ""
135    variants = MultiplicityInfo(0, "", [], xlabel)
136    for p in model_info.parameters.kernel_parameters:
137        if p.name == model_info.control:
138            non_fittable.append(p.name)
139            variants = MultiplicityInfo(
140                len(p.choices), p.name, p.choices, xlabel
141            )
142            break
143        elif p.is_control:
144            non_fittable.append(p.name)
145            variants = MultiplicityInfo(
146                int(p.limits[1]), p.name, p.choices, xlabel
147            )
148            break
149
150    # Organize parameter sets
151    orientation_params = []
152    magnetic_params = []
153    fixed = []
154    for p in model_info.parameters.user_parameters():
155        if p.type == 'orientation':
156            orientation_params.append(p.name)
157            orientation_params.append(p.name+".width")
158            fixed.append(p.name+".width")
159        if p.type == 'magnetic':
160            orientation_params.append(p.name)
161            magnetic_params.append(p.name)
162            fixed.append(p.name+".width")
163
164    # Build class dictionary
165    attrs = {}  # type: Dict[str, Any]
166    attrs['_model_info'] = model_info
167    attrs['name'] = model_info.name
168    attrs['id'] = model_info.id
169    attrs['description'] = model_info.description
170    attrs['category'] = model_info.category
171    attrs['is_structure_factor'] = model_info.structure_factor
172    attrs['is_form_factor'] = model_info.ER is not None
173    attrs['is_multiplicity_model'] = variants[0] > 1
174    attrs['multiplicity_info'] = variants
175    attrs['orientation_params'] = tuple(orientation_params)
176    attrs['magnetic_params'] = tuple(magnetic_params)
177    attrs['fixed'] = tuple(fixed)
178    attrs['non_fittable'] = tuple(non_fittable)
179
180    return attrs
181
182class SasviewModel(object):
183    """
184    Sasview wrapper for opencl/ctypes model.
185    """
186    # Model parameters for the specific model are set in the class constructor
187    # via the _generate_model_attributes function, which subclasses
188    # SasviewModel.  They are included here for typing and documentation
189    # purposes.
190    _model = None       # type: KernelModel
191    _model_info = None  # type: ModelInfo
192    #: load/save name for the model
193    id = None           # type: str
194    #: display name for the model
195    name = None         # type: str
196    #: short model description
197    description = None  # type: str
198    #: default model category
199    category = None     # type: str
200
201    #: names of the orientation parameters in the order they appear
202    orientation_params = None # type: Sequence[str]
203    #: names of the magnetic parameters in the order they appear
204    magnetic_params = None    # type: Sequence[str]
205    #: names of the fittable parameters
206    fixed = None              # type: Sequence[str]
207    # TODO: the attribute fixed is ill-named
208
209    # Axis labels
210    input_name = "Q"
211    input_unit = "A^{-1}"
212    output_name = "Intensity"
213    output_unit = "cm^{-1}"
214
215    #: default cutoff for polydispersity
216    cutoff = 1e-5
217
218    # Note: Use non-mutable values for class attributes to avoid errors
219    #: parameters that are not fitted
220    non_fittable = ()        # type: Sequence[str]
221
222    #: True if model should appear as a structure factor
223    is_structure_factor = False
224    #: True if model should appear as a form factor
225    is_form_factor = False
226    #: True if model has multiplicity
227    is_multiplicity_model = False
228    #: Mulitplicity information
229    multiplicity_info = None # type: MultiplicityInfoType
230
231    # Per-instance variables
232    #: parameter {name: value} mapping
233    params = None      # type: Dict[str, float]
234    #: values for dispersion width, npts, nsigmas and type
235    dispersion = None  # type: Dict[str, Any]
236    #: units and limits for each parameter
237    details = None     # type: Dict[str, Sequence[Any]]
238    #                  # actual type is Dict[str, List[str, float, float]]
239    #: multiplicity value, or None if no multiplicity on the model
240    multiplicity = None     # type: Optional[int]
241    #: memory for polydispersity array if using ArrayDispersion (used by sasview).
242    _persistency_dict = None # type: Dict[str, Tuple[np.ndarray, np.ndarray]]
243
244    def __init__(self, multiplicity=None):
245        # type: (Optional[int]) -> None
246
247        # TODO: _persistency_dict to persistency_dict throughout sasview
248        # TODO: refactor multiplicity to encompass variants
249        # TODO: dispersion should be a class
250        # TODO: refactor multiplicity info
251        # TODO: separate profile view from multiplicity
252        # The button label, x and y axis labels and scale need to be under
253        # the control of the model, not the fit page.  Maximum flexibility,
254        # the fit page would supply the canvas and the profile could plot
255        # how it wants, but this assumes matplotlib.  Next level is that
256        # we provide some sort of data description including title, labels
257        # and lines to plot.
258
259        # Get the list of hidden parameters given the mulitplicity
260        # Don't include multiplicity in the list of parameters
261        self.multiplicity = multiplicity
262        if multiplicity is not None:
263            hidden = self._model_info.get_hidden_parameters(multiplicity)
264            hidden |= set([self.multiplicity_info.control])
265        else:
266            hidden = set()
267
268        self._persistency_dict = {}
269        self.params = collections.OrderedDict()
270        self.dispersion = collections.OrderedDict()
271        self.details = {}
272        for p in self._model_info.parameters.user_parameters():
273            if p.name in hidden:
274                continue
275            self.params[p.name] = p.default
276            self.details[p.id] = [p.units, p.limits[0], p.limits[1]]
277            if p.polydisperse:
278                self.details[p.id+".width"] = [
279                    "", 0.0, 1.0 if p.relative_pd else np.inf
280                ]
281                self.dispersion[p.name] = {
282                    'width': 0,
283                    'npts': 35,
284                    'nsigmas': 3,
285                    'type': 'gaussian',
286                }
287
288    def __get_state__(self):
289        # type: () -> Dict[str, Any]
290        state = self.__dict__.copy()
291        state.pop('_model')
292        # May need to reload model info on set state since it has pointers
293        # to python implementations of Iq, etc.
294        #state.pop('_model_info')
295        return state
296
297    def __set_state__(self, state):
298        # type: (Dict[str, Any]) -> None
299        self.__dict__ = state
300        self._model = None
301
302    def __str__(self):
303        # type: () -> str
304        """
305        :return: string representation
306        """
307        return self.name
308
309    def is_fittable(self, par_name):
310        # type: (str) -> bool
311        """
312        Check if a given parameter is fittable or not
313
314        :param par_name: the parameter name to check
315        """
316        return par_name in self.fixed
317        #For the future
318        #return self.params[str(par_name)].is_fittable()
319
320
321    def getProfile(self):
322        # type: () -> (np.ndarray, np.ndarray)
323        """
324        Get SLD profile
325
326        : return: (z, beta) where z is a list of depth of the transition points
327                beta is a list of the corresponding SLD values
328        """
329        args = [] # type: List[Union[float, np.ndarray]]
330        for p in self._model_info.parameters.kernel_parameters:
331            if p.id == self.multiplicity_info.control:
332                args.append(float(self.multiplicity))
333            elif p.length == 1:
334                args.append(self.params.get(p.id, np.NaN))
335            else:
336                args.append([self.params.get(p.id+str(k), np.NaN)
337                             for k in range(1,p.length+1)])
338        return self._model_info.profile(*args)
339
340    def setParam(self, name, value):
341        # type: (str, float) -> None
342        """
343        Set the value of a model parameter
344
345        :param name: name of the parameter
346        :param value: value of the parameter
347
348        """
349        # Look for dispersion parameters
350        toks = name.split('.')
351        if len(toks) == 2:
352            for item in self.dispersion.keys():
353                if item == toks[0]:
354                    for par in self.dispersion[item]:
355                        if par == toks[1]:
356                            self.dispersion[item][par] = value
357                            return
358        else:
359            # Look for standard parameter
360            for item in self.params.keys():
361                if item == name:
362                    self.params[item] = value
363                    return
364
365        raise ValueError("Model does not contain parameter %s" % name)
366
367    def getParam(self, name):
368        # type: (str) -> float
369        """
370        Set the value of a model parameter
371
372        :param name: name of the parameter
373
374        """
375        # Look for dispersion parameters
376        toks = name.split('.')
377        if len(toks) == 2:
378            for item in self.dispersion.keys():
379                if item == toks[0]:
380                    for par in self.dispersion[item]:
381                        if par == toks[1]:
382                            return self.dispersion[item][par]
383        else:
384            # Look for standard parameter
385            for item in self.params.keys():
386                if item == name:
387                    return self.params[item]
388
389        raise ValueError("Model does not contain parameter %s" % name)
390
391    def getParamList(self):
392        # type: () -> Sequence[str]
393        """
394        Return a list of all available parameters for the model
395        """
396        param_list = list(self.params.keys())
397        # WARNING: Extending the list with the dispersion parameters
398        param_list.extend(self.getDispParamList())
399        return param_list
400
401    def getDispParamList(self):
402        # type: () -> Sequence[str]
403        """
404        Return a list of polydispersity parameters for the model
405        """
406        # TODO: fix test so that parameter order doesn't matter
407        ret = ['%s.%s' % (p.name, ext)
408               for p in self._model_info.parameters.user_parameters()
409               for ext in ('npts', 'nsigmas', 'width')
410               if p.polydisperse]
411        #print(ret)
412        return ret
413
414    def clone(self):
415        # type: () -> "SasviewModel"
416        """ Return a identical copy of self """
417        return deepcopy(self)
418
419    def run(self, x=0.0):
420        # type: (Union[float, (float, float), List[float]]) -> float
421        """
422        Evaluate the model
423
424        :param x: input q, or [q,phi]
425
426        :return: scattering function P(q)
427
428        **DEPRECATED**: use calculate_Iq instead
429        """
430        if isinstance(x, (list, tuple)):
431            # pylint: disable=unpacking-non-sequence
432            q, phi = x
433            return self.calculate_Iq([q*math.cos(phi)], [q*math.sin(phi)])[0]
434        else:
435            return self.calculate_Iq([x])[0]
436
437
438    def runXY(self, x=0.0):
439        # type: (Union[float, (float, float), List[float]]) -> float
440        """
441        Evaluate the model in cartesian coordinates
442
443        :param x: input q, or [qx, qy]
444
445        :return: scattering function P(q)
446
447        **DEPRECATED**: use calculate_Iq instead
448        """
449        if isinstance(x, (list, tuple)):
450            return self.calculate_Iq([x[0]], [x[1]])[0]
451        else:
452            return self.calculate_Iq([x])[0]
453
454    def evalDistribution(self, qdist):
455        # type: (Union[np.ndarray, Tuple[np.ndarray, np.ndarray], List[np.ndarray]]) -> np.ndarray
456        r"""
457        Evaluate a distribution of q-values.
458
459        :param qdist: array of q or a list of arrays [qx,qy]
460
461        * For 1D, a numpy array is expected as input
462
463        ::
464
465            evalDistribution(q)
466
467          where *q* is a numpy array.
468
469        * For 2D, a list of *[qx,qy]* is expected with 1D arrays as input
470
471        ::
472
473              qx = [ qx[0], qx[1], qx[2], ....]
474              qy = [ qy[0], qy[1], qy[2], ....]
475
476        If the model is 1D only, then
477
478        .. math::
479
480            q = \sqrt{q_x^2+q_y^2}
481
482        """
483        if isinstance(qdist, (list, tuple)):
484            # Check whether we have a list of ndarrays [qx,qy]
485            qx, qy = qdist
486            if not self._model_info.parameters.has_2d:
487                return self.calculate_Iq(np.sqrt(qx ** 2 + qy ** 2))
488            else:
489                return self.calculate_Iq(qx, qy)
490
491        elif isinstance(qdist, np.ndarray):
492            # We have a simple 1D distribution of q-values
493            return self.calculate_Iq(qdist)
494
495        else:
496            raise TypeError("evalDistribution expects q or [qx, qy], not %r"
497                            % type(qdist))
498
499    def calculate_Iq(self, qx, qy=None):
500        # type: (Sequence[float], Optional[Sequence[float]]) -> np.ndarray
501        """
502        Calculate Iq for one set of q with the current parameters.
503
504        If the model is 1D, use *q*.  If 2D, use *qx*, *qy*.
505
506        This should NOT be used for fitting since it copies the *q* vectors
507        to the card for each evaluation.
508        """
509        if self._model is None:
510            self._model = core.build_model(self._model_info)
511        if qy is not None:
512            q_vectors = [np.asarray(qx), np.asarray(qy)]
513        else:
514            q_vectors = [np.asarray(qx)]
515        calculator = self._model.make_kernel(q_vectors)
516        pairs = [self._get_weights(p)
517                 for p in self._model_info.parameters.call_parameters]
518        call_details, value = kernel.build_details(calculator, pairs)
519        result = calculator(call_details, value, cutoff=self.cutoff)
520        calculator.release()
521        return result
522
523    def calculate_ER(self):
524        # type: () -> float
525        """
526        Calculate the effective radius for P(q)*S(q)
527
528        :return: the value of the effective radius
529        """
530        if self._model_info.ER is None:
531            return 1.0
532        else:
533            value, weight = self._dispersion_mesh()
534            fv = self._model_info.ER(*value)
535            #print(values[0].shape, weights.shape, fv.shape)
536            return np.sum(weight * fv) / np.sum(weight)
537
538    def calculate_VR(self):
539        # type: () -> float
540        """
541        Calculate the volf ratio for P(q)*S(q)
542
543        :return: the value of the volf ratio
544        """
545        if self._model_info.VR is None:
546            return 1.0
547        else:
548            value, weight = self._dispersion_mesh()
549            whole, part = self._model_info.VR(*value)
550            return np.sum(weight * part) / np.sum(weight * whole)
551
552    def set_dispersion(self, parameter, dispersion):
553        # type: (str, weights.Dispersion) -> Dict[str, Any]
554        """
555        Set the dispersion object for a model parameter
556
557        :param parameter: name of the parameter [string]
558        :param dispersion: dispersion object of type Dispersion
559        """
560        if parameter in self.params:
561            # TODO: Store the disperser object directly in the model.
562            # The current method of relying on the sasview GUI to
563            # remember them is kind of funky.
564            # Note: can't seem to get disperser parameters from sasview
565            # (1) Could create a sasview model that has not yet # been
566            # converted, assign the disperser to one of its polydisperse
567            # parameters, then retrieve the disperser parameters from the
568            # sasview model.  (2) Could write a disperser parameter retriever
569            # in sasview.  (3) Could modify sasview to use sasmodels.weights
570            # dispersers.
571            # For now, rely on the fact that the sasview only ever uses
572            # new dispersers in the set_dispersion call and create a new
573            # one instead of trying to assign parameters.
574            from . import weights
575            disperser = weights.dispersers[dispersion.__class__.__name__]
576            dispersion = weights.MODELS[disperser]()
577            self.dispersion[parameter] = dispersion.get_pars()
578        else:
579            raise ValueError("%r is not a dispersity or orientation parameter")
580
581    def _dispersion_mesh(self):
582        # type: () -> List[Tuple[np.ndarray, np.ndarray]]
583        """
584        Create a mesh grid of dispersion parameters and weights.
585
586        Returns [p1,p2,...],w where pj is a vector of values for parameter j
587        and w is a vector containing the products for weights for each
588        parameter set in the vector.
589        """
590        pars = [self._get_weights(p)
591                for p in self._model_info.parameters.call_parameters
592                if p.type == 'volume']
593        return kernel.dispersion_mesh(self._model_info, pars)
594
595    def _get_weights(self, par):
596        # type: (Parameter) -> Tuple[np.ndarray, np.ndarray]
597        """
598        Return dispersion weights for parameter
599        """
600        if par.name not in self.params:
601            if par.name == self.multiplicity_info.control:
602                return [self.multiplicity], []
603            else:
604                return [np.NaN], []
605        elif par.polydisperse:
606            dis = self.dispersion[par.name]
607            value, weight = weights.get_weights(
608                dis['type'], dis['npts'], dis['width'], dis['nsigmas'],
609                self.params[par.name], par.limits, par.relative_pd)
610            return value, weight / np.sum(weight)
611        else:
612            return [self.params[par.name]], []
613
614def test_model():
615    # type: () -> float
616    """
617    Test that a sasview model (cylinder) can be run.
618    """
619    Cylinder = _make_standard_model('cylinder')
620    cylinder = Cylinder()
621    return cylinder.evalDistribution([0.1,0.1])
622
623def test_rpa():
624    # type: () -> float
625    """
626    Test that a sasview model (cylinder) can be run.
627    """
628    RPA = _make_standard_model('rpa')
629    rpa = RPA(3)
630    return rpa.evalDistribution([0.1,0.1])
631
632
633def test_model_list():
634    # type: () -> None
635    """
636    Make sure that all models build as sasview models.
637    """
638    from .exception import annotate_exception
639    for name in core.list_models():
640        try:
641            _make_standard_model(name)
642        except:
643            annotate_exception("when loading "+name)
644            raise
645
646if __name__ == "__main__":
647    print("cylinder(0.1,0.1)=%g"%test_model())
Note: See TracBrowser for help on using the repository browser.