Changeset d368d21 in sasmodels for sasmodels/compare.py


Ignore:
Timestamp:
Feb 6, 2016 10:56:54 AM (8 years ago)
Author:
butler
Branches:
master, core_shell_microgels, costrafo411, magnetic_model, release_v0.94, release_v0.95, ticket-1257-vesicle-product, ticket_1156, ticket_1265_superball, ticket_822_more_unit_tests
Children:
321736f
Parents:
03582f9 (diff), d5e650d (diff)
Note: this is a merge changeset, the changes displayed below correspond to the merge itself.
Use the (diff) links above to see all the changes relative to each parent.
Message:

Merge branch 'master' of https://github.com/SasView/sasmodels.git

Conflicts:

sasmodels/convert.py

File:
1 edited

Legend:

Unmodified
Added
Removed
  • sasmodels/compare.py

    r608e31e rd5e650d  
    2828 
    2929from __future__ import print_function 
     30 
     31import sys 
     32import math 
     33from os.path import basename, dirname, join as joinpath 
     34import glob 
     35import datetime 
     36import traceback 
     37 
     38import numpy as np 
     39 
     40from . import core 
     41from . import kerneldll 
     42from . import generate 
     43from .data import plot_theory, empty_data1D, empty_data2D 
     44from .direct_model import DirectModel 
     45from .convert import revert_model, constrain_new_to_old 
    3046 
    3147USAGE = """ 
     
    7288# doc string so that we can display it at run time if there is an error. 
    7389# lin 
    74 __doc__ = __doc__ + """ 
     90__doc__ = (__doc__  # pylint: disable=redefined-builtin 
     91           + """ 
    7592Program description 
    7693------------------- 
    7794 
    78 """ + USAGE 
    79  
    80  
    81  
    82 import sys 
    83 import math 
    84 from os.path import basename, dirname, join as joinpath 
    85 import glob 
    86 import datetime 
    87 import traceback 
    88  
    89 import numpy as np 
    90  
     95""" 
     96           + USAGE) 
     97 
     98kerneldll.ALLOW_SINGLE_PRECISION_DLLS = True 
     99 
     100# List of available models 
    91101ROOT = dirname(__file__) 
    92 sys.path.insert(0, ROOT)  # Make sure sasmodels is first on the path 
    93  
    94  
    95 from . import core 
    96 from . import kerneldll 
    97 from . import generate 
    98 from .data import plot_theory, empty_data1D, empty_data2D 
    99 from .direct_model import DirectModel 
    100 from .convert import revert_model, constrain_new_to_old 
    101 kerneldll.ALLOW_SINGLE_PRECISION_DLLS = True 
    102  
    103 # List of available models 
    104102MODELS = [basename(f)[:-3] 
    105103          for f in sorted(glob.glob(joinpath(ROOT, "models", "[a-zA-Z]*.py")))] 
     
    115113        return dt.total_seconds() 
    116114 
     115 
     116class push_seed(object): 
     117    """ 
     118    Set the seed value for the random number generator. 
     119 
     120    When used in a with statement, the random number generator state is 
     121    restored after the with statement is complete. 
     122 
     123    :Parameters: 
     124 
     125    *seed* : int or array_like, optional 
     126        Seed for RandomState 
     127 
     128    :Example: 
     129 
     130    Seed can be used directly to set the seed:: 
     131 
     132        >>> from numpy.random import randint 
     133        >>> push_seed(24) 
     134        <...push_seed object at...> 
     135        >>> print(randint(0,1000000,3)) 
     136        [242082    899 211136] 
     137 
     138    Seed can also be used in a with statement, which sets the random 
     139    number generator state for the enclosed computations and restores 
     140    it to the previous state on completion:: 
     141 
     142        >>> with push_seed(24): 
     143        ...    print(randint(0,1000000,3)) 
     144        [242082    899 211136] 
     145 
     146    Using nested contexts, we can demonstrate that state is indeed 
     147    restored after the block completes:: 
     148 
     149        >>> with push_seed(24): 
     150        ...    print(randint(0,1000000)) 
     151        ...    with push_seed(24): 
     152        ...        print(randint(0,1000000,3)) 
     153        ...    print(randint(0,1000000)) 
     154        242082 
     155        [242082    899 211136] 
     156        899 
     157 
     158    The restore step is protected against exceptions in the block:: 
     159 
     160        >>> with push_seed(24): 
     161        ...    print(randint(0,1000000)) 
     162        ...    try: 
     163        ...        with push_seed(24): 
     164        ...            print(randint(0,1000000,3)) 
     165        ...            raise Exception() 
     166        ...    except: 
     167        ...        print("Exception raised") 
     168        ...    print(randint(0,1000000)) 
     169        242082 
     170        [242082    899 211136] 
     171        Exception raised 
     172        899 
     173    """ 
     174    def __init__(self, seed=None): 
     175        self._state = np.random.get_state() 
     176        np.random.seed(seed) 
     177 
     178    def __enter__(self): 
     179        return None 
     180 
     181    def __exit__(self, *args): 
     182        np.random.set_state(self._state) 
    117183 
    118184def tic(): 
     
    177243        return [0, (2*v if v > 0 else 1)] 
    178244 
     245 
    179246def _randomize_one(p, v): 
    180247    """ 
     
    186253        return np.random.uniform(*parameter_range(p, v)) 
    187254 
     255 
    188256def randomize_pars(pars, seed=None): 
    189257    """ 
     
    195263    :func:`constrain_pars` needs to be called afterward.. 
    196264    """ 
    197     np.random.seed(seed) 
    198     # Note: the sort guarantees order `of calls to random number generator 
    199     pars = dict((p, _randomize_one(p, v)) 
    200                 for p, v in sorted(pars.items())) 
     265    with push_seed(seed): 
     266        # Note: the sort guarantees order `of calls to random number generator 
     267        pars = dict((p, _randomize_one(p, v)) 
     268                    for p, v in sorted(pars.items())) 
    201269    return pars 
    202270 
     
    281349            theory = lambda: smearer.get_value() 
    282350        else: 
    283             theory = lambda: model.evalDistribution([data.qx_data[index], data.qy_data[index]]) 
     351            theory = lambda: model.evalDistribution([data.qx_data[index], 
     352                                                     data.qy_data[index]]) 
    284353    elif smearer is not None: 
    285354        theory = lambda: smearer(model.evalDistribution(data.x)) 
     
    416485        try: 
    417486            base_value, base_time = time_calculation(base, pars, Nbase) 
    418             print("%s t=%.1f ms, intensity=%.0f"%(base.engine, base_time, sum(base_value))) 
     487            print("%s t=%.1f ms, intensity=%.0f" 
     488                  % (base.engine, base_time, sum(base_value))) 
    419489        except ImportError: 
    420490            traceback.print_exc() 
     
    426496        try: 
    427497            comp_value, comp_time = time_calculation(comp, pars, Ncomp) 
    428             print("%s t=%.1f ms, intensity=%.0f"%(comp.engine, comp_time, sum(comp_value))) 
     498            print("%s t=%.1f ms, intensity=%.0f" 
     499                  % (comp.engine, comp_time, sum(comp_value))) 
    429500        except ImportError: 
    430501            traceback.print_exc() 
     
    433504    # Compare, but only if computing both forms 
    434505    if Nbase > 0 and Ncomp > 0: 
    435         #print("speedup %.2g"%(comp_time/base_time)) 
    436         #print("max |base/comp|", max(abs(base_value/comp_value)), "%.15g"%max(abs(base_value)), "%.15g"%max(abs(comp_value))) 
    437         #comp *= max(base_value/comp_value) 
    438506        resid = (base_value - comp_value) 
    439507        relerr = resid/comp_value 
    440         _print_stats("|%s-%s|"%(base.engine, comp.engine) + (" "*(3+len(comp.engine))), 
     508        _print_stats("|%s-%s|" 
     509                     % (base.engine, comp.engine) + (" "*(3+len(comp.engine))), 
    441510                     resid) 
    442         _print_stats("|(%s-%s)/%s|"%(base.engine, comp.engine, comp.engine), 
     511        _print_stats("|(%s-%s)/%s|" 
     512                     % (base.engine, comp.engine, comp.engine), 
    443513                     relerr) 
    444514 
     
    459529    if Nbase > 0: 
    460530        if Ncomp > 0: plt.subplot(131) 
    461         plot_theory(data, base_value, view=view, plot_data=False, limits=limits) 
     531        plot_theory(data, base_value, view=view, use_data=False, limits=limits) 
    462532        plt.title("%s t=%.1f ms"%(base.engine, base_time)) 
    463533        #cbar_title = "log I" 
    464534    if Ncomp > 0: 
    465535        if Nbase > 0: plt.subplot(132) 
    466         plot_theory(data, comp_value, view=view, plot_data=False, limits=limits) 
     536        plot_theory(data, comp_value, view=view, use_data=False, limits=limits) 
    467537        plt.title("%s t=%.1f ms"%(comp.engine, comp_time)) 
    468538        #cbar_title = "log I" 
    469539    if Ncomp > 0 and Nbase > 0: 
    470540        plt.subplot(133) 
    471         if '-abs' in opts: 
     541        if not opts['rel_err']: 
    472542            err, errstr, errview = resid, "abs err", "linear" 
    473543        else: 
    474544            err, errstr, errview = abs(relerr), "rel err", "log" 
    475545        #err,errstr = base/comp,"ratio" 
    476         plot_theory(data, None, resid=err, view=errview, plot_data=False) 
     546        plot_theory(data, None, resid=err, view=errview, use_data=False) 
     547        if view == 'linear': 
     548            plt.xscale('linear') 
    477549        plt.title("max %s = %.3g"%(errstr, max(abs(err)))) 
    478550        #cbar_title = errstr if errview=="linear" else "log "+errstr 
     
    534606def columnize(L, indent="", width=79): 
    535607    """ 
    536     Format a list of strings into columns for printing. 
     608    Format a list of strings into columns. 
     609 
     610    Returns a string with carriage returns ready for printing. 
    537611    """ 
    538612    column_width = max(len(w) for w in L) + 1 
     
    591665    invalid = [o[1:] for o in flags 
    592666               if o[1:] not in NAME_OPTIONS 
    593                    and not any(o.startswith('-%s='%t) for t in VALUE_OPTIONS)] 
     667               and not any(o.startswith('-%s='%t) for t in VALUE_OPTIONS)] 
    594668    if invalid: 
    595669        print("Invalid options: %s"%(", ".join(invalid))) 
     
    597671 
    598672 
     673    # pylint: disable=bad-whitespace 
    599674    # Interpret the flags 
    600675    opts = { 
     
    651726        elif arg == '-sasview': engines.append(arg[1:]) 
    652727        elif arg == '-edit':    opts['explore'] = True 
     728    # pylint: enable=bad-whitespace 
    653729 
    654730    if len(engines) == 0: 
     
    675751    presets = {} 
    676752    for arg in values: 
    677         k,v = arg.split('=',1) 
     753        k, v = arg.split('=', 1) 
    678754        if k not in pars: 
    679755            # extract base name without polydispersity info 
    680756            s = set(p.split('_pd')[0] for p in pars) 
    681             print("%r invalid; parameters are: %s"%(k,", ".join(sorted(s)))) 
     757            print("%r invalid; parameters are: %s"%(k, ", ".join(sorted(s)))) 
    682758            sys.exit(1) 
    683759        presets[k] = float(v) if not k.endswith('type') else v 
     
    697773 
    698774    # Create the computational engines 
    699     data, _index = make_data(opts) 
     775    data, _ = make_data(opts) 
    700776    if n1: 
    701777        base = make_engine(model_definition, data, engines[0], opts['cutoff']) 
     
    707783        comp = None 
    708784 
     785    # pylint: disable=bad-whitespace 
    709786    # Remember it all 
    710787    opts.update({ 
     
    718795        'engines'   : [base, comp], 
    719796    }) 
     797    # pylint: enable=bad-whitespace 
    720798 
    721799    return opts 
    722800 
    723 def main(): 
    724     opts = parse_opts() 
    725     if opts['explore']: 
    726         explore(opts) 
    727     else: 
    728         compare(opts) 
    729  
    730801def explore(opts): 
     802    """ 
     803    Explore the model using the Bumps GUI. 
     804    """ 
    731805    import wx 
    732806    from bumps.names import FitProblem 
     
    734808 
    735809    problem = FitProblem(Explore(opts)) 
    736     isMac = "cocoa" in wx.version() 
     810    is_mac = "cocoa" in wx.version() 
    737811    app = wx.App() 
    738812    frame = AppFrame(parent=None, title="explore") 
    739     if not isMac: frame.Show() 
     813    if not is_mac: frame.Show() 
    740814    frame.panel.set_model(model=problem) 
    741815    frame.panel.Layout() 
    742816    frame.panel.aui.Split(0, wx.TOP) 
    743     if isMac: frame.Show() 
     817    if is_mac: frame.Show() 
    744818    app.MainLoop() 
    745819 
    746820class Explore(object): 
    747821    """ 
    748     Return a bumps wrapper for a SAS model comparison. 
     822    Bumps wrapper for a SAS model comparison. 
     823 
     824    The resulting object can be used as a Bumps fit problem so that 
     825    parameters can be adjusted in the GUI, with plots updated on the fly. 
    749826    """ 
    750827    def __init__(self, opts): 
     
    787864        Return cost. 
    788865        """ 
     866        # pylint: disable=no-self-use 
    789867        return 0.  # No nllf 
    790868 
     
    804882 
    805883 
     884def main(): 
     885    """ 
     886    Main program. 
     887    """ 
     888    opts = parse_opts() 
     889    if opts['explore']: 
     890        explore(opts) 
     891    else: 
     892        compare(opts) 
     893 
    806894if __name__ == "__main__": 
    807895    main() 
Note: See TracChangeset for help on using the changeset viewer.