Changes in sasmodels/core.py [2dcd6e7:b0de252] in sasmodels


Ignore:
File:
1 edited

Legend:

Unmodified
Added
Removed
  • sasmodels/core.py

    r2dcd6e7 rb0de252  
    2121from . import mixture 
    2222from . import kernelpy 
     23from . import kernelcuda 
    2324from . import kernelcl 
    2425from . import kerneldll 
     
    210211        #print("building dll", numpy_dtype) 
    211212        return kerneldll.load_dll(source['dll'], model_info, numpy_dtype) 
     213    elif platform == "cuda": 
     214        return kernelcuda.GpuModel(source, model_info, numpy_dtype, fast=fast) 
    212215    else: 
    213216        #print("building ocl", numpy_dtype) 
     
    245248    # type: (ModelInfo, str, str) -> (np.dtype, bool, str) 
    246249    """ 
    247     Interpret dtype string, returning np.dtype and fast flag. 
     250    Interpret dtype string, returning np.dtype, fast flag and platform. 
    248251 
    249252    Possible types include 'half', 'single', 'double' and 'quad'.  If the 
     
    253256    default for the model and platform. 
    254257 
    255     Platform preference can be specfied ("ocl" vs "dll"), with the default 
    256     being OpenCL if it is availabe.  If the dtype name ends with '!' then 
    257     platform is forced to be DLL rather than OpenCL. 
     258    Platform preference can be specfied ("ocl", "cuda", "dll"), with the 
     259    default being OpenCL or CUDA if available, otherwise DLL.  If the dtype 
     260    name ends with '!' then platform is forced to be DLL rather than GPU. 
     261    The default platform is set by the environment variable SAS_OPENCL, 
     262    SAS_OPENCL=driver:device for OpenCL, SAS_OPENCL=cuda:device for CUDA 
     263    or SAS_OPENCL=none for DLL. 
    258264 
    259265    This routine ignores the preferences within the model definition.  This 
     
    268274    if platform is None: 
    269275        platform = "ocl" 
    270     if not kernelcl.use_opencl() or not model_info.opencl: 
    271         platform = "dll" 
    272276 
    273277    # Check if type indicates dll regardless of which platform is given 
     
    275279        platform = "dll" 
    276280        dtype = dtype[:-1] 
     281 
     282    # Make sure model allows opencl/gpu 
     283    if not model_info.opencl: 
     284        platform = "dll" 
     285 
     286    # Make sure opencl is available, or fallback to cuda then to dll 
     287    if platform == "ocl" and not kernelcl.use_opencl(): 
     288        platform = "cuda" if kernelcuda.use_cuda() else "dll" 
    277289 
    278290    # Convert special type names "half", "fast", and "quad" 
     
    285297        dtype = "float16" 
    286298 
    287     # Convert dtype string to numpy dtype. 
     299    # Convert dtype string to numpy dtype.  Use single precision for GPU 
     300    # if model allows it, otherwise use double precision. 
    288301    if dtype is None or dtype == "default": 
    289         numpy_dtype = (generate.F32 if platform == "ocl" and model_info.single 
     302        numpy_dtype = (generate.F32 if model_info.single and platform in ("ocl", "cuda") 
    290303                       else generate.F64) 
    291304    else: 
    292305        numpy_dtype = np.dtype(dtype) 
    293306 
    294     # Make sure that the type is supported by opencl, otherwise use dll 
     307    # Make sure that the type is supported by GPU, otherwise use dll 
    295308    if platform == "ocl": 
    296309        env = kernelcl.environment() 
    297         if not env.has_type(numpy_dtype): 
    298             platform = "dll" 
    299             if dtype is None: 
    300                 numpy_dtype = generate.F64 
     310    elif platform == "cuda": 
     311        env = kernelcuda.environment() 
     312    else: 
     313        env = None 
     314    if env is not None and not env.has_type(numpy_dtype): 
     315        platform = "dll" 
     316        if dtype is None: 
     317            numpy_dtype = generate.F64 
    301318 
    302319    return numpy_dtype, fast, platform 
Note: See TracChangeset for help on using the changeset viewer.