Changeset 750ffa5 in sasmodels for sasmodels/kerneldll.py


Ignore:
Timestamp:
Mar 9, 2015 4:04:55 PM (9 years ago)
Author:
Paul Kienzle <pkienzle@…>
Branches:
master, core_shell_microgels, costrafo411, magnetic_model, release_v0.94, release_v0.95, ticket-1257-vesicle-product, ticket_1156, ticket_1265_superball, ticket_822_more_unit_tests
Children:
3a45c2c
Parents:
48f0194
Message:

allow test of dll using single precision

File:
1 edited

Legend:

Unmodified
Added
Removed
  • sasmodels/kerneldll.py

    r3c56da87 r750ffa5  
    11""" 
    22C types wrapper for sasview models. 
     3 
     4The global attribute *ALLOW_SINGLE_PRECISION_DLLS* should be set to *True* if 
     5you wish to allow single precision floating point evaluation for the compiled 
     6models, otherwise it defaults to *False*. 
    37""" 
     8 
    49import sys 
    510import os 
    611import tempfile 
    712import ctypes as ct 
    8 from ctypes import c_void_p, c_int, c_double 
     13from ctypes import c_void_p, c_int, c_double, c_float 
    914 
    1015import numpy as np 
     
    3641DLL_PATH = tempfile.gettempdir() 
    3742 
    38  
    39 def dll_path(info): 
     43ALLOW_SINGLE_PRECISION_DLLS = False 
     44 
     45 
     46def dll_path(info, dtype="double"): 
    4047    """ 
    4148    Path to the compiled model defined by *info*. 
     
    4350    from os.path import join as joinpath, split as splitpath, splitext 
    4451    basename = splitext(splitpath(info['filename'])[1])[0] 
     52    if np.dtype(dtype) == generate.F32: 
     53        basename += "32" 
    4554    return joinpath(DLL_PATH, basename+'.so') 
    4655 
    4756 
    48 def load_model(kernel_module, dtype=None): 
     57def load_model(kernel_module, dtype="double"): 
    4958    """ 
    5059    Load the compiled model defined by *kernel_module*. 
     
    5766    be defined without using too many resources. 
    5867    """ 
     68    if not ALLOW_SINGLE_PRECISION_DLLS: dtype = "double"   # Force 64-bit dll 
     69    dtype = np.dtype(dtype) 
     70 
    5971    source, info = generate.make(kernel_module) 
    6072    if callable(info.get('Iq',None)): 
    6173        return PyModel(info) 
     74 
     75    if dtype == generate.F32: # 32-bit dll 
     76        source = generate.use_single(source) 
     77        tempfile_prefix = 'sas_'+info['name']+'32_' 
     78    else: 
     79        tempfile_prefix = 'sas_'+info['name']+'_' 
     80 
    6281    source_files = generate.sources(info) + [info['filename']] 
     82    dllpath = dll_path(info, dtype) 
    6383    newest = max(os.path.getmtime(f) for f in source_files) 
    64     dllpath = dll_path(info) 
    6584    if not os.path.exists(dllpath) or os.path.getmtime(dllpath)<newest: 
    6685        # Replace with a proper temp file 
    67         fid, filename = tempfile.mkstemp(suffix=".c",prefix="sas_"+info['name']) 
     86        fid, filename = tempfile.mkstemp(suffix=".c",prefix=tempfile_prefix) 
    6887        os.fdopen(fid,"w").write(source) 
    6988        command = COMPILE%{"source":filename, "output":dllpath} 
     
    7695            #os.unlink(filename); print "saving compiled file in %r"%filename 
    7796            pass 
    78     return DllModel(dllpath, info) 
     97    return DllModel(dllpath, info, dtype=dtype) 
    7998 
    8099 
     
    96115    Call :meth:`release` when done with the kernel. 
    97116    """ 
    98     def __init__(self, dllpath, info): 
     117    def __init__(self, dllpath, info, dtype=generate.F32): 
    99118        self.info = info 
    100119        self.dllpath = dllpath 
    101120        self.dll = None 
     121        self.dtype = np.dtype(dtype) 
    102122 
    103123    def _load_dll(self): 
     
    110130        self.dll = ct.CDLL(self.dllpath) 
    111131 
    112         pd_args_1d = [c_void_p, c_double] + [c_int]*Npd1d if Npd1d else [] 
    113         pd_args_2d= [c_void_p, c_double] + [c_int]*Npd2d if Npd2d else [] 
     132        fp = c_float if self.dtype == generate.F32 else c_double 
     133        pd_args_1d = [c_void_p, fp] + [c_int]*Npd1d if Npd1d else [] 
     134        pd_args_2d= [c_void_p, fp] + [c_int]*Npd2d if Npd2d else [] 
    114135        self.Iq = self.dll[generate.kernel_name(self.info, False)] 
    115         self.Iq.argtypes = IQ_ARGS + pd_args_1d + [c_double]*Nfixed1d 
     136        self.Iq.argtypes = IQ_ARGS + pd_args_1d + [fp]*Nfixed1d 
    116137 
    117138        self.Iqxy = self.dll[generate.kernel_name(self.info, True)] 
    118         self.Iqxy.argtypes = IQXY_ARGS + pd_args_2d + [c_double]*Nfixed2d 
     139        self.Iqxy.argtypes = IQXY_ARGS + pd_args_2d + [fp]*Nfixed2d 
    119140 
    120141    def __getstate__(self): 
     
    125146 
    126147    def __call__(self, q_input): 
     148        if self.dtype != q_input.dtype: 
     149            raise TypeError("data is %s kernel is %s" % (q_input.dtype, self.dtype)) 
    127150        if self.dll is None: self._load_dll() 
    128151        kernel = self.Iqxy if q_input.is_2D else self.Iq 
     
    138161        ctypes and some may be pure python. 
    139162        """ 
    140         return PyInput(q_vectors, dtype=F64) 
     163        return PyInput(q_vectors, dtype=self.dtype) 
    141164 
    142165    def release(self): 
Note: See TracChangeset for help on using the changeset viewer.