Changeset b0de252 in sasmodels for sasmodels/core.py


Ignore:
Timestamp:
Oct 12, 2018 7:31:24 PM (6 years ago)
Author:
pkienzle
Branches:
master, core_shell_microgels, magnetic_model, ticket-1257-vesicle-product, ticket_1156, ticket_1265_superball, ticket_822_more_unit_tests
Children:
74e9b5f
Parents:
47fb816
Message:

improve control over cuda context

File:
1 edited

Legend:

Unmodified
Added
Removed
  • sasmodels/core.py

    r47fb816 rb0de252  
    1313from glob import glob 
    1414import re 
    15  
    16 # Set "SAS_OPENCL=cuda" in the environment to use the CUDA rather than OpenCL 
    17 USE_CUDA = os.environ.get("SAS_OPENCL", "") == "cuda" 
    1815 
    1916import numpy as np # type: ignore 
     
    2421from . import mixture 
    2522from . import kernelpy 
    26 if USE_CUDA: 
    27     from . import kernelcuda 
    28 else: 
    29     from . import kernelcl 
     23from . import kernelcuda 
     24from . import kernelcl 
    3025from . import kerneldll 
    3126from . import custom 
     
    216211        #print("building dll", numpy_dtype) 
    217212        return kerneldll.load_dll(source['dll'], model_info, numpy_dtype) 
    218     elif USE_CUDA: 
    219         #print("building cuda", numpy_dtype) 
     213    elif platform == "cuda": 
    220214        return kernelcuda.GpuModel(source, model_info, numpy_dtype, fast=fast) 
    221215    else: 
     
    254248    # type: (ModelInfo, str, str) -> (np.dtype, bool, str) 
    255249    """ 
    256     Interpret dtype string, returning np.dtype and fast flag. 
     250    Interpret dtype string, returning np.dtype, fast flag and platform. 
    257251 
    258252    Possible types include 'half', 'single', 'double' and 'quad'.  If the 
     
    262256    default for the model and platform. 
    263257 
    264     Platform preference can be specfied ("ocl" vs "dll"), with the default 
    265     being OpenCL if it is availabe.  If the dtype name ends with '!' then 
    266     platform is forced to be DLL rather than OpenCL. 
     258    Platform preference can be specfied ("ocl", "cuda", "dll"), with the 
     259    default being OpenCL or CUDA if available, otherwise DLL.  If the dtype 
     260    name ends with '!' then platform is forced to be DLL rather than GPU. 
     261    The default platform is set by the environment variable SAS_OPENCL, 
     262    SAS_OPENCL=driver:device for OpenCL, SAS_OPENCL=cuda:device for CUDA 
     263    or SAS_OPENCL=none for DLL. 
    267264 
    268265    This routine ignores the preferences within the model definition.  This 
     
    277274    if platform is None: 
    278275        platform = "ocl" 
    279     if not model_info.opencl: 
    280         platform = "dll" 
    281     elif USE_CUDA: 
    282         if not kernelcuda.use_cuda(): 
    283             platform = "dll" 
    284     else: 
    285         if not kernelcl.use_opencl(): 
    286             platform = "dll" 
    287276 
    288277    # Check if type indicates dll regardless of which platform is given 
     
    290279        platform = "dll" 
    291280        dtype = dtype[:-1] 
     281 
     282    # Make sure model allows opencl/gpu 
     283    if not model_info.opencl: 
     284        platform = "dll" 
     285 
     286    # Make sure opencl is available, or fallback to cuda then to dll 
     287    if platform == "ocl" and not kernelcl.use_opencl(): 
     288        platform = "cuda" if kernelcuda.use_cuda() else "dll" 
    292289 
    293290    # Convert special type names "half", "fast", and "quad" 
     
    300297        dtype = "float16" 
    301298 
    302     # Convert dtype string to numpy dtype. 
     299    # Convert dtype string to numpy dtype.  Use single precision for GPU 
     300    # if model allows it, otherwise use double precision. 
    303301    if dtype is None or dtype == "default": 
    304         numpy_dtype = (generate.F32 if platform == "ocl" and model_info.single 
     302        numpy_dtype = (generate.F32 if model_info.single and platform in ("ocl", "cuda") 
    305303                       else generate.F64) 
    306304    else: 
    307305        numpy_dtype = np.dtype(dtype) 
    308306 
    309     # Make sure that the type is supported by opencl, otherwise use dll 
     307    # Make sure that the type is supported by GPU, otherwise use dll 
    310308    if platform == "ocl": 
    311         if USE_CUDA: 
    312             env = kernelcuda.environment() 
    313         else: 
    314             env = kernelcl.environment() 
    315         if not env.has_type(numpy_dtype): 
    316             platform = "dll" 
    317             if dtype is None: 
    318                 numpy_dtype = generate.F64 
     309        env = kernelcl.environment() 
     310    elif platform == "cuda": 
     311        env = kernelcuda.environment() 
     312    else: 
     313        env = None 
     314    if env is not None and not env.has_type(numpy_dtype): 
     315        platform = "dll" 
     316        if dtype is None: 
     317            numpy_dtype = generate.F64 
    319318 
    320319    return numpy_dtype, fast, platform 
Note: See TracChangeset for help on using the changeset viewer.