source: sasmodels/sasmodels/sasview_model.py

Last change on this file was d0b0f5d, checked in by GitHub <noreply@…>, 5 years ago

Merge pull request #67 from SasView?/ticket-608-user-defined-weights

Ticket 608 user defined weights

  • Property mode set to 100644
File size: 36.0 KB
Line 
1"""
2Sasview model constructor.
3
4Given a module defining an OpenCL kernel such as sasmodels.models.cylinder,
5create a sasview model class to run that kernel as follows::
6
7    from sasmodels.sasview_model import load_custom_model
8    CylinderModel = load_custom_model('sasmodels/models/cylinder.py')
9"""
10from __future__ import print_function
11
12import math
13from copy import deepcopy
14import collections
15import traceback
16import logging
17from os.path import basename, splitext, abspath
18try:
19    import _thread as thread
20except ImportError:
21    import thread
22
23import numpy as np  # type: ignore
24
25from . import core
26from . import custom
27from . import kernelcl
28from . import product
29from . import generate
30from . import weights
31from . import modelinfo
32from .details import make_kernel_args, dispersion_mesh
33
34# Hack: load in any custom distributions
35# Uses ~/.sasview/weights/*.py unless SASMODELS_WEIGHTS is set in the environ.
36# Override with weights.load_weights(pattern="<weights_path>/*.py")
37weights.load_weights()
38
39# pylint: disable=unused-import
40try:
41    from typing import (Dict, Mapping, Any, Sequence, Tuple, NamedTuple,
42                        List, Optional, Union, Callable)
43    from .modelinfo import ModelInfo, Parameter
44    from .kernel import KernelModel
45    MultiplicityInfoType = NamedTuple(
46        'MultiplicityInfo',
47        [("number", int), ("control", str), ("choices", List[str]),
48         ("x_axis_label", str)])
49    SasviewModelType = Callable[[int], "SasviewModel"]
50except ImportError:
51    pass
52# pylint: enable=unused-import
53
54logger = logging.getLogger(__name__)
55
56calculation_lock = thread.allocate_lock()
57
58#: True if pre-existing plugins, with the old names and parameters, should
59#: continue to be supported.
60SUPPORT_OLD_STYLE_PLUGINS = True
61
62# TODO: separate x_axis_label from multiplicity info
63MultiplicityInfo = collections.namedtuple(
64    'MultiplicityInfo',
65    ["number", "control", "choices", "x_axis_label"],
66)
67
68#: set of defined models (standard and custom)
69MODELS = {}  # type: Dict[str, SasviewModelType]
70# TODO: remove unused MODEL_BY_PATH cache once sasview no longer references it
71#: custom model {path: model} mapping so we can check timestamps
72MODEL_BY_PATH = {}  # type: Dict[str, SasviewModelType]
73#: Track modules that we have loaded so we can determine whether the model
74#: has changed since we last reloaded.
75_CACHED_MODULE = {}  # type: Dict[str, "module"]
76
77def reset_environment():
78    # type: () -> None
79    """
80    Clear the compute engine context so that the GUI can change devices.
81
82    This removes all compiled kernels, even those that are active on fit
83    pages, but they will be restored the next time they are needed.
84    """
85    kernelcl.reset_environment()
86    for model in MODELS.values():
87        model._model = None
88
89def find_model(modelname):
90    # type: (str) -> SasviewModelType
91    """
92    Find a model by name.  If the model name ends in py, try loading it from
93    custom models, otherwise look for it in the list of builtin models.
94    """
95    # TODO: used by sum/product model to load an existing model
96    # TODO: doesn't handle custom models properly
97    if modelname.endswith('.py'):
98        return load_custom_model(modelname)
99    elif modelname in MODELS:
100        return MODELS[modelname]
101    else:
102        raise ValueError("unknown model %r"%modelname)
103
104
105# TODO: figure out how to say that the return type is a subclass
106def load_standard_models():
107    # type: () -> List[SasviewModelType]
108    """
109    Load and return the list of predefined models.
110
111    If there is an error loading a model, then a traceback is logged and the
112    model is not returned.
113    """
114    for name in core.list_models():
115        try:
116            MODELS[name] = _make_standard_model(name)
117        except Exception:
118            logger.error(traceback.format_exc())
119    if SUPPORT_OLD_STYLE_PLUGINS:
120        _register_old_models()
121
122    return list(MODELS.values())
123
124
125def load_custom_model(path):
126    # type: (str) -> SasviewModelType
127    """
128    Load a custom model given the model path.
129    """
130    #logger.info("Loading model %s", path)
131
132    # Load the kernel module.  This may already be cached by the loader, so
133    # only requires checking the timestamps of the dependents.
134    kernel_module = custom.load_custom_kernel_module(path)
135
136    # Check if the module has changed since we last looked.
137    reloaded = kernel_module != _CACHED_MODULE.get(path, None)
138    _CACHED_MODULE[path] = kernel_module
139
140    # Turn the module into a model.  We need to do this in even if the
141    # model has already been loaded so that we can determine the model
142    # name and retrieve it from the MODELS cache.
143    model = getattr(kernel_module, 'Model', None)
144    if model is not None:
145        # Old style models do not set the name in the class attributes, so
146        # set it here; this name will be overridden when the object is created
147        # with an instance variable that has the same value.
148        if model.name == "":
149            model.name = splitext(basename(path))[0]
150        if not hasattr(model, 'filename'):
151            model.filename = abspath(kernel_module.__file__).replace('.pyc', '.py')
152        if not hasattr(model, 'id'):
153            model.id = splitext(basename(model.filename))[0]
154    else:
155        model_info = modelinfo.make_model_info(kernel_module)
156        model = make_model_from_info(model_info)
157
158    # If a model name already exists and we are loading a different model,
159    # use the model file name as the model name.
160    if model.name in MODELS and not model.filename == MODELS[model.name].filename:
161        _previous_name = model.name
162        model.name = model.id
163
164        # If the new model name is still in the model list (for instance,
165        # if we put a cylinder.py in our plug-in directory), then append
166        # an identifier.
167        if model.name in MODELS and not model.filename == MODELS[model.name].filename:
168            model.name = model.id + '_user'
169        logger.info("Model %s already exists: using %s [%s]",
170                    _previous_name, model.name, model.filename)
171
172    # Only update the model if the module has changed
173    if reloaded or model.name not in MODELS:
174        MODELS[model.name] = model
175
176    return MODELS[model.name]
177
178
179def make_model_from_info(model_info):
180    # type: (ModelInfo) -> SasviewModelType
181    """
182    Convert *model_info* into a SasView model wrapper.
183    """
184    def __init__(self, multiplicity=None):
185        SasviewModel.__init__(self, multiplicity=multiplicity)
186    attrs = _generate_model_attributes(model_info)
187    attrs['__init__'] = __init__
188    attrs['filename'] = model_info.filename
189    ConstructedModel = type(model_info.name, (SasviewModel,), attrs) # type: SasviewModelType
190    return ConstructedModel
191
192
193def _make_standard_model(name):
194    # type: (str) -> SasviewModelType
195    """
196    Load the sasview model defined by *name*.
197
198    *name* can be a standard model name or a path to a custom model.
199
200    Returns a class that can be used directly as a sasview model.
201    """
202    kernel_module = generate.load_kernel_module(name)
203    model_info = modelinfo.make_model_info(kernel_module)
204    return make_model_from_info(model_info)
205
206
207def _register_old_models():
208    # type: () -> None
209    """
210    Place the new models into sasview under the old names.
211
212    Monkey patch sas.sascalc.fit as sas.models so that sas.models.pluginmodel
213    is available to the plugin modules.
214    """
215    import sys
216    import sas   # needed in order to set sas.models
217    import sas.sascalc.fit
218    sys.modules['sas.models'] = sas.sascalc.fit
219    sas.models = sas.sascalc.fit
220    import sas.models
221    from sasmodels.conversion_table import CONVERSION_TABLE
222
223    for new_name, conversion in CONVERSION_TABLE.get((3, 1, 2), {}).items():
224        # CoreShellEllipsoidModel => core_shell_ellipsoid:1
225        new_name = new_name.split(':')[0]
226        old_name = conversion[0] if len(conversion) < 3 else conversion[2]
227        module_attrs = {old_name: find_model(new_name)}
228        ConstructedModule = type(old_name, (), module_attrs)
229        old_path = 'sas.models.' + old_name
230        setattr(sas.models, old_path, ConstructedModule)
231        sys.modules[old_path] = ConstructedModule
232
233
234def MultiplicationModel(form_factor, structure_factor):
235    # type: ("SasviewModel", "SasviewModel") -> "SasviewModel"
236    """
237    Returns a constructed product model from form_factor and structure_factor.
238    """
239    model_info = product.make_product_info(form_factor._model_info,
240                                           structure_factor._model_info)
241    ConstructedModel = make_model_from_info(model_info)
242    return ConstructedModel(form_factor.multiplicity)
243
244
245def _generate_model_attributes(model_info):
246    # type: (ModelInfo) -> Dict[str, Any]
247    """
248    Generate the class attributes for the model.
249
250    This should include all the information necessary to query the model
251    details so that you do not need to instantiate a model to query it.
252
253    All the attributes should be immutable to avoid accidents.
254    """
255
256    # TODO: allow model to override axis labels input/output name/unit
257
258    # Process multiplicity
259    control_pars = [p.id for p in model_info.parameters.kernel_parameters
260                    if p.is_control]
261    control_id = control_pars[0] if control_pars else None
262    non_fittable = []  # type: List[str]
263    xlabel = model_info.profile_axes[0] if model_info.profile is not None else ""
264    variants = MultiplicityInfo(0, "", [], xlabel)
265    for p in model_info.parameters.kernel_parameters:
266        if p.id == control_id:
267            non_fittable.append(p.name)
268            variants = MultiplicityInfo(
269                len(p.choices) if p.choices else int(p.limits[1]),
270                p.name, p.choices, xlabel
271            )
272            break
273
274    # Only a single drop-down list parameter available
275    fun_list = []
276    for p in model_info.parameters.kernel_parameters:
277        if p.choices:
278            fun_list = p.choices
279            if p.length > 1:
280                non_fittable.extend(p.id+str(k) for k in range(1, p.length+1))
281            break
282
283    # Organize parameter sets
284    orientation_params = []
285    magnetic_params = []
286    fixed = []
287    for p in model_info.parameters.user_parameters({}, is2d=True):
288        if p.type == 'orientation':
289            orientation_params.append(p.name)
290            orientation_params.append(p.name+".width")
291            fixed.append(p.name+".width")
292        elif p.type == 'magnetic':
293            orientation_params.append(p.name)
294            magnetic_params.append(p.name)
295            fixed.append(p.name+".width")
296
297
298    # Build class dictionary
299    attrs = {}  # type: Dict[str, Any]
300    attrs['_model_info'] = model_info
301    attrs['name'] = model_info.name
302    attrs['id'] = model_info.id
303    attrs['description'] = model_info.description
304    attrs['category'] = model_info.category
305    attrs['is_structure_factor'] = model_info.structure_factor
306    attrs['is_form_factor'] = model_info.radius_effective_modes is not None
307    attrs['is_multiplicity_model'] = variants[0] > 1
308    attrs['multiplicity_info'] = variants
309    attrs['orientation_params'] = tuple(orientation_params)
310    attrs['magnetic_params'] = tuple(magnetic_params)
311    attrs['fixed'] = tuple(fixed)
312    attrs['non_fittable'] = tuple(non_fittable)
313    attrs['fun_list'] = tuple(fun_list)
314
315    return attrs
316
317class SasviewModel(object):
318    """
319    Sasview wrapper for opencl/ctypes model.
320    """
321    # Model parameters for the specific model are set in the class constructor
322    # via the _generate_model_attributes function, which subclasses
323    # SasviewModel.  They are included here for typing and documentation
324    # purposes.
325    _model = None       # type: KernelModel
326    _model_info = None  # type: ModelInfo
327    #: load/save name for the model
328    id = None           # type: str
329    #: display name for the model
330    name = None         # type: str
331    #: short model description
332    description = None  # type: str
333    #: default model category
334    category = None     # type: str
335
336    #: names of the orientation parameters in the order they appear
337    orientation_params = None # type: List[str]
338    #: names of the magnetic parameters in the order they appear
339    magnetic_params = None    # type: List[str]
340    #: names of the fittable parameters
341    fixed = None              # type: List[str]
342    # TODO: the attribute fixed is ill-named
343
344    # Axis labels
345    input_name = "Q"
346    input_unit = "A^{-1}"
347    output_name = "Intensity"
348    output_unit = "cm^{-1}"
349
350    #: default cutoff for polydispersity
351    cutoff = 1e-5
352
353    # Note: Use non-mutable values for class attributes to avoid errors
354    #: parameters that are not fitted
355    non_fittable = ()        # type: Sequence[str]
356
357    #: True if model should appear as a structure factor
358    is_structure_factor = False
359    #: True if model should appear as a form factor
360    is_form_factor = False
361    #: True if model has multiplicity
362    is_multiplicity_model = False
363    #: Multiplicity information
364    multiplicity_info = None # type: MultiplicityInfoType
365
366    # Per-instance variables
367    #: parameter {name: value} mapping
368    params = None      # type: Dict[str, float]
369    #: values for dispersion width, npts, nsigmas and type
370    dispersion = None  # type: Dict[str, Any]
371    #: units and limits for each parameter
372    details = None     # type: Dict[str, Sequence[Any]]
373    #                  # actual type is Dict[str, List[str, float, float]]
374    #: multiplicity value, or None if no multiplicity on the model
375    multiplicity = None     # type: Optional[int]
376    #: memory for polydispersity array if using ArrayDispersion (used by sasview).
377    _persistency_dict = None # type: Dict[str, Tuple[np.ndarray, np.ndarray]]
378
379    def __init__(self, multiplicity=None):
380        # type: (Optional[int]) -> None
381
382        # TODO: _persistency_dict to persistency_dict throughout sasview
383        # TODO: refactor multiplicity to encompass variants
384        # TODO: dispersion should be a class
385        # TODO: refactor multiplicity info
386        # TODO: separate profile view from multiplicity
387        # The button label, x and y axis labels and scale need to be under
388        # the control of the model, not the fit page.  Maximum flexibility,
389        # the fit page would supply the canvas and the profile could plot
390        # how it wants, but this assumes matplotlib.  Next level is that
391        # we provide some sort of data description including title, labels
392        # and lines to plot.
393
394        # Get the list of hidden parameters given the multiplicity
395        # Don't include multiplicity in the list of parameters
396        self.multiplicity = multiplicity
397        if multiplicity is not None:
398            hidden = self._model_info.get_hidden_parameters(multiplicity)
399            hidden |= set([self.multiplicity_info.control])
400        else:
401            hidden = set()
402        if self._model_info.structure_factor:
403            hidden.add('scale')
404            hidden.add('background')
405
406        # Update the parameter lists to exclude any hidden parameters
407        self.magnetic_params = tuple(pname for pname in self.magnetic_params
408                                     if pname not in hidden)
409        self.orientation_params = tuple(pname for pname in self.orientation_params
410                                        if pname not in hidden)
411
412        self._persistency_dict = {}
413        self.params = collections.OrderedDict()
414        self.dispersion = collections.OrderedDict()
415        self.details = {}
416        for p in self._model_info.parameters.user_parameters({}, is2d=True):
417            if p.name in hidden:
418                continue
419            self.params[p.name] = p.default
420            self.details[p.id] = [p.units, p.limits[0], p.limits[1]]
421            if p.polydisperse:
422                self.details[p.id+".width"] = [
423                    "", 0.0, 1.0 if p.relative_pd else np.inf
424                ]
425                self.dispersion[p.name] = {
426                    'width': 0,
427                    'npts': 35,
428                    'nsigmas': 3,
429                    'type': 'gaussian',
430                }
431
432    def __get_state__(self):
433        # type: () -> Dict[str, Any]
434        state = self.__dict__.copy()
435        state.pop('_model')
436        # May need to reload model info on set state since it has pointers
437        # to python implementations of Iq, etc.
438        #state.pop('_model_info')
439        return state
440
441    def __set_state__(self, state):
442        # type: (Dict[str, Any]) -> None
443        self.__dict__ = state
444        self._model = None
445
446    def __str__(self):
447        # type: () -> str
448        """
449        :return: string representation
450        """
451        return self.name
452
453    def is_fittable(self, par_name):
454        # type: (str) -> bool
455        """
456        Check if a given parameter is fittable or not
457
458        :param par_name: the parameter name to check
459        """
460        return par_name in self.fixed
461        #For the future
462        #return self.params[str(par_name)].is_fittable()
463
464
465    def getProfile(self):
466        # type: () -> (np.ndarray, np.ndarray)
467        """
468        Get SLD profile
469
470        : return: (z, beta) where z is a list of depth of the transition points
471                beta is a list of the corresponding SLD values
472        """
473        args = {} # type: Dict[str, Any]
474        for p in self._model_info.parameters.kernel_parameters:
475            if p.id == self.multiplicity_info.control:
476                value = float(self.multiplicity)
477            elif p.length == 1:
478                value = self.params.get(p.id, np.NaN)
479            else:
480                value = np.array([self.params.get(p.id+str(k), np.NaN)
481                                  for k in range(1, p.length+1)])
482            args[p.id] = value
483
484        x, y = self._model_info.profile(**args)
485        return x, 1e-6*y
486
487    def setParam(self, name, value):
488        # type: (str, float) -> None
489        """
490        Set the value of a model parameter
491
492        :param name: name of the parameter
493        :param value: value of the parameter
494
495        """
496        # Look for dispersion parameters
497        toks = name.split('.')
498        if len(toks) == 2:
499            for item in self.dispersion.keys():
500                if item == toks[0]:
501                    for par in self.dispersion[item]:
502                        if par == toks[1]:
503                            self.dispersion[item][par] = value
504                            return
505        else:
506            # Look for standard parameter
507            for item in self.params.keys():
508                if item == name:
509                    self.params[item] = value
510                    return
511
512        raise ValueError("Model does not contain parameter %s" % name)
513
514    def getParam(self, name):
515        # type: (str) -> float
516        """
517        Set the value of a model parameter
518
519        :param name: name of the parameter
520
521        """
522        # Look for dispersion parameters
523        toks = name.split('.')
524        if len(toks) == 2:
525            for item in self.dispersion.keys():
526                if item == toks[0]:
527                    for par in self.dispersion[item]:
528                        if par == toks[1]:
529                            return self.dispersion[item][par]
530        else:
531            # Look for standard parameter
532            for item in self.params.keys():
533                if item == name:
534                    return self.params[item]
535
536        raise ValueError("Model does not contain parameter %s" % name)
537
538    def getParamList(self):
539        # type: () -> Sequence[str]
540        """
541        Return a list of all available parameters for the model
542        """
543        param_list = list(self.params.keys())
544        # WARNING: Extending the list with the dispersion parameters
545        param_list.extend(self.getDispParamList())
546        return param_list
547
548    def getDispParamList(self):
549        # type: () -> Sequence[str]
550        """
551        Return a list of polydispersity parameters for the model
552        """
553        # TODO: fix test so that parameter order doesn't matter
554        ret = ['%s.%s' % (p_name, ext)
555               for p_name in self.dispersion.keys()
556               for ext in ('npts', 'nsigmas', 'width')]
557        #print(ret)
558        return ret
559
560    def clone(self):
561        # type: () -> "SasviewModel"
562        """ Return a identical copy of self """
563        return deepcopy(self)
564
565    def run(self, x=0.0):
566        # type: (Union[float, (float, float), List[float]]) -> float
567        """
568        Evaluate the model
569
570        :param x: input q, or [q,phi]
571
572        :return: scattering function P(q)
573
574        **DEPRECATED**: use calculate_Iq instead
575        """
576        if isinstance(x, (list, tuple)):
577            # pylint: disable=unpacking-non-sequence
578            q, phi = x
579            result, _ = self.calculate_Iq([q*math.cos(phi)], [q*math.sin(phi)])
580            return result[0]
581        else:
582            result, _ = self.calculate_Iq([x])
583            return result[0]
584
585
586    def runXY(self, x=0.0):
587        # type: (Union[float, (float, float), List[float]]) -> float
588        """
589        Evaluate the model in cartesian coordinates
590
591        :param x: input q, or [qx, qy]
592
593        :return: scattering function P(q)
594
595        **DEPRECATED**: use calculate_Iq instead
596        """
597        if isinstance(x, (list, tuple)):
598            result, _ = self.calculate_Iq([x[0]], [x[1]])
599            return result[0]
600        else:
601            result, _ = self.calculate_Iq([x])
602            return result[0]
603
604    def evalDistribution(self, qdist):
605        # type: (Union[np.ndarray, Tuple[np.ndarray, np.ndarray], List[np.ndarray]]) -> np.ndarray
606        r"""
607        Evaluate a distribution of q-values.
608
609        :param qdist: array of q or a list of arrays [qx,qy]
610
611        * For 1D, a numpy array is expected as input
612
613        ::
614
615            evalDistribution(q)
616
617          where *q* is a numpy array.
618
619        * For 2D, a list of *[qx,qy]* is expected with 1D arrays as input
620
621        ::
622
623              qx = [ qx[0], qx[1], qx[2], ....]
624              qy = [ qy[0], qy[1], qy[2], ....]
625
626        If the model is 1D only, then
627
628        .. math::
629
630            q = \sqrt{q_x^2+q_y^2}
631
632        """
633        if isinstance(qdist, (list, tuple)):
634            # Check whether we have a list of ndarrays [qx,qy]
635            qx, qy = qdist
636            result, _ = self.calculate_Iq(qx, qy)
637            return result
638
639        elif isinstance(qdist, np.ndarray):
640            # We have a simple 1D distribution of q-values
641            result, _ = self.calculate_Iq(qdist)
642            return result
643
644        else:
645            raise TypeError("evalDistribution expects q or [qx, qy], not %r"
646                            % type(qdist))
647
648    def calc_composition_models(self, qx):
649        """
650        returns parts of the composition model or None if not a composition
651        model.
652        """
653        # TODO: have calculate_Iq return the intermediates.
654        #
655        # The current interface causes calculate_Iq() to be called twice,
656        # once to get the combined result and again to get the intermediate
657        # results.  This is necessary for now.
658        # Long term, the solution is to change the interface to calculate_Iq
659        # so that it returns a results object containing all the bits:
660        #     the A, B, C, ... of the composition model (and any subcomponents?)
661        #     the P and S of the product model
662        #     the combined model before resolution smearing,
663        #     the sasmodel before sesans conversion,
664        #     the oriented 2D model used to fit oriented usans data,
665        #     the final I(q),
666        #     ...
667        #
668        # Have the model calculator add all of these blindly to the data
669        # tree, and update the graphs which contain them.  The fitter
670        # needs to be updated to use the I(q) value only, ignoring the rest.
671        #
672        # The simple fix of returning the existing intermediate results
673        # will not work for a couple of reasons: (1) another thread may
674        # sneak in to compute its own results before calc_composition_models
675        # is called, and (2) calculate_Iq is currently called three times:
676        # once with q, once with q values before qmin and once with q values
677        # after q max.  Both of these should be addressed before
678        # replacing this code.
679        composition = self._model_info.composition
680        if composition and composition[0] == 'product': # only P*S for now
681            with calculation_lock:
682                _, lazy_results = self._calculate_Iq(qx)
683                # for compatibility with sasview 4.x
684                results = lazy_results()
685                pq_data = results.get("P(Q)")
686                sq_data = results.get("S(Q)")
687                return pq_data, sq_data
688        else:
689            return None
690
691    def calculate_Iq(self,
692                     qx,     # type: Sequence[float]
693                     qy=None # type: Optional[Sequence[float]]
694                    ):
695        # type: (...) -> Tuple[np.ndarray, Callable[[], collections.OrderedDict[str, np.ndarray]]]
696        """
697        Calculate Iq for one set of q with the current parameters.
698
699        If the model is 1D, use *q*.  If 2D, use *qx*, *qy*.
700
701        This should NOT be used for fitting since it copies the *q* vectors
702        to the card for each evaluation.
703
704        The returned tuple contains the scattering intensity followed by a
705        callable which returns a dictionary of intermediate data from
706        ProductKernel.
707        """
708        ## uncomment the following when trying to debug the uncoordinated calls
709        ## to calculate_Iq
710        #if calculation_lock.locked():
711        #    logger.info("calculation waiting for another thread to complete")
712        #    logger.info("\n".join(traceback.format_stack()))
713
714        with calculation_lock:
715            return self._calculate_Iq(qx, qy)
716
717    def _calculate_Iq(self, qx, qy=None):
718        if self._model is None:
719            # Only need one copy of the compiled kernel regardless of how many
720            # times it is used, so store it in the class.  Also, to reset the
721            # compute engine, need to clear out all existing compiled kernels,
722            # which is much easier to do if we store them in the class.
723            self.__class__._model = core.build_model(self._model_info)
724        if qy is not None:
725            q_vectors = [np.asarray(qx), np.asarray(qy)]
726        else:
727            q_vectors = [np.asarray(qx)]
728        calculator = self._model.make_kernel(q_vectors)
729        parameters = self._model_info.parameters
730        pairs = [self._get_weights(p) for p in parameters.call_parameters]
731        #weights.plot_weights(self._model_info, pairs)
732        call_details, values, is_magnetic = make_kernel_args(calculator, pairs)
733        #call_details.show()
734        #print("================ parameters ==================")
735        #for p, v in zip(parameters.call_parameters, pairs): print(p.name, v[0])
736        #for k, p in enumerate(self._model_info.parameters.call_parameters):
737        #    print(k, p.name, *pairs[k])
738        #print("params", self.params)
739        #print("values", values)
740        #print("is_mag", is_magnetic)
741        result = calculator(call_details, values, cutoff=self.cutoff,
742                            magnetic=is_magnetic)
743        lazy_results = getattr(calculator, 'results',
744                               lambda: collections.OrderedDict())
745        #print("result", result)
746
747        calculator.release()
748        #self._model.release()
749
750        return result, lazy_results
751
752
753    def calculate_ER(self, mode=1):
754        # type: () -> float
755        """
756        Calculate the effective radius for P(q)*S(q)
757
758        *mode* is the R_eff type, which defaults to 1 to match the ER
759        calculation for sasview models from version 3.x.
760
761        :return: the value of the effective radius
762        """
763        # ER and VR are only needed for old multiplication models, based on
764        # sas.sascalc.fit.MultiplicationModel.  Fail for now.  If we want to
765        # continue supporting them then add some test cases so that the code
766        # is exercised.  We can access ER/VR using the kernel Fq function by
767        # extending _calculate_Iq so that it calls:
768        #    if er_mode > 0:
769        #        res = calculator.Fq(call_details, values, cutoff=self.cutoff,
770        #                            magnetic=False, radius_effective_mode=mode)
771        #        R_eff, form_shell_ratio = res[2], res[4]
772        #        return R_eff, form_shell_ratio
773        # Then use the following in calculate_ER:
774        #    ER, VR = self._calculate_Iq(q=[0.1], er_mode=mode)
775        #    return ER
776        # Similarly, for calculate_VR:
777        #    ER, VR = self._calculate_Iq(q=[0.1], er_mode=1)
778        #    return VR
779        # Obviously a combined calculate_ER_VR method would be better, but
780        # we only need them to support very old models, so ignore the 2x
781        # performance hit.
782        raise NotImplementedError("ER function is no longer available.")
783
784    def calculate_VR(self):
785        # type: () -> float
786        """
787        Calculate the volf ratio for P(q)*S(q)
788
789        :return: the value of the form:shell volume ratio
790        """
791        # See comments in calculate_ER.
792        raise NotImplementedError("VR function is no longer available.")
793
794    def set_dispersion(self, parameter, dispersion):
795        # type: (str, weights.Dispersion) -> None
796        """
797        Set the dispersion object for a model parameter
798
799        :param parameter: name of the parameter [string]
800        :param dispersion: dispersion object of type Dispersion
801        """
802        if parameter in self.params:
803            # TODO: Store the disperser object directly in the model.
804            # The current method of relying on the sasview GUI to
805            # remember them is kind of funky.
806            # Note: can't seem to get disperser parameters from sasview
807            # (1) Could create a sasview model that has not yet been
808            # converted, assign the disperser to one of its polydisperse
809            # parameters, then retrieve the disperser parameters from the
810            # sasview model.
811            # (2) Could write a disperser parameter retriever in sasview.
812            # (3) Could modify sasview to use sasmodels.weights dispersers.
813            # For now, rely on the fact that the sasview only ever uses
814            # new dispersers in the set_dispersion call and create a new
815            # one instead of trying to assign parameters.
816            self.dispersion[parameter] = dispersion.get_pars()
817        else:
818            raise ValueError("%r is not a dispersity or orientation parameter"
819                             % parameter)
820
821    def _dispersion_mesh(self):
822        # type: () -> List[Tuple[np.ndarray, np.ndarray]]
823        """
824        Create a mesh grid of dispersion parameters and weights.
825
826        Returns [p1,p2,...],w where pj is a vector of values for parameter j
827        and w is a vector containing the products for weights for each
828        parameter set in the vector.
829        """
830        pars = [self._get_weights(p)
831                for p in self._model_info.parameters.call_parameters
832                if p.type == 'volume']
833        return dispersion_mesh(self._model_info, pars)
834
835    def _get_weights(self, par):
836        # type: (Parameter) -> Tuple[np.ndarray, np.ndarray]
837        """
838        Return dispersion weights for parameter
839        """
840        if par.name not in self.params:
841            if par.id == self.multiplicity_info.control:
842                return self.multiplicity, [self.multiplicity], [1.0]
843            else:
844                # For hidden parameters use default values.  This sets
845                # scale=1 and background=0 for structure factors
846                default = self._model_info.parameters.defaults.get(par.name, np.NaN)
847                return default, [default], [1.0]
848        elif par.polydisperse:
849            value = self.params[par.name]
850            dis = self.dispersion[par.name]
851            if dis['type'] == 'array':
852                dispersity, weight = dis['values'], dis['weights']
853            else:
854                dispersity, weight = weights.get_weights(
855                    dis['type'], dis['npts'], dis['width'], dis['nsigmas'],
856                    value, par.limits, par.relative_pd)
857            return value, dispersity, weight
858        else:
859            value = self.params[par.name]
860            return value, [value], [1.0]
861
862    @classmethod
863    def runTests(cls):
864        """
865        Run any tests built into the model and captures the test output.
866
867        Returns success flag and output
868        """
869        from .model_test import check_model
870        return check_model(cls._model_info)
871
872def test_cylinder():
873    # type: () -> float
874    """
875    Test that the cylinder model runs, returning the value at [0.1,0.1].
876    """
877    Cylinder = _make_standard_model('cylinder')
878    cylinder = Cylinder()
879    return cylinder.evalDistribution([0.1, 0.1])
880
881def test_structure_factor():
882    # type: () -> float
883    """
884    Test that 2-D hardsphere model runs and doesn't produce NaN.
885    """
886    Model = _make_standard_model('hardsphere')
887    model = Model()
888    value2d = model.evalDistribution([0.1, 0.1])
889    value1d = model.evalDistribution(np.array([0.1*np.sqrt(2)]))
890    #print("hardsphere", value1d, value2d)
891    if np.isnan(value1d) or np.isnan(value2d):
892        raise ValueError("hardsphere returns nan")
893
894def test_product():
895    # type: () -> float
896    """
897    Test that 2-D hardsphere model runs and doesn't produce NaN.
898    """
899    S = _make_standard_model('hayter_msa')()
900    P = _make_standard_model('cylinder')()
901    model = MultiplicationModel(P, S)
902    model.setParam(product.RADIUS_MODE_ID, 1.0)
903    value = model.evalDistribution([0.1, 0.1])
904    if np.isnan(value):
905        raise ValueError("cylinder*hatyer_msa returns null")
906
907def test_rpa():
908    # type: () -> float
909    """
910    Test that the 2-D RPA model runs
911    """
912    RPA = _make_standard_model('rpa')
913    rpa = RPA(3)
914    return rpa.evalDistribution([0.1, 0.1])
915
916def test_empty_distribution():
917    # type: () -> None
918    """
919    Make sure that sasmodels returns NaN when there are no polydispersity points
920    """
921    Cylinder = _make_standard_model('cylinder')
922    cylinder = Cylinder()
923    cylinder.setParam('radius', -1.0)
924    cylinder.setParam('background', 0.)
925    Iq = cylinder.evalDistribution(np.asarray([0.1]))
926    assert Iq[0] == 0., "empty distribution fails"
927
928def test_model_list():
929    # type: () -> None
930    """
931    Make sure that all models build as sasview models
932    """
933    from .exception import annotate_exception
934    for name in core.list_models():
935        try:
936            _make_standard_model(name)
937        except:
938            annotate_exception("when loading "+name)
939            raise
940
941def test_old_name():
942    # type: () -> None
943    """
944    Load and run cylinder model as sas.models.CylinderModel
945    """
946    if not SUPPORT_OLD_STYLE_PLUGINS:
947        return
948    try:
949        # if sasview is not on the path then don't try to test it
950        import sas
951    except ImportError:
952        return
953    load_standard_models()
954    from sas.models.CylinderModel import CylinderModel
955    CylinderModel().evalDistribution([0.1, 0.1])
956
957def test_structure_factor_background():
958    # type: () -> None
959    """
960    Check that sasview model and direct model match, with background=0.
961    """
962    from .data import empty_data1D
963    from .core import load_model_info, build_model
964    from .direct_model import DirectModel
965
966    model_name = "hardsphere"
967    q = [0.0]
968
969    sasview_model = _make_standard_model(model_name)()
970    sasview_value = sasview_model.evalDistribution(np.array(q))[0]
971
972    data = empty_data1D(q)
973    model_info = load_model_info(model_name)
974    model = build_model(model_info)
975    direct_model = DirectModel(data, model)
976    direct_value_zero_background = direct_model(background=0.0)
977
978    assert sasview_value == direct_value_zero_background
979
980    # Additionally check that direct value background defaults to zero
981    direct_value_default = direct_model()
982    assert sasview_value == direct_value_default
983
984
985def magnetic_demo():
986    """
987    Demostrate call to magnetic model.
988    """
989    Model = _make_standard_model('sphere')
990    model = Model()
991    model.setParam('sld_M0', 8)
992    q = np.linspace(-0.35, 0.35, 500)
993    qx, qy = np.meshgrid(q, q)
994    result, _ = model.calculate_Iq(qx.flatten(), qy.flatten())
995    result = result.reshape(qx.shape)
996
997    import pylab
998    pylab.imshow(np.log(result + 0.001))
999    pylab.show()
1000
1001if __name__ == "__main__":
1002    print("cylinder(0.1,0.1)=%g"%test_cylinder())
1003    #magnetic_demo()
1004    #test_product()
1005    #test_structure_factor()
1006    #print("rpa:", test_rpa())
1007    #test_empty_distribution()
1008    #test_structure_factor_background()
Note: See TracBrowser for help on using the repository browser.