Changeset b32dafd in sasmodels for sasmodels/compare.py


Ignore:
Timestamp:
Aug 5, 2016 3:59:13 PM (8 years ago)
Author:
Paul Kienzle <pkienzle@…>
Branches:
master, core_shell_microgels, costrafo411, magnetic_model, release_v0.94, release_v0.95, ticket-1257-vesicle-product, ticket_1156, ticket_1265_superball, ticket_822_more_unit_tests
Children:
99f446a
Parents:
7722b4a
Message:

lint

File:
1 edited

Legend:

Unmodified
Added
Removed
  • sasmodels/compare.py

    rbd49c79 rb32dafd  
    184184        pass 
    185185 
    186     def __exit__(self, type, value, traceback): 
     186    def __exit__(self, exc_type, exc_value, traceback): 
    187187        # type: (Any, BaseException, Any) -> None 
    188188        # TODO: better typing for __exit__ method 
     
    436436                smearer.model = model[0] 
    437437                return smearer.get_value() 
    438             theory = lambda: _call_smearer() 
     438            theory = _call_smearer 
    439439        else: 
    440440            theory = lambda: model[0].evalDistribution([data.qx_data[index], 
     
    503503    return calculator 
    504504 
    505 def time_calculation(calculator, pars, Nevals=1): 
     505def time_calculation(calculator, pars, evals=1): 
    506506    # type: (Calculator, ParameterSet, int) -> Tuple[np.ndarray, float] 
    507507    """ 
     
    512512    """ 
    513513    # initialize the code so time is more accurate 
    514     if Nevals > 1: 
     514    if evals > 1: 
    515515        calculator(**suppress_pd(pars)) 
    516516    toc = tic() 
    517517    # make sure there is at least one eval 
    518518    value = calculator(**pars) 
    519     for _ in range(Nevals-1): 
     519    for _ in range(evals-1): 
    520520        value = calculator(**pars) 
    521     average_time = toc()*1000./Nevals 
     521    average_time = toc()*1000. / evals 
    522522    #print("I(q)",value) 
    523523    return value, average_time 
     
    591591    parameters. 
    592592    """ 
    593     Nbase, Ncomp = opts['n1'], opts['n2'] 
     593    n_base, n_comp = opts['n1'], opts['n2'] 
    594594    pars = opts['pars'] 
    595595    data = opts['data'] 
    596596 
    597597    # silence the linter 
    598     base = opts['engines'][0] if Nbase else None 
    599     comp = opts['engines'][1] if Ncomp else None 
     598    base = opts['engines'][0] if n_base else None 
     599    comp = opts['engines'][1] if n_comp else None 
    600600    base_time = comp_time = None 
    601601    base_value = comp_value = resid = relerr = None 
    602602 
    603603    # Base calculation 
    604     if Nbase > 0: 
     604    if n_base > 0: 
    605605        try: 
    606             base_raw, base_time = time_calculation(base, pars, Nbase) 
     606            base_raw, base_time = time_calculation(base, pars, n_base) 
    607607            base_value = np.ma.masked_invalid(base_raw) 
    608608            print("%s t=%.2f ms, intensity=%.0f" 
     
    611611        except ImportError: 
    612612            traceback.print_exc() 
    613             Nbase = 0 
     613            n_base = 0 
    614614 
    615615    # Comparison calculation 
    616     if Ncomp > 0: 
     616    if n_comp > 0: 
    617617        try: 
    618             comp_raw, comp_time = time_calculation(comp, pars, Ncomp) 
     618            comp_raw, comp_time = time_calculation(comp, pars, n_comp) 
    619619            comp_value = np.ma.masked_invalid(comp_raw) 
    620620            print("%s t=%.2f ms, intensity=%.0f" 
     
    623623        except ImportError: 
    624624            traceback.print_exc() 
    625             Ncomp = 0 
     625            n_comp = 0 
    626626 
    627627    # Compare, but only if computing both forms 
    628     if Nbase > 0 and Ncomp > 0: 
     628    if n_base > 0 and n_comp > 0: 
    629629        resid = (base_value - comp_value) 
    630         relerr = resid/np.where(comp_value!=0., abs(comp_value), 1.0) 
     630        relerr = resid/np.where(comp_value != 0., abs(comp_value), 1.0) 
    631631        _print_stats("|%s-%s|" 
    632632                     % (base.engine, comp.engine) + (" "*(3+len(comp.engine))), 
     
    642642    if limits is None: 
    643643        vmin, vmax = np.Inf, -np.Inf 
    644         if Nbase > 0: 
     644        if n_base > 0: 
    645645            vmin = min(vmin, base_value.min()) 
    646646            vmax = max(vmax, base_value.max()) 
    647         if Ncomp > 0: 
     647        if n_comp > 0: 
    648648            vmin = min(vmin, comp_value.min()) 
    649649            vmax = max(vmax, comp_value.max()) 
    650650        limits = vmin, vmax 
    651651 
    652     if Nbase > 0: 
    653         if Ncomp > 0: plt.subplot(131) 
     652    if n_base > 0: 
     653        if n_comp > 0: plt.subplot(131) 
    654654        plot_theory(data, base_value, view=view, use_data=False, limits=limits) 
    655655        plt.title("%s t=%.2f ms"%(base.engine, base_time)) 
    656656        #cbar_title = "log I" 
    657     if Ncomp > 0: 
    658         if Nbase > 0: plt.subplot(132) 
     657    if n_comp > 0: 
     658        if n_base > 0: plt.subplot(132) 
    659659        plot_theory(data, comp_value, view=view, use_data=False, limits=limits) 
    660660        plt.title("%s t=%.2f ms"%(comp.engine, comp_time)) 
    661661        #cbar_title = "log I" 
    662     if Ncomp > 0 and Nbase > 0: 
     662    if n_comp > 0 and n_base > 0: 
    663663        plt.subplot(133) 
    664664        if not opts['rel_err']: 
     
    678678    fig.suptitle(opts['name']) 
    679679 
    680     if Ncomp > 0 and Nbase > 0 and '-hist' in opts: 
     680    if n_comp > 0 and n_base > 0 and '-hist' in opts: 
    681681        plt.figure() 
    682682        v = relerr 
     
    732732    ] 
    733733 
    734 def columnize(L, indent="", width=79): 
     734def columnize(items, indent="", width=79): 
    735735    # type: (List[str], str, int) -> str 
    736736    """ 
     
    739739    Returns a string with carriage returns ready for printing. 
    740740    """ 
    741     column_width = max(len(w) for w in L) + 1 
     741    column_width = max(len(w) for w in items) + 1 
    742742    num_columns = (width - len(indent)) // column_width 
    743     num_rows = len(L) // num_columns 
    744     L = L + [""] * (num_rows*num_columns - len(L)) 
    745     columns = [L[k*num_rows:(k+1)*num_rows] for k in range(num_columns)] 
     743    num_rows = len(items) // num_columns 
     744    items = items + [""] * (num_rows * num_columns - len(items)) 
     745    columns = [items[k*num_rows:(k+1)*num_rows] for k in range(num_columns)] 
    746746    lines = [" ".join("%-*s"%(column_width, entry) for entry in row) 
    747747             for row in zip(*columns)] 
     
    766766        for ext, val in parts: 
    767767            if p.length > 1: 
    768                 dict(("%s%d%s"%(p.id,k,ext), val) for k in range(1, p.length+1)) 
     768                dict(("%s%d%s" % (p.id, k, ext), val) 
     769                     for k in range(1, p.length+1)) 
    769770            else: 
    770                 pars[p.id+ext] = val 
     771                pars[p.id + ext] = val 
    771772 
    772773    # Plug in values given in demo 
     
    887888    n1 = int(args[1]) if len(args) > 1 else 1 
    888889    n2 = int(args[2]) if len(args) > 2 else 1 
    889     use_sasview = any(engine=='sasview' and count>0 
     890    use_sasview = any(engine == 'sasview' and count > 0 
    890891                      for engine, count in zip(engines, [n1, n2])) 
    891892 
Note: See TracChangeset for help on using the changeset viewer.