source: sasmodels/sasmodels/sasview_model.py @ ed10b57

core_shell_microgelscostrafo411magnetic_modelticket-1257-vesicle-productticket_1156ticket_1265_superballticket_822_more_unit_tests
Last change on this file since ed10b57 was ed10b57, checked in by mathieu, 8 years ago

Allow loading of custom model of the same name as a std model (append version). Re #673

  • Property mode set to 100644
File size: 24.6 KB
Line 
1"""
2Sasview model constructor.
3
4Given a module defining an OpenCL kernel such as sasmodels.models.cylinder,
5create a sasview model class to run that kernel as follows::
6
7    from sasmodels.sasview_model import load_custom_model
8    CylinderModel = load_custom_model('sasmodels/models/cylinder.py')
9"""
10from __future__ import print_function
11
12import math
13from copy import deepcopy
14import collections
15import traceback
16import logging
17from os.path import basename, splitext
18
19import numpy as np  # type: ignore
20
21from . import core
22from . import custom
23from . import generate
24from . import weights
25from . import modelinfo
26from .details import make_kernel_args, dispersion_mesh
27
28try:
29    from typing import Dict, Mapping, Any, Sequence, Tuple, NamedTuple, List, Optional, Union, Callable
30    from .modelinfo import ModelInfo, Parameter
31    from .kernel import KernelModel
32    MultiplicityInfoType = NamedTuple(
33        'MuliplicityInfo',
34        [("number", int), ("control", str), ("choices", List[str]),
35         ("x_axis_label", str)])
36    SasviewModelType = Callable[[int], "SasviewModel"]
37except ImportError:
38    pass
39
40SUPPORT_OLD_STYLE_PLUGINS = True
41
42def _register_old_models():
43    # type: () -> None
44    """
45    Place the new models into sasview under the old names.
46
47    Monkey patch sas.sascalc.fit as sas.models so that sas.models.pluginmodel
48    is available to the plugin modules.
49    """
50    import sys
51    import sas
52    import sas.sascalc.fit
53    sys.modules['sas.models'] = sas.sascalc.fit
54    sas.models = sas.sascalc.fit
55
56    import sas.models
57    from sasmodels.conversion_table import CONVERSION_TABLE
58    for new_name, conversion in CONVERSION_TABLE.items():
59        old_name = conversion[0]
60        module_attrs = {old_name: find_model(new_name)}
61        ConstructedModule = type(old_name, (), module_attrs)
62        old_path = 'sas.models.' + old_name
63        setattr(sas.models, old_path, ConstructedModule)
64        sys.modules[old_path] = ConstructedModule
65
66
67# TODO: separate x_axis_label from multiplicity info
68MultiplicityInfo = collections.namedtuple(
69    'MultiplicityInfo',
70    ["number", "control", "choices", "x_axis_label"],
71)
72
73MODELS = {}
74def find_model(modelname):
75    # type: (str) -> SasviewModelType
76    """
77    Find a model by name.  If the model name ends in py, try loading it from
78    custom models, otherwise look for it in the list of builtin models.
79    """
80    # TODO: used by sum/product model to load an existing model
81    # TODO: doesn't handle custom models properly
82    if modelname.endswith('.py'):
83        return load_custom_model(modelname)
84    elif modelname in MODELS:
85        return MODELS[modelname]
86    else:
87        raise ValueError("unknown model %r"%modelname)
88
89
90# TODO: figure out how to say that the return type is a subclass
91def load_standard_models():
92    # type: () -> List[SasviewModelType]
93    """
94    Load and return the list of predefined models.
95
96    If there is an error loading a model, then a traceback is logged and the
97    model is not returned.
98    """
99    models = []
100    for name in core.list_models():
101        try:
102            MODELS[name] = _make_standard_model(name)
103            models.append(MODELS[name])
104        except Exception:
105            logging.error(traceback.format_exc())
106    if SUPPORT_OLD_STYLE_PLUGINS:
107        _register_old_models()
108
109    return models
110
111
112def load_custom_model(path):
113    # type: (str) -> SasviewModelType
114    """
115    Load a custom model given the model path.
116    """
117    kernel_module = custom.load_custom_kernel_module(path)
118    try:
119        model = kernel_module.Model
120        # Old style models do not set the name in the class attributes, so
121        # set it here; this name will be overridden when the object is created
122        # with an instance variable that has the same value.
123        if model.name == "":
124            model.name = splitext(basename(path))[0]
125    except AttributeError:
126        model_info = modelinfo.make_model_info(kernel_module)
127        model = _make_model_from_info(model_info)
128
129    # If we are trying to load a model that already exists,
130    # append a version number to its name.
131    # Note that models appear to be periodically reloaded
132    # by SasView and keeping track of whether we are reloading
133    # a model or loading it for the first time is tricky.
134    # For now, just allow one custom model of a given name.
135    if model.name in MODELS:
136        model.name = "%s_v2" % model.name
137
138    MODELS[model.name] = model
139    return model
140
141
142def _make_standard_model(name):
143    # type: (str) -> SasviewModelType
144    """
145    Load the sasview model defined by *name*.
146
147    *name* can be a standard model name or a path to a custom model.
148
149    Returns a class that can be used directly as a sasview model.
150    """
151    kernel_module = generate.load_kernel_module(name)
152    model_info = modelinfo.make_model_info(kernel_module)
153    return _make_model_from_info(model_info)
154
155
156def _make_model_from_info(model_info):
157    # type: (ModelInfo) -> SasviewModelType
158    """
159    Convert *model_info* into a SasView model wrapper.
160    """
161    def __init__(self, multiplicity=None):
162        SasviewModel.__init__(self, multiplicity=multiplicity)
163    attrs = _generate_model_attributes(model_info)
164    attrs['__init__'] = __init__
165    ConstructedModel = type(model_info.name, (SasviewModel,), attrs) # type: SasviewModelType
166    return ConstructedModel
167
168def _generate_model_attributes(model_info):
169    # type: (ModelInfo) -> Dict[str, Any]
170    """
171    Generate the class attributes for the model.
172
173    This should include all the information necessary to query the model
174    details so that you do not need to instantiate a model to query it.
175
176    All the attributes should be immutable to avoid accidents.
177    """
178
179    # TODO: allow model to override axis labels input/output name/unit
180
181    # Process multiplicity
182    non_fittable = []  # type: List[str]
183    xlabel = model_info.profile_axes[0] if model_info.profile is not None else ""
184    variants = MultiplicityInfo(0, "", [], xlabel)
185    for p in model_info.parameters.kernel_parameters:
186        if p.name == model_info.control:
187            non_fittable.append(p.name)
188            variants = MultiplicityInfo(
189                len(p.choices) if p.choices else int(p.limits[1]),
190                p.name, p.choices, xlabel
191            )
192            break
193
194    # Only a single drop-down list parameter available
195    fun_list = []
196    for p in model_info.parameters.kernel_parameters:
197        if p.choices:
198            fun_list = p.choices
199            if p.length > 1:
200                non_fittable.extend(p.id+str(k) for k in range(1, p.length+1))
201            break
202
203    # Organize parameter sets
204    orientation_params = []
205    magnetic_params = []
206    fixed = []
207    for p in model_info.parameters.user_parameters():
208        if p.type == 'orientation':
209            orientation_params.append(p.name)
210            orientation_params.append(p.name+".width")
211            fixed.append(p.name+".width")
212        elif p.type == 'magnetic':
213            orientation_params.append(p.name)
214            magnetic_params.append(p.name)
215            fixed.append(p.name+".width")
216
217
218    # Build class dictionary
219    attrs = {}  # type: Dict[str, Any]
220    attrs['_model_info'] = model_info
221    attrs['name'] = model_info.name
222    attrs['id'] = model_info.id
223    attrs['description'] = model_info.description
224    attrs['category'] = model_info.category
225    attrs['is_structure_factor'] = model_info.structure_factor
226    attrs['is_form_factor'] = model_info.ER is not None
227    attrs['is_multiplicity_model'] = variants[0] > 1
228    attrs['multiplicity_info'] = variants
229    attrs['orientation_params'] = tuple(orientation_params)
230    attrs['magnetic_params'] = tuple(magnetic_params)
231    attrs['fixed'] = tuple(fixed)
232    attrs['non_fittable'] = tuple(non_fittable)
233    attrs['fun_list'] = tuple(fun_list)
234
235    return attrs
236
237class SasviewModel(object):
238    """
239    Sasview wrapper for opencl/ctypes model.
240    """
241    # Model parameters for the specific model are set in the class constructor
242    # via the _generate_model_attributes function, which subclasses
243    # SasviewModel.  They are included here for typing and documentation
244    # purposes.
245    _model = None       # type: KernelModel
246    _model_info = None  # type: ModelInfo
247    #: load/save name for the model
248    id = None           # type: str
249    #: display name for the model
250    name = None         # type: str
251    #: short model description
252    description = None  # type: str
253    #: default model category
254    category = None     # type: str
255
256    #: names of the orientation parameters in the order they appear
257    orientation_params = None # type: Sequence[str]
258    #: names of the magnetic parameters in the order they appear
259    magnetic_params = None    # type: Sequence[str]
260    #: names of the fittable parameters
261    fixed = None              # type: Sequence[str]
262    # TODO: the attribute fixed is ill-named
263
264    # Axis labels
265    input_name = "Q"
266    input_unit = "A^{-1}"
267    output_name = "Intensity"
268    output_unit = "cm^{-1}"
269
270    #: default cutoff for polydispersity
271    cutoff = 1e-5
272
273    # Note: Use non-mutable values for class attributes to avoid errors
274    #: parameters that are not fitted
275    non_fittable = ()        # type: Sequence[str]
276
277    #: True if model should appear as a structure factor
278    is_structure_factor = False
279    #: True if model should appear as a form factor
280    is_form_factor = False
281    #: True if model has multiplicity
282    is_multiplicity_model = False
283    #: Mulitplicity information
284    multiplicity_info = None # type: MultiplicityInfoType
285
286    # Per-instance variables
287    #: parameter {name: value} mapping
288    params = None      # type: Dict[str, float]
289    #: values for dispersion width, npts, nsigmas and type
290    dispersion = None  # type: Dict[str, Any]
291    #: units and limits for each parameter
292    details = None     # type: Dict[str, Sequence[Any]]
293    #                  # actual type is Dict[str, List[str, float, float]]
294    #: multiplicity value, or None if no multiplicity on the model
295    multiplicity = None     # type: Optional[int]
296    #: memory for polydispersity array if using ArrayDispersion (used by sasview).
297    _persistency_dict = None # type: Dict[str, Tuple[np.ndarray, np.ndarray]]
298
299    def __init__(self, multiplicity=None):
300        # type: (Optional[int]) -> None
301
302        # TODO: _persistency_dict to persistency_dict throughout sasview
303        # TODO: refactor multiplicity to encompass variants
304        # TODO: dispersion should be a class
305        # TODO: refactor multiplicity info
306        # TODO: separate profile view from multiplicity
307        # The button label, x and y axis labels and scale need to be under
308        # the control of the model, not the fit page.  Maximum flexibility,
309        # the fit page would supply the canvas and the profile could plot
310        # how it wants, but this assumes matplotlib.  Next level is that
311        # we provide some sort of data description including title, labels
312        # and lines to plot.
313
314        # Get the list of hidden parameters given the mulitplicity
315        # Don't include multiplicity in the list of parameters
316        self.multiplicity = multiplicity
317        if multiplicity is not None:
318            hidden = self._model_info.get_hidden_parameters(multiplicity)
319            hidden |= set([self.multiplicity_info.control])
320        else:
321            hidden = set()
322
323        self._persistency_dict = {}
324        self.params = collections.OrderedDict()
325        self.dispersion = collections.OrderedDict()
326        self.details = {}
327        for p in self._model_info.parameters.user_parameters():
328            if p.name in hidden:
329                continue
330            self.params[p.name] = p.default
331            self.details[p.id] = [p.units, p.limits[0], p.limits[1]]
332            if p.polydisperse:
333                self.details[p.id+".width"] = [
334                    "", 0.0, 1.0 if p.relative_pd else np.inf
335                ]
336                self.dispersion[p.name] = {
337                    'width': 0,
338                    'npts': 35,
339                    'nsigmas': 3,
340                    'type': 'gaussian',
341                }
342
343    def __get_state__(self):
344        # type: () -> Dict[str, Any]
345        state = self.__dict__.copy()
346        state.pop('_model')
347        # May need to reload model info on set state since it has pointers
348        # to python implementations of Iq, etc.
349        #state.pop('_model_info')
350        return state
351
352    def __set_state__(self, state):
353        # type: (Dict[str, Any]) -> None
354        self.__dict__ = state
355        self._model = None
356
357    def __str__(self):
358        # type: () -> str
359        """
360        :return: string representation
361        """
362        return self.name
363
364    def is_fittable(self, par_name):
365        # type: (str) -> bool
366        """
367        Check if a given parameter is fittable or not
368
369        :param par_name: the parameter name to check
370        """
371        return par_name in self.fixed
372        #For the future
373        #return self.params[str(par_name)].is_fittable()
374
375
376    def getProfile(self):
377        # type: () -> (np.ndarray, np.ndarray)
378        """
379        Get SLD profile
380
381        : return: (z, beta) where z is a list of depth of the transition points
382                beta is a list of the corresponding SLD values
383        """
384        args = {} # type: Dict[str, Any]
385        for p in self._model_info.parameters.kernel_parameters:
386            if p.id == self.multiplicity_info.control:
387                value = float(self.multiplicity)
388            elif p.length == 1:
389                value = self.params.get(p.id, np.NaN)
390            else:
391                value = np.array([self.params.get(p.id+str(k), np.NaN)
392                                  for k in range(1, p.length+1)])
393            args[p.id] = value
394
395        x, y = self._model_info.profile(**args)
396        return x, 1e-6*y
397
398    def setParam(self, name, value):
399        # type: (str, float) -> None
400        """
401        Set the value of a model parameter
402
403        :param name: name of the parameter
404        :param value: value of the parameter
405
406        """
407        # Look for dispersion parameters
408        toks = name.split('.')
409        if len(toks) == 2:
410            for item in self.dispersion.keys():
411                if item == toks[0]:
412                    for par in self.dispersion[item]:
413                        if par == toks[1]:
414                            self.dispersion[item][par] = value
415                            return
416        else:
417            # Look for standard parameter
418            for item in self.params.keys():
419                if item == name:
420                    self.params[item] = value
421                    return
422
423        raise ValueError("Model does not contain parameter %s" % name)
424
425    def getParam(self, name):
426        # type: (str) -> float
427        """
428        Set the value of a model parameter
429
430        :param name: name of the parameter
431
432        """
433        # Look for dispersion parameters
434        toks = name.split('.')
435        if len(toks) == 2:
436            for item in self.dispersion.keys():
437                if item == toks[0]:
438                    for par in self.dispersion[item]:
439                        if par == toks[1]:
440                            return self.dispersion[item][par]
441        else:
442            # Look for standard parameter
443            for item in self.params.keys():
444                if item == name:
445                    return self.params[item]
446
447        raise ValueError("Model does not contain parameter %s" % name)
448
449    def getParamList(self):
450        # type: () -> Sequence[str]
451        """
452        Return a list of all available parameters for the model
453        """
454        param_list = list(self.params.keys())
455        # WARNING: Extending the list with the dispersion parameters
456        param_list.extend(self.getDispParamList())
457        return param_list
458
459    def getDispParamList(self):
460        # type: () -> Sequence[str]
461        """
462        Return a list of polydispersity parameters for the model
463        """
464        # TODO: fix test so that parameter order doesn't matter
465        ret = ['%s.%s' % (p_name, ext)
466               for p_name in self.dispersion.keys()
467               for ext in ('npts', 'nsigmas', 'width')]
468        #print(ret)
469        return ret
470
471    def clone(self):
472        # type: () -> "SasviewModel"
473        """ Return a identical copy of self """
474        return deepcopy(self)
475
476    def run(self, x=0.0):
477        # type: (Union[float, (float, float), List[float]]) -> float
478        """
479        Evaluate the model
480
481        :param x: input q, or [q,phi]
482
483        :return: scattering function P(q)
484
485        **DEPRECATED**: use calculate_Iq instead
486        """
487        if isinstance(x, (list, tuple)):
488            # pylint: disable=unpacking-non-sequence
489            q, phi = x
490            return self.calculate_Iq([q*math.cos(phi)], [q*math.sin(phi)])[0]
491        else:
492            return self.calculate_Iq([x])[0]
493
494
495    def runXY(self, x=0.0):
496        # type: (Union[float, (float, float), List[float]]) -> float
497        """
498        Evaluate the model in cartesian coordinates
499
500        :param x: input q, or [qx, qy]
501
502        :return: scattering function P(q)
503
504        **DEPRECATED**: use calculate_Iq instead
505        """
506        if isinstance(x, (list, tuple)):
507            return self.calculate_Iq([x[0]], [x[1]])[0]
508        else:
509            return self.calculate_Iq([x])[0]
510
511    def evalDistribution(self, qdist):
512        # type: (Union[np.ndarray, Tuple[np.ndarray, np.ndarray], List[np.ndarray]]) -> np.ndarray
513        r"""
514        Evaluate a distribution of q-values.
515
516        :param qdist: array of q or a list of arrays [qx,qy]
517
518        * For 1D, a numpy array is expected as input
519
520        ::
521
522            evalDistribution(q)
523
524          where *q* is a numpy array.
525
526        * For 2D, a list of *[qx,qy]* is expected with 1D arrays as input
527
528        ::
529
530              qx = [ qx[0], qx[1], qx[2], ....]
531              qy = [ qy[0], qy[1], qy[2], ....]
532
533        If the model is 1D only, then
534
535        .. math::
536
537            q = \sqrt{q_x^2+q_y^2}
538
539        """
540        if isinstance(qdist, (list, tuple)):
541            # Check whether we have a list of ndarrays [qx,qy]
542            qx, qy = qdist
543            if not self._model_info.parameters.has_2d:
544                return self.calculate_Iq(np.sqrt(qx ** 2 + qy ** 2))
545            else:
546                return self.calculate_Iq(qx, qy)
547
548        elif isinstance(qdist, np.ndarray):
549            # We have a simple 1D distribution of q-values
550            return self.calculate_Iq(qdist)
551
552        else:
553            raise TypeError("evalDistribution expects q or [qx, qy], not %r"
554                            % type(qdist))
555
556    def calculate_Iq(self, qx, qy=None):
557        # type: (Sequence[float], Optional[Sequence[float]]) -> np.ndarray
558        """
559        Calculate Iq for one set of q with the current parameters.
560
561        If the model is 1D, use *q*.  If 2D, use *qx*, *qy*.
562
563        This should NOT be used for fitting since it copies the *q* vectors
564        to the card for each evaluation.
565        """
566        #core.HAVE_OPENCL = False
567        if self._model is None:
568            self._model = core.build_model(self._model_info)
569        if qy is not None:
570            q_vectors = [np.asarray(qx), np.asarray(qy)]
571        else:
572            q_vectors = [np.asarray(qx)]
573        calculator = self._model.make_kernel(q_vectors)
574        parameters = self._model_info.parameters
575        pairs = [self._get_weights(p) for p in parameters.call_parameters]
576        #weights.plot_weights(self._model_info, pairs)
577        call_details, values, is_magnetic = make_kernel_args(calculator, pairs)
578        #call_details.show()
579        #print("pairs", pairs)
580        #print("params", self.params)
581        #print("values", values)
582        #print("is_mag", is_magnetic)
583        result = calculator(call_details, values, cutoff=self.cutoff,
584                            magnetic=is_magnetic)
585        calculator.release()
586        try:
587            self._model.release()
588        except:
589            pass
590        return result
591
592    def calculate_ER(self):
593        # type: () -> float
594        """
595        Calculate the effective radius for P(q)*S(q)
596
597        :return: the value of the effective radius
598        """
599        if self._model_info.ER is None:
600            return 1.0
601        else:
602            value, weight = self._dispersion_mesh()
603            fv = self._model_info.ER(*value)
604            #print(values[0].shape, weights.shape, fv.shape)
605            return np.sum(weight * fv) / np.sum(weight)
606
607    def calculate_VR(self):
608        # type: () -> float
609        """
610        Calculate the volf ratio for P(q)*S(q)
611
612        :return: the value of the volf ratio
613        """
614        if self._model_info.VR is None:
615            return 1.0
616        else:
617            value, weight = self._dispersion_mesh()
618            whole, part = self._model_info.VR(*value)
619            return np.sum(weight * part) / np.sum(weight * whole)
620
621    def set_dispersion(self, parameter, dispersion):
622        # type: (str, weights.Dispersion) -> Dict[str, Any]
623        """
624        Set the dispersion object for a model parameter
625
626        :param parameter: name of the parameter [string]
627        :param dispersion: dispersion object of type Dispersion
628        """
629        if parameter in self.params:
630            # TODO: Store the disperser object directly in the model.
631            # The current method of relying on the sasview GUI to
632            # remember them is kind of funky.
633            # Note: can't seem to get disperser parameters from sasview
634            # (1) Could create a sasview model that has not yet been
635            # converted, assign the disperser to one of its polydisperse
636            # parameters, then retrieve the disperser parameters from the
637            # sasview model.
638            # (2) Could write a disperser parameter retriever in sasview.
639            # (3) Could modify sasview to use sasmodels.weights dispersers.
640            # For now, rely on the fact that the sasview only ever uses
641            # new dispersers in the set_dispersion call and create a new
642            # one instead of trying to assign parameters.
643            self.dispersion[parameter] = dispersion.get_pars()
644        else:
645            raise ValueError("%r is not a dispersity or orientation parameter")
646
647    def _dispersion_mesh(self):
648        # type: () -> List[Tuple[np.ndarray, np.ndarray]]
649        """
650        Create a mesh grid of dispersion parameters and weights.
651
652        Returns [p1,p2,...],w where pj is a vector of values for parameter j
653        and w is a vector containing the products for weights for each
654        parameter set in the vector.
655        """
656        pars = [self._get_weights(p)
657                for p in self._model_info.parameters.call_parameters
658                if p.type == 'volume']
659        return dispersion_mesh(self._model_info, pars)
660
661    def _get_weights(self, par):
662        # type: (Parameter) -> Tuple[np.ndarray, np.ndarray]
663        """
664        Return dispersion weights for parameter
665        """
666        if par.name not in self.params:
667            if par.name == self.multiplicity_info.control:
668                return [self.multiplicity], [1.0]
669            else:
670                return [np.NaN], [1.0]
671        elif par.polydisperse:
672            dis = self.dispersion[par.name]
673            if dis['type'] == 'array':
674                value, weight = dis['values'], dis['weights']
675            else:
676                value, weight = weights.get_weights(
677                    dis['type'], dis['npts'], dis['width'], dis['nsigmas'],
678                    self.params[par.name], par.limits, par.relative_pd)
679            return value, weight / np.sum(weight)
680        else:
681            return [self.params[par.name]], [1.0]
682
683def test_model():
684    # type: () -> float
685    """
686    Test that a sasview model (cylinder) can be run.
687    """
688    Cylinder = _make_standard_model('cylinder')
689    cylinder = Cylinder()
690    return cylinder.evalDistribution([0.1, 0.1])
691
692def test_rpa():
693    # type: () -> float
694    """
695    Test that a sasview model (cylinder) can be run.
696    """
697    RPA = _make_standard_model('rpa')
698    rpa = RPA(3)
699    return rpa.evalDistribution([0.1, 0.1])
700
701
702def test_model_list():
703    # type: () -> None
704    """
705    Make sure that all models build as sasview models.
706    """
707    from .exception import annotate_exception
708    for name in core.list_models():
709        try:
710            _make_standard_model(name)
711        except:
712            annotate_exception("when loading "+name)
713            raise
714
715def test_old_name():
716    # type: () -> None
717    """
718    Load and run cylinder model from sas.models.CylinderModel
719    """
720    if not SUPPORT_OLD_STYLE_PLUGINS:
721        return
722    try:
723        # if sasview is not on the path then don't try to test it
724        import sas
725    except ImportError:
726        return
727    load_standard_models()
728    from sas.models.CylinderModel import CylinderModel
729    CylinderModel().evalDistribution([0.1, 0.1])
730
731if __name__ == "__main__":
732    print("cylinder(0.1,0.1)=%g"%test_model())
Note: See TracBrowser for help on using the repository browser.