Changeset ce27e21 in sasmodels for sasmodels/dll.py


Ignore:
Timestamp:
Aug 24, 2014 7:18:14 PM (10 years ago)
Author:
Paul Kienzle <pkienzle@…>
Branches:
master, core_shell_microgels, costrafo411, magnetic_model, release_v0.94, release_v0.95, ticket-1257-vesicle-product, ticket_1156, ticket_1265_superball, ticket_822_more_unit_tests
Children:
1780d59
Parents:
14de349
Message:

first pass for sasview wrapper around opencl models

File:
1 edited

Legend:

Unmodified
Added
Removed
  • sasmodels/dll.py

    r14de349 rce27e21  
    1919    ctypes wrapper for a single model. 
    2020 
    21     *source* and *meta* are the model source and interface as returned 
     21    *source* and *info* are the model source and interface as returned 
    2222    from :func:`gen.make`. 
    2323 
     
    2727    is an optional extension which may not be available on all devices. 
    2828    """ 
    29     def __init__(self, dllpath, meta): 
    30         self.meta = meta 
    31         self.dll = ct.CDLL(dllpath) 
    32         self.Iq = self.dll[gen.kernel_name(self.meta, False)] 
    33         self.Iqxy = self.dll[gen.kernel_name(self.meta, True)] 
     29    def __init__(self, dllpath, info): 
     30        self.info = info 
     31        self.dllpath = dllpath 
     32        self.dll = None 
    3433 
     34    def _load_dll(self): 
     35        Nfixed1d = len(self.info['partype']['fixed-1d']) 
     36        Nfixed2d = len(self.info['partype']['fixed-2d']) 
     37        Npd1d = len(self.info['partype']['pd-1d']) 
     38        Npd2d = len(self.info['partype']['pd-2d']) 
    3539 
    36         self.PARS = dict((p[0],p[2]) for p in meta['parameters']) 
    37         self.PD_PARS = [p[0] for p in meta['parameters'] if p[4] != ""] 
     40        self.dll = ct.CDLL(self.dllpath) 
    3841 
    39         # Determine the set of fixed and polydisperse parameters 
    40         Nfixed = len([p[0] for p in meta['parameters'] if p[4] == ""]) 
    41         N1D = len([p for p in meta['parameters'] if p[4]=="volume"]) 
    42         N2D = len([p for p in meta['parameters'] if p[4]!=""]) 
    43         self.Iq.argtypes = IQ_ARGS + [c_double]*Nfixed + [c_int]*N1D 
    44         self.Iqxy.argtypes = IQXY_ARGS + [c_double]*Nfixed + [c_int]*N2D 
     42        self.Iq = self.dll[gen.kernel_name(self.info, False)] 
     43        self.Iq.argtypes = IQ_ARGS + [c_double]*Nfixed1d + [c_int]*Npd1d 
    4544 
    46     def __call__(self, input, cutoff=1e-5): 
     45        self.Iqxy = self.dll[gen.kernel_name(self.info, True)] 
     46        self.Iqxy.argtypes = IQXY_ARGS + [c_double]*Nfixed2d + [c_int]*Npd2d 
     47 
     48    def __getstate__(self): 
     49        return {'info': self.info, 'dllpath': self.dllpath, 'dll': None} 
     50 
     51    def __setstate__(self, state): 
     52        self.__dict__ = state 
     53 
     54    def __call__(self, input): 
     55        if self.dll is None: self._load_dll() 
     56 
    4757        kernel = self.Iqxy if input.is_2D else self.Iq 
    48         return DllKernel(kernel, self.meta, input, cutoff) 
     58        return DllKernel(kernel, self.info, input) 
    4959 
    5060    def make_input(self, q_vectors): 
     
    90100 
    91101class DllKernel(object): 
    92     def __init__(self, kernel, meta, input, cutoff): 
    93         self.cutoff = cutoff 
     102    def __init__(self, kernel, info, input): 
    94103        self.input = input 
    95104        self.kernel = kernel 
    96         self.meta = meta 
     105        self.info = info 
     106        self.res = np.empty(input.nq, input.dtype) 
     107        dim = '2d' if input.is_2D else '1d' 
     108        self.fixed_pars = info['partype']['fixed-'+dim] 
     109        self.pd_pars = info['partype']['pd-'+dim] 
    97110 
    98         self.res = np.empty(input.nq, input.dtype) 
     111        # In dll kernel, but not in opencl kernel 
    99112        self.p_res = self.res.ctypes.data 
    100113 
    101         # Determine the set of fixed and polydisperse parameters 
    102         self.fixed_pars = [p[0] for p in meta['parameters'] if p[4] == ""] 
    103         self.pd_pars = [p for p in meta['parameters'] 
    104                if p[4]=="volume" or (p[4]=="orientation" and input.is_2D)] 
     114    def __call__(self, pars, pd_pars, cutoff): 
     115        real = np.float32 if self.input.dtype == F32 else np.float64 
     116        fixed = [real(p) for p in pars] 
     117        cutoff = real(cutoff) 
     118        loops = np.hstack(pd_pars) 
     119        loops = np.ascontiguousarray(loops.T, self.input.dtype).flatten() 
     120        loops_N = [np.uint32(len(p[0])) for p in pd_pars] 
    105121 
    106     def eval(self, pars): 
    107         fixed, loops, loop_n = \ 
    108             gen.kernel_pars(pars, self.meta, self.input.is_2D, dtype=self.input.dtype) 
    109         real = np.float32 if self.input.dtype == F32 else np.float64 
    110122        nq = c_int(self.input.nq) 
    111         cutoff = real(self.cutoff) 
    112  
    113123        p_loops = loops.ctypes.data 
    114         pars = self.input.q_pointers + [self.p_res, nq, p_loops, cutoff] + fixed + loop_n 
     124        args = self.input.q_pointers + [self.p_res, nq, p_loops, cutoff] + fixed + loops_N 
    115125        #print pars 
    116         self.kernel(*pars) 
     126        self.kernel(*args) 
    117127 
    118128        return self.res 
     
    120130    def release(self): 
    121131        pass 
    122  
    123     def __del__(self): 
    124         self.release() 
Note: See TracChangeset for help on using the changeset viewer.