source: sasmodels/sasmodels/sasview_model.py @ 3bcb88c

core_shell_microgelscostrafo411magnetic_modelrelease_v0.94release_v0.95ticket-1257-vesicle-productticket_1156ticket_1265_superballticket_822_more_unit_tests
Last change on this file since 3bcb88c was 3bcb88c, checked in by Paul Kienzle <pkienzle@…>, 3 years ago

allow fitting of models with hidden polydisperse parameters (e.g., core_multi_shell)

  • Property mode set to 100644
File size: 24.0 KB
Line 
1"""
2Sasview model constructor.
3
4Given a module defining an OpenCL kernel such as sasmodels.models.cylinder,
5create a sasview model class to run that kernel as follows::
6
7    from sasmodels.sasview_model import load_custom_model
8    CylinderModel = load_custom_model('sasmodels/models/cylinder.py')
9"""
10from __future__ import print_function
11
12import math
13from copy import deepcopy
14import collections
15import traceback
16import logging
17from os.path import basename, splitext
18
19import numpy as np  # type: ignore
20
21from . import core
22from . import custom
23from . import generate
24from . import weights
25from . import modelinfo
26from .details import make_kernel_args, dispersion_mesh
27
28try:
29    from typing import Dict, Mapping, Any, Sequence, Tuple, NamedTuple, List, Optional, Union, Callable
30    from .modelinfo import ModelInfo, Parameter
31    from .kernel import KernelModel
32    MultiplicityInfoType = NamedTuple(
33        'MuliplicityInfo',
34        [("number", int), ("control", str), ("choices", List[str]),
35         ("x_axis_label", str)])
36    SasviewModelType = Callable[[int], "SasviewModel"]
37except ImportError:
38    pass
39
40SUPPORT_OLD_STYLE_PLUGINS = True
41
42def _register_old_models():
43    # type: () -> None
44    """
45    Place the new models into sasview under the old names.
46
47    Monkey patch sas.sascalc.fit as sas.models so that sas.models.pluginmodel
48    is available to the plugin modules.
49    """
50    import sys
51    import sas
52    import sas.sascalc.fit
53    sys.modules['sas.models'] = sas.sascalc.fit
54    sas.models = sas.sascalc.fit
55
56    import sas.models
57    from sasmodels.conversion_table import CONVERSION_TABLE
58    for new_name, conversion in CONVERSION_TABLE.items():
59        old_name = conversion[0]
60        module_attrs = {old_name: find_model(new_name)}
61        ConstructedModule = type(old_name, (), module_attrs)
62        old_path = 'sas.models.' + old_name
63        setattr(sas.models, old_path, ConstructedModule)
64        sys.modules[old_path] = ConstructedModule
65
66
67# TODO: separate x_axis_label from multiplicity info
68MultiplicityInfo = collections.namedtuple(
69    'MultiplicityInfo',
70    ["number", "control", "choices", "x_axis_label"],
71)
72
73MODELS = {}
74def find_model(modelname):
75    # type: (str) -> SasviewModelType
76    """
77    Find a model by name.  If the model name ends in py, try loading it from
78    custom models, otherwise look for it in the list of builtin models.
79    """
80    # TODO: used by sum/product model to load an existing model
81    # TODO: doesn't handle custom models properly
82    if modelname.endswith('.py'):
83        return load_custom_model(modelname)
84    elif modelname in MODELS:
85        return MODELS[modelname]
86    else:
87        raise ValueError("unknown model %r"%modelname)
88
89
90# TODO: figure out how to say that the return type is a subclass
91def load_standard_models():
92    # type: () -> List[SasviewModelType]
93    """
94    Load and return the list of predefined models.
95
96    If there is an error loading a model, then a traceback is logged and the
97    model is not returned.
98    """
99    models = []
100    for name in core.list_models():
101        try:
102            MODELS[name] = _make_standard_model(name)
103            models.append(MODELS[name])
104        except Exception:
105            logging.error(traceback.format_exc())
106    if SUPPORT_OLD_STYLE_PLUGINS:
107        _register_old_models()
108
109    return models
110
111
112def load_custom_model(path):
113    # type: (str) -> SasviewModelType
114    """
115    Load a custom model given the model path.
116    """
117    #print("load custom", path)
118    kernel_module = custom.load_custom_kernel_module(path)
119    try:
120        model = kernel_module.Model
121        # Old style models do not set the name in the class attributes, so
122        # set it here; this name will be overridden when the object is created
123        # with an instance variable that has the same value.
124        if model.name == "":
125            model.name = splitext(basename(path))[0]
126    except AttributeError:
127        model_info = modelinfo.make_model_info(kernel_module)
128        model = _make_model_from_info(model_info)
129    MODELS[model.name] = model
130    return model
131
132
133def _make_standard_model(name):
134    # type: (str) -> SasviewModelType
135    """
136    Load the sasview model defined by *name*.
137
138    *name* can be a standard model name or a path to a custom model.
139
140    Returns a class that can be used directly as a sasview model.
141    """
142    kernel_module = generate.load_kernel_module(name)
143    model_info = modelinfo.make_model_info(kernel_module)
144    return _make_model_from_info(model_info)
145
146
147def _make_model_from_info(model_info):
148    # type: (ModelInfo) -> SasviewModelType
149    """
150    Convert *model_info* into a SasView model wrapper.
151    """
152    def __init__(self, multiplicity=None):
153        SasviewModel.__init__(self, multiplicity=multiplicity)
154    attrs = _generate_model_attributes(model_info)
155    attrs['__init__'] = __init__
156    ConstructedModel = type(model_info.name, (SasviewModel,), attrs) # type: SasviewModelType
157    return ConstructedModel
158
159def _generate_model_attributes(model_info):
160    # type: (ModelInfo) -> Dict[str, Any]
161    """
162    Generate the class attributes for the model.
163
164    This should include all the information necessary to query the model
165    details so that you do not need to instantiate a model to query it.
166
167    All the attributes should be immutable to avoid accidents.
168    """
169
170    # TODO: allow model to override axis labels input/output name/unit
171
172    # Process multiplicity
173    non_fittable = []  # type: List[str]
174    xlabel = model_info.profile_axes[0] if model_info.profile is not None else ""
175    variants = MultiplicityInfo(0, "", [], xlabel)
176    for p in model_info.parameters.kernel_parameters:
177        if p.name == model_info.control:
178            non_fittable.append(p.name)
179            variants = MultiplicityInfo(
180                len(p.choices) if p.choices else int(p.limits[1]),
181                p.name, p.choices, xlabel
182            )
183            break
184
185    # Only a single drop-down list parameter available
186    fun_list = []
187    for p in model_info.parameters.kernel_parameters:
188        if p.choices:
189            fun_list = p.choices
190            if p.length > 1:
191                non_fittable.extend(p.id+str(k) for k in range(1, p.length+1))
192            break
193
194    # Organize parameter sets
195    orientation_params = []
196    magnetic_params = []
197    fixed = []
198    for p in model_info.parameters.user_parameters():
199        if p.type == 'orientation':
200            orientation_params.append(p.name)
201            orientation_params.append(p.name+".width")
202            fixed.append(p.name+".width")
203        elif p.type == 'magnetic':
204            orientation_params.append(p.name)
205            magnetic_params.append(p.name)
206            fixed.append(p.name+".width")
207
208
209    # Build class dictionary
210    attrs = {}  # type: Dict[str, Any]
211    attrs['_model_info'] = model_info
212    attrs['name'] = model_info.name
213    attrs['id'] = model_info.id
214    attrs['description'] = model_info.description
215    attrs['category'] = model_info.category
216    attrs['is_structure_factor'] = model_info.structure_factor
217    attrs['is_form_factor'] = model_info.ER is not None
218    attrs['is_multiplicity_model'] = variants[0] > 1
219    attrs['multiplicity_info'] = variants
220    attrs['orientation_params'] = tuple(orientation_params)
221    attrs['magnetic_params'] = tuple(magnetic_params)
222    attrs['fixed'] = tuple(fixed)
223    attrs['non_fittable'] = tuple(non_fittable)
224    attrs['fun_list'] = tuple(fun_list)
225
226    return attrs
227
228class SasviewModel(object):
229    """
230    Sasview wrapper for opencl/ctypes model.
231    """
232    # Model parameters for the specific model are set in the class constructor
233    # via the _generate_model_attributes function, which subclasses
234    # SasviewModel.  They are included here for typing and documentation
235    # purposes.
236    _model = None       # type: KernelModel
237    _model_info = None  # type: ModelInfo
238    #: load/save name for the model
239    id = None           # type: str
240    #: display name for the model
241    name = None         # type: str
242    #: short model description
243    description = None  # type: str
244    #: default model category
245    category = None     # type: str
246
247    #: names of the orientation parameters in the order they appear
248    orientation_params = None # type: Sequence[str]
249    #: names of the magnetic parameters in the order they appear
250    magnetic_params = None    # type: Sequence[str]
251    #: names of the fittable parameters
252    fixed = None              # type: Sequence[str]
253    # TODO: the attribute fixed is ill-named
254
255    # Axis labels
256    input_name = "Q"
257    input_unit = "A^{-1}"
258    output_name = "Intensity"
259    output_unit = "cm^{-1}"
260
261    #: default cutoff for polydispersity
262    cutoff = 1e-5
263
264    # Note: Use non-mutable values for class attributes to avoid errors
265    #: parameters that are not fitted
266    non_fittable = ()        # type: Sequence[str]
267
268    #: True if model should appear as a structure factor
269    is_structure_factor = False
270    #: True if model should appear as a form factor
271    is_form_factor = False
272    #: True if model has multiplicity
273    is_multiplicity_model = False
274    #: Mulitplicity information
275    multiplicity_info = None # type: MultiplicityInfoType
276
277    # Per-instance variables
278    #: parameter {name: value} mapping
279    params = None      # type: Dict[str, float]
280    #: values for dispersion width, npts, nsigmas and type
281    dispersion = None  # type: Dict[str, Any]
282    #: units and limits for each parameter
283    details = None     # type: Dict[str, Sequence[Any]]
284    #                  # actual type is Dict[str, List[str, float, float]]
285    #: multiplicity value, or None if no multiplicity on the model
286    multiplicity = None     # type: Optional[int]
287    #: memory for polydispersity array if using ArrayDispersion (used by sasview).
288    _persistency_dict = None # type: Dict[str, Tuple[np.ndarray, np.ndarray]]
289
290    def __init__(self, multiplicity=None):
291        # type: (Optional[int]) -> None
292
293        # TODO: _persistency_dict to persistency_dict throughout sasview
294        # TODO: refactor multiplicity to encompass variants
295        # TODO: dispersion should be a class
296        # TODO: refactor multiplicity info
297        # TODO: separate profile view from multiplicity
298        # The button label, x and y axis labels and scale need to be under
299        # the control of the model, not the fit page.  Maximum flexibility,
300        # the fit page would supply the canvas and the profile could plot
301        # how it wants, but this assumes matplotlib.  Next level is that
302        # we provide some sort of data description including title, labels
303        # and lines to plot.
304
305        # Get the list of hidden parameters given the mulitplicity
306        # Don't include multiplicity in the list of parameters
307        self.multiplicity = multiplicity
308        if multiplicity is not None:
309            hidden = self._model_info.get_hidden_parameters(multiplicity)
310            hidden |= set([self.multiplicity_info.control])
311        else:
312            hidden = set()
313
314        self._persistency_dict = {}
315        self.params = collections.OrderedDict()
316        self.dispersion = collections.OrderedDict()
317        self.details = {}
318        for p in self._model_info.parameters.user_parameters():
319            if p.name in hidden:
320                continue
321            self.params[p.name] = p.default
322            self.details[p.id] = [p.units, p.limits[0], p.limits[1]]
323            if p.polydisperse:
324                self.details[p.id+".width"] = [
325                    "", 0.0, 1.0 if p.relative_pd else np.inf
326                ]
327                self.dispersion[p.name] = {
328                    'width': 0,
329                    'npts': 35,
330                    'nsigmas': 3,
331                    'type': 'gaussian',
332                }
333
334    def __get_state__(self):
335        # type: () -> Dict[str, Any]
336        state = self.__dict__.copy()
337        state.pop('_model')
338        # May need to reload model info on set state since it has pointers
339        # to python implementations of Iq, etc.
340        #state.pop('_model_info')
341        return state
342
343    def __set_state__(self, state):
344        # type: (Dict[str, Any]) -> None
345        self.__dict__ = state
346        self._model = None
347
348    def __str__(self):
349        # type: () -> str
350        """
351        :return: string representation
352        """
353        return self.name
354
355    def is_fittable(self, par_name):
356        # type: (str) -> bool
357        """
358        Check if a given parameter is fittable or not
359
360        :param par_name: the parameter name to check
361        """
362        return par_name in self.fixed
363        #For the future
364        #return self.params[str(par_name)].is_fittable()
365
366
367    def getProfile(self):
368        # type: () -> (np.ndarray, np.ndarray)
369        """
370        Get SLD profile
371
372        : return: (z, beta) where z is a list of depth of the transition points
373                beta is a list of the corresponding SLD values
374        """
375        args = {} # type: Dict[str, Any]
376        for p in self._model_info.parameters.kernel_parameters:
377            if p.id == self.multiplicity_info.control:
378                value = float(self.multiplicity)
379            elif p.length == 1:
380                value = self.params.get(p.id, np.NaN)
381            else:
382                value = np.array([self.params.get(p.id+str(k), np.NaN)
383                                  for k in range(1, p.length+1)])
384            args[p.id] = value
385
386        x, y = self._model_info.profile(**args)
387        return x, 1e-6*y
388
389    def setParam(self, name, value):
390        # type: (str, float) -> None
391        """
392        Set the value of a model parameter
393
394        :param name: name of the parameter
395        :param value: value of the parameter
396
397        """
398        # Look for dispersion parameters
399        toks = name.split('.')
400        if len(toks) == 2:
401            for item in self.dispersion.keys():
402                if item == toks[0]:
403                    for par in self.dispersion[item]:
404                        if par == toks[1]:
405                            self.dispersion[item][par] = value
406                            return
407        else:
408            # Look for standard parameter
409            for item in self.params.keys():
410                if item == name:
411                    self.params[item] = value
412                    return
413
414        raise ValueError("Model does not contain parameter %s" % name)
415
416    def getParam(self, name):
417        # type: (str) -> float
418        """
419        Set the value of a model parameter
420
421        :param name: name of the parameter
422
423        """
424        # Look for dispersion parameters
425        toks = name.split('.')
426        if len(toks) == 2:
427            for item in self.dispersion.keys():
428                if item == toks[0]:
429                    for par in self.dispersion[item]:
430                        if par == toks[1]:
431                            return self.dispersion[item][par]
432        else:
433            # Look for standard parameter
434            for item in self.params.keys():
435                if item == name:
436                    return self.params[item]
437
438        raise ValueError("Model does not contain parameter %s" % name)
439
440    def getParamList(self):
441        # type: () -> Sequence[str]
442        """
443        Return a list of all available parameters for the model
444        """
445        param_list = list(self.params.keys())
446        # WARNING: Extending the list with the dispersion parameters
447        param_list.extend(self.getDispParamList())
448        return param_list
449
450    def getDispParamList(self):
451        # type: () -> Sequence[str]
452        """
453        Return a list of polydispersity parameters for the model
454        """
455        # TODO: fix test so that parameter order doesn't matter
456        ret = ['%s.%s' % (p_name, ext)
457               for p_name in self.dispersion.keys()
458               for ext in ('npts', 'nsigmas', 'width')]
459        #print(ret)
460        return ret
461
462    def clone(self):
463        # type: () -> "SasviewModel"
464        """ Return a identical copy of self """
465        return deepcopy(self)
466
467    def run(self, x=0.0):
468        # type: (Union[float, (float, float), List[float]]) -> float
469        """
470        Evaluate the model
471
472        :param x: input q, or [q,phi]
473
474        :return: scattering function P(q)
475
476        **DEPRECATED**: use calculate_Iq instead
477        """
478        if isinstance(x, (list, tuple)):
479            # pylint: disable=unpacking-non-sequence
480            q, phi = x
481            return self.calculate_Iq([q*math.cos(phi)], [q*math.sin(phi)])[0]
482        else:
483            return self.calculate_Iq([x])[0]
484
485
486    def runXY(self, x=0.0):
487        # type: (Union[float, (float, float), List[float]]) -> float
488        """
489        Evaluate the model in cartesian coordinates
490
491        :param x: input q, or [qx, qy]
492
493        :return: scattering function P(q)
494
495        **DEPRECATED**: use calculate_Iq instead
496        """
497        if isinstance(x, (list, tuple)):
498            return self.calculate_Iq([x[0]], [x[1]])[0]
499        else:
500            return self.calculate_Iq([x])[0]
501
502    def evalDistribution(self, qdist):
503        # type: (Union[np.ndarray, Tuple[np.ndarray, np.ndarray], List[np.ndarray]]) -> np.ndarray
504        r"""
505        Evaluate a distribution of q-values.
506
507        :param qdist: array of q or a list of arrays [qx,qy]
508
509        * For 1D, a numpy array is expected as input
510
511        ::
512
513            evalDistribution(q)
514
515          where *q* is a numpy array.
516
517        * For 2D, a list of *[qx,qy]* is expected with 1D arrays as input
518
519        ::
520
521              qx = [ qx[0], qx[1], qx[2], ....]
522              qy = [ qy[0], qy[1], qy[2], ....]
523
524        If the model is 1D only, then
525
526        .. math::
527
528            q = \sqrt{q_x^2+q_y^2}
529
530        """
531        if isinstance(qdist, (list, tuple)):
532            # Check whether we have a list of ndarrays [qx,qy]
533            qx, qy = qdist
534            if not self._model_info.parameters.has_2d:
535                return self.calculate_Iq(np.sqrt(qx ** 2 + qy ** 2))
536            else:
537                return self.calculate_Iq(qx, qy)
538
539        elif isinstance(qdist, np.ndarray):
540            # We have a simple 1D distribution of q-values
541            return self.calculate_Iq(qdist)
542
543        else:
544            raise TypeError("evalDistribution expects q or [qx, qy], not %r"
545                            % type(qdist))
546
547    def calculate_Iq(self, qx, qy=None):
548        # type: (Sequence[float], Optional[Sequence[float]]) -> np.ndarray
549        """
550        Calculate Iq for one set of q with the current parameters.
551
552        If the model is 1D, use *q*.  If 2D, use *qx*, *qy*.
553
554        This should NOT be used for fitting since it copies the *q* vectors
555        to the card for each evaluation.
556        """
557        #core.HAVE_OPENCL = False
558        if self._model is None:
559            self._model = core.build_model(self._model_info)
560        if qy is not None:
561            q_vectors = [np.asarray(qx), np.asarray(qy)]
562        else:
563            q_vectors = [np.asarray(qx)]
564        calculator = self._model.make_kernel(q_vectors)
565        parameters = self._model_info.parameters
566        pairs = [self._get_weights(p) for p in parameters.call_parameters]
567        call_details, values, is_magnetic = make_kernel_args(calculator, pairs)
568        #call_details.show()
569        #print("pairs", pairs)
570        #print("params", self.params)
571        #print("values", values)
572        #print("is_mag", is_magnetic)
573        result = calculator(call_details, values, cutoff=self.cutoff,
574                            magnetic=is_magnetic)
575        calculator.release()
576        return result
577
578    def calculate_ER(self):
579        # type: () -> float
580        """
581        Calculate the effective radius for P(q)*S(q)
582
583        :return: the value of the effective radius
584        """
585        if self._model_info.ER is None:
586            return 1.0
587        else:
588            value, weight = self._dispersion_mesh()
589            fv = self._model_info.ER(*value)
590            #print(values[0].shape, weights.shape, fv.shape)
591            return np.sum(weight * fv) / np.sum(weight)
592
593    def calculate_VR(self):
594        # type: () -> float
595        """
596        Calculate the volf ratio for P(q)*S(q)
597
598        :return: the value of the volf ratio
599        """
600        if self._model_info.VR is None:
601            return 1.0
602        else:
603            value, weight = self._dispersion_mesh()
604            whole, part = self._model_info.VR(*value)
605            return np.sum(weight * part) / np.sum(weight * whole)
606
607    def set_dispersion(self, parameter, dispersion):
608        # type: (str, weights.Dispersion) -> Dict[str, Any]
609        """
610        Set the dispersion object for a model parameter
611
612        :param parameter: name of the parameter [string]
613        :param dispersion: dispersion object of type Dispersion
614        """
615        if parameter in self.params:
616            # TODO: Store the disperser object directly in the model.
617            # The current method of relying on the sasview GUI to
618            # remember them is kind of funky.
619            # Note: can't seem to get disperser parameters from sasview
620            # (1) Could create a sasview model that has not yet # been
621            # converted, assign the disperser to one of its polydisperse
622            # parameters, then retrieve the disperser parameters from the
623            # sasview model.  (2) Could write a disperser parameter retriever
624            # in sasview.  (3) Could modify sasview to use sasmodels.weights
625            # dispersers.
626            # For now, rely on the fact that the sasview only ever uses
627            # new dispersers in the set_dispersion call and create a new
628            # one instead of trying to assign parameters.
629            dispersion = weights.MODELS[dispersion.type]()
630            self.dispersion[parameter] = dispersion.get_pars()
631        else:
632            raise ValueError("%r is not a dispersity or orientation parameter")
633
634    def _dispersion_mesh(self):
635        # type: () -> List[Tuple[np.ndarray, np.ndarray]]
636        """
637        Create a mesh grid of dispersion parameters and weights.
638
639        Returns [p1,p2,...],w where pj is a vector of values for parameter j
640        and w is a vector containing the products for weights for each
641        parameter set in the vector.
642        """
643        pars = [self._get_weights(p)
644                for p in self._model_info.parameters.call_parameters
645                if p.type == 'volume']
646        return dispersion_mesh(self._model_info, pars)
647
648    def _get_weights(self, par):
649        # type: (Parameter) -> Tuple[np.ndarray, np.ndarray]
650        """
651        Return dispersion weights for parameter
652        """
653        if par.name not in self.params:
654            if par.name == self.multiplicity_info.control:
655                return [self.multiplicity], [1.0]
656            else:
657                return [np.NaN], [1.0]
658        elif par.polydisperse:
659            dis = self.dispersion[par.name]
660            value, weight = weights.get_weights(
661                dis['type'], dis['npts'], dis['width'], dis['nsigmas'],
662                self.params[par.name], par.limits, par.relative_pd)
663            return value, weight / np.sum(weight)
664        else:
665            return [self.params[par.name]], [1.0]
666
667def test_model():
668    # type: () -> float
669    """
670    Test that a sasview model (cylinder) can be run.
671    """
672    Cylinder = _make_standard_model('cylinder')
673    cylinder = Cylinder()
674    return cylinder.evalDistribution([0.1, 0.1])
675
676def test_rpa():
677    # type: () -> float
678    """
679    Test that a sasview model (cylinder) can be run.
680    """
681    RPA = _make_standard_model('rpa')
682    rpa = RPA(3)
683    return rpa.evalDistribution([0.1, 0.1])
684
685
686def test_model_list():
687    # type: () -> None
688    """
689    Make sure that all models build as sasview models.
690    """
691    from .exception import annotate_exception
692    for name in core.list_models():
693        try:
694            _make_standard_model(name)
695        except:
696            annotate_exception("when loading "+name)
697            raise
698
699def test_old_name():
700    # type: () -> None
701    """
702    Load and run cylinder model from sas.models.CylinderModel
703    """
704    if not SUPPORT_OLD_STYLE_PLUGINS:
705        return
706    try:
707        # if sasview is not on the path then don't try to test it
708        import sas
709    except ImportError:
710        return
711    load_standard_models()
712    from sas.models.CylinderModel import CylinderModel
713    CylinderModel().evalDistribution([0.1, 0.1])
714
715if __name__ == "__main__":
716    print("cylinder(0.1,0.1)=%g"%test_model())
Note: See TracBrowser for help on using the repository browser.