Changeset b0de252 in sasmodels for sasmodels/core.py
- Timestamp:
- Oct 12, 2018 7:31:24 PM (6 years ago)
- Branches:
- master, core_shell_microgels, magnetic_model, ticket-1257-vesicle-product, ticket_1156, ticket_1265_superball, ticket_822_more_unit_tests
- Children:
- 74e9b5f
- Parents:
- 47fb816
- File:
-
- 1 edited
Legend:
- Unmodified
- Added
- Removed
-
sasmodels/core.py
r47fb816 rb0de252 13 13 from glob import glob 14 14 import re 15 16 # Set "SAS_OPENCL=cuda" in the environment to use the CUDA rather than OpenCL17 USE_CUDA = os.environ.get("SAS_OPENCL", "") == "cuda"18 15 19 16 import numpy as np # type: ignore … … 24 21 from . import mixture 25 22 from . import kernelpy 26 if USE_CUDA: 27 from . import kernelcuda 28 else: 29 from . import kernelcl 23 from . import kernelcuda 24 from . import kernelcl 30 25 from . import kerneldll 31 26 from . import custom … … 216 211 #print("building dll", numpy_dtype) 217 212 return kerneldll.load_dll(source['dll'], model_info, numpy_dtype) 218 elif USE_CUDA: 219 #print("building cuda", numpy_dtype) 213 elif platform == "cuda": 220 214 return kernelcuda.GpuModel(source, model_info, numpy_dtype, fast=fast) 221 215 else: … … 254 248 # type: (ModelInfo, str, str) -> (np.dtype, bool, str) 255 249 """ 256 Interpret dtype string, returning np.dtype and fast flag.250 Interpret dtype string, returning np.dtype, fast flag and platform. 257 251 258 252 Possible types include 'half', 'single', 'double' and 'quad'. If the … … 262 256 default for the model and platform. 263 257 264 Platform preference can be specfied ("ocl" vs "dll"), with the default 265 being OpenCL if it is availabe. If the dtype name ends with '!' then 266 platform is forced to be DLL rather than OpenCL. 258 Platform preference can be specfied ("ocl", "cuda", "dll"), with the 259 default being OpenCL or CUDA if available, otherwise DLL. If the dtype 260 name ends with '!' then platform is forced to be DLL rather than GPU. 261 The default platform is set by the environment variable SAS_OPENCL, 262 SAS_OPENCL=driver:device for OpenCL, SAS_OPENCL=cuda:device for CUDA 263 or SAS_OPENCL=none for DLL. 267 264 268 265 This routine ignores the preferences within the model definition. This … … 277 274 if platform is None: 278 275 platform = "ocl" 279 if not model_info.opencl:280 platform = "dll"281 elif USE_CUDA:282 if not kernelcuda.use_cuda():283 platform = "dll"284 else:285 if not kernelcl.use_opencl():286 platform = "dll"287 276 288 277 # Check if type indicates dll regardless of which platform is given … … 290 279 platform = "dll" 291 280 dtype = dtype[:-1] 281 282 # Make sure model allows opencl/gpu 283 if not model_info.opencl: 284 platform = "dll" 285 286 # Make sure opencl is available, or fallback to cuda then to dll 287 if platform == "ocl" and not kernelcl.use_opencl(): 288 platform = "cuda" if kernelcuda.use_cuda() else "dll" 292 289 293 290 # Convert special type names "half", "fast", and "quad" … … 300 297 dtype = "float16" 301 298 302 # Convert dtype string to numpy dtype. 299 # Convert dtype string to numpy dtype. Use single precision for GPU 300 # if model allows it, otherwise use double precision. 303 301 if dtype is None or dtype == "default": 304 numpy_dtype = (generate.F32 if platform == "ocl" and model_info.single302 numpy_dtype = (generate.F32 if model_info.single and platform in ("ocl", "cuda") 305 303 else generate.F64) 306 304 else: 307 305 numpy_dtype = np.dtype(dtype) 308 306 309 # Make sure that the type is supported by opencl, otherwise use dll307 # Make sure that the type is supported by GPU, otherwise use dll 310 308 if platform == "ocl": 311 if USE_CUDA: 312 env = kernelcuda.environment() 313 else: 314 env = kernelcl.environment() 315 if not env.has_type(numpy_dtype): 316 platform = "dll" 317 if dtype is None: 318 numpy_dtype = generate.F64 309 env = kernelcl.environment() 310 elif platform == "cuda": 311 env = kernelcuda.environment() 312 else: 313 env = None 314 if env is not None and not env.has_type(numpy_dtype): 315 platform = "dll" 316 if dtype is None: 317 numpy_dtype = generate.F64 319 318 320 319 return numpy_dtype, fast, platform
Note: See TracChangeset
for help on using the changeset viewer.