Changeset 3221de0 in sasmodels for sasmodels/compare.py


Ignore:
Timestamp:
Feb 1, 2018 9:40:30 AM (6 years ago)
Author:
Paul Kienzle <pkienzle@…>
Branches:
master, core_shell_microgels, magnetic_model, ticket-1257-vesicle-product, ticket_1156, ticket_1265_superball, ticket_822_more_unit_tests
Children:
b4272a2
Parents:
b3af1c2
Message:

restructure handling of opencl flags so it works with sasview

File:
1 edited

Legend:

Unmodified
Added
Removed
  • sasmodels/compare.py

    r2a7e20e r3221de0  
    4040from . import core 
    4141from . import kerneldll 
     42from . import kernelcl 
    4243from .data import plot_theory, empty_data1D, empty_data2D, load_data 
    4344from .direct_model import DirectModel, get_mesh 
     
    623624 
    624625 
     626def time_calculation(calculator, pars, evals=1): 
     627    # type: (Calculator, ParameterSet, int) -> Tuple[np.ndarray, float] 
     628    """ 
     629    Compute the average calculation time over N evaluations. 
     630 
     631    An additional call is generated without polydispersity in order to 
     632    initialize the calculation engine, and make the average more stable. 
     633    """ 
     634    # initialize the code so time is more accurate 
     635    if evals > 1: 
     636        calculator(**suppress_pd(pars)) 
     637    toc = tic() 
     638    # make sure there is at least one eval 
     639    value = calculator(**pars) 
     640    for _ in range(evals-1): 
     641        value = calculator(**pars) 
     642    average_time = toc()*1000. / evals 
     643    #print("I(q)",value) 
     644    return value, average_time 
     645 
     646def make_data(opts): 
     647    # type: (Dict[str, Any]) -> Tuple[Data, np.ndarray] 
     648    """ 
     649    Generate an empty dataset, used with the model to set Q points 
     650    and resolution. 
     651 
     652    *opts* contains the options, with 'qmax', 'nq', 'res', 
     653    'accuracy', 'is2d' and 'view' parsed from the command line. 
     654    """ 
     655    qmin, qmax, nq, res = opts['qmin'], opts['qmax'], opts['nq'], opts['res'] 
     656    if opts['is2d']: 
     657        q = np.linspace(-qmax, qmax, nq)  # type: np.ndarray 
     658        data = empty_data2D(q, resolution=res) 
     659        data.accuracy = opts['accuracy'] 
     660        set_beam_stop(data, qmin) 
     661        index = ~data.mask 
     662    else: 
     663        if opts['view'] == 'log' and not opts['zero']: 
     664            q = np.logspace(math.log10(qmin), math.log10(qmax), nq) 
     665        else: 
     666            q = np.linspace(qmin, qmax, nq) 
     667        if opts['zero']: 
     668            q = np.hstack((0, q)) 
     669        data = empty_data1D(q, resolution=res) 
     670        index = slice(None, None) 
     671    return data, index 
     672 
    625673DTYPE_MAP = { 
    626674    'half': '16', 
     
    643691    Return a model calculator using the OpenCL calculation engine. 
    644692    """ 
    645     if not core.HAVE_OPENCL: 
    646         raise RuntimeError("OpenCL not available") 
    647     model = core.build_model(model_info, dtype=dtype, platform="ocl") 
    648     calculator = DirectModel(data, model, cutoff=cutoff) 
    649     calculator.engine = "OCL%s"%DTYPE_MAP[str(model.dtype)] 
    650     return calculator 
    651693 
    652694def eval_ctypes(model_info, data, dtype='double', cutoff=0.): 
     
    660702    return calculator 
    661703 
    662 def time_calculation(calculator, pars, evals=1): 
    663     # type: (Calculator, ParameterSet, int) -> Tuple[np.ndarray, float] 
    664     """ 
    665     Compute the average calculation time over N evaluations. 
    666  
    667     An additional call is generated without polydispersity in order to 
    668     initialize the calculation engine, and make the average more stable. 
    669     """ 
    670     # initialize the code so time is more accurate 
    671     if evals > 1: 
    672         calculator(**suppress_pd(pars)) 
    673     toc = tic() 
    674     # make sure there is at least one eval 
    675     value = calculator(**pars) 
    676     for _ in range(evals-1): 
    677         value = calculator(**pars) 
    678     average_time = toc()*1000. / evals 
    679     #print("I(q)",value) 
    680     return value, average_time 
    681  
    682 def make_data(opts): 
    683     # type: (Dict[str, Any]) -> Tuple[Data, np.ndarray] 
    684     """ 
    685     Generate an empty dataset, used with the model to set Q points 
    686     and resolution. 
    687  
    688     *opts* contains the options, with 'qmax', 'nq', 'res', 
    689     'accuracy', 'is2d' and 'view' parsed from the command line. 
    690     """ 
    691     qmin, qmax, nq, res = opts['qmin'], opts['qmax'], opts['nq'], opts['res'] 
    692     if opts['is2d']: 
    693         q = np.linspace(-qmax, qmax, nq)  # type: np.ndarray 
    694         data = empty_data2D(q, resolution=res) 
    695         data.accuracy = opts['accuracy'] 
    696         set_beam_stop(data, qmin) 
    697         index = ~data.mask 
    698     else: 
    699         if opts['view'] == 'log' and not opts['zero']: 
    700             q = np.logspace(math.log10(qmin), math.log10(qmax), nq) 
    701         else: 
    702             q = np.linspace(qmin, qmax, nq) 
    703         if opts['zero']: 
    704             q = np.hstack((0, q)) 
    705         data = empty_data1D(q, resolution=res) 
    706         index = slice(None, None) 
    707     return data, index 
    708  
    709704def make_engine(model_info, data, dtype, cutoff, ngauss=0): 
    710705    # type: (ModelInfo, Data, str, float) -> Calculator 
     
    718713        set_integration_size(model_info, ngauss) 
    719714 
    720     if dtype is None or not dtype.endswith('!'): 
    721         return eval_opencl(model_info, data, dtype=dtype, cutoff=cutoff) 
    722     else: 
    723         return eval_ctypes(model_info, data, dtype=dtype[:-1], cutoff=cutoff) 
     715    if dtype != "default" and not dtype.endswith('!') and not kernelcl.use_opencl(): 
     716        raise RuntimeError("OpenCL not available " + kernelcl.OPENCL_ERROR) 
     717 
     718    model = core.build_model(model_info, dtype=dtype, platform="ocl") 
     719    calculator = DirectModel(data, model, cutoff=cutoff) 
     720    engine_type = calculator._model.__class__.__name__.replace('Model','').upper() 
     721    bits = calculator._model.dtype.itemsize*8 
     722    precision = "fast" if getattr(calculator._model, 'fast', False) else str(bits) 
     723    calculator.engine = "%s[%s]" % (engine_type, precision) 
     724    return calculator 
    724725 
    725726def _show_invalid(data, theory): 
Note: See TracChangeset for help on using the changeset viewer.