source: sasmodels/sasmodels/sasview_model.py @ 839fd68

core_shell_microgelsmagnetic_modelticket-1257-vesicle-productticket_1156ticket_1265_superballticket_822_more_unit_tests
Last change on this file since 839fd68 was 839fd68, checked in by Paul Kienzle <pkienzle@…>, 6 years ago

cache modules by timestamp in the custom kernel loader

  • Property mode set to 100644
File size: 31.1 KB
Line 
1"""
2Sasview model constructor.
3
4Given a module defining an OpenCL kernel such as sasmodels.models.cylinder,
5create a sasview model class to run that kernel as follows::
6
7    from sasmodels.sasview_model import load_custom_model
8    CylinderModel = load_custom_model('sasmodels/models/cylinder.py')
9"""
10from __future__ import print_function
11
12import math
13from copy import deepcopy
14import collections
15import traceback
16import logging
17from os.path import basename, splitext, abspath, getmtime
18try:
19    import _thread as thread
20except ImportError:
21    import thread
22
23import numpy as np  # type: ignore
24
25from . import core
26from . import custom
27from . import product
28from . import generate
29from . import weights
30from . import modelinfo
31from .details import make_kernel_args, dispersion_mesh
32
33# pylint: disable=unused-import
34try:
35    from typing import (Dict, Mapping, Any, Sequence, Tuple, NamedTuple,
36                        List, Optional, Union, Callable)
37    from .modelinfo import ModelInfo, Parameter
38    from .kernel import KernelModel
39    MultiplicityInfoType = NamedTuple(
40        'MultiplicityInfo',
41        [("number", int), ("control", str), ("choices", List[str]),
42         ("x_axis_label", str)])
43    SasviewModelType = Callable[[int], "SasviewModel"]
44except ImportError:
45    pass
46# pylint: enable=unused-import
47
48logger = logging.getLogger(__name__)
49
50calculation_lock = thread.allocate_lock()
51
52#: True if pre-existing plugins, with the old names and parameters, should
53#: continue to be supported.
54SUPPORT_OLD_STYLE_PLUGINS = True
55
56# TODO: separate x_axis_label from multiplicity info
57MultiplicityInfo = collections.namedtuple(
58    'MultiplicityInfo',
59    ["number", "control", "choices", "x_axis_label"],
60)
61
62#: set of defined models (standard and custom)
63MODELS = {}  # type: Dict[str, SasviewModelType]
64# TODO: remove unused MODEL_BY_PATH cache once sasview no longer references it
65#: custom model {path: model} mapping so we can check timestamps
66MODEL_BY_PATH = {}  # type: Dict[str, SasviewModelType]
67
68def find_model(modelname):
69    # type: (str) -> SasviewModelType
70    """
71    Find a model by name.  If the model name ends in py, try loading it from
72    custom models, otherwise look for it in the list of builtin models.
73    """
74    # TODO: used by sum/product model to load an existing model
75    # TODO: doesn't handle custom models properly
76    if modelname.endswith('.py'):
77        return load_custom_model(modelname)
78    elif modelname in MODELS:
79        return MODELS[modelname]
80    else:
81        raise ValueError("unknown model %r"%modelname)
82
83
84# TODO: figure out how to say that the return type is a subclass
85def load_standard_models():
86    # type: () -> List[SasviewModelType]
87    """
88    Load and return the list of predefined models.
89
90    If there is an error loading a model, then a traceback is logged and the
91    model is not returned.
92    """
93    for name in core.list_models():
94        try:
95            MODELS[name] = _make_standard_model(name)
96        except Exception:
97            logger.error(traceback.format_exc())
98    if SUPPORT_OLD_STYLE_PLUGINS:
99        _register_old_models()
100
101    return list(MODELS.values())
102
103
104def load_custom_model(path):
105    # type: (str) -> SasviewModelType
106    """
107    Load a custom model given the model path.
108    """
109    #logger.info("Loading model %s", path)
110    kernel_module = custom.load_custom_kernel_module(path)
111    if hasattr(kernel_module, 'Model'):
112        model = kernel_module.Model
113        # Old style models do not set the name in the class attributes, so
114        # set it here; this name will be overridden when the object is created
115        # with an instance variable that has the same value.
116        if model.name == "":
117            model.name = splitext(basename(path))[0]
118        if not hasattr(model, 'filename'):
119            model.filename = abspath(kernel_module.__file__).replace('.pyc', '.py')
120        if not hasattr(model, 'id'):
121            model.id = splitext(basename(model.filename))[0]
122    else:
123        model_info = modelinfo.make_model_info(kernel_module)
124        model = make_model_from_info(model_info)
125
126    # If a model name already exists and we are loading a different model,
127    # use the model file name as the model name.
128    if model.name in MODELS and not model.filename == MODELS[model.name].filename:
129        _previous_name = model.name
130        model.name = model.id
131
132        # If the new model name is still in the model list (for instance,
133        # if we put a cylinder.py in our plug-in directory), then append
134        # an identifier.
135        if model.name in MODELS and not model.filename == MODELS[model.name].filename:
136            model.name = model.id + '_user'
137        logger.info("Model %s already exists: using %s [%s]",
138                    _previous_name, model.name, model.filename)
139
140    MODELS[model.name] = model
141    return model
142
143
144def make_model_from_info(model_info):
145    # type: (ModelInfo) -> SasviewModelType
146    """
147    Convert *model_info* into a SasView model wrapper.
148    """
149    def __init__(self, multiplicity=None):
150        SasviewModel.__init__(self, multiplicity=multiplicity)
151    attrs = _generate_model_attributes(model_info)
152    attrs['__init__'] = __init__
153    attrs['filename'] = model_info.filename
154    ConstructedModel = type(model_info.name, (SasviewModel,), attrs) # type: SasviewModelType
155    return ConstructedModel
156
157
158def _make_standard_model(name):
159    # type: (str) -> SasviewModelType
160    """
161    Load the sasview model defined by *name*.
162
163    *name* can be a standard model name or a path to a custom model.
164
165    Returns a class that can be used directly as a sasview model.
166    """
167    kernel_module = generate.load_kernel_module(name)
168    model_info = modelinfo.make_model_info(kernel_module)
169    return make_model_from_info(model_info)
170
171
172def _register_old_models():
173    # type: () -> None
174    """
175    Place the new models into sasview under the old names.
176
177    Monkey patch sas.sascalc.fit as sas.models so that sas.models.pluginmodel
178    is available to the plugin modules.
179    """
180    import sys
181    import sas   # needed in order to set sas.models
182    import sas.sascalc.fit
183    sys.modules['sas.models'] = sas.sascalc.fit
184    sas.models = sas.sascalc.fit
185    import sas.models
186    from sasmodels.conversion_table import CONVERSION_TABLE
187
188    for new_name, conversion in CONVERSION_TABLE.get((3, 1, 2), {}).items():
189        # CoreShellEllipsoidModel => core_shell_ellipsoid:1
190        new_name = new_name.split(':')[0]
191        old_name = conversion[0] if len(conversion) < 3 else conversion[2]
192        module_attrs = {old_name: find_model(new_name)}
193        ConstructedModule = type(old_name, (), module_attrs)
194        old_path = 'sas.models.' + old_name
195        setattr(sas.models, old_path, ConstructedModule)
196        sys.modules[old_path] = ConstructedModule
197
198
199def MultiplicationModel(form_factor, structure_factor):
200    # type: ("SasviewModel", "SasviewModel") -> "SasviewModel"
201    """
202    Returns a constructed product model from form_factor and structure_factor.
203    """
204    model_info = product.make_product_info(form_factor._model_info,
205                                           structure_factor._model_info)
206    ConstructedModel = make_model_from_info(model_info)
207    return ConstructedModel(form_factor.multiplicity)
208
209
210def _generate_model_attributes(model_info):
211    # type: (ModelInfo) -> Dict[str, Any]
212    """
213    Generate the class attributes for the model.
214
215    This should include all the information necessary to query the model
216    details so that you do not need to instantiate a model to query it.
217
218    All the attributes should be immutable to avoid accidents.
219    """
220
221    # TODO: allow model to override axis labels input/output name/unit
222
223    # Process multiplicity
224    non_fittable = []  # type: List[str]
225    xlabel = model_info.profile_axes[0] if model_info.profile is not None else ""
226    variants = MultiplicityInfo(0, "", [], xlabel)
227    for p in model_info.parameters.kernel_parameters:
228        if p.name == model_info.control:
229            non_fittable.append(p.name)
230            variants = MultiplicityInfo(
231                len(p.choices) if p.choices else int(p.limits[1]),
232                p.name, p.choices, xlabel
233            )
234            break
235
236    # Only a single drop-down list parameter available
237    fun_list = []
238    for p in model_info.parameters.kernel_parameters:
239        if p.choices:
240            fun_list = p.choices
241            if p.length > 1:
242                non_fittable.extend(p.id+str(k) for k in range(1, p.length+1))
243            break
244
245    # Organize parameter sets
246    orientation_params = []
247    magnetic_params = []
248    fixed = []
249    for p in model_info.parameters.user_parameters({}, is2d=True):
250        if p.type == 'orientation':
251            orientation_params.append(p.name)
252            orientation_params.append(p.name+".width")
253            fixed.append(p.name+".width")
254        elif p.type == 'magnetic':
255            orientation_params.append(p.name)
256            magnetic_params.append(p.name)
257            fixed.append(p.name+".width")
258
259
260    # Build class dictionary
261    attrs = {}  # type: Dict[str, Any]
262    attrs['_model_info'] = model_info
263    attrs['name'] = model_info.name
264    attrs['id'] = model_info.id
265    attrs['description'] = model_info.description
266    attrs['category'] = model_info.category
267    attrs['is_structure_factor'] = model_info.structure_factor
268    attrs['is_form_factor'] = model_info.ER is not None
269    attrs['is_multiplicity_model'] = variants[0] > 1
270    attrs['multiplicity_info'] = variants
271    attrs['orientation_params'] = tuple(orientation_params)
272    attrs['magnetic_params'] = tuple(magnetic_params)
273    attrs['fixed'] = tuple(fixed)
274    attrs['non_fittable'] = tuple(non_fittable)
275    attrs['fun_list'] = tuple(fun_list)
276
277    return attrs
278
279class SasviewModel(object):
280    """
281    Sasview wrapper for opencl/ctypes model.
282    """
283    # Model parameters for the specific model are set in the class constructor
284    # via the _generate_model_attributes function, which subclasses
285    # SasviewModel.  They are included here for typing and documentation
286    # purposes.
287    _model = None       # type: KernelModel
288    _model_info = None  # type: ModelInfo
289    #: load/save name for the model
290    id = None           # type: str
291    #: display name for the model
292    name = None         # type: str
293    #: short model description
294    description = None  # type: str
295    #: default model category
296    category = None     # type: str
297
298    #: names of the orientation parameters in the order they appear
299    orientation_params = None # type: List[str]
300    #: names of the magnetic parameters in the order they appear
301    magnetic_params = None    # type: List[str]
302    #: names of the fittable parameters
303    fixed = None              # type: List[str]
304    # TODO: the attribute fixed is ill-named
305
306    # Axis labels
307    input_name = "Q"
308    input_unit = "A^{-1}"
309    output_name = "Intensity"
310    output_unit = "cm^{-1}"
311
312    #: default cutoff for polydispersity
313    cutoff = 1e-5
314
315    # Note: Use non-mutable values for class attributes to avoid errors
316    #: parameters that are not fitted
317    non_fittable = ()        # type: Sequence[str]
318
319    #: True if model should appear as a structure factor
320    is_structure_factor = False
321    #: True if model should appear as a form factor
322    is_form_factor = False
323    #: True if model has multiplicity
324    is_multiplicity_model = False
325    #: Multiplicity information
326    multiplicity_info = None # type: MultiplicityInfoType
327
328    # Per-instance variables
329    #: parameter {name: value} mapping
330    params = None      # type: Dict[str, float]
331    #: values for dispersion width, npts, nsigmas and type
332    dispersion = None  # type: Dict[str, Any]
333    #: units and limits for each parameter
334    details = None     # type: Dict[str, Sequence[Any]]
335    #                  # actual type is Dict[str, List[str, float, float]]
336    #: multiplicity value, or None if no multiplicity on the model
337    multiplicity = None     # type: Optional[int]
338    #: memory for polydispersity array if using ArrayDispersion (used by sasview).
339    _persistency_dict = None # type: Dict[str, Tuple[np.ndarray, np.ndarray]]
340
341    def __init__(self, multiplicity=None):
342        # type: (Optional[int]) -> None
343
344        # TODO: _persistency_dict to persistency_dict throughout sasview
345        # TODO: refactor multiplicity to encompass variants
346        # TODO: dispersion should be a class
347        # TODO: refactor multiplicity info
348        # TODO: separate profile view from multiplicity
349        # The button label, x and y axis labels and scale need to be under
350        # the control of the model, not the fit page.  Maximum flexibility,
351        # the fit page would supply the canvas and the profile could plot
352        # how it wants, but this assumes matplotlib.  Next level is that
353        # we provide some sort of data description including title, labels
354        # and lines to plot.
355
356        # Get the list of hidden parameters given the multiplicity
357        # Don't include multiplicity in the list of parameters
358        self.multiplicity = multiplicity
359        if multiplicity is not None:
360            hidden = self._model_info.get_hidden_parameters(multiplicity)
361            hidden |= set([self.multiplicity_info.control])
362        else:
363            hidden = set()
364        if self._model_info.structure_factor:
365            hidden.add('scale')
366            hidden.add('background')
367            self._model_info.parameters.defaults['background'] = 0.
368
369        self._persistency_dict = {}
370        self.params = collections.OrderedDict()
371        self.dispersion = collections.OrderedDict()
372        self.details = {}
373        for p in self._model_info.parameters.user_parameters({}, is2d=True):
374            if p.name in hidden:
375                continue
376            self.params[p.name] = p.default
377            self.details[p.id] = [p.units, p.limits[0], p.limits[1]]
378            if p.polydisperse:
379                self.details[p.id+".width"] = [
380                    "", 0.0, 1.0 if p.relative_pd else np.inf
381                ]
382                self.dispersion[p.name] = {
383                    'width': 0,
384                    'npts': 35,
385                    'nsigmas': 3,
386                    'type': 'gaussian',
387                }
388
389    def __get_state__(self):
390        # type: () -> Dict[str, Any]
391        state = self.__dict__.copy()
392        state.pop('_model')
393        # May need to reload model info on set state since it has pointers
394        # to python implementations of Iq, etc.
395        #state.pop('_model_info')
396        return state
397
398    def __set_state__(self, state):
399        # type: (Dict[str, Any]) -> None
400        self.__dict__ = state
401        self._model = None
402
403    def __str__(self):
404        # type: () -> str
405        """
406        :return: string representation
407        """
408        return self.name
409
410    def is_fittable(self, par_name):
411        # type: (str) -> bool
412        """
413        Check if a given parameter is fittable or not
414
415        :param par_name: the parameter name to check
416        """
417        return par_name in self.fixed
418        #For the future
419        #return self.params[str(par_name)].is_fittable()
420
421
422    def getProfile(self):
423        # type: () -> (np.ndarray, np.ndarray)
424        """
425        Get SLD profile
426
427        : return: (z, beta) where z is a list of depth of the transition points
428                beta is a list of the corresponding SLD values
429        """
430        args = {} # type: Dict[str, Any]
431        for p in self._model_info.parameters.kernel_parameters:
432            if p.id == self.multiplicity_info.control:
433                value = float(self.multiplicity)
434            elif p.length == 1:
435                value = self.params.get(p.id, np.NaN)
436            else:
437                value = np.array([self.params.get(p.id+str(k), np.NaN)
438                                  for k in range(1, p.length+1)])
439            args[p.id] = value
440
441        x, y = self._model_info.profile(**args)
442        return x, 1e-6*y
443
444    def setParam(self, name, value):
445        # type: (str, float) -> None
446        """
447        Set the value of a model parameter
448
449        :param name: name of the parameter
450        :param value: value of the parameter
451
452        """
453        # Look for dispersion parameters
454        toks = name.split('.')
455        if len(toks) == 2:
456            for item in self.dispersion.keys():
457                if item == toks[0]:
458                    for par in self.dispersion[item]:
459                        if par == toks[1]:
460                            self.dispersion[item][par] = value
461                            return
462        else:
463            # Look for standard parameter
464            for item in self.params.keys():
465                if item == name:
466                    self.params[item] = value
467                    return
468
469        raise ValueError("Model does not contain parameter %s" % name)
470
471    def getParam(self, name):
472        # type: (str) -> float
473        """
474        Set the value of a model parameter
475
476        :param name: name of the parameter
477
478        """
479        # Look for dispersion parameters
480        toks = name.split('.')
481        if len(toks) == 2:
482            for item in self.dispersion.keys():
483                if item == toks[0]:
484                    for par in self.dispersion[item]:
485                        if par == toks[1]:
486                            return self.dispersion[item][par]
487        else:
488            # Look for standard parameter
489            for item in self.params.keys():
490                if item == name:
491                    return self.params[item]
492
493        raise ValueError("Model does not contain parameter %s" % name)
494
495    def getParamList(self):
496        # type: () -> Sequence[str]
497        """
498        Return a list of all available parameters for the model
499        """
500        param_list = list(self.params.keys())
501        # WARNING: Extending the list with the dispersion parameters
502        param_list.extend(self.getDispParamList())
503        return param_list
504
505    def getDispParamList(self):
506        # type: () -> Sequence[str]
507        """
508        Return a list of polydispersity parameters for the model
509        """
510        # TODO: fix test so that parameter order doesn't matter
511        ret = ['%s.%s' % (p_name, ext)
512               for p_name in self.dispersion.keys()
513               for ext in ('npts', 'nsigmas', 'width')]
514        #print(ret)
515        return ret
516
517    def clone(self):
518        # type: () -> "SasviewModel"
519        """ Return a identical copy of self """
520        return deepcopy(self)
521
522    def run(self, x=0.0):
523        # type: (Union[float, (float, float), List[float]]) -> float
524        """
525        Evaluate the model
526
527        :param x: input q, or [q,phi]
528
529        :return: scattering function P(q)
530
531        **DEPRECATED**: use calculate_Iq instead
532        """
533        if isinstance(x, (list, tuple)):
534            # pylint: disable=unpacking-non-sequence
535            q, phi = x
536            return self.calculate_Iq([q*math.cos(phi)], [q*math.sin(phi)])[0]
537        else:
538            return self.calculate_Iq([x])[0]
539
540
541    def runXY(self, x=0.0):
542        # type: (Union[float, (float, float), List[float]]) -> float
543        """
544        Evaluate the model in cartesian coordinates
545
546        :param x: input q, or [qx, qy]
547
548        :return: scattering function P(q)
549
550        **DEPRECATED**: use calculate_Iq instead
551        """
552        if isinstance(x, (list, tuple)):
553            return self.calculate_Iq([x[0]], [x[1]])[0]
554        else:
555            return self.calculate_Iq([x])[0]
556
557    def evalDistribution(self, qdist):
558        # type: (Union[np.ndarray, Tuple[np.ndarray, np.ndarray], List[np.ndarray]]) -> np.ndarray
559        r"""
560        Evaluate a distribution of q-values.
561
562        :param qdist: array of q or a list of arrays [qx,qy]
563
564        * For 1D, a numpy array is expected as input
565
566        ::
567
568            evalDistribution(q)
569
570          where *q* is a numpy array.
571
572        * For 2D, a list of *[qx,qy]* is expected with 1D arrays as input
573
574        ::
575
576              qx = [ qx[0], qx[1], qx[2], ....]
577              qy = [ qy[0], qy[1], qy[2], ....]
578
579        If the model is 1D only, then
580
581        .. math::
582
583            q = \sqrt{q_x^2+q_y^2}
584
585        """
586        if isinstance(qdist, (list, tuple)):
587            # Check whether we have a list of ndarrays [qx,qy]
588            qx, qy = qdist
589            return self.calculate_Iq(qx, qy)
590
591        elif isinstance(qdist, np.ndarray):
592            # We have a simple 1D distribution of q-values
593            return self.calculate_Iq(qdist)
594
595        else:
596            raise TypeError("evalDistribution expects q or [qx, qy], not %r"
597                            % type(qdist))
598
599    def calc_composition_models(self, qx):
600        """
601        returns parts of the composition model or None if not a composition
602        model.
603        """
604        # TODO: have calculate_Iq return the intermediates.
605        #
606        # The current interface causes calculate_Iq() to be called twice,
607        # once to get the combined result and again to get the intermediate
608        # results.  This is necessary for now.
609        # Long term, the solution is to change the interface to calculate_Iq
610        # so that it returns a results object containing all the bits:
611        #     the A, B, C, ... of the composition model (and any subcomponents?)
612        #     the P and S of the product model,
613        #     the combined model before resolution smearing,
614        #     the sasmodel before sesans conversion,
615        #     the oriented 2D model used to fit oriented usans data,
616        #     the final I(q),
617        #     ...
618        #
619        # Have the model calculator add all of these blindly to the data
620        # tree, and update the graphs which contain them.  The fitter
621        # needs to be updated to use the I(q) value only, ignoring the rest.
622        #
623        # The simple fix of returning the existing intermediate results
624        # will not work for a couple of reasons: (1) another thread may
625        # sneak in to compute its own results before calc_composition_models
626        # is called, and (2) calculate_Iq is currently called three times:
627        # once with q, once with q values before qmin and once with q values
628        # after q max.  Both of these should be addressed before
629        # replacing this code.
630        composition = self._model_info.composition
631        if composition and composition[0] == 'product': # only P*S for now
632            with calculation_lock:
633                self._calculate_Iq(qx)
634                return self._intermediate_results
635        else:
636            return None
637
638    def calculate_Iq(self, qx, qy=None):
639        # type: (Sequence[float], Optional[Sequence[float]]) -> np.ndarray
640        """
641        Calculate Iq for one set of q with the current parameters.
642
643        If the model is 1D, use *q*.  If 2D, use *qx*, *qy*.
644
645        This should NOT be used for fitting since it copies the *q* vectors
646        to the card for each evaluation.
647        """
648        ## uncomment the following when trying to debug the uncoordinated calls
649        ## to calculate_Iq
650        #if calculation_lock.locked():
651        #    logger.info("calculation waiting for another thread to complete")
652        #    logger.info("\n".join(traceback.format_stack()))
653
654        with calculation_lock:
655            return self._calculate_Iq(qx, qy)
656
657    def _calculate_Iq(self, qx, qy=None):
658        if self._model is None:
659            self._model = core.build_model(self._model_info)
660        if qy is not None:
661            q_vectors = [np.asarray(qx), np.asarray(qy)]
662        else:
663            q_vectors = [np.asarray(qx)]
664        calculator = self._model.make_kernel(q_vectors)
665        parameters = self._model_info.parameters
666        pairs = [self._get_weights(p) for p in parameters.call_parameters]
667        #weights.plot_weights(self._model_info, pairs)
668        call_details, values, is_magnetic = make_kernel_args(calculator, pairs)
669        #call_details.show()
670        #print("================ parameters ==================")
671        #for p, v in zip(parameters.call_parameters, pairs): print(p.name, v[0])
672        #for k, p in enumerate(self._model_info.parameters.call_parameters):
673        #    print(k, p.name, *pairs[k])
674        #print("params", self.params)
675        #print("values", values)
676        #print("is_mag", is_magnetic)
677        result = calculator(call_details, values, cutoff=self.cutoff,
678                            magnetic=is_magnetic)
679        #print("result", result)
680        self._intermediate_results = getattr(calculator, 'results', None)
681        calculator.release()
682        #self._model.release()
683        return result
684
685    def calculate_ER(self):
686        # type: () -> float
687        """
688        Calculate the effective radius for P(q)*S(q)
689
690        :return: the value of the effective radius
691        """
692        if self._model_info.ER is None:
693            return 1.0
694        else:
695            value, weight = self._dispersion_mesh()
696            fv = self._model_info.ER(*value)
697            #print(values[0].shape, weights.shape, fv.shape)
698            return np.sum(weight * fv) / np.sum(weight)
699
700    def calculate_VR(self):
701        # type: () -> float
702        """
703        Calculate the volf ratio for P(q)*S(q)
704
705        :return: the value of the volf ratio
706        """
707        if self._model_info.VR is None:
708            return 1.0
709        else:
710            value, weight = self._dispersion_mesh()
711            whole, part = self._model_info.VR(*value)
712            return np.sum(weight * part) / np.sum(weight * whole)
713
714    def set_dispersion(self, parameter, dispersion):
715        # type: (str, weights.Dispersion) -> None
716        """
717        Set the dispersion object for a model parameter
718
719        :param parameter: name of the parameter [string]
720        :param dispersion: dispersion object of type Dispersion
721        """
722        if parameter in self.params:
723            # TODO: Store the disperser object directly in the model.
724            # The current method of relying on the sasview GUI to
725            # remember them is kind of funky.
726            # Note: can't seem to get disperser parameters from sasview
727            # (1) Could create a sasview model that has not yet been
728            # converted, assign the disperser to one of its polydisperse
729            # parameters, then retrieve the disperser parameters from the
730            # sasview model.
731            # (2) Could write a disperser parameter retriever in sasview.
732            # (3) Could modify sasview to use sasmodels.weights dispersers.
733            # For now, rely on the fact that the sasview only ever uses
734            # new dispersers in the set_dispersion call and create a new
735            # one instead of trying to assign parameters.
736            self.dispersion[parameter] = dispersion.get_pars()
737        else:
738            raise ValueError("%r is not a dispersity or orientation parameter"
739                             % parameter)
740
741    def _dispersion_mesh(self):
742        # type: () -> List[Tuple[np.ndarray, np.ndarray]]
743        """
744        Create a mesh grid of dispersion parameters and weights.
745
746        Returns [p1,p2,...],w where pj is a vector of values for parameter j
747        and w is a vector containing the products for weights for each
748        parameter set in the vector.
749        """
750        pars = [self._get_weights(p)
751                for p in self._model_info.parameters.call_parameters
752                if p.type == 'volume']
753        return dispersion_mesh(self._model_info, pars)
754
755    def _get_weights(self, par):
756        # type: (Parameter) -> Tuple[np.ndarray, np.ndarray]
757        """
758        Return dispersion weights for parameter
759        """
760        if par.name not in self.params:
761            if par.name == self.multiplicity_info.control:
762                return self.multiplicity, [self.multiplicity], [1.0]
763            else:
764                # For hidden parameters use default values.  This sets
765                # scale=1 and background=0 for structure factors
766                default = self._model_info.parameters.defaults.get(par.name, np.NaN)
767                return default, [default], [1.0]
768        elif par.polydisperse:
769            value = self.params[par.name]
770            dis = self.dispersion[par.name]
771            if dis['type'] == 'array':
772                dispersity, weight = dis['values'], dis['weights']
773            else:
774                dispersity, weight = weights.get_weights(
775                    dis['type'], dis['npts'], dis['width'], dis['nsigmas'],
776                    value, par.limits, par.relative_pd)
777            return value, dispersity, weight
778        else:
779            value = self.params[par.name]
780            return value, [value], [1.0]
781
782def test_cylinder():
783    # type: () -> float
784    """
785    Test that the cylinder model runs, returning the value at [0.1,0.1].
786    """
787    Cylinder = _make_standard_model('cylinder')
788    cylinder = Cylinder()
789    return cylinder.evalDistribution([0.1, 0.1])
790
791def test_structure_factor():
792    # type: () -> float
793    """
794    Test that 2-D hardsphere model runs and doesn't produce NaN.
795    """
796    Model = _make_standard_model('hardsphere')
797    model = Model()
798    value2d = model.evalDistribution([0.1, 0.1])
799    value1d = model.evalDistribution(np.array([0.1*np.sqrt(2)]))
800    #print("hardsphere", value1d, value2d)
801    if np.isnan(value1d) or np.isnan(value2d):
802        raise ValueError("hardsphere returns nan")
803
804def test_product():
805    # type: () -> float
806    """
807    Test that 2-D hardsphere model runs and doesn't produce NaN.
808    """
809    S = _make_standard_model('hayter_msa')()
810    P = _make_standard_model('cylinder')()
811    model = MultiplicationModel(P, S)
812    value = model.evalDistribution([0.1, 0.1])
813    if np.isnan(value):
814        raise ValueError("cylinder*hatyer_msa returns null")
815
816def test_rpa():
817    # type: () -> float
818    """
819    Test that the 2-D RPA model runs
820    """
821    RPA = _make_standard_model('rpa')
822    rpa = RPA(3)
823    return rpa.evalDistribution([0.1, 0.1])
824
825def test_empty_distribution():
826    # type: () -> None
827    """
828    Make sure that sasmodels returns NaN when there are no polydispersity points
829    """
830    Cylinder = _make_standard_model('cylinder')
831    cylinder = Cylinder()
832    cylinder.setParam('radius', -1.0)
833    cylinder.setParam('background', 0.)
834    Iq = cylinder.evalDistribution(np.asarray([0.1]))
835    assert Iq[0] == 0., "empty distribution fails"
836
837def test_model_list():
838    # type: () -> None
839    """
840    Make sure that all models build as sasview models
841    """
842    from .exception import annotate_exception
843    for name in core.list_models():
844        try:
845            _make_standard_model(name)
846        except:
847            annotate_exception("when loading "+name)
848            raise
849
850def test_old_name():
851    # type: () -> None
852    """
853    Load and run cylinder model as sas-models-CylinderModel
854    """
855    if not SUPPORT_OLD_STYLE_PLUGINS:
856        return
857    try:
858        # if sasview is not on the path then don't try to test it
859        import sas
860    except ImportError:
861        return
862    load_standard_models()
863    from sas.models.CylinderModel import CylinderModel
864    CylinderModel().evalDistribution([0.1, 0.1])
865
866def magnetic_demo():
867    Model = _make_standard_model('sphere')
868    model = Model()
869    model.setParam('M0:sld', 8)
870    q = np.linspace(-0.35, 0.35, 500)
871    qx, qy = np.meshgrid(q, q)
872    result = model.calculate_Iq(qx.flatten(), qy.flatten())
873    result = result.reshape(qx.shape)
874
875    import pylab
876    pylab.imshow(np.log(result + 0.001))
877    pylab.show()
878
879if __name__ == "__main__":
880    print("cylinder(0.1,0.1)=%g"%test_cylinder())
881    #magnetic_demo()
882    #test_product()
883    #test_structure_factor()
884    #print("rpa:", test_rpa())
885    #test_empty_distribution()
Note: See TracBrowser for help on using the repository browser.