source: sasmodels/sasmodels/sasview_model.py @ ce176ca

core_shell_microgelscostrafo411magnetic_modelrelease_v0.94release_v0.95ticket-1257-vesicle-productticket_1156ticket_1265_superballticket_822_more_unit_tests
Last change on this file since ce176ca was ce176ca, checked in by Paul Kienzle <pkienzle@…>, 8 years ago

fix parameter show/hide for multiplicity models

  • Property mode set to 100644
File size: 21.7 KB
Line 
1"""
2Sasview model constructor.
3
4Given a module defining an OpenCL kernel such as sasmodels.models.cylinder,
5create a sasview model class to run that kernel as follows::
6
7    from sasmodels.sasview_model import load_custom_model
8    CylinderModel = load_custom_model('sasmodels/models/cylinder.py')
9"""
10from __future__ import print_function
11
12import math
13from copy import deepcopy
14import collections
15import traceback
16import logging
17
18import numpy as np  # type: ignore
19
20from . import core
21from . import custom
22from . import generate
23from . import weights
24from . import modelinfo
25from . import kernel
26
27try:
28    from typing import Dict, Mapping, Any, Sequence, Tuple, NamedTuple, List, Optional, Union, Callable
29    from .modelinfo import ModelInfo, Parameter
30    from .kernel import KernelModel
31    MultiplicityInfoType = NamedTuple(
32        'MuliplicityInfo',
33        [("number", int), ("control", str), ("choices", List[str]),
34         ("x_axis_label", str)])
35    SasviewModelType = Callable[[int], "SasviewModel"]
36except ImportError:
37    pass
38
39# TODO: separate x_axis_label from multiplicity info
40# The profile x-axis label belongs with the profile generating function
41MultiplicityInfo = collections.namedtuple(
42    'MultiplicityInfo',
43    ["number", "control", "choices", "x_axis_label"],
44)
45
46MODELS = {}
47def find_model(modelname):
48    # TODO: used by sum/product model to load an existing model
49    # TODO: doesn't handle custom models properly
50    if modelname.endswith('.py'):
51        return load_custom_model(modelname)
52    elif modelname in MODELS:
53        return MODELS[modelname]
54    else:
55        raise ValueError("unknown model %r"%modelname)
56
57
58# TODO: figure out how to say that the return type is a subclass
59def load_standard_models():
60    # type: () -> List[SasviewModelType]
61    """
62    Load and return the list of predefined models.
63
64    If there is an error loading a model, then a traceback is logged and the
65    model is not returned.
66    """
67    models = []
68    for name in core.list_models():
69        try:
70            MODELS[name] = _make_standard_model(name)
71            models.append(MODELS[name])
72        except Exception:
73            logging.error(traceback.format_exc())
74    return models
75
76
77def load_custom_model(path):
78    # type: (str) -> SasviewModelType
79    """
80    Load a custom model given the model path.
81    """
82    #print("load custom model", path)
83    kernel_module = custom.load_custom_kernel_module(path)
84    try:
85        model = kernel_module.Model
86    except AttributeError:
87        model_info = modelinfo.make_model_info(kernel_module)
88        model = _make_model_from_info(model_info)
89    MODELS[model.name] = model
90    return model
91
92
93def _make_standard_model(name):
94    # type: (str) -> SasviewModelType
95    """
96    Load the sasview model defined by *name*.
97
98    *name* can be a standard model name or a path to a custom model.
99
100    Returns a class that can be used directly as a sasview model.
101    """
102    kernel_module = generate.load_kernel_module(name)
103    model_info = modelinfo.make_model_info(kernel_module)
104    return _make_model_from_info(model_info)
105
106
107def _make_model_from_info(model_info):
108    # type: (ModelInfo) -> SasviewModelType
109    """
110    Convert *model_info* into a SasView model wrapper.
111    """
112    def __init__(self, multiplicity=None):
113        SasviewModel.__init__(self, multiplicity=multiplicity)
114    attrs = _generate_model_attributes(model_info)
115    attrs['__init__'] = __init__
116    ConstructedModel = type(model_info.name, (SasviewModel,), attrs) # type: SasviewModelType
117    return ConstructedModel
118
119def _generate_model_attributes(model_info):
120    # type: (ModelInfo) -> Dict[str, Any]
121    """
122    Generate the class attributes for the model.
123
124    This should include all the information necessary to query the model
125    details so that you do not need to instantiate a model to query it.
126
127    All the attributes should be immutable to avoid accidents.
128    """
129
130    # TODO: allow model to override axis labels input/output name/unit
131
132    # Process multiplicity
133    non_fittable = []  # type: List[str]
134    xlabel = model_info.profile_axes[0] if model_info.profile is not None else ""
135    variants = MultiplicityInfo(0, "", [], xlabel)
136    for p in model_info.parameters.kernel_parameters:
137        if p.name == model_info.control:
138            non_fittable.append(p.name)
139            variants = MultiplicityInfo(
140                len(p.choices) if p.choices else int(p.limits[1]),
141                p.name, p.choices, xlabel
142            )
143            break
144
145    # Organize parameter sets
146    orientation_params = []
147    magnetic_params = []
148    fixed = []
149    for p in model_info.parameters.user_parameters():
150        if p.type == 'orientation':
151            orientation_params.append(p.name)
152            orientation_params.append(p.name+".width")
153            fixed.append(p.name+".width")
154        if p.type == 'magnetic':
155            orientation_params.append(p.name)
156            magnetic_params.append(p.name)
157            fixed.append(p.name+".width")
158
159    # Build class dictionary
160    attrs = {}  # type: Dict[str, Any]
161    attrs['_model_info'] = model_info
162    attrs['name'] = model_info.name
163    attrs['id'] = model_info.id
164    attrs['description'] = model_info.description
165    attrs['category'] = model_info.category
166    attrs['is_structure_factor'] = model_info.structure_factor
167    attrs['is_form_factor'] = model_info.ER is not None
168    attrs['is_multiplicity_model'] = variants[0] > 1
169    attrs['multiplicity_info'] = variants
170    attrs['orientation_params'] = tuple(orientation_params)
171    attrs['magnetic_params'] = tuple(magnetic_params)
172    attrs['fixed'] = tuple(fixed)
173    attrs['non_fittable'] = tuple(non_fittable)
174
175    return attrs
176
177class SasviewModel(object):
178    """
179    Sasview wrapper for opencl/ctypes model.
180    """
181    # Model parameters for the specific model are set in the class constructor
182    # via the _generate_model_attributes function, which subclasses
183    # SasviewModel.  They are included here for typing and documentation
184    # purposes.
185    _model = None       # type: KernelModel
186    _model_info = None  # type: ModelInfo
187    #: load/save name for the model
188    id = None           # type: str
189    #: display name for the model
190    name = None         # type: str
191    #: short model description
192    description = None  # type: str
193    #: default model category
194    category = None     # type: str
195
196    #: names of the orientation parameters in the order they appear
197    orientation_params = None # type: Sequence[str]
198    #: names of the magnetic parameters in the order they appear
199    magnetic_params = None    # type: Sequence[str]
200    #: names of the fittable parameters
201    fixed = None              # type: Sequence[str]
202    # TODO: the attribute fixed is ill-named
203
204    # Axis labels
205    input_name = "Q"
206    input_unit = "A^{-1}"
207    output_name = "Intensity"
208    output_unit = "cm^{-1}"
209
210    #: default cutoff for polydispersity
211    cutoff = 1e-5
212
213    # Note: Use non-mutable values for class attributes to avoid errors
214    #: parameters that are not fitted
215    non_fittable = ()        # type: Sequence[str]
216
217    #: True if model should appear as a structure factor
218    is_structure_factor = False
219    #: True if model should appear as a form factor
220    is_form_factor = False
221    #: True if model has multiplicity
222    is_multiplicity_model = False
223    #: Mulitplicity information
224    multiplicity_info = None # type: MultiplicityInfoType
225
226    # Per-instance variables
227    #: parameter {name: value} mapping
228    params = None      # type: Dict[str, float]
229    #: values for dispersion width, npts, nsigmas and type
230    dispersion = None  # type: Dict[str, Any]
231    #: units and limits for each parameter
232    details = None     # type: Dict[str, Sequence[Any]]
233    #                  # actual type is Dict[str, List[str, float, float]]
234    #: multiplicity value, or None if no multiplicity on the model
235    multiplicity = None     # type: Optional[int]
236    #: memory for polydispersity array if using ArrayDispersion (used by sasview).
237    _persistency_dict = None # type: Dict[str, Tuple[np.ndarray, np.ndarray]]
238
239    def __init__(self, multiplicity=None):
240        # type: (Optional[int]) -> None
241
242        # TODO: _persistency_dict to persistency_dict throughout sasview
243        # TODO: refactor multiplicity to encompass variants
244        # TODO: dispersion should be a class
245        # TODO: refactor multiplicity info
246        # TODO: separate profile view from multiplicity
247        # The button label, x and y axis labels and scale need to be under
248        # the control of the model, not the fit page.  Maximum flexibility,
249        # the fit page would supply the canvas and the profile could plot
250        # how it wants, but this assumes matplotlib.  Next level is that
251        # we provide some sort of data description including title, labels
252        # and lines to plot.
253
254        # Get the list of hidden parameters given the mulitplicity
255        # Don't include multiplicity in the list of parameters
256        self.multiplicity = multiplicity
257        if multiplicity is not None:
258            hidden = self._model_info.get_hidden_parameters(multiplicity)
259            hidden |= set([self.multiplicity_info.control])
260        else:
261            hidden = set()
262
263        self._persistency_dict = {}
264        self.params = collections.OrderedDict()
265        self.dispersion = collections.OrderedDict()
266        self.details = {}
267        for p in self._model_info.parameters.user_parameters():
268            if p.name in hidden:
269                continue
270            self.params[p.name] = p.default
271            self.details[p.id] = [p.units, p.limits[0], p.limits[1]]
272            if p.polydisperse:
273                self.details[p.id+".width"] = [
274                    "", 0.0, 1.0 if p.relative_pd else np.inf
275                ]
276                self.dispersion[p.name] = {
277                    'width': 0,
278                    'npts': 35,
279                    'nsigmas': 3,
280                    'type': 'gaussian',
281                }
282
283    def __get_state__(self):
284        # type: () -> Dict[str, Any]
285        state = self.__dict__.copy()
286        state.pop('_model')
287        # May need to reload model info on set state since it has pointers
288        # to python implementations of Iq, etc.
289        #state.pop('_model_info')
290        return state
291
292    def __set_state__(self, state):
293        # type: (Dict[str, Any]) -> None
294        self.__dict__ = state
295        self._model = None
296
297    def __str__(self):
298        # type: () -> str
299        """
300        :return: string representation
301        """
302        return self.name
303
304    def is_fittable(self, par_name):
305        # type: (str) -> bool
306        """
307        Check if a given parameter is fittable or not
308
309        :param par_name: the parameter name to check
310        """
311        return par_name in self.fixed
312        #For the future
313        #return self.params[str(par_name)].is_fittable()
314
315
316    def getProfile(self):
317        # type: () -> (np.ndarray, np.ndarray)
318        """
319        Get SLD profile
320
321        : return: (z, beta) where z is a list of depth of the transition points
322                beta is a list of the corresponding SLD values
323        """
324        args = [] # type: List[Union[float, np.ndarray]]
325        for p in self._model_info.parameters.kernel_parameters:
326            if p.id == self.multiplicity_info.control:
327                args.append(float(self.multiplicity))
328            elif p.length == 1:
329                args.append(self.params.get(p.id, np.NaN))
330            else:
331                args.append([self.params.get(p.id+str(k), np.NaN)
332                             for k in range(1,p.length+1)])
333        return self._model_info.profile(*args)
334
335    def setParam(self, name, value):
336        # type: (str, float) -> None
337        """
338        Set the value of a model parameter
339
340        :param name: name of the parameter
341        :param value: value of the parameter
342
343        """
344        # Look for dispersion parameters
345        toks = name.split('.')
346        if len(toks) == 2:
347            for item in self.dispersion.keys():
348                if item == toks[0]:
349                    for par in self.dispersion[item]:
350                        if par == toks[1]:
351                            self.dispersion[item][par] = value
352                            return
353        else:
354            # Look for standard parameter
355            for item in self.params.keys():
356                if item == name:
357                    self.params[item] = value
358                    return
359
360        raise ValueError("Model does not contain parameter %s" % name)
361
362    def getParam(self, name):
363        # type: (str) -> float
364        """
365        Set the value of a model parameter
366
367        :param name: name of the parameter
368
369        """
370        # Look for dispersion parameters
371        toks = name.split('.')
372        if len(toks) == 2:
373            for item in self.dispersion.keys():
374                if item == toks[0]:
375                    for par in self.dispersion[item]:
376                        if par == toks[1]:
377                            return self.dispersion[item][par]
378        else:
379            # Look for standard parameter
380            for item in self.params.keys():
381                if item == name:
382                    return self.params[item]
383
384        raise ValueError("Model does not contain parameter %s" % name)
385
386    def getParamList(self):
387        # type: () -> Sequence[str]
388        """
389        Return a list of all available parameters for the model
390        """
391        param_list = list(self.params.keys())
392        # WARNING: Extending the list with the dispersion parameters
393        param_list.extend(self.getDispParamList())
394        return param_list
395
396    def getDispParamList(self):
397        # type: () -> Sequence[str]
398        """
399        Return a list of polydispersity parameters for the model
400        """
401        # TODO: fix test so that parameter order doesn't matter
402        ret = ['%s.%s' % (p.name, ext)
403               for p in self._model_info.parameters.user_parameters()
404               for ext in ('npts', 'nsigmas', 'width')
405               if p.polydisperse]
406        #print(ret)
407        return ret
408
409    def clone(self):
410        # type: () -> "SasviewModel"
411        """ Return a identical copy of self """
412        return deepcopy(self)
413
414    def run(self, x=0.0):
415        # type: (Union[float, (float, float), List[float]]) -> float
416        """
417        Evaluate the model
418
419        :param x: input q, or [q,phi]
420
421        :return: scattering function P(q)
422
423        **DEPRECATED**: use calculate_Iq instead
424        """
425        if isinstance(x, (list, tuple)):
426            # pylint: disable=unpacking-non-sequence
427            q, phi = x
428            return self.calculate_Iq([q*math.cos(phi)], [q*math.sin(phi)])[0]
429        else:
430            return self.calculate_Iq([x])[0]
431
432
433    def runXY(self, x=0.0):
434        # type: (Union[float, (float, float), List[float]]) -> float
435        """
436        Evaluate the model in cartesian coordinates
437
438        :param x: input q, or [qx, qy]
439
440        :return: scattering function P(q)
441
442        **DEPRECATED**: use calculate_Iq instead
443        """
444        if isinstance(x, (list, tuple)):
445            return self.calculate_Iq([x[0]], [x[1]])[0]
446        else:
447            return self.calculate_Iq([x])[0]
448
449    def evalDistribution(self, qdist):
450        # type: (Union[np.ndarray, Tuple[np.ndarray, np.ndarray], List[np.ndarray]]) -> np.ndarray
451        r"""
452        Evaluate a distribution of q-values.
453
454        :param qdist: array of q or a list of arrays [qx,qy]
455
456        * For 1D, a numpy array is expected as input
457
458        ::
459
460            evalDistribution(q)
461
462          where *q* is a numpy array.
463
464        * For 2D, a list of *[qx,qy]* is expected with 1D arrays as input
465
466        ::
467
468              qx = [ qx[0], qx[1], qx[2], ....]
469              qy = [ qy[0], qy[1], qy[2], ....]
470
471        If the model is 1D only, then
472
473        .. math::
474
475            q = \sqrt{q_x^2+q_y^2}
476
477        """
478        if isinstance(qdist, (list, tuple)):
479            # Check whether we have a list of ndarrays [qx,qy]
480            qx, qy = qdist
481            if not self._model_info.parameters.has_2d:
482                return self.calculate_Iq(np.sqrt(qx ** 2 + qy ** 2))
483            else:
484                return self.calculate_Iq(qx, qy)
485
486        elif isinstance(qdist, np.ndarray):
487            # We have a simple 1D distribution of q-values
488            return self.calculate_Iq(qdist)
489
490        else:
491            raise TypeError("evalDistribution expects q or [qx, qy], not %r"
492                            % type(qdist))
493
494    def calculate_Iq(self, qx, qy=None):
495        # type: (Sequence[float], Optional[Sequence[float]]) -> np.ndarray
496        """
497        Calculate Iq for one set of q with the current parameters.
498
499        If the model is 1D, use *q*.  If 2D, use *qx*, *qy*.
500
501        This should NOT be used for fitting since it copies the *q* vectors
502        to the card for each evaluation.
503        """
504        if self._model is None:
505            self._model = core.build_model(self._model_info)
506        if qy is not None:
507            q_vectors = [np.asarray(qx), np.asarray(qy)]
508        else:
509            q_vectors = [np.asarray(qx)]
510        calculator = self._model.make_kernel(q_vectors)
511        pairs = [self._get_weights(p)
512                 for p in self._model_info.parameters.call_parameters]
513        call_details, value = kernel.build_details(calculator, pairs)
514        result = calculator(call_details, value, cutoff=self.cutoff)
515        calculator.release()
516        return result
517
518    def calculate_ER(self):
519        # type: () -> float
520        """
521        Calculate the effective radius for P(q)*S(q)
522
523        :return: the value of the effective radius
524        """
525        if self._model_info.ER is None:
526            return 1.0
527        else:
528            value, weight = self._dispersion_mesh()
529            fv = self._model_info.ER(*value)
530            #print(values[0].shape, weights.shape, fv.shape)
531            return np.sum(weight * fv) / np.sum(weight)
532
533    def calculate_VR(self):
534        # type: () -> float
535        """
536        Calculate the volf ratio for P(q)*S(q)
537
538        :return: the value of the volf ratio
539        """
540        if self._model_info.VR is None:
541            return 1.0
542        else:
543            value, weight = self._dispersion_mesh()
544            whole, part = self._model_info.VR(*value)
545            return np.sum(weight * part) / np.sum(weight * whole)
546
547    def set_dispersion(self, parameter, dispersion):
548        # type: (str, weights.Dispersion) -> Dict[str, Any]
549        """
550        Set the dispersion object for a model parameter
551
552        :param parameter: name of the parameter [string]
553        :param dispersion: dispersion object of type Dispersion
554        """
555        if parameter in self.params:
556            # TODO: Store the disperser object directly in the model.
557            # The current method of relying on the sasview GUI to
558            # remember them is kind of funky.
559            # Note: can't seem to get disperser parameters from sasview
560            # (1) Could create a sasview model that has not yet # been
561            # converted, assign the disperser to one of its polydisperse
562            # parameters, then retrieve the disperser parameters from the
563            # sasview model.  (2) Could write a disperser parameter retriever
564            # in sasview.  (3) Could modify sasview to use sasmodels.weights
565            # dispersers.
566            # For now, rely on the fact that the sasview only ever uses
567            # new dispersers in the set_dispersion call and create a new
568            # one instead of trying to assign parameters.
569            from . import weights
570            disperser = weights.dispersers[dispersion.__class__.__name__]
571            dispersion = weights.MODELS[disperser]()
572            self.dispersion[parameter] = dispersion.get_pars()
573        else:
574            raise ValueError("%r is not a dispersity or orientation parameter")
575
576    def _dispersion_mesh(self):
577        # type: () -> List[Tuple[np.ndarray, np.ndarray]]
578        """
579        Create a mesh grid of dispersion parameters and weights.
580
581        Returns [p1,p2,...],w where pj is a vector of values for parameter j
582        and w is a vector containing the products for weights for each
583        parameter set in the vector.
584        """
585        pars = [self._get_weights(p)
586                for p in self._model_info.parameters.call_parameters
587                if p.type == 'volume']
588        return kernel.dispersion_mesh(self._model_info, pars)
589
590    def _get_weights(self, par):
591        # type: (Parameter) -> Tuple[np.ndarray, np.ndarray]
592        """
593        Return dispersion weights for parameter
594        """
595        if par.name not in self.params:
596            if par.name == self.multiplicity_info.control:
597                return [self.multiplicity], []
598            else:
599                return [np.NaN], []
600        elif par.polydisperse:
601            dis = self.dispersion[par.name]
602            value, weight = weights.get_weights(
603                dis['type'], dis['npts'], dis['width'], dis['nsigmas'],
604                self.params[par.name], par.limits, par.relative_pd)
605            return value, weight / np.sum(weight)
606        else:
607            return [self.params[par.name]], []
608
609def test_model():
610    # type: () -> float
611    """
612    Test that a sasview model (cylinder) can be run.
613    """
614    Cylinder = _make_standard_model('cylinder')
615    cylinder = Cylinder()
616    return cylinder.evalDistribution([0.1,0.1])
617
618def test_rpa():
619    # type: () -> float
620    """
621    Test that a sasview model (cylinder) can be run.
622    """
623    RPA = _make_standard_model('rpa')
624    rpa = RPA(3)
625    return rpa.evalDistribution([0.1,0.1])
626
627
628def test_model_list():
629    # type: () -> None
630    """
631    Make sure that all models build as sasview models.
632    """
633    from .exception import annotate_exception
634    for name in core.list_models():
635        try:
636            _make_standard_model(name)
637        except:
638            annotate_exception("when loading "+name)
639            raise
640
641if __name__ == "__main__":
642    print("cylinder(0.1,0.1)=%g"%test_model())
Note: See TracBrowser for help on using the repository browser.