Changeset f619de7 in sasmodels for sasmodels/product.py


Ignore:
Timestamp:
Apr 11, 2016 11:14:50 AM (8 years ago)
Author:
Paul Kienzle <pkienzle@…>
Branches:
master, core_shell_microgels, costrafo411, magnetic_model, release_v0.94, release_v0.95, ticket-1257-vesicle-product, ticket_1156, ticket_1265_superball, ticket_822_more_unit_tests
Children:
7ae2b7f
Parents:
9a943d0
Message:

more type hinting

File:
1 edited

Legend:

Unmodified
Added
Removed
  • sasmodels/product.py

    r6d6508e rf619de7  
    1414 
    1515from .details import dispersion_mesh 
    16 from .modelinfo import suffix_parameter, ParameterTable, Parameter, ModelInfo 
     16from .modelinfo import suffix_parameter, ParameterTable, ModelInfo 
     17from .kernel import KernelModel, Kernel 
     18 
     19try: 
     20    from typing import Tuple 
     21    from .modelinfo import ParameterSet 
     22    from .details import CallDetails 
     23except ImportError: 
     24    pass 
    1725 
    1826# TODO: make estimates available to constraints 
     
    2533# revert it after making VR and ER available at run time as constraints. 
    2634def make_product_info(p_info, s_info): 
     35    # type: (ModelInfo, ModelInfo) -> ModelInfo 
    2736    """ 
    2837    Create info block for product model. 
    2938    """ 
    30     p_id, p_name, p_partable = p_info.id, p_info.name, p_info.parameters 
    31     s_id, s_name, s_partable = s_info.id, s_info.name, s_info.parameters 
    32     p_set = set(p.id for p in p_partable) 
    33     s_set = set(p.id for p in s_partable) 
     39    p_id, p_name, p_pars = p_info.id, p_info.name, p_info.parameters 
     40    s_id, s_name, s_pars = s_info.id, s_info.name, s_info.parameters 
     41    p_set = set(p.id for p in p_pars.call_parameters) 
     42    s_set = set(p.id for p in s_pars.call_parameters) 
    3443 
    3544    if p_set & s_set: 
    3645        # there is some overlap between the parameter names; tag the 
    3746        # overlapping S parameters with name_S 
    38         s_pars = [(suffix_parameter(par, "_S") if par.id in p_set else par) 
    39                   for par in s_partable.kernel_parameters] 
    40         pars = p_partable.kernel_parameters + s_pars 
     47        s_list = [(suffix_parameter(par, "_S") if par.id in p_set else par) 
     48                  for par in s_pars.kernel_parameters] 
     49        combined_pars = p_pars.kernel_parameters + s_list 
    4150    else: 
    42         pars= p_partable.kernel_parameters + s_partable.kernel_parameters 
     51        combined_pars = p_pars.kernel_parameters + s_pars.kernel_parameters 
     52    parameters = ParameterTable(combined_pars) 
    4353 
    4454    model_info = ModelInfo() 
     
    5060    model_info.docs = model_info.title 
    5161    model_info.category = "custom" 
    52     model_info.parameters = ParameterTable(pars) 
     62    model_info.parameters = parameters 
    5363    #model_info.single = p_info.single and s_info.single 
    5464    model_info.structure_factor = False 
     
    6070    return model_info 
    6171 
    62 class ProductModel(object): 
     72class ProductModel(KernelModel): 
    6373    def __init__(self, model_info, P, S): 
     74        # type: (ModelInfo, KernelModel, KernelModel) -> None 
    6475        self.info = model_info 
    6576        self.P = P 
     
    6778 
    6879    def __call__(self, q_vectors): 
     80        # type: (List[np.ndarray]) -> Kernel 
    6981        # Note: may be sending the q_vectors to the GPU twice even though they 
    7082        # are only needed once.  It would mess up modularity quite a bit to 
     
    7385        # in opencl; or both in opencl, but one in single precision and the 
    7486        # other in double precision). 
    75         p_kernel = self.P(q_vectors) 
    76         s_kernel = self.S(q_vectors) 
     87        p_kernel = self.P.make_kernel(q_vectors) 
     88        s_kernel = self.S.make_kernel(q_vectors) 
    7789        return ProductKernel(self.info, p_kernel, s_kernel) 
    7890 
    7991    def release(self): 
     92        # type: (None) -> None 
    8093        """ 
    8194        Free resources associated with the model. 
     
    8598 
    8699 
    87 class ProductKernel(object): 
     100class ProductKernel(Kernel): 
    88101    def __init__(self, model_info, p_kernel, s_kernel): 
     102        # type: (ModelInfo, Kernel, Kernel) -> None 
    89103        self.info = model_info 
    90104        self.p_kernel = p_kernel 
     
    92106 
    93107    def __call__(self, details, weights, values, cutoff): 
     108        # type: (CallDetails, np.ndarray, np.ndarray, float) -> np.ndarray 
    94109        effect_radius, vol_ratio = call_ER_VR(self.p_kernel.info, vol_pars) 
    95110 
     
    108123 
    109124    def release(self): 
     125        # type: () -> None 
    110126        self.p_kernel.release() 
    111         self.q_kernel.release() 
     127        self.s_kernel.release() 
    112128 
    113 def call_ER_VR(model_info, vol_pars): 
     129def call_ER_VR(model_info, pars): 
    114130    """ 
    115131    Return effect radius and volume ratio for the model. 
    116132    """ 
    117     value, weight = dispersion_mesh(vol_pars) 
     133    if model_info.ER is None and model_info.VR is None: 
     134        return 1.0, 1.0 
    118135 
    119     individual_radii = model_info.ER(*value) if model_info.ER else 1.0 
    120     whole, part = model_info.VR(*value) if model_info.VR else (1.0, 1.0) 
     136    value, weight = _vol_pars(model_info, pars) 
    121137 
    122     effect_radius = np.sum(weight*individual_radii) / np.sum(weight) 
    123     volume_ratio = np.sum(weight*part)/np.sum(weight*whole) 
     138    if model_info.ER is not None: 
     139        individual_radii = model_info.ER(*value) 
     140        effect_radius = np.sum(weight*individual_radii) / np.sum(weight) 
     141    else: 
     142        effect_radius = 1.0 
     143 
     144    if model_info.VR is not None: 
     145        whole, part = model_info.VR(*value) 
     146        volume_ratio = np.sum(weight*part)/np.sum(weight*whole) 
     147    else: 
     148        volume_ratio = 1.0 
     149 
    124150    return effect_radius, volume_ratio 
     151 
     152def _vol_pars(model_info, pars): 
     153    # type: (ModelInfo, ParameterSet) -> Tuple[np.ndarray, np.ndarray] 
     154    vol_pars = [get_weights(p, pars) 
     155                for p in model_info.parameters.call_parameters 
     156                if p.type == 'volume'] 
     157    value, weight = dispersion_mesh(model_info, vol_pars) 
     158    return value, weight 
     159 
Note: See TracChangeset for help on using the changeset viewer.