source: sasmodels/sasmodels/sasview_model.py @ 4e0968b

core_shell_microgelscostrafo411magnetic_modelrelease_v0.94release_v0.95ticket-1257-vesicle-productticket_1156ticket_1265_superballticket_822_more_unit_tests
Last change on this file since 4e0968b was 4e0968b, checked in by Paul Kienzle <pkienzle@…>, 8 years ago

spherical sld: doc cleanup (with syntax errors fixed)

  • Property mode set to 100644
File size: 22.0 KB
Line 
1"""
2Sasview model constructor.
3
4Given a module defining an OpenCL kernel such as sasmodels.models.cylinder,
5create a sasview model class to run that kernel as follows::
6
7    from sasmodels.sasview_model import load_custom_model
8    CylinderModel = load_custom_model('sasmodels/models/cylinder.py')
9"""
10from __future__ import print_function
11
12import math
13from copy import deepcopy
14import collections
15import traceback
16import logging
17
18import numpy as np  # type: ignore
19
20from . import core
21from . import custom
22from . import generate
23from . import weights
24from . import modelinfo
25from .details import build_details, dispersion_mesh
26
27try:
28    from typing import Dict, Mapping, Any, Sequence, Tuple, NamedTuple, List, Optional, Union, Callable
29    from .modelinfo import ModelInfo, Parameter
30    from .kernel import KernelModel
31    MultiplicityInfoType = NamedTuple(
32        'MuliplicityInfo',
33        [("number", int), ("control", str), ("choices", List[str]),
34         ("x_axis_label", str)])
35    SasviewModelType = Callable[[int], "SasviewModel"]
36except ImportError:
37    pass
38
39# TODO: separate x_axis_label from multiplicity info
40# The profile x-axis label belongs with the profile generating function
41MultiplicityInfo = collections.namedtuple(
42    'MultiplicityInfo',
43    ["number", "control", "choices", "x_axis_label"],
44)
45
46MODELS = {}
47def find_model(modelname):
48    # TODO: used by sum/product model to load an existing model
49    # TODO: doesn't handle custom models properly
50    if modelname.endswith('.py'):
51        return load_custom_model(modelname)
52    elif modelname in MODELS:
53        return MODELS[modelname]
54    else:
55        raise ValueError("unknown model %r"%modelname)
56
57
58# TODO: figure out how to say that the return type is a subclass
59def load_standard_models():
60    # type: () -> List[SasviewModelType]
61    """
62    Load and return the list of predefined models.
63
64    If there is an error loading a model, then a traceback is logged and the
65    model is not returned.
66    """
67    models = []
68    for name in core.list_models():
69        try:
70            MODELS[name] = _make_standard_model(name)
71            models.append(MODELS[name])
72        except Exception:
73            logging.error(traceback.format_exc())
74    return models
75
76
77def load_custom_model(path):
78    # type: (str) -> SasviewModelType
79    """
80    Load a custom model given the model path.
81    """
82    #print("load custom model", path)
83    kernel_module = custom.load_custom_kernel_module(path)
84    try:
85        model = kernel_module.Model
86    except AttributeError:
87        model_info = modelinfo.make_model_info(kernel_module)
88        model = _make_model_from_info(model_info)
89    MODELS[model.name] = model
90    return model
91
92
93def _make_standard_model(name):
94    # type: (str) -> SasviewModelType
95    """
96    Load the sasview model defined by *name*.
97
98    *name* can be a standard model name or a path to a custom model.
99
100    Returns a class that can be used directly as a sasview model.
101    """
102    kernel_module = generate.load_kernel_module(name)
103    model_info = modelinfo.make_model_info(kernel_module)
104    return _make_model_from_info(model_info)
105
106
107def _make_model_from_info(model_info):
108    # type: (ModelInfo) -> SasviewModelType
109    """
110    Convert *model_info* into a SasView model wrapper.
111    """
112    def __init__(self, multiplicity=None):
113        SasviewModel.__init__(self, multiplicity=multiplicity)
114    attrs = _generate_model_attributes(model_info)
115    attrs['__init__'] = __init__
116    ConstructedModel = type(model_info.name, (SasviewModel,), attrs) # type: SasviewModelType
117    return ConstructedModel
118
119def _generate_model_attributes(model_info):
120    # type: (ModelInfo) -> Dict[str, Any]
121    """
122    Generate the class attributes for the model.
123
124    This should include all the information necessary to query the model
125    details so that you do not need to instantiate a model to query it.
126
127    All the attributes should be immutable to avoid accidents.
128    """
129
130    # TODO: allow model to override axis labels input/output name/unit
131
132    # Process multiplicity
133    non_fittable = []  # type: List[str]
134    xlabel = model_info.profile_axes[0] if model_info.profile is not None else ""
135    variants = MultiplicityInfo(0, "", [], xlabel)
136    for p in model_info.parameters.kernel_parameters:
137        if p.name == model_info.control:
138            non_fittable.append(p.name)
139            variants = MultiplicityInfo(
140                len(p.choices) if p.choices else int(p.limits[1]),
141                p.name, p.choices, xlabel
142            )
143            break
144
145    # Organize parameter sets
146    orientation_params = []
147    magnetic_params = []
148    fixed = []
149    for p in model_info.parameters.user_parameters():
150        if p.type == 'orientation':
151            orientation_params.append(p.name)
152            orientation_params.append(p.name+".width")
153            fixed.append(p.name+".width")
154        elif p.type == 'magnetic':
155            orientation_params.append(p.name)
156            magnetic_params.append(p.name)
157            fixed.append(p.name+".width")
158
159
160    # Build class dictionary
161    attrs = {}  # type: Dict[str, Any]
162    attrs['_model_info'] = model_info
163    attrs['name'] = model_info.name
164    attrs['id'] = model_info.id
165    attrs['description'] = model_info.description
166    attrs['category'] = model_info.category
167    attrs['is_structure_factor'] = model_info.structure_factor
168    attrs['is_form_factor'] = model_info.ER is not None
169    attrs['is_multiplicity_model'] = variants[0] > 1
170    attrs['multiplicity_info'] = variants
171    attrs['orientation_params'] = tuple(orientation_params)
172    attrs['magnetic_params'] = tuple(magnetic_params)
173    attrs['fixed'] = tuple(fixed)
174    attrs['non_fittable'] = tuple(non_fittable)
175
176    return attrs
177
178class SasviewModel(object):
179    """
180    Sasview wrapper for opencl/ctypes model.
181    """
182    # Model parameters for the specific model are set in the class constructor
183    # via the _generate_model_attributes function, which subclasses
184    # SasviewModel.  They are included here for typing and documentation
185    # purposes.
186    _model = None       # type: KernelModel
187    _model_info = None  # type: ModelInfo
188    #: load/save name for the model
189    id = None           # type: str
190    #: display name for the model
191    name = None         # type: str
192    #: short model description
193    description = None  # type: str
194    #: default model category
195    category = None     # type: str
196
197    #: names of the orientation parameters in the order they appear
198    orientation_params = None # type: Sequence[str]
199    #: names of the magnetic parameters in the order they appear
200    magnetic_params = None    # type: Sequence[str]
201    #: names of the fittable parameters
202    fixed = None              # type: Sequence[str]
203    # TODO: the attribute fixed is ill-named
204
205    # Axis labels
206    input_name = "Q"
207    input_unit = "A^{-1}"
208    output_name = "Intensity"
209    output_unit = "cm^{-1}"
210
211    #: default cutoff for polydispersity
212    cutoff = 1e-5
213
214    # Note: Use non-mutable values for class attributes to avoid errors
215    #: parameters that are not fitted
216    non_fittable = ()        # type: Sequence[str]
217
218    #: True if model should appear as a structure factor
219    is_structure_factor = False
220    #: True if model should appear as a form factor
221    is_form_factor = False
222    #: True if model has multiplicity
223    is_multiplicity_model = False
224    #: Mulitplicity information
225    multiplicity_info = None # type: MultiplicityInfoType
226
227    # Per-instance variables
228    #: parameter {name: value} mapping
229    params = None      # type: Dict[str, float]
230    #: values for dispersion width, npts, nsigmas and type
231    dispersion = None  # type: Dict[str, Any]
232    #: units and limits for each parameter
233    details = None     # type: Dict[str, Sequence[Any]]
234    #                  # actual type is Dict[str, List[str, float, float]]
235    #: multiplicity value, or None if no multiplicity on the model
236    multiplicity = None     # type: Optional[int]
237    #: memory for polydispersity array if using ArrayDispersion (used by sasview).
238    _persistency_dict = None # type: Dict[str, Tuple[np.ndarray, np.ndarray]]
239
240    def __init__(self, multiplicity=None):
241        # type: (Optional[int]) -> None
242
243        # TODO: _persistency_dict to persistency_dict throughout sasview
244        # TODO: refactor multiplicity to encompass variants
245        # TODO: dispersion should be a class
246        # TODO: refactor multiplicity info
247        # TODO: separate profile view from multiplicity
248        # The button label, x and y axis labels and scale need to be under
249        # the control of the model, not the fit page.  Maximum flexibility,
250        # the fit page would supply the canvas and the profile could plot
251        # how it wants, but this assumes matplotlib.  Next level is that
252        # we provide some sort of data description including title, labels
253        # and lines to plot.
254
255        # Get the list of hidden parameters given the mulitplicity
256        # Don't include multiplicity in the list of parameters
257        self.multiplicity = multiplicity
258        if multiplicity is not None:
259            hidden = self._model_info.get_hidden_parameters(multiplicity)
260            hidden |= set([self.multiplicity_info.control])
261        else:
262            hidden = set()
263
264        self._persistency_dict = {}
265        self.params = collections.OrderedDict()
266        self.dispersion = collections.OrderedDict()
267        self.details = {}
268        for p in self._model_info.parameters.user_parameters():
269            if p.name in hidden:
270                continue
271            self.params[p.name] = p.default
272            self.details[p.id] = [p.units, p.limits[0], p.limits[1]]
273            if p.polydisperse:
274                self.details[p.id+".width"] = [
275                    "", 0.0, 1.0 if p.relative_pd else np.inf
276                ]
277                self.dispersion[p.name] = {
278                    'width': 0,
279                    'npts': 35,
280                    'nsigmas': 3,
281                    'type': 'gaussian',
282                }
283
284    def __get_state__(self):
285        # type: () -> Dict[str, Any]
286        state = self.__dict__.copy()
287        state.pop('_model')
288        # May need to reload model info on set state since it has pointers
289        # to python implementations of Iq, etc.
290        #state.pop('_model_info')
291        return state
292
293    def __set_state__(self, state):
294        # type: (Dict[str, Any]) -> None
295        self.__dict__ = state
296        self._model = None
297
298    def __str__(self):
299        # type: () -> str
300        """
301        :return: string representation
302        """
303        return self.name
304
305    def is_fittable(self, par_name):
306        # type: (str) -> bool
307        """
308        Check if a given parameter is fittable or not
309
310        :param par_name: the parameter name to check
311        """
312        return par_name in self.fixed
313        #For the future
314        #return self.params[str(par_name)].is_fittable()
315
316
317    def getProfile(self):
318        # type: () -> (np.ndarray, np.ndarray)
319        """
320        Get SLD profile
321
322        : return: (z, beta) where z is a list of depth of the transition points
323                beta is a list of the corresponding SLD values
324        """
325        args = {} # type: Dict[str, Any]
326        for p in self._model_info.parameters.kernel_parameters:
327            if p.id == self.multiplicity_info.control:
328                value = float(self.multiplicity)
329            elif p.length == 1:
330                value = self.params.get(p.id, np.NaN)
331            else:
332                value = np.array([self.params.get(p.id+str(k), np.NaN)
333                                  for k in range(1,p.length+1)])
334            args[p.id] = value
335
336        return self._model_info.profile(**args)
337
338    def setParam(self, name, value):
339        # type: (str, float) -> None
340        """
341        Set the value of a model parameter
342
343        :param name: name of the parameter
344        :param value: value of the parameter
345
346        """
347        # Look for dispersion parameters
348        toks = name.split('.')
349        if len(toks) == 2:
350            for item in self.dispersion.keys():
351                if item == toks[0]:
352                    for par in self.dispersion[item]:
353                        if par == toks[1]:
354                            self.dispersion[item][par] = value
355                            return
356        else:
357            # Look for standard parameter
358            for item in self.params.keys():
359                if item == name:
360                    self.params[item] = value
361                    return
362
363        raise ValueError("Model does not contain parameter %s" % name)
364
365    def getParam(self, name):
366        # type: (str) -> float
367        """
368        Set the value of a model parameter
369
370        :param name: name of the parameter
371
372        """
373        # Look for dispersion parameters
374        toks = name.split('.')
375        if len(toks) == 2:
376            for item in self.dispersion.keys():
377                if item == toks[0]:
378                    for par in self.dispersion[item]:
379                        if par == toks[1]:
380                            return self.dispersion[item][par]
381        else:
382            # Look for standard parameter
383            for item in self.params.keys():
384                if item == name:
385                    return self.params[item]
386
387        raise ValueError("Model does not contain parameter %s" % name)
388
389    def getParamList(self):
390        # type: () -> Sequence[str]
391        """
392        Return a list of all available parameters for the model
393        """
394        param_list = list(self.params.keys())
395        # WARNING: Extending the list with the dispersion parameters
396        param_list.extend(self.getDispParamList())
397        return param_list
398
399    def getDispParamList(self):
400        # type: () -> Sequence[str]
401        """
402        Return a list of polydispersity parameters for the model
403        """
404        # TODO: fix test so that parameter order doesn't matter
405        ret = ['%s.%s' % (p.name, ext)
406               for p in self._model_info.parameters.user_parameters()
407               for ext in ('npts', 'nsigmas', 'width')
408               if p.polydisperse]
409        #print(ret)
410        return ret
411
412    def clone(self):
413        # type: () -> "SasviewModel"
414        """ Return a identical copy of self """
415        return deepcopy(self)
416
417    def run(self, x=0.0):
418        # type: (Union[float, (float, float), List[float]]) -> float
419        """
420        Evaluate the model
421
422        :param x: input q, or [q,phi]
423
424        :return: scattering function P(q)
425
426        **DEPRECATED**: use calculate_Iq instead
427        """
428        if isinstance(x, (list, tuple)):
429            # pylint: disable=unpacking-non-sequence
430            q, phi = x
431            return self.calculate_Iq([q*math.cos(phi)], [q*math.sin(phi)])[0]
432        else:
433            return self.calculate_Iq([x])[0]
434
435
436    def runXY(self, x=0.0):
437        # type: (Union[float, (float, float), List[float]]) -> float
438        """
439        Evaluate the model in cartesian coordinates
440
441        :param x: input q, or [qx, qy]
442
443        :return: scattering function P(q)
444
445        **DEPRECATED**: use calculate_Iq instead
446        """
447        if isinstance(x, (list, tuple)):
448            return self.calculate_Iq([x[0]], [x[1]])[0]
449        else:
450            return self.calculate_Iq([x])[0]
451
452    def evalDistribution(self, qdist):
453        # type: (Union[np.ndarray, Tuple[np.ndarray, np.ndarray], List[np.ndarray]]) -> np.ndarray
454        r"""
455        Evaluate a distribution of q-values.
456
457        :param qdist: array of q or a list of arrays [qx,qy]
458
459        * For 1D, a numpy array is expected as input
460
461        ::
462
463            evalDistribution(q)
464
465          where *q* is a numpy array.
466
467        * For 2D, a list of *[qx,qy]* is expected with 1D arrays as input
468
469        ::
470
471              qx = [ qx[0], qx[1], qx[2], ....]
472              qy = [ qy[0], qy[1], qy[2], ....]
473
474        If the model is 1D only, then
475
476        .. math::
477
478            q = \sqrt{q_x^2+q_y^2}
479
480        """
481        if isinstance(qdist, (list, tuple)):
482            # Check whether we have a list of ndarrays [qx,qy]
483            qx, qy = qdist
484            if not self._model_info.parameters.has_2d:
485                return self.calculate_Iq(np.sqrt(qx ** 2 + qy ** 2))
486            else:
487                return self.calculate_Iq(qx, qy)
488
489        elif isinstance(qdist, np.ndarray):
490            # We have a simple 1D distribution of q-values
491            return self.calculate_Iq(qdist)
492
493        else:
494            raise TypeError("evalDistribution expects q or [qx, qy], not %r"
495                            % type(qdist))
496
497    def calculate_Iq(self, qx, qy=None):
498        # type: (Sequence[float], Optional[Sequence[float]]) -> np.ndarray
499        """
500        Calculate Iq for one set of q with the current parameters.
501
502        If the model is 1D, use *q*.  If 2D, use *qx*, *qy*.
503
504        This should NOT be used for fitting since it copies the *q* vectors
505        to the card for each evaluation.
506        """
507        #core.HAVE_OPENCL = False
508        if self._model is None:
509            self._model = core.build_model(self._model_info)
510        if qy is not None:
511            q_vectors = [np.asarray(qx), np.asarray(qy)]
512        else:
513            q_vectors = [np.asarray(qx)]
514        calculator = self._model.make_kernel(q_vectors)
515        parameters = self._model_info.parameters
516        pairs = [self._get_weights(p) for p in parameters.call_parameters]
517        call_details, values, is_magnetic = build_details(calculator, pairs)
518        #call_details.show()
519        #print("pairs", pairs)
520        #print("params", self.params)
521        #print("values", values)
522        #print("is_mag", is_magnetic)
523        result = calculator(call_details, values, cutoff=self.cutoff,
524                            magnetic=is_magnetic)
525        calculator.release()
526        return result
527
528    def calculate_ER(self):
529        # type: () -> float
530        """
531        Calculate the effective radius for P(q)*S(q)
532
533        :return: the value of the effective radius
534        """
535        if self._model_info.ER is None:
536            return 1.0
537        else:
538            value, weight = self._dispersion_mesh()
539            fv = self._model_info.ER(*value)
540            #print(values[0].shape, weights.shape, fv.shape)
541            return np.sum(weight * fv) / np.sum(weight)
542
543    def calculate_VR(self):
544        # type: () -> float
545        """
546        Calculate the volf ratio for P(q)*S(q)
547
548        :return: the value of the volf ratio
549        """
550        if self._model_info.VR is None:
551            return 1.0
552        else:
553            value, weight = self._dispersion_mesh()
554            whole, part = self._model_info.VR(*value)
555            return np.sum(weight * part) / np.sum(weight * whole)
556
557    def set_dispersion(self, parameter, dispersion):
558        # type: (str, weights.Dispersion) -> Dict[str, Any]
559        """
560        Set the dispersion object for a model parameter
561
562        :param parameter: name of the parameter [string]
563        :param dispersion: dispersion object of type Dispersion
564        """
565        if parameter in self.params:
566            # TODO: Store the disperser object directly in the model.
567            # The current method of relying on the sasview GUI to
568            # remember them is kind of funky.
569            # Note: can't seem to get disperser parameters from sasview
570            # (1) Could create a sasview model that has not yet # been
571            # converted, assign the disperser to one of its polydisperse
572            # parameters, then retrieve the disperser parameters from the
573            # sasview model.  (2) Could write a disperser parameter retriever
574            # in sasview.  (3) Could modify sasview to use sasmodels.weights
575            # dispersers.
576            # For now, rely on the fact that the sasview only ever uses
577            # new dispersers in the set_dispersion call and create a new
578            # one instead of trying to assign parameters.
579            from . import weights
580            disperser = weights.dispersers[dispersion.__class__.__name__]
581            dispersion = weights.MODELS[disperser]()
582            self.dispersion[parameter] = dispersion.get_pars()
583        else:
584            raise ValueError("%r is not a dispersity or orientation parameter")
585
586    def _dispersion_mesh(self):
587        # type: () -> List[Tuple[np.ndarray, np.ndarray]]
588        """
589        Create a mesh grid of dispersion parameters and weights.
590
591        Returns [p1,p2,...],w where pj is a vector of values for parameter j
592        and w is a vector containing the products for weights for each
593        parameter set in the vector.
594        """
595        pars = [self._get_weights(p)
596                for p in self._model_info.parameters.call_parameters
597                if p.type == 'volume']
598        return dispersion_mesh(self._model_info, pars)
599
600    def _get_weights(self, par):
601        # type: (Parameter) -> Tuple[np.ndarray, np.ndarray]
602        """
603        Return dispersion weights for parameter
604        """
605        if par.name not in self.params:
606            if par.name == self.multiplicity_info.control:
607                return [self.multiplicity], [1.0]
608            else:
609                return [np.NaN], [1.0]
610        elif par.polydisperse:
611            dis = self.dispersion[par.name]
612            value, weight = weights.get_weights(
613                dis['type'], dis['npts'], dis['width'], dis['nsigmas'],
614                self.params[par.name], par.limits, par.relative_pd)
615            return value, weight / np.sum(weight)
616        else:
617            return [self.params[par.name]], [1.0]
618
619def test_model():
620    # type: () -> float
621    """
622    Test that a sasview model (cylinder) can be run.
623    """
624    Cylinder = _make_standard_model('cylinder')
625    cylinder = Cylinder()
626    return cylinder.evalDistribution([0.1,0.1])
627
628def test_rpa():
629    # type: () -> float
630    """
631    Test that a sasview model (cylinder) can be run.
632    """
633    RPA = _make_standard_model('rpa')
634    rpa = RPA(3)
635    return rpa.evalDistribution([0.1,0.1])
636
637
638def test_model_list():
639    # type: () -> None
640    """
641    Make sure that all models build as sasview models.
642    """
643    from .exception import annotate_exception
644    for name in core.list_models():
645        try:
646            _make_standard_model(name)
647        except:
648            annotate_exception("when loading "+name)
649            raise
650
651if __name__ == "__main__":
652    print("cylinder(0.1,0.1)=%g"%test_model())
Note: See TracBrowser for help on using the repository browser.