source: sasmodels/sasmodels/sasview_model.py @ fd7291e

core_shell_microgelsmagnetic_modelticket-1257-vesicle-productticket_1156ticket_1265_superballticket_822_more_unit_tests
Last change on this file since fd7291e was fd7291e, checked in by Paul Kienzle <pkienzle@…>, 6 months ago

Merge branch 'beta_approx' into ticket-1022-sum_multiplicity

  • Property mode set to 100644
File size: 35.8 KB
Line 
1"""
2Sasview model constructor.
3
4Given a module defining an OpenCL kernel such as sasmodels.models.cylinder,
5create a sasview model class to run that kernel as follows::
6
7    from sasmodels.sasview_model import load_custom_model
8    CylinderModel = load_custom_model('sasmodels/models/cylinder.py')
9"""
10from __future__ import print_function
11
12import math
13from copy import deepcopy
14import collections
15import traceback
16import logging
17from os.path import basename, splitext, abspath, getmtime
18try:
19    import _thread as thread
20except ImportError:
21    import thread
22
23import numpy as np  # type: ignore
24
25from . import core
26from . import custom
27from . import kernelcl
28from . import product
29from . import generate
30from . import weights
31from . import modelinfo
32from .details import make_kernel_args, dispersion_mesh
33from .kernelcl import reset_environment
34
35# pylint: disable=unused-import
36try:
37    from typing import (Dict, Mapping, Any, Sequence, Tuple, NamedTuple,
38                        List, Optional, Union, Callable)
39    from .modelinfo import ModelInfo, Parameter
40    from .kernel import KernelModel
41    MultiplicityInfoType = NamedTuple(
42        'MultiplicityInfo',
43        [("number", int), ("control", str), ("choices", List[str]),
44         ("x_axis_label", str)])
45    SasviewModelType = Callable[[int], "SasviewModel"]
46except ImportError:
47    pass
48# pylint: enable=unused-import
49
50logger = logging.getLogger(__name__)
51
52calculation_lock = thread.allocate_lock()
53
54#: True if pre-existing plugins, with the old names and parameters, should
55#: continue to be supported.
56SUPPORT_OLD_STYLE_PLUGINS = True
57
58# TODO: separate x_axis_label from multiplicity info
59MultiplicityInfo = collections.namedtuple(
60    'MultiplicityInfo',
61    ["number", "control", "choices", "x_axis_label"],
62)
63
64#: set of defined models (standard and custom)
65MODELS = {}  # type: Dict[str, SasviewModelType]
66# TODO: remove unused MODEL_BY_PATH cache once sasview no longer references it
67#: custom model {path: model} mapping so we can check timestamps
68MODEL_BY_PATH = {}  # type: Dict[str, SasviewModelType]
69#: Track modules that we have loaded so we can determine whether the model
70#: has changed since we last reloaded.
71_CACHED_MODULE = {}  # type: Dict[str, "module"]
72
73def reset_environment():
74    # type: () -> None
75    """
76    Clear the compute engine context so that the GUI can change devices.
77
78    This removes all compiled kernels, even those that are active on fit
79    pages, but they will be restored the next time they are needed.
80    """
81    kernelcl.reset_environment()
82    for model in MODELS.values():
83        model._model = None
84
85def find_model(modelname):
86    # type: (str) -> SasviewModelType
87    """
88    Find a model by name.  If the model name ends in py, try loading it from
89    custom models, otherwise look for it in the list of builtin models.
90    """
91    # TODO: used by sum/product model to load an existing model
92    # TODO: doesn't handle custom models properly
93    if modelname.endswith('.py'):
94        return load_custom_model(modelname)
95    elif modelname in MODELS:
96        return MODELS[modelname]
97    else:
98        raise ValueError("unknown model %r"%modelname)
99
100
101# TODO: figure out how to say that the return type is a subclass
102def load_standard_models():
103    # type: () -> List[SasviewModelType]
104    """
105    Load and return the list of predefined models.
106
107    If there is an error loading a model, then a traceback is logged and the
108    model is not returned.
109    """
110    for name in core.list_models():
111        try:
112            MODELS[name] = _make_standard_model(name)
113        except Exception:
114            logger.error(traceback.format_exc())
115    if SUPPORT_OLD_STYLE_PLUGINS:
116        _register_old_models()
117
118    return list(MODELS.values())
119
120
121def load_custom_model(path):
122    # type: (str) -> SasviewModelType
123    """
124    Load a custom model given the model path.
125    """
126    #logger.info("Loading model %s", path)
127
128    # Load the kernel module.  This may already be cached by the loader, so
129    # only requires checking the timestamps of the dependents.
130    kernel_module = custom.load_custom_kernel_module(path)
131
132    # Check if the module has changed since we last looked.
133    reloaded = kernel_module != _CACHED_MODULE.get(path, None)
134    _CACHED_MODULE[path] = kernel_module
135
136    # Turn the module into a model.  We need to do this in even if the
137    # model has already been loaded so that we can determine the model
138    # name and retrieve it from the MODELS cache.
139    model = getattr(kernel_module, 'Model', None)
140    if model is not None:
141        # Old style models do not set the name in the class attributes, so
142        # set it here; this name will be overridden when the object is created
143        # with an instance variable that has the same value.
144        if model.name == "":
145            model.name = splitext(basename(path))[0]
146        if not hasattr(model, 'filename'):
147            model.filename = abspath(kernel_module.__file__).replace('.pyc', '.py')
148        if not hasattr(model, 'id'):
149            model.id = splitext(basename(model.filename))[0]
150    else:
151        model_info = modelinfo.make_model_info(kernel_module)
152        model = make_model_from_info(model_info)
153
154    # If a model name already exists and we are loading a different model,
155    # use the model file name as the model name.
156    if model.name in MODELS and not model.filename == MODELS[model.name].filename:
157        _previous_name = model.name
158        model.name = model.id
159
160        # If the new model name is still in the model list (for instance,
161        # if we put a cylinder.py in our plug-in directory), then append
162        # an identifier.
163        if model.name in MODELS and not model.filename == MODELS[model.name].filename:
164            model.name = model.id + '_user'
165        logger.info("Model %s already exists: using %s [%s]",
166                    _previous_name, model.name, model.filename)
167
168    # Only update the model if the module has changed
169    if reloaded or model.name not in MODELS:
170        MODELS[model.name] = model
171
172    return MODELS[model.name]
173
174
175def make_model_from_info(model_info):
176    # type: (ModelInfo) -> SasviewModelType
177    """
178    Convert *model_info* into a SasView model wrapper.
179    """
180    def __init__(self, multiplicity=None):
181        SasviewModel.__init__(self, multiplicity=multiplicity)
182    attrs = _generate_model_attributes(model_info)
183    attrs['__init__'] = __init__
184    attrs['filename'] = model_info.filename
185    ConstructedModel = type(model_info.name, (SasviewModel,), attrs) # type: SasviewModelType
186    return ConstructedModel
187
188
189def _make_standard_model(name):
190    # type: (str) -> SasviewModelType
191    """
192    Load the sasview model defined by *name*.
193
194    *name* can be a standard model name or a path to a custom model.
195
196    Returns a class that can be used directly as a sasview model.
197    """
198    kernel_module = generate.load_kernel_module(name)
199    model_info = modelinfo.make_model_info(kernel_module)
200    return make_model_from_info(model_info)
201
202
203def _register_old_models():
204    # type: () -> None
205    """
206    Place the new models into sasview under the old names.
207
208    Monkey patch sas.sascalc.fit as sas.models so that sas.models.pluginmodel
209    is available to the plugin modules.
210    """
211    import sys
212    import sas   # needed in order to set sas.models
213    import sas.sascalc.fit
214    sys.modules['sas.models'] = sas.sascalc.fit
215    sas.models = sas.sascalc.fit
216    import sas.models
217    from sasmodels.conversion_table import CONVERSION_TABLE
218
219    for new_name, conversion in CONVERSION_TABLE.get((3, 1, 2), {}).items():
220        # CoreShellEllipsoidModel => core_shell_ellipsoid:1
221        new_name = new_name.split(':')[0]
222        old_name = conversion[0] if len(conversion) < 3 else conversion[2]
223        module_attrs = {old_name: find_model(new_name)}
224        ConstructedModule = type(old_name, (), module_attrs)
225        old_path = 'sas.models.' + old_name
226        setattr(sas.models, old_path, ConstructedModule)
227        sys.modules[old_path] = ConstructedModule
228
229
230def MultiplicationModel(form_factor, structure_factor):
231    # type: ("SasviewModel", "SasviewModel") -> "SasviewModel"
232    """
233    Returns a constructed product model from form_factor and structure_factor.
234    """
235    model_info = product.make_product_info(form_factor._model_info,
236                                           structure_factor._model_info)
237    ConstructedModel = make_model_from_info(model_info)
238    return ConstructedModel(form_factor.multiplicity)
239
240
241def _generate_model_attributes(model_info):
242    # type: (ModelInfo) -> Dict[str, Any]
243    """
244    Generate the class attributes for the model.
245
246    This should include all the information necessary to query the model
247    details so that you do not need to instantiate a model to query it.
248
249    All the attributes should be immutable to avoid accidents.
250    """
251
252    # TODO: allow model to override axis labels input/output name/unit
253
254    # Process multiplicity
255    control_pars = [p.id for p in model_info.parameters.kernel_parameters
256                    if p.is_control]
257    control_id = control_pars[0] if control_pars else None
258    non_fittable = []  # type: List[str]
259    xlabel = model_info.profile_axes[0] if model_info.profile is not None else ""
260    variants = MultiplicityInfo(0, "", [], xlabel)
261    for p in model_info.parameters.kernel_parameters:
262        if p.id == control_id:
263            non_fittable.append(p.name)
264            variants = MultiplicityInfo(
265                len(p.choices) if p.choices else int(p.limits[1]),
266                p.name, p.choices, xlabel
267            )
268            break
269
270    # Only a single drop-down list parameter available
271    fun_list = []
272    for p in model_info.parameters.kernel_parameters:
273        if p.choices:
274            fun_list = p.choices
275            if p.length > 1:
276                non_fittable.extend(p.id+str(k) for k in range(1, p.length+1))
277            break
278
279    # Organize parameter sets
280    orientation_params = []
281    magnetic_params = []
282    fixed = []
283    for p in model_info.parameters.user_parameters({}, is2d=True):
284        if p.type == 'orientation':
285            orientation_params.append(p.name)
286            orientation_params.append(p.name+".width")
287            fixed.append(p.name+".width")
288        elif p.type == 'magnetic':
289            orientation_params.append(p.name)
290            magnetic_params.append(p.name)
291            fixed.append(p.name+".width")
292
293
294    # Build class dictionary
295    attrs = {}  # type: Dict[str, Any]
296    attrs['_model_info'] = model_info
297    attrs['name'] = model_info.name
298    attrs['id'] = model_info.id
299    attrs['description'] = model_info.description
300    attrs['category'] = model_info.category
301    attrs['is_structure_factor'] = model_info.structure_factor
302    attrs['is_form_factor'] = model_info.effective_radius_type is not None
303    attrs['is_multiplicity_model'] = variants[0] > 1
304    attrs['multiplicity_info'] = variants
305    attrs['orientation_params'] = tuple(orientation_params)
306    attrs['magnetic_params'] = tuple(magnetic_params)
307    attrs['fixed'] = tuple(fixed)
308    attrs['non_fittable'] = tuple(non_fittable)
309    attrs['fun_list'] = tuple(fun_list)
310
311    return attrs
312
313class SasviewModel(object):
314    """
315    Sasview wrapper for opencl/ctypes model.
316    """
317    # Model parameters for the specific model are set in the class constructor
318    # via the _generate_model_attributes function, which subclasses
319    # SasviewModel.  They are included here for typing and documentation
320    # purposes.
321    _model = None       # type: KernelModel
322    _model_info = None  # type: ModelInfo
323    #: load/save name for the model
324    id = None           # type: str
325    #: display name for the model
326    name = None         # type: str
327    #: short model description
328    description = None  # type: str
329    #: default model category
330    category = None     # type: str
331
332    #: names of the orientation parameters in the order they appear
333    orientation_params = None # type: List[str]
334    #: names of the magnetic parameters in the order they appear
335    magnetic_params = None    # type: List[str]
336    #: names of the fittable parameters
337    fixed = None              # type: List[str]
338    # TODO: the attribute fixed is ill-named
339
340    # Axis labels
341    input_name = "Q"
342    input_unit = "A^{-1}"
343    output_name = "Intensity"
344    output_unit = "cm^{-1}"
345
346    #: default cutoff for polydispersity
347    cutoff = 1e-5
348
349    # Note: Use non-mutable values for class attributes to avoid errors
350    #: parameters that are not fitted
351    non_fittable = ()        # type: Sequence[str]
352
353    #: True if model should appear as a structure factor
354    is_structure_factor = False
355    #: True if model should appear as a form factor
356    is_form_factor = False
357    #: True if model has multiplicity
358    is_multiplicity_model = False
359    #: Multiplicity information
360    multiplicity_info = None # type: MultiplicityInfoType
361
362    # Per-instance variables
363    #: parameter {name: value} mapping
364    params = None      # type: Dict[str, float]
365    #: values for dispersion width, npts, nsigmas and type
366    dispersion = None  # type: Dict[str, Any]
367    #: units and limits for each parameter
368    details = None     # type: Dict[str, Sequence[Any]]
369    #                  # actual type is Dict[str, List[str, float, float]]
370    #: multiplicity value, or None if no multiplicity on the model
371    multiplicity = None     # type: Optional[int]
372    #: memory for polydispersity array if using ArrayDispersion (used by sasview).
373    _persistency_dict = None # type: Dict[str, Tuple[np.ndarray, np.ndarray]]
374
375    def __init__(self, multiplicity=None):
376        # type: (Optional[int]) -> None
377
378        # TODO: _persistency_dict to persistency_dict throughout sasview
379        # TODO: refactor multiplicity to encompass variants
380        # TODO: dispersion should be a class
381        # TODO: refactor multiplicity info
382        # TODO: separate profile view from multiplicity
383        # The button label, x and y axis labels and scale need to be under
384        # the control of the model, not the fit page.  Maximum flexibility,
385        # the fit page would supply the canvas and the profile could plot
386        # how it wants, but this assumes matplotlib.  Next level is that
387        # we provide some sort of data description including title, labels
388        # and lines to plot.
389
390        # Get the list of hidden parameters given the multiplicity
391        # Don't include multiplicity in the list of parameters
392        self.multiplicity = multiplicity
393        if multiplicity is not None:
394            hidden = self._model_info.get_hidden_parameters(multiplicity)
395            hidden |= set([self.multiplicity_info.control])
396        else:
397            hidden = set()
398        if self._model_info.structure_factor:
399            hidden.add('scale')
400            hidden.add('background')
401
402        # Update the parameter lists to exclude any hidden parameters
403        self.magnetic_params = tuple(pname for pname in self.magnetic_params
404                                     if pname not in hidden)
405        self.orientation_params = tuple(pname for pname in self.orientation_params
406                                        if pname not in hidden)
407
408        self._persistency_dict = {}
409        self.params = collections.OrderedDict()
410        self.dispersion = collections.OrderedDict()
411        self.details = {}
412        for p in self._model_info.parameters.user_parameters({}, is2d=True):
413            if p.name in hidden:
414                continue
415            self.params[p.name] = p.default
416            self.details[p.id] = [p.units, p.limits[0], p.limits[1]]
417            if p.polydisperse:
418                self.details[p.id+".width"] = [
419                    "", 0.0, 1.0 if p.relative_pd else np.inf
420                ]
421                self.dispersion[p.name] = {
422                    'width': 0,
423                    'npts': 35,
424                    'nsigmas': 3,
425                    'type': 'gaussian',
426                }
427
428    def __get_state__(self):
429        # type: () -> Dict[str, Any]
430        state = self.__dict__.copy()
431        state.pop('_model')
432        # May need to reload model info on set state since it has pointers
433        # to python implementations of Iq, etc.
434        #state.pop('_model_info')
435        return state
436
437    def __set_state__(self, state):
438        # type: (Dict[str, Any]) -> None
439        self.__dict__ = state
440        self._model = None
441
442    def __str__(self):
443        # type: () -> str
444        """
445        :return: string representation
446        """
447        return self.name
448
449    def is_fittable(self, par_name):
450        # type: (str) -> bool
451        """
452        Check if a given parameter is fittable or not
453
454        :param par_name: the parameter name to check
455        """
456        return par_name in self.fixed
457        #For the future
458        #return self.params[str(par_name)].is_fittable()
459
460
461    def getProfile(self):
462        # type: () -> (np.ndarray, np.ndarray)
463        """
464        Get SLD profile
465
466        : return: (z, beta) where z is a list of depth of the transition points
467                beta is a list of the corresponding SLD values
468        """
469        args = {} # type: Dict[str, Any]
470        for p in self._model_info.parameters.kernel_parameters:
471            if p.id == self.multiplicity_info.control:
472                value = float(self.multiplicity)
473            elif p.length == 1:
474                value = self.params.get(p.id, np.NaN)
475            else:
476                value = np.array([self.params.get(p.id+str(k), np.NaN)
477                                  for k in range(1, p.length+1)])
478            args[p.id] = value
479
480        x, y = self._model_info.profile(**args)
481        return x, 1e-6*y
482
483    def setParam(self, name, value):
484        # type: (str, float) -> None
485        """
486        Set the value of a model parameter
487
488        :param name: name of the parameter
489        :param value: value of the parameter
490
491        """
492        # Look for dispersion parameters
493        toks = name.split('.')
494        if len(toks) == 2:
495            for item in self.dispersion.keys():
496                if item == toks[0]:
497                    for par in self.dispersion[item]:
498                        if par == toks[1]:
499                            self.dispersion[item][par] = value
500                            return
501        else:
502            # Look for standard parameter
503            for item in self.params.keys():
504                if item == name:
505                    self.params[item] = value
506                    return
507
508        raise ValueError("Model does not contain parameter %s" % name)
509
510    def getParam(self, name):
511        # type: (str) -> float
512        """
513        Set the value of a model parameter
514
515        :param name: name of the parameter
516
517        """
518        # Look for dispersion parameters
519        toks = name.split('.')
520        if len(toks) == 2:
521            for item in self.dispersion.keys():
522                if item == toks[0]:
523                    for par in self.dispersion[item]:
524                        if par == toks[1]:
525                            return self.dispersion[item][par]
526        else:
527            # Look for standard parameter
528            for item in self.params.keys():
529                if item == name:
530                    return self.params[item]
531
532        raise ValueError("Model does not contain parameter %s" % name)
533
534    def getParamList(self):
535        # type: () -> Sequence[str]
536        """
537        Return a list of all available parameters for the model
538        """
539        param_list = list(self.params.keys())
540        # WARNING: Extending the list with the dispersion parameters
541        param_list.extend(self.getDispParamList())
542        return param_list
543
544    def getDispParamList(self):
545        # type: () -> Sequence[str]
546        """
547        Return a list of polydispersity parameters for the model
548        """
549        # TODO: fix test so that parameter order doesn't matter
550        ret = ['%s.%s' % (p_name, ext)
551               for p_name in self.dispersion.keys()
552               for ext in ('npts', 'nsigmas', 'width')]
553        #print(ret)
554        return ret
555
556    def clone(self):
557        # type: () -> "SasviewModel"
558        """ Return a identical copy of self """
559        return deepcopy(self)
560
561    def run(self, x=0.0):
562        # type: (Union[float, (float, float), List[float]]) -> float
563        """
564        Evaluate the model
565
566        :param x: input q, or [q,phi]
567
568        :return: scattering function P(q)
569
570        **DEPRECATED**: use calculate_Iq instead
571        """
572        if isinstance(x, (list, tuple)):
573            # pylint: disable=unpacking-non-sequence
574            q, phi = x
575            result, _ = self.calculate_Iq([q*math.cos(phi)], [q*math.sin(phi)])
576            return result[0]
577        else:
578            result, _ = self.calculate_Iq([x])
579            return result[0]
580
581
582    def runXY(self, x=0.0):
583        # type: (Union[float, (float, float), List[float]]) -> float
584        """
585        Evaluate the model in cartesian coordinates
586
587        :param x: input q, or [qx, qy]
588
589        :return: scattering function P(q)
590
591        **DEPRECATED**: use calculate_Iq instead
592        """
593        if isinstance(x, (list, tuple)):
594            result, _ = self.calculate_Iq([x[0]], [x[1]])
595            return result[0]
596        else:
597            result, _ = self.calculate_Iq([x])
598            return result[0]
599
600    def evalDistribution(self, qdist):
601        # type: (Union[np.ndarray, Tuple[np.ndarray, np.ndarray], List[np.ndarray]]) -> np.ndarray
602        r"""
603        Evaluate a distribution of q-values.
604
605        :param qdist: array of q or a list of arrays [qx,qy]
606
607        * For 1D, a numpy array is expected as input
608
609        ::
610
611            evalDistribution(q)
612
613          where *q* is a numpy array.
614
615        * For 2D, a list of *[qx,qy]* is expected with 1D arrays as input
616
617        ::
618
619              qx = [ qx[0], qx[1], qx[2], ....]
620              qy = [ qy[0], qy[1], qy[2], ....]
621
622        If the model is 1D only, then
623
624        .. math::
625
626            q = \sqrt{q_x^2+q_y^2}
627
628        """
629        if isinstance(qdist, (list, tuple)):
630            # Check whether we have a list of ndarrays [qx,qy]
631            qx, qy = qdist
632            result, _ = self.calculate_Iq(qx, qy)
633            return result
634
635        elif isinstance(qdist, np.ndarray):
636            # We have a simple 1D distribution of q-values
637            result, _ = self.calculate_Iq(qdist)
638            return result
639
640        else:
641            raise TypeError("evalDistribution expects q or [qx, qy], not %r"
642                            % type(qdist))
643
644    def calc_composition_models(self, qx):
645        """
646        returns parts of the composition model or None if not a composition
647        model.
648        """
649        # TODO: have calculate_Iq return the intermediates.
650        #
651        # The current interface causes calculate_Iq() to be called twice,
652        # once to get the combined result and again to get the intermediate
653        # results.  This is necessary for now.
654        # Long term, the solution is to change the interface to calculate_Iq
655        # so that it returns a results object containing all the bits:
656        #     the A, B, C, ... of the composition model (and any subcomponents?)
657        #     the P and S of the product model
658        #     the combined model before resolution smearing,
659        #     the sasmodel before sesans conversion,
660        #     the oriented 2D model used to fit oriented usans data,
661        #     the final I(q),
662        #     ...
663        #
664        # Have the model calculator add all of these blindly to the data
665        # tree, and update the graphs which contain them.  The fitter
666        # needs to be updated to use the I(q) value only, ignoring the rest.
667        #
668        # The simple fix of returning the existing intermediate results
669        # will not work for a couple of reasons: (1) another thread may
670        # sneak in to compute its own results before calc_composition_models
671        # is called, and (2) calculate_Iq is currently called three times:
672        # once with q, once with q values before qmin and once with q values
673        # after q max.  Both of these should be addressed before
674        # replacing this code.
675        composition = self._model_info.composition
676        if composition and composition[0] == 'product': # only P*S for now
677            with calculation_lock:
678                _, lazy_results = self._calculate_Iq(qx)
679                # for compatibility with sasview 4.x
680                results = lazy_results()
681                pq_data = results.get("P(Q)")
682                sq_data = results.get("S(Q)")
683                return pq_data, sq_data
684        else:
685            return None
686
687    def calculate_Iq(self,
688                     qx,     # type: Sequence[float]
689                     qy=None # type: Optional[Sequence[float]]
690                     ):
691        # type: (...) -> Tuple[np.ndarray, Callable[[], collections.OrderedDict[str, np.ndarray]]]
692        """
693        Calculate Iq for one set of q with the current parameters.
694
695        If the model is 1D, use *q*.  If 2D, use *qx*, *qy*.
696
697        This should NOT be used for fitting since it copies the *q* vectors
698        to the card for each evaluation.
699
700        The returned tuple contains the scattering intensity followed by a
701        callable which returns a dictionary of intermediate data from
702        ProductKernel.
703        """
704        ## uncomment the following when trying to debug the uncoordinated calls
705        ## to calculate_Iq
706        #if calculation_lock.locked():
707        #    logger.info("calculation waiting for another thread to complete")
708        #    logger.info("\n".join(traceback.format_stack()))
709
710        with calculation_lock:
711            return self._calculate_Iq(qx, qy)
712
713    def _calculate_Iq(self, qx, qy=None):
714        if self._model is None:
715            # Only need one copy of the compiled kernel regardless of how many
716            # times it is used, so store it in the class.  Also, to reset the
717            # compute engine, need to clear out all existing compiled kernels,
718            # which is much easier to do if we store them in the class.
719            self.__class__._model = core.build_model(self._model_info)
720        if qy is not None:
721            q_vectors = [np.asarray(qx), np.asarray(qy)]
722        else:
723            q_vectors = [np.asarray(qx)]
724        calculator = self._model.make_kernel(q_vectors)
725        parameters = self._model_info.parameters
726        pairs = [self._get_weights(p) for p in parameters.call_parameters]
727        #weights.plot_weights(self._model_info, pairs)
728        call_details, values, is_magnetic = make_kernel_args(calculator, pairs)
729        #call_details.show()
730        #print("================ parameters ==================")
731        #for p, v in zip(parameters.call_parameters, pairs): print(p.name, v[0])
732        #for k, p in enumerate(self._model_info.parameters.call_parameters):
733        #    print(k, p.name, *pairs[k])
734        #print("params", self.params)
735        #print("values", values)
736        #print("is_mag", is_magnetic)
737        result = calculator(call_details, values, cutoff=self.cutoff,
738                            magnetic=is_magnetic)
739        lazy_results = getattr(calculator, 'results',
740                               lambda: collections.OrderedDict())
741        #print("result", result)
742
743        calculator.release()
744        #self._model.release()
745
746        return result, lazy_results
747
748
749    def calculate_ER(self, mode=1):
750        # type: () -> float
751        """
752        Calculate the effective radius for P(q)*S(q)
753
754        *mode* is the R_eff type, which defaults to 1 to match the ER
755        calculation for sasview models from version 3.x.
756
757        :return: the value of the effective radius
758        """
759        # ER and VR are only needed for old multiplication models, based on
760        # sas.sascalc.fit.MultiplicationModel.  Fail for now.  If we want to
761        # continue supporting them then add some test cases so that the code
762        # is exercised.  We can access ER/VR using the kernel Fq function by
763        # extending _calculate_Iq so that it calls:
764        #    if er_mode > 0:
765        #        res = calculator.Fq(call_details, values, cutoff=self.cutoff,
766        #                            magnetic=False, effective_radius_type=mode)
767        #        R_eff, form_shell_ratio = res[2], res[4]
768        #        return R_eff, form_shell_ratio
769        # Then use the following in calculate_ER:
770        #    ER, VR = self._calculate_Iq(q=[0.1], er_mode=mode)
771        #    return ER
772        # Similarly, for calculate_VR:
773        #    ER, VR = self._calculate_Iq(q=[0.1], er_mode=1)
774        #    return VR
775        # Obviously a combined calculate_ER_VR method would be better, but
776        # we only need them to support very old models, so ignore the 2x
777        # performance hit.
778        raise NotImplementedError("ER function is no longer available.")
779
780    def calculate_VR(self):
781        # type: () -> float
782        """
783        Calculate the volf ratio for P(q)*S(q)
784
785        :return: the value of the form:shell volume ratio
786        """
787        # See comments in calculate_ER.
788        raise NotImplementedError("VR function is no longer available.")
789
790    def set_dispersion(self, parameter, dispersion):
791        # type: (str, weights.Dispersion) -> None
792        """
793        Set the dispersion object for a model parameter
794
795        :param parameter: name of the parameter [string]
796        :param dispersion: dispersion object of type Dispersion
797        """
798        if parameter in self.params:
799            # TODO: Store the disperser object directly in the model.
800            # The current method of relying on the sasview GUI to
801            # remember them is kind of funky.
802            # Note: can't seem to get disperser parameters from sasview
803            # (1) Could create a sasview model that has not yet been
804            # converted, assign the disperser to one of its polydisperse
805            # parameters, then retrieve the disperser parameters from the
806            # sasview model.
807            # (2) Could write a disperser parameter retriever in sasview.
808            # (3) Could modify sasview to use sasmodels.weights dispersers.
809            # For now, rely on the fact that the sasview only ever uses
810            # new dispersers in the set_dispersion call and create a new
811            # one instead of trying to assign parameters.
812            self.dispersion[parameter] = dispersion.get_pars()
813        else:
814            raise ValueError("%r is not a dispersity or orientation parameter"
815                             % parameter)
816
817    def _dispersion_mesh(self):
818        # type: () -> List[Tuple[np.ndarray, np.ndarray]]
819        """
820        Create a mesh grid of dispersion parameters and weights.
821
822        Returns [p1,p2,...],w where pj is a vector of values for parameter j
823        and w is a vector containing the products for weights for each
824        parameter set in the vector.
825        """
826        pars = [self._get_weights(p)
827                for p in self._model_info.parameters.call_parameters
828                if p.type == 'volume']
829        return dispersion_mesh(self._model_info, pars)
830
831    def _get_weights(self, par):
832        # type: (Parameter) -> Tuple[np.ndarray, np.ndarray]
833        """
834        Return dispersion weights for parameter
835        """
836        if par.name not in self.params:
837            if par.id == self.multiplicity_info.control:
838                return self.multiplicity, [self.multiplicity], [1.0]
839            else:
840                # For hidden parameters use default values.  This sets
841                # scale=1 and background=0 for structure factors
842                default = self._model_info.parameters.defaults.get(par.name, np.NaN)
843                return default, [default], [1.0]
844        elif par.polydisperse:
845            value = self.params[par.name]
846            dis = self.dispersion[par.name]
847            if dis['type'] == 'array':
848                dispersity, weight = dis['values'], dis['weights']
849            else:
850                dispersity, weight = weights.get_weights(
851                    dis['type'], dis['npts'], dis['width'], dis['nsigmas'],
852                    value, par.limits, par.relative_pd)
853            return value, dispersity, weight
854        else:
855            value = self.params[par.name]
856            return value, [value], [1.0]
857
858    @classmethod
859    def runTests(cls):
860        """
861        Run any tests built into the model and captures the test output.
862
863        Returns success flag and output
864        """
865        from .model_test import check_model
866        return check_model(cls._model_info)
867
868def test_cylinder():
869    # type: () -> float
870    """
871    Test that the cylinder model runs, returning the value at [0.1,0.1].
872    """
873    Cylinder = _make_standard_model('cylinder')
874    cylinder = Cylinder()
875    return cylinder.evalDistribution([0.1, 0.1])
876
877def test_structure_factor():
878    # type: () -> float
879    """
880    Test that 2-D hardsphere model runs and doesn't produce NaN.
881    """
882    Model = _make_standard_model('hardsphere')
883    model = Model()
884    value2d = model.evalDistribution([0.1, 0.1])
885    value1d = model.evalDistribution(np.array([0.1*np.sqrt(2)]))
886    #print("hardsphere", value1d, value2d)
887    if np.isnan(value1d) or np.isnan(value2d):
888        raise ValueError("hardsphere returns nan")
889
890def test_product():
891    # type: () -> float
892    """
893    Test that 2-D hardsphere model runs and doesn't produce NaN.
894    """
895    S = _make_standard_model('hayter_msa')()
896    P = _make_standard_model('cylinder')()
897    model = MultiplicationModel(P, S)
898    model.setParam(product.RADIUS_MODE_ID, 1.0)
899    value = model.evalDistribution([0.1, 0.1])
900    if np.isnan(value):
901        raise ValueError("cylinder*hatyer_msa returns null")
902
903def test_rpa():
904    # type: () -> float
905    """
906    Test that the 2-D RPA model runs
907    """
908    RPA = _make_standard_model('rpa')
909    rpa = RPA(3)
910    return rpa.evalDistribution([0.1, 0.1])
911
912def test_empty_distribution():
913    # type: () -> None
914    """
915    Make sure that sasmodels returns NaN when there are no polydispersity points
916    """
917    Cylinder = _make_standard_model('cylinder')
918    cylinder = Cylinder()
919    cylinder.setParam('radius', -1.0)
920    cylinder.setParam('background', 0.)
921    Iq = cylinder.evalDistribution(np.asarray([0.1]))
922    assert Iq[0] == 0., "empty distribution fails"
923
924def test_model_list():
925    # type: () -> None
926    """
927    Make sure that all models build as sasview models
928    """
929    from .exception import annotate_exception
930    for name in core.list_models():
931        try:
932            _make_standard_model(name)
933        except:
934            annotate_exception("when loading "+name)
935            raise
936
937def test_old_name():
938    # type: () -> None
939    """
940    Load and run cylinder model as sas-models-CylinderModel
941    """
942    if not SUPPORT_OLD_STYLE_PLUGINS:
943        return
944    try:
945        # if sasview is not on the path then don't try to test it
946        import sas
947    except ImportError:
948        return
949    load_standard_models()
950    from sas.models.CylinderModel import CylinderModel
951    CylinderModel().evalDistribution([0.1, 0.1])
952
953def test_structure_factor_background():
954    # type: () -> None
955    """
956    Check that sasview model and direct model match, with background=0.
957    """
958    from .data import empty_data1D
959    from .core import load_model_info, build_model
960    from .direct_model import DirectModel
961
962    model_name = "hardsphere"
963    q = [0.0]
964
965    sasview_model = _make_standard_model(model_name)()
966    sasview_value = sasview_model.evalDistribution(np.array(q))[0]
967
968    data = empty_data1D(q)
969    model_info = load_model_info(model_name)
970    model = build_model(model_info)
971    direct_model = DirectModel(data, model)
972    direct_value_zero_background = direct_model(background=0.0)
973
974    assert sasview_value == direct_value_zero_background
975
976    # Additionally check that direct value background defaults to zero
977    direct_value_default = direct_model()
978    assert sasview_value == direct_value_default
979
980
981def magnetic_demo():
982    Model = _make_standard_model('sphere')
983    model = Model()
984    model.setParam('sld_M0', 8)
985    q = np.linspace(-0.35, 0.35, 500)
986    qx, qy = np.meshgrid(q, q)
987    result, _ = model.calculate_Iq(qx.flatten(), qy.flatten())
988    result = result.reshape(qx.shape)
989
990    import pylab
991    pylab.imshow(np.log(result + 0.001))
992    pylab.show()
993
994if __name__ == "__main__":
995    print("cylinder(0.1,0.1)=%g"%test_cylinder())
996    #magnetic_demo()
997    #test_product()
998    #test_structure_factor()
999    #print("rpa:", test_rpa())
1000    #test_empty_distribution()
1001    #test_structure_factor_background()
Note: See TracBrowser for help on using the repository browser.