Changeset f619de7 in sasmodels
- Timestamp:
- Apr 11, 2016 11:14:50 AM (9 years ago)
- Branches:
- master, core_shell_microgels, costrafo411, magnetic_model, release_v0.94, release_v0.95, ticket-1257-vesicle-product, ticket_1156, ticket_1265_superball, ticket_822_more_unit_tests
- Children:
- 7ae2b7f
- Parents:
- 9a943d0
- Location:
- sasmodels
- Files:
-
- 1 added
- 9 edited
Legend:
- Unmodified
- Added
- Removed
-
sasmodels/core.py
r6d6508e rf619de7 26 26 HAVE_OPENCL = False 27 27 28 try: 29 from typing import List, Union, Optional, Any 30 DType = Union[None, str, np.dtype] 31 from .kernel import KernelModel 32 except ImportError: 33 pass 34 35 28 36 # TODO: refactor composite model support 29 37 # The current load_model_info/build_model does not reuse existing model … … 39 47 40 48 def list_models(): 49 # type: () -> List[str] 41 50 """ 42 51 Return the list of available models on the model path. … … 48 57 49 58 def isstr(s): 59 # type: (Any) -> bool 50 60 """ 51 61 Return True if *s* is a string-like object. … … 55 65 return True 56 66 57 def load_model(model_name, **kw): 67 def load_model(model_name, dtype=None, platform='ocl'): 68 # type: (str, DType, str) -> KernelModel 58 69 """ 59 70 Load model info and build model. 71 72 *model_name* is the name of the model as used by :func:`load_model_info`. 73 Additional keyword arguments are passed directly to :func:`build_model`. 60 74 """ 61 return build_model(load_model_info(model_name), **kw) 75 return build_model(load_model_info(model_name), 76 dtype=dtype, platform=platform) 62 77 63 78 64 79 def load_model_info(model_name): 80 # type: (str) -> modelinfo.ModelInfo 65 81 """ 66 82 Load a model definition given the model name. … … 86 102 87 103 def build_model(model_info, dtype=None, platform="ocl"): 104 # type: (modelinfo.ModelInfo, DType, str) -> KernelModel 88 105 """ 89 106 Prepare the model for the default execution platform. … … 138 155 139 156 def precompile_dll(model_name, dtype="double"): 157 # type: (str, DType) -> Optional[str] 140 158 """ 141 159 Precompile the dll for a model. -
sasmodels/generate.py
r6d6508e rf619de7 164 164 from .modelinfo import Parameter 165 165 from .custom import load_custom_kernel_module 166 167 try: 168 from typing import Tuple, Sequence, Iterator 169 from .modelinfo import ModelInfo 170 except ImportError: 171 pass 166 172 167 173 TEMPLATE_ROOT = dirname(__file__) … … 220 226 221 227 def format_units(units): 228 # type: (str) -> str 222 229 """ 223 230 Convert units into ReStructured Text format. … … 226 233 227 234 def make_partable(pars): 235 # type: (List[Parameter]) -> str 228 236 """ 229 237 Generate the parameter table to include in the sphinx documentation. … … 256 264 257 265 def _search(search_path, filename): 266 # type: (List[str], str) -> str 258 267 """ 259 268 Find *filename* in *search_path*. … … 269 278 270 279 def model_sources(model_info): 280 # type: (ModelInfo) -> List[str] 271 281 """ 272 282 Return a list of the sources file paths for the module. … … 277 287 278 288 def timestamp(model_info): 289 # type: (ModelInfo) -> int 279 290 """ 280 291 Return a timestamp for the model corresponding to the most recently … … 288 299 289 300 def convert_type(source, dtype): 301 # type: (str, np.dtype) -> str 290 302 """ 291 303 Convert code from double precision to the desired type. … … 312 324 313 325 def _convert_type(source, type_name, constant_flag): 326 # type: (str, str, str) -> str 314 327 """ 315 328 Replace 'double' with *type_name* in *source*, tagging floating point … … 330 343 331 344 def kernel_name(model_info, is_2d): 345 # type: (ModelInfo, bool) -> str 332 346 """ 333 347 Name of the exported kernel symbol. … … 337 351 338 352 def indent(s, depth): 353 # type: (str, int) -> str 339 354 """ 340 355 Indent a string of text with *depth* additional spaces on each line. … … 345 360 346 361 347 _template_cache = {} 362 _template_cache = {} # type: Dict[str, Tuple[int, str, str]] 348 363 def load_template(filename): 364 # type: (str) -> str 349 365 path = joinpath(TEMPLATE_ROOT, filename) 350 366 mtime = getmtime(path) … … 355 371 356 372 def model_templates(): 373 # type: () -> List[str] 357 374 # TODO: fails DRY; templates are listed in two places. 358 375 # should instead have model_info contain a list of paths … … 371 388 372 389 def _gen_fn(name, pars, body): 390 # type: (str, List[Parameter], str) -> str 373 391 """ 374 392 Generate a function given pars and body. … … 385 403 386 404 def _call_pars(prefix, pars): 405 # type: (str, List[Parameter]) -> List[str] 387 406 """ 388 407 Return a list of *prefix.parameter* from parameter items. … … 393 412 flags=re.MULTILINE) 394 413 def _have_Iqxy(sources): 414 # type: (List[str]) -> bool 395 415 """ 396 416 Return true if any file defines Iqxy. … … 414 434 415 435 def make_source(model_info): 436 # type: (ModelInfo) -> str 416 437 """ 417 438 Generate the OpenCL/ctypes kernel from the module info. 418 439 419 Uses source files found in the given search path. 440 Uses source files found in the given search path. Returns None if this 441 is a pure python model, with no C source components. 420 442 """ 421 443 if callable(model_info.Iq): 422 r eturn None444 raise ValueError("can't compile python model") 423 445 424 446 # TODO: need something other than volume to indicate dispersion parameters … … 447 469 q, qx, qy = [Parameter(name=v) for v in ('q', 'qx', 'qy')] 448 470 # Generate form_volume function, etc. from body only 449 if model_info.form_volume is not None:471 if isinstance(model_info.form_volume, str): 450 472 pars = partable.form_volume_parameters 451 473 source.append(_gen_fn('form_volume', pars, model_info.form_volume)) 452 if model_info.Iq is not None:474 if isinstance(model_info.Iq, str): 453 475 pars = [q] + partable.iq_parameters 454 476 source.append(_gen_fn('Iq', pars, model_info.Iq)) 455 if model_info.Iqxy is not None:477 if isinstance(model_info.Iqxy, str): 456 478 pars = [qx, qy] + partable.iqxy_parameters 457 479 source.append(_gen_fn('Iqxy', pars, model_info.Iqxy)) … … 509 531 510 532 def load_kernel_module(model_name): 533 # type: (str) -> module 511 534 if model_name.endswith('.py'): 512 535 kernel_module = load_custom_kernel_module(model_name) … … 522 545 %re.escape(string.punctuation)) 523 546 def _convert_section_titles_to_boldface(lines): 547 # type: (Sequence[str]) -> Iterator[str] 524 548 """ 525 549 Do the actual work of identifying and converting section headings. … … 543 567 544 568 def convert_section_titles_to_boldface(s): 569 # type: (str) -> str 545 570 """ 546 571 Use explicit bold-face rather than section headings so that the table of … … 553 578 554 579 def make_doc(model_info): 580 # type: (ModelInfo) -> str 555 581 """ 556 582 Return the documentation for the model. … … 562 588 name=model_info.name, 563 589 title=model_info.title, 564 parameters=make_partable(model_info.parameters ),590 parameters=make_partable(model_info.parameters.kernel_parameters), 565 591 returns=Sq_units if model_info.structure_factor else Iq_units, 566 592 docs=docs) … … 569 595 570 596 def demo_time(): 597 # type: () -> None 571 598 """ 572 599 Show how long it takes to process a model. … … 582 609 583 610 def main(): 611 # type: () -> None 584 612 """ 585 613 Program which prints the source produced by the model. -
sasmodels/kernelcl.py
r6d6508e rf619de7 67 67 68 68 from . import generate 69 from .kernel import KernelModel, Kernel 69 70 70 71 # The max loops number is limited by the amount of local memory available … … 310 311 311 312 312 class GpuModel( object):313 class GpuModel(KernelModel): 313 314 """ 314 315 GPU wrapper for a single model. … … 420 421 self.release() 421 422 422 class GpuKernel( object):423 class GpuKernel(Kernel): 423 424 """ 424 425 Callable SAS kernel. … … 489 490 self.kernel(self.queue, self.q_input.global_size, None, *args) 490 491 cl.enqueue_copy(self.queue, self.result, self.result_b) 491 [v.release() for v in details_b, weights_b, values_b]492 [v.release() for v in (details_b, weights_b, values_b)] 492 493 493 494 return self.result[:self.nq] -
sasmodels/kerneldll.py
r6d6508e rf619de7 56 56 from . import generate 57 57 from . import details 58 from .kernelpy import PyInput, PyModel 58 from .kernel import KernelModel, Kernel 59 from .kernelpy import PyInput 59 60 from .exception import annotate_exception 61 from .generate import F16, F32, F64 62 63 try: 64 from typing import Tuple, Callable, Any 65 from .modelinfo import ModelInfo 66 from .details import CallDetails 67 except ImportError: 68 pass 60 69 61 70 # Compiler platform details … … 91 100 92 101 def dll_name(model_info, dtype): 102 # type: (ModelInfo, np.dtype) -> str 93 103 """ 94 104 Name of the dll containing the model. This is the base file name without … … 98 108 return "sas_%s%d"%(model_info.id, bits) 99 109 110 100 111 def dll_path(model_info, dtype): 112 # type: (ModelInfo, np.dtype) -> str 101 113 """ 102 114 Complete path to the dll for the model. Note that the dll may not … … 105 117 return os.path.join(DLL_PATH, dll_name(model_info, dtype)+".so") 106 118 107 def make_dll(source, model_info, dtype="double"): 108 """ 109 Load the compiled model defined by *kernel_module*. 110 111 Recompile if any files are newer than the model file. 119 120 def make_dll(source, model_info, dtype=F64): 121 # type: (str, ModelInfo, np.dtype) -> str 122 """ 123 Returns the path to the compiled model defined by *kernel_module*. 124 125 If the model has not been compiled, or if the source file(s) are newer 126 than the dll, then *make_dll* will compile the model before returning. 127 This routine does not load the resulting dll. 112 128 113 129 *dtype* is a numpy floating point precision specifier indicating whether 114 the model should be single or double precision. The default is double115 precision.116 117 The DLL is not loaded until the kernel is called so models can118 be defined without using too many resources.130 the model should be single, double or long double precision. The default 131 is double precision, *np.dtype('d')*. 132 133 Set *sasmodels.ALLOW_SINGLE_PRECISION_DLLS* to False if single precision 134 models are not allowed as DLLs. 119 135 120 136 Set *sasmodels.kerneldll.DLL_PATH* to the compiled dll output path. 121 137 The default is the system temporary directory. 122 123 Set *sasmodels.ALLOW_SINGLE_PRECISION_DLLS* to True if single precision 124 models are allowed as DLLs. 125 """ 126 if callable(model_info.Iq): 127 return PyModel(model_info) 128 129 dtype = np.dtype(dtype) 130 if dtype == generate.F16: 138 """ 139 if dtype == F16: 131 140 raise ValueError("16 bit floats not supported") 132 if dtype == generate.F32 and not ALLOW_SINGLE_PRECISION_DLLS:133 dtype = generate.F64 # Force 64-bit dll134 135 source = generate.convert_type(source, dtype) 141 if dtype == F32 and not ALLOW_SINGLE_PRECISION_DLLS: 142 dtype = F64 # Force 64-bit dll 143 # Note: dtype may be F128 for long double precision 144 136 145 newest = generate.timestamp(model_info) 137 146 dll = dll_path(model_info, dtype) … … 139 148 basename = dll_name(model_info, dtype) + "_" 140 149 fid, filename = tempfile.mkstemp(suffix=".c", prefix=basename) 150 source = generate.convert_type(source, dtype) 141 151 os.fdopen(fid, "w").write(source) 142 152 command = COMPILE%{"source":filename, "output":dll} … … 152 162 153 163 154 def load_dll(source, model_info, dtype="double"): 164 def load_dll(source, model_info, dtype=F64): 165 # type: (str, ModelInfo, np.dtype) -> "DllModel" 155 166 """ 156 167 Create and load a dll corresponding to the source, info pair returned … … 163 174 return DllModel(filename, model_info, dtype=dtype) 164 175 165 class DllModel(object): 176 177 class DllModel(KernelModel): 166 178 """ 167 179 ctypes wrapper for a single model. … … 179 191 180 192 def __init__(self, dllpath, model_info, dtype=generate.F32): 193 # type: (str, ModelInfo, np.dtype) -> None 181 194 self.info = model_info 182 195 self.dllpath = dllpath 183 self. dll = None196 self._dll = None # type: ct.CDLL 184 197 self.dtype = np.dtype(dtype) 185 198 186 199 def _load_dll(self): 200 # type: () -> None 187 201 #print("dll", self.dllpath) 188 202 try: 189 self. dll = ct.CDLL(self.dllpath)203 self._dll = ct.CDLL(self.dllpath) 190 204 except: 191 205 annotate_exception("while loading "+self.dllpath) … … 198 212 # int, int, int, int*, double*, double*, double*, double*, double*, double 199 213 argtypes = [c_int32]*3 + [c_void_p]*5 + [fp] 200 self. Iq = self.dll[generate.kernel_name(self.info,False)]201 self. Iqxy = self.dll[generate.kernel_name(self.info,True)]202 self. Iq.argtypes = argtypes203 self. Iqxy.argtypes = argtypes214 self._Iq = self._dll[generate.kernel_name(self.info, is_2d=False)] 215 self._Iqxy = self._dll[generate.kernel_name(self.info, is_2d=True)] 216 self._Iq.argtypes = argtypes 217 self._Iqxy.argtypes = argtypes 204 218 205 219 def __getstate__(self): 220 # type: () -> Tuple[ModelInfo, str] 206 221 return self.info, self.dllpath 207 222 208 223 def __setstate__(self, state): 224 # type: (Tuple[ModelInfo, str]) -> None 209 225 self.info, self.dllpath = state 210 self. dll = None226 self._dll = None 211 227 212 228 def make_kernel(self, q_vectors): 229 # type: (List[np.ndarray]) -> DllKernel 213 230 q_input = PyInput(q_vectors, self.dtype) 214 if self.dll is None: self._load_dll() 215 kernel = self.Iqxy if q_input.is_2d else self.Iq 231 # Note: pickle not supported for DllKernel 232 if self._dll is None: 233 self._load_dll() 234 kernel = self._Iqxy if q_input.is_2d else self._Iq 216 235 return DllKernel(kernel, self.info, q_input) 217 236 218 237 def release(self): 238 # type: () -> None 219 239 """ 220 240 Release any resources associated with the model. … … 225 245 libHandle = dll._handle 226 246 #libHandle = ct.c_void_p(dll._handle) 227 del dll, self. dll228 self. dll = None247 del dll, self._dll 248 self._dll = None 229 249 #_ctypes.FreeLibrary(libHandle) 230 250 ct.windll.kernel32.FreeLibrary(libHandle) … … 233 253 234 254 235 class DllKernel( object):255 class DllKernel(Kernel): 236 256 """ 237 257 Callable SAS kernel. … … 253 273 """ 254 274 def __init__(self, kernel, model_info, q_input): 275 # type: (Callable[[], np.ndarray], ModelInfo, PyInput) -> None 255 276 self.kernel = kernel 256 277 self.info = model_info … … 261 282 262 283 def __call__(self, call_details, weights, values, cutoff): 284 # type: (CallDetails, np.ndarray, np.ndarray, float) -> np.ndarray 263 285 real = (np.float32 if self.q_input.dtype == generate.F32 264 286 else np.float64 if self.q_input.dtype == generate.F64 … … 282 304 real(cutoff), # cutoff 283 305 ] 284 self.kernel(*args) 306 self.kernel(*args) # type: ignore 285 307 return self.result[:-3] 286 308 287 309 def release(self): 310 # type: () -> None 288 311 """ 289 312 Release any resources associated with the kernel. 290 313 """ 291 pass314 self.q_input.release() -
sasmodels/kernelpy.py
r9a943d0 rf619de7 12 12 from . import details 13 13 from .generate import F64 14 from .kernel import KernelModel, Kernel 14 15 15 16 try: … … 20 21 DType = Union[None, str, np.dtype] 21 22 22 class PyModel( object):23 class PyModel(KernelModel): 23 24 """ 24 25 Wrapper for pure python models. … … 77 78 self.q = None 78 79 79 class PyKernel( object):80 class PyKernel(Kernel): 80 81 """ 81 82 Callable SAS kernel. … … 162 163 Free resources associated with the kernel. 163 164 """ 165 self.q_input.release() 164 166 self.q_input = None 165 167 -
sasmodels/mixture.py
r6d6508e rf619de7 15 15 16 16 from .modelinfo import Parameter, ParameterTable, ModelInfo 17 from .kernel import KernelModel, Kernel 18 19 try: 20 from typing import List 21 from .details import CallDetails 22 except ImportError: 23 pass 17 24 18 25 def make_mixture_info(parts): 26 # type: (List[ModelInfo]) -> ModelInfo 19 27 """ 20 28 Create info block for product model. … … 22 30 flatten = [] 23 31 for part in parts: 24 if part ['composition'] and part['composition'][0] == 'mixture':25 flatten.extend(part ['compostion'][1])32 if part.composition and part.composition[0] == 'mixture': 33 flatten.extend(part.composition[1]) 26 34 else: 27 35 flatten.append(part) … … 29 37 30 38 # Build new parameter list 31 pars = []39 combined_pars = [] 32 40 for k, part in enumerate(parts): 33 41 # Parameter prefix per model, A_, B_, ... … … 35 43 # to support vector parameters 36 44 prefix = chr(ord('A')+k) + '_' 37 pars.append(Parameter(prefix+'scale'))38 for p in part ['parameters'].kernel_pars:45 combined_pars.append(Parameter(prefix+'scale')) 46 for p in part.parameters.kernel_parameters: 39 47 p = copy(p) 40 p.name = prefix +p.name41 p.id = prefix +p.id48 p.name = prefix + p.name 49 p.id = prefix + p.id 42 50 if p.length_control is not None: 43 p.length_control = prefix +p.length_control44 pars.append(p)45 par table = ParameterTable(pars)51 p.length_control = prefix + p.length_control 52 combined_pars.append(p) 53 parameters = ParameterTable(combined_pars) 46 54 47 55 model_info = ModelInfo() 48 model_info.id = '+'.join(part ['id'])49 model_info.name = ' + '.join(part ['name'])56 model_info.id = '+'.join(part.id for part in parts) 57 model_info.name = ' + '.join(part.name for part in parts) 50 58 model_info.filename = None 51 59 model_info.title = 'Mixture model with ' + model_info.name … … 53 61 model_info.docs = model_info.title 54 62 model_info.category = "custom" 55 model_info.parameters = par table63 model_info.parameters = parameters 56 64 #model_info.single = any(part['single'] for part in parts) 57 65 model_info.structure_factor = False … … 64 72 65 73 66 class MixtureModel( object):74 class MixtureModel(KernelModel): 67 75 def __init__(self, model_info, parts): 76 # type: (ModelInfo, List[KernelModel]) -> None 68 77 self.info = model_info 69 78 self.parts = parts 70 79 71 80 def __call__(self, q_vectors): 81 # type: (List[np.ndarray]) -> MixtureKernel 72 82 # Note: may be sending the q_vectors to the n times even though they 73 83 # are only needed once. It would mess up modularity quite a bit to … … 76 86 # in opencl; or both in opencl, but one in single precision and the 77 87 # other in double precision). 78 kernels = [part (q_vectors) for part in self.parts]88 kernels = [part.make_kernel(q_vectors) for part in self.parts] 79 89 return MixtureKernel(self.info, kernels) 80 90 81 91 def release(self): 92 # type: () -> None 82 93 """ 83 94 Free resources associated with the model. … … 87 98 88 99 89 class MixtureKernel( object):100 class MixtureKernel(Kernel): 90 101 def __init__(self, model_info, kernels): 91 dim = '2d' if kernels[0].q_input.is_2d else '1d' 102 # type: (ModelInfo, List[Kernel]) -> None 103 self.dim = kernels[0].dim 104 self.info = model_info 105 self.kernels = kernels 92 106 93 # fixed offsets starts at 2 for scale and background 94 fixed_pars, pd_pars = [], [] 95 offsets = [[2, 0]] 96 #vol_index = [] 97 def accumulate(fixed, pd, volume): 98 # subtract 1 from fixed since we are removing background 99 fixed_offset, pd_offset = offsets[-1] 100 #vol_index.extend(k+pd_offset for k,v in pd if v in volume) 101 offsets.append([fixed_offset + len(fixed) - 1, pd_offset + len(pd)]) 102 pd_pars.append(pd) 103 if dim == '2d': 104 for p in kernels: 105 partype = p.info.partype 106 accumulate(partype['fixed-2d'], partype['pd-2d'], partype['volume']) 107 else: 108 for p in kernels: 109 partype = p.info.partype 110 accumulate(partype['fixed-1d'], partype['pd-1d'], partype['volume']) 111 112 #self.vol_index = vol_index 113 self.offsets = offsets 114 self.fixed_pars = fixed_pars 115 self.pd_pars = pd_pars 116 self.info = model_info 117 self.kernels = kernels 118 self.results = None 119 120 def __call__(self, fixed_pars, pd_pars, cutoff=1e-5): 121 scale, background = fixed_pars[0:2] 107 def __call__(self, call_details, value, weight, cutoff): 108 # type: (CallDetails, np.ndarray, np.ndarry, float) -> np.ndarray 109 scale, background = value[0:2] 122 110 total = 0.0 123 self.results = [] # remember the parts for plotting later 124 for k in range(len(self.offsets)-1): 125 start_fixed, start_pd = self.offsets[k] 126 end_fixed, end_pd = self.offsets[k+1] 127 part_fixed = [fixed_pars[start_fixed], 0.0] + fixed_pars[start_fixed+1:end_fixed] 128 part_pd = [pd_pars[start_pd], 0.0] + pd_pars[start_pd+1:end_pd] 129 part_result = self.kernels[k](part_fixed, part_pd) 111 # remember the parts for plotting later 112 self.results = [] 113 for kernel, kernel_details in zip(self.kernels, call_details.parts): 114 part_result = kernel(kernel_details, value, weight, cutoff) 130 115 total += part_result 131 self.results.append( scale*sum+background)116 self.results.append(part_result) 132 117 133 118 return scale*total + background 134 119 135 120 def release(self): 136 self.p_kernel.release() 137 self.q_kernel.release() 121 # type: () -> None 122 for k in self.kernels: 123 k.release() 138 124 -
sasmodels/model_test.py
rc1a888b rf619de7 69 69 # type: (ModelInfo, ParameterSet) -> float 70 70 """ 71 Call the model ER function using *values*. *model_info* is either 72 *model.info* if you have a loaded model, or *kernel.info* if you 73 have a model kernel prepared for evaluation. 71 Call the model ER function using *values*. 72 73 *model_info* is either *model.info* if you have a loaded model, 74 or *kernel.info* if you have a model kernel prepared for evaluation. 74 75 """ 75 76 if model_info.ER is None: … … 84 85 """ 85 86 Call the model VR function using *pars*. 86 *info* is either *model.info* if you have a loaded model, or *kernel.info* 87 if you have a model kernel prepared for evaluation. 87 88 *model_info* is either *model.info* if you have a loaded model, 89 or *kernel.info* if you have a model kernel prepared for evaluation. 88 90 """ 89 91 if model_info.VR is None: -
sasmodels/modelinfo.py
r9a943d0 rf619de7 713 713 ER = None # type: Optional[Callable[[np.ndarray], np.ndarray]] 714 714 VR = None # type: Optional[Callable[[np.ndarray], Tuple[np.ndarray, np.ndarray]]] 715 form_volume = None # type: Optional[Callable[[np.ndarray], float]]716 Iq = None # type: Optional[Callable[[np.ndarray], np.ndarray]]717 Iqxy = None # type: Optional[Callable[[np.ndarray], np.ndarray]]715 form_volume = None # type: Union[None, str, Callable[[np.ndarray], float]] 716 Iq = None # type: Union[None, str, Callable[[np.ndarray], np.ndarray]] 717 Iqxy = None # type: Union[None, str, Callable[[np.ndarray], np.ndarray]] 718 718 profile = None # type: Optional[Callable[[np.ndarray], None]] 719 719 sesans = None # type: Optional[Callable[[np.ndarray], np.ndarray]] -
sasmodels/product.py
r6d6508e rf619de7 14 14 15 15 from .details import dispersion_mesh 16 from .modelinfo import suffix_parameter, ParameterTable, Parameter, ModelInfo 16 from .modelinfo import suffix_parameter, ParameterTable, ModelInfo 17 from .kernel import KernelModel, Kernel 18 19 try: 20 from typing import Tuple 21 from .modelinfo import ParameterSet 22 from .details import CallDetails 23 except ImportError: 24 pass 17 25 18 26 # TODO: make estimates available to constraints … … 25 33 # revert it after making VR and ER available at run time as constraints. 26 34 def make_product_info(p_info, s_info): 35 # type: (ModelInfo, ModelInfo) -> ModelInfo 27 36 """ 28 37 Create info block for product model. 29 38 """ 30 p_id, p_name, p_par table= p_info.id, p_info.name, p_info.parameters31 s_id, s_name, s_par table= s_info.id, s_info.name, s_info.parameters32 p_set = set(p.id for p in p_par table)33 s_set = set(p.id for p in s_par table)39 p_id, p_name, p_pars = p_info.id, p_info.name, p_info.parameters 40 s_id, s_name, s_pars = s_info.id, s_info.name, s_info.parameters 41 p_set = set(p.id for p in p_pars.call_parameters) 42 s_set = set(p.id for p in s_pars.call_parameters) 34 43 35 44 if p_set & s_set: 36 45 # there is some overlap between the parameter names; tag the 37 46 # overlapping S parameters with name_S 38 s_ pars= [(suffix_parameter(par, "_S") if par.id in p_set else par)39 for par in s_par table.kernel_parameters]40 pars = p_partable.kernel_parameters + s_pars47 s_list = [(suffix_parameter(par, "_S") if par.id in p_set else par) 48 for par in s_pars.kernel_parameters] 49 combined_pars = p_pars.kernel_parameters + s_list 41 50 else: 42 pars= p_partable.kernel_parameters + s_partable.kernel_parameters 51 combined_pars = p_pars.kernel_parameters + s_pars.kernel_parameters 52 parameters = ParameterTable(combined_pars) 43 53 44 54 model_info = ModelInfo() … … 50 60 model_info.docs = model_info.title 51 61 model_info.category = "custom" 52 model_info.parameters = ParameterTable(pars)62 model_info.parameters = parameters 53 63 #model_info.single = p_info.single and s_info.single 54 64 model_info.structure_factor = False … … 60 70 return model_info 61 71 62 class ProductModel( object):72 class ProductModel(KernelModel): 63 73 def __init__(self, model_info, P, S): 74 # type: (ModelInfo, KernelModel, KernelModel) -> None 64 75 self.info = model_info 65 76 self.P = P … … 67 78 68 79 def __call__(self, q_vectors): 80 # type: (List[np.ndarray]) -> Kernel 69 81 # Note: may be sending the q_vectors to the GPU twice even though they 70 82 # are only needed once. It would mess up modularity quite a bit to … … 73 85 # in opencl; or both in opencl, but one in single precision and the 74 86 # other in double precision). 75 p_kernel = self.P (q_vectors)76 s_kernel = self.S (q_vectors)87 p_kernel = self.P.make_kernel(q_vectors) 88 s_kernel = self.S.make_kernel(q_vectors) 77 89 return ProductKernel(self.info, p_kernel, s_kernel) 78 90 79 91 def release(self): 92 # type: (None) -> None 80 93 """ 81 94 Free resources associated with the model. … … 85 98 86 99 87 class ProductKernel( object):100 class ProductKernel(Kernel): 88 101 def __init__(self, model_info, p_kernel, s_kernel): 102 # type: (ModelInfo, Kernel, Kernel) -> None 89 103 self.info = model_info 90 104 self.p_kernel = p_kernel … … 92 106 93 107 def __call__(self, details, weights, values, cutoff): 108 # type: (CallDetails, np.ndarray, np.ndarray, float) -> np.ndarray 94 109 effect_radius, vol_ratio = call_ER_VR(self.p_kernel.info, vol_pars) 95 110 … … 108 123 109 124 def release(self): 125 # type: () -> None 110 126 self.p_kernel.release() 111 self. q_kernel.release()127 self.s_kernel.release() 112 128 113 def call_ER_VR(model_info, vol_pars):129 def call_ER_VR(model_info, pars): 114 130 """ 115 131 Return effect radius and volume ratio for the model. 116 132 """ 117 value, weight = dispersion_mesh(vol_pars) 133 if model_info.ER is None and model_info.VR is None: 134 return 1.0, 1.0 118 135 119 individual_radii = model_info.ER(*value) if model_info.ER else 1.0 120 whole, part = model_info.VR(*value) if model_info.VR else (1.0, 1.0) 136 value, weight = _vol_pars(model_info, pars) 121 137 122 effect_radius = np.sum(weight*individual_radii) / np.sum(weight) 123 volume_ratio = np.sum(weight*part)/np.sum(weight*whole) 138 if model_info.ER is not None: 139 individual_radii = model_info.ER(*value) 140 effect_radius = np.sum(weight*individual_radii) / np.sum(weight) 141 else: 142 effect_radius = 1.0 143 144 if model_info.VR is not None: 145 whole, part = model_info.VR(*value) 146 volume_ratio = np.sum(weight*part)/np.sum(weight*whole) 147 else: 148 volume_ratio = 1.0 149 124 150 return effect_radius, volume_ratio 151 152 def _vol_pars(model_info, pars): 153 # type: (ModelInfo, ParameterSet) -> Tuple[np.ndarray, np.ndarray] 154 vol_pars = [get_weights(p, pars) 155 for p in model_info.parameters.call_parameters 156 if p.type == 'volume'] 157 value, weight = dispersion_mesh(model_info, vol_pars) 158 return value, weight 159
Note: See TracChangeset
for help on using the changeset viewer.