source: sasmodels/sasmodels/sasview_model.py @ b32dafd

core_shell_microgelscostrafo411magnetic_modelrelease_v0.94release_v0.95ticket-1257-vesicle-productticket_1156ticket_1265_superballticket_822_more_unit_tests
Last change on this file since b32dafd was b32dafd, checked in by Paul Kienzle <pkienzle@…>, 6 years ago

lint

  • Property mode set to 100644
File size: 22.4 KB
Line 
1"""
2Sasview model constructor.
3
4Given a module defining an OpenCL kernel such as sasmodels.models.cylinder,
5create a sasview model class to run that kernel as follows::
6
7    from sasmodels.sasview_model import load_custom_model
8    CylinderModel = load_custom_model('sasmodels/models/cylinder.py')
9"""
10from __future__ import print_function
11
12import math
13from copy import deepcopy
14import collections
15import traceback
16import logging
17
18import numpy as np  # type: ignore
19
20from . import core
21from . import custom
22from . import generate
23from . import weights
24from . import modelinfo
25from .details import build_details, dispersion_mesh
26
27try:
28    from typing import Dict, Mapping, Any, Sequence, Tuple, NamedTuple, List, Optional, Union, Callable
29    from .modelinfo import ModelInfo, Parameter
30    from .kernel import KernelModel
31    MultiplicityInfoType = NamedTuple(
32        'MuliplicityInfo',
33        [("number", int), ("control", str), ("choices", List[str]),
34         ("x_axis_label", str)])
35    SasviewModelType = Callable[[int], "SasviewModel"]
36except ImportError:
37    pass
38
39# TODO: separate x_axis_label from multiplicity info
40MultiplicityInfo = collections.namedtuple(
41    'MultiplicityInfo',
42    ["number", "control", "choices", "x_axis_label"],
43)
44
45MODELS = {}
46def find_model(modelname):
47    # type: (str) -> SasviewModelType
48    """
49    Find a model by name.  If the model name ends in py, try loading it from
50    custom models, otherwise look for it in the list of builtin models.
51    """
52    # TODO: used by sum/product model to load an existing model
53    # TODO: doesn't handle custom models properly
54    if modelname.endswith('.py'):
55        return load_custom_model(modelname)
56    elif modelname in MODELS:
57        return MODELS[modelname]
58    else:
59        raise ValueError("unknown model %r"%modelname)
60
61
62# TODO: figure out how to say that the return type is a subclass
63def load_standard_models():
64    # type: () -> List[SasviewModelType]
65    """
66    Load and return the list of predefined models.
67
68    If there is an error loading a model, then a traceback is logged and the
69    model is not returned.
70    """
71    models = []
72    for name in core.list_models():
73        try:
74            MODELS[name] = _make_standard_model(name)
75            models.append(MODELS[name])
76        except Exception:
77            logging.error(traceback.format_exc())
78    return models
79
80
81def load_custom_model(path):
82    # type: (str) -> SasviewModelType
83    """
84    Load a custom model given the model path.
85    """
86    #print("load custom model", path)
87    kernel_module = custom.load_custom_kernel_module(path)
88    try:
89        model = kernel_module.Model
90    except AttributeError:
91        model_info = modelinfo.make_model_info(kernel_module)
92        model = _make_model_from_info(model_info)
93    MODELS[model.name] = model
94    return model
95
96
97def _make_standard_model(name):
98    # type: (str) -> SasviewModelType
99    """
100    Load the sasview model defined by *name*.
101
102    *name* can be a standard model name or a path to a custom model.
103
104    Returns a class that can be used directly as a sasview model.
105    """
106    kernel_module = generate.load_kernel_module(name)
107    model_info = modelinfo.make_model_info(kernel_module)
108    return _make_model_from_info(model_info)
109
110
111def _make_model_from_info(model_info):
112    # type: (ModelInfo) -> SasviewModelType
113    """
114    Convert *model_info* into a SasView model wrapper.
115    """
116    def __init__(self, multiplicity=None):
117        SasviewModel.__init__(self, multiplicity=multiplicity)
118    attrs = _generate_model_attributes(model_info)
119    attrs['__init__'] = __init__
120    ConstructedModel = type(model_info.name, (SasviewModel,), attrs) # type: SasviewModelType
121    return ConstructedModel
122
123def _generate_model_attributes(model_info):
124    # type: (ModelInfo) -> Dict[str, Any]
125    """
126    Generate the class attributes for the model.
127
128    This should include all the information necessary to query the model
129    details so that you do not need to instantiate a model to query it.
130
131    All the attributes should be immutable to avoid accidents.
132    """
133
134    # TODO: allow model to override axis labels input/output name/unit
135
136    # Process multiplicity
137    non_fittable = []  # type: List[str]
138    xlabel = model_info.profile_axes[0] if model_info.profile is not None else ""
139    variants = MultiplicityInfo(0, "", [], xlabel)
140    for p in model_info.parameters.kernel_parameters:
141        if p.name == model_info.control:
142            non_fittable.append(p.name)
143            variants = MultiplicityInfo(
144                len(p.choices) if p.choices else int(p.limits[1]),
145                p.name, p.choices, xlabel
146            )
147            break
148
149    # Only a single drop-down list parameter available
150    fun_list = []
151    for p in model_info.parameters.kernel_parameters:
152        if p.choices:
153            fun_list = p.choices
154            if p.length > 1:
155                non_fittable.extend(p.id+str(k) for k in range(1, p.length+1))
156            break
157
158    # Organize parameter sets
159    orientation_params = []
160    magnetic_params = []
161    fixed = []
162    for p in model_info.parameters.user_parameters():
163        if p.type == 'orientation':
164            orientation_params.append(p.name)
165            orientation_params.append(p.name+".width")
166            fixed.append(p.name+".width")
167        elif p.type == 'magnetic':
168            orientation_params.append(p.name)
169            magnetic_params.append(p.name)
170            fixed.append(p.name+".width")
171
172
173    # Build class dictionary
174    attrs = {}  # type: Dict[str, Any]
175    attrs['_model_info'] = model_info
176    attrs['name'] = model_info.name
177    attrs['id'] = model_info.id
178    attrs['description'] = model_info.description
179    attrs['category'] = model_info.category
180    attrs['is_structure_factor'] = model_info.structure_factor
181    attrs['is_form_factor'] = model_info.ER is not None
182    attrs['is_multiplicity_model'] = variants[0] > 1
183    attrs['multiplicity_info'] = variants
184    attrs['orientation_params'] = tuple(orientation_params)
185    attrs['magnetic_params'] = tuple(magnetic_params)
186    attrs['fixed'] = tuple(fixed)
187    attrs['non_fittable'] = tuple(non_fittable)
188    attrs['fun_list'] = tuple(fun_list)
189
190    return attrs
191
192class SasviewModel(object):
193    """
194    Sasview wrapper for opencl/ctypes model.
195    """
196    # Model parameters for the specific model are set in the class constructor
197    # via the _generate_model_attributes function, which subclasses
198    # SasviewModel.  They are included here for typing and documentation
199    # purposes.
200    _model = None       # type: KernelModel
201    _model_info = None  # type: ModelInfo
202    #: load/save name for the model
203    id = None           # type: str
204    #: display name for the model
205    name = None         # type: str
206    #: short model description
207    description = None  # type: str
208    #: default model category
209    category = None     # type: str
210
211    #: names of the orientation parameters in the order they appear
212    orientation_params = None # type: Sequence[str]
213    #: names of the magnetic parameters in the order they appear
214    magnetic_params = None    # type: Sequence[str]
215    #: names of the fittable parameters
216    fixed = None              # type: Sequence[str]
217    # TODO: the attribute fixed is ill-named
218
219    # Axis labels
220    input_name = "Q"
221    input_unit = "A^{-1}"
222    output_name = "Intensity"
223    output_unit = "cm^{-1}"
224
225    #: default cutoff for polydispersity
226    cutoff = 1e-5
227
228    # Note: Use non-mutable values for class attributes to avoid errors
229    #: parameters that are not fitted
230    non_fittable = ()        # type: Sequence[str]
231
232    #: True if model should appear as a structure factor
233    is_structure_factor = False
234    #: True if model should appear as a form factor
235    is_form_factor = False
236    #: True if model has multiplicity
237    is_multiplicity_model = False
238    #: Mulitplicity information
239    multiplicity_info = None # type: MultiplicityInfoType
240
241    # Per-instance variables
242    #: parameter {name: value} mapping
243    params = None      # type: Dict[str, float]
244    #: values for dispersion width, npts, nsigmas and type
245    dispersion = None  # type: Dict[str, Any]
246    #: units and limits for each parameter
247    details = None     # type: Dict[str, Sequence[Any]]
248    #                  # actual type is Dict[str, List[str, float, float]]
249    #: multiplicity value, or None if no multiplicity on the model
250    multiplicity = None     # type: Optional[int]
251    #: memory for polydispersity array if using ArrayDispersion (used by sasview).
252    _persistency_dict = None # type: Dict[str, Tuple[np.ndarray, np.ndarray]]
253
254    def __init__(self, multiplicity=None):
255        # type: (Optional[int]) -> None
256
257        # TODO: _persistency_dict to persistency_dict throughout sasview
258        # TODO: refactor multiplicity to encompass variants
259        # TODO: dispersion should be a class
260        # TODO: refactor multiplicity info
261        # TODO: separate profile view from multiplicity
262        # The button label, x and y axis labels and scale need to be under
263        # the control of the model, not the fit page.  Maximum flexibility,
264        # the fit page would supply the canvas and the profile could plot
265        # how it wants, but this assumes matplotlib.  Next level is that
266        # we provide some sort of data description including title, labels
267        # and lines to plot.
268
269        # Get the list of hidden parameters given the mulitplicity
270        # Don't include multiplicity in the list of parameters
271        self.multiplicity = multiplicity
272        if multiplicity is not None:
273            hidden = self._model_info.get_hidden_parameters(multiplicity)
274            hidden |= set([self.multiplicity_info.control])
275        else:
276            hidden = set()
277
278        self._persistency_dict = {}
279        self.params = collections.OrderedDict()
280        self.dispersion = collections.OrderedDict()
281        self.details = {}
282        for p in self._model_info.parameters.user_parameters():
283            if p.name in hidden:
284                continue
285            self.params[p.name] = p.default
286            self.details[p.id] = [p.units, p.limits[0], p.limits[1]]
287            if p.polydisperse:
288                self.details[p.id+".width"] = [
289                    "", 0.0, 1.0 if p.relative_pd else np.inf
290                ]
291                self.dispersion[p.name] = {
292                    'width': 0,
293                    'npts': 35,
294                    'nsigmas': 3,
295                    'type': 'gaussian',
296                }
297
298    def __get_state__(self):
299        # type: () -> Dict[str, Any]
300        state = self.__dict__.copy()
301        state.pop('_model')
302        # May need to reload model info on set state since it has pointers
303        # to python implementations of Iq, etc.
304        #state.pop('_model_info')
305        return state
306
307    def __set_state__(self, state):
308        # type: (Dict[str, Any]) -> None
309        self.__dict__ = state
310        self._model = None
311
312    def __str__(self):
313        # type: () -> str
314        """
315        :return: string representation
316        """
317        return self.name
318
319    def is_fittable(self, par_name):
320        # type: (str) -> bool
321        """
322        Check if a given parameter is fittable or not
323
324        :param par_name: the parameter name to check
325        """
326        return par_name in self.fixed
327        #For the future
328        #return self.params[str(par_name)].is_fittable()
329
330
331    def getProfile(self):
332        # type: () -> (np.ndarray, np.ndarray)
333        """
334        Get SLD profile
335
336        : return: (z, beta) where z is a list of depth of the transition points
337                beta is a list of the corresponding SLD values
338        """
339        args = {} # type: Dict[str, Any]
340        for p in self._model_info.parameters.kernel_parameters:
341            if p.id == self.multiplicity_info.control:
342                value = float(self.multiplicity)
343            elif p.length == 1:
344                value = self.params.get(p.id, np.NaN)
345            else:
346                value = np.array([self.params.get(p.id+str(k), np.NaN)
347                                  for k in range(1, p.length+1)])
348            args[p.id] = value
349
350        x, y = self._model_info.profile(**args)
351        return x, 1e-6*y
352
353    def setParam(self, name, value):
354        # type: (str, float) -> None
355        """
356        Set the value of a model parameter
357
358        :param name: name of the parameter
359        :param value: value of the parameter
360
361        """
362        # Look for dispersion parameters
363        toks = name.split('.')
364        if len(toks) == 2:
365            for item in self.dispersion.keys():
366                if item == toks[0]:
367                    for par in self.dispersion[item]:
368                        if par == toks[1]:
369                            self.dispersion[item][par] = value
370                            return
371        else:
372            # Look for standard parameter
373            for item in self.params.keys():
374                if item == name:
375                    self.params[item] = value
376                    return
377
378        raise ValueError("Model does not contain parameter %s" % name)
379
380    def getParam(self, name):
381        # type: (str) -> float
382        """
383        Set the value of a model parameter
384
385        :param name: name of the parameter
386
387        """
388        # Look for dispersion parameters
389        toks = name.split('.')
390        if len(toks) == 2:
391            for item in self.dispersion.keys():
392                if item == toks[0]:
393                    for par in self.dispersion[item]:
394                        if par == toks[1]:
395                            return self.dispersion[item][par]
396        else:
397            # Look for standard parameter
398            for item in self.params.keys():
399                if item == name:
400                    return self.params[item]
401
402        raise ValueError("Model does not contain parameter %s" % name)
403
404    def getParamList(self):
405        # type: () -> Sequence[str]
406        """
407        Return a list of all available parameters for the model
408        """
409        param_list = list(self.params.keys())
410        # WARNING: Extending the list with the dispersion parameters
411        param_list.extend(self.getDispParamList())
412        return param_list
413
414    def getDispParamList(self):
415        # type: () -> Sequence[str]
416        """
417        Return a list of polydispersity parameters for the model
418        """
419        # TODO: fix test so that parameter order doesn't matter
420        ret = ['%s.%s' % (p.name, ext)
421               for p in self._model_info.parameters.user_parameters()
422               for ext in ('npts', 'nsigmas', 'width')
423               if p.polydisperse]
424        #print(ret)
425        return ret
426
427    def clone(self):
428        # type: () -> "SasviewModel"
429        """ Return a identical copy of self """
430        return deepcopy(self)
431
432    def run(self, x=0.0):
433        # type: (Union[float, (float, float), List[float]]) -> float
434        """
435        Evaluate the model
436
437        :param x: input q, or [q,phi]
438
439        :return: scattering function P(q)
440
441        **DEPRECATED**: use calculate_Iq instead
442        """
443        if isinstance(x, (list, tuple)):
444            # pylint: disable=unpacking-non-sequence
445            q, phi = x
446            return self.calculate_Iq([q*math.cos(phi)], [q*math.sin(phi)])[0]
447        else:
448            return self.calculate_Iq([x])[0]
449
450
451    def runXY(self, x=0.0):
452        # type: (Union[float, (float, float), List[float]]) -> float
453        """
454        Evaluate the model in cartesian coordinates
455
456        :param x: input q, or [qx, qy]
457
458        :return: scattering function P(q)
459
460        **DEPRECATED**: use calculate_Iq instead
461        """
462        if isinstance(x, (list, tuple)):
463            return self.calculate_Iq([x[0]], [x[1]])[0]
464        else:
465            return self.calculate_Iq([x])[0]
466
467    def evalDistribution(self, qdist):
468        # type: (Union[np.ndarray, Tuple[np.ndarray, np.ndarray], List[np.ndarray]]) -> np.ndarray
469        r"""
470        Evaluate a distribution of q-values.
471
472        :param qdist: array of q or a list of arrays [qx,qy]
473
474        * For 1D, a numpy array is expected as input
475
476        ::
477
478            evalDistribution(q)
479
480          where *q* is a numpy array.
481
482        * For 2D, a list of *[qx,qy]* is expected with 1D arrays as input
483
484        ::
485
486              qx = [ qx[0], qx[1], qx[2], ....]
487              qy = [ qy[0], qy[1], qy[2], ....]
488
489        If the model is 1D only, then
490
491        .. math::
492
493            q = \sqrt{q_x^2+q_y^2}
494
495        """
496        if isinstance(qdist, (list, tuple)):
497            # Check whether we have a list of ndarrays [qx,qy]
498            qx, qy = qdist
499            if not self._model_info.parameters.has_2d:
500                return self.calculate_Iq(np.sqrt(qx ** 2 + qy ** 2))
501            else:
502                return self.calculate_Iq(qx, qy)
503
504        elif isinstance(qdist, np.ndarray):
505            # We have a simple 1D distribution of q-values
506            return self.calculate_Iq(qdist)
507
508        else:
509            raise TypeError("evalDistribution expects q or [qx, qy], not %r"
510                            % type(qdist))
511
512    def calculate_Iq(self, qx, qy=None):
513        # type: (Sequence[float], Optional[Sequence[float]]) -> np.ndarray
514        """
515        Calculate Iq for one set of q with the current parameters.
516
517        If the model is 1D, use *q*.  If 2D, use *qx*, *qy*.
518
519        This should NOT be used for fitting since it copies the *q* vectors
520        to the card for each evaluation.
521        """
522        #core.HAVE_OPENCL = False
523        if self._model is None:
524            self._model = core.build_model(self._model_info)
525        if qy is not None:
526            q_vectors = [np.asarray(qx), np.asarray(qy)]
527        else:
528            q_vectors = [np.asarray(qx)]
529        calculator = self._model.make_kernel(q_vectors)
530        parameters = self._model_info.parameters
531        pairs = [self._get_weights(p) for p in parameters.call_parameters]
532        call_details, values, is_magnetic = build_details(calculator, pairs)
533        #call_details.show()
534        #print("pairs", pairs)
535        #print("params", self.params)
536        #print("values", values)
537        #print("is_mag", is_magnetic)
538        result = calculator(call_details, values, cutoff=self.cutoff,
539                            magnetic=is_magnetic)
540        calculator.release()
541        return result
542
543    def calculate_ER(self):
544        # type: () -> float
545        """
546        Calculate the effective radius for P(q)*S(q)
547
548        :return: the value of the effective radius
549        """
550        if self._model_info.ER is None:
551            return 1.0
552        else:
553            value, weight = self._dispersion_mesh()
554            fv = self._model_info.ER(*value)
555            #print(values[0].shape, weights.shape, fv.shape)
556            return np.sum(weight * fv) / np.sum(weight)
557
558    def calculate_VR(self):
559        # type: () -> float
560        """
561        Calculate the volf ratio for P(q)*S(q)
562
563        :return: the value of the volf ratio
564        """
565        if self._model_info.VR is None:
566            return 1.0
567        else:
568            value, weight = self._dispersion_mesh()
569            whole, part = self._model_info.VR(*value)
570            return np.sum(weight * part) / np.sum(weight * whole)
571
572    def set_dispersion(self, parameter, dispersion):
573        # type: (str, weights.Dispersion) -> Dict[str, Any]
574        """
575        Set the dispersion object for a model parameter
576
577        :param parameter: name of the parameter [string]
578        :param dispersion: dispersion object of type Dispersion
579        """
580        if parameter in self.params:
581            # TODO: Store the disperser object directly in the model.
582            # The current method of relying on the sasview GUI to
583            # remember them is kind of funky.
584            # Note: can't seem to get disperser parameters from sasview
585            # (1) Could create a sasview model that has not yet # been
586            # converted, assign the disperser to one of its polydisperse
587            # parameters, then retrieve the disperser parameters from the
588            # sasview model.  (2) Could write a disperser parameter retriever
589            # in sasview.  (3) Could modify sasview to use sasmodels.weights
590            # dispersers.
591            # For now, rely on the fact that the sasview only ever uses
592            # new dispersers in the set_dispersion call and create a new
593            # one instead of trying to assign parameters.
594            disperser = weights.dispersers[dispersion.__class__.__name__]
595            dispersion = weights.MODELS[disperser]()
596            self.dispersion[parameter] = dispersion.get_pars()
597        else:
598            raise ValueError("%r is not a dispersity or orientation parameter")
599
600    def _dispersion_mesh(self):
601        # type: () -> List[Tuple[np.ndarray, np.ndarray]]
602        """
603        Create a mesh grid of dispersion parameters and weights.
604
605        Returns [p1,p2,...],w where pj is a vector of values for parameter j
606        and w is a vector containing the products for weights for each
607        parameter set in the vector.
608        """
609        pars = [self._get_weights(p)
610                for p in self._model_info.parameters.call_parameters
611                if p.type == 'volume']
612        return dispersion_mesh(self._model_info, pars)
613
614    def _get_weights(self, par):
615        # type: (Parameter) -> Tuple[np.ndarray, np.ndarray]
616        """
617        Return dispersion weights for parameter
618        """
619        if par.name not in self.params:
620            if par.name == self.multiplicity_info.control:
621                return [self.multiplicity], [1.0]
622            else:
623                return [np.NaN], [1.0]
624        elif par.polydisperse:
625            dis = self.dispersion[par.name]
626            value, weight = weights.get_weights(
627                dis['type'], dis['npts'], dis['width'], dis['nsigmas'],
628                self.params[par.name], par.limits, par.relative_pd)
629            return value, weight / np.sum(weight)
630        else:
631            return [self.params[par.name]], [1.0]
632
633def test_model():
634    # type: () -> float
635    """
636    Test that a sasview model (cylinder) can be run.
637    """
638    Cylinder = _make_standard_model('cylinder')
639    cylinder = Cylinder()
640    return cylinder.evalDistribution([0.1, 0.1])
641
642def test_rpa():
643    # type: () -> float
644    """
645    Test that a sasview model (cylinder) can be run.
646    """
647    RPA = _make_standard_model('rpa')
648    rpa = RPA(3)
649    return rpa.evalDistribution([0.1, 0.1])
650
651
652def test_model_list():
653    # type: () -> None
654    """
655    Make sure that all models build as sasview models.
656    """
657    from .exception import annotate_exception
658    for name in core.list_models():
659        try:
660            _make_standard_model(name)
661        except:
662            annotate_exception("when loading "+name)
663            raise
664
665if __name__ == "__main__":
666    print("cylinder(0.1,0.1)=%g"%test_model())
Note: See TracBrowser for help on using the repository browser.