Changeset 0ff62d4 in sasmodels for sasmodels/kernel.py


Ignore:
Timestamp:
Apr 15, 2016 12:31:35 PM (8 years ago)
Author:
Paul Kienzle <pkienzle@…>
Branches:
master, core_shell_microgels, costrafo411, magnetic_model, release_v0.94, release_v0.95, ticket-1257-vesicle-product, ticket_1156, ticket_1265_superball, ticket_822_more_unit_tests
Children:
f2f67a6
Parents:
8f6817d
Message:

refactor: move dispersion_mesh alongside build_details in kernel.py

File:
1 edited

Legend:

Unmodified
Added
Removed
  • sasmodels/kernel.py

    ra5b8477 r0ff62d4  
    1010""" 
    1111 
     12import numpy as np 
     13from .details import mono_details, poly_details 
     14 
    1215try: 
    1316    from typing import List 
     17except ImportError: 
     18    pass 
     19else: 
    1420    from .details import CallDetails 
    1521    from .modelinfo import ModelInfo 
    1622    import numpy as np  # type: ignore 
    17 except ImportError: 
    18     pass 
    1923 
    2024class KernelModel(object): 
     
    4246        # type: () -> None 
    4347        pass 
     48 
     49try: 
     50    np.meshgrid([]) 
     51    meshgrid = np.meshgrid 
     52except ValueError: 
     53    # CRUFT: np.meshgrid requires multiple vectors 
     54    def meshgrid(*args): 
     55        if len(args) > 1: 
     56            return np.meshgrid(*args) 
     57        else: 
     58            return [np.asarray(v) for v in args] 
     59 
     60def dispersion_mesh(model_info, pars): 
     61    """ 
     62    Create a mesh grid of dispersion parameters and weights. 
     63 
     64    Returns [p1,p2,...],w where pj is a vector of values for parameter j 
     65    and w is a vector containing the products for weights for each 
     66    parameter set in the vector. 
     67    """ 
     68    value, weight = zip(*pars) 
     69    weight = [w if w else [1.] for w in weight] 
     70    weight = np.vstack([v.flatten() for v in meshgrid(*weight)]) 
     71    weight = np.prod(weight, axis=0) 
     72    value = [v.flatten() for v in meshgrid(*value)] 
     73    lengths = [par.length for par in model_info.parameters.kernel_parameters 
     74               if par.type == 'volume'] 
     75    if any(n > 1 for n in lengths): 
     76        pars = [] 
     77        offset = 0 
     78        for n in lengths: 
     79            pars.append(np.vstack(value[offset:offset+n]) if n > 1 else value[offset]) 
     80            offset += n 
     81        value = pars 
     82    return value, weight 
     83 
     84 
     85 
     86def build_details(kernel, pairs): 
     87    # type: (Kernel, Tuple[List[np.ndarray], List[np.ndarray]]) -> Tuple[CallDetails, np.ndarray, np.ndarray] 
     88    """ 
     89    Construct the kernel call details object for calling the particular kernel. 
     90    """ 
     91    values, weights = zip(*pairs) 
     92    if max([len(w) for w in weights]) > 1: 
     93        call_details = poly_details(kernel.info, weights) 
     94    else: 
     95        call_details = mono_details(kernel.info) 
     96    weights, values = [np.hstack(v) for v in (weights, values)] 
     97    weights = weights.astype(dtype=kernel.dtype) 
     98    values = values.astype(dtype=kernel.dtype) 
     99    return call_details, weights, values 
     100 
Note: See TracChangeset for help on using the changeset viewer.