Changeset e8d2276 in sasmodels for sasmodels/kernelcl.py


Ignore:
Timestamp:
Mar 21, 2016 4:09:13 PM (8 years ago)
Author:
wojciech
Branches:
master, core_shell_microgels, costrafo411, magnetic_model, release_v0.94, release_v0.95, ticket-1257-vesicle-product, ticket_1156, ticket_1265_superball, ticket_822_more_unit_tests
Children:
5462ffb
Parents:
abc03d8 (diff), 48fbd50 (diff)
Note: this is a merge changeset, the changes displayed below correspond to the merge itself.
Use the (diff) links above to see all the changes relative to each parent.
Message:

Merged with branch

File:
1 edited

Legend:

Unmodified
Added
Removed
  • sasmodels/kernelcl.py

    r445d1c0 r48fbd50  
    336336        self.program = None 
    337337 
    338     def make_calculator(self, q_vectors, details): 
     338    def make_kernel(self, q_vectors): 
    339339        if self.program is None: 
    340340            compiler = environment().compile_program 
     
    344344        kernel_name = generate.kernel_name(self.info, is_2d) 
    345345        kernel = getattr(self.program, kernel_name) 
    346         return GpuKernel(kernel, self.info, q_vectors, details, self.dtype) 
     346        return GpuKernel(kernel, self.info, q_vectors, self.dtype) 
    347347 
    348348    def release(self): 
     
    403403        context = env.get_context(self.dtype) 
    404404        #print("creating inputs of size", self.global_size) 
    405         # COPY_HOST_PTR initiates transfer as necessary? 
    406405        self.q_b = cl.Buffer(context, mf.READ_ONLY | mf.COPY_HOST_PTR, 
    407406                             hostbuf=self.q) 
     
    438437    Call :meth:`release` when done with the kernel instance. 
    439438    """ 
    440     def __init__(self, kernel, model_info, q_vectors, details, dtype): 
    441         if details.dtype != np.int32: 
    442             raise TypeError("numeric type does not match the kernel type") 
    443  
     439    def __init__(self, kernel, model_info, q_vectors, dtype): 
    444440        max_pd = self.info['max_pd'] 
    445441        npars = len(model_info['parameters'])-2 
     
    449445        self.kernel = kernel 
    450446        self.info = model_info 
    451         self.details = details 
    452447        self.pd_stop_index = 4*max_pd-1 
    453448        # plus three for the normalization values 
     
    459454        self.queue = env.get_queue(dtype) 
    460455 
    461         # details is int32 data, padded to a 32 integer boundary 
    462         size = 4*((self.info['mono'].size+7)//8)*8 # padded to 32 byte boundary 
    463         self.details_b = cl.Buffer(self.queue.context, 
    464                                    mf.READ_ONLY | mf.COPY_HOST_PTR, 
    465                                    hostbuf=details) 
    466         size = np.sum(details[max_pd:2*max_pd]) 
    467         self.weights_b = cl.Buffer(self.queue.context, mf.READ_ONLY, size) 
    468         size = np.sum(details[max_pd:2*max_pd])+npars 
    469         self.values_b = cl.Buffer(self.queue.context, mf.READ_ONLY, size) 
     456        # details is int32 data, padded to an 8 integer boundary 
     457        size = ((max_pd*5 + npars*3 + 2 + 7)//8)*8 
    470458        self.result_b = cl.Buffer(self.queue.context, mf.READ_WRITE, 
    471459                               q_input.global_size[0] * q_input.dtype.itemsize) 
    472460        self.q_input = q_input # allocated by GpuInput above 
    473461 
    474         self._need_release = [ 
    475             self.details_b, self.weights_b, self.values_b, self.result_b, 
    476             self.q_input, 
    477         ] 
    478  
    479     def __call__(self, weights, values, cutoff): 
     462        self._need_release = [ self.result_b, self.q_input ] 
     463 
     464    def __call__(self, details, weights, values, cutoff): 
    480465        real = (np.float32 if self.q_input.dtype == generate.F32 
    481466                else np.float64 if self.q_input.dtype == generate.F64 
    482467                else np.float16 if self.q_input.dtype == generate.F16 
    483468                else np.float32)  # will never get here, so use np.float32 
    484  
    485         if weights.dtype != real or values.dtype != real: 
    486             raise TypeError("numeric type does not match the kernel type") 
    487  
    488         cl.enqueue_copy(self.queue, self.weights_b, weights) 
    489         cl.enqueue_copy(self.queue, self.values_b, values) 
    490  
     469        assert details.dtype == np.int32 
     470        assert weights.dtype == real and values.dtype == real 
     471 
     472        context = self.queue.context 
     473        details_b = cl.Buffer(context, mf.READ_ONLY | mf.COPY_HOST_PTR, 
     474                              hostbuf=details) 
     475        weights_b = cl.Buffer(context, mf.READ_ONLY | mf.COPY_HOST_PTR, 
     476                              hostbuf=weights) 
     477        values_b = cl.Buffer(context, mf.READ_ONLY | mf.COPY_HOST_PTR, 
     478                             hostbuf=values) 
     479 
     480        start, stop = 0, self.details[self.pd_stop_index] 
    491481        args = [ 
    492             np.uint32(self.q_input.nq), 
    493             np.uint32(0), 
    494             np.uint32(self.details[self.pd_stop_index]), 
    495             self.details_b, 
    496             self.weights_b, 
    497             self.values_b, 
    498             self.q_input.q_b, 
    499             self.result_b, 
    500             real(cutoff), 
     482            np.uint32(self.q_input.nq), np.uint32(start), np.uint32(stop), 
     483            self.details_b, self.weights_b, self.values_b, 
     484            self.q_input.q_b, self.result_b, real(cutoff), 
    501485        ] 
    502486        self.kernel(self.queue, self.q_input.global_size, None, *args) 
    503487        cl.enqueue_copy(self.queue, self.result, self.result_b) 
     488        [v.release() for v in details_b, weights_b, values_b] 
    504489 
    505490        return self.result[:self.nq] 
Note: See TracChangeset for help on using the changeset viewer.