source: sasmodels/sasmodels/sasview_model.py @ e7fe459

core_shell_microgelscostrafo411magnetic_modelrelease_v0.94release_v0.95ticket-1257-vesicle-productticket_1156ticket_1265_superballticket_822_more_unit_tests
Last change on this file since e7fe459 was e7fe459, checked in by Paul Kienzle <pkienzle@…>, 8 years ago

spherical sld: sasview assumes profile needs to be scaled by 1e6

  • Property mode set to 100644
File size: 22.0 KB
Line 
1"""
2Sasview model constructor.
3
4Given a module defining an OpenCL kernel such as sasmodels.models.cylinder,
5create a sasview model class to run that kernel as follows::
6
7    from sasmodels.sasview_model import load_custom_model
8    CylinderModel = load_custom_model('sasmodels/models/cylinder.py')
9"""
10from __future__ import print_function
11
12import math
13from copy import deepcopy
14import collections
15import traceback
16import logging
17
18import numpy as np  # type: ignore
19
20from . import core
21from . import custom
22from . import generate
23from . import weights
24from . import modelinfo
25from .details import build_details, dispersion_mesh
26
27try:
28    from typing import Dict, Mapping, Any, Sequence, Tuple, NamedTuple, List, Optional, Union, Callable
29    from .modelinfo import ModelInfo, Parameter
30    from .kernel import KernelModel
31    MultiplicityInfoType = NamedTuple(
32        'MuliplicityInfo',
33        [("number", int), ("control", str), ("choices", List[str]),
34         ("x_axis_label", str)])
35    SasviewModelType = Callable[[int], "SasviewModel"]
36except ImportError:
37    pass
38
39# TODO: separate x_axis_label from multiplicity info
40# The profile x-axis label belongs with the profile generating function
41MultiplicityInfo = collections.namedtuple(
42    'MultiplicityInfo',
43    ["number", "control", "choices", "x_axis_label"],
44)
45
46MODELS = {}
47def find_model(modelname):
48    # TODO: used by sum/product model to load an existing model
49    # TODO: doesn't handle custom models properly
50    if modelname.endswith('.py'):
51        return load_custom_model(modelname)
52    elif modelname in MODELS:
53        return MODELS[modelname]
54    else:
55        raise ValueError("unknown model %r"%modelname)
56
57
58# TODO: figure out how to say that the return type is a subclass
59def load_standard_models():
60    # type: () -> List[SasviewModelType]
61    """
62    Load and return the list of predefined models.
63
64    If there is an error loading a model, then a traceback is logged and the
65    model is not returned.
66    """
67    models = []
68    for name in core.list_models():
69        try:
70            MODELS[name] = _make_standard_model(name)
71            models.append(MODELS[name])
72        except Exception:
73            logging.error(traceback.format_exc())
74    return models
75
76
77def load_custom_model(path):
78    # type: (str) -> SasviewModelType
79    """
80    Load a custom model given the model path.
81    """
82    #print("load custom model", path)
83    kernel_module = custom.load_custom_kernel_module(path)
84    try:
85        model = kernel_module.Model
86    except AttributeError:
87        model_info = modelinfo.make_model_info(kernel_module)
88        model = _make_model_from_info(model_info)
89    MODELS[model.name] = model
90    return model
91
92
93def _make_standard_model(name):
94    # type: (str) -> SasviewModelType
95    """
96    Load the sasview model defined by *name*.
97
98    *name* can be a standard model name or a path to a custom model.
99
100    Returns a class that can be used directly as a sasview model.
101    """
102    kernel_module = generate.load_kernel_module(name)
103    model_info = modelinfo.make_model_info(kernel_module)
104    return _make_model_from_info(model_info)
105
106
107def _make_model_from_info(model_info):
108    # type: (ModelInfo) -> SasviewModelType
109    """
110    Convert *model_info* into a SasView model wrapper.
111    """
112    def __init__(self, multiplicity=None):
113        SasviewModel.__init__(self, multiplicity=multiplicity)
114    attrs = _generate_model_attributes(model_info)
115    attrs['__init__'] = __init__
116    ConstructedModel = type(model_info.name, (SasviewModel,), attrs) # type: SasviewModelType
117    return ConstructedModel
118
119def _generate_model_attributes(model_info):
120    # type: (ModelInfo) -> Dict[str, Any]
121    """
122    Generate the class attributes for the model.
123
124    This should include all the information necessary to query the model
125    details so that you do not need to instantiate a model to query it.
126
127    All the attributes should be immutable to avoid accidents.
128    """
129
130    # TODO: allow model to override axis labels input/output name/unit
131
132    # Process multiplicity
133    non_fittable = []  # type: List[str]
134    xlabel = model_info.profile_axes[0] if model_info.profile is not None else ""
135    variants = MultiplicityInfo(0, "", [], xlabel)
136    for p in model_info.parameters.kernel_parameters:
137        if p.name == model_info.control:
138            non_fittable.append(p.name)
139            variants = MultiplicityInfo(
140                len(p.choices) if p.choices else int(p.limits[1]),
141                p.name, p.choices, xlabel
142            )
143            break
144
145    # Organize parameter sets
146    orientation_params = []
147    magnetic_params = []
148    fixed = []
149    for p in model_info.parameters.user_parameters():
150        if p.type == 'orientation':
151            orientation_params.append(p.name)
152            orientation_params.append(p.name+".width")
153            fixed.append(p.name+".width")
154        elif p.type == 'magnetic':
155            orientation_params.append(p.name)
156            magnetic_params.append(p.name)
157            fixed.append(p.name+".width")
158
159
160    # Build class dictionary
161    attrs = {}  # type: Dict[str, Any]
162    attrs['_model_info'] = model_info
163    attrs['name'] = model_info.name
164    attrs['id'] = model_info.id
165    attrs['description'] = model_info.description
166    attrs['category'] = model_info.category
167    attrs['is_structure_factor'] = model_info.structure_factor
168    attrs['is_form_factor'] = model_info.ER is not None
169    attrs['is_multiplicity_model'] = variants[0] > 1
170    attrs['multiplicity_info'] = variants
171    attrs['orientation_params'] = tuple(orientation_params)
172    attrs['magnetic_params'] = tuple(magnetic_params)
173    attrs['fixed'] = tuple(fixed)
174    attrs['non_fittable'] = tuple(non_fittable)
175
176    return attrs
177
178class SasviewModel(object):
179    """
180    Sasview wrapper for opencl/ctypes model.
181    """
182    # Model parameters for the specific model are set in the class constructor
183    # via the _generate_model_attributes function, which subclasses
184    # SasviewModel.  They are included here for typing and documentation
185    # purposes.
186    _model = None       # type: KernelModel
187    _model_info = None  # type: ModelInfo
188    #: load/save name for the model
189    id = None           # type: str
190    #: display name for the model
191    name = None         # type: str
192    #: short model description
193    description = None  # type: str
194    #: default model category
195    category = None     # type: str
196
197    #: names of the orientation parameters in the order they appear
198    orientation_params = None # type: Sequence[str]
199    #: names of the magnetic parameters in the order they appear
200    magnetic_params = None    # type: Sequence[str]
201    #: names of the fittable parameters
202    fixed = None              # type: Sequence[str]
203    # TODO: the attribute fixed is ill-named
204
205    # Axis labels
206    input_name = "Q"
207    input_unit = "A^{-1}"
208    output_name = "Intensity"
209    output_unit = "cm^{-1}"
210
211    #: default cutoff for polydispersity
212    cutoff = 1e-5
213
214    # Note: Use non-mutable values for class attributes to avoid errors
215    #: parameters that are not fitted
216    non_fittable = ()        # type: Sequence[str]
217
218    #: True if model should appear as a structure factor
219    is_structure_factor = False
220    #: True if model should appear as a form factor
221    is_form_factor = False
222    #: True if model has multiplicity
223    is_multiplicity_model = False
224    #: Mulitplicity information
225    multiplicity_info = None # type: MultiplicityInfoType
226
227    # Per-instance variables
228    #: parameter {name: value} mapping
229    params = None      # type: Dict[str, float]
230    #: values for dispersion width, npts, nsigmas and type
231    dispersion = None  # type: Dict[str, Any]
232    #: units and limits for each parameter
233    details = None     # type: Dict[str, Sequence[Any]]
234    #                  # actual type is Dict[str, List[str, float, float]]
235    #: multiplicity value, or None if no multiplicity on the model
236    multiplicity = None     # type: Optional[int]
237    #: memory for polydispersity array if using ArrayDispersion (used by sasview).
238    _persistency_dict = None # type: Dict[str, Tuple[np.ndarray, np.ndarray]]
239
240    def __init__(self, multiplicity=None):
241        # type: (Optional[int]) -> None
242
243        # TODO: _persistency_dict to persistency_dict throughout sasview
244        # TODO: refactor multiplicity to encompass variants
245        # TODO: dispersion should be a class
246        # TODO: refactor multiplicity info
247        # TODO: separate profile view from multiplicity
248        # The button label, x and y axis labels and scale need to be under
249        # the control of the model, not the fit page.  Maximum flexibility,
250        # the fit page would supply the canvas and the profile could plot
251        # how it wants, but this assumes matplotlib.  Next level is that
252        # we provide some sort of data description including title, labels
253        # and lines to plot.
254
255        # Get the list of hidden parameters given the mulitplicity
256        # Don't include multiplicity in the list of parameters
257        self.multiplicity = multiplicity
258        if multiplicity is not None:
259            hidden = self._model_info.get_hidden_parameters(multiplicity)
260            hidden |= set([self.multiplicity_info.control])
261        else:
262            hidden = set()
263
264        self._persistency_dict = {}
265        self.params = collections.OrderedDict()
266        self.dispersion = collections.OrderedDict()
267        self.details = {}
268        for p in self._model_info.parameters.user_parameters():
269            if p.name in hidden:
270                continue
271            self.params[p.name] = p.default
272            self.details[p.id] = [p.units, p.limits[0], p.limits[1]]
273            if p.polydisperse:
274                self.details[p.id+".width"] = [
275                    "", 0.0, 1.0 if p.relative_pd else np.inf
276                ]
277                self.dispersion[p.name] = {
278                    'width': 0,
279                    'npts': 35,
280                    'nsigmas': 3,
281                    'type': 'gaussian',
282                }
283
284    def __get_state__(self):
285        # type: () -> Dict[str, Any]
286        state = self.__dict__.copy()
287        state.pop('_model')
288        # May need to reload model info on set state since it has pointers
289        # to python implementations of Iq, etc.
290        #state.pop('_model_info')
291        return state
292
293    def __set_state__(self, state):
294        # type: (Dict[str, Any]) -> None
295        self.__dict__ = state
296        self._model = None
297
298    def __str__(self):
299        # type: () -> str
300        """
301        :return: string representation
302        """
303        return self.name
304
305    def is_fittable(self, par_name):
306        # type: (str) -> bool
307        """
308        Check if a given parameter is fittable or not
309
310        :param par_name: the parameter name to check
311        """
312        return par_name in self.fixed
313        #For the future
314        #return self.params[str(par_name)].is_fittable()
315
316
317    def getProfile(self):
318        # type: () -> (np.ndarray, np.ndarray)
319        """
320        Get SLD profile
321
322        : return: (z, beta) where z is a list of depth of the transition points
323                beta is a list of the corresponding SLD values
324        """
325        args = {} # type: Dict[str, Any]
326        for p in self._model_info.parameters.kernel_parameters:
327            if p.id == self.multiplicity_info.control:
328                value = float(self.multiplicity)
329            elif p.length == 1:
330                value = self.params.get(p.id, np.NaN)
331            else:
332                value = np.array([self.params.get(p.id+str(k), np.NaN)
333                                  for k in range(1,p.length+1)])
334            args[p.id] = value
335
336        x, y = self._model_info.profile(**args)
337        return x, 1e-6*y
338
339    def setParam(self, name, value):
340        # type: (str, float) -> None
341        """
342        Set the value of a model parameter
343
344        :param name: name of the parameter
345        :param value: value of the parameter
346
347        """
348        # Look for dispersion parameters
349        toks = name.split('.')
350        if len(toks) == 2:
351            for item in self.dispersion.keys():
352                if item == toks[0]:
353                    for par in self.dispersion[item]:
354                        if par == toks[1]:
355                            self.dispersion[item][par] = value
356                            return
357        else:
358            # Look for standard parameter
359            for item in self.params.keys():
360                if item == name:
361                    self.params[item] = value
362                    return
363
364        raise ValueError("Model does not contain parameter %s" % name)
365
366    def getParam(self, name):
367        # type: (str) -> float
368        """
369        Set the value of a model parameter
370
371        :param name: name of the parameter
372
373        """
374        # Look for dispersion parameters
375        toks = name.split('.')
376        if len(toks) == 2:
377            for item in self.dispersion.keys():
378                if item == toks[0]:
379                    for par in self.dispersion[item]:
380                        if par == toks[1]:
381                            return self.dispersion[item][par]
382        else:
383            # Look for standard parameter
384            for item in self.params.keys():
385                if item == name:
386                    return self.params[item]
387
388        raise ValueError("Model does not contain parameter %s" % name)
389
390    def getParamList(self):
391        # type: () -> Sequence[str]
392        """
393        Return a list of all available parameters for the model
394        """
395        param_list = list(self.params.keys())
396        # WARNING: Extending the list with the dispersion parameters
397        param_list.extend(self.getDispParamList())
398        return param_list
399
400    def getDispParamList(self):
401        # type: () -> Sequence[str]
402        """
403        Return a list of polydispersity parameters for the model
404        """
405        # TODO: fix test so that parameter order doesn't matter
406        ret = ['%s.%s' % (p.name, ext)
407               for p in self._model_info.parameters.user_parameters()
408               for ext in ('npts', 'nsigmas', 'width')
409               if p.polydisperse]
410        #print(ret)
411        return ret
412
413    def clone(self):
414        # type: () -> "SasviewModel"
415        """ Return a identical copy of self """
416        return deepcopy(self)
417
418    def run(self, x=0.0):
419        # type: (Union[float, (float, float), List[float]]) -> float
420        """
421        Evaluate the model
422
423        :param x: input q, or [q,phi]
424
425        :return: scattering function P(q)
426
427        **DEPRECATED**: use calculate_Iq instead
428        """
429        if isinstance(x, (list, tuple)):
430            # pylint: disable=unpacking-non-sequence
431            q, phi = x
432            return self.calculate_Iq([q*math.cos(phi)], [q*math.sin(phi)])[0]
433        else:
434            return self.calculate_Iq([x])[0]
435
436
437    def runXY(self, x=0.0):
438        # type: (Union[float, (float, float), List[float]]) -> float
439        """
440        Evaluate the model in cartesian coordinates
441
442        :param x: input q, or [qx, qy]
443
444        :return: scattering function P(q)
445
446        **DEPRECATED**: use calculate_Iq instead
447        """
448        if isinstance(x, (list, tuple)):
449            return self.calculate_Iq([x[0]], [x[1]])[0]
450        else:
451            return self.calculate_Iq([x])[0]
452
453    def evalDistribution(self, qdist):
454        # type: (Union[np.ndarray, Tuple[np.ndarray, np.ndarray], List[np.ndarray]]) -> np.ndarray
455        r"""
456        Evaluate a distribution of q-values.
457
458        :param qdist: array of q or a list of arrays [qx,qy]
459
460        * For 1D, a numpy array is expected as input
461
462        ::
463
464            evalDistribution(q)
465
466          where *q* is a numpy array.
467
468        * For 2D, a list of *[qx,qy]* is expected with 1D arrays as input
469
470        ::
471
472              qx = [ qx[0], qx[1], qx[2], ....]
473              qy = [ qy[0], qy[1], qy[2], ....]
474
475        If the model is 1D only, then
476
477        .. math::
478
479            q = \sqrt{q_x^2+q_y^2}
480
481        """
482        if isinstance(qdist, (list, tuple)):
483            # Check whether we have a list of ndarrays [qx,qy]
484            qx, qy = qdist
485            if not self._model_info.parameters.has_2d:
486                return self.calculate_Iq(np.sqrt(qx ** 2 + qy ** 2))
487            else:
488                return self.calculate_Iq(qx, qy)
489
490        elif isinstance(qdist, np.ndarray):
491            # We have a simple 1D distribution of q-values
492            return self.calculate_Iq(qdist)
493
494        else:
495            raise TypeError("evalDistribution expects q or [qx, qy], not %r"
496                            % type(qdist))
497
498    def calculate_Iq(self, qx, qy=None):
499        # type: (Sequence[float], Optional[Sequence[float]]) -> np.ndarray
500        """
501        Calculate Iq for one set of q with the current parameters.
502
503        If the model is 1D, use *q*.  If 2D, use *qx*, *qy*.
504
505        This should NOT be used for fitting since it copies the *q* vectors
506        to the card for each evaluation.
507        """
508        #core.HAVE_OPENCL = False
509        if self._model is None:
510            self._model = core.build_model(self._model_info)
511        if qy is not None:
512            q_vectors = [np.asarray(qx), np.asarray(qy)]
513        else:
514            q_vectors = [np.asarray(qx)]
515        calculator = self._model.make_kernel(q_vectors)
516        parameters = self._model_info.parameters
517        pairs = [self._get_weights(p) for p in parameters.call_parameters]
518        call_details, values, is_magnetic = build_details(calculator, pairs)
519        #call_details.show()
520        #print("pairs", pairs)
521        #print("params", self.params)
522        #print("values", values)
523        #print("is_mag", is_magnetic)
524        result = calculator(call_details, values, cutoff=self.cutoff,
525                            magnetic=is_magnetic)
526        calculator.release()
527        return result
528
529    def calculate_ER(self):
530        # type: () -> float
531        """
532        Calculate the effective radius for P(q)*S(q)
533
534        :return: the value of the effective radius
535        """
536        if self._model_info.ER is None:
537            return 1.0
538        else:
539            value, weight = self._dispersion_mesh()
540            fv = self._model_info.ER(*value)
541            #print(values[0].shape, weights.shape, fv.shape)
542            return np.sum(weight * fv) / np.sum(weight)
543
544    def calculate_VR(self):
545        # type: () -> float
546        """
547        Calculate the volf ratio for P(q)*S(q)
548
549        :return: the value of the volf ratio
550        """
551        if self._model_info.VR is None:
552            return 1.0
553        else:
554            value, weight = self._dispersion_mesh()
555            whole, part = self._model_info.VR(*value)
556            return np.sum(weight * part) / np.sum(weight * whole)
557
558    def set_dispersion(self, parameter, dispersion):
559        # type: (str, weights.Dispersion) -> Dict[str, Any]
560        """
561        Set the dispersion object for a model parameter
562
563        :param parameter: name of the parameter [string]
564        :param dispersion: dispersion object of type Dispersion
565        """
566        if parameter in self.params:
567            # TODO: Store the disperser object directly in the model.
568            # The current method of relying on the sasview GUI to
569            # remember them is kind of funky.
570            # Note: can't seem to get disperser parameters from sasview
571            # (1) Could create a sasview model that has not yet # been
572            # converted, assign the disperser to one of its polydisperse
573            # parameters, then retrieve the disperser parameters from the
574            # sasview model.  (2) Could write a disperser parameter retriever
575            # in sasview.  (3) Could modify sasview to use sasmodels.weights
576            # dispersers.
577            # For now, rely on the fact that the sasview only ever uses
578            # new dispersers in the set_dispersion call and create a new
579            # one instead of trying to assign parameters.
580            from . import weights
581            disperser = weights.dispersers[dispersion.__class__.__name__]
582            dispersion = weights.MODELS[disperser]()
583            self.dispersion[parameter] = dispersion.get_pars()
584        else:
585            raise ValueError("%r is not a dispersity or orientation parameter")
586
587    def _dispersion_mesh(self):
588        # type: () -> List[Tuple[np.ndarray, np.ndarray]]
589        """
590        Create a mesh grid of dispersion parameters and weights.
591
592        Returns [p1,p2,...],w where pj is a vector of values for parameter j
593        and w is a vector containing the products for weights for each
594        parameter set in the vector.
595        """
596        pars = [self._get_weights(p)
597                for p in self._model_info.parameters.call_parameters
598                if p.type == 'volume']
599        return dispersion_mesh(self._model_info, pars)
600
601    def _get_weights(self, par):
602        # type: (Parameter) -> Tuple[np.ndarray, np.ndarray]
603        """
604        Return dispersion weights for parameter
605        """
606        if par.name not in self.params:
607            if par.name == self.multiplicity_info.control:
608                return [self.multiplicity], [1.0]
609            else:
610                return [np.NaN], [1.0]
611        elif par.polydisperse:
612            dis = self.dispersion[par.name]
613            value, weight = weights.get_weights(
614                dis['type'], dis['npts'], dis['width'], dis['nsigmas'],
615                self.params[par.name], par.limits, par.relative_pd)
616            return value, weight / np.sum(weight)
617        else:
618            return [self.params[par.name]], [1.0]
619
620def test_model():
621    # type: () -> float
622    """
623    Test that a sasview model (cylinder) can be run.
624    """
625    Cylinder = _make_standard_model('cylinder')
626    cylinder = Cylinder()
627    return cylinder.evalDistribution([0.1,0.1])
628
629def test_rpa():
630    # type: () -> float
631    """
632    Test that a sasview model (cylinder) can be run.
633    """
634    RPA = _make_standard_model('rpa')
635    rpa = RPA(3)
636    return rpa.evalDistribution([0.1,0.1])
637
638
639def test_model_list():
640    # type: () -> None
641    """
642    Make sure that all models build as sasview models.
643    """
644    from .exception import annotate_exception
645    for name in core.list_models():
646        try:
647            _make_standard_model(name)
648        except:
649            annotate_exception("when loading "+name)
650            raise
651
652if __name__ == "__main__":
653    print("cylinder(0.1,0.1)=%g"%test_model())
Note: See TracBrowser for help on using the repository browser.