Changeset 8d62008 in sasmodels


Ignore:
Timestamp:
Apr 15, 2016 12:16:41 PM (8 years ago)
Author:
Paul Kienzle <pkienzle@…>
Branches:
master, core_shell_microgels, costrafo411, magnetic_model, release_v0.94, release_v0.95, ticket-1257-vesicle-product, ticket_1156, ticket_1265_superball, ticket_822_more_unit_tests
Children:
8f6817d
Parents:
3599d36
Message:

remove circular dependency between details/modelinfo; fix compare Calculator type hint

Location:
sasmodels
Files:
5 edited

Legend:

Unmodified
Added
Removed
  • sasmodels/compare.py

    rdd7fc12 r8d62008  
    4949    from .modelinfo import ModelInfo, Parameter, ParameterSet 
    5050    from .data import Data 
    51     Calculator = Callable[[float, ...], np.ndarray] 
     51    Calculator = Callable[[float], np.ndarray] 
    5252 
    5353USAGE = """ 
     
    383383    # import rather than the more obscure smear_selection not imported error 
    384384    import sas 
     385    import sas.models 
    385386    from sas.models.qsmearing import smear_selection 
    386     import sas.models 
     387    from sas.models.MultiplicationModel import MultiplicationModel 
    387388 
    388389    def get_model(name): 
     
    399400        composition_type, parts = model_info.composition 
    400401        if composition_type == 'product': 
    401             from sas.models.MultiplicationModel import MultiplicationModel 
    402402            P, S = [get_model(revert_name(p)) for p in parts] 
    403403            model = MultiplicationModel(P, S) 
  • sasmodels/details.py

    r7ae2b7f r8d62008  
    55except ImportError: 
    66    pass 
     7else: 
     8    from .modelinfo import ModelInfo 
    79 
    810 
     
    1012    parts = None  # type: List["CallDetails"] 
    1113    def __init__(self, model_info): 
     14        # type: (ModelInfo) -> None 
    1215        parameters = model_info.parameters 
    1316        max_pd = parameters.max_pd 
    1417        npars = parameters.npars 
    1518        par_offset = 4*max_pd 
    16         self._details = np.zeros(par_offset + 3*npars + 4, 'i4') 
     19        self.buffer = np.zeros(par_offset + 3 * npars + 4, 'i4') 
    1720 
    1821        # generate views on different parts of the array 
    19         self._pd_par     = self._details[0*max_pd:1*max_pd] 
    20         self._pd_length  = self._details[1*max_pd:2*max_pd] 
    21         self._pd_offset  = self._details[2*max_pd:3*max_pd] 
    22         self._pd_stride  = self._details[3*max_pd:4*max_pd] 
    23         self._par_offset = self._details[par_offset+0*npars:par_offset+1*npars] 
    24         self._par_coord  = self._details[par_offset+1*npars:par_offset+2*npars] 
    25         self._pd_coord   = self._details[par_offset+2*npars:par_offset+3*npars] 
     22        self._pd_par     = self.buffer[0 * max_pd:1 * max_pd] 
     23        self._pd_length  = self.buffer[1 * max_pd:2 * max_pd] 
     24        self._pd_offset  = self.buffer[2 * max_pd:3 * max_pd] 
     25        self._pd_stride  = self.buffer[3 * max_pd:4 * max_pd] 
     26        self._par_offset = self.buffer[par_offset + 0 * npars:par_offset + 1 * npars] 
     27        self._par_coord  = self.buffer[par_offset + 1 * npars:par_offset + 2 * npars] 
     28        self._pd_coord   = self.buffer[par_offset + 2 * npars:par_offset + 3 * npars] 
    2629 
    2730        # theta_par is fixed 
    28         self._details[-1] = parameters.theta_offset 
    29  
    30     @property 
    31     def ctypes(self): return self._details.ctypes 
     31        self.buffer[-1] = parameters.theta_offset 
    3232 
    3333    @property 
     
    5353 
    5454    @property 
    55     def num_active(self): return self._details[-4] 
     55    def num_active(self): return self.buffer[-4] 
    5656    @num_active.setter 
    57     def num_active(self, v): self._details[-4] = v 
     57    def num_active(self, v): self.buffer[-4] = v 
    5858 
    5959    @property 
    60     def total_pd(self): return self._details[-3] 
     60    def total_pd(self): return self.buffer[-3] 
    6161    @total_pd.setter 
    62     def total_pd(self, v): self._details[-3] = v 
     62    def total_pd(self, v): self.buffer[-3] = v 
    6363 
    6464    @property 
    65     def num_coord(self): return self._details[-2] 
     65    def num_coord(self): return self.buffer[-2] 
    6666    @num_coord.setter 
    67     def num_coord(self, v): self._details[-2] = v 
     67    def num_coord(self, v): self.buffer[-2] = v 
    6868 
    6969    @property 
    70     def theta_par(self): return self._details[-1] 
     70    def theta_par(self): return self.buffer[-1] 
    7171 
    7272    def show(self): 
     
    8181        print("par_coord", self.par_coord) 
    8282        print("pd_coord", self.pd_coord) 
    83         print("theta par", self._details[-1]) 
     83        print("theta par", self.buffer[-1]) 
    8484 
    8585def build_details(kernel, pairs): 
     
    8888        call_details = poly_details(kernel.info, weights) 
    8989    else: 
    90         call_details = kernel.info.mono_details 
     90        call_details = mono_details(kernel.info) 
    9191    weights, values = [np.hstack(v) for v in (weights, values)] 
    9292    weights = weights.astype(dtype=kernel.dtype) 
  • sasmodels/kernelcl.py

    rdd7fc12 r8d62008  
    6060    # Ask OpenCL for the default context so that we know that one exists 
    6161    cl.create_some_context(interactive=False) 
    62 except Exception as exc: 
    63     warnings.warn(str(exc)) 
     62except Exception as ocl_exc: 
     63    warnings.warn(str(ocl_exc)) 
     64    del ocl_exc 
    6465    raise RuntimeError("OpenCL not available") 
    6566 
     
    479480        self.dtype = kernel.dtype 
    480481        self.dim = '2d' if q_input.is_2d else '1d' 
    481         self.pd_stop_index = 4*max_pd-1 
    482482        # plus three for the normalization values 
    483483        self.result = np.empty(q_input.nq+3, q_input.dtype) 
     
    495495 
    496496        self._need_release = [ self.result_b, self.q_input ] 
     497        self.real = (np.float32 if self.q_input.dtype == generate.F32 
     498                     else np.float64 if self.q_input.dtype == generate.F64 
     499                     else np.float16 if self.q_input.dtype == generate.F16 
     500                     else np.float32)  # will never get here, so use np.float32 
    497501 
    498502    def __call__(self, call_details, weights, values, cutoff): 
    499503        # type: (CallDetails, np.ndarray, np.ndarray, float) -> np.ndarray 
    500         real = (np.float32 if self.q_input.dtype == generate.F32 
    501                 else np.float64 if self.q_input.dtype == generate.F64 
    502                 else np.float16 if self.q_input.dtype == generate.F16 
    503                 else np.float32)  # will never get here, so use np.float32 
    504         assert call_details.dtype == np.int32 
    505         assert weights.dtype == real and values.dtype == real 
    506504 
    507505        context = self.queue.context 
     506        # Arrange data transfer to card 
    508507        details_b = cl.Buffer(context, mf.READ_ONLY | mf.COPY_HOST_PTR, 
    509                               hostbuf=call_details) 
     508                              hostbuf=call_details.buffer) 
    510509        weights_b = cl.Buffer(context, mf.READ_ONLY | mf.COPY_HOST_PTR, 
    511510                              hostbuf=weights) 
     
    513512                             hostbuf=values) 
    514513 
    515         start, stop = 0, self.details[self.pd_stop_index] 
     514        start, stop = 0, call_details.total_pd 
    516515        args = [ 
    517             np.uint32(self.q_input.nq), np.uint32(start), np.uint32(stop), 
    518             self.details_b, self.weights_b, self.values_b, 
    519             self.q_input.q_b, self.result_b, real(cutoff), 
     516            np.uint32(self.q_input.nq), np.int32(start), np.int32(stop), 
     517            details_b, weights_b, values_b, self.q_input.q_b, self.result_b, 
     518            self.real(cutoff), 
    520519        ] 
    521520        self.kernel(self.queue, self.q_input.global_size, None, *args) 
    522521        cl.enqueue_copy(self.queue, self.result, self.result_b) 
    523         [v.release() for v in (details_b, weights_b, values_b)] 
    524  
    525         return self.result[:self.nq] 
     522        for v in (details_b, weights_b, values_b): 
     523            v.release() 
     524 
     525        return self.result[:self.q_input.nq] 
    526526 
    527527    def release(self): 
  • sasmodels/kerneldll.py

    ra5b8477 r8d62008  
    294294            start, # pd_start 
    295295            stop, # pd_stop pd_stride[MAX_PD] 
    296             call_details.ctypes.data, # problem 
     296            call_details.buffer.ctypes.data, # problem 
    297297            weights.ctypes.data,  # weights 
    298298            values.ctypes.data,  #pars 
  • sasmodels/modelinfo.py

    ra5b8477 r8d62008  
    1313 
    1414import numpy as np  # type: ignore 
    15  
    16 from .details import mono_details 
    1715 
    1816# Optional typing 
     
    2220    pass 
    2321else: 
    24     from .details import CallDetails 
    2522    Limits = Tuple[float, float] 
    2623    #LimitsOrChoice = Union[Limits, Tuple[Sequence[str]]] 
     
    658655    info.hidden = getattr(kernel_module, 'hidden', None) # type: ignore 
    659656 
    660     # Precalculate the monodisperse parameter details 
    661     info.mono_details = mono_details(info) 
    662657    return info 
    663658 
     
    811806    #: the SESANS correlation function.  Note: not currently implemented. 
    812807    sesans = None           # type: Optional[Callable[[np.ndarray], np.ndarray]] 
    813     #: :class:details.CallDetails data for mono-disperse function evaluation. 
    814     #: This field is created automatically by the model loader, and should 
    815     #: not be defined as part of the model definition file. 
    816     mono_details = None     # type: CallDetails 
    817808 
    818809    def __init__(self): 
Note: See TracChangeset for help on using the changeset viewer.