source: sasmodels/sasmodels/kernel.py @ def2c1b

core_shell_microgelscostrafo411magnetic_modelrelease_v0.94release_v0.95ticket-1257-vesicle-productticket_1156ticket_1265_superballticket_822_more_unit_tests
Last change on this file since def2c1b was def2c1b, checked in by Paul Kienzle <pkienzle@…>, 8 years ago

honour platform request when selecting kernel

  • Property mode set to 100644
File size: 3.2 KB
Line 
1"""
2Execution kernel interface
3==========================
4
5:class:`KernelModel` defines the interface to all kernel models.
6In particular, each model should provide a :meth:`KernelModel.make_kernel`
7call which returns an executable kernel, :class:`Kernel`, that operates
8on the given set of *q_vector* inputs.  On completion of the computation,
9the kernel should be released, which also releases the inputs.
10"""
11
12from __future__ import division, print_function
13
14import numpy as np
15from .details import mono_details, poly_details
16
17try:
18    from typing import List
19except ImportError:
20    pass
21else:
22    from .details import CallDetails
23    from .modelinfo import ModelInfo
24    import numpy as np  # type: ignore
25
26class KernelModel(object):
27    info = None  # type: ModelInfo
28    dtype = None # type: np.dtype
29    def make_kernel(self, q_vectors):
30        # type: (List[np.ndarray]) -> "Kernel"
31        raise NotImplementedError("need to implement make_kernel")
32
33    def release(self):
34        # type: () -> None
35        pass
36
37class Kernel(object):
38    #: kernel dimension, either "1d" or "2d"
39    dim = None  # type: str
40    info = None  # type: ModelInfo
41    results = None # type: List[np.ndarray]
42
43    def __call__(self, call_details, values, cutoff):
44        # type: (CallDetails, np.ndarray, np.ndarray, float) -> np.ndarray
45        raise NotImplementedError("need to implement __call__")
46
47    def release(self):
48        # type: () -> None
49        pass
50
51try:
52    np.meshgrid([])
53    meshgrid = np.meshgrid
54except ValueError:
55    # CRUFT: np.meshgrid requires multiple vectors
56    def meshgrid(*args):
57        if len(args) > 1:
58            return np.meshgrid(*args)
59        else:
60            return [np.asarray(v) for v in args]
61
62def dispersion_mesh(model_info, pars):
63    """
64    Create a mesh grid of dispersion parameters and weights.
65
66    Returns [p1,p2,...],w where pj is a vector of values for parameter j
67    and w is a vector containing the products for weights for each
68    parameter set in the vector.
69    """
70    value, weight = zip(*pars)
71    weight = [w if w else [1.] for w in weight]
72    weight = np.vstack([v.flatten() for v in meshgrid(*weight)])
73    weight = np.prod(weight, axis=0)
74    value = [v.flatten() for v in meshgrid(*value)]
75    lengths = [par.length for par in model_info.parameters.kernel_parameters
76               if par.type == 'volume']
77    if any(n > 1 for n in lengths):
78        pars = []
79        offset = 0
80        for n in lengths:
81            pars.append(np.vstack(value[offset:offset+n]) if n > 1 else value[offset])
82            offset += n
83        value = pars
84    return value, weight
85
86
87
88def build_details(kernel, pairs):
89    # type: (Kernel, Tuple[List[np.ndarray], List[np.ndarray]]) -> Tuple[CallDetails, np.ndarray, np.ndarray]
90    """
91    Construct the kernel call details object for calling the particular kernel.
92    """
93    values, weights = zip(*pairs)
94    scalars = [v[0] for v in values]
95    if all(len(w)==1 for w in weights):
96        call_details = mono_details(kernel.info)
97        data = np.array(scalars, dtype=kernel.dtype)
98    else:
99        call_details = poly_details(kernel.info, weights)
100        data = np.hstack(scalars+list(values)+list(weights)).astype(kernel.dtype)
101    return call_details, data
Note: See TracBrowser for help on using the repository browser.