source: sasmodels/sasmodels/sasview_model.py @ a38b065

core_shell_microgelscostrafo411magnetic_modelticket-1257-vesicle-productticket_1156ticket_1265_superballticket_822_more_unit_tests
Last change on this file since a38b065 was a38b065, checked in by Paul Kienzle <pkienzle@…>, 7 years ago

limit sasview to a single model calculation across all threads

  • Property mode set to 100644
File size: 27.5 KB
RevLine 
[87985ca]1"""
2Sasview model constructor.
3
4Given a module defining an OpenCL kernel such as sasmodels.models.cylinder,
5create a sasview model class to run that kernel as follows::
6
[92d38285]7    from sasmodels.sasview_model import load_custom_model
8    CylinderModel = load_custom_model('sasmodels/models/cylinder.py')
[87985ca]9"""
[4d76711]10from __future__ import print_function
[87985ca]11
[ce27e21]12import math
13from copy import deepcopy
[2622b3f]14import collections
[4d76711]15import traceback
16import logging
[9457498]17from os.path import basename, splitext
[a38b065]18import thread
[ce27e21]19
[7ae2b7f]20import numpy as np  # type: ignore
[ce27e21]21
[aa4946b]22from . import core
[4d76711]23from . import custom
[a80e64c]24from . import product
[72a081d]25from . import generate
[fb5914f]26from . import weights
[6d6508e]27from . import modelinfo
[bde38b5]28from .details import make_kernel_args, dispersion_mesh
[ff7119b]29
[fa5fd8d]30try:
[60f03de]31    from typing import Dict, Mapping, Any, Sequence, Tuple, NamedTuple, List, Optional, Union, Callable
[fa5fd8d]32    from .modelinfo import ModelInfo, Parameter
33    from .kernel import KernelModel
34    MultiplicityInfoType = NamedTuple(
35        'MuliplicityInfo',
36        [("number", int), ("control", str), ("choices", List[str]),
37         ("x_axis_label", str)])
[60f03de]38    SasviewModelType = Callable[[int], "SasviewModel"]
[fa5fd8d]39except ImportError:
40    pass
41
[a38b065]42calculation_lock = thread.allocate_lock()
43
[c95dfc63]44SUPPORT_OLD_STYLE_PLUGINS = True
45
46def _register_old_models():
47    # type: () -> None
48    """
49    Place the new models into sasview under the old names.
50
51    Monkey patch sas.sascalc.fit as sas.models so that sas.models.pluginmodel
52    is available to the plugin modules.
53    """
54    import sys
[85fe7f8]55    import sas   # needed in order to set sas.models
[c95dfc63]56    import sas.sascalc.fit
57    sys.modules['sas.models'] = sas.sascalc.fit
58    sas.models = sas.sascalc.fit
59
60    import sas.models
61    from sasmodels.conversion_table import CONVERSION_TABLE
[a297255]62    for new_name, conversion in CONVERSION_TABLE.get((3,1,2), {}).items():
[85fe7f8]63        # CoreShellEllipsoidModel => core_shell_ellipsoid:1
64        new_name = new_name.split(':')[0]
[bb584b3]65        old_name = conversion[0] if len(conversion) < 3 else conversion[2]
[fd19811]66        module_attrs = {old_name: find_model(new_name)}
[c95dfc63]67        ConstructedModule = type(old_name, (), module_attrs)
68        old_path = 'sas.models.' + old_name
69        setattr(sas.models, old_path, ConstructedModule)
70        sys.modules[old_path] = ConstructedModule
71
[9457498]72
[fa5fd8d]73# TODO: separate x_axis_label from multiplicity info
74MultiplicityInfo = collections.namedtuple(
75    'MultiplicityInfo',
76    ["number", "control", "choices", "x_axis_label"],
77)
78
[92d38285]79MODELS = {}
80def find_model(modelname):
[b32dafd]81    # type: (str) -> SasviewModelType
82    """
83    Find a model by name.  If the model name ends in py, try loading it from
84    custom models, otherwise look for it in the list of builtin models.
85    """
[92d38285]86    # TODO: used by sum/product model to load an existing model
87    # TODO: doesn't handle custom models properly
88    if modelname.endswith('.py'):
89        return load_custom_model(modelname)
90    elif modelname in MODELS:
91        return MODELS[modelname]
92    else:
93        raise ValueError("unknown model %r"%modelname)
94
[56b2687]95
[fa5fd8d]96# TODO: figure out how to say that the return type is a subclass
[4d76711]97def load_standard_models():
[60f03de]98    # type: () -> List[SasviewModelType]
[4d76711]99    """
100    Load and return the list of predefined models.
101
102    If there is an error loading a model, then a traceback is logged and the
103    model is not returned.
104    """
105    models = []
106    for name in core.list_models():
107        try:
[92d38285]108            MODELS[name] = _make_standard_model(name)
109            models.append(MODELS[name])
[ee8f734]110        except Exception:
[4d76711]111            logging.error(traceback.format_exc())
[c95dfc63]112    if SUPPORT_OLD_STYLE_PLUGINS:
113        _register_old_models()
114
[4d76711]115    return models
[de97440]116
[4d76711]117
118def load_custom_model(path):
[60f03de]119    # type: (str) -> SasviewModelType
[4d76711]120    """
121    Load a custom model given the model path.
[ff7119b]122    """
[4d76711]123    kernel_module = custom.load_custom_kernel_module(path)
[92d38285]124    try:
125        model = kernel_module.Model
[9457498]126        # Old style models do not set the name in the class attributes, so
127        # set it here; this name will be overridden when the object is created
128        # with an instance variable that has the same value.
129        if model.name == "":
130            model.name = splitext(basename(path))[0]
[20a70bc]131        if not hasattr(model, 'filename'):
132            model.filename = kernel_module.__file__
[64f0a1c]133            # For old models, treat .pyc and .py files interchangeably.
134            # This is needed because of the Sum|Multi(p1,p2) types of models
135            # and the convoluted way in which they are created.
136            if model.filename.endswith(".py"):
137                logging.info("Loading %s as .pyc", model.filename)
138                model.filename = model.filename+'c'
[e4bf271]139        if not hasattr(model, 'id'):
140            model.id = splitext(basename(model.filename))[0]
[92d38285]141    except AttributeError:
[56b2687]142        model_info = modelinfo.make_model_info(kernel_module)
[92d38285]143        model = _make_model_from_info(model_info)
[ed10b57]144
[2f2c70c]145    # If a model name already exists and we are loading a different model,
146    # use the model file name as the model name.
147    if model.name in MODELS and not model.filename == MODELS[model.name].filename:
148        _previous_name = model.name
149        model.name = model.id
150       
151        # If the new model name is still in the model list (for instance,
152        # if we put a cylinder.py in our plug-in directory), then append
153        # an identifier.
154        if model.name in MODELS and not model.filename == MODELS[model.name].filename:
155            model.name = model.id + '_user'
156        logging.info("Model %s already exists: using %s [%s]", _previous_name, model.name, model.filename)
[ed10b57]157
[92d38285]158    MODELS[model.name] = model
159    return model
[4d76711]160
[87985ca]161
[4d76711]162def _make_standard_model(name):
[60f03de]163    # type: (str) -> SasviewModelType
[ff7119b]164    """
[4d76711]165    Load the sasview model defined by *name*.
[72a081d]166
[4d76711]167    *name* can be a standard model name or a path to a custom model.
[87985ca]168
[4d76711]169    Returns a class that can be used directly as a sasview model.
[ff7119b]170    """
[4d76711]171    kernel_module = generate.load_kernel_module(name)
[fa5fd8d]172    model_info = modelinfo.make_model_info(kernel_module)
[4d76711]173    return _make_model_from_info(model_info)
[72a081d]174
175
[a80e64c]176def MultiplicationModel(form_factor, structure_factor):
177    # type: ("SasviewModel", "SasviewModel") -> "SasviewModel"
178    model_info = product.make_product_info(form_factor._model_info,
179                                           structure_factor._model_info)
180    ConstructedModel = _make_model_from_info(model_info)
181    return ConstructedModel()
182
[4d76711]183def _make_model_from_info(model_info):
[60f03de]184    # type: (ModelInfo) -> SasviewModelType
[4d76711]185    """
186    Convert *model_info* into a SasView model wrapper.
187    """
[fa5fd8d]188    def __init__(self, multiplicity=None):
189        SasviewModel.__init__(self, multiplicity=multiplicity)
190    attrs = _generate_model_attributes(model_info)
191    attrs['__init__'] = __init__
[2f2c70c]192    attrs['filename'] = model_info.filename
[60f03de]193    ConstructedModel = type(model_info.name, (SasviewModel,), attrs) # type: SasviewModelType
[ce27e21]194    return ConstructedModel
195
[fa5fd8d]196def _generate_model_attributes(model_info):
197    # type: (ModelInfo) -> Dict[str, Any]
198    """
199    Generate the class attributes for the model.
200
201    This should include all the information necessary to query the model
202    details so that you do not need to instantiate a model to query it.
203
204    All the attributes should be immutable to avoid accidents.
205    """
206
207    # TODO: allow model to override axis labels input/output name/unit
208
[a18c5b3]209    # Process multiplicity
[fa5fd8d]210    non_fittable = []  # type: List[str]
[04045f4]211    xlabel = model_info.profile_axes[0] if model_info.profile is not None else ""
212    variants = MultiplicityInfo(0, "", [], xlabel)
[a18c5b3]213    for p in model_info.parameters.kernel_parameters:
[04045f4]214        if p.name == model_info.control:
[fa5fd8d]215            non_fittable.append(p.name)
[04045f4]216            variants = MultiplicityInfo(
[ce176ca]217                len(p.choices) if p.choices else int(p.limits[1]),
218                p.name, p.choices, xlabel
[fa5fd8d]219            )
220            break
221
[50ec515]222    # Only a single drop-down list parameter available
223    fun_list = []
224    for p in model_info.parameters.kernel_parameters:
225        if p.choices:
226            fun_list = p.choices
227            if p.length > 1:
228                non_fittable.extend(p.id+str(k) for k in range(1, p.length+1))
229            break
230
[a18c5b3]231    # Organize parameter sets
[fa5fd8d]232    orientation_params = []
233    magnetic_params = []
234    fixed = []
[85fe7f8]235    for p in model_info.parameters.user_parameters({}, is2d=True):
[fa5fd8d]236        if p.type == 'orientation':
237            orientation_params.append(p.name)
238            orientation_params.append(p.name+".width")
239            fixed.append(p.name+".width")
[32e3c9b]240        elif p.type == 'magnetic':
[fa5fd8d]241            orientation_params.append(p.name)
242            magnetic_params.append(p.name)
243            fixed.append(p.name+".width")
[a18c5b3]244
[32e3c9b]245
[a18c5b3]246    # Build class dictionary
247    attrs = {}  # type: Dict[str, Any]
248    attrs['_model_info'] = model_info
249    attrs['name'] = model_info.name
250    attrs['id'] = model_info.id
251    attrs['description'] = model_info.description
252    attrs['category'] = model_info.category
253    attrs['is_structure_factor'] = model_info.structure_factor
254    attrs['is_form_factor'] = model_info.ER is not None
255    attrs['is_multiplicity_model'] = variants[0] > 1
256    attrs['multiplicity_info'] = variants
[fa5fd8d]257    attrs['orientation_params'] = tuple(orientation_params)
258    attrs['magnetic_params'] = tuple(magnetic_params)
259    attrs['fixed'] = tuple(fixed)
260    attrs['non_fittable'] = tuple(non_fittable)
[50ec515]261    attrs['fun_list'] = tuple(fun_list)
[fa5fd8d]262
263    return attrs
[4d76711]264
[ce27e21]265class SasviewModel(object):
266    """
267    Sasview wrapper for opencl/ctypes model.
268    """
[fa5fd8d]269    # Model parameters for the specific model are set in the class constructor
270    # via the _generate_model_attributes function, which subclasses
271    # SasviewModel.  They are included here for typing and documentation
272    # purposes.
273    _model = None       # type: KernelModel
274    _model_info = None  # type: ModelInfo
275    #: load/save name for the model
276    id = None           # type: str
277    #: display name for the model
278    name = None         # type: str
279    #: short model description
280    description = None  # type: str
281    #: default model category
282    category = None     # type: str
283
284    #: names of the orientation parameters in the order they appear
285    orientation_params = None # type: Sequence[str]
286    #: names of the magnetic parameters in the order they appear
287    magnetic_params = None    # type: Sequence[str]
288    #: names of the fittable parameters
289    fixed = None              # type: Sequence[str]
290    # TODO: the attribute fixed is ill-named
291
292    # Axis labels
293    input_name = "Q"
294    input_unit = "A^{-1}"
295    output_name = "Intensity"
296    output_unit = "cm^{-1}"
297
298    #: default cutoff for polydispersity
299    cutoff = 1e-5
300
301    # Note: Use non-mutable values for class attributes to avoid errors
302    #: parameters that are not fitted
303    non_fittable = ()        # type: Sequence[str]
304
305    #: True if model should appear as a structure factor
306    is_structure_factor = False
307    #: True if model should appear as a form factor
308    is_form_factor = False
309    #: True if model has multiplicity
310    is_multiplicity_model = False
311    #: Mulitplicity information
312    multiplicity_info = None # type: MultiplicityInfoType
313
314    # Per-instance variables
315    #: parameter {name: value} mapping
316    params = None      # type: Dict[str, float]
317    #: values for dispersion width, npts, nsigmas and type
318    dispersion = None  # type: Dict[str, Any]
319    #: units and limits for each parameter
[60f03de]320    details = None     # type: Dict[str, Sequence[Any]]
321    #                  # actual type is Dict[str, List[str, float, float]]
[04dc697]322    #: multiplicity value, or None if no multiplicity on the model
[fa5fd8d]323    multiplicity = None     # type: Optional[int]
[04dc697]324    #: memory for polydispersity array if using ArrayDispersion (used by sasview).
325    _persistency_dict = None # type: Dict[str, Tuple[np.ndarray, np.ndarray]]
[fa5fd8d]326
327    def __init__(self, multiplicity=None):
[04dc697]328        # type: (Optional[int]) -> None
[2622b3f]329
[04045f4]330        # TODO: _persistency_dict to persistency_dict throughout sasview
331        # TODO: refactor multiplicity to encompass variants
332        # TODO: dispersion should be a class
[fa5fd8d]333        # TODO: refactor multiplicity info
334        # TODO: separate profile view from multiplicity
335        # The button label, x and y axis labels and scale need to be under
336        # the control of the model, not the fit page.  Maximum flexibility,
337        # the fit page would supply the canvas and the profile could plot
338        # how it wants, but this assumes matplotlib.  Next level is that
339        # we provide some sort of data description including title, labels
340        # and lines to plot.
341
[04045f4]342        # Get the list of hidden parameters given the mulitplicity
343        # Don't include multiplicity in the list of parameters
[fa5fd8d]344        self.multiplicity = multiplicity
[04045f4]345        if multiplicity is not None:
346            hidden = self._model_info.get_hidden_parameters(multiplicity)
347            hidden |= set([self.multiplicity_info.control])
348        else:
349            hidden = set()
[8f93522]350        if self._model_info.structure_factor:
351            hidden.add('scale')
352            hidden.add('background')
353            self._model_info.parameters.defaults['background'] = 0.
[04045f4]354
[04dc697]355        self._persistency_dict = {}
[fa5fd8d]356        self.params = collections.OrderedDict()
[b3a85cd]357        self.dispersion = collections.OrderedDict()
[fa5fd8d]358        self.details = {}
[8977226]359        for p in self._model_info.parameters.user_parameters({}, is2d=True):
[04045f4]360            if p.name in hidden:
[fa5fd8d]361                continue
[fcd7bbd]362            self.params[p.name] = p.default
[fa5fd8d]363            self.details[p.id] = [p.units, p.limits[0], p.limits[1]]
[fb5914f]364            if p.polydisperse:
[fa5fd8d]365                self.details[p.id+".width"] = [
366                    "", 0.0, 1.0 if p.relative_pd else np.inf
367                ]
[fb5914f]368                self.dispersion[p.name] = {
369                    'width': 0,
370                    'npts': 35,
371                    'nsigmas': 3,
372                    'type': 'gaussian',
373                }
[ce27e21]374
[de97440]375    def __get_state__(self):
[fa5fd8d]376        # type: () -> Dict[str, Any]
[de97440]377        state = self.__dict__.copy()
[4d76711]378        state.pop('_model')
[de97440]379        # May need to reload model info on set state since it has pointers
380        # to python implementations of Iq, etc.
381        #state.pop('_model_info')
382        return state
383
384    def __set_state__(self, state):
[fa5fd8d]385        # type: (Dict[str, Any]) -> None
[de97440]386        self.__dict__ = state
[fb5914f]387        self._model = None
[de97440]388
[ce27e21]389    def __str__(self):
[fa5fd8d]390        # type: () -> str
[ce27e21]391        """
392        :return: string representation
393        """
394        return self.name
395
396    def is_fittable(self, par_name):
[fa5fd8d]397        # type: (str) -> bool
[ce27e21]398        """
399        Check if a given parameter is fittable or not
400
401        :param par_name: the parameter name to check
402        """
[e758662]403        return par_name in self.fixed
[ce27e21]404        #For the future
405        #return self.params[str(par_name)].is_fittable()
406
407
408    def getProfile(self):
[fa5fd8d]409        # type: () -> (np.ndarray, np.ndarray)
[ce27e21]410        """
411        Get SLD profile
412
413        : return: (z, beta) where z is a list of depth of the transition points
414                beta is a list of the corresponding SLD values
415        """
[745b7bb]416        args = {} # type: Dict[str, Any]
[fa5fd8d]417        for p in self._model_info.parameters.kernel_parameters:
418            if p.id == self.multiplicity_info.control:
[745b7bb]419                value = float(self.multiplicity)
[fa5fd8d]420            elif p.length == 1:
[745b7bb]421                value = self.params.get(p.id, np.NaN)
[fa5fd8d]422            else:
[745b7bb]423                value = np.array([self.params.get(p.id+str(k), np.NaN)
[b32dafd]424                                  for k in range(1, p.length+1)])
[745b7bb]425            args[p.id] = value
426
[e7fe459]427        x, y = self._model_info.profile(**args)
428        return x, 1e-6*y
[ce27e21]429
430    def setParam(self, name, value):
[fa5fd8d]431        # type: (str, float) -> None
[ce27e21]432        """
433        Set the value of a model parameter
434
435        :param name: name of the parameter
436        :param value: value of the parameter
437
438        """
439        # Look for dispersion parameters
440        toks = name.split('.')
[de0c4ba]441        if len(toks) == 2:
[ce27e21]442            for item in self.dispersion.keys():
[e758662]443                if item == toks[0]:
[ce27e21]444                    for par in self.dispersion[item]:
[e758662]445                        if par == toks[1]:
[ce27e21]446                            self.dispersion[item][par] = value
447                            return
448        else:
449            # Look for standard parameter
450            for item in self.params.keys():
[e758662]451                if item == name:
[ce27e21]452                    self.params[item] = value
453                    return
454
[63b32bb]455        raise ValueError("Model does not contain parameter %s" % name)
[ce27e21]456
457    def getParam(self, name):
[fa5fd8d]458        # type: (str) -> float
[ce27e21]459        """
460        Set the value of a model parameter
461
462        :param name: name of the parameter
463
464        """
465        # Look for dispersion parameters
466        toks = name.split('.')
[de0c4ba]467        if len(toks) == 2:
[ce27e21]468            for item in self.dispersion.keys():
[e758662]469                if item == toks[0]:
[ce27e21]470                    for par in self.dispersion[item]:
[e758662]471                        if par == toks[1]:
[ce27e21]472                            return self.dispersion[item][par]
473        else:
474            # Look for standard parameter
475            for item in self.params.keys():
[e758662]476                if item == name:
[ce27e21]477                    return self.params[item]
478
[63b32bb]479        raise ValueError("Model does not contain parameter %s" % name)
[ce27e21]480
481    def getParamList(self):
[04dc697]482        # type: () -> Sequence[str]
[ce27e21]483        """
484        Return a list of all available parameters for the model
485        """
[04dc697]486        param_list = list(self.params.keys())
[ce27e21]487        # WARNING: Extending the list with the dispersion parameters
[de0c4ba]488        param_list.extend(self.getDispParamList())
489        return param_list
[ce27e21]490
491    def getDispParamList(self):
[04dc697]492        # type: () -> Sequence[str]
[ce27e21]493        """
[fb5914f]494        Return a list of polydispersity parameters for the model
[ce27e21]495        """
[1780d59]496        # TODO: fix test so that parameter order doesn't matter
[3bcb88c]497        ret = ['%s.%s' % (p_name, ext)
498               for p_name in self.dispersion.keys()
499               for ext in ('npts', 'nsigmas', 'width')]
[9404dd3]500        #print(ret)
[1780d59]501        return ret
[ce27e21]502
503    def clone(self):
[04dc697]504        # type: () -> "SasviewModel"
[ce27e21]505        """ Return a identical copy of self """
506        return deepcopy(self)
507
508    def run(self, x=0.0):
[fa5fd8d]509        # type: (Union[float, (float, float), List[float]]) -> float
[ce27e21]510        """
511        Evaluate the model
512
513        :param x: input q, or [q,phi]
514
515        :return: scattering function P(q)
516
517        **DEPRECATED**: use calculate_Iq instead
518        """
[de0c4ba]519        if isinstance(x, (list, tuple)):
[3c56da87]520            # pylint: disable=unpacking-non-sequence
[ce27e21]521            q, phi = x
[60f03de]522            return self.calculate_Iq([q*math.cos(phi)], [q*math.sin(phi)])[0]
[ce27e21]523        else:
[60f03de]524            return self.calculate_Iq([x])[0]
[ce27e21]525
526
527    def runXY(self, x=0.0):
[fa5fd8d]528        # type: (Union[float, (float, float), List[float]]) -> float
[ce27e21]529        """
530        Evaluate the model in cartesian coordinates
531
532        :param x: input q, or [qx, qy]
533
534        :return: scattering function P(q)
535
536        **DEPRECATED**: use calculate_Iq instead
537        """
[de0c4ba]538        if isinstance(x, (list, tuple)):
[60f03de]539            return self.calculate_Iq([x[0]], [x[1]])[0]
[ce27e21]540        else:
[60f03de]541            return self.calculate_Iq([x])[0]
[ce27e21]542
543    def evalDistribution(self, qdist):
[04dc697]544        # type: (Union[np.ndarray, Tuple[np.ndarray, np.ndarray], List[np.ndarray]]) -> np.ndarray
[d138d43]545        r"""
[ce27e21]546        Evaluate a distribution of q-values.
547
[d138d43]548        :param qdist: array of q or a list of arrays [qx,qy]
[ce27e21]549
[d138d43]550        * For 1D, a numpy array is expected as input
[ce27e21]551
[d138d43]552        ::
[ce27e21]553
[d138d43]554            evalDistribution(q)
[ce27e21]555
[d138d43]556          where *q* is a numpy array.
[ce27e21]557
[d138d43]558        * For 2D, a list of *[qx,qy]* is expected with 1D arrays as input
[ce27e21]559
[d138d43]560        ::
[ce27e21]561
[d138d43]562              qx = [ qx[0], qx[1], qx[2], ....]
563              qy = [ qy[0], qy[1], qy[2], ....]
[ce27e21]564
[d138d43]565        If the model is 1D only, then
[ce27e21]566
[d138d43]567        .. math::
[ce27e21]568
[d138d43]569            q = \sqrt{q_x^2+q_y^2}
[ce27e21]570
571        """
[de0c4ba]572        if isinstance(qdist, (list, tuple)):
[ce27e21]573            # Check whether we have a list of ndarrays [qx,qy]
574            qx, qy = qdist
[6d6508e]575            if not self._model_info.parameters.has_2d:
[de0c4ba]576                return self.calculate_Iq(np.sqrt(qx ** 2 + qy ** 2))
[5d4777d]577            else:
578                return self.calculate_Iq(qx, qy)
[ce27e21]579
580        elif isinstance(qdist, np.ndarray):
581            # We have a simple 1D distribution of q-values
582            return self.calculate_Iq(qdist)
583
584        else:
[3c56da87]585            raise TypeError("evalDistribution expects q or [qx, qy], not %r"
586                            % type(qdist))
[ce27e21]587
[64614ad]588    def get_composition_models(self):
589        """
590            Returns usable models that compose this model
591        """
592        s_model = None
593        p_model = None
594        if hasattr(self._model_info, "composition") \
595           and self._model_info.composition is not None:
596            p_model = _make_model_from_info(self._model_info.composition[1][0])()
597            s_model = _make_model_from_info(self._model_info.composition[1][1])()
598        return p_model, s_model
599
[fa5fd8d]600    def calculate_Iq(self, qx, qy=None):
601        # type: (Sequence[float], Optional[Sequence[float]]) -> np.ndarray
[ff7119b]602        """
603        Calculate Iq for one set of q with the current parameters.
604
605        If the model is 1D, use *q*.  If 2D, use *qx*, *qy*.
606
607        This should NOT be used for fitting since it copies the *q* vectors
608        to the card for each evaluation.
609        """
[a38b065]610        ## uncomment the following when trying to debug the uncoordinated calls
611        ## to calculate_Iq
612        #if calculation_lock.locked():
613        #    logging.info("calculation waiting for another thread to complete")
614        #    logging.info("\n".join(traceback.format_stack()))
615
616        with calculation_lock:
617            return self._calculate_Iq(qx, qy)
618
619    def _calculate_Iq(self, qx, qy=None):
[6a0d6aa]620        #core.HAVE_OPENCL = False
[fb5914f]621        if self._model is None:
[d2bb604]622            self._model = core.build_model(self._model_info)
[fa5fd8d]623        if qy is not None:
624            q_vectors = [np.asarray(qx), np.asarray(qy)]
625        else:
626            q_vectors = [np.asarray(qx)]
[a738209]627        calculator = self._model.make_kernel(q_vectors)
[6a0d6aa]628        parameters = self._model_info.parameters
629        pairs = [self._get_weights(p) for p in parameters.call_parameters]
[9c1a59c]630        #weights.plot_weights(self._model_info, pairs)
[bde38b5]631        call_details, values, is_magnetic = make_kernel_args(calculator, pairs)
[4edec6f]632        #call_details.show()
633        #print("pairs", pairs)
634        #print("params", self.params)
635        #print("values", values)
636        #print("is_mag", is_magnetic)
[6a0d6aa]637        result = calculator(call_details, values, cutoff=self.cutoff,
[9eb3632]638                            magnetic=is_magnetic)
[a738209]639        calculator.release()
[9f37726]640        self._model.release()
[ce27e21]641        return result
642
643    def calculate_ER(self):
[fa5fd8d]644        # type: () -> float
[ce27e21]645        """
646        Calculate the effective radius for P(q)*S(q)
647
648        :return: the value of the effective radius
649        """
[4bfd277]650        if self._model_info.ER is None:
[ce27e21]651            return 1.0
652        else:
[4bfd277]653            value, weight = self._dispersion_mesh()
654            fv = self._model_info.ER(*value)
[9404dd3]655            #print(values[0].shape, weights.shape, fv.shape)
[4bfd277]656            return np.sum(weight * fv) / np.sum(weight)
[ce27e21]657
658    def calculate_VR(self):
[fa5fd8d]659        # type: () -> float
[ce27e21]660        """
661        Calculate the volf ratio for P(q)*S(q)
662
663        :return: the value of the volf ratio
664        """
[4bfd277]665        if self._model_info.VR is None:
[ce27e21]666            return 1.0
667        else:
[4bfd277]668            value, weight = self._dispersion_mesh()
669            whole, part = self._model_info.VR(*value)
670            return np.sum(weight * part) / np.sum(weight * whole)
[ce27e21]671
672    def set_dispersion(self, parameter, dispersion):
[fa5fd8d]673        # type: (str, weights.Dispersion) -> Dict[str, Any]
[ce27e21]674        """
675        Set the dispersion object for a model parameter
676
677        :param parameter: name of the parameter [string]
678        :param dispersion: dispersion object of type Dispersion
679        """
[fa800e72]680        if parameter in self.params:
[1780d59]681            # TODO: Store the disperser object directly in the model.
[56b2687]682            # The current method of relying on the sasview GUI to
[fa800e72]683            # remember them is kind of funky.
[1780d59]684            # Note: can't seem to get disperser parameters from sasview
[9c1a59c]685            # (1) Could create a sasview model that has not yet been
[1780d59]686            # converted, assign the disperser to one of its polydisperse
687            # parameters, then retrieve the disperser parameters from the
[9c1a59c]688            # sasview model.
689            # (2) Could write a disperser parameter retriever in sasview.
690            # (3) Could modify sasview to use sasmodels.weights dispersers.
[1780d59]691            # For now, rely on the fact that the sasview only ever uses
692            # new dispersers in the set_dispersion call and create a new
693            # one instead of trying to assign parameters.
[ce27e21]694            self.dispersion[parameter] = dispersion.get_pars()
695        else:
696            raise ValueError("%r is not a dispersity or orientation parameter")
697
[aa4946b]698    def _dispersion_mesh(self):
[fa5fd8d]699        # type: () -> List[Tuple[np.ndarray, np.ndarray]]
[ce27e21]700        """
701        Create a mesh grid of dispersion parameters and weights.
702
703        Returns [p1,p2,...],w where pj is a vector of values for parameter j
704        and w is a vector containing the products for weights for each
705        parameter set in the vector.
706        """
[4bfd277]707        pars = [self._get_weights(p)
708                for p in self._model_info.parameters.call_parameters
709                if p.type == 'volume']
[9eb3632]710        return dispersion_mesh(self._model_info, pars)
[ce27e21]711
712    def _get_weights(self, par):
[fa5fd8d]713        # type: (Parameter) -> Tuple[np.ndarray, np.ndarray]
[de0c4ba]714        """
[fb5914f]715        Return dispersion weights for parameter
[de0c4ba]716        """
[fa5fd8d]717        if par.name not in self.params:
718            if par.name == self.multiplicity_info.control:
[4edec6f]719                return [self.multiplicity], [1.0]
[fa5fd8d]720            else:
[8f93522]721                # For hidden parameters use the default value.
722                value = self._model_info.parameters.defaults.get(par.name, np.NaN)
723                return [value], [1.0]
[fa5fd8d]724        elif par.polydisperse:
[fb5914f]725            dis = self.dispersion[par.name]
[9c1a59c]726            if dis['type'] == 'array':
727                value, weight = dis['values'], dis['weights']
728            else:
729                value, weight = weights.get_weights(
730                    dis['type'], dis['npts'], dis['width'], dis['nsigmas'],
731                    self.params[par.name], par.limits, par.relative_pd)
[fb5914f]732            return value, weight / np.sum(weight)
733        else:
[4edec6f]734            return [self.params[par.name]], [1.0]
[ce27e21]735
[fb5914f]736def test_model():
[fa5fd8d]737    # type: () -> float
[4d76711]738    """
739    Test that a sasview model (cylinder) can be run.
740    """
741    Cylinder = _make_standard_model('cylinder')
[fb5914f]742    cylinder = Cylinder()
[b32dafd]743    return cylinder.evalDistribution([0.1, 0.1])
[de97440]744
[8f93522]745def test_structure_factor():
746    # type: () -> float
747    """
748    Test that a sasview model (cylinder) can be run.
749    """
750    Model = _make_standard_model('hardsphere')
751    model = Model()
752    value = model.evalDistribution([0.1, 0.1])
753    if np.isnan(value):
754        raise ValueError("hardsphere returns null")
755
[04045f4]756def test_rpa():
757    # type: () -> float
758    """
759    Test that a sasview model (cylinder) can be run.
760    """
761    RPA = _make_standard_model('rpa')
762    rpa = RPA(3)
[b32dafd]763    return rpa.evalDistribution([0.1, 0.1])
[04045f4]764
[4d76711]765
766def test_model_list():
[fa5fd8d]767    # type: () -> None
[4d76711]768    """
769    Make sure that all models build as sasview models.
770    """
771    from .exception import annotate_exception
772    for name in core.list_models():
773        try:
774            _make_standard_model(name)
775        except:
776            annotate_exception("when loading "+name)
777            raise
778
[c95dfc63]779def test_old_name():
780    # type: () -> None
781    """
782    Load and run cylinder model from sas.models.CylinderModel
783    """
784    if not SUPPORT_OLD_STYLE_PLUGINS:
785        return
786    try:
787        # if sasview is not on the path then don't try to test it
788        import sas
789    except ImportError:
790        return
791    load_standard_models()
792    from sas.models.CylinderModel import CylinderModel
793    CylinderModel().evalDistribution([0.1, 0.1])
794
[fb5914f]795if __name__ == "__main__":
[ea05c87]796    print("cylinder(0.1,0.1)=%g"%test_model())
Note: See TracBrowser for help on using the repository browser.