source: sasmodels/sasmodels/direct_model.py @ 7cf2cfd

core_shell_microgelscostrafo411magnetic_modelrelease_v0.94release_v0.95ticket-1257-vesicle-productticket_1156ticket_1265_superballticket_822_more_unit_tests
Last change on this file since 7cf2cfd was 7cf2cfd, checked in by Paul Kienzle <pkienzle@…>, 8 years ago

refactor compare.py so that bumps/sasview not required for simple tests

  • Property mode set to 100644
File size: 6.7 KB
Line 
1import warnings
2
3import numpy as np
4
5from .core import load_model_definition, load_model, make_kernel
6from .core import call_kernel, call_ER, call_VR
7from . import sesans
8from . import resolution
9from . import resolution2d
10
11class DataMixin(object):
12    """
13    DataMixin captures the common aspects of evaluating a SAS model for a
14    particular data set, including calculating Iq and evaluating the
15    resolution function.  It is used in particular by :class:`DirectModel`,
16    which evaluates a SAS model parameters as key word arguments to the
17    calculator method, and by :class:`bumps_model.Experiment`, which wraps the
18    model and data for use with the Bumps fitting engine.  It is not
19    currently used by :class:`sasview_model.SasviewModel` since this will
20    require a number of changes to SasView before we can do it.
21    """
22    def _interpret_data(self, data, model):
23        self._data = data
24        self._model = model
25
26        # interpret data
27        if hasattr(data, 'lam'):
28            self.data_type = 'sesans'
29        elif hasattr(data, 'qx_data'):
30            self.data_type = 'Iqxy'
31        else:
32            self.data_type = 'Iq'
33
34        partype = model.info['partype']
35
36        if self.data_type == 'sesans':
37            q = sesans.make_q(data.sample.zacceptance, data.Rmax)
38            self.index = slice(None, None)
39            if data.y is not None:
40                self.Iq = data.y
41                self.dIq = data.dy
42            #self._theory = np.zeros_like(q)
43            q_vectors = [q]
44        elif self.data_type == 'Iqxy':
45            if not partype['orientation'] and not partype['magnetic']:
46                raise ValueError("not 2D without orientation or magnetic parameters")
47            q = np.sqrt(data.qx_data**2 + data.qy_data**2)
48            qmin = getattr(data, 'qmin', 1e-16)
49            qmax = getattr(data, 'qmax', np.inf)
50            accuracy = getattr(data, 'accuracy', 'Low')
51            self.index = ~data.mask & (q >= qmin) & (q <= qmax)
52            if data.data is not None:
53                self.index &= ~np.isnan(data.data)
54                self.Iq = data.data[self.index]
55                self.dIq = data.err_data[self.index]
56            self.resolution = resolution2d.Pinhole2D(data=data, index=self.index,
57                                                     nsigma=3.0, accuracy=accuracy)
58            #self._theory = np.zeros_like(self.Iq)
59            q_vectors = self.resolution.q_calc
60        elif self.data_type == 'Iq':
61            self.index = (data.x >= data.qmin) & (data.x <= data.qmax)
62            if data.y is not None:
63                self.index &= ~np.isnan(data.y)
64                self.Iq = data.y[self.index]
65                self.dIq = data.dy[self.index]
66            if getattr(data, 'dx', None) is not None:
67                q, dq = data.x[self.index], data.dx[self.index]
68                if (dq>0).any():
69                    self.resolution = resolution.Pinhole1D(q, dq)
70                else:
71                    self.resolution = resolution.Perfect1D(q)
72            elif (getattr(data, 'dxl', None) is not None and
73                          getattr(data, 'dxw', None) is not None):
74                self.resolution = resolution.Slit1D(data.x[self.index],
75                                                    width=data.dxh[self.index],
76                                                    height=data.dxw[self.index])
77            else:
78                self.resolution = resolution.Perfect1D(data.x[self.index])
79
80            #self._theory = np.zeros_like(self.Iq)
81            q_vectors = [self.resolution.q_calc]
82        else:
83            raise ValueError("Unknown data type") # never gets here
84
85        # Remember function inputs so we can delay loading the function and
86        # so we can save/restore state
87        self._kernel_inputs = [v for v in q_vectors]
88        self._kernel = None
89
90    def _set_data(self, Iq, noise=None):
91        if noise is not None:
92            self.dIq = Iq*noise*0.01
93        dy = self.dIq
94        y = Iq + np.random.randn(*dy.shape) * dy
95        self.Iq = y
96        if self.data_type == 'Iq':
97            self._data.dy[self.index] = dy
98            self._data.y[self.index] = y
99        elif self.data_type == 'Iqxy':
100            self._data.data[self.index] = y
101        elif self.data_type == 'sesans':
102            self._data.y[self.index] = y
103        else:
104            raise ValueError("Unknown model")
105
106    def _calc_theory(self, pars, cutoff=0.0):
107        if self._kernel is None:
108            q_input = self._model.make_input(self._kernel_inputs)
109            self._kernel = self._model(q_input)
110
111        Iq_calc = call_kernel(self._kernel, pars, cutoff=cutoff)
112        if self.data_type == 'sesans':
113            result = sesans.hankel(self._data.x, self._data.lam * 1e-9,
114                                   self._data.sample.thickness / 10,
115                                   self._kernel_inputs[0], Iq_calc)
116        else:
117            result = self.resolution.apply(Iq_calc)
118        return result
119
120
121class DirectModel(DataMixin):
122    def __init__(self, data, model, cutoff=1e-5):
123        self.model = model
124        self.cutoff = cutoff
125        self._interpret_data(data, model)
126        self.kernel = make_kernel(self.model, self._kernel_inputs)
127    def __call__(self, **pars):
128        return self._calc_theory(pars, cutoff=self.cutoff)
129    def ER(self, **pars):
130        return call_ER(self.model.info, pars)
131    def VR(self, **pars):
132        return call_VR(self.model.info, pars)
133    def simulate_data(self, noise=None, **pars):
134        Iq = self.__call__(**pars)
135        self._set_data(Iq, noise=noise)
136
137def demo():
138    import sys
139    from .data import empty_data1D, empty_data2D
140
141    if len(sys.argv) < 3:
142        print "usage: python -m sasmodels.direct_model modelname (q|qx,qy) par=val ..."
143        sys.exit(1)
144    model_name = sys.argv[1]
145    call = sys.argv[2].upper()
146    if call not in ("ER","VR"):
147        try:
148            values = [float(v) for v in call.split(',')]
149        except:
150            values = []
151        if len(values) == 1:
152            q, = values
153            data = empty_data1D([q])
154        elif len(values) == 2:
155            qx,qy = values
156            data = empty_data2D([qx],[qy])
157        else:
158            print "use q or qx,qy or ER or VR"
159            sys.exit(1)
160    else:
161        data = empty_data1D([0.001])  # Data not used in ER/VR
162
163    model_definition = load_model_definition(model_name)
164    model = load_model(model_definition, dtype='single')
165    calculator = DirectModel(data, model)
166    pars = dict((k,float(v))
167                for pair in sys.argv[3:]
168                for k,v in [pair.split('=')])
169    if call == "ER":
170        print calculator.ER(**pars)
171    elif call == "VR":
172        print calculator.VR(**pars)
173    else:
174        Iq = calculator(**pars)
175        print Iq[0]
176
177if __name__ == "__main__":
178    demo()
Note: See TracBrowser for help on using the repository browser.