source: sasmodels/sasmodels/sasview_model.py @ 9457498

core_shell_microgelscostrafo411magnetic_modelrelease_v0.94release_v0.95ticket-1257-vesicle-productticket_1156ticket_1265_superballticket_822_more_unit_tests
Last change on this file since 9457498 was 9457498, checked in by Paul Kienzle <pkienzle@…>, 8 years ago

support old style custom formula models (but not sum or product)

  • Property mode set to 100644
File size: 22.9 KB
Line 
1"""
2Sasview model constructor.
3
4Given a module defining an OpenCL kernel such as sasmodels.models.cylinder,
5create a sasview model class to run that kernel as follows::
6
7    from sasmodels.sasview_model import load_custom_model
8    CylinderModel = load_custom_model('sasmodels/models/cylinder.py')
9"""
10from __future__ import print_function
11
12import math
13from copy import deepcopy
14import collections
15import traceback
16import logging
17from os.path import basename, splitext
18
19import numpy as np  # type: ignore
20
21# Monkey patch sas.sascalc.fit as sas.models so that sas.models.pluginmodel
22# is available to the plugin modules.
23import sys
24import sas
25import sas.sascalc.fit
26sys.modules['sas.models'] = sas.sascalc.fit
27sas.models = sas.sascalc.fit
28
29from . import core
30from . import custom
31from . import generate
32from . import weights
33from . import modelinfo
34from .details import build_details, dispersion_mesh
35
36try:
37    from typing import Dict, Mapping, Any, Sequence, Tuple, NamedTuple, List, Optional, Union, Callable
38    from .modelinfo import ModelInfo, Parameter
39    from .kernel import KernelModel
40    MultiplicityInfoType = NamedTuple(
41        'MuliplicityInfo',
42        [("number", int), ("control", str), ("choices", List[str]),
43         ("x_axis_label", str)])
44    SasviewModelType = Callable[[int], "SasviewModel"]
45except ImportError:
46    pass
47
48
49# TODO: separate x_axis_label from multiplicity info
50MultiplicityInfo = collections.namedtuple(
51    'MultiplicityInfo',
52    ["number", "control", "choices", "x_axis_label"],
53)
54
55MODELS = {}
56def find_model(modelname):
57    # type: (str) -> SasviewModelType
58    """
59    Find a model by name.  If the model name ends in py, try loading it from
60    custom models, otherwise look for it in the list of builtin models.
61    """
62    # TODO: used by sum/product model to load an existing model
63    # TODO: doesn't handle custom models properly
64    if modelname.endswith('.py'):
65        return load_custom_model(modelname)
66    elif modelname in MODELS:
67        return MODELS[modelname]
68    else:
69        raise ValueError("unknown model %r"%modelname)
70
71
72# TODO: figure out how to say that the return type is a subclass
73def load_standard_models():
74    # type: () -> List[SasviewModelType]
75    """
76    Load and return the list of predefined models.
77
78    If there is an error loading a model, then a traceback is logged and the
79    model is not returned.
80    """
81    models = []
82    for name in core.list_models():
83        try:
84            MODELS[name] = _make_standard_model(name)
85            models.append(MODELS[name])
86        except Exception:
87            logging.error(traceback.format_exc())
88    return models
89
90
91def load_custom_model(path):
92    # type: (str) -> SasviewModelType
93    """
94    Load a custom model given the model path.
95    """
96    kernel_module = custom.load_custom_kernel_module(path)
97    try:
98        model = kernel_module.Model
99        # Old style models do not set the name in the class attributes, so
100        # set it here; this name will be overridden when the object is created
101        # with an instance variable that has the same value.
102        if model.name == "":
103            model.name = splitext(basename(path))[0]
104    except AttributeError:
105        model_info = modelinfo.make_model_info(kernel_module)
106        model = _make_model_from_info(model_info)
107    MODELS[model.name] = model
108    return model
109
110
111def _make_standard_model(name):
112    # type: (str) -> SasviewModelType
113    """
114    Load the sasview model defined by *name*.
115
116    *name* can be a standard model name or a path to a custom model.
117
118    Returns a class that can be used directly as a sasview model.
119    """
120    kernel_module = generate.load_kernel_module(name)
121    model_info = modelinfo.make_model_info(kernel_module)
122    return _make_model_from_info(model_info)
123
124
125def _make_model_from_info(model_info):
126    # type: (ModelInfo) -> SasviewModelType
127    """
128    Convert *model_info* into a SasView model wrapper.
129    """
130    def __init__(self, multiplicity=None):
131        SasviewModel.__init__(self, multiplicity=multiplicity)
132    attrs = _generate_model_attributes(model_info)
133    attrs['__init__'] = __init__
134    ConstructedModel = type(model_info.name, (SasviewModel,), attrs) # type: SasviewModelType
135    return ConstructedModel
136
137def _generate_model_attributes(model_info):
138    # type: (ModelInfo) -> Dict[str, Any]
139    """
140    Generate the class attributes for the model.
141
142    This should include all the information necessary to query the model
143    details so that you do not need to instantiate a model to query it.
144
145    All the attributes should be immutable to avoid accidents.
146    """
147
148    # TODO: allow model to override axis labels input/output name/unit
149
150    # Process multiplicity
151    non_fittable = []  # type: List[str]
152    xlabel = model_info.profile_axes[0] if model_info.profile is not None else ""
153    variants = MultiplicityInfo(0, "", [], xlabel)
154    for p in model_info.parameters.kernel_parameters:
155        if p.name == model_info.control:
156            non_fittable.append(p.name)
157            variants = MultiplicityInfo(
158                len(p.choices) if p.choices else int(p.limits[1]),
159                p.name, p.choices, xlabel
160            )
161            break
162
163    # Only a single drop-down list parameter available
164    fun_list = []
165    for p in model_info.parameters.kernel_parameters:
166        if p.choices:
167            fun_list = p.choices
168            if p.length > 1:
169                non_fittable.extend(p.id+str(k) for k in range(1, p.length+1))
170            break
171
172    # Organize parameter sets
173    orientation_params = []
174    magnetic_params = []
175    fixed = []
176    for p in model_info.parameters.user_parameters():
177        if p.type == 'orientation':
178            orientation_params.append(p.name)
179            orientation_params.append(p.name+".width")
180            fixed.append(p.name+".width")
181        elif p.type == 'magnetic':
182            orientation_params.append(p.name)
183            magnetic_params.append(p.name)
184            fixed.append(p.name+".width")
185
186
187    # Build class dictionary
188    attrs = {}  # type: Dict[str, Any]
189    attrs['_model_info'] = model_info
190    attrs['name'] = model_info.name
191    attrs['id'] = model_info.id
192    attrs['description'] = model_info.description
193    attrs['category'] = model_info.category
194    attrs['is_structure_factor'] = model_info.structure_factor
195    attrs['is_form_factor'] = model_info.ER is not None
196    attrs['is_multiplicity_model'] = variants[0] > 1
197    attrs['multiplicity_info'] = variants
198    attrs['orientation_params'] = tuple(orientation_params)
199    attrs['magnetic_params'] = tuple(magnetic_params)
200    attrs['fixed'] = tuple(fixed)
201    attrs['non_fittable'] = tuple(non_fittable)
202    attrs['fun_list'] = tuple(fun_list)
203
204    return attrs
205
206class SasviewModel(object):
207    """
208    Sasview wrapper for opencl/ctypes model.
209    """
210    # Model parameters for the specific model are set in the class constructor
211    # via the _generate_model_attributes function, which subclasses
212    # SasviewModel.  They are included here for typing and documentation
213    # purposes.
214    _model = None       # type: KernelModel
215    _model_info = None  # type: ModelInfo
216    #: load/save name for the model
217    id = None           # type: str
218    #: display name for the model
219    name = None         # type: str
220    #: short model description
221    description = None  # type: str
222    #: default model category
223    category = None     # type: str
224
225    #: names of the orientation parameters in the order they appear
226    orientation_params = None # type: Sequence[str]
227    #: names of the magnetic parameters in the order they appear
228    magnetic_params = None    # type: Sequence[str]
229    #: names of the fittable parameters
230    fixed = None              # type: Sequence[str]
231    # TODO: the attribute fixed is ill-named
232
233    # Axis labels
234    input_name = "Q"
235    input_unit = "A^{-1}"
236    output_name = "Intensity"
237    output_unit = "cm^{-1}"
238
239    #: default cutoff for polydispersity
240    cutoff = 1e-5
241
242    # Note: Use non-mutable values for class attributes to avoid errors
243    #: parameters that are not fitted
244    non_fittable = ()        # type: Sequence[str]
245
246    #: True if model should appear as a structure factor
247    is_structure_factor = False
248    #: True if model should appear as a form factor
249    is_form_factor = False
250    #: True if model has multiplicity
251    is_multiplicity_model = False
252    #: Mulitplicity information
253    multiplicity_info = None # type: MultiplicityInfoType
254
255    # Per-instance variables
256    #: parameter {name: value} mapping
257    params = None      # type: Dict[str, float]
258    #: values for dispersion width, npts, nsigmas and type
259    dispersion = None  # type: Dict[str, Any]
260    #: units and limits for each parameter
261    details = None     # type: Dict[str, Sequence[Any]]
262    #                  # actual type is Dict[str, List[str, float, float]]
263    #: multiplicity value, or None if no multiplicity on the model
264    multiplicity = None     # type: Optional[int]
265    #: memory for polydispersity array if using ArrayDispersion (used by sasview).
266    _persistency_dict = None # type: Dict[str, Tuple[np.ndarray, np.ndarray]]
267
268    def __init__(self, multiplicity=None):
269        # type: (Optional[int]) -> None
270
271        # TODO: _persistency_dict to persistency_dict throughout sasview
272        # TODO: refactor multiplicity to encompass variants
273        # TODO: dispersion should be a class
274        # TODO: refactor multiplicity info
275        # TODO: separate profile view from multiplicity
276        # The button label, x and y axis labels and scale need to be under
277        # the control of the model, not the fit page.  Maximum flexibility,
278        # the fit page would supply the canvas and the profile could plot
279        # how it wants, but this assumes matplotlib.  Next level is that
280        # we provide some sort of data description including title, labels
281        # and lines to plot.
282
283        # Get the list of hidden parameters given the mulitplicity
284        # Don't include multiplicity in the list of parameters
285        self.multiplicity = multiplicity
286        if multiplicity is not None:
287            hidden = self._model_info.get_hidden_parameters(multiplicity)
288            hidden |= set([self.multiplicity_info.control])
289        else:
290            hidden = set()
291
292        self._persistency_dict = {}
293        self.params = collections.OrderedDict()
294        self.dispersion = collections.OrderedDict()
295        self.details = {}
296        for p in self._model_info.parameters.user_parameters():
297            if p.name in hidden:
298                continue
299            self.params[p.name] = p.default
300            self.details[p.id] = [p.units, p.limits[0], p.limits[1]]
301            if p.polydisperse:
302                self.details[p.id+".width"] = [
303                    "", 0.0, 1.0 if p.relative_pd else np.inf
304                ]
305                self.dispersion[p.name] = {
306                    'width': 0,
307                    'npts': 35,
308                    'nsigmas': 3,
309                    'type': 'gaussian',
310                }
311
312    def __get_state__(self):
313        # type: () -> Dict[str, Any]
314        state = self.__dict__.copy()
315        state.pop('_model')
316        # May need to reload model info on set state since it has pointers
317        # to python implementations of Iq, etc.
318        #state.pop('_model_info')
319        return state
320
321    def __set_state__(self, state):
322        # type: (Dict[str, Any]) -> None
323        self.__dict__ = state
324        self._model = None
325
326    def __str__(self):
327        # type: () -> str
328        """
329        :return: string representation
330        """
331        return self.name
332
333    def is_fittable(self, par_name):
334        # type: (str) -> bool
335        """
336        Check if a given parameter is fittable or not
337
338        :param par_name: the parameter name to check
339        """
340        return par_name in self.fixed
341        #For the future
342        #return self.params[str(par_name)].is_fittable()
343
344
345    def getProfile(self):
346        # type: () -> (np.ndarray, np.ndarray)
347        """
348        Get SLD profile
349
350        : return: (z, beta) where z is a list of depth of the transition points
351                beta is a list of the corresponding SLD values
352        """
353        args = {} # type: Dict[str, Any]
354        for p in self._model_info.parameters.kernel_parameters:
355            if p.id == self.multiplicity_info.control:
356                value = float(self.multiplicity)
357            elif p.length == 1:
358                value = self.params.get(p.id, np.NaN)
359            else:
360                value = np.array([self.params.get(p.id+str(k), np.NaN)
361                                  for k in range(1, p.length+1)])
362            args[p.id] = value
363
364        x, y = self._model_info.profile(**args)
365        return x, 1e-6*y
366
367    def setParam(self, name, value):
368        # type: (str, float) -> None
369        """
370        Set the value of a model parameter
371
372        :param name: name of the parameter
373        :param value: value of the parameter
374
375        """
376        # Look for dispersion parameters
377        toks = name.split('.')
378        if len(toks) == 2:
379            for item in self.dispersion.keys():
380                if item == toks[0]:
381                    for par in self.dispersion[item]:
382                        if par == toks[1]:
383                            self.dispersion[item][par] = value
384                            return
385        else:
386            # Look for standard parameter
387            for item in self.params.keys():
388                if item == name:
389                    self.params[item] = value
390                    return
391
392        raise ValueError("Model does not contain parameter %s" % name)
393
394    def getParam(self, name):
395        # type: (str) -> float
396        """
397        Set the value of a model parameter
398
399        :param name: name of the parameter
400
401        """
402        # Look for dispersion parameters
403        toks = name.split('.')
404        if len(toks) == 2:
405            for item in self.dispersion.keys():
406                if item == toks[0]:
407                    for par in self.dispersion[item]:
408                        if par == toks[1]:
409                            return self.dispersion[item][par]
410        else:
411            # Look for standard parameter
412            for item in self.params.keys():
413                if item == name:
414                    return self.params[item]
415
416        raise ValueError("Model does not contain parameter %s" % name)
417
418    def getParamList(self):
419        # type: () -> Sequence[str]
420        """
421        Return a list of all available parameters for the model
422        """
423        param_list = list(self.params.keys())
424        # WARNING: Extending the list with the dispersion parameters
425        param_list.extend(self.getDispParamList())
426        return param_list
427
428    def getDispParamList(self):
429        # type: () -> Sequence[str]
430        """
431        Return a list of polydispersity parameters for the model
432        """
433        # TODO: fix test so that parameter order doesn't matter
434        ret = ['%s.%s' % (p.name, ext)
435               for p in self._model_info.parameters.user_parameters()
436               for ext in ('npts', 'nsigmas', 'width')
437               if p.polydisperse]
438        #print(ret)
439        return ret
440
441    def clone(self):
442        # type: () -> "SasviewModel"
443        """ Return a identical copy of self """
444        return deepcopy(self)
445
446    def run(self, x=0.0):
447        # type: (Union[float, (float, float), List[float]]) -> float
448        """
449        Evaluate the model
450
451        :param x: input q, or [q,phi]
452
453        :return: scattering function P(q)
454
455        **DEPRECATED**: use calculate_Iq instead
456        """
457        if isinstance(x, (list, tuple)):
458            # pylint: disable=unpacking-non-sequence
459            q, phi = x
460            return self.calculate_Iq([q*math.cos(phi)], [q*math.sin(phi)])[0]
461        else:
462            return self.calculate_Iq([x])[0]
463
464
465    def runXY(self, x=0.0):
466        # type: (Union[float, (float, float), List[float]]) -> float
467        """
468        Evaluate the model in cartesian coordinates
469
470        :param x: input q, or [qx, qy]
471
472        :return: scattering function P(q)
473
474        **DEPRECATED**: use calculate_Iq instead
475        """
476        if isinstance(x, (list, tuple)):
477            return self.calculate_Iq([x[0]], [x[1]])[0]
478        else:
479            return self.calculate_Iq([x])[0]
480
481    def evalDistribution(self, qdist):
482        # type: (Union[np.ndarray, Tuple[np.ndarray, np.ndarray], List[np.ndarray]]) -> np.ndarray
483        r"""
484        Evaluate a distribution of q-values.
485
486        :param qdist: array of q or a list of arrays [qx,qy]
487
488        * For 1D, a numpy array is expected as input
489
490        ::
491
492            evalDistribution(q)
493
494          where *q* is a numpy array.
495
496        * For 2D, a list of *[qx,qy]* is expected with 1D arrays as input
497
498        ::
499
500              qx = [ qx[0], qx[1], qx[2], ....]
501              qy = [ qy[0], qy[1], qy[2], ....]
502
503        If the model is 1D only, then
504
505        .. math::
506
507            q = \sqrt{q_x^2+q_y^2}
508
509        """
510        if isinstance(qdist, (list, tuple)):
511            # Check whether we have a list of ndarrays [qx,qy]
512            qx, qy = qdist
513            if not self._model_info.parameters.has_2d:
514                return self.calculate_Iq(np.sqrt(qx ** 2 + qy ** 2))
515            else:
516                return self.calculate_Iq(qx, qy)
517
518        elif isinstance(qdist, np.ndarray):
519            # We have a simple 1D distribution of q-values
520            return self.calculate_Iq(qdist)
521
522        else:
523            raise TypeError("evalDistribution expects q or [qx, qy], not %r"
524                            % type(qdist))
525
526    def calculate_Iq(self, qx, qy=None):
527        # type: (Sequence[float], Optional[Sequence[float]]) -> np.ndarray
528        """
529        Calculate Iq for one set of q with the current parameters.
530
531        If the model is 1D, use *q*.  If 2D, use *qx*, *qy*.
532
533        This should NOT be used for fitting since it copies the *q* vectors
534        to the card for each evaluation.
535        """
536        #core.HAVE_OPENCL = False
537        if self._model is None:
538            self._model = core.build_model(self._model_info)
539        if qy is not None:
540            q_vectors = [np.asarray(qx), np.asarray(qy)]
541        else:
542            q_vectors = [np.asarray(qx)]
543        calculator = self._model.make_kernel(q_vectors)
544        parameters = self._model_info.parameters
545        pairs = [self._get_weights(p) for p in parameters.call_parameters]
546        call_details, values, is_magnetic = build_details(calculator, pairs)
547        #call_details.show()
548        #print("pairs", pairs)
549        #print("params", self.params)
550        #print("values", values)
551        #print("is_mag", is_magnetic)
552        result = calculator(call_details, values, cutoff=self.cutoff,
553                            magnetic=is_magnetic)
554        calculator.release()
555        return result
556
557    def calculate_ER(self):
558        # type: () -> float
559        """
560        Calculate the effective radius for P(q)*S(q)
561
562        :return: the value of the effective radius
563        """
564        if self._model_info.ER is None:
565            return 1.0
566        else:
567            value, weight = self._dispersion_mesh()
568            fv = self._model_info.ER(*value)
569            #print(values[0].shape, weights.shape, fv.shape)
570            return np.sum(weight * fv) / np.sum(weight)
571
572    def calculate_VR(self):
573        # type: () -> float
574        """
575        Calculate the volf ratio for P(q)*S(q)
576
577        :return: the value of the volf ratio
578        """
579        if self._model_info.VR is None:
580            return 1.0
581        else:
582            value, weight = self._dispersion_mesh()
583            whole, part = self._model_info.VR(*value)
584            return np.sum(weight * part) / np.sum(weight * whole)
585
586    def set_dispersion(self, parameter, dispersion):
587        # type: (str, weights.Dispersion) -> Dict[str, Any]
588        """
589        Set the dispersion object for a model parameter
590
591        :param parameter: name of the parameter [string]
592        :param dispersion: dispersion object of type Dispersion
593        """
594        if parameter in self.params:
595            # TODO: Store the disperser object directly in the model.
596            # The current method of relying on the sasview GUI to
597            # remember them is kind of funky.
598            # Note: can't seem to get disperser parameters from sasview
599            # (1) Could create a sasview model that has not yet # been
600            # converted, assign the disperser to one of its polydisperse
601            # parameters, then retrieve the disperser parameters from the
602            # sasview model.  (2) Could write a disperser parameter retriever
603            # in sasview.  (3) Could modify sasview to use sasmodels.weights
604            # dispersers.
605            # For now, rely on the fact that the sasview only ever uses
606            # new dispersers in the set_dispersion call and create a new
607            # one instead of trying to assign parameters.
608            dispersion = weights.MODELS[dispersion.type]()
609            self.dispersion[parameter] = dispersion.get_pars()
610        else:
611            raise ValueError("%r is not a dispersity or orientation parameter")
612
613    def _dispersion_mesh(self):
614        # type: () -> List[Tuple[np.ndarray, np.ndarray]]
615        """
616        Create a mesh grid of dispersion parameters and weights.
617
618        Returns [p1,p2,...],w where pj is a vector of values for parameter j
619        and w is a vector containing the products for weights for each
620        parameter set in the vector.
621        """
622        pars = [self._get_weights(p)
623                for p in self._model_info.parameters.call_parameters
624                if p.type == 'volume']
625        return dispersion_mesh(self._model_info, pars)
626
627    def _get_weights(self, par):
628        # type: (Parameter) -> Tuple[np.ndarray, np.ndarray]
629        """
630        Return dispersion weights for parameter
631        """
632        if par.name not in self.params:
633            if par.name == self.multiplicity_info.control:
634                return [self.multiplicity], [1.0]
635            else:
636                return [np.NaN], [1.0]
637        elif par.polydisperse:
638            dis = self.dispersion[par.name]
639            value, weight = weights.get_weights(
640                dis['type'], dis['npts'], dis['width'], dis['nsigmas'],
641                self.params[par.name], par.limits, par.relative_pd)
642            return value, weight / np.sum(weight)
643        else:
644            return [self.params[par.name]], [1.0]
645
646def test_model():
647    # type: () -> float
648    """
649    Test that a sasview model (cylinder) can be run.
650    """
651    Cylinder = _make_standard_model('cylinder')
652    cylinder = Cylinder()
653    return cylinder.evalDistribution([0.1, 0.1])
654
655def test_rpa():
656    # type: () -> float
657    """
658    Test that a sasview model (cylinder) can be run.
659    """
660    RPA = _make_standard_model('rpa')
661    rpa = RPA(3)
662    return rpa.evalDistribution([0.1, 0.1])
663
664
665def test_model_list():
666    # type: () -> None
667    """
668    Make sure that all models build as sasview models.
669    """
670    from .exception import annotate_exception
671    for name in core.list_models():
672        try:
673            _make_standard_model(name)
674        except:
675            annotate_exception("when loading "+name)
676            raise
677
678if __name__ == "__main__":
679    print("cylinder(0.1,0.1)=%g"%test_model())
Note: See TracBrowser for help on using the repository browser.