Changeset f619de7 in sasmodels


Ignore:
Timestamp:
Apr 11, 2016 11:14:50 AM (8 years ago)
Author:
Paul Kienzle <pkienzle@…>
Branches:
master, core_shell_microgels, costrafo411, magnetic_model, release_v0.94, release_v0.95, ticket-1257-vesicle-product, ticket_1156, ticket_1265_superball, ticket_822_more_unit_tests
Children:
7ae2b7f
Parents:
9a943d0
Message:

more type hinting

Location:
sasmodels
Files:
1 added
9 edited

Legend:

Unmodified
Added
Removed
  • sasmodels/core.py

    r6d6508e rf619de7  
    2626    HAVE_OPENCL = False 
    2727 
     28try: 
     29    from typing import List, Union, Optional, Any 
     30    DType = Union[None, str, np.dtype] 
     31    from .kernel import KernelModel 
     32except ImportError: 
     33    pass 
     34 
     35 
    2836# TODO: refactor composite model support 
    2937# The current load_model_info/build_model does not reuse existing model 
     
    3947 
    4048def list_models(): 
     49    # type: () -> List[str] 
    4150    """ 
    4251    Return the list of available models on the model path. 
     
    4857 
    4958def isstr(s): 
     59    # type: (Any) -> bool 
    5060    """ 
    5161    Return True if *s* is a string-like object. 
     
    5565    return True 
    5666 
    57 def load_model(model_name, **kw): 
     67def load_model(model_name, dtype=None, platform='ocl'): 
     68    # type: (str, DType, str) -> KernelModel 
    5869    """ 
    5970    Load model info and build model. 
     71 
     72    *model_name* is the name of the model as used by :func:`load_model_info`. 
     73    Additional keyword arguments are passed directly to :func:`build_model`. 
    6074    """ 
    61     return build_model(load_model_info(model_name), **kw) 
     75    return build_model(load_model_info(model_name), 
     76                       dtype=dtype, platform=platform) 
    6277 
    6378 
    6479def load_model_info(model_name): 
     80    # type: (str) -> modelinfo.ModelInfo 
    6581    """ 
    6682    Load a model definition given the model name. 
     
    86102 
    87103def build_model(model_info, dtype=None, platform="ocl"): 
     104    # type: (modelinfo.ModelInfo, DType, str) -> KernelModel 
    88105    """ 
    89106    Prepare the model for the default execution platform. 
     
    138155 
    139156def precompile_dll(model_name, dtype="double"): 
     157    # type: (str, DType) -> Optional[str] 
    140158    """ 
    141159    Precompile the dll for a model. 
  • sasmodels/generate.py

    r6d6508e rf619de7  
    164164from .modelinfo import Parameter 
    165165from .custom import load_custom_kernel_module 
     166 
     167try: 
     168    from typing import Tuple, Sequence, Iterator 
     169    from .modelinfo import ModelInfo 
     170except ImportError: 
     171    pass 
    166172 
    167173TEMPLATE_ROOT = dirname(__file__) 
     
    220226 
    221227def format_units(units): 
     228    # type: (str) -> str 
    222229    """ 
    223230    Convert units into ReStructured Text format. 
     
    226233 
    227234def make_partable(pars): 
     235    # type: (List[Parameter]) -> str 
    228236    """ 
    229237    Generate the parameter table to include in the sphinx documentation. 
     
    256264 
    257265def _search(search_path, filename): 
     266    # type: (List[str], str) -> str 
    258267    """ 
    259268    Find *filename* in *search_path*. 
     
    269278 
    270279def model_sources(model_info): 
     280    # type: (ModelInfo) -> List[str] 
    271281    """ 
    272282    Return a list of the sources file paths for the module. 
     
    277287 
    278288def timestamp(model_info): 
     289    # type: (ModelInfo) -> int 
    279290    """ 
    280291    Return a timestamp for the model corresponding to the most recently 
     
    288299 
    289300def convert_type(source, dtype): 
     301    # type: (str, np.dtype) -> str 
    290302    """ 
    291303    Convert code from double precision to the desired type. 
     
    312324 
    313325def _convert_type(source, type_name, constant_flag): 
     326    # type: (str, str, str) -> str 
    314327    """ 
    315328    Replace 'double' with *type_name* in *source*, tagging floating point 
     
    330343 
    331344def kernel_name(model_info, is_2d): 
     345    # type: (ModelInfo, bool) -> str 
    332346    """ 
    333347    Name of the exported kernel symbol. 
     
    337351 
    338352def indent(s, depth): 
     353    # type: (str, int) -> str 
    339354    """ 
    340355    Indent a string of text with *depth* additional spaces on each line. 
     
    345360 
    346361 
    347 _template_cache = {} 
     362_template_cache = {}  # type: Dict[str, Tuple[int, str, str]] 
    348363def load_template(filename): 
     364    # type: (str) -> str 
    349365    path = joinpath(TEMPLATE_ROOT, filename) 
    350366    mtime = getmtime(path) 
     
    355371 
    356372def model_templates(): 
     373    # type: () -> List[str] 
    357374    # TODO: fails DRY; templates are listed in two places. 
    358375    # should instead have model_info contain a list of paths 
     
    371388 
    372389def _gen_fn(name, pars, body): 
     390    # type: (str, List[Parameter], str) -> str 
    373391    """ 
    374392    Generate a function given pars and body. 
     
    385403 
    386404def _call_pars(prefix, pars): 
     405    # type: (str, List[Parameter]) -> List[str] 
    387406    """ 
    388407    Return a list of *prefix.parameter* from parameter items. 
     
    393412                           flags=re.MULTILINE) 
    394413def _have_Iqxy(sources): 
     414    # type: (List[str]) -> bool 
    395415    """ 
    396416    Return true if any file defines Iqxy. 
     
    414434 
    415435def make_source(model_info): 
     436    # type: (ModelInfo) -> str 
    416437    """ 
    417438    Generate the OpenCL/ctypes kernel from the module info. 
    418439 
    419     Uses source files found in the given search path. 
     440    Uses source files found in the given search path.  Returns None if this 
     441    is a pure python model, with no C source components. 
    420442    """ 
    421443    if callable(model_info.Iq): 
    422         return None 
     444        raise ValueError("can't compile python model") 
    423445 
    424446    # TODO: need something other than volume to indicate dispersion parameters 
     
    447469    q, qx, qy = [Parameter(name=v) for v in ('q', 'qx', 'qy')] 
    448470    # Generate form_volume function, etc. from body only 
    449     if model_info.form_volume is not None: 
     471    if isinstance(model_info.form_volume, str): 
    450472        pars = partable.form_volume_parameters 
    451473        source.append(_gen_fn('form_volume', pars, model_info.form_volume)) 
    452     if model_info.Iq is not None: 
     474    if isinstance(model_info.Iq, str): 
    453475        pars = [q] + partable.iq_parameters 
    454476        source.append(_gen_fn('Iq', pars, model_info.Iq)) 
    455     if model_info.Iqxy is not None: 
     477    if isinstance(model_info.Iqxy, str): 
    456478        pars = [qx, qy] + partable.iqxy_parameters 
    457479        source.append(_gen_fn('Iqxy', pars, model_info.Iqxy)) 
     
    509531 
    510532def load_kernel_module(model_name): 
     533    # type: (str) -> module 
    511534    if model_name.endswith('.py'): 
    512535        kernel_module = load_custom_kernel_module(model_name) 
     
    522545                            %re.escape(string.punctuation)) 
    523546def _convert_section_titles_to_boldface(lines): 
     547    # type: (Sequence[str]) -> Iterator[str] 
    524548    """ 
    525549    Do the actual work of identifying and converting section headings. 
     
    543567 
    544568def convert_section_titles_to_boldface(s): 
     569    # type: (str) -> str 
    545570    """ 
    546571    Use explicit bold-face rather than section headings so that the table of 
     
    553578 
    554579def make_doc(model_info): 
     580    # type: (ModelInfo) -> str 
    555581    """ 
    556582    Return the documentation for the model. 
     
    562588                 name=model_info.name, 
    563589                 title=model_info.title, 
    564                  parameters=make_partable(model_info.parameters), 
     590                 parameters=make_partable(model_info.parameters.kernel_parameters), 
    565591                 returns=Sq_units if model_info.structure_factor else Iq_units, 
    566592                 docs=docs) 
     
    569595 
    570596def demo_time(): 
     597    # type: () -> None 
    571598    """ 
    572599    Show how long it takes to process a model. 
     
    582609 
    583610def main(): 
     611    # type: () -> None 
    584612    """ 
    585613    Program which prints the source produced by the model. 
  • sasmodels/kernelcl.py

    r6d6508e rf619de7  
    6767 
    6868from . import generate 
     69from .kernel import KernelModel, Kernel 
    6970 
    7071# The max loops number is limited by the amount of local memory available 
     
    310311 
    311312 
    312 class GpuModel(object): 
     313class GpuModel(KernelModel): 
    313314    """ 
    314315    GPU wrapper for a single model. 
     
    420421        self.release() 
    421422 
    422 class GpuKernel(object): 
     423class GpuKernel(Kernel): 
    423424    """ 
    424425    Callable SAS kernel. 
     
    489490        self.kernel(self.queue, self.q_input.global_size, None, *args) 
    490491        cl.enqueue_copy(self.queue, self.result, self.result_b) 
    491         [v.release() for v in details_b, weights_b, values_b] 
     492        [v.release() for v in (details_b, weights_b, values_b)] 
    492493 
    493494        return self.result[:self.nq] 
  • sasmodels/kerneldll.py

    r6d6508e rf619de7  
    5656from . import generate 
    5757from . import details 
    58 from .kernelpy import PyInput, PyModel 
     58from .kernel import KernelModel, Kernel 
     59from .kernelpy import PyInput 
    5960from .exception import annotate_exception 
     61from .generate import F16, F32, F64 
     62 
     63try: 
     64    from typing import Tuple, Callable, Any 
     65    from .modelinfo import ModelInfo 
     66    from .details import CallDetails 
     67except ImportError: 
     68    pass 
    6069 
    6170# Compiler platform details 
     
    91100 
    92101def dll_name(model_info, dtype): 
     102    # type: (ModelInfo, np.dtype) ->  str 
    93103    """ 
    94104    Name of the dll containing the model.  This is the base file name without 
     
    98108    return "sas_%s%d"%(model_info.id, bits) 
    99109 
     110 
    100111def dll_path(model_info, dtype): 
     112    # type: (ModelInfo, np.dtype) -> str 
    101113    """ 
    102114    Complete path to the dll for the model.  Note that the dll may not 
     
    105117    return os.path.join(DLL_PATH, dll_name(model_info, dtype)+".so") 
    106118 
    107 def make_dll(source, model_info, dtype="double"): 
    108     """ 
    109     Load the compiled model defined by *kernel_module*. 
    110  
    111     Recompile if any files are newer than the model file. 
     119 
     120def make_dll(source, model_info, dtype=F64): 
     121    # type: (str, ModelInfo, np.dtype) -> str 
     122    """ 
     123    Returns the path to the compiled model defined by *kernel_module*. 
     124 
     125    If the model has not been compiled, or if the source file(s) are newer 
     126    than the dll, then *make_dll* will compile the model before returning. 
     127    This routine does not load the resulting dll. 
    112128 
    113129    *dtype* is a numpy floating point precision specifier indicating whether 
    114     the model should be single or double precision.  The default is double 
    115     precision. 
    116  
    117     The DLL is not loaded until the kernel is called so models can 
    118     be defined without using too many resources. 
     130    the model should be single, double or long double precision.  The default 
     131    is double precision, *np.dtype('d')*. 
     132 
     133    Set *sasmodels.ALLOW_SINGLE_PRECISION_DLLS* to False if single precision 
     134    models are not allowed as DLLs. 
    119135 
    120136    Set *sasmodels.kerneldll.DLL_PATH* to the compiled dll output path. 
    121137    The default is the system temporary directory. 
    122  
    123     Set *sasmodels.ALLOW_SINGLE_PRECISION_DLLS* to True if single precision 
    124     models are allowed as DLLs. 
    125     """ 
    126     if callable(model_info.Iq): 
    127         return PyModel(model_info) 
    128      
    129     dtype = np.dtype(dtype) 
    130     if dtype == generate.F16: 
     138    """ 
     139    if dtype == F16: 
    131140        raise ValueError("16 bit floats not supported") 
    132     if dtype == generate.F32 and not ALLOW_SINGLE_PRECISION_DLLS: 
    133         dtype = generate.F64  # Force 64-bit dll 
    134  
    135     source = generate.convert_type(source, dtype) 
     141    if dtype == F32 and not ALLOW_SINGLE_PRECISION_DLLS: 
     142        dtype = F64  # Force 64-bit dll 
     143    # Note: dtype may be F128 for long double precision 
     144 
    136145    newest = generate.timestamp(model_info) 
    137146    dll = dll_path(model_info, dtype) 
     
    139148        basename = dll_name(model_info, dtype) + "_" 
    140149        fid, filename = tempfile.mkstemp(suffix=".c", prefix=basename) 
     150        source = generate.convert_type(source, dtype) 
    141151        os.fdopen(fid, "w").write(source) 
    142152        command = COMPILE%{"source":filename, "output":dll} 
     
    152162 
    153163 
    154 def load_dll(source, model_info, dtype="double"): 
     164def load_dll(source, model_info, dtype=F64): 
     165    # type: (str, ModelInfo, np.dtype) -> "DllModel" 
    155166    """ 
    156167    Create and load a dll corresponding to the source, info pair returned 
     
    163174    return DllModel(filename, model_info, dtype=dtype) 
    164175 
    165 class DllModel(object): 
     176 
     177class DllModel(KernelModel): 
    166178    """ 
    167179    ctypes wrapper for a single model. 
     
    179191     
    180192    def __init__(self, dllpath, model_info, dtype=generate.F32): 
     193        # type: (str, ModelInfo, np.dtype) -> None 
    181194        self.info = model_info 
    182195        self.dllpath = dllpath 
    183         self.dll = None 
     196        self._dll = None  # type: ct.CDLL 
    184197        self.dtype = np.dtype(dtype) 
    185198 
    186199    def _load_dll(self): 
     200        # type: () -> None 
    187201        #print("dll", self.dllpath) 
    188202        try: 
    189             self.dll = ct.CDLL(self.dllpath) 
     203            self._dll = ct.CDLL(self.dllpath) 
    190204        except: 
    191205            annotate_exception("while loading "+self.dllpath) 
     
    198212        # int, int, int, int*, double*, double*, double*, double*, double*, double 
    199213        argtypes = [c_int32]*3 + [c_void_p]*5 + [fp] 
    200         self.Iq = self.dll[generate.kernel_name(self.info, False)] 
    201         self.Iqxy = self.dll[generate.kernel_name(self.info, True)] 
    202         self.Iq.argtypes = argtypes 
    203         self.Iqxy.argtypes = argtypes 
     214        self._Iq = self._dll[generate.kernel_name(self.info, is_2d=False)] 
     215        self._Iqxy = self._dll[generate.kernel_name(self.info, is_2d=True)] 
     216        self._Iq.argtypes = argtypes 
     217        self._Iqxy.argtypes = argtypes 
    204218 
    205219    def __getstate__(self): 
     220        # type: () -> Tuple[ModelInfo, str] 
    206221        return self.info, self.dllpath 
    207222 
    208223    def __setstate__(self, state): 
     224        # type: (Tuple[ModelInfo, str]) -> None 
    209225        self.info, self.dllpath = state 
    210         self.dll = None 
     226        self._dll = None 
    211227 
    212228    def make_kernel(self, q_vectors): 
     229        # type: (List[np.ndarray]) -> DllKernel 
    213230        q_input = PyInput(q_vectors, self.dtype) 
    214         if self.dll is None: self._load_dll() 
    215         kernel = self.Iqxy if q_input.is_2d else self.Iq 
     231        # Note: pickle not supported for DllKernel 
     232        if self._dll is None: 
     233            self._load_dll() 
     234        kernel = self._Iqxy if q_input.is_2d else self._Iq 
    216235        return DllKernel(kernel, self.info, q_input) 
    217236 
    218237    def release(self): 
     238        # type: () -> None 
    219239        """ 
    220240        Release any resources associated with the model. 
     
    225245            libHandle = dll._handle 
    226246            #libHandle = ct.c_void_p(dll._handle) 
    227             del dll, self.dll 
    228             self.dll = None 
     247            del dll, self._dll 
     248            self._dll = None 
    229249            #_ctypes.FreeLibrary(libHandle) 
    230250            ct.windll.kernel32.FreeLibrary(libHandle) 
     
    233253 
    234254 
    235 class DllKernel(object): 
     255class DllKernel(Kernel): 
    236256    """ 
    237257    Callable SAS kernel. 
     
    253273    """ 
    254274    def __init__(self, kernel, model_info, q_input): 
     275        # type: (Callable[[], np.ndarray], ModelInfo, PyInput) -> None 
    255276        self.kernel = kernel 
    256277        self.info = model_info 
     
    261282 
    262283    def __call__(self, call_details, weights, values, cutoff): 
     284        # type: (CallDetails, np.ndarray, np.ndarray, float) -> np.ndarray 
    263285        real = (np.float32 if self.q_input.dtype == generate.F32 
    264286                else np.float64 if self.q_input.dtype == generate.F64 
     
    282304            real(cutoff), # cutoff 
    283305            ] 
    284         self.kernel(*args) 
     306        self.kernel(*args) # type: ignore 
    285307        return self.result[:-3] 
    286308 
    287309    def release(self): 
     310        # type: () -> None 
    288311        """ 
    289312        Release any resources associated with the kernel. 
    290313        """ 
    291         pass 
     314        self.q_input.release() 
  • sasmodels/kernelpy.py

    r9a943d0 rf619de7  
    1212from . import details 
    1313from .generate import F64 
     14from .kernel import KernelModel, Kernel 
    1415 
    1516try: 
     
    2021    DType = Union[None, str, np.dtype] 
    2122 
    22 class PyModel(object): 
     23class PyModel(KernelModel): 
    2324    """ 
    2425    Wrapper for pure python models. 
     
    7778        self.q = None 
    7879 
    79 class PyKernel(object): 
     80class PyKernel(Kernel): 
    8081    """ 
    8182    Callable SAS kernel. 
     
    162163        Free resources associated with the kernel. 
    163164        """ 
     165        self.q_input.release() 
    164166        self.q_input = None 
    165167 
  • sasmodels/mixture.py

    r6d6508e rf619de7  
    1515 
    1616from .modelinfo import Parameter, ParameterTable, ModelInfo 
     17from .kernel import KernelModel, Kernel 
     18 
     19try: 
     20    from typing import List 
     21    from .details import CallDetails 
     22except ImportError: 
     23    pass 
    1724 
    1825def make_mixture_info(parts): 
     26    # type: (List[ModelInfo]) -> ModelInfo 
    1927    """ 
    2028    Create info block for product model. 
     
    2230    flatten = [] 
    2331    for part in parts: 
    24         if part['composition'] and part['composition'][0] == 'mixture': 
    25             flatten.extend(part['compostion'][1]) 
     32        if part.composition and part.composition[0] == 'mixture': 
     33            flatten.extend(part.composition[1]) 
    2634        else: 
    2735            flatten.append(part) 
     
    2937 
    3038    # Build new parameter list 
    31     pars = [] 
     39    combined_pars = [] 
    3240    for k, part in enumerate(parts): 
    3341        # Parameter prefix per model, A_, B_, ... 
     
    3543        # to support vector parameters 
    3644        prefix = chr(ord('A')+k) + '_' 
    37         pars.append(Parameter(prefix+'scale')) 
    38         for p in part['parameters'].kernel_pars: 
     45        combined_pars.append(Parameter(prefix+'scale')) 
     46        for p in part.parameters.kernel_parameters: 
    3947            p = copy(p) 
    40             p.name = prefix+p.name 
    41             p.id = prefix+p.id 
     48            p.name = prefix + p.name 
     49            p.id = prefix + p.id 
    4250            if p.length_control is not None: 
    43                 p.length_control = prefix+p.length_control 
    44             pars.append(p) 
    45     partable = ParameterTable(pars) 
     51                p.length_control = prefix + p.length_control 
     52            combined_pars.append(p) 
     53    parameters = ParameterTable(combined_pars) 
    4654 
    4755    model_info = ModelInfo() 
    48     model_info.id = '+'.join(part['id']) 
    49     model_info.name = ' + '.join(part['name']) 
     56    model_info.id = '+'.join(part.id for part in parts) 
     57    model_info.name = ' + '.join(part.name for part in parts) 
    5058    model_info.filename = None 
    5159    model_info.title = 'Mixture model with ' + model_info.name 
     
    5361    model_info.docs = model_info.title 
    5462    model_info.category = "custom" 
    55     model_info.parameters = partable 
     63    model_info.parameters = parameters 
    5664    #model_info.single = any(part['single'] for part in parts) 
    5765    model_info.structure_factor = False 
     
    6472 
    6573 
    66 class MixtureModel(object): 
     74class MixtureModel(KernelModel): 
    6775    def __init__(self, model_info, parts): 
     76        # type: (ModelInfo, List[KernelModel]) -> None 
    6877        self.info = model_info 
    6978        self.parts = parts 
    7079 
    7180    def __call__(self, q_vectors): 
     81        # type: (List[np.ndarray]) -> MixtureKernel 
    7282        # Note: may be sending the q_vectors to the n times even though they 
    7383        # are only needed once.  It would mess up modularity quite a bit to 
     
    7686        # in opencl; or both in opencl, but one in single precision and the 
    7787        # other in double precision). 
    78         kernels = [part(q_vectors) for part in self.parts] 
     88        kernels = [part.make_kernel(q_vectors) for part in self.parts] 
    7989        return MixtureKernel(self.info, kernels) 
    8090 
    8191    def release(self): 
     92        # type: () -> None 
    8293        """ 
    8394        Free resources associated with the model. 
     
    8798 
    8899 
    89 class MixtureKernel(object): 
     100class MixtureKernel(Kernel): 
    90101    def __init__(self, model_info, kernels): 
    91         dim = '2d' if kernels[0].q_input.is_2d else '1d' 
     102        # type: (ModelInfo, List[Kernel]) -> None 
     103        self.dim = kernels[0].dim 
     104        self.info =  model_info 
     105        self.kernels = kernels 
    92106 
    93         # fixed offsets starts at 2 for scale and background 
    94         fixed_pars, pd_pars = [], [] 
    95         offsets = [[2, 0]] 
    96         #vol_index = [] 
    97         def accumulate(fixed, pd, volume): 
    98             # subtract 1 from fixed since we are removing background 
    99             fixed_offset, pd_offset = offsets[-1] 
    100             #vol_index.extend(k+pd_offset for k,v in pd if v in volume) 
    101             offsets.append([fixed_offset + len(fixed) - 1, pd_offset + len(pd)]) 
    102             pd_pars.append(pd) 
    103         if dim == '2d': 
    104             for p in kernels: 
    105                 partype = p.info.partype 
    106                 accumulate(partype['fixed-2d'], partype['pd-2d'], partype['volume']) 
    107         else: 
    108             for p in kernels: 
    109                 partype = p.info.partype 
    110                 accumulate(partype['fixed-1d'], partype['pd-1d'], partype['volume']) 
    111  
    112         #self.vol_index = vol_index 
    113         self.offsets = offsets 
    114         self.fixed_pars = fixed_pars 
    115         self.pd_pars = pd_pars 
    116         self.info = model_info 
    117         self.kernels = kernels 
    118         self.results = None 
    119  
    120     def __call__(self, fixed_pars, pd_pars, cutoff=1e-5): 
    121         scale, background = fixed_pars[0:2] 
     107    def __call__(self, call_details, value, weight, cutoff): 
     108        # type: (CallDetails, np.ndarray, np.ndarry, float) -> np.ndarray 
     109        scale, background = value[0:2] 
    122110        total = 0.0 
    123         self.results = []  # remember the parts for plotting later 
    124         for k in range(len(self.offsets)-1): 
    125             start_fixed, start_pd = self.offsets[k] 
    126             end_fixed, end_pd = self.offsets[k+1] 
    127             part_fixed = [fixed_pars[start_fixed], 0.0] + fixed_pars[start_fixed+1:end_fixed] 
    128             part_pd = [pd_pars[start_pd], 0.0] + pd_pars[start_pd+1:end_pd] 
    129             part_result = self.kernels[k](part_fixed, part_pd) 
     111        # remember the parts for plotting later 
     112        self.results = [] 
     113        for kernel, kernel_details in zip(self.kernels, call_details.parts): 
     114            part_result = kernel(kernel_details, value, weight, cutoff) 
    130115            total += part_result 
    131             self.results.append(scale*sum+background) 
     116            self.results.append(part_result) 
    132117 
    133118        return scale*total + background 
    134119 
    135120    def release(self): 
    136         self.p_kernel.release() 
    137         self.q_kernel.release() 
     121        # type: () -> None 
     122        for k in self.kernels: 
     123            k.release() 
    138124 
  • sasmodels/model_test.py

    rc1a888b rf619de7  
    6969    # type: (ModelInfo, ParameterSet) -> float 
    7070    """ 
    71     Call the model ER function using *values*. *model_info* is either 
    72     *model.info* if you have a loaded model, or *kernel.info* if you 
    73     have a model kernel prepared for evaluation. 
     71    Call the model ER function using *values*. 
     72 
     73    *model_info* is either *model.info* if you have a loaded model, 
     74    or *kernel.info* if you have a model kernel prepared for evaluation. 
    7475    """ 
    7576    if model_info.ER is None: 
     
    8485    """ 
    8586    Call the model VR function using *pars*. 
    86     *info* is either *model.info* if you have a loaded model, or *kernel.info* 
    87     if you have a model kernel prepared for evaluation. 
     87 
     88    *model_info* is either *model.info* if you have a loaded model, 
     89    or *kernel.info* if you have a model kernel prepared for evaluation. 
    8890    """ 
    8991    if model_info.VR is None: 
  • sasmodels/modelinfo.py

    r9a943d0 rf619de7  
    713713    ER = None               # type: Optional[Callable[[np.ndarray], np.ndarray]] 
    714714    VR = None               # type: Optional[Callable[[np.ndarray], Tuple[np.ndarray, np.ndarray]]] 
    715     form_volume = None      # type: Optional[Callable[[np.ndarray], float]] 
    716     Iq = None               # type: Optional[Callable[[np.ndarray], np.ndarray]] 
    717     Iqxy = None             # type: Optional[Callable[[np.ndarray], np.ndarray]] 
     715    form_volume = None      # type: Union[None, str, Callable[[np.ndarray], float]] 
     716    Iq = None               # type: Union[None, str, Callable[[np.ndarray], np.ndarray]] 
     717    Iqxy = None             # type: Union[None, str, Callable[[np.ndarray], np.ndarray]] 
    718718    profile = None          # type: Optional[Callable[[np.ndarray], None]] 
    719719    sesans = None           # type: Optional[Callable[[np.ndarray], np.ndarray]] 
  • sasmodels/product.py

    r6d6508e rf619de7  
    1414 
    1515from .details import dispersion_mesh 
    16 from .modelinfo import suffix_parameter, ParameterTable, Parameter, ModelInfo 
     16from .modelinfo import suffix_parameter, ParameterTable, ModelInfo 
     17from .kernel import KernelModel, Kernel 
     18 
     19try: 
     20    from typing import Tuple 
     21    from .modelinfo import ParameterSet 
     22    from .details import CallDetails 
     23except ImportError: 
     24    pass 
    1725 
    1826# TODO: make estimates available to constraints 
     
    2533# revert it after making VR and ER available at run time as constraints. 
    2634def make_product_info(p_info, s_info): 
     35    # type: (ModelInfo, ModelInfo) -> ModelInfo 
    2736    """ 
    2837    Create info block for product model. 
    2938    """ 
    30     p_id, p_name, p_partable = p_info.id, p_info.name, p_info.parameters 
    31     s_id, s_name, s_partable = s_info.id, s_info.name, s_info.parameters 
    32     p_set = set(p.id for p in p_partable) 
    33     s_set = set(p.id for p in s_partable) 
     39    p_id, p_name, p_pars = p_info.id, p_info.name, p_info.parameters 
     40    s_id, s_name, s_pars = s_info.id, s_info.name, s_info.parameters 
     41    p_set = set(p.id for p in p_pars.call_parameters) 
     42    s_set = set(p.id for p in s_pars.call_parameters) 
    3443 
    3544    if p_set & s_set: 
    3645        # there is some overlap between the parameter names; tag the 
    3746        # overlapping S parameters with name_S 
    38         s_pars = [(suffix_parameter(par, "_S") if par.id in p_set else par) 
    39                   for par in s_partable.kernel_parameters] 
    40         pars = p_partable.kernel_parameters + s_pars 
     47        s_list = [(suffix_parameter(par, "_S") if par.id in p_set else par) 
     48                  for par in s_pars.kernel_parameters] 
     49        combined_pars = p_pars.kernel_parameters + s_list 
    4150    else: 
    42         pars= p_partable.kernel_parameters + s_partable.kernel_parameters 
     51        combined_pars = p_pars.kernel_parameters + s_pars.kernel_parameters 
     52    parameters = ParameterTable(combined_pars) 
    4353 
    4454    model_info = ModelInfo() 
     
    5060    model_info.docs = model_info.title 
    5161    model_info.category = "custom" 
    52     model_info.parameters = ParameterTable(pars) 
     62    model_info.parameters = parameters 
    5363    #model_info.single = p_info.single and s_info.single 
    5464    model_info.structure_factor = False 
     
    6070    return model_info 
    6171 
    62 class ProductModel(object): 
     72class ProductModel(KernelModel): 
    6373    def __init__(self, model_info, P, S): 
     74        # type: (ModelInfo, KernelModel, KernelModel) -> None 
    6475        self.info = model_info 
    6576        self.P = P 
     
    6778 
    6879    def __call__(self, q_vectors): 
     80        # type: (List[np.ndarray]) -> Kernel 
    6981        # Note: may be sending the q_vectors to the GPU twice even though they 
    7082        # are only needed once.  It would mess up modularity quite a bit to 
     
    7385        # in opencl; or both in opencl, but one in single precision and the 
    7486        # other in double precision). 
    75         p_kernel = self.P(q_vectors) 
    76         s_kernel = self.S(q_vectors) 
     87        p_kernel = self.P.make_kernel(q_vectors) 
     88        s_kernel = self.S.make_kernel(q_vectors) 
    7789        return ProductKernel(self.info, p_kernel, s_kernel) 
    7890 
    7991    def release(self): 
     92        # type: (None) -> None 
    8093        """ 
    8194        Free resources associated with the model. 
     
    8598 
    8699 
    87 class ProductKernel(object): 
     100class ProductKernel(Kernel): 
    88101    def __init__(self, model_info, p_kernel, s_kernel): 
     102        # type: (ModelInfo, Kernel, Kernel) -> None 
    89103        self.info = model_info 
    90104        self.p_kernel = p_kernel 
     
    92106 
    93107    def __call__(self, details, weights, values, cutoff): 
     108        # type: (CallDetails, np.ndarray, np.ndarray, float) -> np.ndarray 
    94109        effect_radius, vol_ratio = call_ER_VR(self.p_kernel.info, vol_pars) 
    95110 
     
    108123 
    109124    def release(self): 
     125        # type: () -> None 
    110126        self.p_kernel.release() 
    111         self.q_kernel.release() 
     127        self.s_kernel.release() 
    112128 
    113 def call_ER_VR(model_info, vol_pars): 
     129def call_ER_VR(model_info, pars): 
    114130    """ 
    115131    Return effect radius and volume ratio for the model. 
    116132    """ 
    117     value, weight = dispersion_mesh(vol_pars) 
     133    if model_info.ER is None and model_info.VR is None: 
     134        return 1.0, 1.0 
    118135 
    119     individual_radii = model_info.ER(*value) if model_info.ER else 1.0 
    120     whole, part = model_info.VR(*value) if model_info.VR else (1.0, 1.0) 
     136    value, weight = _vol_pars(model_info, pars) 
    121137 
    122     effect_radius = np.sum(weight*individual_radii) / np.sum(weight) 
    123     volume_ratio = np.sum(weight*part)/np.sum(weight*whole) 
     138    if model_info.ER is not None: 
     139        individual_radii = model_info.ER(*value) 
     140        effect_radius = np.sum(weight*individual_radii) / np.sum(weight) 
     141    else: 
     142        effect_radius = 1.0 
     143 
     144    if model_info.VR is not None: 
     145        whole, part = model_info.VR(*value) 
     146        volume_ratio = np.sum(weight*part)/np.sum(weight*whole) 
     147    else: 
     148        volume_ratio = 1.0 
     149 
    124150    return effect_radius, volume_ratio 
     151 
     152def _vol_pars(model_info, pars): 
     153    # type: (ModelInfo, ParameterSet) -> Tuple[np.ndarray, np.ndarray] 
     154    vol_pars = [get_weights(p, pars) 
     155                for p in model_info.parameters.call_parameters 
     156                if p.type == 'volume'] 
     157    value, weight = dispersion_mesh(model_info, vol_pars) 
     158    return value, weight 
     159 
Note: See TracChangeset for help on using the changeset viewer.