source: sasmodels/sasmodels/core.py @ a85a569

core_shell_microgelscostrafo411magnetic_modelticket-1257-vesicle-productticket_1156ticket_1265_superballticket_822_more_unit_tests
Last change on this file since a85a569 was a85a569, checked in by lewis, 7 years ago

Merge branch 'master' into ticket-767

  • Property mode set to 100644
File size: 12.3 KB
RevLine 
[aa4946b]1"""
2Core model handling routines.
3"""
[6d6508e]4from __future__ import print_function
5
[98f60fc]6__all__ = [
[6d6508e]7    "list_models", "load_model", "load_model_info",
[2547694]8    "build_model", "precompile_dlls",
[98f60fc]9    ]
[f734e7d]10
[7bf4757]11import os
[60335cc]12import re
[2547694]13from os.path import basename, dirname, join as joinpath
[f734e7d]14from glob import glob
15
[7ae2b7f]16import numpy as np # type: ignore
[f734e7d]17
[aa4946b]18from . import generate
[6d6508e]19from . import modelinfo
20from . import product
[72a081d]21from . import mixture
[aa4946b]22from . import kernelpy
23from . import kerneldll
[60335cc]24from . import custom
[880a2ed]25
[a557a99]26if os.environ.get("SAS_OPENCL", "").lower() == "none":
[aa4946b]27    HAVE_OPENCL = False
[880a2ed]28else:
29    try:
30        from . import kernelcl
31        HAVE_OPENCL = True
32    except Exception:
33        HAVE_OPENCL = False
[aa4946b]34
[60335cc]35CUSTOM_MODEL_PATH = os.environ.get('SAS_MODELPATH', "")
36if CUSTOM_MODEL_PATH == "":
37    path = joinpath(os.path.expanduser("~"), ".sasmodels", "custom_models")
38    if not os.path.isdir(path):
39        os.makedirs(path)
40    CUSTOM_MODEL_PATH = path
41
[f619de7]42try:
43    from typing import List, Union, Optional, Any
44    from .kernel import KernelModel
[dd7fc12]45    from .modelinfo import ModelInfo
[f619de7]46except ImportError:
47    pass
48
[4d76711]49# TODO: refactor composite model support
50# The current load_model_info/build_model does not reuse existing model
51# definitions when loading a composite model, instead reloading and
52# rebuilding the kernel for each component model in the expression.  This
53# is fine in a scripting environment where the model is built when the script
54# starts and is thrown away when the script ends, but may not be the best
55# solution in a long-lived application.  This affects the following functions:
56#
57#    load_model
58#    load_model_info
59#    build_model
[f734e7d]60
[8407d8c]61KINDS = ("all", "py", "c", "double", "single", "opencl", "1d", "2d",
[2547694]62         "nonmagnetic", "magnetic")
[0b03001]63def list_models(kind=None):
[52e9a45]64    # type: (str) -> List[str]
[aa4946b]65    """
66    Return the list of available models on the model path.
[40a87fa]67
68    *kind* can be one of the following:
69
70        * all: all models
71        * py: python models only
72        * c: compiled models only
73        * single: models which support single precision
74        * double: models which require double precision
[8407d8c]75        * opencl: controls if OpenCL is supperessed
[40a87fa]76        * 1d: models which are 1D only, or 2D using abs(q)
77        * 2d: models which can be 2D
78        * magnetic: models with an sld
79        * nommagnetic: models without an sld
[5124c969]80
81    For multiple conditions, combine with plus.  For example, *c+single+2d*
82    would return all oriented models implemented in C which can be computed
83    accurately with single precision arithmetic.
[aa4946b]84    """
[5124c969]85    if kind and any(k not in KINDS for k in kind.split('+')):
[2547694]86        raise ValueError("kind not in " + ", ".join(KINDS))
[3d9001f]87    files = sorted(glob(joinpath(generate.MODEL_PATH, "[a-zA-Z]*.py")))
[f734e7d]88    available_models = [basename(f)[:-3] for f in files]
[5124c969]89    if kind and '+' in kind:
90        all_kinds = kind.split('+')
91        condition = lambda name: all(_matches(name, k) for k in all_kinds)
92    else:
93        condition = lambda name: _matches(name, kind)
94    selected = [name for name in available_models if condition(name)]
[0b03001]95
96    return selected
97
98def _matches(name, kind):
[2547694]99    if kind is None or kind == "all":
[0b03001]100        return True
101    info = load_model_info(name)
102    pars = info.parameters.kernel_parameters
103    if kind == "py" and callable(info.Iq):
104        return True
105    elif kind == "c" and not callable(info.Iq):
106        return True
107    elif kind == "double" and not info.single:
108        return True
[d2d6100]109    elif kind == "single" and info.single:
110        return True
[8407d8c]111    elif kind == "opencl" and info.opencl:
[407bf48]112        return True
[2547694]113    elif kind == "2d" and any(p.type == 'orientation' for p in pars):
[d2d6100]114        return True
[40a87fa]115    elif kind == "1d" and all(p.type != 'orientation' for p in pars):
[0b03001]116        return True
[2547694]117    elif kind == "magnetic" and any(p.type == 'sld' for p in pars):
[0b03001]118        return True
[2547694]119    elif kind == "nonmagnetic" and any(p.type != 'sld' for p in pars):
[d2d6100]120        return True
[0b03001]121    return False
[f734e7d]122
[f619de7]123def load_model(model_name, dtype=None, platform='ocl'):
[dd7fc12]124    # type: (str, str, str) -> KernelModel
[b8e5e21]125    """
126    Load model info and build model.
[f619de7]127
[2e66ef5]128    *model_name* is the name of the model, or perhaps a model expression
129    such as sphere*hardsphere or sphere+cylinder.
130
131    *dtype* and *platform* are given by :func:`build_model`.
[b8e5e21]132    """
[f619de7]133    return build_model(load_model_info(model_name),
134                       dtype=dtype, platform=platform)
[aa4946b]135
[61a4bd4]136def load_model_info(model_string):
[f619de7]137    # type: (str) -> modelinfo.ModelInfo
[aa4946b]138    """
139    Load a model definition given the model name.
[1d4017a]140
[481ff64]141    *model_string* is the name of the model, or perhaps a model expression
142    such as sphere*cylinder or sphere+cylinder. Use '@' for a structure
[60335cc]143    factor product, e.g. sphere@hardsphere. Custom models can be specified by
144    prefixing the model name with 'custom.', e.g. 'custom.MyModel+sphere'.
[2e66ef5]145
[1d4017a]146    This returns a handle to the module defining the model.  This can be
147    used with functions in generate to build the docs or extract model info.
[aa4946b]148    """
[481ff64]149    if '@' in model_string:
150        parts = model_string.split('@')
151        if len(parts) != 2:
152            raise ValueError("Use P@S to apply a structure factor S to model P")
153        P_info, Q_info = [load_model_info(part) for part in parts]
154        return product.make_product_info(P_info, Q_info)
155
[61a4bd4]156    product_parts = []
157    addition_parts = []
158
159    addition_parts_names = model_string.split('+')
160    if len(addition_parts_names) >= 2:
161        addition_parts = [load_model_info(part) for part in addition_parts_names]
162    elif len(addition_parts_names) == 1:
163        product_parts_names = model_string.split('*')
164        if len(product_parts_names) >= 2:
165            product_parts = [load_model_info(part) for part in product_parts_names]
166        elif len(product_parts_names) == 1:
[60335cc]167            if "custom." in product_parts_names[0]:
168                # Extract ModelName from "custom.ModelName"
169                pattern = "custom.([A-Za-z0-9_-]+)"
170                result = re.match(pattern, product_parts_names[0])
171                if result is None:
172                    raise ValueError("Model name in invalid format: " + product_parts_names[0])
173                model_name = result.group(1)
174                # Use ModelName to find the path to the custom model file
175                model_path = joinpath(CUSTOM_MODEL_PATH, model_name + ".py")
176                if not os.path.isfile(model_path):
177                    raise ValueError("The model file {} doesn't exist".format(model_path))
178                kernel_module = custom.load_custom_kernel_module(model_path)
179                return modelinfo.make_model_info(kernel_module)
180            # Model is a core model
[61a4bd4]181            kernel_module = generate.load_kernel_module(product_parts_names[0])
182            return modelinfo.make_model_info(kernel_module)
183
184    model = None
185    if len(product_parts) > 1:
186        model = mixture.make_mixture_info(product_parts, operation='*')
187    if len(addition_parts) > 1:
188        if model is not None:
189            addition_parts.append(model)
190        model = mixture.make_mixture_info(addition_parts, operation='+')
191    return model
[d19962c]192
193
[17bbadd]194def build_model(model_info, dtype=None, platform="ocl"):
[dd7fc12]195    # type: (modelinfo.ModelInfo, str, str) -> KernelModel
[aa4946b]196    """
197    Prepare the model for the default execution platform.
198
199    This will return an OpenCL model, a DLL model or a python model depending
200    on the model and the computing platform.
201
[17bbadd]202    *model_info* is the model definition structure returned from
203    :func:`load_model_info`.
[bcd3aa3]204
[aa4946b]205    *dtype* indicates whether the model should use single or double precision
[dd7fc12]206    for the calculation.  Choices are 'single', 'double', 'quad', 'half',
207    or 'fast'.  If *dtype* ends with '!', then force the use of the DLL rather
208    than OpenCL for the calculation.
[aa4946b]209
210    *platform* should be "dll" to force the dll to be used for C models,
211    otherwise it uses the default "ocl".
212    """
[6d6508e]213    composition = model_info.composition
[72a081d]214    if composition is not None:
215        composition_type, parts = composition
216        models = [build_model(p, dtype=dtype, platform=platform) for p in parts]
217        if composition_type == 'mixture':
218            return mixture.MixtureModel(model_info, models)
219        elif composition_type == 'product':
[e79f0a5]220            P, S = models
[72a081d]221            return product.ProductModel(model_info, P, S)
222        else:
223            raise ValueError('unknown mixture type %s'%composition_type)
[aa4946b]224
[fa5fd8d]225    # If it is a python model, return it immediately
226    if callable(model_info.Iq):
227        return kernelpy.PyModel(model_info)
228
[def2c1b]229    numpy_dtype, fast, platform = parse_dtype(model_info, dtype, platform)
[7891a2a]230
[72a081d]231    source = generate.make_source(model_info)
[7891a2a]232    if platform == "dll":
[f2f67a6]233        #print("building dll", numpy_dtype)
[a4280bd]234        return kerneldll.load_dll(source['dll'], model_info, numpy_dtype)
[aa4946b]235    else:
[f2f67a6]236        #print("building ocl", numpy_dtype)
[dd7fc12]237        return kernelcl.GpuModel(source, model_info, numpy_dtype, fast=fast)
[aa4946b]238
[7bf4757]239def precompile_dlls(path, dtype="double"):
[7891a2a]240    # type: (str, str) -> List[str]
[b8e5e21]241    """
[7bf4757]242    Precompile the dlls for all builtin models, returning a list of dll paths.
[b8e5e21]243
[7bf4757]244    *path* is the directory in which to save the dlls.  It will be created if
245    it does not already exist.
[b8e5e21]246
247    This can be used when build the windows distribution of sasmodels
[7bf4757]248    which may be missing the OpenCL driver and the dll compiler.
[b8e5e21]249    """
[7891a2a]250    numpy_dtype = np.dtype(dtype)
[7bf4757]251    if not os.path.exists(path):
252        os.makedirs(path)
253    compiled_dlls = []
254    for model_name in list_models():
255        model_info = load_model_info(model_name)
[a4280bd]256        if not callable(model_info.Iq):
257            source = generate.make_source(model_info)['dll']
[7bf4757]258            old_path = kerneldll.DLL_PATH
259            try:
260                kerneldll.DLL_PATH = path
[def2c1b]261                dll = kerneldll.make_dll(source, model_info, dtype=numpy_dtype)
[7bf4757]262            finally:
263                kerneldll.DLL_PATH = old_path
264            compiled_dlls.append(dll)
265    return compiled_dlls
[b8e5e21]266
[7891a2a]267def parse_dtype(model_info, dtype=None, platform=None):
268    # type: (ModelInfo, str, str) -> (np.dtype, bool, str)
[dd7fc12]269    """
270    Interpret dtype string, returning np.dtype and fast flag.
271
272    Possible types include 'half', 'single', 'double' and 'quad'.  If the
[2e66ef5]273    type is 'fast', then this is equivalent to dtype 'single' but using
274    fast native functions rather than those with the precision level guaranteed
275    by the OpenCL standard.
276
277    Platform preference can be specfied ("ocl" vs "dll"), with the default
278    being OpenCL if it is availabe.  If the dtype name ends with '!' then
279    platform is forced to be DLL rather than OpenCL.
280
281    This routine ignores the preferences within the model definition.  This
282    is by design.  It allows us to test models in single precision even when
283    we have flagged them as requiring double precision so we can easily check
284    the performance on different platforms without having to change the model
285    definition.
[dd7fc12]286    """
[7891a2a]287    # Assign default platform, overriding ocl with dll if OpenCL is unavailable
[8407d8c]288    # If opencl=False OpenCL is switched off
[fe77343]289
[f49675c]290    if platform is None:
[7891a2a]291        platform = "ocl"
[8407d8c]292    if platform == "ocl" and not HAVE_OPENCL or not model_info.opencl:
[7891a2a]293        platform = "dll"
294
295    # Check if type indicates dll regardless of which platform is given
296    if dtype is not None and dtype.endswith('!'):
297        platform = "dll"
[dd7fc12]298        dtype = dtype[:-1]
299
[7891a2a]300    # Convert special type names "half", "fast", and "quad"
[2547694]301    fast = (dtype == "fast")
[7891a2a]302    if fast:
303        dtype = "single"
[2547694]304    elif dtype == "quad":
[7891a2a]305        dtype = "longdouble"
[2547694]306    elif dtype == "half":
[650c6d2]307        dtype = "float16"
[7891a2a]308
309    # Convert dtype string to numpy dtype.
[bb39b4a]310    if dtype is None or dtype == "default":
[2547694]311        numpy_dtype = (generate.F32 if platform == "ocl" and model_info.single
312                       else generate.F64)
[dd7fc12]313    else:
[7891a2a]314        numpy_dtype = np.dtype(dtype)
[9890053]315
[7891a2a]316    # Make sure that the type is supported by opencl, otherwise use dll
[2547694]317    if platform == "ocl":
[7891a2a]318        env = kernelcl.environment()
319        if not env.has_type(numpy_dtype):
320            platform = "dll"
321            if dtype is None:
322                numpy_dtype = generate.F64
[dd7fc12]323
[7891a2a]324    return numpy_dtype, fast, platform
[0c24a82]325
[2547694]326def list_models_main():
[40a87fa]327    # type: () -> None
328    """
329    Run list_models as a main program.  See :func:`list_models` for the
330    kinds of models that can be requested on the command line.
331    """
[0b03001]332    import sys
333    kind = sys.argv[1] if len(sys.argv) > 1 else "all"
334    print("\n".join(list_models(kind)))
[2547694]335
336if __name__ == "__main__":
337    list_models_main()
Note: See TracBrowser for help on using the repository browser.