Changeset dd7fc12 in sasmodels for sasmodels/compare.py


Ignore:
Timestamp:
Apr 15, 2016 11:11:43 AM (8 years ago)
Author:
Paul Kienzle <pkienzle@…>
Branches:
master, core_shell_microgels, costrafo411, magnetic_model, release_v0.94, release_v0.95, ticket-1257-vesicle-product, ticket_1156, ticket_1265_superball, ticket_822_more_unit_tests
Children:
3599d36
Parents:
b151003
Message:

fix kerneldll dtype problem; more type hinting

File:
1 edited

Legend:

Unmodified
Added
Removed
  • sasmodels/compare.py

    rb151003 rdd7fc12  
    4141from .direct_model import DirectModel 
    4242from .convert import revert_name, revert_pars, constrain_new_to_old 
     43 
     44try: 
     45    from typing import Optional, Dict, Any, Callable, Tuple 
     46except: 
     47    pass 
     48else: 
     49    from .modelinfo import ModelInfo, Parameter, ParameterSet 
     50    from .data import Data 
     51    Calculator = Callable[[float, ...], np.ndarray] 
    4352 
    4453USAGE = """ 
     
    97106kerneldll.ALLOW_SINGLE_PRECISION_DLLS = True 
    98107 
    99 MODELS = core.list_models() 
    100  
    101108# CRUFT python 2.6 
    102109if not hasattr(datetime.timedelta, 'total_seconds'): 
     
    160167        ...            print(randint(0,1000000,3)) 
    161168        ...            raise Exception() 
    162         ...    except: 
     169        ...    except Exception: 
    163170        ...        print("Exception raised") 
    164171        ...    print(randint(0,1000000)) 
     
    169176    """ 
    170177    def __init__(self, seed=None): 
     178        # type: (Optional[int]) -> None 
    171179        self._state = np.random.get_state() 
    172180        np.random.seed(seed) 
    173181 
    174182    def __enter__(self): 
    175         return None 
    176  
    177     def __exit__(self, *args): 
     183        # type: () -> None 
     184        pass 
     185 
     186    def __exit__(self, type, value, traceback): 
     187        # type: (Any, BaseException, Any) -> None 
     188        # TODO: better typing for __exit__ method 
    178189        np.random.set_state(self._state) 
    179190 
    180191def tic(): 
     192    # type: () -> Callable[[], float] 
    181193    """ 
    182194    Timer function. 
     
    190202 
    191203def set_beam_stop(data, radius, outer=None): 
     204    # type: (Data, float, float) -> None 
    192205    """ 
    193206    Add a beam stop of the given *radius*.  If *outer*, make an annulus. 
    194207 
    195     Note: this function does not use the sasview package 
     208    Note: this function does not require sasview 
    196209    """ 
    197210    if hasattr(data, 'qx_data'): 
     
    207220 
    208221def parameter_range(p, v): 
     222    # type: (str, float) -> Tuple[float, float] 
    209223    """ 
    210224    Choose a parameter range based on parameter name and initial value. 
     
    212226    # process the polydispersity options 
    213227    if p.endswith('_pd_n'): 
    214         return [0, 100] 
     228        return 0., 100. 
    215229    elif p.endswith('_pd_nsigma'): 
    216         return [0, 5] 
     230        return 0., 5. 
    217231    elif p.endswith('_pd_type'): 
    218         return v 
     232        raise ValueError("Cannot return a range for a string value") 
    219233    elif any(s in p for s in ('theta', 'phi', 'psi')): 
    220234        # orientation in [-180,180], orientation pd in [0,45] 
    221235        if p.endswith('_pd'): 
    222             return [0, 45] 
     236            return 0., 45. 
    223237        else: 
    224             return [-180, 180] 
     238            return -180., 180. 
    225239    elif p.endswith('_pd'): 
    226         return [0, 1] 
     240        return 0., 1. 
    227241    elif 'sld' in p: 
    228         return [-0.5, 10] 
     242        return -0.5, 10. 
    229243    elif p == 'background': 
    230         return [0, 10] 
     244        return 0., 10. 
    231245    elif p == 'scale': 
    232         return [0, 1e3] 
    233     elif v < 0: 
    234         return [2*v, -2*v] 
     246        return 0., 1.e3 
     247    elif v < 0.: 
     248        return 2.*v, -2.*v 
    235249    else: 
    236         return [0, (2*v if v > 0 else 1)] 
     250        return 0., (2.*v if v > 0. else 1.) 
    237251 
    238252 
    239253def _randomize_one(model_info, p, v): 
     254    # type: (ModelInfo, str, float) -> float 
     255    # type: (ModelInfo, str, str) -> str 
    240256    """ 
    241257    Randomize a single parameter. 
     
    263279 
    264280def randomize_pars(model_info, pars, seed=None): 
     281    # type: (ModelInfo, ParameterSet, int) -> ParameterSet 
    265282    """ 
    266283    Generate random values for all of the parameters. 
     
    273290    with push_seed(seed): 
    274291        # Note: the sort guarantees order `of calls to random number generator 
    275         pars = dict((p, _randomize_one(model_info, p, v)) 
    276                     for p, v in sorted(pars.items())) 
    277     return pars 
     292        random_pars = dict((p, _randomize_one(model_info, p, v)) 
     293                           for p, v in sorted(pars.items())) 
     294    return random_pars 
    278295 
    279296def constrain_pars(model_info, pars): 
     297    # type: (ModelInfo, ParameterSet) -> None 
    280298    """ 
    281299    Restrict parameters to valid values. 
     
    284302    which need to support within model constraints (cap radius more than 
    285303    cylinder radius in this case). 
     304 
     305    Warning: this updates the *pars* dictionary in place. 
    286306    """ 
    287307    name = model_info.id 
     
    315335 
    316336def parlist(model_info, pars, is2d): 
     337    # type: (ModelInfo, ParameterSet, bool) -> str 
    317338    """ 
    318339    Format the parameter list for printing. 
     
    326347            n=int(pars.get(p.id+"_pd_n", 0)), 
    327348            nsigma=pars.get(p.id+"_pd_nsgima", 3.), 
    328             type=pars.get(p.id+"_pd_type", 'gaussian')) 
     349            pdtype=pars.get(p.id+"_pd_type", 'gaussian'), 
     350        ) 
    329351        lines.append(_format_par(p.name, **fields)) 
    330352    return "\n".join(lines) 
     
    332354    #return "\n".join("%s: %s"%(p, v) for p, v in sorted(pars.items())) 
    333355 
    334 def _format_par(name, value=0., pd=0., n=0, nsigma=3., type='gaussian'): 
     356def _format_par(name, value=0., pd=0., n=0, nsigma=3., pdtype='gaussian'): 
     357    # type: (str, float, float, int, float, str) -> str 
    335358    line = "%s: %g"%(name, value) 
    336359    if pd != 0.  and n != 0: 
    337360        line += " +/- %g  (%d points in [-%g,%g] sigma %s)"\ 
    338                 % (pd, n, nsigma, nsigma, type) 
     361                % (pd, n, nsigma, nsigma, pdtype) 
    339362    return line 
    340363 
    341364def suppress_pd(pars): 
     365    # type: (ParameterSet) -> ParameterSet 
    342366    """ 
    343367    Suppress theta_pd for now until the normalization is resolved. 
     
    352376 
    353377def eval_sasview(model_info, data): 
     378    # type: (Modelinfo, Data) -> Calculator 
    354379    """ 
    355380    Return a model calculator using the pre-4.0 SasView models. 
     
    359384    import sas 
    360385    from sas.models.qsmearing import smear_selection 
     386    import sas.models 
    361387 
    362388    def get_model(name): 
     389        # type: (str) -> "sas.models.BaseComponent" 
    363390        #print("new",sorted(_pars.items())) 
    364         sas = __import__('sas.models.' + name) 
     391        __import__('sas.models.' + name) 
    365392        ModelClass = getattr(getattr(sas.models, name, None), name, None) 
    366393        if ModelClass is None: 
     
    400427 
    401428    def calculator(**pars): 
     429        # type: (float, ...) -> np.ndarray 
    402430        """ 
    403431        Sasview calculator for model. 
     
    406434        pars = revert_pars(model_info, pars) 
    407435        for k, v in pars.items(): 
    408             parts = k.split('.')  # polydispersity components 
    409             if len(parts) == 2: 
    410                 model.dispersion[parts[0]][parts[1]] = v 
     436            name_attr = k.split('.')  # polydispersity components 
     437            if len(name_attr) == 2: 
     438                model.dispersion[name_attr[0]][name_attr[1]] = v 
    411439            else: 
    412440                model.setParam(k, v) 
     
    428456} 
    429457def eval_opencl(model_info, data, dtype='single', cutoff=0.): 
     458    # type: (ModelInfo, Data, str, float) -> Calculator 
    430459    """ 
    431460    Return a model calculator using the OpenCL calculation engine. 
     
    442471 
    443472def eval_ctypes(model_info, data, dtype='double', cutoff=0.): 
     473    # type: (ModelInfo, Data, str, float) -> Calculator 
    444474    """ 
    445475    Return a model calculator using the DLL calculation engine. 
    446476    """ 
    447     if dtype == 'quad': 
    448         dtype = 'longdouble' 
    449477    model = core.build_model(model_info, dtype=dtype, platform="dll") 
    450478    calculator = DirectModel(data, model, cutoff=cutoff) 
     
    453481 
    454482def time_calculation(calculator, pars, Nevals=1): 
     483    # type: (Calculator, ParameterSet, int) -> Tuple[np.ndarray, float] 
    455484    """ 
    456485    Compute the average calculation time over N evaluations. 
     
    461490    # initialize the code so time is more accurate 
    462491    if Nevals > 1: 
    463         value = calculator(**suppress_pd(pars)) 
     492        calculator(**suppress_pd(pars)) 
    464493    toc = tic() 
    465     for _ in range(max(Nevals, 1)):  # make sure there is at least one eval 
     494    # make sure there is at least one eval 
     495    value = calculator(**pars) 
     496    for _ in range(Nevals-1): 
    466497        value = calculator(**pars) 
    467498    average_time = toc()*1000./Nevals 
     
    469500 
    470501def make_data(opts): 
     502    # type: (Dict[str, Any]) -> Tuple[Data, np.ndarray] 
    471503    """ 
    472504    Generate an empty dataset, used with the model to set Q points 
     
    478510    qmax, nq, res = opts['qmax'], opts['nq'], opts['res'] 
    479511    if opts['is2d']: 
    480         data = empty_data2D(np.linspace(-qmax, qmax, nq), resolution=res) 
     512        q = np.linspace(-qmax, qmax, nq)  # type: np.ndarray 
     513        data = empty_data2D(q, resolution=res) 
    481514        data.accuracy = opts['accuracy'] 
    482515        set_beam_stop(data, 0.0004) 
     
    495528 
    496529def make_engine(model_info, data, dtype, cutoff): 
     530    # type: (ModelInfo, Data, str, float) -> Calculator 
    497531    """ 
    498532    Generate the appropriate calculation engine for the given datatype. 
     
    509543 
    510544def _show_invalid(data, theory): 
     545    # type: (Data, np.ma.ndarray) -> None 
     546    """ 
     547    Display a list of the non-finite values in theory. 
     548    """ 
    511549    if not theory.mask.any(): 
    512550        return 
     
    514552    if hasattr(data, 'x'): 
    515553        bad = zip(data.x[theory.mask], theory[theory.mask]) 
    516         print("   *** ", ", ".join("I(%g)=%g"%(x, y) for x,y in bad)) 
     554        print("   *** ", ", ".join("I(%g)=%g"%(x, y) for x, y in bad)) 
    517555 
    518556 
    519557def compare(opts, limits=None): 
     558    # type: (Dict[str, Any], Optional[Tuple[float, float]]) -> Tuple[float, float] 
    520559    """ 
    521560    Preform a comparison using options from the command line. 
     
    532571    data = opts['data'] 
    533572 
     573    # silence the linter 
     574    base = opts['engines'][0] if Nbase else None 
     575    comp = opts['engines'][1] if Ncomp else None 
     576    base_time = comp_time = None 
     577    base_value = comp_value = resid = relerr = None 
     578 
    534579    # Base calculation 
    535580    if Nbase > 0: 
    536         base = opts['engines'][0] 
    537581        try: 
    538             base_value, base_time = time_calculation(base, pars, Nbase) 
    539             base_value = np.ma.masked_invalid(base_value) 
     582            base_raw, base_time = time_calculation(base, pars, Nbase) 
     583            base_value = np.ma.masked_invalid(base_raw) 
    540584            print("%s t=%.2f ms, intensity=%.0f" 
    541585                  % (base.engine, base_time, base_value.sum())) 
     
    547591    # Comparison calculation 
    548592    if Ncomp > 0: 
    549         comp = opts['engines'][1] 
    550593        try: 
    551             comp_value, comp_time = time_calculation(comp, pars, Ncomp) 
    552             comp_value = np.ma.masked_invalid(comp_value) 
     594            comp_raw, comp_time = time_calculation(comp, pars, Ncomp) 
     595            comp_value = np.ma.masked_invalid(comp_raw) 
    553596            print("%s t=%.2f ms, intensity=%.0f" 
    554597                  % (comp.engine, comp_time, comp_value.sum())) 
     
    625668 
    626669def _print_stats(label, err): 
     670    # type: (str, np.ma.ndarray) -> None 
     671    # work with trimmed data, not the full set 
    627672    sorted_err = np.sort(abs(err.compressed())) 
    628     p50 = int((len(err)-1)*0.50) 
    629     p98 = int((len(err)-1)*0.98) 
     673    p50 = int((len(sorted_err)-1)*0.50) 
     674    p98 = int((len(sorted_err)-1)*0.98) 
    630675    data = [ 
    631676        "max:%.3e"%sorted_err[-1], 
    632677        "median:%.3e"%sorted_err[p50], 
    633678        "98%%:%.3e"%sorted_err[p98], 
    634         "rms:%.3e"%np.sqrt(np.mean(err**2)), 
    635         "zero-offset:%+.3e"%np.mean(err), 
     679        "rms:%.3e"%np.sqrt(np.mean(sorted_err**2)), 
     680        "zero-offset:%+.3e"%np.mean(sorted_err), 
    636681        ] 
    637682    print(label+"  "+"  ".join(data)) 
     
    662707 
    663708def columnize(L, indent="", width=79): 
     709    # type: (List[str], str, int) -> str 
    664710    """ 
    665711    Format a list of strings into columns. 
     
    679725 
    680726def get_pars(model_info, use_demo=False): 
     727    # type: (ModelInfo, bool) -> ParameterSet 
    681728    """ 
    682729    Extract demo parameters from the model definition. 
     
    704751 
    705752def parse_opts(): 
     753    # type: () -> Dict[str, Any] 
    706754    """ 
    707755    Parse command line options. 
     
    757805        'explore'   : False, 
    758806        'use_demo'  : True, 
     807        'zero'      : False, 
    759808    } 
    760809    engines = [] 
     
    777826        elif arg.startswith('-cutoff='):   opts['cutoff'] = float(arg[8:]) 
    778827        elif arg.startswith('-random='):   opts['seed'] = int(arg[8:]) 
    779         elif arg == '-random':  opts['seed'] = np.random.randint(1e6) 
     828        elif arg == '-random':  opts['seed'] = np.random.randint(1000000) 
    780829        elif arg == '-preset':  opts['seed'] = -1 
    781830        elif arg == '-mono':    opts['mono'] = True 
     
    874923 
    875924def explore(opts): 
     925    # type: (Dict[str, Any]) -> None 
    876926    """ 
    877927    Explore the model using the Bumps GUI. 
     
    900950    """ 
    901951    def __init__(self, opts): 
     952        # type: (Dict[str, Any]) -> None 
    902953        from bumps.cli import config_matplotlib  # type: ignore 
    903954        from . import bumps_model 
     
    923974 
    924975    def numpoints(self): 
     976        # type: () -> int 
    925977        """ 
    926978        Return the number of points. 
     
    929981 
    930982    def parameters(self): 
     983        # type: () -> Any   # Dict/List hierarchy of parameters 
    931984        """ 
    932985        Return a dictionary of parameters. 
     
    935988 
    936989    def nllf(self): 
     990        # type: () -> float 
    937991        """ 
    938992        Return cost. 
     
    942996 
    943997    def plot(self, view='log'): 
     998        # type: (str) -> None 
    944999        """ 
    9451000        Plot the data and residuals. 
     
    9511006        if self.limits is None: 
    9521007            vmin, vmax = limits 
    953             vmax = 1.3*vmax 
    954             vmin = vmax*1e-7 
    955             self.limits = vmin, vmax 
     1008            self.limits = vmax*1e-7, 1.3*vmax 
    9561009 
    9571010 
    9581011def main(): 
     1012    # type: () -> None 
    9591013    """ 
    9601014    Main program. 
Note: See TracChangeset for help on using the changeset viewer.