source: sasmodels/sasmodels/sasview_model.py @ 7ae2b7f

core_shell_microgelscostrafo411magnetic_modelrelease_v0.94release_v0.95ticket-1257-vesicle-productticket_1156ticket_1265_superballticket_822_more_unit_tests
Last change on this file since 7ae2b7f was 7ae2b7f, checked in by Paul Kienzle <pkienzle@…>, 8 years ago

still more linting; ignore numpy types

  • Property mode set to 100644
File size: 14.9 KB
RevLine 
[87985ca]1"""
2Sasview model constructor.
3
4Given a module defining an OpenCL kernel such as sasmodels.models.cylinder,
5create a sasview model class to run that kernel as follows::
6
7    from sasmodels.sasview_model import make_class
8    from sasmodels.models import cylinder
9    CylinderModel = make_class(cylinder, dtype='single')
10
11The model parameters for sasmodels are different from those in sasview.
12When reloading previously saved models, the parameters should be converted
13using :func:`sasmodels.convert.convert`.
14"""
[4d76711]15from __future__ import print_function
[87985ca]16
[ce27e21]17import math
18from copy import deepcopy
[2622b3f]19import collections
[4d76711]20import traceback
21import logging
[ce27e21]22
[7ae2b7f]23import numpy as np  # type: ignore
[ce27e21]24
[aa4946b]25from . import core
[4d76711]26from . import custom
[72a081d]27from . import generate
[fb5914f]28from . import weights
[6d6508e]29from . import details
30from . import modelinfo
[ff7119b]31
[4d76711]32def load_standard_models():
33    """
34    Load and return the list of predefined models.
35
36    If there is an error loading a model, then a traceback is logged and the
37    model is not returned.
38    """
39    models = []
40    for name in core.list_models():
41        try:
42            models.append(_make_standard_model(name))
[ee8f734]43        except Exception:
[4d76711]44            logging.error(traceback.format_exc())
45    return models
[de97440]46
[4d76711]47
48def load_custom_model(path):
49    """
50    Load a custom model given the model path.
[ff7119b]51    """
[4d76711]52    kernel_module = custom.load_custom_kernel_module(path)
[6d6508e]53    model_info = modelinfo.make_model_info(kernel_module)
[4d76711]54    return _make_model_from_info(model_info)
55
[87985ca]56
[4d76711]57def _make_standard_model(name):
[ff7119b]58    """
[4d76711]59    Load the sasview model defined by *name*.
[72a081d]60
[4d76711]61    *name* can be a standard model name or a path to a custom model.
[87985ca]62
[4d76711]63    Returns a class that can be used directly as a sasview model.
[ff7119b]64    """
[4d76711]65    kernel_module = generate.load_kernel_module(name)
[0d99a6a]66    #model_info = modelinfo.make_model_info(kernel_module)
67    model_info = modelinfo.make_model_info("hello")
[4d76711]68    return _make_model_from_info(model_info)
[72a081d]69
70
[4d76711]71def _make_model_from_info(model_info):
72    """
73    Convert *model_info* into a SasView model wrapper.
74    """
[32c160a]75    def __init__(self, multfactor=1):
[08376e7]76        SasviewModel.__init__(self)
77    attrs = dict(__init__=__init__, _model_info=model_info)
[6d6508e]78    ConstructedModel = type(model_info.name, (SasviewModel,), attrs)
[ce27e21]79    return ConstructedModel
80
[4d76711]81
[ce27e21]82class SasviewModel(object):
83    """
84    Sasview wrapper for opencl/ctypes model.
85    """
[4bfd277]86    _model_info = None # type: modelinfo.ModelInfo
[08376e7]87    def __init__(self):
[fb5914f]88        self._model = None
[08376e7]89        model_info = self._model_info
[6d6508e]90        parameters = model_info.parameters
[ce27e21]91
[6d6508e]92        self.name = model_info.name
93        self.description = model_info.description
[ce27e21]94        self.category = None
[ce896fd]95        #self.is_multifunc = False
96        for p in parameters.kernel_parameters:
97            if p.is_control:
[6d6508e]98                profile_axes = model_info.profile_axes
[ce896fd]99                self.multiplicity_info = [
100                    p.limits[1], p.name, p.choices, profile_axes[0]
101                    ]
102                break
103        else:
104            self.multiplicity_info = []
[ce27e21]105
106        ## interpret the parameters
107        ## TODO: reorganize parameter handling
108        self.details = dict()
[2622b3f]109        self.params = collections.OrderedDict()
[ce27e21]110        self.dispersion = dict()
[2622b3f]111
[fb5914f]112        self.orientation_params = []
113        self.magnetic_params = []
114        self.fixed = []
115        for p in parameters.user_parameters():
[fcd7bbd]116            self.params[p.name] = p.default
117            self.details[p.name] = [p.units] + p.limits
[fb5914f]118            if p.polydisperse:
119                self.dispersion[p.name] = {
120                    'width': 0,
121                    'npts': 35,
122                    'nsigmas': 3,
123                    'type': 'gaussian',
124                }
125            if p.type == 'orientation':
126                self.orientation_params.append(p.name)
127                self.orientation_params.append(p.name+".width")
128                self.fixed.append(p.name+".width")
129            if p.type == 'magnetic':
130                self.orientation_params.append(p.name)
131                self.magnetic_params.append(p.name)
132                self.fixed.append(p.name+".width")
[ce27e21]133
134        self.non_fittable = []
135
136        ## independent parameter name and unit [string]
[6d6508e]137        self.input_name = "Q", #model_info.get("input_name", "Q")
138        self.input_unit = "A^{-1}" #model_info.get("input_unit", "A^{-1}")
139        self.output_name = "Intensity" #model_info.get("output_name", "Intensity")
140        self.output_unit = "cm^{-1}" #model_info.get("output_unit", "cm^{-1}")
[ce27e21]141
[87c722e]142        ## _persistency_dict is used by sas.perspectives.fitting.basepage
[ce27e21]143        ## to store dispersity reference.
144        ## TODO: _persistency_dict to persistency_dict throughout sasview
145        self._persistency_dict = {}
146
147        ## New fields introduced for opencl rewrite
148        self.cutoff = 1e-5
149
[de97440]150    def __get_state__(self):
151        state = self.__dict__.copy()
[4d76711]152        state.pop('_model')
[de97440]153        # May need to reload model info on set state since it has pointers
154        # to python implementations of Iq, etc.
155        #state.pop('_model_info')
156        return state
157
158    def __set_state__(self, state):
159        self.__dict__ = state
[fb5914f]160        self._model = None
[de97440]161
[ce27e21]162    def __str__(self):
163        """
164        :return: string representation
165        """
166        return self.name
167
168    def is_fittable(self, par_name):
169        """
170        Check if a given parameter is fittable or not
171
172        :param par_name: the parameter name to check
173        """
174        return par_name.lower() in self.fixed
175        #For the future
176        #return self.params[str(par_name)].is_fittable()
177
178
[3c56da87]179    # pylint: disable=no-self-use
[ce27e21]180    def getProfile(self):
181        """
182        Get SLD profile
183
184        : return: (z, beta) where z is a list of depth of the transition points
185                beta is a list of the corresponding SLD values
186        """
187        return None, None
188
189    def setParam(self, name, value):
190        """
191        Set the value of a model parameter
192
193        :param name: name of the parameter
194        :param value: value of the parameter
195
196        """
197        # Look for dispersion parameters
198        toks = name.split('.')
[de0c4ba]199        if len(toks) == 2:
[ce27e21]200            for item in self.dispersion.keys():
[de0c4ba]201                if item.lower() == toks[0].lower():
[ce27e21]202                    for par in self.dispersion[item]:
203                        if par.lower() == toks[1].lower():
204                            self.dispersion[item][par] = value
205                            return
206        else:
207            # Look for standard parameter
208            for item in self.params.keys():
[de0c4ba]209                if item.lower() == name.lower():
[ce27e21]210                    self.params[item] = value
211                    return
212
[63b32bb]213        raise ValueError("Model does not contain parameter %s" % name)
[ce27e21]214
215    def getParam(self, name):
216        """
217        Set the value of a model parameter
218
219        :param name: name of the parameter
220
221        """
222        # Look for dispersion parameters
223        toks = name.split('.')
[de0c4ba]224        if len(toks) == 2:
[ce27e21]225            for item in self.dispersion.keys():
[de0c4ba]226                if item.lower() == toks[0].lower():
[ce27e21]227                    for par in self.dispersion[item]:
228                        if par.lower() == toks[1].lower():
229                            return self.dispersion[item][par]
230        else:
231            # Look for standard parameter
232            for item in self.params.keys():
[de0c4ba]233                if item.lower() == name.lower():
[ce27e21]234                    return self.params[item]
235
[63b32bb]236        raise ValueError("Model does not contain parameter %s" % name)
[ce27e21]237
238    def getParamList(self):
239        """
240        Return a list of all available parameters for the model
241        """
[de0c4ba]242        param_list = self.params.keys()
[ce27e21]243        # WARNING: Extending the list with the dispersion parameters
[de0c4ba]244        param_list.extend(self.getDispParamList())
245        return param_list
[ce27e21]246
247    def getDispParamList(self):
248        """
[fb5914f]249        Return a list of polydispersity parameters for the model
[ce27e21]250        """
[1780d59]251        # TODO: fix test so that parameter order doesn't matter
[fb5914f]252        ret = ['%s.%s' % (p.name.lower(), ext)
[6d6508e]253               for p in self._model_info.parameters.user_parameters()
[fb5914f]254               for ext in ('npts', 'nsigmas', 'width')
255               if p.polydisperse]
[9404dd3]256        #print(ret)
[1780d59]257        return ret
[ce27e21]258
259    def clone(self):
260        """ Return a identical copy of self """
261        return deepcopy(self)
262
263    def run(self, x=0.0):
264        """
265        Evaluate the model
266
267        :param x: input q, or [q,phi]
268
269        :return: scattering function P(q)
270
271        **DEPRECATED**: use calculate_Iq instead
272        """
[de0c4ba]273        if isinstance(x, (list, tuple)):
[3c56da87]274            # pylint: disable=unpacking-non-sequence
[ce27e21]275            q, phi = x
276            return self.calculate_Iq([q * math.cos(phi)],
277                                     [q * math.sin(phi)])[0]
278        else:
279            return self.calculate_Iq([float(x)])[0]
280
281
282    def runXY(self, x=0.0):
283        """
284        Evaluate the model in cartesian coordinates
285
286        :param x: input q, or [qx, qy]
287
288        :return: scattering function P(q)
289
290        **DEPRECATED**: use calculate_Iq instead
291        """
[de0c4ba]292        if isinstance(x, (list, tuple)):
293            return self.calculate_Iq([float(x[0])], [float(x[1])])[0]
[ce27e21]294        else:
295            return self.calculate_Iq([float(x)])[0]
296
297    def evalDistribution(self, qdist):
[d138d43]298        r"""
[ce27e21]299        Evaluate a distribution of q-values.
300
[d138d43]301        :param qdist: array of q or a list of arrays [qx,qy]
[ce27e21]302
[d138d43]303        * For 1D, a numpy array is expected as input
[ce27e21]304
[d138d43]305        ::
[ce27e21]306
[d138d43]307            evalDistribution(q)
[ce27e21]308
[d138d43]309          where *q* is a numpy array.
[ce27e21]310
[d138d43]311        * For 2D, a list of *[qx,qy]* is expected with 1D arrays as input
[ce27e21]312
[d138d43]313        ::
[ce27e21]314
[d138d43]315              qx = [ qx[0], qx[1], qx[2], ....]
316              qy = [ qy[0], qy[1], qy[2], ....]
[ce27e21]317
[d138d43]318        If the model is 1D only, then
[ce27e21]319
[d138d43]320        .. math::
[ce27e21]321
[d138d43]322            q = \sqrt{q_x^2+q_y^2}
[ce27e21]323
324        """
[de0c4ba]325        if isinstance(qdist, (list, tuple)):
[ce27e21]326            # Check whether we have a list of ndarrays [qx,qy]
327            qx, qy = qdist
[6d6508e]328            if not self._model_info.parameters.has_2d:
[de0c4ba]329                return self.calculate_Iq(np.sqrt(qx ** 2 + qy ** 2))
[5d4777d]330            else:
331                return self.calculate_Iq(qx, qy)
[ce27e21]332
333        elif isinstance(qdist, np.ndarray):
334            # We have a simple 1D distribution of q-values
335            return self.calculate_Iq(qdist)
336
337        else:
[3c56da87]338            raise TypeError("evalDistribution expects q or [qx, qy], not %r"
339                            % type(qdist))
[ce27e21]340
341    def calculate_Iq(self, *args):
[ff7119b]342        """
343        Calculate Iq for one set of q with the current parameters.
344
345        If the model is 1D, use *q*.  If 2D, use *qx*, *qy*.
346
347        This should NOT be used for fitting since it copies the *q* vectors
348        to the card for each evaluation.
349        """
[fb5914f]350        if self._model is None:
[d2bb604]351            self._model = core.build_model(self._model_info)
[ce27e21]352        q_vectors = [np.asarray(q) for q in args]
[fb5914f]353        kernel = self._model.make_kernel(q_vectors)
354        pairs = [self._get_weights(p)
[6d6508e]355                 for p in self._model_info.parameters.call_parameters]
[4bfd277]356        call_details, weight, value = details.build_details(kernel, pairs)
357        result = kernel(call_details, weight, value, cutoff=self.cutoff)
[fb5914f]358        kernel.q_input.release()
359        kernel.release()
[ce27e21]360        return result
361
362    def calculate_ER(self):
363        """
364        Calculate the effective radius for P(q)*S(q)
365
366        :return: the value of the effective radius
367        """
[4bfd277]368        if self._model_info.ER is None:
[ce27e21]369            return 1.0
370        else:
[4bfd277]371            value, weight = self._dispersion_mesh()
372            fv = self._model_info.ER(*value)
[9404dd3]373            #print(values[0].shape, weights.shape, fv.shape)
[4bfd277]374            return np.sum(weight * fv) / np.sum(weight)
[ce27e21]375
376    def calculate_VR(self):
377        """
378        Calculate the volf ratio for P(q)*S(q)
379
380        :return: the value of the volf ratio
381        """
[4bfd277]382        if self._model_info.VR is None:
[ce27e21]383            return 1.0
384        else:
[4bfd277]385            value, weight = self._dispersion_mesh()
386            whole, part = self._model_info.VR(*value)
387            return np.sum(weight * part) / np.sum(weight * whole)
[ce27e21]388
389    def set_dispersion(self, parameter, dispersion):
390        """
391        Set the dispersion object for a model parameter
392
393        :param parameter: name of the parameter [string]
394        :param dispersion: dispersion object of type Dispersion
395        """
[1780d59]396        if parameter.lower() in (s.lower() for s in self.params.keys()):
397            # TODO: Store the disperser object directly in the model.
398            # The current method of creating one on the fly whenever it is
399            # needed is kind of funky.
400            # Note: can't seem to get disperser parameters from sasview
401            # (1) Could create a sasview model that has not yet # been
402            # converted, assign the disperser to one of its polydisperse
403            # parameters, then retrieve the disperser parameters from the
404            # sasview model.  (2) Could write a disperser parameter retriever
405            # in sasview.  (3) Could modify sasview to use sasmodels.weights
406            # dispersers.
407            # For now, rely on the fact that the sasview only ever uses
408            # new dispersers in the set_dispersion call and create a new
409            # one instead of trying to assign parameters.
410            from . import weights
411            disperser = weights.dispersers[dispersion.__class__.__name__]
412            dispersion = weights.models[disperser]()
[ce27e21]413            self.dispersion[parameter] = dispersion.get_pars()
414        else:
415            raise ValueError("%r is not a dispersity or orientation parameter")
416
[aa4946b]417    def _dispersion_mesh(self):
[ce27e21]418        """
419        Create a mesh grid of dispersion parameters and weights.
420
421        Returns [p1,p2,...],w where pj is a vector of values for parameter j
422        and w is a vector containing the products for weights for each
423        parameter set in the vector.
424        """
[4bfd277]425        pars = [self._get_weights(p)
426                for p in self._model_info.parameters.call_parameters
427                if p.type == 'volume']
428        return details.dispersion_mesh(self._model_info, pars)
[ce27e21]429
430    def _get_weights(self, par):
[de0c4ba]431        """
[fb5914f]432        Return dispersion weights for parameter
[de0c4ba]433        """
[fb5914f]434        if par.polydisperse:
435            dis = self.dispersion[par.name]
436            value, weight = weights.get_weights(
437                dis['type'], dis['npts'], dis['width'], dis['nsigmas'],
438                self.params[par.name], par.limits, par.relative_pd)
439            return value, weight / np.sum(weight)
440        else:
441            return [self.params[par.name]], []
[ce27e21]442
[fb5914f]443def test_model():
[4d76711]444    """
445    Test that a sasview model (cylinder) can be run.
446    """
447    Cylinder = _make_standard_model('cylinder')
[fb5914f]448    cylinder = Cylinder()
449    return cylinder.evalDistribution([0.1,0.1])
[de97440]450
[4d76711]451
452def test_model_list():
453    """
454    Make sure that all models build as sasview models.
455    """
456    from .exception import annotate_exception
457    for name in core.list_models():
458        try:
459            _make_standard_model(name)
460        except:
461            annotate_exception("when loading "+name)
462            raise
463
[fb5914f]464if __name__ == "__main__":
[ea05c87]465    print("cylinder(0.1,0.1)=%g"%test_model())
Note: See TracBrowser for help on using the repository browser.