source: sasmodels/sasmodels/sasview_model.py @ a430f5f

ticket-1257-vesicle-productticket_1156ticket_822_more_unit_tests
Last change on this file since a430f5f was b297ba9, checked in by Paul Kienzle <pkienzle@…>, 5 years ago

lint

  • Property mode set to 100644
File size: 35.8 KB
Line 
1"""
2Sasview model constructor.
3
4Given a module defining an OpenCL kernel such as sasmodels.models.cylinder,
5create a sasview model class to run that kernel as follows::
6
7    from sasmodels.sasview_model import load_custom_model
8    CylinderModel = load_custom_model('sasmodels/models/cylinder.py')
9"""
10from __future__ import print_function
11
12import math
13from copy import deepcopy
14import collections
15import traceback
16import logging
17from os.path import basename, splitext, abspath
18try:
19    import _thread as thread
20except ImportError:
21    import thread
22
23import numpy as np  # type: ignore
24
25from . import core
26from . import custom
27from . import kernelcl
28from . import product
29from . import generate
30from . import weights
31from . import modelinfo
32from .details import make_kernel_args, dispersion_mesh
33
34# pylint: disable=unused-import
35try:
36    from typing import (Dict, Mapping, Any, Sequence, Tuple, NamedTuple,
37                        List, Optional, Union, Callable)
38    from .modelinfo import ModelInfo, Parameter
39    from .kernel import KernelModel
40    MultiplicityInfoType = NamedTuple(
41        'MultiplicityInfo',
42        [("number", int), ("control", str), ("choices", List[str]),
43         ("x_axis_label", str)])
44    SasviewModelType = Callable[[int], "SasviewModel"]
45except ImportError:
46    pass
47# pylint: enable=unused-import
48
49logger = logging.getLogger(__name__)
50
51calculation_lock = thread.allocate_lock()
52
53#: True if pre-existing plugins, with the old names and parameters, should
54#: continue to be supported.
55SUPPORT_OLD_STYLE_PLUGINS = True
56
57# TODO: separate x_axis_label from multiplicity info
58MultiplicityInfo = collections.namedtuple(
59    'MultiplicityInfo',
60    ["number", "control", "choices", "x_axis_label"],
61)
62
63#: set of defined models (standard and custom)
64MODELS = {}  # type: Dict[str, SasviewModelType]
65# TODO: remove unused MODEL_BY_PATH cache once sasview no longer references it
66#: custom model {path: model} mapping so we can check timestamps
67MODEL_BY_PATH = {}  # type: Dict[str, SasviewModelType]
68#: Track modules that we have loaded so we can determine whether the model
69#: has changed since we last reloaded.
70_CACHED_MODULE = {}  # type: Dict[str, "module"]
71
72def reset_environment():
73    # type: () -> None
74    """
75    Clear the compute engine context so that the GUI can change devices.
76
77    This removes all compiled kernels, even those that are active on fit
78    pages, but they will be restored the next time they are needed.
79    """
80    kernelcl.reset_environment()
81    for model in MODELS.values():
82        model._model = None
83
84def find_model(modelname):
85    # type: (str) -> SasviewModelType
86    """
87    Find a model by name.  If the model name ends in py, try loading it from
88    custom models, otherwise look for it in the list of builtin models.
89    """
90    # TODO: used by sum/product model to load an existing model
91    # TODO: doesn't handle custom models properly
92    if modelname.endswith('.py'):
93        return load_custom_model(modelname)
94    elif modelname in MODELS:
95        return MODELS[modelname]
96    else:
97        raise ValueError("unknown model %r"%modelname)
98
99
100# TODO: figure out how to say that the return type is a subclass
101def load_standard_models():
102    # type: () -> List[SasviewModelType]
103    """
104    Load and return the list of predefined models.
105
106    If there is an error loading a model, then a traceback is logged and the
107    model is not returned.
108    """
109    for name in core.list_models():
110        try:
111            MODELS[name] = _make_standard_model(name)
112        except Exception:
113            logger.error(traceback.format_exc())
114    if SUPPORT_OLD_STYLE_PLUGINS:
115        _register_old_models()
116
117    return list(MODELS.values())
118
119
120def load_custom_model(path):
121    # type: (str) -> SasviewModelType
122    """
123    Load a custom model given the model path.
124    """
125    #logger.info("Loading model %s", path)
126
127    # Load the kernel module.  This may already be cached by the loader, so
128    # only requires checking the timestamps of the dependents.
129    kernel_module = custom.load_custom_kernel_module(path)
130
131    # Check if the module has changed since we last looked.
132    reloaded = kernel_module != _CACHED_MODULE.get(path, None)
133    _CACHED_MODULE[path] = kernel_module
134
135    # Turn the module into a model.  We need to do this in even if the
136    # model has already been loaded so that we can determine the model
137    # name and retrieve it from the MODELS cache.
138    model = getattr(kernel_module, 'Model', None)
139    if model is not None:
140        # Old style models do not set the name in the class attributes, so
141        # set it here; this name will be overridden when the object is created
142        # with an instance variable that has the same value.
143        if model.name == "":
144            model.name = splitext(basename(path))[0]
145        if not hasattr(model, 'filename'):
146            model.filename = abspath(kernel_module.__file__).replace('.pyc', '.py')
147        if not hasattr(model, 'id'):
148            model.id = splitext(basename(model.filename))[0]
149    else:
150        model_info = modelinfo.make_model_info(kernel_module)
151        model = make_model_from_info(model_info)
152
153    # If a model name already exists and we are loading a different model,
154    # use the model file name as the model name.
155    if model.name in MODELS and not model.filename == MODELS[model.name].filename:
156        _previous_name = model.name
157        model.name = model.id
158
159        # If the new model name is still in the model list (for instance,
160        # if we put a cylinder.py in our plug-in directory), then append
161        # an identifier.
162        if model.name in MODELS and not model.filename == MODELS[model.name].filename:
163            model.name = model.id + '_user'
164        logger.info("Model %s already exists: using %s [%s]",
165                    _previous_name, model.name, model.filename)
166
167    # Only update the model if the module has changed
168    if reloaded or model.name not in MODELS:
169        MODELS[model.name] = model
170
171    return MODELS[model.name]
172
173
174def make_model_from_info(model_info):
175    # type: (ModelInfo) -> SasviewModelType
176    """
177    Convert *model_info* into a SasView model wrapper.
178    """
179    def __init__(self, multiplicity=None):
180        SasviewModel.__init__(self, multiplicity=multiplicity)
181    attrs = _generate_model_attributes(model_info)
182    attrs['__init__'] = __init__
183    attrs['filename'] = model_info.filename
184    ConstructedModel = type(model_info.name, (SasviewModel,), attrs) # type: SasviewModelType
185    return ConstructedModel
186
187
188def _make_standard_model(name):
189    # type: (str) -> SasviewModelType
190    """
191    Load the sasview model defined by *name*.
192
193    *name* can be a standard model name or a path to a custom model.
194
195    Returns a class that can be used directly as a sasview model.
196    """
197    kernel_module = generate.load_kernel_module(name)
198    model_info = modelinfo.make_model_info(kernel_module)
199    return make_model_from_info(model_info)
200
201
202def _register_old_models():
203    # type: () -> None
204    """
205    Place the new models into sasview under the old names.
206
207    Monkey patch sas.sascalc.fit as sas.models so that sas.models.pluginmodel
208    is available to the plugin modules.
209    """
210    import sys
211    import sas   # needed in order to set sas.models
212    import sas.sascalc.fit
213    sys.modules['sas.models'] = sas.sascalc.fit
214    sas.models = sas.sascalc.fit
215    import sas.models
216    from sasmodels.conversion_table import CONVERSION_TABLE
217
218    for new_name, conversion in CONVERSION_TABLE.get((3, 1, 2), {}).items():
219        # CoreShellEllipsoidModel => core_shell_ellipsoid:1
220        new_name = new_name.split(':')[0]
221        old_name = conversion[0] if len(conversion) < 3 else conversion[2]
222        module_attrs = {old_name: find_model(new_name)}
223        ConstructedModule = type(old_name, (), module_attrs)
224        old_path = 'sas.models.' + old_name
225        setattr(sas.models, old_path, ConstructedModule)
226        sys.modules[old_path] = ConstructedModule
227
228
229def MultiplicationModel(form_factor, structure_factor):
230    # type: ("SasviewModel", "SasviewModel") -> "SasviewModel"
231    """
232    Returns a constructed product model from form_factor and structure_factor.
233    """
234    model_info = product.make_product_info(form_factor._model_info,
235                                           structure_factor._model_info)
236    ConstructedModel = make_model_from_info(model_info)
237    return ConstructedModel(form_factor.multiplicity)
238
239
240def _generate_model_attributes(model_info):
241    # type: (ModelInfo) -> Dict[str, Any]
242    """
243    Generate the class attributes for the model.
244
245    This should include all the information necessary to query the model
246    details so that you do not need to instantiate a model to query it.
247
248    All the attributes should be immutable to avoid accidents.
249    """
250
251    # TODO: allow model to override axis labels input/output name/unit
252
253    # Process multiplicity
254    control_pars = [p.id for p in model_info.parameters.kernel_parameters
255                    if p.is_control]
256    control_id = control_pars[0] if control_pars else None
257    non_fittable = []  # type: List[str]
258    xlabel = model_info.profile_axes[0] if model_info.profile is not None else ""
259    variants = MultiplicityInfo(0, "", [], xlabel)
260    for p in model_info.parameters.kernel_parameters:
261        if p.id == control_id:
262            non_fittable.append(p.name)
263            variants = MultiplicityInfo(
264                len(p.choices) if p.choices else int(p.limits[1]),
265                p.name, p.choices, xlabel
266            )
267            break
268
269    # Only a single drop-down list parameter available
270    fun_list = []
271    for p in model_info.parameters.kernel_parameters:
272        if p.choices:
273            fun_list = p.choices
274            if p.length > 1:
275                non_fittable.extend(p.id+str(k) for k in range(1, p.length+1))
276            break
277
278    # Organize parameter sets
279    orientation_params = []
280    magnetic_params = []
281    fixed = []
282    for p in model_info.parameters.user_parameters({}, is2d=True):
283        if p.type == 'orientation':
284            orientation_params.append(p.name)
285            orientation_params.append(p.name+".width")
286            fixed.append(p.name+".width")
287        elif p.type == 'magnetic':
288            orientation_params.append(p.name)
289            magnetic_params.append(p.name)
290            fixed.append(p.name+".width")
291
292
293    # Build class dictionary
294    attrs = {}  # type: Dict[str, Any]
295    attrs['_model_info'] = model_info
296    attrs['name'] = model_info.name
297    attrs['id'] = model_info.id
298    attrs['description'] = model_info.description
299    attrs['category'] = model_info.category
300    attrs['is_structure_factor'] = model_info.structure_factor
301    attrs['is_form_factor'] = model_info.effective_radius_type is not None
302    attrs['is_multiplicity_model'] = variants[0] > 1
303    attrs['multiplicity_info'] = variants
304    attrs['orientation_params'] = tuple(orientation_params)
305    attrs['magnetic_params'] = tuple(magnetic_params)
306    attrs['fixed'] = tuple(fixed)
307    attrs['non_fittable'] = tuple(non_fittable)
308    attrs['fun_list'] = tuple(fun_list)
309
310    return attrs
311
312class SasviewModel(object):
313    """
314    Sasview wrapper for opencl/ctypes model.
315    """
316    # Model parameters for the specific model are set in the class constructor
317    # via the _generate_model_attributes function, which subclasses
318    # SasviewModel.  They are included here for typing and documentation
319    # purposes.
320    _model = None       # type: KernelModel
321    _model_info = None  # type: ModelInfo
322    #: load/save name for the model
323    id = None           # type: str
324    #: display name for the model
325    name = None         # type: str
326    #: short model description
327    description = None  # type: str
328    #: default model category
329    category = None     # type: str
330
331    #: names of the orientation parameters in the order they appear
332    orientation_params = None # type: List[str]
333    #: names of the magnetic parameters in the order they appear
334    magnetic_params = None    # type: List[str]
335    #: names of the fittable parameters
336    fixed = None              # type: List[str]
337    # TODO: the attribute fixed is ill-named
338
339    # Axis labels
340    input_name = "Q"
341    input_unit = "A^{-1}"
342    output_name = "Intensity"
343    output_unit = "cm^{-1}"
344
345    #: default cutoff for polydispersity
346    cutoff = 1e-5
347
348    # Note: Use non-mutable values for class attributes to avoid errors
349    #: parameters that are not fitted
350    non_fittable = ()        # type: Sequence[str]
351
352    #: True if model should appear as a structure factor
353    is_structure_factor = False
354    #: True if model should appear as a form factor
355    is_form_factor = False
356    #: True if model has multiplicity
357    is_multiplicity_model = False
358    #: Multiplicity information
359    multiplicity_info = None # type: MultiplicityInfoType
360
361    # Per-instance variables
362    #: parameter {name: value} mapping
363    params = None      # type: Dict[str, float]
364    #: values for dispersion width, npts, nsigmas and type
365    dispersion = None  # type: Dict[str, Any]
366    #: units and limits for each parameter
367    details = None     # type: Dict[str, Sequence[Any]]
368    #                  # actual type is Dict[str, List[str, float, float]]
369    #: multiplicity value, or None if no multiplicity on the model
370    multiplicity = None     # type: Optional[int]
371    #: memory for polydispersity array if using ArrayDispersion (used by sasview).
372    _persistency_dict = None # type: Dict[str, Tuple[np.ndarray, np.ndarray]]
373
374    def __init__(self, multiplicity=None):
375        # type: (Optional[int]) -> None
376
377        # TODO: _persistency_dict to persistency_dict throughout sasview
378        # TODO: refactor multiplicity to encompass variants
379        # TODO: dispersion should be a class
380        # TODO: refactor multiplicity info
381        # TODO: separate profile view from multiplicity
382        # The button label, x and y axis labels and scale need to be under
383        # the control of the model, not the fit page.  Maximum flexibility,
384        # the fit page would supply the canvas and the profile could plot
385        # how it wants, but this assumes matplotlib.  Next level is that
386        # we provide some sort of data description including title, labels
387        # and lines to plot.
388
389        # Get the list of hidden parameters given the multiplicity
390        # Don't include multiplicity in the list of parameters
391        self.multiplicity = multiplicity
392        if multiplicity is not None:
393            hidden = self._model_info.get_hidden_parameters(multiplicity)
394            hidden |= set([self.multiplicity_info.control])
395        else:
396            hidden = set()
397        if self._model_info.structure_factor:
398            hidden.add('scale')
399            hidden.add('background')
400
401        # Update the parameter lists to exclude any hidden parameters
402        self.magnetic_params = tuple(pname for pname in self.magnetic_params
403                                     if pname not in hidden)
404        self.orientation_params = tuple(pname for pname in self.orientation_params
405                                        if pname not in hidden)
406
407        self._persistency_dict = {}
408        self.params = collections.OrderedDict()
409        self.dispersion = collections.OrderedDict()
410        self.details = {}
411        for p in self._model_info.parameters.user_parameters({}, is2d=True):
412            if p.name in hidden:
413                continue
414            self.params[p.name] = p.default
415            self.details[p.id] = [p.units, p.limits[0], p.limits[1]]
416            if p.polydisperse:
417                self.details[p.id+".width"] = [
418                    "", 0.0, 1.0 if p.relative_pd else np.inf
419                ]
420                self.dispersion[p.name] = {
421                    'width': 0,
422                    'npts': 35,
423                    'nsigmas': 3,
424                    'type': 'gaussian',
425                }
426
427    def __get_state__(self):
428        # type: () -> Dict[str, Any]
429        state = self.__dict__.copy()
430        state.pop('_model')
431        # May need to reload model info on set state since it has pointers
432        # to python implementations of Iq, etc.
433        #state.pop('_model_info')
434        return state
435
436    def __set_state__(self, state):
437        # type: (Dict[str, Any]) -> None
438        self.__dict__ = state
439        self._model = None
440
441    def __str__(self):
442        # type: () -> str
443        """
444        :return: string representation
445        """
446        return self.name
447
448    def is_fittable(self, par_name):
449        # type: (str) -> bool
450        """
451        Check if a given parameter is fittable or not
452
453        :param par_name: the parameter name to check
454        """
455        return par_name in self.fixed
456        #For the future
457        #return self.params[str(par_name)].is_fittable()
458
459
460    def getProfile(self):
461        # type: () -> (np.ndarray, np.ndarray)
462        """
463        Get SLD profile
464
465        : return: (z, beta) where z is a list of depth of the transition points
466                beta is a list of the corresponding SLD values
467        """
468        args = {} # type: Dict[str, Any]
469        for p in self._model_info.parameters.kernel_parameters:
470            if p.id == self.multiplicity_info.control:
471                value = float(self.multiplicity)
472            elif p.length == 1:
473                value = self.params.get(p.id, np.NaN)
474            else:
475                value = np.array([self.params.get(p.id+str(k), np.NaN)
476                                  for k in range(1, p.length+1)])
477            args[p.id] = value
478
479        x, y = self._model_info.profile(**args)
480        return x, 1e-6*y
481
482    def setParam(self, name, value):
483        # type: (str, float) -> None
484        """
485        Set the value of a model parameter
486
487        :param name: name of the parameter
488        :param value: value of the parameter
489
490        """
491        # Look for dispersion parameters
492        toks = name.split('.')
493        if len(toks) == 2:
494            for item in self.dispersion.keys():
495                if item == toks[0]:
496                    for par in self.dispersion[item]:
497                        if par == toks[1]:
498                            self.dispersion[item][par] = value
499                            return
500        else:
501            # Look for standard parameter
502            for item in self.params.keys():
503                if item == name:
504                    self.params[item] = value
505                    return
506
507        raise ValueError("Model does not contain parameter %s" % name)
508
509    def getParam(self, name):
510        # type: (str) -> float
511        """
512        Set the value of a model parameter
513
514        :param name: name of the parameter
515
516        """
517        # Look for dispersion parameters
518        toks = name.split('.')
519        if len(toks) == 2:
520            for item in self.dispersion.keys():
521                if item == toks[0]:
522                    for par in self.dispersion[item]:
523                        if par == toks[1]:
524                            return self.dispersion[item][par]
525        else:
526            # Look for standard parameter
527            for item in self.params.keys():
528                if item == name:
529                    return self.params[item]
530
531        raise ValueError("Model does not contain parameter %s" % name)
532
533    def getParamList(self):
534        # type: () -> Sequence[str]
535        """
536        Return a list of all available parameters for the model
537        """
538        param_list = list(self.params.keys())
539        # WARNING: Extending the list with the dispersion parameters
540        param_list.extend(self.getDispParamList())
541        return param_list
542
543    def getDispParamList(self):
544        # type: () -> Sequence[str]
545        """
546        Return a list of polydispersity parameters for the model
547        """
548        # TODO: fix test so that parameter order doesn't matter
549        ret = ['%s.%s' % (p_name, ext)
550               for p_name in self.dispersion.keys()
551               for ext in ('npts', 'nsigmas', 'width')]
552        #print(ret)
553        return ret
554
555    def clone(self):
556        # type: () -> "SasviewModel"
557        """ Return a identical copy of self """
558        return deepcopy(self)
559
560    def run(self, x=0.0):
561        # type: (Union[float, (float, float), List[float]]) -> float
562        """
563        Evaluate the model
564
565        :param x: input q, or [q,phi]
566
567        :return: scattering function P(q)
568
569        **DEPRECATED**: use calculate_Iq instead
570        """
571        if isinstance(x, (list, tuple)):
572            # pylint: disable=unpacking-non-sequence
573            q, phi = x
574            result, _ = self.calculate_Iq([q*math.cos(phi)], [q*math.sin(phi)])
575            return result[0]
576        else:
577            result, _ = self.calculate_Iq([x])
578            return result[0]
579
580
581    def runXY(self, x=0.0):
582        # type: (Union[float, (float, float), List[float]]) -> float
583        """
584        Evaluate the model in cartesian coordinates
585
586        :param x: input q, or [qx, qy]
587
588        :return: scattering function P(q)
589
590        **DEPRECATED**: use calculate_Iq instead
591        """
592        if isinstance(x, (list, tuple)):
593            result, _ = self.calculate_Iq([x[0]], [x[1]])
594            return result[0]
595        else:
596            result, _ = self.calculate_Iq([x])
597            return result[0]
598
599    def evalDistribution(self, qdist):
600        # type: (Union[np.ndarray, Tuple[np.ndarray, np.ndarray], List[np.ndarray]]) -> np.ndarray
601        r"""
602        Evaluate a distribution of q-values.
603
604        :param qdist: array of q or a list of arrays [qx,qy]
605
606        * For 1D, a numpy array is expected as input
607
608        ::
609
610            evalDistribution(q)
611
612          where *q* is a numpy array.
613
614        * For 2D, a list of *[qx,qy]* is expected with 1D arrays as input
615
616        ::
617
618              qx = [ qx[0], qx[1], qx[2], ....]
619              qy = [ qy[0], qy[1], qy[2], ....]
620
621        If the model is 1D only, then
622
623        .. math::
624
625            q = \sqrt{q_x^2+q_y^2}
626
627        """
628        if isinstance(qdist, (list, tuple)):
629            # Check whether we have a list of ndarrays [qx,qy]
630            qx, qy = qdist
631            result, _ = self.calculate_Iq(qx, qy)
632            return result
633
634        elif isinstance(qdist, np.ndarray):
635            # We have a simple 1D distribution of q-values
636            result, _ = self.calculate_Iq(qdist)
637            return result
638
639        else:
640            raise TypeError("evalDistribution expects q or [qx, qy], not %r"
641                            % type(qdist))
642
643    def calc_composition_models(self, qx):
644        """
645        returns parts of the composition model or None if not a composition
646        model.
647        """
648        # TODO: have calculate_Iq return the intermediates.
649        #
650        # The current interface causes calculate_Iq() to be called twice,
651        # once to get the combined result and again to get the intermediate
652        # results.  This is necessary for now.
653        # Long term, the solution is to change the interface to calculate_Iq
654        # so that it returns a results object containing all the bits:
655        #     the A, B, C, ... of the composition model (and any subcomponents?)
656        #     the P and S of the product model
657        #     the combined model before resolution smearing,
658        #     the sasmodel before sesans conversion,
659        #     the oriented 2D model used to fit oriented usans data,
660        #     the final I(q),
661        #     ...
662        #
663        # Have the model calculator add all of these blindly to the data
664        # tree, and update the graphs which contain them.  The fitter
665        # needs to be updated to use the I(q) value only, ignoring the rest.
666        #
667        # The simple fix of returning the existing intermediate results
668        # will not work for a couple of reasons: (1) another thread may
669        # sneak in to compute its own results before calc_composition_models
670        # is called, and (2) calculate_Iq is currently called three times:
671        # once with q, once with q values before qmin and once with q values
672        # after q max.  Both of these should be addressed before
673        # replacing this code.
674        composition = self._model_info.composition
675        if composition and composition[0] == 'product': # only P*S for now
676            with calculation_lock:
677                _, lazy_results = self._calculate_Iq(qx)
678                # for compatibility with sasview 4.x
679                results = lazy_results()
680                pq_data = results.get("P(Q)")
681                sq_data = results.get("S(Q)")
682                return pq_data, sq_data
683        else:
684            return None
685
686    def calculate_Iq(self,
687                     qx,     # type: Sequence[float]
688                     qy=None # type: Optional[Sequence[float]]
689                    ):
690        # type: (...) -> Tuple[np.ndarray, Callable[[], collections.OrderedDict[str, np.ndarray]]]
691        """
692        Calculate Iq for one set of q with the current parameters.
693
694        If the model is 1D, use *q*.  If 2D, use *qx*, *qy*.
695
696        This should NOT be used for fitting since it copies the *q* vectors
697        to the card for each evaluation.
698
699        The returned tuple contains the scattering intensity followed by a
700        callable which returns a dictionary of intermediate data from
701        ProductKernel.
702        """
703        ## uncomment the following when trying to debug the uncoordinated calls
704        ## to calculate_Iq
705        #if calculation_lock.locked():
706        #    logger.info("calculation waiting for another thread to complete")
707        #    logger.info("\n".join(traceback.format_stack()))
708
709        with calculation_lock:
710            return self._calculate_Iq(qx, qy)
711
712    def _calculate_Iq(self, qx, qy=None):
713        if self._model is None:
714            # Only need one copy of the compiled kernel regardless of how many
715            # times it is used, so store it in the class.  Also, to reset the
716            # compute engine, need to clear out all existing compiled kernels,
717            # which is much easier to do if we store them in the class.
718            self.__class__._model = core.build_model(self._model_info)
719        if qy is not None:
720            q_vectors = [np.asarray(qx), np.asarray(qy)]
721        else:
722            q_vectors = [np.asarray(qx)]
723        calculator = self._model.make_kernel(q_vectors)
724        parameters = self._model_info.parameters
725        pairs = [self._get_weights(p) for p in parameters.call_parameters]
726        #weights.plot_weights(self._model_info, pairs)
727        call_details, values, is_magnetic = make_kernel_args(calculator, pairs)
728        #call_details.show()
729        #print("================ parameters ==================")
730        #for p, v in zip(parameters.call_parameters, pairs): print(p.name, v[0])
731        #for k, p in enumerate(self._model_info.parameters.call_parameters):
732        #    print(k, p.name, *pairs[k])
733        #print("params", self.params)
734        #print("values", values)
735        #print("is_mag", is_magnetic)
736        result = calculator(call_details, values, cutoff=self.cutoff,
737                            magnetic=is_magnetic)
738        lazy_results = getattr(calculator, 'results',
739                               lambda: collections.OrderedDict())
740        #print("result", result)
741
742        calculator.release()
743        #self._model.release()
744
745        return result, lazy_results
746
747
748    def calculate_ER(self, mode=1):
749        # type: () -> float
750        """
751        Calculate the effective radius for P(q)*S(q)
752
753        *mode* is the R_eff type, which defaults to 1 to match the ER
754        calculation for sasview models from version 3.x.
755
756        :return: the value of the effective radius
757        """
758        # ER and VR are only needed for old multiplication models, based on
759        # sas.sascalc.fit.MultiplicationModel.  Fail for now.  If we want to
760        # continue supporting them then add some test cases so that the code
761        # is exercised.  We can access ER/VR using the kernel Fq function by
762        # extending _calculate_Iq so that it calls:
763        #    if er_mode > 0:
764        #        res = calculator.Fq(call_details, values, cutoff=self.cutoff,
765        #                            magnetic=False, effective_radius_type=mode)
766        #        R_eff, form_shell_ratio = res[2], res[4]
767        #        return R_eff, form_shell_ratio
768        # Then use the following in calculate_ER:
769        #    ER, VR = self._calculate_Iq(q=[0.1], er_mode=mode)
770        #    return ER
771        # Similarly, for calculate_VR:
772        #    ER, VR = self._calculate_Iq(q=[0.1], er_mode=1)
773        #    return VR
774        # Obviously a combined calculate_ER_VR method would be better, but
775        # we only need them to support very old models, so ignore the 2x
776        # performance hit.
777        raise NotImplementedError("ER function is no longer available.")
778
779    def calculate_VR(self):
780        # type: () -> float
781        """
782        Calculate the volf ratio for P(q)*S(q)
783
784        :return: the value of the form:shell volume ratio
785        """
786        # See comments in calculate_ER.
787        raise NotImplementedError("VR function is no longer available.")
788
789    def set_dispersion(self, parameter, dispersion):
790        # type: (str, weights.Dispersion) -> None
791        """
792        Set the dispersion object for a model parameter
793
794        :param parameter: name of the parameter [string]
795        :param dispersion: dispersion object of type Dispersion
796        """
797        if parameter in self.params:
798            # TODO: Store the disperser object directly in the model.
799            # The current method of relying on the sasview GUI to
800            # remember them is kind of funky.
801            # Note: can't seem to get disperser parameters from sasview
802            # (1) Could create a sasview model that has not yet been
803            # converted, assign the disperser to one of its polydisperse
804            # parameters, then retrieve the disperser parameters from the
805            # sasview model.
806            # (2) Could write a disperser parameter retriever in sasview.
807            # (3) Could modify sasview to use sasmodels.weights dispersers.
808            # For now, rely on the fact that the sasview only ever uses
809            # new dispersers in the set_dispersion call and create a new
810            # one instead of trying to assign parameters.
811            self.dispersion[parameter] = dispersion.get_pars()
812        else:
813            raise ValueError("%r is not a dispersity or orientation parameter"
814                             % parameter)
815
816    def _dispersion_mesh(self):
817        # type: () -> List[Tuple[np.ndarray, np.ndarray]]
818        """
819        Create a mesh grid of dispersion parameters and weights.
820
821        Returns [p1,p2,...],w where pj is a vector of values for parameter j
822        and w is a vector containing the products for weights for each
823        parameter set in the vector.
824        """
825        pars = [self._get_weights(p)
826                for p in self._model_info.parameters.call_parameters
827                if p.type == 'volume']
828        return dispersion_mesh(self._model_info, pars)
829
830    def _get_weights(self, par):
831        # type: (Parameter) -> Tuple[np.ndarray, np.ndarray]
832        """
833        Return dispersion weights for parameter
834        """
835        if par.name not in self.params:
836            if par.id == self.multiplicity_info.control:
837                return self.multiplicity, [self.multiplicity], [1.0]
838            else:
839                # For hidden parameters use default values.  This sets
840                # scale=1 and background=0 for structure factors
841                default = self._model_info.parameters.defaults.get(par.name, np.NaN)
842                return default, [default], [1.0]
843        elif par.polydisperse:
844            value = self.params[par.name]
845            dis = self.dispersion[par.name]
846            if dis['type'] == 'array':
847                dispersity, weight = dis['values'], dis['weights']
848            else:
849                dispersity, weight = weights.get_weights(
850                    dis['type'], dis['npts'], dis['width'], dis['nsigmas'],
851                    value, par.limits, par.relative_pd)
852            return value, dispersity, weight
853        else:
854            value = self.params[par.name]
855            return value, [value], [1.0]
856
857    @classmethod
858    def runTests(cls):
859        """
860        Run any tests built into the model and captures the test output.
861
862        Returns success flag and output
863        """
864        from .model_test import check_model
865        return check_model(cls._model_info)
866
867def test_cylinder():
868    # type: () -> float
869    """
870    Test that the cylinder model runs, returning the value at [0.1,0.1].
871    """
872    Cylinder = _make_standard_model('cylinder')
873    cylinder = Cylinder()
874    return cylinder.evalDistribution([0.1, 0.1])
875
876def test_structure_factor():
877    # type: () -> float
878    """
879    Test that 2-D hardsphere model runs and doesn't produce NaN.
880    """
881    Model = _make_standard_model('hardsphere')
882    model = Model()
883    value2d = model.evalDistribution([0.1, 0.1])
884    value1d = model.evalDistribution(np.array([0.1*np.sqrt(2)]))
885    #print("hardsphere", value1d, value2d)
886    if np.isnan(value1d) or np.isnan(value2d):
887        raise ValueError("hardsphere returns nan")
888
889def test_product():
890    # type: () -> float
891    """
892    Test that 2-D hardsphere model runs and doesn't produce NaN.
893    """
894    S = _make_standard_model('hayter_msa')()
895    P = _make_standard_model('cylinder')()
896    model = MultiplicationModel(P, S)
897    model.setParam(product.RADIUS_MODE_ID, 1.0)
898    value = model.evalDistribution([0.1, 0.1])
899    if np.isnan(value):
900        raise ValueError("cylinder*hatyer_msa returns null")
901
902def test_rpa():
903    # type: () -> float
904    """
905    Test that the 2-D RPA model runs
906    """
907    RPA = _make_standard_model('rpa')
908    rpa = RPA(3)
909    return rpa.evalDistribution([0.1, 0.1])
910
911def test_empty_distribution():
912    # type: () -> None
913    """
914    Make sure that sasmodels returns NaN when there are no polydispersity points
915    """
916    Cylinder = _make_standard_model('cylinder')
917    cylinder = Cylinder()
918    cylinder.setParam('radius', -1.0)
919    cylinder.setParam('background', 0.)
920    Iq = cylinder.evalDistribution(np.asarray([0.1]))
921    assert Iq[0] == 0., "empty distribution fails"
922
923def test_model_list():
924    # type: () -> None
925    """
926    Make sure that all models build as sasview models
927    """
928    from .exception import annotate_exception
929    for name in core.list_models():
930        try:
931            _make_standard_model(name)
932        except:
933            annotate_exception("when loading "+name)
934            raise
935
936def test_old_name():
937    # type: () -> None
938    """
939    Load and run cylinder model as sas.models.CylinderModel
940    """
941    if not SUPPORT_OLD_STYLE_PLUGINS:
942        return
943    try:
944        # if sasview is not on the path then don't try to test it
945        import sas
946    except ImportError:
947        return
948    load_standard_models()
949    from sas.models.CylinderModel import CylinderModel
950    CylinderModel().evalDistribution([0.1, 0.1])
951
952def test_structure_factor_background():
953    # type: () -> None
954    """
955    Check that sasview model and direct model match, with background=0.
956    """
957    from .data import empty_data1D
958    from .core import load_model_info, build_model
959    from .direct_model import DirectModel
960
961    model_name = "hardsphere"
962    q = [0.0]
963
964    sasview_model = _make_standard_model(model_name)()
965    sasview_value = sasview_model.evalDistribution(np.array(q))[0]
966
967    data = empty_data1D(q)
968    model_info = load_model_info(model_name)
969    model = build_model(model_info)
970    direct_model = DirectModel(data, model)
971    direct_value_zero_background = direct_model(background=0.0)
972
973    assert sasview_value == direct_value_zero_background
974
975    # Additionally check that direct value background defaults to zero
976    direct_value_default = direct_model()
977    assert sasview_value == direct_value_default
978
979
980def magnetic_demo():
981    """
982    Demostrate call to magnetic model.
983    """
984    Model = _make_standard_model('sphere')
985    model = Model()
986    model.setParam('sld_M0', 8)
987    q = np.linspace(-0.35, 0.35, 500)
988    qx, qy = np.meshgrid(q, q)
989    result, _ = model.calculate_Iq(qx.flatten(), qy.flatten())
990    result = result.reshape(qx.shape)
991
992    import pylab
993    pylab.imshow(np.log(result + 0.001))
994    pylab.show()
995
996if __name__ == "__main__":
997    print("cylinder(0.1,0.1)=%g"%test_cylinder())
998    #magnetic_demo()
999    #test_product()
1000    #test_structure_factor()
1001    #print("rpa:", test_rpa())
1002    #test_empty_distribution()
1003    #test_structure_factor_background()
Note: See TracBrowser for help on using the repository browser.