Changeset a5b8477 in sasmodels for sasmodels/kernelcl.py


Ignore:
Timestamp:
Apr 13, 2016 8:17:10 PM (8 years ago)
Author:
Paul Kienzle <pkienzle@…>
Branches:
master, core_shell_microgels, costrafo411, magnetic_model, release_v0.94, release_v0.95, ticket-1257-vesicle-product, ticket_1156, ticket_1265_superball, ticket_822_more_unit_tests
Children:
0ce5710
Parents:
60f03de
Message:

update docs to work with the new ModelInfo/ParameterTable? classes

File:
1 edited

Legend:

Unmodified
Added
Removed
  • sasmodels/kernelcl.py

    r7ae2b7f ra5b8477  
    4949""" 
    5050from __future__ import print_function 
     51 
    5152import os 
    5253import warnings 
     
    6869from . import generate 
    6970from .kernel import KernelModel, Kernel 
     71 
     72try: 
     73    from typing import Tuple, Callable, Any 
     74    from .modelinfo import ModelInfo 
     75    from .details import CallDetails 
     76except ImportError: 
     77    pass 
    7078 
    7179# The max loops number is limited by the amount of local memory available 
     
    441449    Call :meth:`release` when done with the kernel instance. 
    442450    """ 
    443     def __init__(self, kernel, model_info, q_vectors, dtype): 
    444         max_pd = model_info.max_pd 
    445         npars = len(model_info.parameters)-2 
    446         q_input = GpuInput(q_vectors, dtype) 
    447         self.dtype = dtype 
    448         self.dim = '2d' if q_input.is_2d else '1d' 
     451    def __init__(self, kernel, model_info, q_vectors): 
     452        # type: (KernelModel, ModelInfo, List[np.ndarray]) -> None 
     453        max_pd = model_info.parameters.max_pd 
     454        npars = len(model_info.parameters.kernel_parameters)-2 
     455        q_input = GpuInput(q_vectors, kernel.dtype) 
    449456        self.kernel = kernel 
    450457        self.info = model_info 
     458        self.dtype = kernel.dtype 
     459        self.dim = '2d' if q_input.is_2d else '1d' 
    451460        self.pd_stop_index = 4*max_pd-1 
    452461        # plus three for the normalization values 
     
    456465        # Note: res may be shorter than res_b if global_size != nq 
    457466        env = environment() 
    458         self.queue = env.get_queue(dtype) 
     467        self.queue = env.get_queue(kernel.dtype) 
    459468 
    460469        # details is int32 data, padded to an 8 integer boundary 
    461470        size = ((max_pd*5 + npars*3 + 2 + 7)//8)*8 
    462471        self.result_b = cl.Buffer(self.queue.context, mf.READ_WRITE, 
    463                                q_input.global_size[0] * q_input.dtype.itemsize) 
     472                               q_input.global_size[0] * kernel.dtype.itemsize) 
    464473        self.q_input = q_input # allocated by GpuInput above 
    465474 
     
    467476 
    468477    def __call__(self, call_details, weights, values, cutoff): 
     478        # type: (CallDetails, np.ndarray, np.ndarray, float) -> np.ndarray 
    469479        real = (np.float32 if self.q_input.dtype == generate.F32 
    470480                else np.float64 if self.q_input.dtype == generate.F64 
Note: See TracChangeset for help on using the changeset viewer.