Changeset 7cf2cfd in sasmodels for sasmodels/bumps_model.py


Ignore:
Timestamp:
Nov 22, 2015 11:37:15 PM (8 years ago)
Author:
Paul Kienzle <pkienzle@…>
Branches:
master, core_shell_microgels, costrafo411, magnetic_model, release_v0.94, release_v0.95, ticket-1257-vesicle-product, ticket_1156, ticket_1265_superball, ticket_822_more_unit_tests
Children:
3b4243d
Parents:
677ccf1
Message:

refactor compare.py so that bumps/sasview not required for simple tests

File:
1 edited

Legend:

Unmodified
Added
Removed
  • sasmodels/bumps_model.py

    r5d80bbf r7cf2cfd  
    1010how far the polydispersity integral extends. 
    1111 
    12 A variety of helper functions are provided: 
    13  
    14     :func:`load_data` loads a sasview data file. 
    15  
    16     :func:`empty_data1D` creates an empty dataset, which is useful for plotting 
    17     a theory function before the data is measured. 
    18  
    19     :func:`empty_data2D` creates an empty 2D dataset. 
    20  
    21     :func:`set_beam_stop` masks the beam stop from the data. 
    22  
    23     :func:`set_half` selects the right or left half of the data, which can 
    24     be useful for shear measurements which have not been properly corrected 
    25     for path length and reflections. 
    26  
    27     :func:`set_top` cuts the top part off the data. 
    28  
    29     :func:`plot_data` plots the data file. 
    30  
    31     :func:`plot_theory` plots a calculated result from the model. 
    32  
    3312""" 
    3413 
    3514import datetime 
    3615import warnings 
    37 import traceback 
    3816 
    3917import numpy as np 
    4018 
     19from bumps.names import Parameter 
     20 
    4121from . import sesans 
    42 from .resolution import Perfect1D, Pinhole1D, Slit1D 
    43 from .resolution2d import Pinhole2D 
    44  
    45 # CRUFT python 2.6 
    46 if not hasattr(datetime.timedelta, 'total_seconds'): 
    47     def delay(dt): 
    48         """Return number date-time delta as number seconds""" 
    49         return dt.days * 86400 + dt.seconds + 1e-6 * dt.microseconds 
    50 else: 
    51     def delay(dt): 
    52         """Return number date-time delta as number seconds""" 
    53         return dt.total_seconds() 
    54  
     22from . import weights 
     23from .data import plot_theory 
     24from .direct_model import DataMixin 
    5525 
    5626# CRUFT: old style bumps wrapper which doesn't separate data and model 
     
    6434 
    6535 
    66  
    67 def tic(): 
    68     """ 
    69     Timer function. 
    70  
    71     Use "toc=tic()" to start the clock and "toc()" to measure 
    72     a time interval. 
    73     """ 
    74     then = datetime.datetime.now() 
    75     return lambda: delay(datetime.datetime.now() - then) 
    76  
    77  
    78 def load_data(filename): 
    79     """ 
    80     Load data using a sasview loader. 
    81     """ 
    82     from sas.dataloader.loader import Loader 
    83     loader = Loader() 
    84     data = loader.load(filename) 
    85     if data is None: 
    86         raise IOError("Data %r could not be loaded" % filename) 
    87     return data 
    88  
    89 def plot_data(data, view='log'): 
    90     """ 
    91     Plot data loaded by the sasview loader. 
    92     """ 
    93     if hasattr(data, 'qx_data'): 
    94         _plot_2d_signal(data, data.data, view=view) 
    95     else: 
    96         # Note: kind of weird using the _plot_result1D to plot just the 
    97         # data, but it handles the masking and graph markup already, so 
    98         # do not repeat. 
    99         _plot_result1D(data, None, None, view) 
    100  
    101 def plot_theory(data, theory, view='log'): 
    102     if hasattr(data, 'qx_data'): 
    103         _plot_2d_signal(data, theory, view=view) 
    104     else: 
    105         _plot_result1D(data, theory, None, view, include_data=False) 
    106  
    107  
    108 def empty_data1D(q, resolution=0.05): 
    109     """ 
    110     Create empty 1D data using the given *q* as the x value. 
    111  
    112     *resolution* dq/q defaults to 5%. 
    113     """ 
    114  
    115     from sas.dataloader.data_info import Data1D 
    116  
    117     Iq = 100 * np.ones_like(q) 
    118     dIq = np.sqrt(Iq) 
    119     data = Data1D(q, Iq, dx=resolution * q, dy=dIq) 
    120     data.filename = "fake data" 
    121     data.qmin, data.qmax = q.min(), q.max() 
    122     data.mask = np.zeros(len(Iq), dtype='bool') 
    123     return data 
    124  
    125  
    126 def empty_data2D(qx, qy=None, resolution=0.05): 
    127     """ 
    128     Create empty 2D data using the given mesh. 
    129  
    130     If *qy* is missing, create a square mesh with *qy=qx*. 
    131  
    132     *resolution* dq/q defaults to 5%. 
    133     """ 
    134     from sas.dataloader.data_info import Data2D, Detector 
    135  
    136     if qy is None: 
    137         qy = qx 
    138     Qx, Qy = np.meshgrid(qx, qy) 
    139     Qx, Qy = Qx.flatten(), Qy.flatten() 
    140     Iq = 100 * np.ones_like(Qx) 
    141     dIq = np.sqrt(Iq) 
    142     mask = np.ones(len(Iq), dtype='bool') 
    143  
    144     data = Data2D() 
    145     data.filename = "fake data" 
    146     data.qx_data = Qx 
    147     data.qy_data = Qy 
    148     data.data = Iq 
    149     data.err_data = dIq 
    150     data.mask = mask 
    151     data.qmin = 1e-16 
    152     data.qmax = np.inf 
    153  
    154     # 5% dQ/Q resolution 
    155     if resolution != 0: 
    156         # https://www.ncnr.nist.gov/staff/hammouda/distance_learning/chapter_15.pdf 
    157         # Should have an additional constant which depends on distances and 
    158         # radii of the aperture, pixel dimensions and wavelength spread 
    159         # Instead, assume radial dQ/Q is constant, and perpendicular matches 
    160         # radial (which instead it should be inverse). 
    161         Q = np.sqrt(Qx**2 + Qy**2) 
    162         data.dqx_data = resolution * Q 
    163         data.dqy_data = resolution * Q 
    164  
    165     detector = Detector() 
    166     detector.pixel_size.x = 5 # mm 
    167     detector.pixel_size.y = 5 # mm 
    168     detector.distance = 4 # m 
    169     data.detector.append(detector) 
    170     data.xbins = qx 
    171     data.ybins = qy 
    172     data.source.wavelength = 5 # angstroms 
    173     data.source.wavelength_unit = "A" 
    174     data.Q_unit = "1/A" 
    175     data.I_unit = "1/cm" 
    176     data.q_data = np.sqrt(Qx ** 2 + Qy ** 2) 
    177     data.xaxis("Q_x", "A^{-1}") 
    178     data.yaxis("Q_y", "A^{-1}") 
    179     data.zaxis("Intensity", r"\text{cm}^{-1}") 
    180     return data 
    181  
    182  
    183 def set_beam_stop(data, radius, outer=None): 
    184     """ 
    185     Add a beam stop of the given *radius*.  If *outer*, make an annulus. 
    186     """ 
    187     from sas.dataloader.manipulations import Ringcut 
    188     if hasattr(data, 'qx_data'): 
    189         data.mask = Ringcut(0, radius)(data) 
    190         if outer is not None: 
    191             data.mask += Ringcut(outer, np.inf)(data) 
    192     else: 
    193         data.mask = (data.x >= radius) 
    194         if outer is not None: 
    195             data.mask &= (data.x < outer) 
    196  
    197  
    198 def set_half(data, half): 
    199     """ 
    200     Select half of the data, either "right" or "left". 
    201     """ 
    202     from sas.dataloader.manipulations import Boxcut 
    203     if half == 'right': 
    204         data.mask += \ 
    205             Boxcut(x_min=-np.inf, x_max=0.0, y_min=-np.inf, y_max=np.inf)(data) 
    206     if half == 'left': 
    207         data.mask += \ 
    208             Boxcut(x_min=0.0, x_max=np.inf, y_min=-np.inf, y_max=np.inf)(data) 
    209  
    210  
    211 def set_top(data, cutoff): 
    212     """ 
    213     Chop the top off the data, above *cutoff*. 
    214     """ 
    215     from sas.dataloader.manipulations import Boxcut 
    216     data.mask += \ 
    217         Boxcut(x_min=-np.inf, x_max=np.inf, y_min=-np.inf, y_max=cutoff)(data) 
    218  
    219 def protect(fn): 
    220     def wrapper(*args, **kw): 
    221         try:  
    222             return fn(*args, **kw) 
    223         except: 
    224             traceback.print_exc() 
    225     return wrapper 
    226  
    227 @protect 
    228 def _plot_result1D(data, theory, resid, view, include_data=True): 
    229     """ 
    230     Plot the data and residuals for 1D data. 
    231     """ 
    232     import matplotlib.pyplot as plt 
    233     from numpy.ma import masked_array, masked 
    234     #print "not a number",sum(np.isnan(data.y)) 
    235     #data.y[data.y<0.05] = 0.5 
    236     mdata = masked_array(data.y, data.mask) 
    237     mdata[~np.isfinite(mdata)] = masked 
    238     if view is 'log': 
    239         mdata[mdata <= 0] = masked 
    240  
    241     scale = data.x**4 if view == 'q4' else 1.0 
    242     if resid is not None: 
    243         plt.subplot(121) 
    244  
    245     positive = False 
    246     if include_data: 
    247         plt.errorbar(data.x, scale*mdata, yerr=data.dy, fmt='.') 
    248         positive = positive or (mdata>0).any() 
    249     if theory is not None: 
    250         mtheory = masked_array(theory, mdata.mask) 
    251         plt.plot(data.x, scale*mtheory, '-', hold=True) 
    252         positive = positive or (mtheory>0).any() 
    253     plt.xscale(view) 
    254     plt.yscale('linear' if view == 'q4' or not positive else view) 
    255     plt.xlabel('Q') 
    256     plt.ylabel('I(Q)') 
    257     if resid is not None: 
    258         mresid = masked_array(resid, mdata.mask) 
    259         plt.subplot(122) 
    260         plt.plot(data.x, mresid, 'x') 
    261         plt.ylabel('residuals') 
    262         plt.xlabel('Q') 
    263         plt.xscale(view) 
    264  
    265 # pylint: disable=unused-argument 
    266 @protect 
    267 def _plot_sesans(data, theory, resid, view): 
    268     import matplotlib.pyplot as plt 
    269     plt.subplot(121) 
    270     plt.errorbar(data.x, data.y, yerr=data.dy) 
    271     plt.plot(data.x, theory, '-', hold=True) 
    272     plt.xlabel('spin echo length (nm)') 
    273     plt.ylabel('polarization (P/P0)') 
    274     plt.subplot(122) 
    275     plt.plot(data.x, resid, 'x') 
    276     plt.xlabel('spin echo length (nm)') 
    277     plt.ylabel('residuals (P/P0)') 
    278  
    279 @protect 
    280 def _plot_result2D(data, theory, resid, view): 
    281     """ 
    282     Plot the data and residuals for 2D data. 
    283     """ 
    284     import matplotlib.pyplot as plt 
    285     target = data.data[~data.mask] 
    286     if view == 'log': 
    287         vmin = min(target[target>0].min(), theory[theory>0].min()) 
    288         vmax = max(target.max(), theory.max()) 
    289     else: 
    290         vmin = min(target.min(), theory.min()) 
    291         vmax = max(target.max(), theory.max()) 
    292     #print vmin, vmax 
    293     plt.subplot(131) 
    294     _plot_2d_signal(data, target, view=view, vmin=vmin, vmax=vmax) 
    295     plt.title('data') 
    296     plt.colorbar() 
    297     plt.subplot(132) 
    298     _plot_2d_signal(data, theory, view=view, vmin=vmin, vmax=vmax) 
    299     plt.title('theory') 
    300     plt.colorbar() 
    301     plt.subplot(133) 
    302     _plot_2d_signal(data, resid, view='linear') 
    303     plt.title('residuals') 
    304     plt.colorbar() 
    305  
    306 @protect 
    307 def _plot_2d_signal(data, signal, vmin=None, vmax=None, view='log'): 
    308     """ 
    309     Plot the target value for the data.  This could be the data itself, 
    310     the theory calculation, or the residuals. 
    311  
    312     *scale* can be 'log' for log scale data, or 'linear'. 
    313     """ 
    314     import matplotlib.pyplot as plt 
    315     from numpy.ma import masked_array 
    316  
    317     image = np.zeros_like(data.qx_data) 
    318     image[~data.mask] = signal 
    319     valid = np.isfinite(image) 
    320     if view == 'log': 
    321         valid[valid] = (image[valid] > 0) 
    322         image[valid] = np.log10(image[valid]) 
    323     elif view == 'q4': 
    324         image[valid] *= (data.qx_data[valid]**2+data.qy_data[valid]**2)**2 
    325     image[~valid | data.mask] = 0 
    326     #plottable = Iq 
    327     plottable = masked_array(image, ~valid | data.mask) 
    328     xmin, xmax = min(data.qx_data), max(data.qx_data) 
    329     ymin, ymax = min(data.qy_data), max(data.qy_data) 
    330     # TODO: fix vmin, vmax so it is shared for theory/resid 
    331     vmin = vmax = None 
    332     try: 
    333         if vmin is None: vmin = image[valid & ~data.mask].min() 
    334         if vmax is None: vmax = image[valid & ~data.mask].max() 
    335     except: 
    336         vmin, vmax = 0, 1 
    337     #print vmin,vmax 
    338     plt.imshow(plottable.reshape(128, 128), 
    339                interpolation='nearest', aspect=1, origin='upper', 
    340                extent=[xmin, xmax, ymin, ymax], vmin=vmin, vmax=vmax) 
    341  
    342  
    34336class Model(object): 
    344     def __init__(self, kernel, **kw): 
    345         from bumps.names import Parameter 
    346  
    347         self.kernel = kernel 
    348         partype = kernel.info['partype'] 
     37    def __init__(self, model, **kw): 
     38        self._sasmodel = model 
     39        partype = model.info['partype'] 
    34940 
    35041        pars = [] 
    351         for p in kernel.info['parameters']: 
     42        for p in model.info['parameters']: 
    35243            name, default, limits = p[0], p[2], p[3] 
    35344            value = kw.pop(name, default) 
     
    37869        return dict((k, getattr(self, k)) for k in self._parameter_names) 
    37970 
    380 class Experiment(object): 
     71 
     72class Experiment(DataMixin): 
    38173    """ 
    38274    Return a bumps wrapper for a SAS model. 
     
    39789 
    39890        # remember inputs so we can inspect from outside 
    399         self.data = data 
    40091        self.model = model 
    40192        self.cutoff = cutoff 
    402         if hasattr(data, 'lam'): 
    403             self.data_type = 'sesans' 
    404         elif hasattr(data, 'qx_data'): 
    405             self.data_type = 'Iqxy' 
    406         else: 
    407             self.data_type = 'Iq' 
    408  
    409         # interpret data 
    410         partype = model.kernel.info['partype'] 
    411         if self.data_type == 'sesans': 
    412             q = sesans.make_q(data.sample.zacceptance, data.Rmax) 
    413             self.index = slice(None, None) 
    414             self.Iq = data.y 
    415             self.dIq = data.dy 
    416             #self._theory = np.zeros_like(q) 
    417             q_vectors = [q] 
    418         elif self.data_type == 'Iqxy': 
    419             q = np.sqrt(data.qx_data**2 + data.qy_data**2) 
    420             qmin = getattr(data, 'qmin', 1e-16) 
    421             qmax = getattr(data, 'qmax', np.inf) 
    422             accuracy = getattr(data, 'accuracy', 'Low') 
    423             self.index = (~data.mask) & (~np.isnan(data.data)) \ 
    424                          & (q >= qmin) & (q <= qmax) 
    425             self.Iq = data.data[self.index] 
    426             self.dIq = data.err_data[self.index] 
    427             self.resolution = Pinhole2D(data=data, index=self.index, 
    428                                         nsigma=3.0, accuracy=accuracy) 
    429             #self._theory = np.zeros_like(self.Iq) 
    430             if not partype['orientation'] and not partype['magnetic']: 
    431                 raise ValueError("not 2D without orientation or magnetic parameters") 
    432                 #qx,qy = self.resolution.q_calc 
    433                 #q_vectors = [np.sqrt(qx**2 + qy**2)] 
    434             else: 
    435                 q_vectors = self.resolution.q_calc 
    436         elif self.data_type == 'Iq': 
    437             self.index = (data.x >= data.qmin) & (data.x <= data.qmax) & ~np.isnan(data.y) 
    438             self.Iq = data.y[self.index] 
    439             self.dIq = data.dy[self.index] 
    440             if getattr(data, 'dx', None) is not None: 
    441                 q, dq = data.x[self.index], data.dx[self.index] 
    442                 if (dq>0).any(): 
    443                     self.resolution = Pinhole1D(q, dq) 
    444                 else: 
    445                     self.resolution = Perfect1D(q) 
    446             elif (getattr(data, 'dxl', None) is not None and 
    447                   getattr(data, 'dxw', None) is not None): 
    448                 q = data.x[self.index] 
    449                 width = data.dxh[self.index]  # Note: dx 
    450                 self.resolution = Slit1D(data.x[self.index], 
    451                                          width=data.dxh[self.index], 
    452                                          height=data.dxw[self.index]) 
    453             else: 
    454                 self.resolution = Perfect1D(data.x[self.index]) 
    455  
    456             #self._theory = np.zeros_like(self.Iq) 
    457             q_vectors = [self.resolution.q_calc] 
    458         else: 
    459             raise ValueError("Unknown data type") # never gets here 
    460  
    461         # Remember function inputs so we can delay loading the function and 
    462         # so we can save/restore state 
    463         self._fn_inputs = [v for v in q_vectors] 
    464         self._fn = None 
    465  
     93        self._interpret_data(data, model._sasmodel) 
    46694        self.update() 
    46795 
     
    483111    def theory(self): 
    484112        if 'theory' not in self._cache: 
     113            pars = dict((k, v.value) for k,v in self.model.parameters().items()) 
     114            self._cache['theory'] = self._calc_theory(pars, cutoff=self.cutoff) 
     115            """ 
    485116            if self._fn is None: 
    486                 q_input = self.model.kernel.make_input(self._fn_inputs) 
     117                q_input = self.model.kernel.make_input(self._kernel_inputs) 
    487118                self._fn = self.model.kernel(q_input) 
    488119 
     
    495126                result = sesans.hankel(self.data.x, self.data.lam * 1e-9, 
    496127                                       self.data.sample.thickness / 10, 
    497                                        self._fn_inputs[0], Iq_calc) 
     128                                       self._kernel_inputs[0], Iq_calc) 
    498129                self._cache['theory'] = result 
    499130            else: 
    500131                Iq = self.resolution.apply(Iq_calc) 
    501132                self._cache['theory'] = Iq 
     133            """ 
    502134        return self._cache['theory'] 
    503135 
     
    518150        Plot the data and residuals. 
    519151        """ 
    520         data, theory, resid = self.data, self.theory(), self.residuals() 
    521         if self.data_type == 'Iq': 
    522             _plot_result1D(data, theory, resid, view) 
    523         elif self.data_type == 'Iqxy': 
    524             _plot_result2D(data, theory, resid, view) 
    525         elif self.data_type == 'sesans': 
    526             _plot_sesans(data, theory, resid, view) 
    527         else: 
    528             raise ValueError("Unknown data type") 
     152        data, theory, resid = self._data, self.theory(), self.residuals() 
     153        plot_theory(data, theory, resid, view) 
    529154 
    530155    def simulate_data(self, noise=None): 
    531         theory = self.theory() 
    532         if noise is not None: 
    533             self.dIq = theory*noise*0.01 
    534         dy = self.dIq 
    535         y = theory + np.random.randn(*dy.shape) * dy 
    536         self.Iq = y 
    537         if self.data_type == 'Iq': 
    538             self.data.dy[self.index] = dy 
    539             self.data.y[self.index] = y 
    540         elif self.data_type == 'Iqxy': 
    541             self.data.data[self.index] = y 
    542         elif self.data_type == 'sesans': 
    543             self.data.y[self.index] = y 
    544         else: 
    545             raise ValueError("Unknown model") 
     156        Iq = self.theory() 
     157        self._set_data(Iq, noise) 
    546158 
    547159    def save(self, basename): 
    548160        pass 
    549161 
    550     def _get_weights(self, par): 
     162    def remove_get_weights(self, name): 
    551163        """ 
    552164        Get parameter dispersion weights 
    553165        """ 
    554         from . import weights 
    555  
    556         relative = self.model.kernel.info['partype']['pd-rel'] 
    557         limits = self.model.kernel.info['limits'] 
     166        info = self.model.kernel.info 
     167        relative = name in info['partype']['pd-rel'] 
     168        limits = info['limits'][name] 
    558169        disperser, value, npts, width, nsigma = [ 
    559             getattr(self.model, par + ext) 
     170            getattr(self.model, name + ext) 
    560171            for ext in ('_pd_type', '', '_pd_n', '_pd', '_pd_nsigma')] 
    561172        value, weight = weights.get_weights( 
    562173            disperser, int(npts.value), width.value, nsigma.value, 
    563             value.value, limits[par], par in relative) 
     174            value.value, limits, relative) 
    564175        return value, weight / np.sum(weight) 
    565176 
     
    567178        # Can't pickle gpu functions, so instead make them lazy 
    568179        state = self.__dict__.copy() 
    569         state['_fn'] = None 
     180        state['_kernel'] = None 
    570181        return state 
    571182 
     
    573184        # pylint: disable=attribute-defined-outside-init 
    574185        self.__dict__ = state 
    575  
    576  
    577 def demo(): 
    578     data = load_data('DEC07086.DAT') 
    579     set_beam_stop(data, 0.004) 
    580     plot_data(data) 
    581     import matplotlib.pyplot as plt; plt.show() 
    582  
    583  
    584 if __name__ == "__main__": 
    585     demo() 
Note: See TracChangeset for help on using the changeset viewer.