source: sasmodels/sasmodels/sasview_model.py @ 50ec515

core_shell_microgelscostrafo411magnetic_modelrelease_v0.94release_v0.95ticket-1257-vesicle-productticket_1156ticket_1265_superballticket_822_more_unit_tests
Last change on this file since 50ec515 was 50ec515, checked in by Paul Kienzle <pkienzle@…>, 6 years ago

spherical sld: document interface shape number→interface relationship since UI doesn't show dropdown list yet

  • Property mode set to 100644
File size: 22.3 KB
Line 
1"""
2Sasview model constructor.
3
4Given a module defining an OpenCL kernel such as sasmodels.models.cylinder,
5create a sasview model class to run that kernel as follows::
6
7    from sasmodels.sasview_model import load_custom_model
8    CylinderModel = load_custom_model('sasmodels/models/cylinder.py')
9"""
10from __future__ import print_function
11
12import math
13from copy import deepcopy
14import collections
15import traceback
16import logging
17
18import numpy as np  # type: ignore
19
20from . import core
21from . import custom
22from . import generate
23from . import weights
24from . import modelinfo
25from .details import build_details, dispersion_mesh
26
27try:
28    from typing import Dict, Mapping, Any, Sequence, Tuple, NamedTuple, List, Optional, Union, Callable
29    from .modelinfo import ModelInfo, Parameter
30    from .kernel import KernelModel
31    MultiplicityInfoType = NamedTuple(
32        'MuliplicityInfo',
33        [("number", int), ("control", str), ("choices", List[str]),
34         ("x_axis_label", str)])
35    SasviewModelType = Callable[[int], "SasviewModel"]
36except ImportError:
37    pass
38
39# TODO: separate x_axis_label from multiplicity info
40MultiplicityInfo = collections.namedtuple(
41    'MultiplicityInfo',
42    ["number", "control", "choices", "x_axis_label"],
43)
44
45MODELS = {}
46def find_model(modelname):
47    # TODO: used by sum/product model to load an existing model
48    # TODO: doesn't handle custom models properly
49    if modelname.endswith('.py'):
50        return load_custom_model(modelname)
51    elif modelname in MODELS:
52        return MODELS[modelname]
53    else:
54        raise ValueError("unknown model %r"%modelname)
55
56
57# TODO: figure out how to say that the return type is a subclass
58def load_standard_models():
59    # type: () -> List[SasviewModelType]
60    """
61    Load and return the list of predefined models.
62
63    If there is an error loading a model, then a traceback is logged and the
64    model is not returned.
65    """
66    models = []
67    for name in core.list_models():
68        try:
69            MODELS[name] = _make_standard_model(name)
70            models.append(MODELS[name])
71        except Exception:
72            logging.error(traceback.format_exc())
73    return models
74
75
76def load_custom_model(path):
77    # type: (str) -> SasviewModelType
78    """
79    Load a custom model given the model path.
80    """
81    #print("load custom model", path)
82    kernel_module = custom.load_custom_kernel_module(path)
83    try:
84        model = kernel_module.Model
85    except AttributeError:
86        model_info = modelinfo.make_model_info(kernel_module)
87        model = _make_model_from_info(model_info)
88    MODELS[model.name] = model
89    return model
90
91
92def _make_standard_model(name):
93    # type: (str) -> SasviewModelType
94    """
95    Load the sasview model defined by *name*.
96
97    *name* can be a standard model name or a path to a custom model.
98
99    Returns a class that can be used directly as a sasview model.
100    """
101    kernel_module = generate.load_kernel_module(name)
102    model_info = modelinfo.make_model_info(kernel_module)
103    return _make_model_from_info(model_info)
104
105
106def _make_model_from_info(model_info):
107    # type: (ModelInfo) -> SasviewModelType
108    """
109    Convert *model_info* into a SasView model wrapper.
110    """
111    def __init__(self, multiplicity=None):
112        SasviewModel.__init__(self, multiplicity=multiplicity)
113    attrs = _generate_model_attributes(model_info)
114    attrs['__init__'] = __init__
115    ConstructedModel = type(model_info.name, (SasviewModel,), attrs) # type: SasviewModelType
116    return ConstructedModel
117
118def _generate_model_attributes(model_info):
119    # type: (ModelInfo) -> Dict[str, Any]
120    """
121    Generate the class attributes for the model.
122
123    This should include all the information necessary to query the model
124    details so that you do not need to instantiate a model to query it.
125
126    All the attributes should be immutable to avoid accidents.
127    """
128
129    # TODO: allow model to override axis labels input/output name/unit
130
131    # Process multiplicity
132    non_fittable = []  # type: List[str]
133    xlabel = model_info.profile_axes[0] if model_info.profile is not None else ""
134    variants = MultiplicityInfo(0, "", [], xlabel)
135    for p in model_info.parameters.kernel_parameters:
136        if p.name == model_info.control:
137            non_fittable.append(p.name)
138            variants = MultiplicityInfo(
139                len(p.choices) if p.choices else int(p.limits[1]),
140                p.name, p.choices, xlabel
141            )
142            break
143
144    # Only a single drop-down list parameter available
145    fun_list = []
146    for p in model_info.parameters.kernel_parameters:
147        if p.choices:
148            fun_list = p.choices
149            if p.length > 1:
150                non_fittable.extend(p.id+str(k) for k in range(1, p.length+1))
151            break
152
153    # Organize parameter sets
154    orientation_params = []
155    magnetic_params = []
156    fixed = []
157    for p in model_info.parameters.user_parameters():
158        if p.type == 'orientation':
159            orientation_params.append(p.name)
160            orientation_params.append(p.name+".width")
161            fixed.append(p.name+".width")
162        elif p.type == 'magnetic':
163            orientation_params.append(p.name)
164            magnetic_params.append(p.name)
165            fixed.append(p.name+".width")
166
167
168    # Build class dictionary
169    attrs = {}  # type: Dict[str, Any]
170    attrs['_model_info'] = model_info
171    attrs['name'] = model_info.name
172    attrs['id'] = model_info.id
173    attrs['description'] = model_info.description
174    attrs['category'] = model_info.category
175    attrs['is_structure_factor'] = model_info.structure_factor
176    attrs['is_form_factor'] = model_info.ER is not None
177    attrs['is_multiplicity_model'] = variants[0] > 1
178    attrs['multiplicity_info'] = variants
179    attrs['orientation_params'] = tuple(orientation_params)
180    attrs['magnetic_params'] = tuple(magnetic_params)
181    attrs['fixed'] = tuple(fixed)
182    attrs['non_fittable'] = tuple(non_fittable)
183    attrs['fun_list'] = tuple(fun_list)
184
185    return attrs
186
187class SasviewModel(object):
188    """
189    Sasview wrapper for opencl/ctypes model.
190    """
191    # Model parameters for the specific model are set in the class constructor
192    # via the _generate_model_attributes function, which subclasses
193    # SasviewModel.  They are included here for typing and documentation
194    # purposes.
195    _model = None       # type: KernelModel
196    _model_info = None  # type: ModelInfo
197    #: load/save name for the model
198    id = None           # type: str
199    #: display name for the model
200    name = None         # type: str
201    #: short model description
202    description = None  # type: str
203    #: default model category
204    category = None     # type: str
205
206    #: names of the orientation parameters in the order they appear
207    orientation_params = None # type: Sequence[str]
208    #: names of the magnetic parameters in the order they appear
209    magnetic_params = None    # type: Sequence[str]
210    #: names of the fittable parameters
211    fixed = None              # type: Sequence[str]
212    # TODO: the attribute fixed is ill-named
213
214    # Axis labels
215    input_name = "Q"
216    input_unit = "A^{-1}"
217    output_name = "Intensity"
218    output_unit = "cm^{-1}"
219
220    #: default cutoff for polydispersity
221    cutoff = 1e-5
222
223    # Note: Use non-mutable values for class attributes to avoid errors
224    #: parameters that are not fitted
225    non_fittable = ()        # type: Sequence[str]
226
227    #: True if model should appear as a structure factor
228    is_structure_factor = False
229    #: True if model should appear as a form factor
230    is_form_factor = False
231    #: True if model has multiplicity
232    is_multiplicity_model = False
233    #: Mulitplicity information
234    multiplicity_info = None # type: MultiplicityInfoType
235
236    # Per-instance variables
237    #: parameter {name: value} mapping
238    params = None      # type: Dict[str, float]
239    #: values for dispersion width, npts, nsigmas and type
240    dispersion = None  # type: Dict[str, Any]
241    #: units and limits for each parameter
242    details = None     # type: Dict[str, Sequence[Any]]
243    #                  # actual type is Dict[str, List[str, float, float]]
244    #: multiplicity value, or None if no multiplicity on the model
245    multiplicity = None     # type: Optional[int]
246    #: memory for polydispersity array if using ArrayDispersion (used by sasview).
247    _persistency_dict = None # type: Dict[str, Tuple[np.ndarray, np.ndarray]]
248
249    def __init__(self, multiplicity=None):
250        # type: (Optional[int]) -> None
251
252        # TODO: _persistency_dict to persistency_dict throughout sasview
253        # TODO: refactor multiplicity to encompass variants
254        # TODO: dispersion should be a class
255        # TODO: refactor multiplicity info
256        # TODO: separate profile view from multiplicity
257        # The button label, x and y axis labels and scale need to be under
258        # the control of the model, not the fit page.  Maximum flexibility,
259        # the fit page would supply the canvas and the profile could plot
260        # how it wants, but this assumes matplotlib.  Next level is that
261        # we provide some sort of data description including title, labels
262        # and lines to plot.
263
264        # Get the list of hidden parameters given the mulitplicity
265        # Don't include multiplicity in the list of parameters
266        self.multiplicity = multiplicity
267        if multiplicity is not None:
268            hidden = self._model_info.get_hidden_parameters(multiplicity)
269            hidden |= set([self.multiplicity_info.control])
270        else:
271            hidden = set()
272
273        self._persistency_dict = {}
274        self.params = collections.OrderedDict()
275        self.dispersion = collections.OrderedDict()
276        self.details = {}
277        for p in self._model_info.parameters.user_parameters():
278            if p.name in hidden:
279                continue
280            self.params[p.name] = p.default
281            self.details[p.id] = [p.units, p.limits[0], p.limits[1]]
282            if p.polydisperse:
283                self.details[p.id+".width"] = [
284                    "", 0.0, 1.0 if p.relative_pd else np.inf
285                ]
286                self.dispersion[p.name] = {
287                    'width': 0,
288                    'npts': 35,
289                    'nsigmas': 3,
290                    'type': 'gaussian',
291                }
292
293    def __get_state__(self):
294        # type: () -> Dict[str, Any]
295        state = self.__dict__.copy()
296        state.pop('_model')
297        # May need to reload model info on set state since it has pointers
298        # to python implementations of Iq, etc.
299        #state.pop('_model_info')
300        return state
301
302    def __set_state__(self, state):
303        # type: (Dict[str, Any]) -> None
304        self.__dict__ = state
305        self._model = None
306
307    def __str__(self):
308        # type: () -> str
309        """
310        :return: string representation
311        """
312        return self.name
313
314    def is_fittable(self, par_name):
315        # type: (str) -> bool
316        """
317        Check if a given parameter is fittable or not
318
319        :param par_name: the parameter name to check
320        """
321        return par_name in self.fixed
322        #For the future
323        #return self.params[str(par_name)].is_fittable()
324
325
326    def getProfile(self):
327        # type: () -> (np.ndarray, np.ndarray)
328        """
329        Get SLD profile
330
331        : return: (z, beta) where z is a list of depth of the transition points
332                beta is a list of the corresponding SLD values
333        """
334        args = {} # type: Dict[str, Any]
335        for p in self._model_info.parameters.kernel_parameters:
336            if p.id == self.multiplicity_info.control:
337                value = float(self.multiplicity)
338            elif p.length == 1:
339                value = self.params.get(p.id, np.NaN)
340            else:
341                value = np.array([self.params.get(p.id+str(k), np.NaN)
342                                  for k in range(1,p.length+1)])
343            args[p.id] = value
344
345        x, y = self._model_info.profile(**args)
346        return x, 1e-6*y
347
348    def setParam(self, name, value):
349        # type: (str, float) -> None
350        """
351        Set the value of a model parameter
352
353        :param name: name of the parameter
354        :param value: value of the parameter
355
356        """
357        # Look for dispersion parameters
358        toks = name.split('.')
359        if len(toks) == 2:
360            for item in self.dispersion.keys():
361                if item == toks[0]:
362                    for par in self.dispersion[item]:
363                        if par == toks[1]:
364                            self.dispersion[item][par] = value
365                            return
366        else:
367            # Look for standard parameter
368            for item in self.params.keys():
369                if item == name:
370                    self.params[item] = value
371                    return
372
373        raise ValueError("Model does not contain parameter %s" % name)
374
375    def getParam(self, name):
376        # type: (str) -> float
377        """
378        Set the value of a model parameter
379
380        :param name: name of the parameter
381
382        """
383        # Look for dispersion parameters
384        toks = name.split('.')
385        if len(toks) == 2:
386            for item in self.dispersion.keys():
387                if item == toks[0]:
388                    for par in self.dispersion[item]:
389                        if par == toks[1]:
390                            return self.dispersion[item][par]
391        else:
392            # Look for standard parameter
393            for item in self.params.keys():
394                if item == name:
395                    return self.params[item]
396
397        raise ValueError("Model does not contain parameter %s" % name)
398
399    def getParamList(self):
400        # type: () -> Sequence[str]
401        """
402        Return a list of all available parameters for the model
403        """
404        param_list = list(self.params.keys())
405        # WARNING: Extending the list with the dispersion parameters
406        param_list.extend(self.getDispParamList())
407        return param_list
408
409    def getDispParamList(self):
410        # type: () -> Sequence[str]
411        """
412        Return a list of polydispersity parameters for the model
413        """
414        # TODO: fix test so that parameter order doesn't matter
415        ret = ['%s.%s' % (p.name, ext)
416               for p in self._model_info.parameters.user_parameters()
417               for ext in ('npts', 'nsigmas', 'width')
418               if p.polydisperse]
419        #print(ret)
420        return ret
421
422    def clone(self):
423        # type: () -> "SasviewModel"
424        """ Return a identical copy of self """
425        return deepcopy(self)
426
427    def run(self, x=0.0):
428        # type: (Union[float, (float, float), List[float]]) -> float
429        """
430        Evaluate the model
431
432        :param x: input q, or [q,phi]
433
434        :return: scattering function P(q)
435
436        **DEPRECATED**: use calculate_Iq instead
437        """
438        if isinstance(x, (list, tuple)):
439            # pylint: disable=unpacking-non-sequence
440            q, phi = x
441            return self.calculate_Iq([q*math.cos(phi)], [q*math.sin(phi)])[0]
442        else:
443            return self.calculate_Iq([x])[0]
444
445
446    def runXY(self, x=0.0):
447        # type: (Union[float, (float, float), List[float]]) -> float
448        """
449        Evaluate the model in cartesian coordinates
450
451        :param x: input q, or [qx, qy]
452
453        :return: scattering function P(q)
454
455        **DEPRECATED**: use calculate_Iq instead
456        """
457        if isinstance(x, (list, tuple)):
458            return self.calculate_Iq([x[0]], [x[1]])[0]
459        else:
460            return self.calculate_Iq([x])[0]
461
462    def evalDistribution(self, qdist):
463        # type: (Union[np.ndarray, Tuple[np.ndarray, np.ndarray], List[np.ndarray]]) -> np.ndarray
464        r"""
465        Evaluate a distribution of q-values.
466
467        :param qdist: array of q or a list of arrays [qx,qy]
468
469        * For 1D, a numpy array is expected as input
470
471        ::
472
473            evalDistribution(q)
474
475          where *q* is a numpy array.
476
477        * For 2D, a list of *[qx,qy]* is expected with 1D arrays as input
478
479        ::
480
481              qx = [ qx[0], qx[1], qx[2], ....]
482              qy = [ qy[0], qy[1], qy[2], ....]
483
484        If the model is 1D only, then
485
486        .. math::
487
488            q = \sqrt{q_x^2+q_y^2}
489
490        """
491        if isinstance(qdist, (list, tuple)):
492            # Check whether we have a list of ndarrays [qx,qy]
493            qx, qy = qdist
494            if not self._model_info.parameters.has_2d:
495                return self.calculate_Iq(np.sqrt(qx ** 2 + qy ** 2))
496            else:
497                return self.calculate_Iq(qx, qy)
498
499        elif isinstance(qdist, np.ndarray):
500            # We have a simple 1D distribution of q-values
501            return self.calculate_Iq(qdist)
502
503        else:
504            raise TypeError("evalDistribution expects q or [qx, qy], not %r"
505                            % type(qdist))
506
507    def calculate_Iq(self, qx, qy=None):
508        # type: (Sequence[float], Optional[Sequence[float]]) -> np.ndarray
509        """
510        Calculate Iq for one set of q with the current parameters.
511
512        If the model is 1D, use *q*.  If 2D, use *qx*, *qy*.
513
514        This should NOT be used for fitting since it copies the *q* vectors
515        to the card for each evaluation.
516        """
517        #core.HAVE_OPENCL = False
518        if self._model is None:
519            self._model = core.build_model(self._model_info)
520        if qy is not None:
521            q_vectors = [np.asarray(qx), np.asarray(qy)]
522        else:
523            q_vectors = [np.asarray(qx)]
524        calculator = self._model.make_kernel(q_vectors)
525        parameters = self._model_info.parameters
526        pairs = [self._get_weights(p) for p in parameters.call_parameters]
527        call_details, values, is_magnetic = build_details(calculator, pairs)
528        #call_details.show()
529        #print("pairs", pairs)
530        #print("params", self.params)
531        #print("values", values)
532        #print("is_mag", is_magnetic)
533        result = calculator(call_details, values, cutoff=self.cutoff,
534                            magnetic=is_magnetic)
535        calculator.release()
536        return result
537
538    def calculate_ER(self):
539        # type: () -> float
540        """
541        Calculate the effective radius for P(q)*S(q)
542
543        :return: the value of the effective radius
544        """
545        if self._model_info.ER is None:
546            return 1.0
547        else:
548            value, weight = self._dispersion_mesh()
549            fv = self._model_info.ER(*value)
550            #print(values[0].shape, weights.shape, fv.shape)
551            return np.sum(weight * fv) / np.sum(weight)
552
553    def calculate_VR(self):
554        # type: () -> float
555        """
556        Calculate the volf ratio for P(q)*S(q)
557
558        :return: the value of the volf ratio
559        """
560        if self._model_info.VR is None:
561            return 1.0
562        else:
563            value, weight = self._dispersion_mesh()
564            whole, part = self._model_info.VR(*value)
565            return np.sum(weight * part) / np.sum(weight * whole)
566
567    def set_dispersion(self, parameter, dispersion):
568        # type: (str, weights.Dispersion) -> Dict[str, Any]
569        """
570        Set the dispersion object for a model parameter
571
572        :param parameter: name of the parameter [string]
573        :param dispersion: dispersion object of type Dispersion
574        """
575        if parameter in self.params:
576            # TODO: Store the disperser object directly in the model.
577            # The current method of relying on the sasview GUI to
578            # remember them is kind of funky.
579            # Note: can't seem to get disperser parameters from sasview
580            # (1) Could create a sasview model that has not yet # been
581            # converted, assign the disperser to one of its polydisperse
582            # parameters, then retrieve the disperser parameters from the
583            # sasview model.  (2) Could write a disperser parameter retriever
584            # in sasview.  (3) Could modify sasview to use sasmodels.weights
585            # dispersers.
586            # For now, rely on the fact that the sasview only ever uses
587            # new dispersers in the set_dispersion call and create a new
588            # one instead of trying to assign parameters.
589            from . import weights
590            disperser = weights.dispersers[dispersion.__class__.__name__]
591            dispersion = weights.MODELS[disperser]()
592            self.dispersion[parameter] = dispersion.get_pars()
593        else:
594            raise ValueError("%r is not a dispersity or orientation parameter")
595
596    def _dispersion_mesh(self):
597        # type: () -> List[Tuple[np.ndarray, np.ndarray]]
598        """
599        Create a mesh grid of dispersion parameters and weights.
600
601        Returns [p1,p2,...],w where pj is a vector of values for parameter j
602        and w is a vector containing the products for weights for each
603        parameter set in the vector.
604        """
605        pars = [self._get_weights(p)
606                for p in self._model_info.parameters.call_parameters
607                if p.type == 'volume']
608        return dispersion_mesh(self._model_info, pars)
609
610    def _get_weights(self, par):
611        # type: (Parameter) -> Tuple[np.ndarray, np.ndarray]
612        """
613        Return dispersion weights for parameter
614        """
615        if par.name not in self.params:
616            if par.name == self.multiplicity_info.control:
617                return [self.multiplicity], [1.0]
618            else:
619                return [np.NaN], [1.0]
620        elif par.polydisperse:
621            dis = self.dispersion[par.name]
622            value, weight = weights.get_weights(
623                dis['type'], dis['npts'], dis['width'], dis['nsigmas'],
624                self.params[par.name], par.limits, par.relative_pd)
625            return value, weight / np.sum(weight)
626        else:
627            return [self.params[par.name]], [1.0]
628
629def test_model():
630    # type: () -> float
631    """
632    Test that a sasview model (cylinder) can be run.
633    """
634    Cylinder = _make_standard_model('cylinder')
635    cylinder = Cylinder()
636    return cylinder.evalDistribution([0.1,0.1])
637
638def test_rpa():
639    # type: () -> float
640    """
641    Test that a sasview model (cylinder) can be run.
642    """
643    RPA = _make_standard_model('rpa')
644    rpa = RPA(3)
645    return rpa.evalDistribution([0.1,0.1])
646
647
648def test_model_list():
649    # type: () -> None
650    """
651    Make sure that all models build as sasview models.
652    """
653    from .exception import annotate_exception
654    for name in core.list_models():
655        try:
656            _make_standard_model(name)
657        except:
658            annotate_exception("when loading "+name)
659            raise
660
661if __name__ == "__main__":
662    print("cylinder(0.1,0.1)=%g"%test_model())
Note: See TracBrowser for help on using the repository browser.