source: sasmodels/sasmodels/kernel.py @ 0ff62d4

core_shell_microgelscostrafo411magnetic_modelrelease_v0.94release_v0.95ticket-1257-vesicle-productticket_1156ticket_1265_superballticket_822_more_unit_tests
Last change on this file since 0ff62d4 was 0ff62d4, checked in by Paul Kienzle <pkienzle@…>, 8 years ago

refactor: move dispersion_mesh alongside build_details in kernel.py

  • Property mode set to 100644
File size: 3.2 KB
Line 
1"""
2Execution kernel interface
3==========================
4
5:class:`KernelModel` defines the interface to all kernel models.
6In particular, each model should provide a :meth:`KernelModel.make_kernel`
7call which returns an executable kernel, :class:`Kernel`, that operates
8on the given set of *q_vector* inputs.  On completion of the computation,
9the kernel should be released, which also releases the inputs.
10"""
11
12import numpy as np
13from .details import mono_details, poly_details
14
15try:
16    from typing import List
17except ImportError:
18    pass
19else:
20    from .details import CallDetails
21    from .modelinfo import ModelInfo
22    import numpy as np  # type: ignore
23
24class KernelModel(object):
25    info = None  # type: ModelInfo
26    dtype = None # type: np.dtype
27    def make_kernel(self, q_vectors):
28        # type: (List[np.ndarray]) -> "Kernel"
29        raise NotImplementedError("need to implement make_kernel")
30
31    def release(self):
32        # type: () -> None
33        pass
34
35class Kernel(object):
36    #: kernel dimension, either "1d" or "2d"
37    dim = None  # type: str
38    info = None  # type: ModelInfo
39    results = None # type: List[np.ndarray]
40
41    def __call__(self, call_details, weights, values, cutoff):
42        # type: (CallDetails, np.ndarray, np.ndarray, float) -> np.ndarray
43        raise NotImplementedError("need to implement __call__")
44
45    def release(self):
46        # type: () -> None
47        pass
48
49try:
50    np.meshgrid([])
51    meshgrid = np.meshgrid
52except ValueError:
53    # CRUFT: np.meshgrid requires multiple vectors
54    def meshgrid(*args):
55        if len(args) > 1:
56            return np.meshgrid(*args)
57        else:
58            return [np.asarray(v) for v in args]
59
60def dispersion_mesh(model_info, pars):
61    """
62    Create a mesh grid of dispersion parameters and weights.
63
64    Returns [p1,p2,...],w where pj is a vector of values for parameter j
65    and w is a vector containing the products for weights for each
66    parameter set in the vector.
67    """
68    value, weight = zip(*pars)
69    weight = [w if w else [1.] for w in weight]
70    weight = np.vstack([v.flatten() for v in meshgrid(*weight)])
71    weight = np.prod(weight, axis=0)
72    value = [v.flatten() for v in meshgrid(*value)]
73    lengths = [par.length for par in model_info.parameters.kernel_parameters
74               if par.type == 'volume']
75    if any(n > 1 for n in lengths):
76        pars = []
77        offset = 0
78        for n in lengths:
79            pars.append(np.vstack(value[offset:offset+n]) if n > 1 else value[offset])
80            offset += n
81        value = pars
82    return value, weight
83
84
85
86def build_details(kernel, pairs):
87    # type: (Kernel, Tuple[List[np.ndarray], List[np.ndarray]]) -> Tuple[CallDetails, np.ndarray, np.ndarray]
88    """
89    Construct the kernel call details object for calling the particular kernel.
90    """
91    values, weights = zip(*pairs)
92    if max([len(w) for w in weights]) > 1:
93        call_details = poly_details(kernel.info, weights)
94    else:
95        call_details = mono_details(kernel.info)
96    weights, values = [np.hstack(v) for v in (weights, values)]
97    weights = weights.astype(dtype=kernel.dtype)
98    values = values.astype(dtype=kernel.dtype)
99    return call_details, weights, values
100
Note: See TracBrowser for help on using the repository browser.