Changeset f619de7 in sasmodels for sasmodels/mixture.py


Ignore:
Timestamp:
Apr 11, 2016 11:14:50 AM (8 years ago)
Author:
Paul Kienzle <pkienzle@…>
Branches:
master, core_shell_microgels, costrafo411, magnetic_model, release_v0.94, release_v0.95, ticket-1257-vesicle-product, ticket_1156, ticket_1265_superball, ticket_822_more_unit_tests
Children:
7ae2b7f
Parents:
9a943d0
Message:

more type hinting

File:
1 edited

Legend:

Unmodified
Added
Removed
  • sasmodels/mixture.py

    r6d6508e rf619de7  
    1515 
    1616from .modelinfo import Parameter, ParameterTable, ModelInfo 
     17from .kernel import KernelModel, Kernel 
     18 
     19try: 
     20    from typing import List 
     21    from .details import CallDetails 
     22except ImportError: 
     23    pass 
    1724 
    1825def make_mixture_info(parts): 
     26    # type: (List[ModelInfo]) -> ModelInfo 
    1927    """ 
    2028    Create info block for product model. 
     
    2230    flatten = [] 
    2331    for part in parts: 
    24         if part['composition'] and part['composition'][0] == 'mixture': 
    25             flatten.extend(part['compostion'][1]) 
     32        if part.composition and part.composition[0] == 'mixture': 
     33            flatten.extend(part.composition[1]) 
    2634        else: 
    2735            flatten.append(part) 
     
    2937 
    3038    # Build new parameter list 
    31     pars = [] 
     39    combined_pars = [] 
    3240    for k, part in enumerate(parts): 
    3341        # Parameter prefix per model, A_, B_, ... 
     
    3543        # to support vector parameters 
    3644        prefix = chr(ord('A')+k) + '_' 
    37         pars.append(Parameter(prefix+'scale')) 
    38         for p in part['parameters'].kernel_pars: 
     45        combined_pars.append(Parameter(prefix+'scale')) 
     46        for p in part.parameters.kernel_parameters: 
    3947            p = copy(p) 
    40             p.name = prefix+p.name 
    41             p.id = prefix+p.id 
     48            p.name = prefix + p.name 
     49            p.id = prefix + p.id 
    4250            if p.length_control is not None: 
    43                 p.length_control = prefix+p.length_control 
    44             pars.append(p) 
    45     partable = ParameterTable(pars) 
     51                p.length_control = prefix + p.length_control 
     52            combined_pars.append(p) 
     53    parameters = ParameterTable(combined_pars) 
    4654 
    4755    model_info = ModelInfo() 
    48     model_info.id = '+'.join(part['id']) 
    49     model_info.name = ' + '.join(part['name']) 
     56    model_info.id = '+'.join(part.id for part in parts) 
     57    model_info.name = ' + '.join(part.name for part in parts) 
    5058    model_info.filename = None 
    5159    model_info.title = 'Mixture model with ' + model_info.name 
     
    5361    model_info.docs = model_info.title 
    5462    model_info.category = "custom" 
    55     model_info.parameters = partable 
     63    model_info.parameters = parameters 
    5664    #model_info.single = any(part['single'] for part in parts) 
    5765    model_info.structure_factor = False 
     
    6472 
    6573 
    66 class MixtureModel(object): 
     74class MixtureModel(KernelModel): 
    6775    def __init__(self, model_info, parts): 
     76        # type: (ModelInfo, List[KernelModel]) -> None 
    6877        self.info = model_info 
    6978        self.parts = parts 
    7079 
    7180    def __call__(self, q_vectors): 
     81        # type: (List[np.ndarray]) -> MixtureKernel 
    7282        # Note: may be sending the q_vectors to the n times even though they 
    7383        # are only needed once.  It would mess up modularity quite a bit to 
     
    7686        # in opencl; or both in opencl, but one in single precision and the 
    7787        # other in double precision). 
    78         kernels = [part(q_vectors) for part in self.parts] 
     88        kernels = [part.make_kernel(q_vectors) for part in self.parts] 
    7989        return MixtureKernel(self.info, kernels) 
    8090 
    8191    def release(self): 
     92        # type: () -> None 
    8293        """ 
    8394        Free resources associated with the model. 
     
    8798 
    8899 
    89 class MixtureKernel(object): 
     100class MixtureKernel(Kernel): 
    90101    def __init__(self, model_info, kernels): 
    91         dim = '2d' if kernels[0].q_input.is_2d else '1d' 
     102        # type: (ModelInfo, List[Kernel]) -> None 
     103        self.dim = kernels[0].dim 
     104        self.info =  model_info 
     105        self.kernels = kernels 
    92106 
    93         # fixed offsets starts at 2 for scale and background 
    94         fixed_pars, pd_pars = [], [] 
    95         offsets = [[2, 0]] 
    96         #vol_index = [] 
    97         def accumulate(fixed, pd, volume): 
    98             # subtract 1 from fixed since we are removing background 
    99             fixed_offset, pd_offset = offsets[-1] 
    100             #vol_index.extend(k+pd_offset for k,v in pd if v in volume) 
    101             offsets.append([fixed_offset + len(fixed) - 1, pd_offset + len(pd)]) 
    102             pd_pars.append(pd) 
    103         if dim == '2d': 
    104             for p in kernels: 
    105                 partype = p.info.partype 
    106                 accumulate(partype['fixed-2d'], partype['pd-2d'], partype['volume']) 
    107         else: 
    108             for p in kernels: 
    109                 partype = p.info.partype 
    110                 accumulate(partype['fixed-1d'], partype['pd-1d'], partype['volume']) 
    111  
    112         #self.vol_index = vol_index 
    113         self.offsets = offsets 
    114         self.fixed_pars = fixed_pars 
    115         self.pd_pars = pd_pars 
    116         self.info = model_info 
    117         self.kernels = kernels 
    118         self.results = None 
    119  
    120     def __call__(self, fixed_pars, pd_pars, cutoff=1e-5): 
    121         scale, background = fixed_pars[0:2] 
     107    def __call__(self, call_details, value, weight, cutoff): 
     108        # type: (CallDetails, np.ndarray, np.ndarry, float) -> np.ndarray 
     109        scale, background = value[0:2] 
    122110        total = 0.0 
    123         self.results = []  # remember the parts for plotting later 
    124         for k in range(len(self.offsets)-1): 
    125             start_fixed, start_pd = self.offsets[k] 
    126             end_fixed, end_pd = self.offsets[k+1] 
    127             part_fixed = [fixed_pars[start_fixed], 0.0] + fixed_pars[start_fixed+1:end_fixed] 
    128             part_pd = [pd_pars[start_pd], 0.0] + pd_pars[start_pd+1:end_pd] 
    129             part_result = self.kernels[k](part_fixed, part_pd) 
     111        # remember the parts for plotting later 
     112        self.results = [] 
     113        for kernel, kernel_details in zip(self.kernels, call_details.parts): 
     114            part_result = kernel(kernel_details, value, weight, cutoff) 
    130115            total += part_result 
    131             self.results.append(scale*sum+background) 
     116            self.results.append(part_result) 
    132117 
    133118        return scale*total + background 
    134119 
    135120    def release(self): 
    136         self.p_kernel.release() 
    137         self.q_kernel.release() 
     121        # type: () -> None 
     122        for k in self.kernels: 
     123            k.release() 
    138124 
Note: See TracChangeset for help on using the changeset viewer.