source: sasmodels/sasmodels/sasview_model.py @ 4edec6f

core_shell_microgelscostrafo411magnetic_modelrelease_v0.94release_v0.95ticket-1257-vesicle-productticket_1156ticket_1265_superballticket_822_more_unit_tests
Last change on this file since 4edec6f was 4edec6f, checked in by Paul Kienzle <pkienzle@…>, 8 years ago

fix sasview for new kernel interface

  • Property mode set to 100644
File size: 22.0 KB
Line 
1"""
2Sasview model constructor.
3
4Given a module defining an OpenCL kernel such as sasmodels.models.cylinder,
5create a sasview model class to run that kernel as follows::
6
7    from sasmodels.sasview_model import load_custom_model
8    CylinderModel = load_custom_model('sasmodels/models/cylinder.py')
9"""
10from __future__ import print_function
11
12import math
13from copy import deepcopy
14import collections
15import traceback
16import logging
17
18import numpy as np  # type: ignore
19
20from . import core
21from . import custom
22from . import generate
23from . import weights
24from . import modelinfo
25from .details import build_details, dispersion_mesh
26
27try:
28    from typing import Dict, Mapping, Any, Sequence, Tuple, NamedTuple, List, Optional, Union, Callable
29    from .modelinfo import ModelInfo, Parameter
30    from .kernel import KernelModel
31    MultiplicityInfoType = NamedTuple(
32        'MuliplicityInfo',
33        [("number", int), ("control", str), ("choices", List[str]),
34         ("x_axis_label", str)])
35    SasviewModelType = Callable[[int], "SasviewModel"]
36except ImportError:
37    pass
38
39# TODO: separate x_axis_label from multiplicity info
40# The profile x-axis label belongs with the profile generating function
41MultiplicityInfo = collections.namedtuple(
42    'MultiplicityInfo',
43    ["number", "control", "choices", "x_axis_label"],
44)
45
46MODELS = {}
47def find_model(modelname):
48    # TODO: used by sum/product model to load an existing model
49    # TODO: doesn't handle custom models properly
50    if modelname.endswith('.py'):
51        return load_custom_model(modelname)
52    elif modelname in MODELS:
53        return MODELS[modelname]
54    else:
55        raise ValueError("unknown model %r"%modelname)
56
57
58# TODO: figure out how to say that the return type is a subclass
59def load_standard_models():
60    # type: () -> List[SasviewModelType]
61    """
62    Load and return the list of predefined models.
63
64    If there is an error loading a model, then a traceback is logged and the
65    model is not returned.
66    """
67    models = []
68    for name in core.list_models():
69        try:
70            MODELS[name] = _make_standard_model(name)
71            models.append(MODELS[name])
72        except Exception:
73            logging.error(traceback.format_exc())
74    return models
75
76
77def load_custom_model(path):
78    # type: (str) -> SasviewModelType
79    """
80    Load a custom model given the model path.
81    """
82    #print("load custom model", path)
83    kernel_module = custom.load_custom_kernel_module(path)
84    try:
85        model = kernel_module.Model
86    except AttributeError:
87        model_info = modelinfo.make_model_info(kernel_module)
88        model = _make_model_from_info(model_info)
89    MODELS[model.name] = model
90    return model
91
92
93def _make_standard_model(name):
94    # type: (str) -> SasviewModelType
95    """
96    Load the sasview model defined by *name*.
97
98    *name* can be a standard model name or a path to a custom model.
99
100    Returns a class that can be used directly as a sasview model.
101    """
102    kernel_module = generate.load_kernel_module(name)
103    model_info = modelinfo.make_model_info(kernel_module)
104    return _make_model_from_info(model_info)
105
106
107def _make_model_from_info(model_info):
108    # type: (ModelInfo) -> SasviewModelType
109    """
110    Convert *model_info* into a SasView model wrapper.
111    """
112    def __init__(self, multiplicity=None):
113        SasviewModel.__init__(self, multiplicity=multiplicity)
114    attrs = _generate_model_attributes(model_info)
115    attrs['__init__'] = __init__
116    ConstructedModel = type(model_info.name, (SasviewModel,), attrs) # type: SasviewModelType
117    return ConstructedModel
118
119def _generate_model_attributes(model_info):
120    # type: (ModelInfo) -> Dict[str, Any]
121    """
122    Generate the class attributes for the model.
123
124    This should include all the information necessary to query the model
125    details so that you do not need to instantiate a model to query it.
126
127    All the attributes should be immutable to avoid accidents.
128    """
129
130    # TODO: allow model to override axis labels input/output name/unit
131
132    # Process multiplicity
133    non_fittable = []  # type: List[str]
134    xlabel = model_info.profile_axes[0] if model_info.profile is not None else ""
135    variants = MultiplicityInfo(0, "", [], xlabel)
136    for p in model_info.parameters.kernel_parameters:
137        if p.name == model_info.control:
138            non_fittable.append(p.name)
139            variants = MultiplicityInfo(
140                len(p.choices) if p.choices else int(p.limits[1]),
141                p.name, p.choices, xlabel
142            )
143            break
144
145    # Organize parameter sets
146    orientation_params = []
147    magnetic_params = []
148    fixed = []
149    for p in model_info.parameters.user_parameters():
150        if p.type == 'orientation':
151            orientation_params.append(p.name)
152            orientation_params.append(p.name+".width")
153            fixed.append(p.name+".width")
154        elif p.type == 'magnetic':
155            orientation_params.append(p.name)
156            magnetic_params.append(p.name)
157            fixed.append(p.name+".width")
158
159
160    # Build class dictionary
161    attrs = {}  # type: Dict[str, Any]
162    attrs['_model_info'] = model_info
163    attrs['name'] = model_info.name
164    attrs['id'] = model_info.id
165    attrs['description'] = model_info.description
166    attrs['category'] = model_info.category
167    attrs['is_structure_factor'] = model_info.structure_factor
168    attrs['is_form_factor'] = model_info.ER is not None
169    attrs['is_multiplicity_model'] = variants[0] > 1
170    attrs['multiplicity_info'] = variants
171    attrs['orientation_params'] = tuple(orientation_params)
172    attrs['magnetic_params'] = tuple(magnetic_params)
173    attrs['fixed'] = tuple(fixed)
174    attrs['non_fittable'] = tuple(non_fittable)
175
176    return attrs
177
178class SasviewModel(object):
179    """
180    Sasview wrapper for opencl/ctypes model.
181    """
182    # Model parameters for the specific model are set in the class constructor
183    # via the _generate_model_attributes function, which subclasses
184    # SasviewModel.  They are included here for typing and documentation
185    # purposes.
186    _model = None       # type: KernelModel
187    _model_info = None  # type: ModelInfo
188    #: load/save name for the model
189    id = None           # type: str
190    #: display name for the model
191    name = None         # type: str
192    #: short model description
193    description = None  # type: str
194    #: default model category
195    category = None     # type: str
196
197    #: names of the orientation parameters in the order they appear
198    orientation_params = None # type: Sequence[str]
199    #: names of the magnetic parameters in the order they appear
200    magnetic_params = None    # type: Sequence[str]
201    #: names of the fittable parameters
202    fixed = None              # type: Sequence[str]
203    # TODO: the attribute fixed is ill-named
204
205    # Axis labels
206    input_name = "Q"
207    input_unit = "A^{-1}"
208    output_name = "Intensity"
209    output_unit = "cm^{-1}"
210
211    #: default cutoff for polydispersity
212    cutoff = 1e-5
213
214    # Note: Use non-mutable values for class attributes to avoid errors
215    #: parameters that are not fitted
216    non_fittable = ()        # type: Sequence[str]
217
218    #: True if model should appear as a structure factor
219    is_structure_factor = False
220    #: True if model should appear as a form factor
221    is_form_factor = False
222    #: True if model has multiplicity
223    is_multiplicity_model = False
224    #: Mulitplicity information
225    multiplicity_info = None # type: MultiplicityInfoType
226
227    # Per-instance variables
228    #: parameter {name: value} mapping
229    params = None      # type: Dict[str, float]
230    #: values for dispersion width, npts, nsigmas and type
231    dispersion = None  # type: Dict[str, Any]
232    #: units and limits for each parameter
233    details = None     # type: Dict[str, Sequence[Any]]
234    #                  # actual type is Dict[str, List[str, float, float]]
235    #: multiplicity value, or None if no multiplicity on the model
236    multiplicity = None     # type: Optional[int]
237    #: memory for polydispersity array if using ArrayDispersion (used by sasview).
238    _persistency_dict = None # type: Dict[str, Tuple[np.ndarray, np.ndarray]]
239
240    def __init__(self, multiplicity=None):
241        # type: (Optional[int]) -> None
242
243        # TODO: _persistency_dict to persistency_dict throughout sasview
244        # TODO: refactor multiplicity to encompass variants
245        # TODO: dispersion should be a class
246        # TODO: refactor multiplicity info
247        # TODO: separate profile view from multiplicity
248        # The button label, x and y axis labels and scale need to be under
249        # the control of the model, not the fit page.  Maximum flexibility,
250        # the fit page would supply the canvas and the profile could plot
251        # how it wants, but this assumes matplotlib.  Next level is that
252        # we provide some sort of data description including title, labels
253        # and lines to plot.
254
255        # Get the list of hidden parameters given the mulitplicity
256        # Don't include multiplicity in the list of parameters
257        self.multiplicity = multiplicity
258        if multiplicity is not None:
259            hidden = self._model_info.get_hidden_parameters(multiplicity)
260            hidden |= set([self.multiplicity_info.control])
261        else:
262            hidden = set()
263
264        self._persistency_dict = {}
265        self.params = collections.OrderedDict()
266        self.dispersion = collections.OrderedDict()
267        self.details = {}
268        for p in self._model_info.parameters.user_parameters():
269            if p.name in hidden:
270                continue
271            self.params[p.name] = p.default
272            self.details[p.id] = [p.units, p.limits[0], p.limits[1]]
273            if p.polydisperse:
274                self.details[p.id+".width"] = [
275                    "", 0.0, 1.0 if p.relative_pd else np.inf
276                ]
277                self.dispersion[p.name] = {
278                    'width': 0,
279                    'npts': 35,
280                    'nsigmas': 3,
281                    'type': 'gaussian',
282                }
283
284    def __get_state__(self):
285        # type: () -> Dict[str, Any]
286        state = self.__dict__.copy()
287        state.pop('_model')
288        # May need to reload model info on set state since it has pointers
289        # to python implementations of Iq, etc.
290        #state.pop('_model_info')
291        return state
292
293    def __set_state__(self, state):
294        # type: (Dict[str, Any]) -> None
295        self.__dict__ = state
296        self._model = None
297
298    def __str__(self):
299        # type: () -> str
300        """
301        :return: string representation
302        """
303        return self.name
304
305    def is_fittable(self, par_name):
306        # type: (str) -> bool
307        """
308        Check if a given parameter is fittable or not
309
310        :param par_name: the parameter name to check
311        """
312        return par_name in self.fixed
313        #For the future
314        #return self.params[str(par_name)].is_fittable()
315
316
317    def getProfile(self):
318        # type: () -> (np.ndarray, np.ndarray)
319        """
320        Get SLD profile
321
322        : return: (z, beta) where z is a list of depth of the transition points
323                beta is a list of the corresponding SLD values
324        """
325        args = [] # type: List[Union[float, np.ndarray]]
326        for p in self._model_info.parameters.kernel_parameters:
327            if p.id == self.multiplicity_info.control:
328                args.append(float(self.multiplicity))
329            elif p.length == 1:
330                args.append(self.params.get(p.id, np.NaN))
331            else:
332                args.append([self.params.get(p.id+str(k), np.NaN)
333                             for k in range(1,p.length+1)])
334        return self._model_info.profile(*args)
335
336    def setParam(self, name, value):
337        # type: (str, float) -> None
338        """
339        Set the value of a model parameter
340
341        :param name: name of the parameter
342        :param value: value of the parameter
343
344        """
345        # Look for dispersion parameters
346        toks = name.split('.')
347        if len(toks) == 2:
348            for item in self.dispersion.keys():
349                if item == toks[0]:
350                    for par in self.dispersion[item]:
351                        if par == toks[1]:
352                            self.dispersion[item][par] = value
353                            return
354        else:
355            # Look for standard parameter
356            for item in self.params.keys():
357                if item == name:
358                    self.params[item] = value
359                    return
360
361        raise ValueError("Model does not contain parameter %s" % name)
362
363    def getParam(self, name):
364        # type: (str) -> float
365        """
366        Set the value of a model parameter
367
368        :param name: name of the parameter
369
370        """
371        # Look for dispersion parameters
372        toks = name.split('.')
373        if len(toks) == 2:
374            for item in self.dispersion.keys():
375                if item == toks[0]:
376                    for par in self.dispersion[item]:
377                        if par == toks[1]:
378                            return self.dispersion[item][par]
379        else:
380            # Look for standard parameter
381            for item in self.params.keys():
382                if item == name:
383                    return self.params[item]
384
385        raise ValueError("Model does not contain parameter %s" % name)
386
387    def getParamList(self):
388        # type: () -> Sequence[str]
389        """
390        Return a list of all available parameters for the model
391        """
392        param_list = list(self.params.keys())
393        # WARNING: Extending the list with the dispersion parameters
394        param_list.extend(self.getDispParamList())
395        return param_list
396
397    def getDispParamList(self):
398        # type: () -> Sequence[str]
399        """
400        Return a list of polydispersity parameters for the model
401        """
402        # TODO: fix test so that parameter order doesn't matter
403        ret = ['%s.%s' % (p.name, ext)
404               for p in self._model_info.parameters.user_parameters()
405               for ext in ('npts', 'nsigmas', 'width')
406               if p.polydisperse]
407        #print(ret)
408        return ret
409
410    def clone(self):
411        # type: () -> "SasviewModel"
412        """ Return a identical copy of self """
413        return deepcopy(self)
414
415    def run(self, x=0.0):
416        # type: (Union[float, (float, float), List[float]]) -> float
417        """
418        Evaluate the model
419
420        :param x: input q, or [q,phi]
421
422        :return: scattering function P(q)
423
424        **DEPRECATED**: use calculate_Iq instead
425        """
426        if isinstance(x, (list, tuple)):
427            # pylint: disable=unpacking-non-sequence
428            q, phi = x
429            return self.calculate_Iq([q*math.cos(phi)], [q*math.sin(phi)])[0]
430        else:
431            return self.calculate_Iq([x])[0]
432
433
434    def runXY(self, x=0.0):
435        # type: (Union[float, (float, float), List[float]]) -> float
436        """
437        Evaluate the model in cartesian coordinates
438
439        :param x: input q, or [qx, qy]
440
441        :return: scattering function P(q)
442
443        **DEPRECATED**: use calculate_Iq instead
444        """
445        if isinstance(x, (list, tuple)):
446            return self.calculate_Iq([x[0]], [x[1]])[0]
447        else:
448            return self.calculate_Iq([x])[0]
449
450    def evalDistribution(self, qdist):
451        # type: (Union[np.ndarray, Tuple[np.ndarray, np.ndarray], List[np.ndarray]]) -> np.ndarray
452        r"""
453        Evaluate a distribution of q-values.
454
455        :param qdist: array of q or a list of arrays [qx,qy]
456
457        * For 1D, a numpy array is expected as input
458
459        ::
460
461            evalDistribution(q)
462
463          where *q* is a numpy array.
464
465        * For 2D, a list of *[qx,qy]* is expected with 1D arrays as input
466
467        ::
468
469              qx = [ qx[0], qx[1], qx[2], ....]
470              qy = [ qy[0], qy[1], qy[2], ....]
471
472        If the model is 1D only, then
473
474        .. math::
475
476            q = \sqrt{q_x^2+q_y^2}
477
478        """
479        if isinstance(qdist, (list, tuple)):
480            # Check whether we have a list of ndarrays [qx,qy]
481            qx, qy = qdist
482            if not self._model_info.parameters.has_2d:
483                return self.calculate_Iq(np.sqrt(qx ** 2 + qy ** 2))
484            else:
485                return self.calculate_Iq(qx, qy)
486
487        elif isinstance(qdist, np.ndarray):
488            # We have a simple 1D distribution of q-values
489            return self.calculate_Iq(qdist)
490
491        else:
492            raise TypeError("evalDistribution expects q or [qx, qy], not %r"
493                            % type(qdist))
494
495    def calculate_Iq(self, qx, qy=None):
496        # type: (Sequence[float], Optional[Sequence[float]]) -> np.ndarray
497        """
498        Calculate Iq for one set of q with the current parameters.
499
500        If the model is 1D, use *q*.  If 2D, use *qx*, *qy*.
501
502        This should NOT be used for fitting since it copies the *q* vectors
503        to the card for each evaluation.
504        """
505        #core.HAVE_OPENCL = False
506        if self._model is None:
507            self._model = core.build_model(self._model_info)
508        if qy is not None:
509            q_vectors = [np.asarray(qx), np.asarray(qy)]
510        else:
511            q_vectors = [np.asarray(qx)]
512        calculator = self._model.make_kernel(q_vectors)
513        parameters = self._model_info.parameters
514        pairs = [self._get_weights(p) for p in parameters.call_parameters]
515        call_details, values, is_magnetic = build_details(calculator, pairs)
516        #call_details.show()
517        #print("pairs", pairs)
518        #print("params", self.params)
519        #print("values", values)
520        #print("is_mag", is_magnetic)
521        result = calculator(call_details, values, cutoff=self.cutoff,
522                            magnetic=is_magnetic)
523        calculator.release()
524        return result
525
526    def calculate_ER(self):
527        # type: () -> float
528        """
529        Calculate the effective radius for P(q)*S(q)
530
531        :return: the value of the effective radius
532        """
533        if self._model_info.ER is None:
534            return 1.0
535        else:
536            value, weight = self._dispersion_mesh()
537            fv = self._model_info.ER(*value)
538            #print(values[0].shape, weights.shape, fv.shape)
539            return np.sum(weight * fv) / np.sum(weight)
540
541    def calculate_VR(self):
542        # type: () -> float
543        """
544        Calculate the volf ratio for P(q)*S(q)
545
546        :return: the value of the volf ratio
547        """
548        if self._model_info.VR is None:
549            return 1.0
550        else:
551            value, weight = self._dispersion_mesh()
552            whole, part = self._model_info.VR(*value)
553            return np.sum(weight * part) / np.sum(weight * whole)
554
555    def set_dispersion(self, parameter, dispersion):
556        # type: (str, weights.Dispersion) -> Dict[str, Any]
557        """
558        Set the dispersion object for a model parameter
559
560        :param parameter: name of the parameter [string]
561        :param dispersion: dispersion object of type Dispersion
562        """
563        if parameter in self.params:
564            # TODO: Store the disperser object directly in the model.
565            # The current method of relying on the sasview GUI to
566            # remember them is kind of funky.
567            # Note: can't seem to get disperser parameters from sasview
568            # (1) Could create a sasview model that has not yet # been
569            # converted, assign the disperser to one of its polydisperse
570            # parameters, then retrieve the disperser parameters from the
571            # sasview model.  (2) Could write a disperser parameter retriever
572            # in sasview.  (3) Could modify sasview to use sasmodels.weights
573            # dispersers.
574            # For now, rely on the fact that the sasview only ever uses
575            # new dispersers in the set_dispersion call and create a new
576            # one instead of trying to assign parameters.
577            from . import weights
578            disperser = weights.dispersers[dispersion.__class__.__name__]
579            dispersion = weights.MODELS[disperser]()
580            self.dispersion[parameter] = dispersion.get_pars()
581        else:
582            raise ValueError("%r is not a dispersity or orientation parameter")
583
584    def _dispersion_mesh(self):
585        # type: () -> List[Tuple[np.ndarray, np.ndarray]]
586        """
587        Create a mesh grid of dispersion parameters and weights.
588
589        Returns [p1,p2,...],w where pj is a vector of values for parameter j
590        and w is a vector containing the products for weights for each
591        parameter set in the vector.
592        """
593        pars = [self._get_weights(p)
594                for p in self._model_info.parameters.call_parameters
595                if p.type == 'volume']
596        return dispersion_mesh(self._model_info, pars)
597
598    def _get_weights(self, par):
599        # type: (Parameter) -> Tuple[np.ndarray, np.ndarray]
600        """
601        Return dispersion weights for parameter
602        """
603        if par.name not in self.params:
604            if par.name == self.multiplicity_info.control:
605                return [self.multiplicity], [1.0]
606            else:
607                return [np.NaN], [1.0]
608        elif par.polydisperse:
609            dis = self.dispersion[par.name]
610            value, weight = weights.get_weights(
611                dis['type'], dis['npts'], dis['width'], dis['nsigmas'],
612                self.params[par.name], par.limits, par.relative_pd)
613            return value, weight / np.sum(weight)
614        else:
615            return [self.params[par.name]], [1.0]
616
617def test_model():
618    # type: () -> float
619    """
620    Test that a sasview model (cylinder) can be run.
621    """
622    Cylinder = _make_standard_model('cylinder')
623    cylinder = Cylinder()
624    return cylinder.evalDistribution([0.1,0.1])
625
626def test_rpa():
627    # type: () -> float
628    """
629    Test that a sasview model (cylinder) can be run.
630    """
631    RPA = _make_standard_model('rpa')
632    rpa = RPA(3)
633    return rpa.evalDistribution([0.1,0.1])
634
635
636def test_model_list():
637    # type: () -> None
638    """
639    Make sure that all models build as sasview models.
640    """
641    from .exception import annotate_exception
642    for name in core.list_models():
643        try:
644            _make_standard_model(name)
645        except:
646            annotate_exception("when loading "+name)
647            raise
648
649if __name__ == "__main__":
650    print("cylinder(0.1,0.1)=%g"%test_model())
Note: See TracBrowser for help on using the repository browser.