Changeset 69ec80f in sasmodels


Ignore:
Timestamp:
Jan 28, 2016 5:42:09 PM (8 years ago)
Author:
Paul Kienzle <pkienzle@…>
Branches:
master, core_shell_microgels, costrafo411, magnetic_model, release_v0.94, release_v0.95, ticket-1257-vesicle-product, ticket_1156, ticket_1265_superball, ticket_822_more_unit_tests
Children:
841753c
Parents:
a3e78c3
Message:

refactor code to reduce lint count

File:
1 edited

Legend:

Unmodified
Added
Removed
  • sasmodels/data.py

    rd15a908 r69ec80f  
    9090        self.x, self.y, self.dx, self.dy = x, y, dx, dy 
    9191        self.dxl = None 
     92        self.filename = None 
     93        self.qmin = x.min() if x is not None else np.NaN 
     94        self.qmax = x.max() if x is not None else np.NaN 
     95        self.mask = np.isnan(y) if y is not None else None 
     96        self._xaxis, self._xunit = "x", "" 
     97        self._yaxis, self._yunit = "y", "" 
    9298 
    9399    def xaxis(self, label, unit): 
     
    108114 
    109115class Data2D(object): 
    110     def __init__(self): 
     116    def __init__(self, x=None, y=None, z=None, dx=None, dy=None, dz=None): 
     117        self.qx_data, self.dqx_data = x, dx 
     118        self.qy_data, self.dqy_data = y, dy 
     119        self.data, self.err_data = z, dz 
     120        self.mask = ~np.isnan(z) if z is not None else None 
     121        self.q_data = np.sqrt(x**2 + y**2) 
     122        self.qmin = 1e-16 
     123        self.qmax = np.inf 
    111124        self.detector = [] 
    112125        self.source = Source() 
     126        self.Q_unit = "1/A" 
     127        self.I_unit = "1/cm" 
     128        self.xaxis("Q_x", "A^{-1}") 
     129        self.yaxis("Q_y", "A^{-1}") 
     130        self.zaxis("Intensity", r"\text{cm}^{-1}") 
     131        self._xaxis, self._xunit = "x", "" 
     132        self._yaxis, self._yunit = "y", "" 
     133        self._zaxis, self._zunit = "z", "" 
     134        self.x_bins, self.y_bins = None, None 
    113135 
    114136    def xaxis(self, label, unit): 
     
    139161 
    140162class Detector(object): 
     163    """ 
     164    Detector attributes. 
     165    """ 
     166    def __init__(self, pixel_size=(None, None), distance=None): 
     167        self.pixel_size = Vector(*pixel_size) 
     168        self.distance = distance 
     169 
     170class Source(object): 
     171    """ 
     172    Beam attributes. 
     173    """ 
    141174    def __init__(self): 
    142         self.pixel_size = Vector() 
    143  
    144 class Source(object): 
    145     pass 
     175        self.wavelength = np.NaN 
     176        self.wavelength_unit = "A" 
    146177 
    147178 
     
    158189    data = Data1D(q, Iq, dx=resolution * q, dy=dIq) 
    159190    data.filename = "fake data" 
    160     data.qmin, data.qmax = q.min(), q.max() 
    161     data.mask = np.zeros(len(q), dtype='bool') 
    162191    return data 
    163192 
     
    173202    if qy is None: 
    174203        qy = qx 
     204    # 5% dQ/Q resolution 
    175205    Qx, Qy = np.meshgrid(qx, qy) 
    176206    Qx, Qy = Qx.flatten(), Qy.flatten() 
    177207    Iq = 100 * np.ones_like(Qx) 
    178208    dIq = np.sqrt(Iq) 
    179     mask = np.ones(len(Iq), dtype='bool') 
    180  
    181     data = Data2D() 
    182     data.filename = "fake data" 
    183     data.qx_data = Qx 
    184     data.qy_data = Qy 
    185     data.data = Iq 
    186     data.err_data = dIq 
    187     data.mask = mask 
    188     data.qmin = 1e-16 
    189     data.qmax = np.inf 
    190  
    191     # 5% dQ/Q resolution 
    192209    if resolution != 0: 
    193210        # https://www.ncnr.nist.gov/staff/hammouda/distance_learning/chapter_15.pdf 
     
    197214        # radial (which instead it should be inverse). 
    198215        Q = np.sqrt(Qx**2 + Qy**2) 
    199         data.dqx_data = resolution * Q 
    200         data.dqy_data = resolution * Q 
     216        dqx = resolution * Q 
     217        dqy = resolution * Q 
    201218    else: 
    202         data.dqx_data = data.dqy_data = None 
    203  
    204     detector = Detector() 
    205     detector.pixel_size.x = 5 # mm 
    206     detector.pixel_size.y = 5 # mm 
    207     detector.distance = 4 # m 
    208     data.detector.append(detector) 
     219        dqx = dqy = None 
     220 
     221    data = Data2D(x=Qx, y=Qy, z=Iq, dx=dqx, dy=dqy, dz=dIq) 
    209222    data.x_bins = qx 
    210223    data.y_bins = qy 
     224    data.filename = "fake data" 
     225 
     226    # pixel_size in mm, distance in m 
     227    detector = Detector(pixel_size=(5, 5), distance=4) 
     228    data.detector.append(detector) 
    211229    data.source.wavelength = 5 # angstroms 
    212230    data.source.wavelength_unit = "A" 
    213     data.Q_unit = "1/A" 
    214     data.I_unit = "1/cm" 
    215     data.q_data = np.sqrt(Qx ** 2 + Qy ** 2) 
    216     data.xaxis("Q_x", "A^{-1}") 
    217     data.yaxis("Q_y", "A^{-1}") 
    218     data.zaxis("Intensity", r"\text{cm}^{-1}") 
    219231    return data 
    220232 
     
    228240    # do not repeat. 
    229241    if hasattr(data, 'lam'): 
    230         _plot_result_sesans(data, None, None, plot_data=True, limits=limits) 
     242        _plot_result_sesans(data, None, None, use_data=True, limits=limits) 
    231243    elif hasattr(data, 'qx_data'): 
    232         _plot_result2D(data, None, None, view, plot_data=True, limits=limits) 
     244        _plot_result2D(data, None, None, view, use_data=True, limits=limits) 
    233245    else: 
    234         _plot_result1D(data, None, None, view, plot_data=True, limits=limits) 
     246        _plot_result1D(data, None, None, view, use_data=True, limits=limits) 
    235247 
    236248 
    237249def plot_theory(data, theory, resid=None, view='log', 
    238                 plot_data=True, limits=None): 
     250                use_data=True, limits=None): 
    239251    if hasattr(data, 'lam'): 
    240         _plot_result_sesans(data, theory, resid, plot_data=True, limits=limits) 
     252        _plot_result_sesans(data, theory, resid, use_data=True, limits=limits) 
    241253    elif hasattr(data, 'qx_data'): 
    242         _plot_result2D(data, theory, resid, view, plot_data, limits=limits) 
     254        _plot_result2D(data, theory, resid, view, use_data, limits=limits) 
    243255    else: 
    244         _plot_result1D(data, theory, resid, view, plot_data, limits=limits) 
     256        _plot_result1D(data, theory, resid, view, use_data, limits=limits) 
    245257 
    246258 
     
    251263        except: 
    252264            traceback.print_exc() 
    253             pass 
    254265 
    255266    return wrapper 
     
    257268 
    258269@protect 
    259 def _plot_result1D(data, theory, resid, view, plot_data, limits=None): 
     270def _plot_result1D(data, theory, resid, view, use_data, limits=None): 
    260271    """ 
    261272    Plot the data and residuals for 1D data. 
     
    264275    from numpy.ma import masked_array, masked 
    265276 
    266     plot_theory = theory is not None 
    267     plot_resid = resid is not None 
    268  
    269     if data.y is None: 
    270         plot_data = False 
     277    use_data = use_data and data.y is not None 
     278    use_theory = theory is not None 
     279    use_resid = resid is not None 
     280    num_plots = (use_data or use_theory) + use_resid 
     281 
    271282    scale = data.x**4 if view == 'q4' else 1.0 
    272283 
    273     if plot_data or plot_theory: 
    274         if plot_resid: 
    275             plt.subplot(121) 
    276  
     284    if use_data or use_theory: 
    277285        #print(vmin, vmax) 
    278286        all_positive = True 
    279287        some_present = False 
    280         if plot_data: 
     288        if use_data: 
    281289            mdata = masked_array(data.y, data.mask.copy()) 
    282290            mdata[~np.isfinite(mdata)] = masked 
     
    288296 
    289297 
    290         if plot_theory: 
     298        if use_theory: 
    291299            mtheory = masked_array(theory, data.mask.copy()) 
    292300            mtheory[~np.isfinite(mtheory)] = masked 
     
    299307        if limits is not None: 
    300308            plt.ylim(*limits) 
     309 
     310        if num_plots > 1: 
     311            plt.subplot(1, num_plots, 1) 
    301312        plt.xscale('linear' if not some_present else view) 
    302313        plt.yscale('linear' 
     
    306317        plt.ylabel('$I(q)$') 
    307318 
    308     if plot_resid: 
    309         if plot_data or plot_theory: 
    310             plt.subplot(122) 
    311  
     319    if use_resid: 
    312320        mresid = masked_array(resid, data.mask.copy()) 
    313321        mresid[~np.isfinite(mresid)] = masked 
    314322        some_present = (mresid.count() > 0) 
     323 
     324        if num_plots > 1: 
     325            plt.subplot(1, num_plots, (use_data or use_theory) + 1) 
    315326        plt.plot(data.x/10, mresid, '-') 
    316327        plt.xlabel("$q$/nm$^{-1}$") 
     
    320331 
    321332@protect 
    322 def _plot_result_sesans(data, theory, resid, plot_data, limits=None): 
     333def _plot_result_sesans(data, theory, resid, use_data, limits=None): 
    323334    import matplotlib.pyplot as plt 
    324     if data.y is None: 
    325         plot_data = False 
    326     plot_theory = theory is not None 
    327     plot_resid = resid is not None 
    328  
    329     if plot_data or plot_theory: 
    330         if plot_resid: 
    331             plt.subplot(121) 
    332         if plot_data: 
     335    use_data = use_data and data.y is not None 
     336    use_theory = theory is not None 
     337    use_resid = resid is not None 
     338    num_plots = (use_data or use_theory) + use_resid 
     339 
     340    if use_data or use_theory: 
     341        if num_plots > 1: 
     342            plt.subplot(1, num_plots, 1) 
     343        if use_data: 
    333344            plt.errorbar(data.x, data.y, yerr=data.dy) 
    334345        if theory is not None: 
     
    340351 
    341352    if resid is not None: 
    342         if plot_data or plot_theory: 
    343             plt.subplot(122) 
    344  
     353        if num_plots > 1: 
     354            plt.subplot(1, num_plots, (use_data or use_theory) + 1) 
    345355        plt.plot(data.x, resid, 'x') 
    346356        plt.xlabel('spin echo length (nm)') 
     
    349359 
    350360@protect 
    351 def _plot_result2D(data, theory, resid, view, plot_data, limits=None): 
     361def _plot_result2D(data, theory, resid, view, use_data, limits=None): 
    352362    """ 
    353363    Plot the data and residuals for 2D data. 
    354364    """ 
    355365    import matplotlib.pyplot as plt 
    356     if data.data is None: 
    357         plot_data = False 
    358     plot_theory = theory is not None 
    359     plot_resid = resid is not None 
     366    use_data = use_data and data.data is not None 
     367    use_theory = theory is not None 
     368    use_resid = resid is not None 
     369    num_plots = use_data + use_theory + use_resid 
    360370 
    361371    # Put theory and data on a common colormap scale 
    362     if limits is None: 
    363         vmin, vmax = np.inf, -np.inf 
    364         if plot_data: 
    365             target = data.data[~data.mask] 
    366             datamin = target[target>0].min() if view == 'log' else target.min() 
    367             datamax = target.max() 
    368             vmin = min(vmin, datamin) 
    369             vmax = max(vmax, datamax) 
    370         if plot_theory: 
    371             theorymin = theory[theory>0].min() if view=='log' else theory.min() 
    372             theorymax = theory.max() 
    373             vmin = min(vmin, theorymin) 
    374             vmax = max(vmax, theorymax) 
    375     else: 
     372    vmin, vmax = np.inf, -np.inf 
     373    if use_data: 
     374        target = data.data[~data.mask] 
     375        datamin = target[target > 0].min() if view == 'log' else target.min() 
     376        datamax = target.max() 
     377        vmin = min(vmin, datamin) 
     378        vmax = max(vmax, datamax) 
     379    if use_theory: 
     380        theorymin = theory[theory > 0].min() if view == 'log' else theory.min() 
     381        theorymax = theory.max() 
     382        vmin = min(vmin, theorymin) 
     383        vmax = max(vmax, theorymax) 
     384 
     385    # Override data limits from the caller 
     386    if limits is not None: 
    376387        vmin, vmax = limits 
    377388 
    378     if plot_data: 
    379         if plot_theory and plot_resid: 
    380             plt.subplot(131) 
    381         elif plot_theory or plot_resid: 
    382             plt.subplot(121) 
     389    # Plot data 
     390    if use_data: 
     391        if num_plots > 1: 
     392            plt.subplot(1, num_plots, 1) 
    383393        _plot_2d_signal(data, target, view=view, vmin=vmin, vmax=vmax) 
    384394        plt.title('data') 
     
    386396        h.set_label('$I(q)$') 
    387397 
    388     if plot_theory: 
    389         if plot_data and plot_resid: 
    390             plt.subplot(132) 
    391         elif plot_data: 
    392             plt.subplot(122) 
    393         elif plot_resid: 
    394             plt.subplot(121) 
     398    # plot theory 
     399    if use_theory: 
     400        if num_plots > 1: 
     401            plt.subplot(1, num_plots, use_data+1) 
    395402        _plot_2d_signal(data, theory, view=view, vmin=vmin, vmax=vmax) 
    396403        plt.title('theory') 
     
    400407                    else '$I(q)$') 
    401408 
    402     #if plot_data or plot_theory: 
    403     #    plt.colorbar() 
    404  
    405     if plot_resid: 
    406         if plot_data and plot_theory: 
    407             plt.subplot(133) 
    408         elif plot_data or plot_theory: 
    409             plt.subplot(122) 
     409    # plot resid 
     410    if use_resid: 
     411        if num_plots > 1: 
     412            plt.subplot(1, num_plots, use_data+use_theory+1) 
    410413        _plot_2d_signal(data, resid, view='linear') 
    411414        plt.title('residuals') 
Note: See TracChangeset for help on using the changeset viewer.