source: sasmodels/sasmodels/sasview_model.py @ fd19811

core_shell_microgelscostrafo411magnetic_modelrelease_v0.94release_v0.95ticket-1257-vesicle-productticket_1156ticket_1265_superballticket_822_more_unit_tests
Last change on this file since fd19811 was fd19811, checked in by Paul Kienzle <pkienzle@…>, 8 years ago

lint

  • Property mode set to 100644
File size: 24.0 KB
Line 
1"""
2Sasview model constructor.
3
4Given a module defining an OpenCL kernel such as sasmodels.models.cylinder,
5create a sasview model class to run that kernel as follows::
6
7    from sasmodels.sasview_model import load_custom_model
8    CylinderModel = load_custom_model('sasmodels/models/cylinder.py')
9"""
10from __future__ import print_function
11
12import math
13from copy import deepcopy
14import collections
15import traceback
16import logging
17from os.path import basename, splitext
18
19import numpy as np  # type: ignore
20
21from . import core
22from . import custom
23from . import generate
24from . import weights
25from . import modelinfo
26from .details import build_details, dispersion_mesh
27
28try:
29    from typing import Dict, Mapping, Any, Sequence, Tuple, NamedTuple, List, Optional, Union, Callable
30    from .modelinfo import ModelInfo, Parameter
31    from .kernel import KernelModel
32    MultiplicityInfoType = NamedTuple(
33        'MuliplicityInfo',
34        [("number", int), ("control", str), ("choices", List[str]),
35         ("x_axis_label", str)])
36    SasviewModelType = Callable[[int], "SasviewModel"]
37except ImportError:
38    pass
39
40SUPPORT_OLD_STYLE_PLUGINS = True
41
42def _register_old_models():
43    # type: () -> None
44    """
45    Place the new models into sasview under the old names.
46
47    Monkey patch sas.sascalc.fit as sas.models so that sas.models.pluginmodel
48    is available to the plugin modules.
49    """
50    import sys
51    import sas
52    import sas.sascalc.fit
53    sys.modules['sas.models'] = sas.sascalc.fit
54    sas.models = sas.sascalc.fit
55
56    import sas.models
57    from sasmodels.conversion_table import CONVERSION_TABLE
58    for new_name, conversion in CONVERSION_TABLE.items():
59        old_name = conversion[0]
60        module_attrs = {old_name: find_model(new_name)}
61        ConstructedModule = type(old_name, (), module_attrs)
62        old_path = 'sas.models.' + old_name
63        setattr(sas.models, old_path, ConstructedModule)
64        sys.modules[old_path] = ConstructedModule
65
66
67# TODO: separate x_axis_label from multiplicity info
68MultiplicityInfo = collections.namedtuple(
69    'MultiplicityInfo',
70    ["number", "control", "choices", "x_axis_label"],
71)
72
73MODELS = {}
74def find_model(modelname):
75    # type: (str) -> SasviewModelType
76    """
77    Find a model by name.  If the model name ends in py, try loading it from
78    custom models, otherwise look for it in the list of builtin models.
79    """
80    # TODO: used by sum/product model to load an existing model
81    # TODO: doesn't handle custom models properly
82    if modelname.endswith('.py'):
83        return load_custom_model(modelname)
84    elif modelname in MODELS:
85        return MODELS[modelname]
86    else:
87        raise ValueError("unknown model %r"%modelname)
88
89
90# TODO: figure out how to say that the return type is a subclass
91def load_standard_models():
92    # type: () -> List[SasviewModelType]
93    """
94    Load and return the list of predefined models.
95
96    If there is an error loading a model, then a traceback is logged and the
97    model is not returned.
98    """
99    models = []
100    for name in core.list_models():
101        try:
102            MODELS[name] = _make_standard_model(name)
103            models.append(MODELS[name])
104        except Exception:
105            logging.error(traceback.format_exc())
106    if SUPPORT_OLD_STYLE_PLUGINS:
107        _register_old_models()
108
109    return models
110
111
112def load_custom_model(path):
113    # type: (str) -> SasviewModelType
114    """
115    Load a custom model given the model path.
116    """
117    kernel_module = custom.load_custom_kernel_module(path)
118    try:
119        model = kernel_module.Model
120        # Old style models do not set the name in the class attributes, so
121        # set it here; this name will be overridden when the object is created
122        # with an instance variable that has the same value.
123        if model.name == "":
124            model.name = splitext(basename(path))[0]
125    except AttributeError:
126        model_info = modelinfo.make_model_info(kernel_module)
127        model = _make_model_from_info(model_info)
128    MODELS[model.name] = model
129    return model
130
131
132def _make_standard_model(name):
133    # type: (str) -> SasviewModelType
134    """
135    Load the sasview model defined by *name*.
136
137    *name* can be a standard model name or a path to a custom model.
138
139    Returns a class that can be used directly as a sasview model.
140    """
141    kernel_module = generate.load_kernel_module(name)
142    model_info = modelinfo.make_model_info(kernel_module)
143    return _make_model_from_info(model_info)
144
145
146def _make_model_from_info(model_info):
147    # type: (ModelInfo) -> SasviewModelType
148    """
149    Convert *model_info* into a SasView model wrapper.
150    """
151    def __init__(self, multiplicity=None):
152        SasviewModel.__init__(self, multiplicity=multiplicity)
153    attrs = _generate_model_attributes(model_info)
154    attrs['__init__'] = __init__
155    ConstructedModel = type(model_info.name, (SasviewModel,), attrs) # type: SasviewModelType
156    return ConstructedModel
157
158def _generate_model_attributes(model_info):
159    # type: (ModelInfo) -> Dict[str, Any]
160    """
161    Generate the class attributes for the model.
162
163    This should include all the information necessary to query the model
164    details so that you do not need to instantiate a model to query it.
165
166    All the attributes should be immutable to avoid accidents.
167    """
168
169    # TODO: allow model to override axis labels input/output name/unit
170
171    # Process multiplicity
172    non_fittable = []  # type: List[str]
173    xlabel = model_info.profile_axes[0] if model_info.profile is not None else ""
174    variants = MultiplicityInfo(0, "", [], xlabel)
175    for p in model_info.parameters.kernel_parameters:
176        if p.name == model_info.control:
177            non_fittable.append(p.name)
178            variants = MultiplicityInfo(
179                len(p.choices) if p.choices else int(p.limits[1]),
180                p.name, p.choices, xlabel
181            )
182            break
183
184    # Only a single drop-down list parameter available
185    fun_list = []
186    for p in model_info.parameters.kernel_parameters:
187        if p.choices:
188            fun_list = p.choices
189            if p.length > 1:
190                non_fittable.extend(p.id+str(k) for k in range(1, p.length+1))
191            break
192
193    # Organize parameter sets
194    orientation_params = []
195    magnetic_params = []
196    fixed = []
197    for p in model_info.parameters.user_parameters():
198        if p.type == 'orientation':
199            orientation_params.append(p.name)
200            orientation_params.append(p.name+".width")
201            fixed.append(p.name+".width")
202        elif p.type == 'magnetic':
203            orientation_params.append(p.name)
204            magnetic_params.append(p.name)
205            fixed.append(p.name+".width")
206
207
208    # Build class dictionary
209    attrs = {}  # type: Dict[str, Any]
210    attrs['_model_info'] = model_info
211    attrs['name'] = model_info.name
212    attrs['id'] = model_info.id
213    attrs['description'] = model_info.description
214    attrs['category'] = model_info.category
215    attrs['is_structure_factor'] = model_info.structure_factor
216    attrs['is_form_factor'] = model_info.ER is not None
217    attrs['is_multiplicity_model'] = variants[0] > 1
218    attrs['multiplicity_info'] = variants
219    attrs['orientation_params'] = tuple(orientation_params)
220    attrs['magnetic_params'] = tuple(magnetic_params)
221    attrs['fixed'] = tuple(fixed)
222    attrs['non_fittable'] = tuple(non_fittable)
223    attrs['fun_list'] = tuple(fun_list)
224
225    return attrs
226
227class SasviewModel(object):
228    """
229    Sasview wrapper for opencl/ctypes model.
230    """
231    # Model parameters for the specific model are set in the class constructor
232    # via the _generate_model_attributes function, which subclasses
233    # SasviewModel.  They are included here for typing and documentation
234    # purposes.
235    _model = None       # type: KernelModel
236    _model_info = None  # type: ModelInfo
237    #: load/save name for the model
238    id = None           # type: str
239    #: display name for the model
240    name = None         # type: str
241    #: short model description
242    description = None  # type: str
243    #: default model category
244    category = None     # type: str
245
246    #: names of the orientation parameters in the order they appear
247    orientation_params = None # type: Sequence[str]
248    #: names of the magnetic parameters in the order they appear
249    magnetic_params = None    # type: Sequence[str]
250    #: names of the fittable parameters
251    fixed = None              # type: Sequence[str]
252    # TODO: the attribute fixed is ill-named
253
254    # Axis labels
255    input_name = "Q"
256    input_unit = "A^{-1}"
257    output_name = "Intensity"
258    output_unit = "cm^{-1}"
259
260    #: default cutoff for polydispersity
261    cutoff = 1e-5
262
263    # Note: Use non-mutable values for class attributes to avoid errors
264    #: parameters that are not fitted
265    non_fittable = ()        # type: Sequence[str]
266
267    #: True if model should appear as a structure factor
268    is_structure_factor = False
269    #: True if model should appear as a form factor
270    is_form_factor = False
271    #: True if model has multiplicity
272    is_multiplicity_model = False
273    #: Mulitplicity information
274    multiplicity_info = None # type: MultiplicityInfoType
275
276    # Per-instance variables
277    #: parameter {name: value} mapping
278    params = None      # type: Dict[str, float]
279    #: values for dispersion width, npts, nsigmas and type
280    dispersion = None  # type: Dict[str, Any]
281    #: units and limits for each parameter
282    details = None     # type: Dict[str, Sequence[Any]]
283    #                  # actual type is Dict[str, List[str, float, float]]
284    #: multiplicity value, or None if no multiplicity on the model
285    multiplicity = None     # type: Optional[int]
286    #: memory for polydispersity array if using ArrayDispersion (used by sasview).
287    _persistency_dict = None # type: Dict[str, Tuple[np.ndarray, np.ndarray]]
288
289    def __init__(self, multiplicity=None):
290        # type: (Optional[int]) -> None
291
292        # TODO: _persistency_dict to persistency_dict throughout sasview
293        # TODO: refactor multiplicity to encompass variants
294        # TODO: dispersion should be a class
295        # TODO: refactor multiplicity info
296        # TODO: separate profile view from multiplicity
297        # The button label, x and y axis labels and scale need to be under
298        # the control of the model, not the fit page.  Maximum flexibility,
299        # the fit page would supply the canvas and the profile could plot
300        # how it wants, but this assumes matplotlib.  Next level is that
301        # we provide some sort of data description including title, labels
302        # and lines to plot.
303
304        # Get the list of hidden parameters given the mulitplicity
305        # Don't include multiplicity in the list of parameters
306        self.multiplicity = multiplicity
307        if multiplicity is not None:
308            hidden = self._model_info.get_hidden_parameters(multiplicity)
309            hidden |= set([self.multiplicity_info.control])
310        else:
311            hidden = set()
312
313        self._persistency_dict = {}
314        self.params = collections.OrderedDict()
315        self.dispersion = collections.OrderedDict()
316        self.details = {}
317        for p in self._model_info.parameters.user_parameters():
318            if p.name in hidden:
319                continue
320            self.params[p.name] = p.default
321            self.details[p.id] = [p.units, p.limits[0], p.limits[1]]
322            if p.polydisperse:
323                self.details[p.id+".width"] = [
324                    "", 0.0, 1.0 if p.relative_pd else np.inf
325                ]
326                self.dispersion[p.name] = {
327                    'width': 0,
328                    'npts': 35,
329                    'nsigmas': 3,
330                    'type': 'gaussian',
331                }
332
333    def __get_state__(self):
334        # type: () -> Dict[str, Any]
335        state = self.__dict__.copy()
336        state.pop('_model')
337        # May need to reload model info on set state since it has pointers
338        # to python implementations of Iq, etc.
339        #state.pop('_model_info')
340        return state
341
342    def __set_state__(self, state):
343        # type: (Dict[str, Any]) -> None
344        self.__dict__ = state
345        self._model = None
346
347    def __str__(self):
348        # type: () -> str
349        """
350        :return: string representation
351        """
352        return self.name
353
354    def is_fittable(self, par_name):
355        # type: (str) -> bool
356        """
357        Check if a given parameter is fittable or not
358
359        :param par_name: the parameter name to check
360        """
361        return par_name in self.fixed
362        #For the future
363        #return self.params[str(par_name)].is_fittable()
364
365
366    def getProfile(self):
367        # type: () -> (np.ndarray, np.ndarray)
368        """
369        Get SLD profile
370
371        : return: (z, beta) where z is a list of depth of the transition points
372                beta is a list of the corresponding SLD values
373        """
374        args = {} # type: Dict[str, Any]
375        for p in self._model_info.parameters.kernel_parameters:
376            if p.id == self.multiplicity_info.control:
377                value = float(self.multiplicity)
378            elif p.length == 1:
379                value = self.params.get(p.id, np.NaN)
380            else:
381                value = np.array([self.params.get(p.id+str(k), np.NaN)
382                                  for k in range(1, p.length+1)])
383            args[p.id] = value
384
385        x, y = self._model_info.profile(**args)
386        return x, 1e-6*y
387
388    def setParam(self, name, value):
389        # type: (str, float) -> None
390        """
391        Set the value of a model parameter
392
393        :param name: name of the parameter
394        :param value: value of the parameter
395
396        """
397        # Look for dispersion parameters
398        toks = name.split('.')
399        if len(toks) == 2:
400            for item in self.dispersion.keys():
401                if item == toks[0]:
402                    for par in self.dispersion[item]:
403                        if par == toks[1]:
404                            self.dispersion[item][par] = value
405                            return
406        else:
407            # Look for standard parameter
408            for item in self.params.keys():
409                if item == name:
410                    self.params[item] = value
411                    return
412
413        raise ValueError("Model does not contain parameter %s" % name)
414
415    def getParam(self, name):
416        # type: (str) -> float
417        """
418        Set the value of a model parameter
419
420        :param name: name of the parameter
421
422        """
423        # Look for dispersion parameters
424        toks = name.split('.')
425        if len(toks) == 2:
426            for item in self.dispersion.keys():
427                if item == toks[0]:
428                    for par in self.dispersion[item]:
429                        if par == toks[1]:
430                            return self.dispersion[item][par]
431        else:
432            # Look for standard parameter
433            for item in self.params.keys():
434                if item == name:
435                    return self.params[item]
436
437        raise ValueError("Model does not contain parameter %s" % name)
438
439    def getParamList(self):
440        # type: () -> Sequence[str]
441        """
442        Return a list of all available parameters for the model
443        """
444        param_list = list(self.params.keys())
445        # WARNING: Extending the list with the dispersion parameters
446        param_list.extend(self.getDispParamList())
447        return param_list
448
449    def getDispParamList(self):
450        # type: () -> Sequence[str]
451        """
452        Return a list of polydispersity parameters for the model
453        """
454        # TODO: fix test so that parameter order doesn't matter
455        ret = ['%s.%s' % (p.name, ext)
456               for p in self._model_info.parameters.user_parameters()
457               for ext in ('npts', 'nsigmas', 'width')
458               if p.polydisperse]
459        #print(ret)
460        return ret
461
462    def clone(self):
463        # type: () -> "SasviewModel"
464        """ Return a identical copy of self """
465        return deepcopy(self)
466
467    def run(self, x=0.0):
468        # type: (Union[float, (float, float), List[float]]) -> float
469        """
470        Evaluate the model
471
472        :param x: input q, or [q,phi]
473
474        :return: scattering function P(q)
475
476        **DEPRECATED**: use calculate_Iq instead
477        """
478        if isinstance(x, (list, tuple)):
479            # pylint: disable=unpacking-non-sequence
480            q, phi = x
481            return self.calculate_Iq([q*math.cos(phi)], [q*math.sin(phi)])[0]
482        else:
483            return self.calculate_Iq([x])[0]
484
485
486    def runXY(self, x=0.0):
487        # type: (Union[float, (float, float), List[float]]) -> float
488        """
489        Evaluate the model in cartesian coordinates
490
491        :param x: input q, or [qx, qy]
492
493        :return: scattering function P(q)
494
495        **DEPRECATED**: use calculate_Iq instead
496        """
497        if isinstance(x, (list, tuple)):
498            return self.calculate_Iq([x[0]], [x[1]])[0]
499        else:
500            return self.calculate_Iq([x])[0]
501
502    def evalDistribution(self, qdist):
503        # type: (Union[np.ndarray, Tuple[np.ndarray, np.ndarray], List[np.ndarray]]) -> np.ndarray
504        r"""
505        Evaluate a distribution of q-values.
506
507        :param qdist: array of q or a list of arrays [qx,qy]
508
509        * For 1D, a numpy array is expected as input
510
511        ::
512
513            evalDistribution(q)
514
515          where *q* is a numpy array.
516
517        * For 2D, a list of *[qx,qy]* is expected with 1D arrays as input
518
519        ::
520
521              qx = [ qx[0], qx[1], qx[2], ....]
522              qy = [ qy[0], qy[1], qy[2], ....]
523
524        If the model is 1D only, then
525
526        .. math::
527
528            q = \sqrt{q_x^2+q_y^2}
529
530        """
531        if isinstance(qdist, (list, tuple)):
532            # Check whether we have a list of ndarrays [qx,qy]
533            qx, qy = qdist
534            if not self._model_info.parameters.has_2d:
535                return self.calculate_Iq(np.sqrt(qx ** 2 + qy ** 2))
536            else:
537                return self.calculate_Iq(qx, qy)
538
539        elif isinstance(qdist, np.ndarray):
540            # We have a simple 1D distribution of q-values
541            return self.calculate_Iq(qdist)
542
543        else:
544            raise TypeError("evalDistribution expects q or [qx, qy], not %r"
545                            % type(qdist))
546
547    def calculate_Iq(self, qx, qy=None):
548        # type: (Sequence[float], Optional[Sequence[float]]) -> np.ndarray
549        """
550        Calculate Iq for one set of q with the current parameters.
551
552        If the model is 1D, use *q*.  If 2D, use *qx*, *qy*.
553
554        This should NOT be used for fitting since it copies the *q* vectors
555        to the card for each evaluation.
556        """
557        #core.HAVE_OPENCL = False
558        if self._model is None:
559            self._model = core.build_model(self._model_info)
560        if qy is not None:
561            q_vectors = [np.asarray(qx), np.asarray(qy)]
562        else:
563            q_vectors = [np.asarray(qx)]
564        calculator = self._model.make_kernel(q_vectors)
565        parameters = self._model_info.parameters
566        pairs = [self._get_weights(p) for p in parameters.call_parameters]
567        call_details, values, is_magnetic = build_details(calculator, pairs)
568        #call_details.show()
569        #print("pairs", pairs)
570        #print("params", self.params)
571        #print("values", values)
572        #print("is_mag", is_magnetic)
573        result = calculator(call_details, values, cutoff=self.cutoff,
574                            magnetic=is_magnetic)
575        calculator.release()
576        return result
577
578    def calculate_ER(self):
579        # type: () -> float
580        """
581        Calculate the effective radius for P(q)*S(q)
582
583        :return: the value of the effective radius
584        """
585        if self._model_info.ER is None:
586            return 1.0
587        else:
588            value, weight = self._dispersion_mesh()
589            fv = self._model_info.ER(*value)
590            #print(values[0].shape, weights.shape, fv.shape)
591            return np.sum(weight * fv) / np.sum(weight)
592
593    def calculate_VR(self):
594        # type: () -> float
595        """
596        Calculate the volf ratio for P(q)*S(q)
597
598        :return: the value of the volf ratio
599        """
600        if self._model_info.VR is None:
601            return 1.0
602        else:
603            value, weight = self._dispersion_mesh()
604            whole, part = self._model_info.VR(*value)
605            return np.sum(weight * part) / np.sum(weight * whole)
606
607    def set_dispersion(self, parameter, dispersion):
608        # type: (str, weights.Dispersion) -> Dict[str, Any]
609        """
610        Set the dispersion object for a model parameter
611
612        :param parameter: name of the parameter [string]
613        :param dispersion: dispersion object of type Dispersion
614        """
615        if parameter in self.params:
616            # TODO: Store the disperser object directly in the model.
617            # The current method of relying on the sasview GUI to
618            # remember them is kind of funky.
619            # Note: can't seem to get disperser parameters from sasview
620            # (1) Could create a sasview model that has not yet # been
621            # converted, assign the disperser to one of its polydisperse
622            # parameters, then retrieve the disperser parameters from the
623            # sasview model.  (2) Could write a disperser parameter retriever
624            # in sasview.  (3) Could modify sasview to use sasmodels.weights
625            # dispersers.
626            # For now, rely on the fact that the sasview only ever uses
627            # new dispersers in the set_dispersion call and create a new
628            # one instead of trying to assign parameters.
629            dispersion = weights.MODELS[dispersion.type]()
630            self.dispersion[parameter] = dispersion.get_pars()
631        else:
632            raise ValueError("%r is not a dispersity or orientation parameter")
633
634    def _dispersion_mesh(self):
635        # type: () -> List[Tuple[np.ndarray, np.ndarray]]
636        """
637        Create a mesh grid of dispersion parameters and weights.
638
639        Returns [p1,p2,...],w where pj is a vector of values for parameter j
640        and w is a vector containing the products for weights for each
641        parameter set in the vector.
642        """
643        pars = [self._get_weights(p)
644                for p in self._model_info.parameters.call_parameters
645                if p.type == 'volume']
646        return dispersion_mesh(self._model_info, pars)
647
648    def _get_weights(self, par):
649        # type: (Parameter) -> Tuple[np.ndarray, np.ndarray]
650        """
651        Return dispersion weights for parameter
652        """
653        if par.name not in self.params:
654            if par.name == self.multiplicity_info.control:
655                return [self.multiplicity], [1.0]
656            else:
657                return [np.NaN], [1.0]
658        elif par.polydisperse:
659            dis = self.dispersion[par.name]
660            value, weight = weights.get_weights(
661                dis['type'], dis['npts'], dis['width'], dis['nsigmas'],
662                self.params[par.name], par.limits, par.relative_pd)
663            return value, weight / np.sum(weight)
664        else:
665            return [self.params[par.name]], [1.0]
666
667def test_model():
668    # type: () -> float
669    """
670    Test that a sasview model (cylinder) can be run.
671    """
672    Cylinder = _make_standard_model('cylinder')
673    cylinder = Cylinder()
674    return cylinder.evalDistribution([0.1, 0.1])
675
676def test_rpa():
677    # type: () -> float
678    """
679    Test that a sasview model (cylinder) can be run.
680    """
681    RPA = _make_standard_model('rpa')
682    rpa = RPA(3)
683    return rpa.evalDistribution([0.1, 0.1])
684
685
686def test_model_list():
687    # type: () -> None
688    """
689    Make sure that all models build as sasview models.
690    """
691    from .exception import annotate_exception
692    for name in core.list_models():
693        try:
694            _make_standard_model(name)
695        except:
696            annotate_exception("when loading "+name)
697            raise
698
699def test_old_name():
700    # type: () -> None
701    """
702    Load and run cylinder model from sas.models.CylinderModel
703    """
704    if not SUPPORT_OLD_STYLE_PLUGINS:
705        return
706    try:
707        # if sasview is not on the path then don't try to test it
708        import sas
709    except ImportError:
710        return
711    load_standard_models()
712    from sas.models.CylinderModel import CylinderModel
713    CylinderModel().evalDistribution([0.1, 0.1])
714
715if __name__ == "__main__":
716    print("cylinder(0.1,0.1)=%g"%test_model())
Note: See TracBrowser for help on using the repository browser.