source: sasmodels/sasmodels/sasview_model.py @ be86916

release_v0.95
Last change on this file since be86916 was be86916, checked in by ajj, 7 years ago

hide structure factor background/scale from sasview gui. Fixes #657.

  • Property mode set to 100644
File size: 24.7 KB
Line 
1"""
2Sasview model constructor.
3
4Given a module defining an OpenCL kernel such as sasmodels.models.cylinder,
5create a sasview model class to run that kernel as follows::
6
7    from sasmodels.sasview_model import load_custom_model
8    CylinderModel = load_custom_model('sasmodels/models/cylinder.py')
9"""
10from __future__ import print_function
11
12import math
13from copy import deepcopy
14import collections
15import traceback
16import logging
17from os.path import basename, splitext
18
19import numpy as np  # type: ignore
20
21from . import core
22from . import custom
23from . import generate
24from . import weights
25from . import modelinfo
26from .details import make_kernel_args, dispersion_mesh
27
28try:
29    from typing import Dict, Mapping, Any, Sequence, Tuple, NamedTuple, List, Optional, Union, Callable
30    from .modelinfo import ModelInfo, Parameter
31    from .kernel import KernelModel
32    MultiplicityInfoType = NamedTuple(
33        'MuliplicityInfo',
34        [("number", int), ("control", str), ("choices", List[str]),
35         ("x_axis_label", str)])
36    SasviewModelType = Callable[[int], "SasviewModel"]
37except ImportError:
38    pass
39
40SUPPORT_OLD_STYLE_PLUGINS = True
41
42def _register_old_models():
43    # type: () -> None
44    """
45    Place the new models into sasview under the old names.
46
47    Monkey patch sas.sascalc.fit as sas.models so that sas.models.pluginmodel
48    is available to the plugin modules.
49    """
50    import sys
51    import sas
52    import sas.sascalc.fit
53    sys.modules['sas.models'] = sas.sascalc.fit
54    sas.models = sas.sascalc.fit
55
56    import sas.models
57    from sasmodels.conversion_table import CONVERSION_TABLE
58    for new_name, conversion in CONVERSION_TABLE.items():
59        old_name = conversion[0]
60        module_attrs = {old_name: find_model(new_name)}
61        ConstructedModule = type(old_name, (), module_attrs)
62        old_path = 'sas.models.' + old_name
63        setattr(sas.models, old_path, ConstructedModule)
64        sys.modules[old_path] = ConstructedModule
65
66
67# TODO: separate x_axis_label from multiplicity info
68MultiplicityInfo = collections.namedtuple(
69    'MultiplicityInfo',
70    ["number", "control", "choices", "x_axis_label"],
71)
72
73MODELS = {}
74def find_model(modelname):
75    # type: (str) -> SasviewModelType
76    """
77    Find a model by name.  If the model name ends in py, try loading it from
78    custom models, otherwise look for it in the list of builtin models.
79    """
80    # TODO: used by sum/product model to load an existing model
81    # TODO: doesn't handle custom models properly
82    if modelname.endswith('.py'):
83        return load_custom_model(modelname)
84    elif modelname in MODELS:
85        return MODELS[modelname]
86    else:
87        raise ValueError("unknown model %r"%modelname)
88
89
90# TODO: figure out how to say that the return type is a subclass
91def load_standard_models():
92    # type: () -> List[SasviewModelType]
93    """
94    Load and return the list of predefined models.
95
96    If there is an error loading a model, then a traceback is logged and the
97    model is not returned.
98    """
99    models = []
100    for name in core.list_models():
101        try:
102            MODELS[name] = _make_standard_model(name)
103            models.append(MODELS[name])
104        except Exception:
105            logging.error(traceback.format_exc())
106    if SUPPORT_OLD_STYLE_PLUGINS:
107        _register_old_models()
108
109    return models
110
111
112def load_custom_model(path):
113    # type: (str) -> SasviewModelType
114    """
115    Load a custom model given the model path.
116    """
117    #print("load custom", path)
118    kernel_module = custom.load_custom_kernel_module(path)
119    try:
120        model = kernel_module.Model
121        # Old style models do not set the name in the class attributes, so
122        # set it here; this name will be overridden when the object is created
123        # with an instance variable that has the same value.
124        if model.name == "":
125            model.name = splitext(basename(path))[0]
126    except AttributeError:
127        model_info = modelinfo.make_model_info(kernel_module)
128        model = _make_model_from_info(model_info)
129    MODELS[model.name] = model
130    return model
131
132
133def _make_standard_model(name):
134    # type: (str) -> SasviewModelType
135    """
136    Load the sasview model defined by *name*.
137
138    *name* can be a standard model name or a path to a custom model.
139
140    Returns a class that can be used directly as a sasview model.
141    """
142    kernel_module = generate.load_kernel_module(name)
143    model_info = modelinfo.make_model_info(kernel_module)
144    return _make_model_from_info(model_info)
145
146
147def _make_model_from_info(model_info):
148    # type: (ModelInfo) -> SasviewModelType
149    """
150    Convert *model_info* into a SasView model wrapper.
151    """
152    def __init__(self, multiplicity=None):
153        SasviewModel.__init__(self, multiplicity=multiplicity)
154    attrs = _generate_model_attributes(model_info)
155    attrs['__init__'] = __init__
156    ConstructedModel = type(model_info.name, (SasviewModel,), attrs) # type: SasviewModelType
157    return ConstructedModel
158
159def _generate_model_attributes(model_info):
160    # type: (ModelInfo) -> Dict[str, Any]
161    """
162    Generate the class attributes for the model.
163
164    This should include all the information necessary to query the model
165    details so that you do not need to instantiate a model to query it.
166
167    All the attributes should be immutable to avoid accidents.
168    """
169
170    # TODO: allow model to override axis labels input/output name/unit
171
172    # Process multiplicity
173    non_fittable = []  # type: List[str]
174    xlabel = model_info.profile_axes[0] if model_info.profile is not None else ""
175    variants = MultiplicityInfo(0, "", [], xlabel)
176    for p in model_info.parameters.kernel_parameters:
177        if p.name == model_info.control:
178            non_fittable.append(p.name)
179            variants = MultiplicityInfo(
180                len(p.choices) if p.choices else int(p.limits[1]),
181                p.name, p.choices, xlabel
182            )
183            break
184
185    # Only a single drop-down list parameter available
186    fun_list = []
187    for p in model_info.parameters.kernel_parameters:
188        if p.choices:
189            fun_list = p.choices
190            if p.length > 1:
191                non_fittable.extend(p.id+str(k) for k in range(1, p.length+1))
192            break
193
194    # Organize parameter sets
195    orientation_params = []
196    magnetic_params = []
197    fixed = []
198    for p in model_info.parameters.user_parameters():
199        if p.type == 'orientation':
200            orientation_params.append(p.name)
201            orientation_params.append(p.name+".width")
202            fixed.append(p.name+".width")
203        elif p.type == 'magnetic':
204            orientation_params.append(p.name)
205            magnetic_params.append(p.name)
206            fixed.append(p.name+".width")
207
208
209    # Build class dictionary
210    attrs = {}  # type: Dict[str, Any]
211    attrs['_model_info'] = model_info
212    attrs['name'] = model_info.name
213    attrs['id'] = model_info.id
214    attrs['description'] = model_info.description
215    attrs['category'] = model_info.category
216    attrs['is_structure_factor'] = model_info.structure_factor
217    attrs['is_form_factor'] = model_info.ER is not None
218    attrs['is_multiplicity_model'] = variants[0] > 1
219    attrs['multiplicity_info'] = variants
220    attrs['orientation_params'] = tuple(orientation_params)
221    attrs['magnetic_params'] = tuple(magnetic_params)
222    attrs['fixed'] = tuple(fixed)
223    attrs['non_fittable'] = tuple(non_fittable)
224    attrs['fun_list'] = tuple(fun_list)
225
226    return attrs
227
228class SasviewModel(object):
229    """
230    Sasview wrapper for opencl/ctypes model.
231    """
232    # Model parameters for the specific model are set in the class constructor
233    # via the _generate_model_attributes function, which subclasses
234    # SasviewModel.  They are included here for typing and documentation
235    # purposes.
236    _model = None       # type: KernelModel
237    _model_info = None  # type: ModelInfo
238    #: load/save name for the model
239    id = None           # type: str
240    #: display name for the model
241    name = None         # type: str
242    #: short model description
243    description = None  # type: str
244    #: default model category
245    category = None     # type: str
246
247    #: names of the orientation parameters in the order they appear
248    orientation_params = None # type: Sequence[str]
249    #: names of the magnetic parameters in the order they appear
250    magnetic_params = None    # type: Sequence[str]
251    #: names of the fittable parameters
252    fixed = None              # type: Sequence[str]
253    # TODO: the attribute fixed is ill-named
254
255    # Axis labels
256    input_name = "Q"
257    input_unit = "A^{-1}"
258    output_name = "Intensity"
259    output_unit = "cm^{-1}"
260
261    #: default cutoff for polydispersity
262    cutoff = 1e-5
263
264    # Note: Use non-mutable values for class attributes to avoid errors
265    #: parameters that are not fitted
266    non_fittable = ()        # type: Sequence[str]
267
268    #: True if model should appear as a structure factor
269    is_structure_factor = False
270    #: True if model should appear as a form factor
271    is_form_factor = False
272    #: True if model has multiplicity
273    is_multiplicity_model = False
274    #: Mulitplicity information
275    multiplicity_info = None # type: MultiplicityInfoType
276
277    # Per-instance variables
278    #: parameter {name: value} mapping
279    params = None      # type: Dict[str, float]
280    #: values for dispersion width, npts, nsigmas and type
281    dispersion = None  # type: Dict[str, Any]
282    #: units and limits for each parameter
283    details = None     # type: Dict[str, Sequence[Any]]
284    #                  # actual type is Dict[str, List[str, float, float]]
285    #: multiplicity value, or None if no multiplicity on the model
286    multiplicity = None     # type: Optional[int]
287    #: memory for polydispersity array if using ArrayDispersion (used by sasview).
288    _persistency_dict = None # type: Dict[str, Tuple[np.ndarray, np.ndarray]]
289
290    def __init__(self, multiplicity=None):
291        # type: (Optional[int]) -> None
292
293        # TODO: _persistency_dict to persistency_dict throughout sasview
294        # TODO: refactor multiplicity to encompass variants
295        # TODO: dispersion should be a class
296        # TODO: refactor multiplicity info
297        # TODO: separate profile view from multiplicity
298        # The button label, x and y axis labels and scale need to be under
299        # the control of the model, not the fit page.  Maximum flexibility,
300        # the fit page would supply the canvas and the profile could plot
301        # how it wants, but this assumes matplotlib.  Next level is that
302        # we provide some sort of data description including title, labels
303        # and lines to plot.
304
305        # Get the list of hidden parameters given the mulitplicity
306        # Don't include multiplicity in the list of parameters
307        self.multiplicity = multiplicity
308        if multiplicity is not None:
309            hidden = self._model_info.get_hidden_parameters(multiplicity)
310            hidden |= set([self.multiplicity_info.control])
311        else:
312            hidden = set()
313        if self._model_info.structure_factor:
314            hidden.add('scale')
315            hidden.add('background')
316            self._model_info.parameters.defaults['background'] = 0.
317
318        self._persistency_dict = {}
319        self.params = collections.OrderedDict()
320        self.dispersion = collections.OrderedDict()
321        self.details = {}
322        for p in self._model_info.parameters.user_parameters():
323            if p.name in hidden:
324                continue
325            self.params[p.name] = p.default
326            self.details[p.id] = [p.units, p.limits[0], p.limits[1]]
327            if p.polydisperse:
328                self.details[p.id+".width"] = [
329                    "", 0.0, 1.0 if p.relative_pd else np.inf
330                ]
331                self.dispersion[p.name] = {
332                    'width': 0,
333                    'npts': 35,
334                    'nsigmas': 3,
335                    'type': 'gaussian',
336                }
337
338    def __get_state__(self):
339        # type: () -> Dict[str, Any]
340        state = self.__dict__.copy()
341        state.pop('_model')
342        # May need to reload model info on set state since it has pointers
343        # to python implementations of Iq, etc.
344        #state.pop('_model_info')
345        return state
346
347    def __set_state__(self, state):
348        # type: (Dict[str, Any]) -> None
349        self.__dict__ = state
350        self._model = None
351
352    def __str__(self):
353        # type: () -> str
354        """
355        :return: string representation
356        """
357        return self.name
358
359    def is_fittable(self, par_name):
360        # type: (str) -> bool
361        """
362        Check if a given parameter is fittable or not
363
364        :param par_name: the parameter name to check
365        """
366        return par_name in self.fixed
367        #For the future
368        #return self.params[str(par_name)].is_fittable()
369
370
371    def getProfile(self):
372        # type: () -> (np.ndarray, np.ndarray)
373        """
374        Get SLD profile
375
376        : return: (z, beta) where z is a list of depth of the transition points
377                beta is a list of the corresponding SLD values
378        """
379        args = {} # type: Dict[str, Any]
380        for p in self._model_info.parameters.kernel_parameters:
381            if p.id == self.multiplicity_info.control:
382                value = float(self.multiplicity)
383            elif p.length == 1:
384                value = self.params.get(p.id, np.NaN)
385            else:
386                value = np.array([self.params.get(p.id+str(k), np.NaN)
387                                  for k in range(1, p.length+1)])
388            args[p.id] = value
389
390        x, y = self._model_info.profile(**args)
391        return x, 1e-6*y
392
393    def setParam(self, name, value):
394        # type: (str, float) -> None
395        """
396        Set the value of a model parameter
397
398        :param name: name of the parameter
399        :param value: value of the parameter
400
401        """
402        # Look for dispersion parameters
403        toks = name.split('.')
404        if len(toks) == 2:
405            for item in self.dispersion.keys():
406                if item == toks[0]:
407                    for par in self.dispersion[item]:
408                        if par == toks[1]:
409                            self.dispersion[item][par] = value
410                            return
411        else:
412            # Look for standard parameter
413            for item in self.params.keys():
414                if item == name:
415                    self.params[item] = value
416                    return
417
418        raise ValueError("Model does not contain parameter %s" % name)
419
420    def getParam(self, name):
421        # type: (str) -> float
422        """
423        Set the value of a model parameter
424
425        :param name: name of the parameter
426
427        """
428        # Look for dispersion parameters
429        toks = name.split('.')
430        if len(toks) == 2:
431            for item in self.dispersion.keys():
432                if item == toks[0]:
433                    for par in self.dispersion[item]:
434                        if par == toks[1]:
435                            return self.dispersion[item][par]
436        else:
437            # Look for standard parameter
438            for item in self.params.keys():
439                if item == name:
440                    return self.params[item]
441
442        raise ValueError("Model does not contain parameter %s" % name)
443
444    def getParamList(self):
445        # type: () -> Sequence[str]
446        """
447        Return a list of all available parameters for the model
448        """
449        param_list = list(self.params.keys())
450        # WARNING: Extending the list with the dispersion parameters
451        param_list.extend(self.getDispParamList())
452        return param_list
453
454    def getDispParamList(self):
455        # type: () -> Sequence[str]
456        """
457        Return a list of polydispersity parameters for the model
458        """
459        # TODO: fix test so that parameter order doesn't matter
460        ret = ['%s.%s' % (p_name, ext)
461               for p_name in self.dispersion.keys()
462               for ext in ('npts', 'nsigmas', 'width')]
463        #print(ret)
464        return ret
465
466    def clone(self):
467        # type: () -> "SasviewModel"
468        """ Return a identical copy of self """
469        return deepcopy(self)
470
471    def run(self, x=0.0):
472        # type: (Union[float, (float, float), List[float]]) -> float
473        """
474        Evaluate the model
475
476        :param x: input q, or [q,phi]
477
478        :return: scattering function P(q)
479
480        **DEPRECATED**: use calculate_Iq instead
481        """
482        if isinstance(x, (list, tuple)):
483            # pylint: disable=unpacking-non-sequence
484            q, phi = x
485            return self.calculate_Iq([q*math.cos(phi)], [q*math.sin(phi)])[0]
486        else:
487            return self.calculate_Iq([x])[0]
488
489
490    def runXY(self, x=0.0):
491        # type: (Union[float, (float, float), List[float]]) -> float
492        """
493        Evaluate the model in cartesian coordinates
494
495        :param x: input q, or [qx, qy]
496
497        :return: scattering function P(q)
498
499        **DEPRECATED**: use calculate_Iq instead
500        """
501        if isinstance(x, (list, tuple)):
502            return self.calculate_Iq([x[0]], [x[1]])[0]
503        else:
504            return self.calculate_Iq([x])[0]
505
506    def evalDistribution(self, qdist):
507        # type: (Union[np.ndarray, Tuple[np.ndarray, np.ndarray], List[np.ndarray]]) -> np.ndarray
508        r"""
509        Evaluate a distribution of q-values.
510
511        :param qdist: array of q or a list of arrays [qx,qy]
512
513        * For 1D, a numpy array is expected as input
514
515        ::
516
517            evalDistribution(q)
518
519          where *q* is a numpy array.
520
521        * For 2D, a list of *[qx,qy]* is expected with 1D arrays as input
522
523        ::
524
525              qx = [ qx[0], qx[1], qx[2], ....]
526              qy = [ qy[0], qy[1], qy[2], ....]
527
528        If the model is 1D only, then
529
530        .. math::
531
532            q = \sqrt{q_x^2+q_y^2}
533
534        """
535        if isinstance(qdist, (list, tuple)):
536            # Check whether we have a list of ndarrays [qx,qy]
537            qx, qy = qdist
538            if not self._model_info.parameters.has_2d:
539                return self.calculate_Iq(np.sqrt(qx ** 2 + qy ** 2))
540            else:
541                return self.calculate_Iq(qx, qy)
542
543        elif isinstance(qdist, np.ndarray):
544            # We have a simple 1D distribution of q-values
545            return self.calculate_Iq(qdist)
546
547        else:
548            raise TypeError("evalDistribution expects q or [qx, qy], not %r"
549                            % type(qdist))
550
551    def calculate_Iq(self, qx, qy=None):
552        # type: (Sequence[float], Optional[Sequence[float]]) -> np.ndarray
553        """
554        Calculate Iq for one set of q with the current parameters.
555
556        If the model is 1D, use *q*.  If 2D, use *qx*, *qy*.
557
558        This should NOT be used for fitting since it copies the *q* vectors
559        to the card for each evaluation.
560        """
561        #core.HAVE_OPENCL = False
562        if self._model is None:
563            self._model = core.build_model(self._model_info)
564        if qy is not None:
565            q_vectors = [np.asarray(qx), np.asarray(qy)]
566        else:
567            q_vectors = [np.asarray(qx)]
568        calculator = self._model.make_kernel(q_vectors)
569        parameters = self._model_info.parameters
570        pairs = [self._get_weights(p) for p in parameters.call_parameters]
571        #weights.plot_weights(self._model_info, pairs)
572        call_details, values, is_magnetic = make_kernel_args(calculator, pairs)
573        #call_details.show()
574        #print("pairs", pairs)
575        #print("params", self.params)
576        #print("values", values)
577        #print("is_mag", is_magnetic)
578        result = calculator(call_details, values, cutoff=self.cutoff,
579                            magnetic=is_magnetic)
580        calculator.release()
581        return result
582
583    def calculate_ER(self):
584        # type: () -> float
585        """
586        Calculate the effective radius for P(q)*S(q)
587
588        :return: the value of the effective radius
589        """
590        if self._model_info.ER is None:
591            return 1.0
592        else:
593            value, weight = self._dispersion_mesh()
594            fv = self._model_info.ER(*value)
595            #print(values[0].shape, weights.shape, fv.shape)
596            return np.sum(weight * fv) / np.sum(weight)
597
598    def calculate_VR(self):
599        # type: () -> float
600        """
601        Calculate the volf ratio for P(q)*S(q)
602
603        :return: the value of the volf ratio
604        """
605        if self._model_info.VR is None:
606            return 1.0
607        else:
608            value, weight = self._dispersion_mesh()
609            whole, part = self._model_info.VR(*value)
610            return np.sum(weight * part) / np.sum(weight * whole)
611
612    def set_dispersion(self, parameter, dispersion):
613        # type: (str, weights.Dispersion) -> Dict[str, Any]
614        """
615        Set the dispersion object for a model parameter
616
617        :param parameter: name of the parameter [string]
618        :param dispersion: dispersion object of type Dispersion
619        """
620        if parameter in self.params:
621            # TODO: Store the disperser object directly in the model.
622            # The current method of relying on the sasview GUI to
623            # remember them is kind of funky.
624            # Note: can't seem to get disperser parameters from sasview
625            # (1) Could create a sasview model that has not yet been
626            # converted, assign the disperser to one of its polydisperse
627            # parameters, then retrieve the disperser parameters from the
628            # sasview model.
629            # (2) Could write a disperser parameter retriever in sasview.
630            # (3) Could modify sasview to use sasmodels.weights dispersers.
631            # For now, rely on the fact that the sasview only ever uses
632            # new dispersers in the set_dispersion call and create a new
633            # one instead of trying to assign parameters.
634            self.dispersion[parameter] = dispersion.get_pars()
635        else:
636            raise ValueError("%r is not a dispersity or orientation parameter")
637
638    def _dispersion_mesh(self):
639        # type: () -> List[Tuple[np.ndarray, np.ndarray]]
640        """
641        Create a mesh grid of dispersion parameters and weights.
642
643        Returns [p1,p2,...],w where pj is a vector of values for parameter j
644        and w is a vector containing the products for weights for each
645        parameter set in the vector.
646        """
647        pars = [self._get_weights(p)
648                for p in self._model_info.parameters.call_parameters
649                if p.type == 'volume']
650        return dispersion_mesh(self._model_info, pars)
651
652    def _get_weights(self, par):
653        # type: (Parameter) -> Tuple[np.ndarray, np.ndarray]
654        """
655        Return dispersion weights for parameter
656        """
657        if par.name not in self.params:
658            if par.name == self.multiplicity_info.control:
659                return [self.multiplicity], [1.0]
660            else:
661                # For hidden parameters use the default value.
662                value = self._model_info.parameters.defaults.get(par.name, np.NaN)
663                return [value], [1.0]
664        elif par.polydisperse:
665            dis = self.dispersion[par.name]
666            if dis['type'] == 'array':
667                value, weight = dis['values'], dis['weights']
668            else:
669                value, weight = weights.get_weights(
670                    dis['type'], dis['npts'], dis['width'], dis['nsigmas'],
671                    self.params[par.name], par.limits, par.relative_pd)
672            return value, weight / np.sum(weight)
673        else:
674            return [self.params[par.name]], [1.0]
675
676def test_model():
677    # type: () -> float
678    """
679    Test that a sasview model (cylinder) can be run.
680    """
681    Cylinder = _make_standard_model('cylinder')
682    cylinder = Cylinder()
683    return cylinder.evalDistribution([0.1, 0.1])
684
685def test_structure_factor():
686    # type: () -> float
687    """
688    Test that a sasview model (cylinder) can be run.
689    """
690    Model = _make_standard_model('hardsphere')
691    model = Model()
692    value = model.evalDistribution([0.1, 0.1])
693    if np.isnan(value):
694        raise ValueError("hardsphere returns null")
695
696def test_rpa():
697    # type: () -> float
698    """
699    Test that a sasview model (cylinder) can be run.
700    """
701    RPA = _make_standard_model('rpa')
702    rpa = RPA(3)
703    return rpa.evalDistribution([0.1, 0.1])
704
705
706def test_model_list():
707    # type: () -> None
708    """
709    Make sure that all models build as sasview models.
710    """
711    from .exception import annotate_exception
712    for name in core.list_models():
713        try:
714            _make_standard_model(name)
715        except:
716            annotate_exception("when loading "+name)
717            raise
718
719def test_old_name():
720    # type: () -> None
721    """
722    Load and run cylinder model from sas.models.CylinderModel
723    """
724    if not SUPPORT_OLD_STYLE_PLUGINS:
725        return
726    try:
727        # if sasview is not on the path then don't try to test it
728        import sas
729    except ImportError:
730        return
731    load_standard_models()
732    from sas.models.CylinderModel import CylinderModel
733    CylinderModel().evalDistribution([0.1, 0.1])
734
735if __name__ == "__main__":
736    print("cylinder(0.1,0.1)=%g"%test_model())
Note: See TracBrowser for help on using the repository browser.