Changeset dd7fc12 in sasmodels for sasmodels/compare.py
- Timestamp:
- Apr 15, 2016 11:11:43 AM (8 years ago)
- Branches:
- master, core_shell_microgels, costrafo411, magnetic_model, release_v0.94, release_v0.95, ticket-1257-vesicle-product, ticket_1156, ticket_1265_superball, ticket_822_more_unit_tests
- Children:
- 3599d36
- Parents:
- b151003
- File:
-
- 1 edited
Legend:
- Unmodified
- Added
- Removed
-
sasmodels/compare.py
rb151003 rdd7fc12 41 41 from .direct_model import DirectModel 42 42 from .convert import revert_name, revert_pars, constrain_new_to_old 43 44 try: 45 from typing import Optional, Dict, Any, Callable, Tuple 46 except: 47 pass 48 else: 49 from .modelinfo import ModelInfo, Parameter, ParameterSet 50 from .data import Data 51 Calculator = Callable[[float, ...], np.ndarray] 43 52 44 53 USAGE = """ … … 97 106 kerneldll.ALLOW_SINGLE_PRECISION_DLLS = True 98 107 99 MODELS = core.list_models()100 101 108 # CRUFT python 2.6 102 109 if not hasattr(datetime.timedelta, 'total_seconds'): … … 160 167 ... print(randint(0,1000000,3)) 161 168 ... raise Exception() 162 ... except :169 ... except Exception: 163 170 ... print("Exception raised") 164 171 ... print(randint(0,1000000)) … … 169 176 """ 170 177 def __init__(self, seed=None): 178 # type: (Optional[int]) -> None 171 179 self._state = np.random.get_state() 172 180 np.random.seed(seed) 173 181 174 182 def __enter__(self): 175 return None 176 177 def __exit__(self, *args): 183 # type: () -> None 184 pass 185 186 def __exit__(self, type, value, traceback): 187 # type: (Any, BaseException, Any) -> None 188 # TODO: better typing for __exit__ method 178 189 np.random.set_state(self._state) 179 190 180 191 def tic(): 192 # type: () -> Callable[[], float] 181 193 """ 182 194 Timer function. … … 190 202 191 203 def set_beam_stop(data, radius, outer=None): 204 # type: (Data, float, float) -> None 192 205 """ 193 206 Add a beam stop of the given *radius*. If *outer*, make an annulus. 194 207 195 Note: this function does not use the sasview package208 Note: this function does not require sasview 196 209 """ 197 210 if hasattr(data, 'qx_data'): … … 207 220 208 221 def parameter_range(p, v): 222 # type: (str, float) -> Tuple[float, float] 209 223 """ 210 224 Choose a parameter range based on parameter name and initial value. … … 212 226 # process the polydispersity options 213 227 if p.endswith('_pd_n'): 214 return [0, 100]228 return 0., 100. 215 229 elif p.endswith('_pd_nsigma'): 216 return [0, 5]230 return 0., 5. 217 231 elif p.endswith('_pd_type'): 218 r eturn v232 raise ValueError("Cannot return a range for a string value") 219 233 elif any(s in p for s in ('theta', 'phi', 'psi')): 220 234 # orientation in [-180,180], orientation pd in [0,45] 221 235 if p.endswith('_pd'): 222 return [0, 45]236 return 0., 45. 223 237 else: 224 return [-180, 180]238 return -180., 180. 225 239 elif p.endswith('_pd'): 226 return [0, 1]240 return 0., 1. 227 241 elif 'sld' in p: 228 return [-0.5, 10]242 return -0.5, 10. 229 243 elif p == 'background': 230 return [0, 10]244 return 0., 10. 231 245 elif p == 'scale': 232 return [0, 1e3]233 elif v < 0 :234 return [2*v, -2*v]246 return 0., 1.e3 247 elif v < 0.: 248 return 2.*v, -2.*v 235 249 else: 236 return [0, (2*v if v > 0 else 1)]250 return 0., (2.*v if v > 0. else 1.) 237 251 238 252 239 253 def _randomize_one(model_info, p, v): 254 # type: (ModelInfo, str, float) -> float 255 # type: (ModelInfo, str, str) -> str 240 256 """ 241 257 Randomize a single parameter. … … 263 279 264 280 def randomize_pars(model_info, pars, seed=None): 281 # type: (ModelInfo, ParameterSet, int) -> ParameterSet 265 282 """ 266 283 Generate random values for all of the parameters. … … 273 290 with push_seed(seed): 274 291 # Note: the sort guarantees order `of calls to random number generator 275 pars = dict((p, _randomize_one(model_info, p, v))276 for p, v in sorted(pars.items()))277 return pars292 random_pars = dict((p, _randomize_one(model_info, p, v)) 293 for p, v in sorted(pars.items())) 294 return random_pars 278 295 279 296 def constrain_pars(model_info, pars): 297 # type: (ModelInfo, ParameterSet) -> None 280 298 """ 281 299 Restrict parameters to valid values. … … 284 302 which need to support within model constraints (cap radius more than 285 303 cylinder radius in this case). 304 305 Warning: this updates the *pars* dictionary in place. 286 306 """ 287 307 name = model_info.id … … 315 335 316 336 def parlist(model_info, pars, is2d): 337 # type: (ModelInfo, ParameterSet, bool) -> str 317 338 """ 318 339 Format the parameter list for printing. … … 326 347 n=int(pars.get(p.id+"_pd_n", 0)), 327 348 nsigma=pars.get(p.id+"_pd_nsgima", 3.), 328 type=pars.get(p.id+"_pd_type", 'gaussian')) 349 pdtype=pars.get(p.id+"_pd_type", 'gaussian'), 350 ) 329 351 lines.append(_format_par(p.name, **fields)) 330 352 return "\n".join(lines) … … 332 354 #return "\n".join("%s: %s"%(p, v) for p, v in sorted(pars.items())) 333 355 334 def _format_par(name, value=0., pd=0., n=0, nsigma=3., type='gaussian'): 356 def _format_par(name, value=0., pd=0., n=0, nsigma=3., pdtype='gaussian'): 357 # type: (str, float, float, int, float, str) -> str 335 358 line = "%s: %g"%(name, value) 336 359 if pd != 0. and n != 0: 337 360 line += " +/- %g (%d points in [-%g,%g] sigma %s)"\ 338 % (pd, n, nsigma, nsigma, type)361 % (pd, n, nsigma, nsigma, pdtype) 339 362 return line 340 363 341 364 def suppress_pd(pars): 365 # type: (ParameterSet) -> ParameterSet 342 366 """ 343 367 Suppress theta_pd for now until the normalization is resolved. … … 352 376 353 377 def eval_sasview(model_info, data): 378 # type: (Modelinfo, Data) -> Calculator 354 379 """ 355 380 Return a model calculator using the pre-4.0 SasView models. … … 359 384 import sas 360 385 from sas.models.qsmearing import smear_selection 386 import sas.models 361 387 362 388 def get_model(name): 389 # type: (str) -> "sas.models.BaseComponent" 363 390 #print("new",sorted(_pars.items())) 364 sas =__import__('sas.models.' + name)391 __import__('sas.models.' + name) 365 392 ModelClass = getattr(getattr(sas.models, name, None), name, None) 366 393 if ModelClass is None: … … 400 427 401 428 def calculator(**pars): 429 # type: (float, ...) -> np.ndarray 402 430 """ 403 431 Sasview calculator for model. … … 406 434 pars = revert_pars(model_info, pars) 407 435 for k, v in pars.items(): 408 parts= k.split('.') # polydispersity components409 if len( parts) == 2:410 model.dispersion[ parts[0]][parts[1]] = v436 name_attr = k.split('.') # polydispersity components 437 if len(name_attr) == 2: 438 model.dispersion[name_attr[0]][name_attr[1]] = v 411 439 else: 412 440 model.setParam(k, v) … … 428 456 } 429 457 def eval_opencl(model_info, data, dtype='single', cutoff=0.): 458 # type: (ModelInfo, Data, str, float) -> Calculator 430 459 """ 431 460 Return a model calculator using the OpenCL calculation engine. … … 442 471 443 472 def eval_ctypes(model_info, data, dtype='double', cutoff=0.): 473 # type: (ModelInfo, Data, str, float) -> Calculator 444 474 """ 445 475 Return a model calculator using the DLL calculation engine. 446 476 """ 447 if dtype == 'quad':448 dtype = 'longdouble'449 477 model = core.build_model(model_info, dtype=dtype, platform="dll") 450 478 calculator = DirectModel(data, model, cutoff=cutoff) … … 453 481 454 482 def time_calculation(calculator, pars, Nevals=1): 483 # type: (Calculator, ParameterSet, int) -> Tuple[np.ndarray, float] 455 484 """ 456 485 Compute the average calculation time over N evaluations. … … 461 490 # initialize the code so time is more accurate 462 491 if Nevals > 1: 463 value =calculator(**suppress_pd(pars))492 calculator(**suppress_pd(pars)) 464 493 toc = tic() 465 for _ in range(max(Nevals, 1)): # make sure there is at least one eval 494 # make sure there is at least one eval 495 value = calculator(**pars) 496 for _ in range(Nevals-1): 466 497 value = calculator(**pars) 467 498 average_time = toc()*1000./Nevals … … 469 500 470 501 def make_data(opts): 502 # type: (Dict[str, Any]) -> Tuple[Data, np.ndarray] 471 503 """ 472 504 Generate an empty dataset, used with the model to set Q points … … 478 510 qmax, nq, res = opts['qmax'], opts['nq'], opts['res'] 479 511 if opts['is2d']: 480 data = empty_data2D(np.linspace(-qmax, qmax, nq), resolution=res) 512 q = np.linspace(-qmax, qmax, nq) # type: np.ndarray 513 data = empty_data2D(q, resolution=res) 481 514 data.accuracy = opts['accuracy'] 482 515 set_beam_stop(data, 0.0004) … … 495 528 496 529 def make_engine(model_info, data, dtype, cutoff): 530 # type: (ModelInfo, Data, str, float) -> Calculator 497 531 """ 498 532 Generate the appropriate calculation engine for the given datatype. … … 509 543 510 544 def _show_invalid(data, theory): 545 # type: (Data, np.ma.ndarray) -> None 546 """ 547 Display a list of the non-finite values in theory. 548 """ 511 549 if not theory.mask.any(): 512 550 return … … 514 552 if hasattr(data, 'x'): 515 553 bad = zip(data.x[theory.mask], theory[theory.mask]) 516 print(" *** ", ", ".join("I(%g)=%g"%(x, y) for x, y in bad))554 print(" *** ", ", ".join("I(%g)=%g"%(x, y) for x, y in bad)) 517 555 518 556 519 557 def compare(opts, limits=None): 558 # type: (Dict[str, Any], Optional[Tuple[float, float]]) -> Tuple[float, float] 520 559 """ 521 560 Preform a comparison using options from the command line. … … 532 571 data = opts['data'] 533 572 573 # silence the linter 574 base = opts['engines'][0] if Nbase else None 575 comp = opts['engines'][1] if Ncomp else None 576 base_time = comp_time = None 577 base_value = comp_value = resid = relerr = None 578 534 579 # Base calculation 535 580 if Nbase > 0: 536 base = opts['engines'][0]537 581 try: 538 base_ value, base_time = time_calculation(base, pars, Nbase)539 base_value = np.ma.masked_invalid(base_ value)582 base_raw, base_time = time_calculation(base, pars, Nbase) 583 base_value = np.ma.masked_invalid(base_raw) 540 584 print("%s t=%.2f ms, intensity=%.0f" 541 585 % (base.engine, base_time, base_value.sum())) … … 547 591 # Comparison calculation 548 592 if Ncomp > 0: 549 comp = opts['engines'][1]550 593 try: 551 comp_ value, comp_time = time_calculation(comp, pars, Ncomp)552 comp_value = np.ma.masked_invalid(comp_ value)594 comp_raw, comp_time = time_calculation(comp, pars, Ncomp) 595 comp_value = np.ma.masked_invalid(comp_raw) 553 596 print("%s t=%.2f ms, intensity=%.0f" 554 597 % (comp.engine, comp_time, comp_value.sum())) … … 625 668 626 669 def _print_stats(label, err): 670 # type: (str, np.ma.ndarray) -> None 671 # work with trimmed data, not the full set 627 672 sorted_err = np.sort(abs(err.compressed())) 628 p50 = int((len( err)-1)*0.50)629 p98 = int((len( err)-1)*0.98)673 p50 = int((len(sorted_err)-1)*0.50) 674 p98 = int((len(sorted_err)-1)*0.98) 630 675 data = [ 631 676 "max:%.3e"%sorted_err[-1], 632 677 "median:%.3e"%sorted_err[p50], 633 678 "98%%:%.3e"%sorted_err[p98], 634 "rms:%.3e"%np.sqrt(np.mean( err**2)),635 "zero-offset:%+.3e"%np.mean( err),679 "rms:%.3e"%np.sqrt(np.mean(sorted_err**2)), 680 "zero-offset:%+.3e"%np.mean(sorted_err), 636 681 ] 637 682 print(label+" "+" ".join(data)) … … 662 707 663 708 def columnize(L, indent="", width=79): 709 # type: (List[str], str, int) -> str 664 710 """ 665 711 Format a list of strings into columns. … … 679 725 680 726 def get_pars(model_info, use_demo=False): 727 # type: (ModelInfo, bool) -> ParameterSet 681 728 """ 682 729 Extract demo parameters from the model definition. … … 704 751 705 752 def parse_opts(): 753 # type: () -> Dict[str, Any] 706 754 """ 707 755 Parse command line options. … … 757 805 'explore' : False, 758 806 'use_demo' : True, 807 'zero' : False, 759 808 } 760 809 engines = [] … … 777 826 elif arg.startswith('-cutoff='): opts['cutoff'] = float(arg[8:]) 778 827 elif arg.startswith('-random='): opts['seed'] = int(arg[8:]) 779 elif arg == '-random': opts['seed'] = np.random.randint(1 e6)828 elif arg == '-random': opts['seed'] = np.random.randint(1000000) 780 829 elif arg == '-preset': opts['seed'] = -1 781 830 elif arg == '-mono': opts['mono'] = True … … 874 923 875 924 def explore(opts): 925 # type: (Dict[str, Any]) -> None 876 926 """ 877 927 Explore the model using the Bumps GUI. … … 900 950 """ 901 951 def __init__(self, opts): 952 # type: (Dict[str, Any]) -> None 902 953 from bumps.cli import config_matplotlib # type: ignore 903 954 from . import bumps_model … … 923 974 924 975 def numpoints(self): 976 # type: () -> int 925 977 """ 926 978 Return the number of points. … … 929 981 930 982 def parameters(self): 983 # type: () -> Any # Dict/List hierarchy of parameters 931 984 """ 932 985 Return a dictionary of parameters. … … 935 988 936 989 def nllf(self): 990 # type: () -> float 937 991 """ 938 992 Return cost. … … 942 996 943 997 def plot(self, view='log'): 998 # type: (str) -> None 944 999 """ 945 1000 Plot the data and residuals. … … 951 1006 if self.limits is None: 952 1007 vmin, vmax = limits 953 vmax = 1.3*vmax 954 vmin = vmax*1e-7 955 self.limits = vmin, vmax 1008 self.limits = vmax*1e-7, 1.3*vmax 956 1009 957 1010 958 1011 def main(): 1012 # type: () -> None 959 1013 """ 960 1014 Main program.
Note: See TracChangeset
for help on using the changeset viewer.