Changeset dd7fc12 in sasmodels


Ignore:
Timestamp:
Apr 15, 2016 9:11:43 AM (8 years ago)
Author:
Paul Kienzle <pkienzle@…>
Branches:
master, core_shell_microgels, costrafo411, magnetic_model, release_v0.94, release_v0.95, ticket-1257-vesicle-product, ticket_1156, ticket_1265_superball, ticket_822_more_unit_tests
Children:
3599d36
Parents:
b151003
Message:

fix kerneldll dtype problem; more type hinting

Location:
sasmodels
Files:
3 edited

Legend:

Unmodified
Added
Removed
  • sasmodels/compare.py

    rb151003 rdd7fc12  
    4141from .direct_model import DirectModel 
    4242from .convert import revert_name, revert_pars, constrain_new_to_old 
     43 
     44try: 
     45    from typing import Optional, Dict, Any, Callable, Tuple 
     46except: 
     47    pass 
     48else: 
     49    from .modelinfo import ModelInfo, Parameter, ParameterSet 
     50    from .data import Data 
     51    Calculator = Callable[[float, ...], np.ndarray] 
    4352 
    4453USAGE = """ 
     
    97106kerneldll.ALLOW_SINGLE_PRECISION_DLLS = True 
    98107 
    99 MODELS = core.list_models() 
    100  
    101108# CRUFT python 2.6 
    102109if not hasattr(datetime.timedelta, 'total_seconds'): 
     
    160167        ...            print(randint(0,1000000,3)) 
    161168        ...            raise Exception() 
    162         ...    except: 
     169        ...    except Exception: 
    163170        ...        print("Exception raised") 
    164171        ...    print(randint(0,1000000)) 
     
    169176    """ 
    170177    def __init__(self, seed=None): 
     178        # type: (Optional[int]) -> None 
    171179        self._state = np.random.get_state() 
    172180        np.random.seed(seed) 
    173181 
    174182    def __enter__(self): 
    175         return None 
    176  
    177     def __exit__(self, *args): 
     183        # type: () -> None 
     184        pass 
     185 
     186    def __exit__(self, type, value, traceback): 
     187        # type: (Any, BaseException, Any) -> None 
     188        # TODO: better typing for __exit__ method 
    178189        np.random.set_state(self._state) 
    179190 
    180191def tic(): 
     192    # type: () -> Callable[[], float] 
    181193    """ 
    182194    Timer function. 
     
    190202 
    191203def set_beam_stop(data, radius, outer=None): 
     204    # type: (Data, float, float) -> None 
    192205    """ 
    193206    Add a beam stop of the given *radius*.  If *outer*, make an annulus. 
    194207 
    195     Note: this function does not use the sasview package 
     208    Note: this function does not require sasview 
    196209    """ 
    197210    if hasattr(data, 'qx_data'): 
     
    207220 
    208221def parameter_range(p, v): 
     222    # type: (str, float) -> Tuple[float, float] 
    209223    """ 
    210224    Choose a parameter range based on parameter name and initial value. 
     
    212226    # process the polydispersity options 
    213227    if p.endswith('_pd_n'): 
    214         return [0, 100] 
     228        return 0., 100. 
    215229    elif p.endswith('_pd_nsigma'): 
    216         return [0, 5] 
     230        return 0., 5. 
    217231    elif p.endswith('_pd_type'): 
    218         return v 
     232        raise ValueError("Cannot return a range for a string value") 
    219233    elif any(s in p for s in ('theta', 'phi', 'psi')): 
    220234        # orientation in [-180,180], orientation pd in [0,45] 
    221235        if p.endswith('_pd'): 
    222             return [0, 45] 
     236            return 0., 45. 
    223237        else: 
    224             return [-180, 180] 
     238            return -180., 180. 
    225239    elif p.endswith('_pd'): 
    226         return [0, 1] 
     240        return 0., 1. 
    227241    elif 'sld' in p: 
    228         return [-0.5, 10] 
     242        return -0.5, 10. 
    229243    elif p == 'background': 
    230         return [0, 10] 
     244        return 0., 10. 
    231245    elif p == 'scale': 
    232         return [0, 1e3] 
    233     elif v < 0: 
    234         return [2*v, -2*v] 
     246        return 0., 1.e3 
     247    elif v < 0.: 
     248        return 2.*v, -2.*v 
    235249    else: 
    236         return [0, (2*v if v > 0 else 1)] 
     250        return 0., (2.*v if v > 0. else 1.) 
    237251 
    238252 
    239253def _randomize_one(model_info, p, v): 
     254    # type: (ModelInfo, str, float) -> float 
     255    # type: (ModelInfo, str, str) -> str 
    240256    """ 
    241257    Randomize a single parameter. 
     
    263279 
    264280def randomize_pars(model_info, pars, seed=None): 
     281    # type: (ModelInfo, ParameterSet, int) -> ParameterSet 
    265282    """ 
    266283    Generate random values for all of the parameters. 
     
    273290    with push_seed(seed): 
    274291        # Note: the sort guarantees order `of calls to random number generator 
    275         pars = dict((p, _randomize_one(model_info, p, v)) 
    276                     for p, v in sorted(pars.items())) 
    277     return pars 
     292        random_pars = dict((p, _randomize_one(model_info, p, v)) 
     293                           for p, v in sorted(pars.items())) 
     294    return random_pars 
    278295 
    279296def constrain_pars(model_info, pars): 
     297    # type: (ModelInfo, ParameterSet) -> None 
    280298    """ 
    281299    Restrict parameters to valid values. 
     
    284302    which need to support within model constraints (cap radius more than 
    285303    cylinder radius in this case). 
     304 
     305    Warning: this updates the *pars* dictionary in place. 
    286306    """ 
    287307    name = model_info.id 
     
    315335 
    316336def parlist(model_info, pars, is2d): 
     337    # type: (ModelInfo, ParameterSet, bool) -> str 
    317338    """ 
    318339    Format the parameter list for printing. 
     
    326347            n=int(pars.get(p.id+"_pd_n", 0)), 
    327348            nsigma=pars.get(p.id+"_pd_nsgima", 3.), 
    328             type=pars.get(p.id+"_pd_type", 'gaussian')) 
     349            pdtype=pars.get(p.id+"_pd_type", 'gaussian'), 
     350        ) 
    329351        lines.append(_format_par(p.name, **fields)) 
    330352    return "\n".join(lines) 
     
    332354    #return "\n".join("%s: %s"%(p, v) for p, v in sorted(pars.items())) 
    333355 
    334 def _format_par(name, value=0., pd=0., n=0, nsigma=3., type='gaussian'): 
     356def _format_par(name, value=0., pd=0., n=0, nsigma=3., pdtype='gaussian'): 
     357    # type: (str, float, float, int, float, str) -> str 
    335358    line = "%s: %g"%(name, value) 
    336359    if pd != 0.  and n != 0: 
    337360        line += " +/- %g  (%d points in [-%g,%g] sigma %s)"\ 
    338                 % (pd, n, nsigma, nsigma, type) 
     361                % (pd, n, nsigma, nsigma, pdtype) 
    339362    return line 
    340363 
    341364def suppress_pd(pars): 
     365    # type: (ParameterSet) -> ParameterSet 
    342366    """ 
    343367    Suppress theta_pd for now until the normalization is resolved. 
     
    352376 
    353377def eval_sasview(model_info, data): 
     378    # type: (Modelinfo, Data) -> Calculator 
    354379    """ 
    355380    Return a model calculator using the pre-4.0 SasView models. 
     
    359384    import sas 
    360385    from sas.models.qsmearing import smear_selection 
     386    import sas.models 
    361387 
    362388    def get_model(name): 
     389        # type: (str) -> "sas.models.BaseComponent" 
    363390        #print("new",sorted(_pars.items())) 
    364         sas = __import__('sas.models.' + name) 
     391        __import__('sas.models.' + name) 
    365392        ModelClass = getattr(getattr(sas.models, name, None), name, None) 
    366393        if ModelClass is None: 
     
    400427 
    401428    def calculator(**pars): 
     429        # type: (float, ...) -> np.ndarray 
    402430        """ 
    403431        Sasview calculator for model. 
     
    406434        pars = revert_pars(model_info, pars) 
    407435        for k, v in pars.items(): 
    408             parts = k.split('.')  # polydispersity components 
    409             if len(parts) == 2: 
    410                 model.dispersion[parts[0]][parts[1]] = v 
     436            name_attr = k.split('.')  # polydispersity components 
     437            if len(name_attr) == 2: 
     438                model.dispersion[name_attr[0]][name_attr[1]] = v 
    411439            else: 
    412440                model.setParam(k, v) 
     
    428456} 
    429457def eval_opencl(model_info, data, dtype='single', cutoff=0.): 
     458    # type: (ModelInfo, Data, str, float) -> Calculator 
    430459    """ 
    431460    Return a model calculator using the OpenCL calculation engine. 
     
    442471 
    443472def eval_ctypes(model_info, data, dtype='double', cutoff=0.): 
     473    # type: (ModelInfo, Data, str, float) -> Calculator 
    444474    """ 
    445475    Return a model calculator using the DLL calculation engine. 
    446476    """ 
    447     if dtype == 'quad': 
    448         dtype = 'longdouble' 
    449477    model = core.build_model(model_info, dtype=dtype, platform="dll") 
    450478    calculator = DirectModel(data, model, cutoff=cutoff) 
     
    453481 
    454482def time_calculation(calculator, pars, Nevals=1): 
     483    # type: (Calculator, ParameterSet, int) -> Tuple[np.ndarray, float] 
    455484    """ 
    456485    Compute the average calculation time over N evaluations. 
     
    461490    # initialize the code so time is more accurate 
    462491    if Nevals > 1: 
    463         value = calculator(**suppress_pd(pars)) 
     492        calculator(**suppress_pd(pars)) 
    464493    toc = tic() 
    465     for _ in range(max(Nevals, 1)):  # make sure there is at least one eval 
     494    # make sure there is at least one eval 
     495    value = calculator(**pars) 
     496    for _ in range(Nevals-1): 
    466497        value = calculator(**pars) 
    467498    average_time = toc()*1000./Nevals 
     
    469500 
    470501def make_data(opts): 
     502    # type: (Dict[str, Any]) -> Tuple[Data, np.ndarray] 
    471503    """ 
    472504    Generate an empty dataset, used with the model to set Q points 
     
    478510    qmax, nq, res = opts['qmax'], opts['nq'], opts['res'] 
    479511    if opts['is2d']: 
    480         data = empty_data2D(np.linspace(-qmax, qmax, nq), resolution=res) 
     512        q = np.linspace(-qmax, qmax, nq)  # type: np.ndarray 
     513        data = empty_data2D(q, resolution=res) 
    481514        data.accuracy = opts['accuracy'] 
    482515        set_beam_stop(data, 0.0004) 
     
    495528 
    496529def make_engine(model_info, data, dtype, cutoff): 
     530    # type: (ModelInfo, Data, str, float) -> Calculator 
    497531    """ 
    498532    Generate the appropriate calculation engine for the given datatype. 
     
    509543 
    510544def _show_invalid(data, theory): 
     545    # type: (Data, np.ma.ndarray) -> None 
     546    """ 
     547    Display a list of the non-finite values in theory. 
     548    """ 
    511549    if not theory.mask.any(): 
    512550        return 
     
    514552    if hasattr(data, 'x'): 
    515553        bad = zip(data.x[theory.mask], theory[theory.mask]) 
    516         print("   *** ", ", ".join("I(%g)=%g"%(x, y) for x,y in bad)) 
     554        print("   *** ", ", ".join("I(%g)=%g"%(x, y) for x, y in bad)) 
    517555 
    518556 
    519557def compare(opts, limits=None): 
     558    # type: (Dict[str, Any], Optional[Tuple[float, float]]) -> Tuple[float, float] 
    520559    """ 
    521560    Preform a comparison using options from the command line. 
     
    532571    data = opts['data'] 
    533572 
     573    # silence the linter 
     574    base = opts['engines'][0] if Nbase else None 
     575    comp = opts['engines'][1] if Ncomp else None 
     576    base_time = comp_time = None 
     577    base_value = comp_value = resid = relerr = None 
     578 
    534579    # Base calculation 
    535580    if Nbase > 0: 
    536         base = opts['engines'][0] 
    537581        try: 
    538             base_value, base_time = time_calculation(base, pars, Nbase) 
    539             base_value = np.ma.masked_invalid(base_value) 
     582            base_raw, base_time = time_calculation(base, pars, Nbase) 
     583            base_value = np.ma.masked_invalid(base_raw) 
    540584            print("%s t=%.2f ms, intensity=%.0f" 
    541585                  % (base.engine, base_time, base_value.sum())) 
     
    547591    # Comparison calculation 
    548592    if Ncomp > 0: 
    549         comp = opts['engines'][1] 
    550593        try: 
    551             comp_value, comp_time = time_calculation(comp, pars, Ncomp) 
    552             comp_value = np.ma.masked_invalid(comp_value) 
     594            comp_raw, comp_time = time_calculation(comp, pars, Ncomp) 
     595            comp_value = np.ma.masked_invalid(comp_raw) 
    553596            print("%s t=%.2f ms, intensity=%.0f" 
    554597                  % (comp.engine, comp_time, comp_value.sum())) 
     
    625668 
    626669def _print_stats(label, err): 
     670    # type: (str, np.ma.ndarray) -> None 
     671    # work with trimmed data, not the full set 
    627672    sorted_err = np.sort(abs(err.compressed())) 
    628     p50 = int((len(err)-1)*0.50) 
    629     p98 = int((len(err)-1)*0.98) 
     673    p50 = int((len(sorted_err)-1)*0.50) 
     674    p98 = int((len(sorted_err)-1)*0.98) 
    630675    data = [ 
    631676        "max:%.3e"%sorted_err[-1], 
    632677        "median:%.3e"%sorted_err[p50], 
    633678        "98%%:%.3e"%sorted_err[p98], 
    634         "rms:%.3e"%np.sqrt(np.mean(err**2)), 
    635         "zero-offset:%+.3e"%np.mean(err), 
     679        "rms:%.3e"%np.sqrt(np.mean(sorted_err**2)), 
     680        "zero-offset:%+.3e"%np.mean(sorted_err), 
    636681        ] 
    637682    print(label+"  "+"  ".join(data)) 
     
    662707 
    663708def columnize(L, indent="", width=79): 
     709    # type: (List[str], str, int) -> str 
    664710    """ 
    665711    Format a list of strings into columns. 
     
    679725 
    680726def get_pars(model_info, use_demo=False): 
     727    # type: (ModelInfo, bool) -> ParameterSet 
    681728    """ 
    682729    Extract demo parameters from the model definition. 
     
    704751 
    705752def parse_opts(): 
     753    # type: () -> Dict[str, Any] 
    706754    """ 
    707755    Parse command line options. 
     
    757805        'explore'   : False, 
    758806        'use_demo'  : True, 
     807        'zero'      : False, 
    759808    } 
    760809    engines = [] 
     
    777826        elif arg.startswith('-cutoff='):   opts['cutoff'] = float(arg[8:]) 
    778827        elif arg.startswith('-random='):   opts['seed'] = int(arg[8:]) 
    779         elif arg == '-random':  opts['seed'] = np.random.randint(1e6) 
     828        elif arg == '-random':  opts['seed'] = np.random.randint(1000000) 
    780829        elif arg == '-preset':  opts['seed'] = -1 
    781830        elif arg == '-mono':    opts['mono'] = True 
     
    874923 
    875924def explore(opts): 
     925    # type: (Dict[str, Any]) -> None 
    876926    """ 
    877927    Explore the model using the Bumps GUI. 
     
    900950    """ 
    901951    def __init__(self, opts): 
     952        # type: (Dict[str, Any]) -> None 
    902953        from bumps.cli import config_matplotlib  # type: ignore 
    903954        from . import bumps_model 
     
    923974 
    924975    def numpoints(self): 
     976        # type: () -> int 
    925977        """ 
    926978        Return the number of points. 
     
    929981 
    930982    def parameters(self): 
     983        # type: () -> Any   # Dict/List hierarchy of parameters 
    931984        """ 
    932985        Return a dictionary of parameters. 
     
    935988 
    936989    def nllf(self): 
     990        # type: () -> float 
    937991        """ 
    938992        Return cost. 
     
    942996 
    943997    def plot(self, view='log'): 
     998        # type: (str) -> None 
    944999        """ 
    9451000        Plot the data and residuals. 
     
    9511006        if self.limits is None: 
    9521007            vmin, vmax = limits 
    953             vmax = 1.3*vmax 
    954             vmin = vmax*1e-7 
    955             self.limits = vmin, vmax 
     1008            self.limits = vmax*1e-7, 1.3*vmax 
    9561009 
    9571010 
    9581011def main(): 
     1012    # type: () -> None 
    9591013    """ 
    9601014    Main program. 
  • sasmodels/core.py

    rfa5fd8d rdd7fc12  
    2828try: 
    2929    from typing import List, Union, Optional, Any 
    30     DType = Union[None, str, np.dtype] 
    3130    from .kernel import KernelModel 
     31    from .modelinfo import ModelInfo 
    3232except ImportError: 
    3333    pass 
     
    5656    return available_models 
    5757 
    58 def isstr(s): 
    59     # type: (Any) -> bool 
    60     """ 
    61     Return True if *s* is a string-like object. 
    62     """ 
    63     try: s + '' 
    64     except Exception: return False 
    65     return True 
    66  
    6758def load_model(model_name, dtype=None, platform='ocl'): 
    68     # type: (str, DType, str) -> KernelModel 
     59    # type: (str, str, str) -> KernelModel 
    6960    """ 
    7061    Load model info and build model. 
     
    10293 
    10394def build_model(model_info, dtype=None, platform="ocl"): 
    104     # type: (modelinfo.ModelInfo, np.dtype, str) -> KernelModel 
     95    # type: (modelinfo.ModelInfo, str, str) -> KernelModel 
    10596    """ 
    10697    Prepare the model for the default execution platform. 
     
    113104 
    114105    *dtype* indicates whether the model should use single or double precision 
    115     for the calculation. Any valid numpy single or double precision identifier 
    116     is valid, such as 'single', 'f', 'f32', or np.float32 for single, or 
    117     'double', 'd', 'f64'  and np.float64 for double.  If *None*, then use 
    118     'single' unless the model defines single=False. 
     106    for the calculation.  Choices are 'single', 'double', 'quad', 'half', 
     107    or 'fast'.  If *dtype* ends with '!', then force the use of the DLL rather 
     108    than OpenCL for the calculation. 
    119109 
    120110    *platform* should be "dll" to force the dll to be used for C models, 
     
    147137    # source = open(model_info.name+'.cl','r').read() 
    148138    source = generate.make_source(model_info) 
    149     if dtype is None: 
    150         dtype = generate.F32 if model_info.single else generate.F64 
     139    numpy_dtype, fast = parse_dtype(model_info, dtype) 
    151140    if (platform == "dll" 
     141            or dtype.endswith('!') 
    152142            or not HAVE_OPENCL 
    153             or not kernelcl.environment().has_type(dtype)): 
    154         return kerneldll.load_dll(source, model_info, dtype) 
     143            or not kernelcl.environment().has_type(numpy_dtype)): 
     144        return kerneldll.load_dll(source, model_info, numpy_dtype) 
    155145    else: 
    156         return kernelcl.GpuModel(source, model_info, dtype) 
     146        return kernelcl.GpuModel(source, model_info, numpy_dtype, fast=fast) 
    157147 
    158148def precompile_dll(model_name, dtype="double"): 
    159     # type: (str, DType) -> Optional[str] 
     149    # type: (str, str) -> Optional[str] 
    160150    """ 
    161151    Precompile the dll for a model. 
     
    172162    """ 
    173163    model_info = load_model_info(model_name) 
     164    numpy_dtype, fast = parse_dtype(model_info, dtype) 
    174165    source = generate.make_source(model_info) 
    175     return kerneldll.make_dll(source, model_info, dtype=dtype) if source else None 
     166    return kerneldll.make_dll(source, model_info, dtype=numpy_dtype) if source else None 
     167 
     168def parse_dtype(model_info, dtype): 
     169    # type: (ModelInfo, str) -> Tuple[np.dtype, bool] 
     170    """ 
     171    Interpret dtype string, returning np.dtype and fast flag. 
     172 
     173    Possible types include 'half', 'single', 'double' and 'quad'.  If the 
     174    type is 'fast', then this is equivalent to dtype 'single' with the 
     175    fast flag set to True. 
     176    """ 
     177    # Fill in default type based on required precision in the model 
     178    if dtype is None: 
     179        dtype = 'single' if model_info.single else 'double' 
     180 
     181    # Ignore platform indicator 
     182    if dtype.endswith('!'): 
     183        dtype = dtype[:-1] 
     184 
     185    # Convert type string to type 
     186    if dtype == 'quad': 
     187        return generate.F128, False 
     188    elif dtype == 'half': 
     189        return generate.F16, False 
     190    elif dtype == 'fast': 
     191        return generate.F32, True 
     192    else: 
     193        return np.dtype(dtype), False 
     194 
  • sasmodels/kernelcl.py

    ra5b8477 rdd7fc12  
    103103ENV = None 
    104104def environment(): 
     105    # type: () -> "GpuEnvironment" 
    105106    """ 
    106107    Returns a singleton :class:`GpuEnvironment`. 
     
    114115 
    115116def has_type(device, dtype): 
     117    # type: (cl.Device, np.dtype) -> bool 
    116118    """ 
    117119    Return true if device supports the requested precision. 
     
    127129 
    128130def get_warp(kernel, queue): 
     131    # type: (cl.Kernel, cl.CommandQueue) -> int 
    129132    """ 
    130133    Return the size of an execution batch for *kernel* running on *queue*. 
     
    135138 
    136139def _stretch_input(vector, dtype, extra=1e-3, boundary=32): 
     140    # type: (np.ndarray, np.dtype, float, int) -> np.ndarray 
    137141    """ 
    138142    Stretch an input vector to the correct boundary. 
     
    157161 
    158162def compile_model(context, source, dtype, fast=False): 
     163    # type: (cl.Context, str, np.dtype, bool) -> cl.Program 
    159164    """ 
    160165    Build a model to run on the gpu. 
     
    192197    """ 
    193198    def __init__(self): 
     199        # type: () -> None 
    194200        # find gpu context 
    195201        #self.context = cl.create_some_context() 
     
    210216 
    211217    def has_type(self, dtype): 
     218        # type: (np.dtype) -> bool 
    212219        """ 
    213220        Return True if all devices support a given type. 
    214221        """ 
    215         dtype = generate.F32 if dtype == 'fast' else np.dtype(dtype) 
    216222        return any(has_type(d, dtype) 
    217223                   for context in self.context 
     
    219225 
    220226    def get_queue(self, dtype): 
     227        # type: (np.dtype) -> cl.CommandQueue 
    221228        """ 
    222229        Return a command queue for the kernels of type dtype. 
     
    227234 
    228235    def get_context(self, dtype): 
     236        # type: (np.dtype) -> cl.Context 
    229237        """ 
    230238        Return a OpenCL context for the kernels of type dtype. 
     
    235243 
    236244    def _create_some_context(self): 
     245        # type: () -> cl.Context 
    237246        """ 
    238247        Protected call to cl.create_some_context without interactivity.  Use 
     
    248257 
    249258    def compile_program(self, name, source, dtype, fast=False): 
     259        # type: (str, str, np.dtype, bool) -> cl.Program 
    250260        """ 
    251261        Compile the program for the device in the given context. 
     
    261271 
    262272    def release_program(self, name): 
     273        # type: (str) -> None 
    263274        """ 
    264275        Free memory associated with the program on the device. 
     
    269280 
    270281def _get_default_context(): 
     282    # type: () -> cl.Context 
    271283    """ 
    272284    Get an OpenCL context, preferring GPU over CPU, and preferring Intel 
     
    334346    that the compiler is allowed to take shortcuts. 
    335347    """ 
    336     def __init__(self, source, model_info, dtype=generate.F32): 
     348    def __init__(self, source, model_info, dtype=generate.F32, fast=False): 
     349        # type: (str, ModelInfo, np.dtype, bool) -> None 
    337350        self.info = model_info 
    338351        self.source = source 
    339         self.dtype = generate.F32 if dtype == 'fast' else np.dtype(dtype) 
    340         self.fast = (dtype == 'fast') 
     352        self.dtype = dtype 
     353        self.fast = fast 
    341354        self.program = None # delay program creation 
    342355 
    343356    def __getstate__(self): 
     357        # type: () -> Tuple[ModelInfo, str, np.dtype, bool] 
    344358        return self.info, self.source, self.dtype, self.fast 
    345359 
    346360    def __setstate__(self, state): 
     361        # type: (Tuple[ModelInfo, str, np.dtype, bool]) -> None 
    347362        self.info, self.source, self.dtype, self.fast = state 
    348363        self.program = None 
    349364 
    350365    def make_kernel(self, q_vectors): 
     366        # type: (List[np.ndarray]) -> "GpuKernel" 
    351367        if self.program is None: 
    352368            compiler = environment().compile_program 
     
    356372        kernel_name = generate.kernel_name(self.info, is_2d) 
    357373        kernel = getattr(self.program, kernel_name) 
    358         return GpuKernel(kernel, self.info, q_vectors, self.dtype) 
     374        return GpuKernel(kernel, self.info, q_vectors) 
    359375 
    360376    def release(self): 
     377        # type: () -> None 
    361378        """ 
    362379        Free the resources associated with the model. 
     
    367384 
    368385    def __del__(self): 
     386        # type: () -> None 
    369387        self.release() 
    370388 
     
    390408    """ 
    391409    def __init__(self, q_vectors, dtype=generate.F32): 
     410        # type: (List[np.ndarray], np.dtype) -> None 
    392411        # TODO: do we ever need double precision q? 
    393412        env = environment() 
     
    419438 
    420439    def release(self): 
     440        # type: () -> None 
    421441        """ 
    422442        Free the memory. 
     
    427447 
    428448    def __del__(self): 
     449        # type: () -> None 
    429450        self.release() 
    430451 
     
    450471    """ 
    451472    def __init__(self, kernel, model_info, q_vectors): 
    452         # type: (KernelModel, ModelInfo, List[np.ndarray]) -> None 
     473        # type: (cl.Kernel, ModelInfo, List[np.ndarray]) -> None 
    453474        max_pd = model_info.parameters.max_pd 
    454475        npars = len(model_info.parameters.kernel_parameters)-2 
     
    505526 
    506527    def release(self): 
     528        # type: () -> None 
    507529        """ 
    508530        Release resources associated with the kernel. 
     
    513535 
    514536    def __del__(self): 
     537        # type: () -> None 
    515538        self.release() 
Note: See TracChangeset for help on using the changeset viewer.