source: sasmodels/sasmodels/sasview_model.py @ 9f37726

core_shell_microgelscostrafo411magnetic_modelticket-1257-vesicle-productticket_1156ticket_1265_superballticket_822_more_unit_tests
Last change on this file since 9f37726 was 9f37726, checked in by wojciech, 8 years ago

Removing try/except clause that wasn't properly handled

  • Property mode set to 100644
File size: 24.5 KB
Line 
1"""
2Sasview model constructor.
3
4Given a module defining an OpenCL kernel such as sasmodels.models.cylinder,
5create a sasview model class to run that kernel as follows::
6
7    from sasmodels.sasview_model import load_custom_model
8    CylinderModel = load_custom_model('sasmodels/models/cylinder.py')
9"""
10from __future__ import print_function
11
12import math
13from copy import deepcopy
14import collections
15import traceback
16import logging
17from os.path import basename, splitext
18
19import numpy as np  # type: ignore
20
21from . import core
22from . import custom
23from . import generate
24from . import weights
25from . import modelinfo
26from .details import make_kernel_args, dispersion_mesh
27
28try:
29    from typing import Dict, Mapping, Any, Sequence, Tuple, NamedTuple, List, Optional, Union, Callable
30    from .modelinfo import ModelInfo, Parameter
31    from .kernel import KernelModel
32    MultiplicityInfoType = NamedTuple(
33        'MuliplicityInfo',
34        [("number", int), ("control", str), ("choices", List[str]),
35         ("x_axis_label", str)])
36    SasviewModelType = Callable[[int], "SasviewModel"]
37except ImportError:
38    pass
39
40SUPPORT_OLD_STYLE_PLUGINS = True
41
42def _register_old_models():
43    # type: () -> None
44    """
45    Place the new models into sasview under the old names.
46
47    Monkey patch sas.sascalc.fit as sas.models so that sas.models.pluginmodel
48    is available to the plugin modules.
49    """
50    import sys
51    import sas
52    import sas.sascalc.fit
53    sys.modules['sas.models'] = sas.sascalc.fit
54    sas.models = sas.sascalc.fit
55
56    import sas.models
57    from sasmodels.conversion_table import CONVERSION_TABLE
58    for new_name, conversion in CONVERSION_TABLE.items():
59        old_name = conversion[0]
60        module_attrs = {old_name: find_model(new_name)}
61        ConstructedModule = type(old_name, (), module_attrs)
62        old_path = 'sas.models.' + old_name
63        setattr(sas.models, old_path, ConstructedModule)
64        sys.modules[old_path] = ConstructedModule
65
66
67# TODO: separate x_axis_label from multiplicity info
68MultiplicityInfo = collections.namedtuple(
69    'MultiplicityInfo',
70    ["number", "control", "choices", "x_axis_label"],
71)
72
73MODELS = {}
74def find_model(modelname):
75    # type: (str) -> SasviewModelType
76    """
77    Find a model by name.  If the model name ends in py, try loading it from
78    custom models, otherwise look for it in the list of builtin models.
79    """
80    # TODO: used by sum/product model to load an existing model
81    # TODO: doesn't handle custom models properly
82    if modelname.endswith('.py'):
83        return load_custom_model(modelname)
84    elif modelname in MODELS:
85        return MODELS[modelname]
86    else:
87        raise ValueError("unknown model %r"%modelname)
88
89
90# TODO: figure out how to say that the return type is a subclass
91def load_standard_models():
92    # type: () -> List[SasviewModelType]
93    """
94    Load and return the list of predefined models.
95
96    If there is an error loading a model, then a traceback is logged and the
97    model is not returned.
98    """
99    models = []
100    for name in core.list_models():
101        try:
102            MODELS[name] = _make_standard_model(name)
103            models.append(MODELS[name])
104        except Exception:
105            logging.error(traceback.format_exc())
106    if SUPPORT_OLD_STYLE_PLUGINS:
107        _register_old_models()
108
109    return models
110
111
112def load_custom_model(path):
113    # type: (str) -> SasviewModelType
114    """
115    Load a custom model given the model path.
116    """
117    kernel_module = custom.load_custom_kernel_module(path)
118    try:
119        model = kernel_module.Model
120        # Old style models do not set the name in the class attributes, so
121        # set it here; this name will be overridden when the object is created
122        # with an instance variable that has the same value.
123        if model.name == "":
124            model.name = splitext(basename(path))[0]
125    except AttributeError:
126        model_info = modelinfo.make_model_info(kernel_module)
127        model = _make_model_from_info(model_info)
128
129    # If we are trying to load a model that already exists,
130    # append a version number to its name.
131    # Note that models appear to be periodically reloaded
132    # by SasView and keeping track of whether we are reloading
133    # a model or loading it for the first time is tricky.
134    # For now, just allow one custom model of a given name.
135    if model.name in MODELS:
136        model.name = "%s_v2" % model.name
137
138    MODELS[model.name] = model
139    return model
140
141
142def _make_standard_model(name):
143    # type: (str) -> SasviewModelType
144    """
145    Load the sasview model defined by *name*.
146
147    *name* can be a standard model name or a path to a custom model.
148
149    Returns a class that can be used directly as a sasview model.
150    """
151    kernel_module = generate.load_kernel_module(name)
152    model_info = modelinfo.make_model_info(kernel_module)
153    return _make_model_from_info(model_info)
154
155
156def _make_model_from_info(model_info):
157    # type: (ModelInfo) -> SasviewModelType
158    """
159    Convert *model_info* into a SasView model wrapper.
160    """
161    def __init__(self, multiplicity=None):
162        SasviewModel.__init__(self, multiplicity=multiplicity)
163    attrs = _generate_model_attributes(model_info)
164    attrs['__init__'] = __init__
165    ConstructedModel = type(model_info.name, (SasviewModel,), attrs) # type: SasviewModelType
166    return ConstructedModel
167
168def _generate_model_attributes(model_info):
169    # type: (ModelInfo) -> Dict[str, Any]
170    """
171    Generate the class attributes for the model.
172
173    This should include all the information necessary to query the model
174    details so that you do not need to instantiate a model to query it.
175
176    All the attributes should be immutable to avoid accidents.
177    """
178
179    # TODO: allow model to override axis labels input/output name/unit
180
181    # Process multiplicity
182    non_fittable = []  # type: List[str]
183    xlabel = model_info.profile_axes[0] if model_info.profile is not None else ""
184    variants = MultiplicityInfo(0, "", [], xlabel)
185    for p in model_info.parameters.kernel_parameters:
186        if p.name == model_info.control:
187            non_fittable.append(p.name)
188            variants = MultiplicityInfo(
189                len(p.choices) if p.choices else int(p.limits[1]),
190                p.name, p.choices, xlabel
191            )
192            break
193
194    # Only a single drop-down list parameter available
195    fun_list = []
196    for p in model_info.parameters.kernel_parameters:
197        if p.choices:
198            fun_list = p.choices
199            if p.length > 1:
200                non_fittable.extend(p.id+str(k) for k in range(1, p.length+1))
201            break
202
203    # Organize parameter sets
204    orientation_params = []
205    magnetic_params = []
206    fixed = []
207    for p in model_info.parameters.user_parameters():
208        if p.type == 'orientation':
209            orientation_params.append(p.name)
210            orientation_params.append(p.name+".width")
211            fixed.append(p.name+".width")
212        elif p.type == 'magnetic':
213            orientation_params.append(p.name)
214            magnetic_params.append(p.name)
215            fixed.append(p.name+".width")
216
217
218    # Build class dictionary
219    attrs = {}  # type: Dict[str, Any]
220    attrs['_model_info'] = model_info
221    attrs['name'] = model_info.name
222    attrs['id'] = model_info.id
223    attrs['description'] = model_info.description
224    attrs['category'] = model_info.category
225    attrs['is_structure_factor'] = model_info.structure_factor
226    attrs['is_form_factor'] = model_info.ER is not None
227    attrs['is_multiplicity_model'] = variants[0] > 1
228    attrs['multiplicity_info'] = variants
229    attrs['orientation_params'] = tuple(orientation_params)
230    attrs['magnetic_params'] = tuple(magnetic_params)
231    attrs['fixed'] = tuple(fixed)
232    attrs['non_fittable'] = tuple(non_fittable)
233    attrs['fun_list'] = tuple(fun_list)
234
235    return attrs
236
237class SasviewModel(object):
238    """
239    Sasview wrapper for opencl/ctypes model.
240    """
241    # Model parameters for the specific model are set in the class constructor
242    # via the _generate_model_attributes function, which subclasses
243    # SasviewModel.  They are included here for typing and documentation
244    # purposes.
245    _model = None       # type: KernelModel
246    _model_info = None  # type: ModelInfo
247    #: load/save name for the model
248    id = None           # type: str
249    #: display name for the model
250    name = None         # type: str
251    #: short model description
252    description = None  # type: str
253    #: default model category
254    category = None     # type: str
255
256    #: names of the orientation parameters in the order they appear
257    orientation_params = None # type: Sequence[str]
258    #: names of the magnetic parameters in the order they appear
259    magnetic_params = None    # type: Sequence[str]
260    #: names of the fittable parameters
261    fixed = None              # type: Sequence[str]
262    # TODO: the attribute fixed is ill-named
263
264    # Axis labels
265    input_name = "Q"
266    input_unit = "A^{-1}"
267    output_name = "Intensity"
268    output_unit = "cm^{-1}"
269
270    #: default cutoff for polydispersity
271    cutoff = 1e-5
272
273    # Note: Use non-mutable values for class attributes to avoid errors
274    #: parameters that are not fitted
275    non_fittable = ()        # type: Sequence[str]
276
277    #: True if model should appear as a structure factor
278    is_structure_factor = False
279    #: True if model should appear as a form factor
280    is_form_factor = False
281    #: True if model has multiplicity
282    is_multiplicity_model = False
283    #: Mulitplicity information
284    multiplicity_info = None # type: MultiplicityInfoType
285
286    # Per-instance variables
287    #: parameter {name: value} mapping
288    params = None      # type: Dict[str, float]
289    #: values for dispersion width, npts, nsigmas and type
290    dispersion = None  # type: Dict[str, Any]
291    #: units and limits for each parameter
292    details = None     # type: Dict[str, Sequence[Any]]
293    #                  # actual type is Dict[str, List[str, float, float]]
294    #: multiplicity value, or None if no multiplicity on the model
295    multiplicity = None     # type: Optional[int]
296    #: memory for polydispersity array if using ArrayDispersion (used by sasview).
297    _persistency_dict = None # type: Dict[str, Tuple[np.ndarray, np.ndarray]]
298
299    def __init__(self, multiplicity=None):
300        # type: (Optional[int]) -> None
301
302        # TODO: _persistency_dict to persistency_dict throughout sasview
303        # TODO: refactor multiplicity to encompass variants
304        # TODO: dispersion should be a class
305        # TODO: refactor multiplicity info
306        # TODO: separate profile view from multiplicity
307        # The button label, x and y axis labels and scale need to be under
308        # the control of the model, not the fit page.  Maximum flexibility,
309        # the fit page would supply the canvas and the profile could plot
310        # how it wants, but this assumes matplotlib.  Next level is that
311        # we provide some sort of data description including title, labels
312        # and lines to plot.
313
314        # Get the list of hidden parameters given the mulitplicity
315        # Don't include multiplicity in the list of parameters
316        self.multiplicity = multiplicity
317        if multiplicity is not None:
318            hidden = self._model_info.get_hidden_parameters(multiplicity)
319            hidden |= set([self.multiplicity_info.control])
320        else:
321            hidden = set()
322
323        self._persistency_dict = {}
324        self.params = collections.OrderedDict()
325        self.dispersion = collections.OrderedDict()
326        self.details = {}
327        for p in self._model_info.parameters.user_parameters():
328            if p.name in hidden:
329                continue
330            self.params[p.name] = p.default
331            self.details[p.id] = [p.units, p.limits[0], p.limits[1]]
332            if p.polydisperse:
333                self.details[p.id+".width"] = [
334                    "", 0.0, 1.0 if p.relative_pd else np.inf
335                ]
336                self.dispersion[p.name] = {
337                    'width': 0,
338                    'npts': 35,
339                    'nsigmas': 3,
340                    'type': 'gaussian',
341                }
342
343    def __get_state__(self):
344        # type: () -> Dict[str, Any]
345        state = self.__dict__.copy()
346        state.pop('_model')
347        # May need to reload model info on set state since it has pointers
348        # to python implementations of Iq, etc.
349        #state.pop('_model_info')
350        return state
351
352    def __set_state__(self, state):
353        # type: (Dict[str, Any]) -> None
354        self.__dict__ = state
355        self._model = None
356
357    def __str__(self):
358        # type: () -> str
359        """
360        :return: string representation
361        """
362        return self.name
363
364    def is_fittable(self, par_name):
365        # type: (str) -> bool
366        """
367        Check if a given parameter is fittable or not
368
369        :param par_name: the parameter name to check
370        """
371        return par_name in self.fixed
372        #For the future
373        #return self.params[str(par_name)].is_fittable()
374
375
376    def getProfile(self):
377        # type: () -> (np.ndarray, np.ndarray)
378        """
379        Get SLD profile
380
381        : return: (z, beta) where z is a list of depth of the transition points
382                beta is a list of the corresponding SLD values
383        """
384        args = {} # type: Dict[str, Any]
385        for p in self._model_info.parameters.kernel_parameters:
386            if p.id == self.multiplicity_info.control:
387                value = float(self.multiplicity)
388            elif p.length == 1:
389                value = self.params.get(p.id, np.NaN)
390            else:
391                value = np.array([self.params.get(p.id+str(k), np.NaN)
392                                  for k in range(1, p.length+1)])
393            args[p.id] = value
394
395        x, y = self._model_info.profile(**args)
396        return x, 1e-6*y
397
398    def setParam(self, name, value):
399        # type: (str, float) -> None
400        """
401        Set the value of a model parameter
402
403        :param name: name of the parameter
404        :param value: value of the parameter
405
406        """
407        # Look for dispersion parameters
408        toks = name.split('.')
409        if len(toks) == 2:
410            for item in self.dispersion.keys():
411                if item == toks[0]:
412                    for par in self.dispersion[item]:
413                        if par == toks[1]:
414                            self.dispersion[item][par] = value
415                            return
416        else:
417            # Look for standard parameter
418            for item in self.params.keys():
419                if item == name:
420                    self.params[item] = value
421                    return
422
423        raise ValueError("Model does not contain parameter %s" % name)
424
425    def getParam(self, name):
426        # type: (str) -> float
427        """
428        Set the value of a model parameter
429
430        :param name: name of the parameter
431
432        """
433        # Look for dispersion parameters
434        toks = name.split('.')
435        if len(toks) == 2:
436            for item in self.dispersion.keys():
437                if item == toks[0]:
438                    for par in self.dispersion[item]:
439                        if par == toks[1]:
440                            return self.dispersion[item][par]
441        else:
442            # Look for standard parameter
443            for item in self.params.keys():
444                if item == name:
445                    return self.params[item]
446
447        raise ValueError("Model does not contain parameter %s" % name)
448
449    def getParamList(self):
450        # type: () -> Sequence[str]
451        """
452        Return a list of all available parameters for the model
453        """
454        param_list = list(self.params.keys())
455        # WARNING: Extending the list with the dispersion parameters
456        param_list.extend(self.getDispParamList())
457        return param_list
458
459    def getDispParamList(self):
460        # type: () -> Sequence[str]
461        """
462        Return a list of polydispersity parameters for the model
463        """
464        # TODO: fix test so that parameter order doesn't matter
465        ret = ['%s.%s' % (p_name, ext)
466               for p_name in self.dispersion.keys()
467               for ext in ('npts', 'nsigmas', 'width')]
468        #print(ret)
469        return ret
470
471    def clone(self):
472        # type: () -> "SasviewModel"
473        """ Return a identical copy of self """
474        return deepcopy(self)
475
476    def run(self, x=0.0):
477        # type: (Union[float, (float, float), List[float]]) -> float
478        """
479        Evaluate the model
480
481        :param x: input q, or [q,phi]
482
483        :return: scattering function P(q)
484
485        **DEPRECATED**: use calculate_Iq instead
486        """
487        if isinstance(x, (list, tuple)):
488            # pylint: disable=unpacking-non-sequence
489            q, phi = x
490            return self.calculate_Iq([q*math.cos(phi)], [q*math.sin(phi)])[0]
491        else:
492            return self.calculate_Iq([x])[0]
493
494
495    def runXY(self, x=0.0):
496        # type: (Union[float, (float, float), List[float]]) -> float
497        """
498        Evaluate the model in cartesian coordinates
499
500        :param x: input q, or [qx, qy]
501
502        :return: scattering function P(q)
503
504        **DEPRECATED**: use calculate_Iq instead
505        """
506        if isinstance(x, (list, tuple)):
507            return self.calculate_Iq([x[0]], [x[1]])[0]
508        else:
509            return self.calculate_Iq([x])[0]
510
511    def evalDistribution(self, qdist):
512        # type: (Union[np.ndarray, Tuple[np.ndarray, np.ndarray], List[np.ndarray]]) -> np.ndarray
513        r"""
514        Evaluate a distribution of q-values.
515
516        :param qdist: array of q or a list of arrays [qx,qy]
517
518        * For 1D, a numpy array is expected as input
519
520        ::
521
522            evalDistribution(q)
523
524          where *q* is a numpy array.
525
526        * For 2D, a list of *[qx,qy]* is expected with 1D arrays as input
527
528        ::
529
530              qx = [ qx[0], qx[1], qx[2], ....]
531              qy = [ qy[0], qy[1], qy[2], ....]
532
533        If the model is 1D only, then
534
535        .. math::
536
537            q = \sqrt{q_x^2+q_y^2}
538
539        """
540        if isinstance(qdist, (list, tuple)):
541            # Check whether we have a list of ndarrays [qx,qy]
542            qx, qy = qdist
543            if not self._model_info.parameters.has_2d:
544                return self.calculate_Iq(np.sqrt(qx ** 2 + qy ** 2))
545            else:
546                return self.calculate_Iq(qx, qy)
547
548        elif isinstance(qdist, np.ndarray):
549            # We have a simple 1D distribution of q-values
550            return self.calculate_Iq(qdist)
551
552        else:
553            raise TypeError("evalDistribution expects q or [qx, qy], not %r"
554                            % type(qdist))
555
556    def calculate_Iq(self, qx, qy=None):
557        # type: (Sequence[float], Optional[Sequence[float]]) -> np.ndarray
558        """
559        Calculate Iq for one set of q with the current parameters.
560
561        If the model is 1D, use *q*.  If 2D, use *qx*, *qy*.
562
563        This should NOT be used for fitting since it copies the *q* vectors
564        to the card for each evaluation.
565        """
566        #core.HAVE_OPENCL = False
567        if self._model is None:
568            self._model = core.build_model(self._model_info)
569        if qy is not None:
570            q_vectors = [np.asarray(qx), np.asarray(qy)]
571        else:
572            q_vectors = [np.asarray(qx)]
573        calculator = self._model.make_kernel(q_vectors)
574        parameters = self._model_info.parameters
575        pairs = [self._get_weights(p) for p in parameters.call_parameters]
576        #weights.plot_weights(self._model_info, pairs)
577        call_details, values, is_magnetic = make_kernel_args(calculator, pairs)
578        #call_details.show()
579        #print("pairs", pairs)
580        #print("params", self.params)
581        #print("values", values)
582        #print("is_mag", is_magnetic)
583        result = calculator(call_details, values, cutoff=self.cutoff,
584                            magnetic=is_magnetic)
585        calculator.release()
586        self._model.release()
587        return result
588
589    def calculate_ER(self):
590        # type: () -> float
591        """
592        Calculate the effective radius for P(q)*S(q)
593
594        :return: the value of the effective radius
595        """
596        if self._model_info.ER is None:
597            return 1.0
598        else:
599            value, weight = self._dispersion_mesh()
600            fv = self._model_info.ER(*value)
601            #print(values[0].shape, weights.shape, fv.shape)
602            return np.sum(weight * fv) / np.sum(weight)
603
604    def calculate_VR(self):
605        # type: () -> float
606        """
607        Calculate the volf ratio for P(q)*S(q)
608
609        :return: the value of the volf ratio
610        """
611        if self._model_info.VR is None:
612            return 1.0
613        else:
614            value, weight = self._dispersion_mesh()
615            whole, part = self._model_info.VR(*value)
616            return np.sum(weight * part) / np.sum(weight * whole)
617
618    def set_dispersion(self, parameter, dispersion):
619        # type: (str, weights.Dispersion) -> Dict[str, Any]
620        """
621        Set the dispersion object for a model parameter
622
623        :param parameter: name of the parameter [string]
624        :param dispersion: dispersion object of type Dispersion
625        """
626        if parameter in self.params:
627            # TODO: Store the disperser object directly in the model.
628            # The current method of relying on the sasview GUI to
629            # remember them is kind of funky.
630            # Note: can't seem to get disperser parameters from sasview
631            # (1) Could create a sasview model that has not yet been
632            # converted, assign the disperser to one of its polydisperse
633            # parameters, then retrieve the disperser parameters from the
634            # sasview model.
635            # (2) Could write a disperser parameter retriever in sasview.
636            # (3) Could modify sasview to use sasmodels.weights dispersers.
637            # For now, rely on the fact that the sasview only ever uses
638            # new dispersers in the set_dispersion call and create a new
639            # one instead of trying to assign parameters.
640            self.dispersion[parameter] = dispersion.get_pars()
641        else:
642            raise ValueError("%r is not a dispersity or orientation parameter")
643
644    def _dispersion_mesh(self):
645        # type: () -> List[Tuple[np.ndarray, np.ndarray]]
646        """
647        Create a mesh grid of dispersion parameters and weights.
648
649        Returns [p1,p2,...],w where pj is a vector of values for parameter j
650        and w is a vector containing the products for weights for each
651        parameter set in the vector.
652        """
653        pars = [self._get_weights(p)
654                for p in self._model_info.parameters.call_parameters
655                if p.type == 'volume']
656        return dispersion_mesh(self._model_info, pars)
657
658    def _get_weights(self, par):
659        # type: (Parameter) -> Tuple[np.ndarray, np.ndarray]
660        """
661        Return dispersion weights for parameter
662        """
663        if par.name not in self.params:
664            if par.name == self.multiplicity_info.control:
665                return [self.multiplicity], [1.0]
666            else:
667                return [np.NaN], [1.0]
668        elif par.polydisperse:
669            dis = self.dispersion[par.name]
670            if dis['type'] == 'array':
671                value, weight = dis['values'], dis['weights']
672            else:
673                value, weight = weights.get_weights(
674                    dis['type'], dis['npts'], dis['width'], dis['nsigmas'],
675                    self.params[par.name], par.limits, par.relative_pd)
676            return value, weight / np.sum(weight)
677        else:
678            return [self.params[par.name]], [1.0]
679
680def test_model():
681    # type: () -> float
682    """
683    Test that a sasview model (cylinder) can be run.
684    """
685    Cylinder = _make_standard_model('cylinder')
686    cylinder = Cylinder()
687    return cylinder.evalDistribution([0.1, 0.1])
688
689def test_rpa():
690    # type: () -> float
691    """
692    Test that a sasview model (cylinder) can be run.
693    """
694    RPA = _make_standard_model('rpa')
695    rpa = RPA(3)
696    return rpa.evalDistribution([0.1, 0.1])
697
698
699def test_model_list():
700    # type: () -> None
701    """
702    Make sure that all models build as sasview models.
703    """
704    from .exception import annotate_exception
705    for name in core.list_models():
706        try:
707            _make_standard_model(name)
708        except:
709            annotate_exception("when loading "+name)
710            raise
711
712def test_old_name():
713    # type: () -> None
714    """
715    Load and run cylinder model from sas.models.CylinderModel
716    """
717    if not SUPPORT_OLD_STYLE_PLUGINS:
718        return
719    try:
720        # if sasview is not on the path then don't try to test it
721        import sas
722    except ImportError:
723        return
724    load_standard_models()
725    from sas.models.CylinderModel import CylinderModel
726    CylinderModel().evalDistribution([0.1, 0.1])
727
728if __name__ == "__main__":
729    print("cylinder(0.1,0.1)=%g"%test_model())
Note: See TracBrowser for help on using the repository browser.