Changeset 07646b6 in sasmodels for sasmodels/core.py


Ignore:
Timestamp:
Oct 25, 2018 3:43:01 PM (6 years ago)
Author:
Paul Kienzle <pkienzle@…>
Branches:
master, core_shell_microgels, magnetic_model, ticket-1257-vesicle-product, ticket_1156, ticket_1265_superball, ticket_822_more_unit_tests
Children:
149eb53
Parents:
31fc4ad (diff), d5ce7fa (diff)
Note: this is a merge changeset, the changes displayed below correspond to the merge itself.
Use the (diff) links above to see all the changes relative to each parent.
git-author:
Paul Kienzle <pkienzle@…> (10/25/18 14:41:48)
git-committer:
Paul Kienzle <pkienzle@…> (10/25/18 15:43:01)
Message:

Merge branch 'cuda-test' into beta_approx

File:
1 edited

Legend:

Unmodified
Added
Removed
  • sasmodels/core.py

    ree60aa7 r07646b6  
    2121from . import mixture 
    2222from . import kernelpy 
     23from . import kernelcuda 
    2324from . import kernelcl 
    2425from . import kerneldll 
     
    209210        #print("building dll", numpy_dtype) 
    210211        return kerneldll.load_dll(source['dll'], model_info, numpy_dtype) 
     212    elif platform == "cuda": 
     213        return kernelcuda.GpuModel(source, model_info, numpy_dtype, fast=fast) 
    211214    else: 
    212215        #print("building ocl", numpy_dtype) 
     
    244247    # type: (ModelInfo, str, str) -> (np.dtype, bool, str) 
    245248    """ 
    246     Interpret dtype string, returning np.dtype and fast flag. 
     249    Interpret dtype string, returning np.dtype, fast flag and platform. 
    247250 
    248251    Possible types include 'half', 'single', 'double' and 'quad'.  If the 
     
    252255    default for the model and platform. 
    253256 
    254     Platform preference can be specfied ("ocl" vs "dll"), with the default 
    255     being OpenCL if it is availabe.  If the dtype name ends with '!' then 
    256     platform is forced to be DLL rather than OpenCL. 
     257    Platform preference can be specfied ("ocl", "cuda", "dll"), with the 
     258    default being OpenCL or CUDA if available, otherwise DLL.  If the dtype 
     259    name ends with '!' then platform is forced to be DLL rather than GPU. 
     260    The default platform is set by the environment variable SAS_OPENCL, 
     261    SAS_OPENCL=driver:device for OpenCL, SAS_OPENCL=cuda:device for CUDA 
     262    or SAS_OPENCL=none for DLL. 
    257263 
    258264    This routine ignores the preferences within the model definition.  This 
     
    266272    if platform is None: 
    267273        platform = "ocl" 
    268     if not kernelcl.use_opencl() or not model_info.opencl: 
    269         platform = "dll" 
    270274 
    271275    # Check if type indicates dll regardless of which platform is given 
     
    273277        platform = "dll" 
    274278        dtype = dtype[:-1] 
     279 
     280    # Make sure model allows opencl/gpu 
     281    if not model_info.opencl: 
     282        platform = "dll" 
     283 
     284    # Make sure opencl is available, or fallback to cuda then to dll 
     285    if platform == "ocl" and not kernelcl.use_opencl(): 
     286        platform = "cuda" if kernelcuda.use_cuda() else "dll" 
    275287 
    276288    # Convert special type names "half", "fast", and "quad" 
     
    283295        dtype = "float16" 
    284296 
    285     # Convert dtype string to numpy dtype. 
     297    # Convert dtype string to numpy dtype.  Use single precision for GPU 
     298    # if model allows it, otherwise use double precision. 
    286299    if dtype is None or dtype == "default": 
    287         numpy_dtype = (generate.F32 if platform == "ocl" and model_info.single 
     300        numpy_dtype = (generate.F32 if model_info.single and platform in ("ocl", "cuda") 
    288301                       else generate.F64) 
    289302    else: 
    290303        numpy_dtype = np.dtype(dtype) 
    291304 
    292     # Make sure that the type is supported by opencl, otherwise use dll 
     305    # Make sure that the type is supported by GPU, otherwise use dll 
    293306    if platform == "ocl": 
    294307        env = kernelcl.environment() 
    295         if not env.has_type(numpy_dtype): 
    296             platform = "dll" 
    297             if dtype is None: 
    298                 numpy_dtype = generate.F64 
     308    elif platform == "cuda": 
     309        env = kernelcuda.environment() 
     310    else: 
     311        env = None 
     312    if env is not None and not env.has_type(numpy_dtype): 
     313        platform = "dll" 
     314        if dtype is None: 
     315            numpy_dtype = generate.F64 
    299316 
    300317    return numpy_dtype, fast, platform 
Note: See TracChangeset for help on using the changeset viewer.