Changeset d368d21 in sasmodels for sasmodels/compare.py
- Timestamp:
- Feb 6, 2016 10:56:54 AM (8 years ago)
- Branches:
- master, core_shell_microgels, costrafo411, magnetic_model, release_v0.94, release_v0.95, ticket-1257-vesicle-product, ticket_1156, ticket_1265_superball, ticket_822_more_unit_tests
- Children:
- 321736f
- Parents:
- 03582f9 (diff), d5e650d (diff)
Note: this is a merge changeset, the changes displayed below correspond to the merge itself.
Use the (diff) links above to see all the changes relative to each parent. - File:
-
- 1 edited
Legend:
- Unmodified
- Added
- Removed
-
sasmodels/compare.py
r608e31e rd5e650d 28 28 29 29 from __future__ import print_function 30 31 import sys 32 import math 33 from os.path import basename, dirname, join as joinpath 34 import glob 35 import datetime 36 import traceback 37 38 import numpy as np 39 40 from . import core 41 from . import kerneldll 42 from . import generate 43 from .data import plot_theory, empty_data1D, empty_data2D 44 from .direct_model import DirectModel 45 from .convert import revert_model, constrain_new_to_old 30 46 31 47 USAGE = """ … … 72 88 # doc string so that we can display it at run time if there is an error. 73 89 # lin 74 __doc__ = __doc__ + """ 90 __doc__ = (__doc__ # pylint: disable=redefined-builtin 91 + """ 75 92 Program description 76 93 ------------------- 77 94 78 """ + USAGE 79 80 81 82 import sys 83 import math 84 from os.path import basename, dirname, join as joinpath 85 import glob 86 import datetime 87 import traceback 88 89 import numpy as np 90 95 """ 96 + USAGE) 97 98 kerneldll.ALLOW_SINGLE_PRECISION_DLLS = True 99 100 # List of available models 91 101 ROOT = dirname(__file__) 92 sys.path.insert(0, ROOT) # Make sure sasmodels is first on the path93 94 95 from . import core96 from . import kerneldll97 from . import generate98 from .data import plot_theory, empty_data1D, empty_data2D99 from .direct_model import DirectModel100 from .convert import revert_model, constrain_new_to_old101 kerneldll.ALLOW_SINGLE_PRECISION_DLLS = True102 103 # List of available models104 102 MODELS = [basename(f)[:-3] 105 103 for f in sorted(glob.glob(joinpath(ROOT, "models", "[a-zA-Z]*.py")))] … … 115 113 return dt.total_seconds() 116 114 115 116 class push_seed(object): 117 """ 118 Set the seed value for the random number generator. 119 120 When used in a with statement, the random number generator state is 121 restored after the with statement is complete. 122 123 :Parameters: 124 125 *seed* : int or array_like, optional 126 Seed for RandomState 127 128 :Example: 129 130 Seed can be used directly to set the seed:: 131 132 >>> from numpy.random import randint 133 >>> push_seed(24) 134 <...push_seed object at...> 135 >>> print(randint(0,1000000,3)) 136 [242082 899 211136] 137 138 Seed can also be used in a with statement, which sets the random 139 number generator state for the enclosed computations and restores 140 it to the previous state on completion:: 141 142 >>> with push_seed(24): 143 ... print(randint(0,1000000,3)) 144 [242082 899 211136] 145 146 Using nested contexts, we can demonstrate that state is indeed 147 restored after the block completes:: 148 149 >>> with push_seed(24): 150 ... print(randint(0,1000000)) 151 ... with push_seed(24): 152 ... print(randint(0,1000000,3)) 153 ... print(randint(0,1000000)) 154 242082 155 [242082 899 211136] 156 899 157 158 The restore step is protected against exceptions in the block:: 159 160 >>> with push_seed(24): 161 ... print(randint(0,1000000)) 162 ... try: 163 ... with push_seed(24): 164 ... print(randint(0,1000000,3)) 165 ... raise Exception() 166 ... except: 167 ... print("Exception raised") 168 ... print(randint(0,1000000)) 169 242082 170 [242082 899 211136] 171 Exception raised 172 899 173 """ 174 def __init__(self, seed=None): 175 self._state = np.random.get_state() 176 np.random.seed(seed) 177 178 def __enter__(self): 179 return None 180 181 def __exit__(self, *args): 182 np.random.set_state(self._state) 117 183 118 184 def tic(): … … 177 243 return [0, (2*v if v > 0 else 1)] 178 244 245 179 246 def _randomize_one(p, v): 180 247 """ … … 186 253 return np.random.uniform(*parameter_range(p, v)) 187 254 255 188 256 def randomize_pars(pars, seed=None): 189 257 """ … … 195 263 :func:`constrain_pars` needs to be called afterward.. 196 264 """ 197 np.random.seed(seed)198 # Note: the sort guarantees order `of calls to random number generator199 pars = dict((p, _randomize_one(p, v))200 for p, v in sorted(pars.items()))265 with push_seed(seed): 266 # Note: the sort guarantees order `of calls to random number generator 267 pars = dict((p, _randomize_one(p, v)) 268 for p, v in sorted(pars.items())) 201 269 return pars 202 270 … … 281 349 theory = lambda: smearer.get_value() 282 350 else: 283 theory = lambda: model.evalDistribution([data.qx_data[index], data.qy_data[index]]) 351 theory = lambda: model.evalDistribution([data.qx_data[index], 352 data.qy_data[index]]) 284 353 elif smearer is not None: 285 354 theory = lambda: smearer(model.evalDistribution(data.x)) … … 416 485 try: 417 486 base_value, base_time = time_calculation(base, pars, Nbase) 418 print("%s t=%.1f ms, intensity=%.0f"%(base.engine, base_time, sum(base_value))) 487 print("%s t=%.1f ms, intensity=%.0f" 488 % (base.engine, base_time, sum(base_value))) 419 489 except ImportError: 420 490 traceback.print_exc() … … 426 496 try: 427 497 comp_value, comp_time = time_calculation(comp, pars, Ncomp) 428 print("%s t=%.1f ms, intensity=%.0f"%(comp.engine, comp_time, sum(comp_value))) 498 print("%s t=%.1f ms, intensity=%.0f" 499 % (comp.engine, comp_time, sum(comp_value))) 429 500 except ImportError: 430 501 traceback.print_exc() … … 433 504 # Compare, but only if computing both forms 434 505 if Nbase > 0 and Ncomp > 0: 435 #print("speedup %.2g"%(comp_time/base_time))436 #print("max |base/comp|", max(abs(base_value/comp_value)), "%.15g"%max(abs(base_value)), "%.15g"%max(abs(comp_value)))437 #comp *= max(base_value/comp_value)438 506 resid = (base_value - comp_value) 439 507 relerr = resid/comp_value 440 _print_stats("|%s-%s|"%(base.engine, comp.engine) + (" "*(3+len(comp.engine))), 508 _print_stats("|%s-%s|" 509 % (base.engine, comp.engine) + (" "*(3+len(comp.engine))), 441 510 resid) 442 _print_stats("|(%s-%s)/%s|"%(base.engine, comp.engine, comp.engine), 511 _print_stats("|(%s-%s)/%s|" 512 % (base.engine, comp.engine, comp.engine), 443 513 relerr) 444 514 … … 459 529 if Nbase > 0: 460 530 if Ncomp > 0: plt.subplot(131) 461 plot_theory(data, base_value, view=view, plot_data=False, limits=limits)531 plot_theory(data, base_value, view=view, use_data=False, limits=limits) 462 532 plt.title("%s t=%.1f ms"%(base.engine, base_time)) 463 533 #cbar_title = "log I" 464 534 if Ncomp > 0: 465 535 if Nbase > 0: plt.subplot(132) 466 plot_theory(data, comp_value, view=view, plot_data=False, limits=limits)536 plot_theory(data, comp_value, view=view, use_data=False, limits=limits) 467 537 plt.title("%s t=%.1f ms"%(comp.engine, comp_time)) 468 538 #cbar_title = "log I" 469 539 if Ncomp > 0 and Nbase > 0: 470 540 plt.subplot(133) 471 if '-abs' in opts:541 if not opts['rel_err']: 472 542 err, errstr, errview = resid, "abs err", "linear" 473 543 else: 474 544 err, errstr, errview = abs(relerr), "rel err", "log" 475 545 #err,errstr = base/comp,"ratio" 476 plot_theory(data, None, resid=err, view=errview, plot_data=False) 546 plot_theory(data, None, resid=err, view=errview, use_data=False) 547 if view == 'linear': 548 plt.xscale('linear') 477 549 plt.title("max %s = %.3g"%(errstr, max(abs(err)))) 478 550 #cbar_title = errstr if errview=="linear" else "log "+errstr … … 534 606 def columnize(L, indent="", width=79): 535 607 """ 536 Format a list of strings into columns for printing. 608 Format a list of strings into columns. 609 610 Returns a string with carriage returns ready for printing. 537 611 """ 538 612 column_width = max(len(w) for w in L) + 1 … … 591 665 invalid = [o[1:] for o in flags 592 666 if o[1:] not in NAME_OPTIONS 593 667 and not any(o.startswith('-%s='%t) for t in VALUE_OPTIONS)] 594 668 if invalid: 595 669 print("Invalid options: %s"%(", ".join(invalid))) … … 597 671 598 672 673 # pylint: disable=bad-whitespace 599 674 # Interpret the flags 600 675 opts = { … … 651 726 elif arg == '-sasview': engines.append(arg[1:]) 652 727 elif arg == '-edit': opts['explore'] = True 728 # pylint: enable=bad-whitespace 653 729 654 730 if len(engines) == 0: … … 675 751 presets = {} 676 752 for arg in values: 677 k, v = arg.split('=',1)753 k, v = arg.split('=', 1) 678 754 if k not in pars: 679 755 # extract base name without polydispersity info 680 756 s = set(p.split('_pd')[0] for p in pars) 681 print("%r invalid; parameters are: %s"%(k, ", ".join(sorted(s))))757 print("%r invalid; parameters are: %s"%(k, ", ".join(sorted(s)))) 682 758 sys.exit(1) 683 759 presets[k] = float(v) if not k.endswith('type') else v … … 697 773 698 774 # Create the computational engines 699 data, _ index= make_data(opts)775 data, _ = make_data(opts) 700 776 if n1: 701 777 base = make_engine(model_definition, data, engines[0], opts['cutoff']) … … 707 783 comp = None 708 784 785 # pylint: disable=bad-whitespace 709 786 # Remember it all 710 787 opts.update({ … … 718 795 'engines' : [base, comp], 719 796 }) 797 # pylint: enable=bad-whitespace 720 798 721 799 return opts 722 800 723 def main():724 opts = parse_opts()725 if opts['explore']:726 explore(opts)727 else:728 compare(opts)729 730 801 def explore(opts): 802 """ 803 Explore the model using the Bumps GUI. 804 """ 731 805 import wx 732 806 from bumps.names import FitProblem … … 734 808 735 809 problem = FitProblem(Explore(opts)) 736 is Mac = "cocoa" in wx.version()810 is_mac = "cocoa" in wx.version() 737 811 app = wx.App() 738 812 frame = AppFrame(parent=None, title="explore") 739 if not is Mac: frame.Show()813 if not is_mac: frame.Show() 740 814 frame.panel.set_model(model=problem) 741 815 frame.panel.Layout() 742 816 frame.panel.aui.Split(0, wx.TOP) 743 if is Mac: frame.Show()817 if is_mac: frame.Show() 744 818 app.MainLoop() 745 819 746 820 class Explore(object): 747 821 """ 748 Return a bumps wrapper for a SAS model comparison. 822 Bumps wrapper for a SAS model comparison. 823 824 The resulting object can be used as a Bumps fit problem so that 825 parameters can be adjusted in the GUI, with plots updated on the fly. 749 826 """ 750 827 def __init__(self, opts): … … 787 864 Return cost. 788 865 """ 866 # pylint: disable=no-self-use 789 867 return 0. # No nllf 790 868 … … 804 882 805 883 884 def main(): 885 """ 886 Main program. 887 """ 888 opts = parse_opts() 889 if opts['explore']: 890 explore(opts) 891 else: 892 compare(opts) 893 806 894 if __name__ == "__main__": 807 895 main()
Note: See TracChangeset
for help on using the changeset viewer.