Changes in sasmodels/sasview_model.py [81ec7c8:60f03de] in sasmodels
- File:
-
- 1 edited
Legend:
- Unmodified
- Added
- Removed
-
sasmodels/sasview_model.py
r81ec7c8 r60f03de 21 21 import logging 22 22 23 import numpy as np 23 import numpy as np # type: ignore 24 24 25 25 from . import core 26 26 from . import custom 27 27 from . import generate 28 from . import weights 29 from . import details 30 from . import modelinfo 28 31 29 32 try: 30 from typing import Dict, Mapping, Any, Sequence, Tuple, NamedTuple, List, Optional 33 from typing import Dict, Mapping, Any, Sequence, Tuple, NamedTuple, List, Optional, Union, Callable 34 from .modelinfo import ModelInfo, Parameter 31 35 from .kernel import KernelModel 32 36 MultiplicityInfoType = NamedTuple( … … 34 38 [("number", int), ("control", str), ("choices", List[str]), 35 39 ("x_axis_label", str)]) 40 SasviewModelType = Callable[[int], "SasviewModel"] 36 41 except ImportError: 37 42 pass 38 43 39 44 # TODO: separate x_axis_label from multiplicity info 40 # The x-axis label belongs with the profile generating function45 # The profile x-axis label belongs with the profile generating function 41 46 MultiplicityInfo = collections.namedtuple( 42 47 'MultiplicityInfo', … … 44 49 ) 45 50 51 # TODO: figure out how to say that the return type is a subclass 46 52 def load_standard_models(): 53 # type: () -> List[SasviewModelType] 47 54 """ 48 55 Load and return the list of predefined models. … … 55 62 try: 56 63 models.append(_make_standard_model(name)) 57 except :64 except Exception: 58 65 logging.error(traceback.format_exc()) 59 66 return models … … 61 68 62 69 def load_custom_model(path): 70 # type: (str) -> SasviewModelType 63 71 """ 64 72 Load a custom model given the model path. 65 73 """ 66 74 kernel_module = custom.load_custom_kernel_module(path) 67 model_info = generate.make_model_info(kernel_module)75 model_info = modelinfo.make_model_info(kernel_module) 68 76 return _make_model_from_info(model_info) 69 77 70 78 71 79 def _make_standard_model(name): 80 # type: (str) -> SasviewModelType 72 81 """ 73 82 Load the sasview model defined by *name*. … … 78 87 """ 79 88 kernel_module = generate.load_kernel_module(name) 80 model_info = generate.make_model_info(kernel_module)89 model_info = modelinfo.make_model_info(kernel_module) 81 90 return _make_model_from_info(model_info) 82 91 83 92 84 93 def _make_model_from_info(model_info): 94 # type: (ModelInfo) -> SasviewModelType 85 95 """ 86 96 Convert *model_info* into a SasView model wrapper. 87 97 """ 88 model_info['variant_info'] = None # temporary hack for older sasview 89 def __init__(self, multiplicity=1): 98 def __init__(self, multiplicity=None): 90 99 SasviewModel.__init__(self, multiplicity=multiplicity) 91 100 attrs = _generate_model_attributes(model_info) 92 101 attrs['__init__'] = __init__ 93 ConstructedModel = type(model_info ['id'], (SasviewModel,), attrs)102 ConstructedModel = type(model_info.name, (SasviewModel,), attrs) # type: SasviewModelType 94 103 return ConstructedModel 95 104 … … 104 113 All the attributes should be immutable to avoid accidents. 105 114 """ 115 116 # TODO: allow model to override axis labels input/output name/unit 117 118 # Process multiplicity 119 non_fittable = [] # type: List[str] 120 xlabel = model_info.profile_axes[0] if model_info.profile is not None else "" 121 variants = MultiplicityInfo(0, "", [], xlabel) 122 for p in model_info.parameters.kernel_parameters: 123 if p.name == model_info.control: 124 non_fittable.append(p.name) 125 variants = MultiplicityInfo( 126 len(p.choices), p.name, p.choices, xlabel 127 ) 128 break 129 elif p.is_control: 130 non_fittable.append(p.name) 131 variants = MultiplicityInfo( 132 int(p.limits[1]), p.name, p.choices, xlabel 133 ) 134 break 135 136 # Organize parameter sets 137 orientation_params = [] 138 magnetic_params = [] 139 fixed = [] 140 for p in model_info.parameters.user_parameters(): 141 if p.type == 'orientation': 142 orientation_params.append(p.name) 143 orientation_params.append(p.name+".width") 144 fixed.append(p.name+".width") 145 if p.type == 'magnetic': 146 orientation_params.append(p.name) 147 magnetic_params.append(p.name) 148 fixed.append(p.name+".width") 149 150 # Build class dictionary 106 151 attrs = {} # type: Dict[str, Any] 107 152 attrs['_model_info'] = model_info 108 attrs['name'] = model_info['name'] 109 attrs['id'] = model_info['id'] 110 attrs['description'] = model_info['description'] 111 attrs['category'] = model_info['category'] 112 113 # TODO: allow model to override axis labels input/output name/unit 114 115 #self.is_multifunc = False 116 non_fittable = [] # type: List[str] 117 variants = MultiplicityInfo(0, "", [], "") 118 attrs['is_structure_factor'] = model_info['structure_factor'] 119 attrs['is_form_factor'] = model_info['ER'] is not None 153 attrs['name'] = model_info.name 154 attrs['id'] = model_info.id 155 attrs['description'] = model_info.description 156 attrs['category'] = model_info.category 157 attrs['is_structure_factor'] = model_info.structure_factor 158 attrs['is_form_factor'] = model_info.ER is not None 120 159 attrs['is_multiplicity_model'] = variants[0] > 1 121 160 attrs['multiplicity_info'] = variants 122 123 partype = model_info['partype']124 orientation_params = (125 partype['orientation']126 + [n + '.width' for n in partype['orientation']]127 + partype['magnetic'])128 magnetic_params = partype['magnetic']129 fixed = [n + '.width' for n in partype['pd-2d']]130 131 161 attrs['orientation_params'] = tuple(orientation_params) 132 162 attrs['magnetic_params'] = tuple(magnetic_params) 133 163 attrs['fixed'] = tuple(fixed) 134 135 164 attrs['non_fittable'] = tuple(non_fittable) 136 165 … … 192 221 dispersion = None # type: Dict[str, Any] 193 222 #: units and limits for each parameter 194 details = None # type: Mapping[str, Tuple(str, float, float)] 195 #: multiplicity used, or None if no multiplicity controls 223 details = None # type: Dict[str, Sequence[Any]] 224 # # actual type is Dict[str, List[str, float, float]] 225 #: multiplicity value, or None if no multiplicity on the model 196 226 multiplicity = None # type: Optional[int] 197 198 def __init__(self, multiplicity): 199 # type: () -> None 200 print("initializing", self.name) 201 #raise Exception("first initialization") 202 self._model = None 203 204 ## _persistency_dict is used by sas.perspectives.fitting.basepage 205 ## to store dispersity reference. 227 #: memory for polydispersity array if using ArrayDispersion (used by sasview). 228 _persistency_dict = None # type: Dict[str, Tuple[np.ndarray, np.ndarray]] 229 230 def __init__(self, multiplicity=None): 231 # type: (Optional[int]) -> None 232 233 # TODO: _persistency_dict to persistency_dict throughout sasview 234 # TODO: refactor multiplicity to encompass variants 235 # TODO: dispersion should be a class 236 # TODO: refactor multiplicity info 237 # TODO: separate profile view from multiplicity 238 # The button label, x and y axis labels and scale need to be under 239 # the control of the model, not the fit page. Maximum flexibility, 240 # the fit page would supply the canvas and the profile could plot 241 # how it wants, but this assumes matplotlib. Next level is that 242 # we provide some sort of data description including title, labels 243 # and lines to plot. 244 245 # Get the list of hidden parameters given the mulitplicity 246 # Don't include multiplicity in the list of parameters 247 self.multiplicity = multiplicity 248 if multiplicity is not None: 249 hidden = self._model_info.get_hidden_parameters(multiplicity) 250 hidden |= set([self.multiplicity_info.control]) 251 else: 252 hidden = set() 253 206 254 self._persistency_dict = {} 207 208 self.multiplicity = multiplicity209 210 255 self.params = collections.OrderedDict() 211 256 self.dispersion = {} 212 257 self.details = {} 213 214 for p in self._model_info['parameters']: 258 for p in self._model_info.parameters.user_parameters(): 259 if p.name in hidden: 260 continue 215 261 self.params[p.name] = p.default 216 self.details[p.name] = [p.units] + p.limits 217 218 for name in self._model_info['partype']['pd-2d']: 219 self.dispersion[name] = { 220 'width': 0, 221 'npts': 35, 222 'nsigmas': 3, 223 'type': 'gaussian', 224 } 262 self.details[p.id] = [p.units, p.limits[0], p.limits[1]] 263 if p.polydisperse: 264 self.details[p.id+".width"] = [ 265 "", 0.0, 1.0 if p.relative_pd else np.inf 266 ] 267 self.dispersion[p.name] = { 268 'width': 0, 269 'npts': 35, 270 'nsigmas': 3, 271 'type': 'gaussian', 272 } 225 273 226 274 def __get_state__(self): 275 # type: () -> Dict[str, Any] 227 276 state = self.__dict__.copy() 228 277 state.pop('_model') … … 233 282 234 283 def __set_state__(self, state): 284 # type: (Dict[str, Any]) -> None 235 285 self.__dict__ = state 236 286 self._model = None 237 287 238 288 def __str__(self): 289 # type: () -> str 239 290 """ 240 291 :return: string representation … … 243 294 244 295 def is_fittable(self, par_name): 296 # type: (str) -> bool 245 297 """ 246 298 Check if a given parameter is fittable or not … … 253 305 254 306 255 # pylint: disable=no-self-use256 307 def getProfile(self): 308 # type: () -> (np.ndarray, np.ndarray) 257 309 """ 258 310 Get SLD profile … … 261 313 beta is a list of the corresponding SLD values 262 314 """ 263 return None, None 315 args = [] # type: List[Union[float, np.ndarray]] 316 for p in self._model_info.parameters.kernel_parameters: 317 if p.id == self.multiplicity_info.control: 318 args.append(float(self.multiplicity)) 319 elif p.length == 1: 320 args.append(self.params.get(p.id, np.NaN)) 321 else: 322 args.append([self.params.get(p.id+str(k), np.NaN) 323 for k in range(1,p.length+1)]) 324 return self._model_info.profile(*args) 264 325 265 326 def setParam(self, name, value): 327 # type: (str, float) -> None 266 328 """ 267 329 Set the value of a model parameter … … 290 352 291 353 def getParam(self, name): 354 # type: (str) -> float 292 355 """ 293 356 Set the value of a model parameter … … 313 376 314 377 def getParamList(self): 378 # type: () -> Sequence[str] 315 379 """ 316 380 Return a list of all available parameters for the model 317 381 """ 318 param_list = self.params.keys()382 param_list = list(self.params.keys()) 319 383 # WARNING: Extending the list with the dispersion parameters 320 384 param_list.extend(self.getDispParamList()) … … 322 386 323 387 def getDispParamList(self): 388 # type: () -> Sequence[str] 324 389 """ 325 390 Return a list of polydispersity parameters for the model 326 391 """ 327 392 # TODO: fix test so that parameter order doesn't matter 328 ret = ['%s.%s' % (d.lower(), p) 329 for d in self._model_info['partype']['pd-2d'] 330 for p in ('npts', 'nsigmas', 'width')] 393 ret = ['%s.%s' % (p.name.lower(), ext) 394 for p in self._model_info.parameters.user_parameters() 395 for ext in ('npts', 'nsigmas', 'width') 396 if p.polydisperse] 331 397 #print(ret) 332 398 return ret 333 399 334 400 def clone(self): 401 # type: () -> "SasviewModel" 335 402 """ Return a identical copy of self """ 336 403 return deepcopy(self) 337 404 338 405 def run(self, x=0.0): 406 # type: (Union[float, (float, float), List[float]]) -> float 339 407 """ 340 408 Evaluate the model … … 349 417 # pylint: disable=unpacking-non-sequence 350 418 q, phi = x 351 return self.calculate_Iq([q * math.cos(phi)], 352 [q * math.sin(phi)])[0] 353 else: 354 return self.calculate_Iq([float(x)])[0] 419 return self.calculate_Iq([q*math.cos(phi)], [q*math.sin(phi)])[0] 420 else: 421 return self.calculate_Iq([x])[0] 355 422 356 423 357 424 def runXY(self, x=0.0): 425 # type: (Union[float, (float, float), List[float]]) -> float 358 426 """ 359 427 Evaluate the model in cartesian coordinates … … 366 434 """ 367 435 if isinstance(x, (list, tuple)): 368 return self.calculate_Iq([ float(x[0])], [float(x[1])])[0]369 else: 370 return self.calculate_Iq([ float(x)])[0]436 return self.calculate_Iq([x[0]], [x[1]])[0] 437 else: 438 return self.calculate_Iq([x])[0] 371 439 372 440 def evalDistribution(self, qdist): 441 # type: (Union[np.ndarray, Tuple[np.ndarray, np.ndarray], List[np.ndarray]]) -> np.ndarray 373 442 r""" 374 443 Evaluate a distribution of q-values. … … 401 470 # Check whether we have a list of ndarrays [qx,qy] 402 471 qx, qy = qdist 403 partype = self._model_info['partype'] 404 if not partype['orientation'] and not partype['magnetic']: 472 if not self._model_info.parameters.has_2d: 405 473 return self.calculate_Iq(np.sqrt(qx ** 2 + qy ** 2)) 406 474 else: … … 415 483 % type(qdist)) 416 484 417 def calculate_Iq(self, *args): 485 def calculate_Iq(self, qx, qy=None): 486 # type: (Sequence[float], Optional[Sequence[float]]) -> np.ndarray 418 487 """ 419 488 Calculate Iq for one set of q with the current parameters. … … 426 495 if self._model is None: 427 496 self._model = core.build_model(self._model_info) 428 q_vectors = [np.asarray(q) for q in args] 429 fn = self._model.make_kernel(q_vectors) 430 pars = [self.params[v] for v in fn.fixed_pars] 431 pd_pars = [self._get_weights(p) for p in fn.pd_pars] 432 result = fn(pars, pd_pars, self.cutoff) 433 fn.q_input.release() 434 fn.release() 497 if qy is not None: 498 q_vectors = [np.asarray(qx), np.asarray(qy)] 499 else: 500 q_vectors = [np.asarray(qx)] 501 kernel = self._model.make_kernel(q_vectors) 502 pairs = [self._get_weights(p) 503 for p in self._model_info.parameters.call_parameters] 504 call_details, weight, value = details.build_details(kernel, pairs) 505 result = kernel(call_details, weight, value, cutoff=self.cutoff) 506 kernel.release() 435 507 return result 436 508 437 509 def calculate_ER(self): 510 # type: () -> float 438 511 """ 439 512 Calculate the effective radius for P(q)*S(q) … … 441 514 :return: the value of the effective radius 442 515 """ 443 ER = self._model_info.get('ER', None) 444 if ER is None: 516 if self._model_info.ER is None: 445 517 return 1.0 446 518 else: 447 value s, weights= self._dispersion_mesh()448 fv = ER(*values)519 value, weight = self._dispersion_mesh() 520 fv = self._model_info.ER(*value) 449 521 #print(values[0].shape, weights.shape, fv.shape) 450 return np.sum(weight s * fv) / np.sum(weights)522 return np.sum(weight * fv) / np.sum(weight) 451 523 452 524 def calculate_VR(self): 525 # type: () -> float 453 526 """ 454 527 Calculate the volf ratio for P(q)*S(q) … … 456 529 :return: the value of the volf ratio 457 530 """ 458 VR = self._model_info.get('VR', None) 459 if VR is None: 531 if self._model_info.VR is None: 460 532 return 1.0 461 533 else: 462 value s, weights= self._dispersion_mesh()463 whole, part = VR(*values)464 return np.sum(weight s * part) / np.sum(weights* whole)534 value, weight = self._dispersion_mesh() 535 whole, part = self._model_info.VR(*value) 536 return np.sum(weight * part) / np.sum(weight * whole) 465 537 466 538 def set_dispersion(self, parameter, dispersion): 539 # type: (str, weights.Dispersion) -> Dict[str, Any] 467 540 """ 468 541 Set the dispersion object for a model parameter … … 487 560 from . import weights 488 561 disperser = weights.dispersers[dispersion.__class__.__name__] 489 dispersion = weights. models[disperser]()562 dispersion = weights.MODELS[disperser]() 490 563 self.dispersion[parameter] = dispersion.get_pars() 491 564 else: … … 493 566 494 567 def _dispersion_mesh(self): 568 # type: () -> List[Tuple[np.ndarray, np.ndarray]] 495 569 """ 496 570 Create a mesh grid of dispersion parameters and weights. … … 500 574 parameter set in the vector. 501 575 """ 502 pars = self._model_info['partype']['volume'] 503 return core.dispersion_mesh([self._get_weights(p) for p in pars]) 576 pars = [self._get_weights(p) 577 for p in self._model_info.parameters.call_parameters 578 if p.type == 'volume'] 579 return details.dispersion_mesh(self._model_info, pars) 504 580 505 581 def _get_weights(self, par): 582 # type: (Parameter) -> Tuple[np.ndarray, np.ndarray] 506 583 """ 507 584 Return dispersion weights for parameter 508 585 """ 509 from . import weights 510 relative = self._model_info['partype']['pd-rel'] 511 limits = self._model_info['limits'] 512 dis = self.dispersion[par] 513 value, weight = weights.get_weights( 514 dis['type'], dis['npts'], dis['width'], dis['nsigmas'], 515 self.params[par], limits[par], par in relative) 516 return value, weight / np.sum(weight) 517 586 if par.name not in self.params: 587 if par.name == self.multiplicity_info.control: 588 return [self.multiplicity], [] 589 else: 590 return [np.NaN], [] 591 elif par.polydisperse: 592 dis = self.dispersion[par.name] 593 value, weight = weights.get_weights( 594 dis['type'], dis['npts'], dis['width'], dis['nsigmas'], 595 self.params[par.name], par.limits, par.relative_pd) 596 return value, weight / np.sum(weight) 597 else: 598 return [self.params[par.name]], [] 518 599 519 600 def test_model(): 601 # type: () -> float 520 602 """ 521 603 Test that a sasview model (cylinder) can be run. … … 525 607 return cylinder.evalDistribution([0.1,0.1]) 526 608 609 def test_rpa(): 610 # type: () -> float 611 """ 612 Test that a sasview model (cylinder) can be run. 613 """ 614 RPA = _make_standard_model('rpa') 615 rpa = RPA(3) 616 return rpa.evalDistribution([0.1,0.1]) 617 527 618 528 619 def test_model_list(): 620 # type: () -> None 529 621 """ 530 622 Make sure that all models build as sasview models.
Note: See TracChangeset
for help on using the changeset viewer.