source: sasmodels/sasmodels/core.py @ ef07e95

core_shell_microgelsmagnetic_modelticket-1257-vesicle-productticket_1156ticket_1265_superballticket_822_more_unit_tests
Last change on this file since ef07e95 was 2d81cfe, checked in by Paul Kienzle <pkienzle@…>, 6 years ago

lint

  • Property mode set to 100644
File size: 12.4 KB
RevLine 
[aa4946b]1"""
2Core model handling routines.
3"""
[6d6508e]4from __future__ import print_function
5
[98f60fc]6__all__ = [
[6d6508e]7    "list_models", "load_model", "load_model_info",
[2547694]8    "build_model", "precompile_dlls",
[98f60fc]9    ]
[f734e7d]10
[7bf4757]11import os
[e65c3ba]12from os.path import basename, join as joinpath
[f734e7d]13from glob import glob
[e65c3ba]14import re
[f734e7d]15
[7ae2b7f]16import numpy as np # type: ignore
[f734e7d]17
[aa4946b]18from . import generate
[6d6508e]19from . import modelinfo
20from . import product
[72a081d]21from . import mixture
[aa4946b]22from . import kernelpy
23from . import kerneldll
[60335cc]24from . import custom
[880a2ed]25
[a557a99]26if os.environ.get("SAS_OPENCL", "").lower() == "none":
[aa4946b]27    HAVE_OPENCL = False
[880a2ed]28else:
29    try:
30        from . import kernelcl
31        HAVE_OPENCL = True
32    except Exception:
33        HAVE_OPENCL = False
[aa4946b]34
[60335cc]35CUSTOM_MODEL_PATH = os.environ.get('SAS_MODELPATH', "")
36if CUSTOM_MODEL_PATH == "":
[e65c3ba]37    CUSTOM_MODEL_PATH = joinpath(os.path.expanduser("~"), ".sasmodels", "custom_models")
38    if not os.path.isdir(CUSTOM_MODEL_PATH):
39        os.makedirs(CUSTOM_MODEL_PATH)
[60335cc]40
[2d81cfe]41# pylint: disable=unused-import
[f619de7]42try:
43    from typing import List, Union, Optional, Any
44    from .kernel import KernelModel
[dd7fc12]45    from .modelinfo import ModelInfo
[f619de7]46except ImportError:
47    pass
[2d81cfe]48# pylint: enable=unused-import
[f619de7]49
[4d76711]50# TODO: refactor composite model support
51# The current load_model_info/build_model does not reuse existing model
52# definitions when loading a composite model, instead reloading and
53# rebuilding the kernel for each component model in the expression.  This
54# is fine in a scripting environment where the model is built when the script
55# starts and is thrown away when the script ends, but may not be the best
56# solution in a long-lived application.  This affects the following functions:
57#
58#    load_model
59#    load_model_info
60#    build_model
[f734e7d]61
[8407d8c]62KINDS = ("all", "py", "c", "double", "single", "opencl", "1d", "2d",
[2547694]63         "nonmagnetic", "magnetic")
[0b03001]64def list_models(kind=None):
[52e9a45]65    # type: (str) -> List[str]
[aa4946b]66    """
67    Return the list of available models on the model path.
[40a87fa]68
69    *kind* can be one of the following:
70
71        * all: all models
72        * py: python models only
73        * c: compiled models only
74        * single: models which support single precision
75        * double: models which require double precision
[8407d8c]76        * opencl: controls if OpenCL is supperessed
[40a87fa]77        * 1d: models which are 1D only, or 2D using abs(q)
78        * 2d: models which can be 2D
79        * magnetic: models with an sld
80        * nommagnetic: models without an sld
[5124c969]81
82    For multiple conditions, combine with plus.  For example, *c+single+2d*
83    would return all oriented models implemented in C which can be computed
84    accurately with single precision arithmetic.
[aa4946b]85    """
[5124c969]86    if kind and any(k not in KINDS for k in kind.split('+')):
[2547694]87        raise ValueError("kind not in " + ", ".join(KINDS))
[3d9001f]88    files = sorted(glob(joinpath(generate.MODEL_PATH, "[a-zA-Z]*.py")))
[f734e7d]89    available_models = [basename(f)[:-3] for f in files]
[5124c969]90    if kind and '+' in kind:
91        all_kinds = kind.split('+')
92        condition = lambda name: all(_matches(name, k) for k in all_kinds)
93    else:
94        condition = lambda name: _matches(name, kind)
95    selected = [name for name in available_models if condition(name)]
[0b03001]96
97    return selected
98
99def _matches(name, kind):
[2547694]100    if kind is None or kind == "all":
[0b03001]101        return True
102    info = load_model_info(name)
103    pars = info.parameters.kernel_parameters
104    if kind == "py" and callable(info.Iq):
105        return True
106    elif kind == "c" and not callable(info.Iq):
107        return True
108    elif kind == "double" and not info.single:
109        return True
[d2d6100]110    elif kind == "single" and info.single:
111        return True
[8407d8c]112    elif kind == "opencl" and info.opencl:
[407bf48]113        return True
[2547694]114    elif kind == "2d" and any(p.type == 'orientation' for p in pars):
[d2d6100]115        return True
[40a87fa]116    elif kind == "1d" and all(p.type != 'orientation' for p in pars):
[0b03001]117        return True
[2547694]118    elif kind == "magnetic" and any(p.type == 'sld' for p in pars):
[0b03001]119        return True
[2547694]120    elif kind == "nonmagnetic" and any(p.type != 'sld' for p in pars):
[d2d6100]121        return True
[0b03001]122    return False
[f734e7d]123
[f619de7]124def load_model(model_name, dtype=None, platform='ocl'):
[dd7fc12]125    # type: (str, str, str) -> KernelModel
[b8e5e21]126    """
127    Load model info and build model.
[f619de7]128
[2e66ef5]129    *model_name* is the name of the model, or perhaps a model expression
130    such as sphere*hardsphere or sphere+cylinder.
131
132    *dtype* and *platform* are given by :func:`build_model`.
[b8e5e21]133    """
[f619de7]134    return build_model(load_model_info(model_name),
135                       dtype=dtype, platform=platform)
[aa4946b]136
[61a4bd4]137def load_model_info(model_string):
[f619de7]138    # type: (str) -> modelinfo.ModelInfo
[aa4946b]139    """
140    Load a model definition given the model name.
[1d4017a]141
[481ff64]142    *model_string* is the name of the model, or perhaps a model expression
143    such as sphere*cylinder or sphere+cylinder. Use '@' for a structure
[60335cc]144    factor product, e.g. sphere@hardsphere. Custom models can be specified by
145    prefixing the model name with 'custom.', e.g. 'custom.MyModel+sphere'.
[2e66ef5]146
[1d4017a]147    This returns a handle to the module defining the model.  This can be
148    used with functions in generate to build the docs or extract model info.
[aa4946b]149    """
[481ff64]150    if '@' in model_string:
151        parts = model_string.split('@')
152        if len(parts) != 2:
153            raise ValueError("Use P@S to apply a structure factor S to model P")
154        P_info, Q_info = [load_model_info(part) for part in parts]
155        return product.make_product_info(P_info, Q_info)
156
[61a4bd4]157    product_parts = []
158    addition_parts = []
159
160    addition_parts_names = model_string.split('+')
161    if len(addition_parts_names) >= 2:
162        addition_parts = [load_model_info(part) for part in addition_parts_names]
163    elif len(addition_parts_names) == 1:
164        product_parts_names = model_string.split('*')
165        if len(product_parts_names) >= 2:
166            product_parts = [load_model_info(part) for part in product_parts_names]
167        elif len(product_parts_names) == 1:
[60335cc]168            if "custom." in product_parts_names[0]:
169                # Extract ModelName from "custom.ModelName"
170                pattern = "custom.([A-Za-z0-9_-]+)"
171                result = re.match(pattern, product_parts_names[0])
172                if result is None:
173                    raise ValueError("Model name in invalid format: " + product_parts_names[0])
174                model_name = result.group(1)
175                # Use ModelName to find the path to the custom model file
176                model_path = joinpath(CUSTOM_MODEL_PATH, model_name + ".py")
177                if not os.path.isfile(model_path):
178                    raise ValueError("The model file {} doesn't exist".format(model_path))
179                kernel_module = custom.load_custom_kernel_module(model_path)
180                return modelinfo.make_model_info(kernel_module)
181            # Model is a core model
[61a4bd4]182            kernel_module = generate.load_kernel_module(product_parts_names[0])
183            return modelinfo.make_model_info(kernel_module)
184
185    model = None
186    if len(product_parts) > 1:
187        model = mixture.make_mixture_info(product_parts, operation='*')
188    if len(addition_parts) > 1:
189        if model is not None:
190            addition_parts.append(model)
191        model = mixture.make_mixture_info(addition_parts, operation='+')
192    return model
[d19962c]193
194
[17bbadd]195def build_model(model_info, dtype=None, platform="ocl"):
[dd7fc12]196    # type: (modelinfo.ModelInfo, str, str) -> KernelModel
[aa4946b]197    """
198    Prepare the model for the default execution platform.
199
200    This will return an OpenCL model, a DLL model or a python model depending
201    on the model and the computing platform.
202
[17bbadd]203    *model_info* is the model definition structure returned from
204    :func:`load_model_info`.
[bcd3aa3]205
[aa4946b]206    *dtype* indicates whether the model should use single or double precision
[dd7fc12]207    for the calculation.  Choices are 'single', 'double', 'quad', 'half',
208    or 'fast'.  If *dtype* ends with '!', then force the use of the DLL rather
209    than OpenCL for the calculation.
[aa4946b]210
211    *platform* should be "dll" to force the dll to be used for C models,
212    otherwise it uses the default "ocl".
213    """
[6d6508e]214    composition = model_info.composition
[72a081d]215    if composition is not None:
216        composition_type, parts = composition
217        models = [build_model(p, dtype=dtype, platform=platform) for p in parts]
218        if composition_type == 'mixture':
219            return mixture.MixtureModel(model_info, models)
220        elif composition_type == 'product':
[e79f0a5]221            P, S = models
[72a081d]222            return product.ProductModel(model_info, P, S)
223        else:
224            raise ValueError('unknown mixture type %s'%composition_type)
[aa4946b]225
[fa5fd8d]226    # If it is a python model, return it immediately
227    if callable(model_info.Iq):
228        return kernelpy.PyModel(model_info)
229
[def2c1b]230    numpy_dtype, fast, platform = parse_dtype(model_info, dtype, platform)
[7891a2a]231
[72a081d]232    source = generate.make_source(model_info)
[7891a2a]233    if platform == "dll":
[f2f67a6]234        #print("building dll", numpy_dtype)
[a4280bd]235        return kerneldll.load_dll(source['dll'], model_info, numpy_dtype)
[aa4946b]236    else:
[f2f67a6]237        #print("building ocl", numpy_dtype)
[dd7fc12]238        return kernelcl.GpuModel(source, model_info, numpy_dtype, fast=fast)
[aa4946b]239
[7bf4757]240def precompile_dlls(path, dtype="double"):
[7891a2a]241    # type: (str, str) -> List[str]
[b8e5e21]242    """
[7bf4757]243    Precompile the dlls for all builtin models, returning a list of dll paths.
[b8e5e21]244
[7bf4757]245    *path* is the directory in which to save the dlls.  It will be created if
246    it does not already exist.
[b8e5e21]247
248    This can be used when build the windows distribution of sasmodels
[7bf4757]249    which may be missing the OpenCL driver and the dll compiler.
[b8e5e21]250    """
[7891a2a]251    numpy_dtype = np.dtype(dtype)
[7bf4757]252    if not os.path.exists(path):
253        os.makedirs(path)
254    compiled_dlls = []
255    for model_name in list_models():
256        model_info = load_model_info(model_name)
[a4280bd]257        if not callable(model_info.Iq):
258            source = generate.make_source(model_info)['dll']
[7bf4757]259            old_path = kerneldll.DLL_PATH
260            try:
261                kerneldll.DLL_PATH = path
[def2c1b]262                dll = kerneldll.make_dll(source, model_info, dtype=numpy_dtype)
[7bf4757]263            finally:
264                kerneldll.DLL_PATH = old_path
265            compiled_dlls.append(dll)
266    return compiled_dlls
[b8e5e21]267
[7891a2a]268def parse_dtype(model_info, dtype=None, platform=None):
269    # type: (ModelInfo, str, str) -> (np.dtype, bool, str)
[dd7fc12]270    """
271    Interpret dtype string, returning np.dtype and fast flag.
272
273    Possible types include 'half', 'single', 'double' and 'quad'.  If the
[2e66ef5]274    type is 'fast', then this is equivalent to dtype 'single' but using
[9e771a3]275    fast native functions rather than those with the precision level
276    guaranteed by the OpenCL standard.  'default' will choose the appropriate
277    default for the model and platform.
[2e66ef5]278
279    Platform preference can be specfied ("ocl" vs "dll"), with the default
280    being OpenCL if it is availabe.  If the dtype name ends with '!' then
281    platform is forced to be DLL rather than OpenCL.
282
283    This routine ignores the preferences within the model definition.  This
284    is by design.  It allows us to test models in single precision even when
285    we have flagged them as requiring double precision so we can easily check
286    the performance on different platforms without having to change the model
287    definition.
[dd7fc12]288    """
[7891a2a]289    # Assign default platform, overriding ocl with dll if OpenCL is unavailable
[8407d8c]290    # If opencl=False OpenCL is switched off
[fe77343]291
[f49675c]292    if platform is None:
[7891a2a]293        platform = "ocl"
[8407d8c]294    if platform == "ocl" and not HAVE_OPENCL or not model_info.opencl:
[7891a2a]295        platform = "dll"
296
297    # Check if type indicates dll regardless of which platform is given
298    if dtype is not None and dtype.endswith('!'):
299        platform = "dll"
[dd7fc12]300        dtype = dtype[:-1]
301
[7891a2a]302    # Convert special type names "half", "fast", and "quad"
[2547694]303    fast = (dtype == "fast")
[7891a2a]304    if fast:
305        dtype = "single"
[2547694]306    elif dtype == "quad":
[7891a2a]307        dtype = "longdouble"
[2547694]308    elif dtype == "half":
[650c6d2]309        dtype = "float16"
[7891a2a]310
311    # Convert dtype string to numpy dtype.
[bb39b4a]312    if dtype is None or dtype == "default":
[2547694]313        numpy_dtype = (generate.F32 if platform == "ocl" and model_info.single
314                       else generate.F64)
[dd7fc12]315    else:
[7891a2a]316        numpy_dtype = np.dtype(dtype)
[9890053]317
[7891a2a]318    # Make sure that the type is supported by opencl, otherwise use dll
[2547694]319    if platform == "ocl":
[7891a2a]320        env = kernelcl.environment()
321        if not env.has_type(numpy_dtype):
322            platform = "dll"
323            if dtype is None:
324                numpy_dtype = generate.F64
[dd7fc12]325
[7891a2a]326    return numpy_dtype, fast, platform
[0c24a82]327
[2547694]328def list_models_main():
[40a87fa]329    # type: () -> None
330    """
331    Run list_models as a main program.  See :func:`list_models` for the
332    kinds of models that can be requested on the command line.
333    """
[0b03001]334    import sys
335    kind = sys.argv[1] if len(sys.argv) > 1 else "all"
336    print("\n".join(list_models(kind)))
[2547694]337
338if __name__ == "__main__":
339    list_models_main()
Note: See TracBrowser for help on using the repository browser.