Changeset ca9e54e in sasmodels


Ignore:
Timestamp:
Oct 20, 2016 2:55:35 PM (8 years ago)
Author:
Paul Kienzle <pkienzle@…>
Branches:
master, core_shell_microgels, costrafo411, magnetic_model, ticket-1257-vesicle-product, ticket_1156, ticket_1265_superball, ticket_822_more_unit_tests
Children:
97f9b46
Parents:
bbc2b34
Message:

sascomp: allow interactive comparison of different model sets

File:
1 edited

Legend:

Unmodified
Added
Removed
  • sasmodels/compare.py

    r3e8ea5d rca9e54e  
    654654    parameters. 
    655655    """ 
     656    result = run_models(opts, verbose=True) 
     657    if opts['plot']:  # Note: never called from explore 
     658        plot_models(opts, result, limits=limits) 
     659 
     660def run_models(opts, verbose=False): 
     661    # type: (Dict[str, Any]) -> Dict[str, Any] 
     662 
    656663    n_base, n_comp = opts['count'] 
    657664    pars, pars2 = opts['pars'] 
     
    661668    base = opts['engines'][0] if n_base else None 
    662669    comp = opts['engines'][1] if n_comp else None 
     670 
    663671    base_time = comp_time = None 
    664672    base_value = comp_value = resid = relerr = None 
     
    669677            base_raw, base_time = time_calculation(base, pars, n_base) 
    670678            base_value = np.ma.masked_invalid(base_raw) 
    671             print("%s t=%.2f ms, intensity=%.0f" 
    672                   % (base.engine, base_time, base_value.sum())) 
     679            if verbose: 
     680                print("%s t=%.2f ms, intensity=%.0f" 
     681                      % (base.engine, base_time, base_value.sum())) 
    673682            _show_invalid(data, base_value) 
    674683        except ImportError: 
     
    681690            comp_raw, comp_time = time_calculation(comp, pars2, n_comp) 
    682691            comp_value = np.ma.masked_invalid(comp_raw) 
    683             print("%s t=%.2f ms, intensity=%.0f" 
    684                   % (comp.engine, comp_time, comp_value.sum())) 
     692            if verbose: 
     693                print("%s t=%.2f ms, intensity=%.0f" 
     694                      % (comp.engine, comp_time, comp_value.sum())) 
    685695            _show_invalid(data, comp_value) 
    686696        except ImportError: 
     
    692702        resid = (base_value - comp_value) 
    693703        relerr = resid/np.where(comp_value != 0., abs(comp_value), 1.0) 
    694         _print_stats("|%s-%s|" 
    695                      % (base.engine, comp.engine) + (" "*(3+len(comp.engine))), 
    696                      resid) 
    697         _print_stats("|(%s-%s)/%s|" 
    698                      % (base.engine, comp.engine, comp.engine), 
    699                      relerr) 
     704        if verbose: 
     705            _print_stats("|%s-%s|" 
     706                         % (base.engine, comp.engine) + (" "*(3+len(comp.engine))), 
     707                         resid) 
     708            _print_stats("|(%s-%s)/%s|" 
     709                         % (base.engine, comp.engine, comp.engine), 
     710                         relerr) 
     711 
     712    return dict(base_value=base_value, comp_value=comp_value, 
     713                base_time=base_time, comp_time=comp_time, 
     714                resid=resid, relerr=relerr) 
     715 
     716 
     717def _print_stats(label, err): 
     718    # type: (str, np.ma.ndarray) -> None 
     719    # work with trimmed data, not the full set 
     720    sorted_err = np.sort(abs(err.compressed())) 
     721    if len(sorted_err) == 0.: 
     722        print(label + "  no valid values") 
     723        return 
     724 
     725    p50 = int((len(sorted_err)-1)*0.50) 
     726    p98 = int((len(sorted_err)-1)*0.98) 
     727    data = [ 
     728        "max:%.3e"%sorted_err[-1], 
     729        "median:%.3e"%sorted_err[p50], 
     730        "98%%:%.3e"%sorted_err[p98], 
     731        "rms:%.3e"%np.sqrt(np.mean(sorted_err**2)), 
     732        "zero-offset:%+.3e"%np.mean(sorted_err), 
     733        ] 
     734    print(label+"  "+"  ".join(data)) 
     735 
     736 
     737def plot_models(opts, result, limits=None): 
     738    # type: (Dict[str, Any], Dict[str, Any], Optional[Tuple[float, float]]) -> Tuple[float, float] 
     739    base_value, comp_value= result['base_value'], result['comp_value'] 
     740    base_time, comp_time = result['base_time'], result['comp_time'] 
     741    resid, relerr = result['resid'], result['relerr'] 
     742 
     743    have_base, have_comp = (base_value is not None), (comp_value is not None) 
     744    base = opts['engines'][0] if have_base else None 
     745    comp = opts['engines'][1] if have_comp else None 
     746    data = opts['data'] 
    700747 
    701748    # Plot if requested 
    702     if not opts['plot'] and not opts['explore']: return 
    703749    view = opts['view'] 
    704750    import matplotlib.pyplot as plt 
    705751    if limits is None: 
    706752        vmin, vmax = np.Inf, -np.Inf 
    707         if n_base > 0: 
     753        if have_base: 
    708754            vmin = min(vmin, base_value.min()) 
    709755            vmax = max(vmax, base_value.max()) 
    710         if n_comp > 0: 
     756        if have_comp: 
    711757            vmin = min(vmin, comp_value.min()) 
    712758            vmax = max(vmax, comp_value.max()) 
    713759        limits = vmin, vmax 
    714760 
    715     if n_base > 0: 
    716         if n_comp > 0: plt.subplot(131) 
     761    if have_base: 
     762        if have_comp: plt.subplot(131) 
    717763        plot_theory(data, base_value, view=view, use_data=False, limits=limits) 
    718764        plt.title("%s t=%.2f ms"%(base.engine, base_time)) 
    719765        #cbar_title = "log I" 
    720     if n_comp > 0: 
    721         if n_base > 0: plt.subplot(132) 
     766    if have_comp: 
     767        if have_base: plt.subplot(132) 
     768        if not opts['is2d'] and have_base: 
     769            plot_theory(data, base_value, view=view, use_data=False, limits=limits) 
    722770        plot_theory(data, comp_value, view=view, use_data=False, limits=limits) 
    723771        plt.title("%s t=%.2f ms"%(comp.engine, comp_time)) 
    724772        #cbar_title = "log I" 
    725     if n_comp > 0 and n_base > 0: 
    726         if not opts['is2d']: 
    727             plot_theory(data, base_value, view=view, use_data=False, limits=limits) 
     773    if have_base and have_comp: 
    728774        plt.subplot(133) 
    729775        if not opts['rel_err']: 
     
    748794    fig.suptitle(":".join(opts['name']) + extra_title) 
    749795 
    750     if n_comp > 0 and n_base > 0 and opts['show_hist']: 
     796    if have_base and have_comp and opts['show_hist']: 
    751797        plt.figure() 
    752798        v = relerr 
     
    763809    return limits 
    764810 
    765 def _print_stats(label, err): 
    766     # type: (str, np.ma.ndarray) -> None 
    767     # work with trimmed data, not the full set 
    768     sorted_err = np.sort(abs(err.compressed())) 
    769     p50 = int((len(sorted_err)-1)*0.50) 
    770     p98 = int((len(sorted_err)-1)*0.98) 
    771     data = [ 
    772         "max:%.3e"%sorted_err[-1], 
    773         "median:%.3e"%sorted_err[p50], 
    774         "98%%:%.3e"%sorted_err[p98], 
    775         "rms:%.3e"%np.sqrt(np.mean(sorted_err**2)), 
    776         "zero-offset:%+.3e"%np.mean(sorted_err), 
    777         ] 
    778     print(label+"  "+"  ".join(data)) 
    779811 
    780812 
     
    11131145    from bumps.names import FitProblem  # type: ignore 
    11141146    from bumps.gui.app_frame import AppFrame  # type: ignore 
     1147    from bumps.gui import signal 
    11151148 
    11161149    is_mac = "cocoa" in wx.version() 
    11171150    # Create an app if not running embedded 
    11181151    app = wx.App() if wx.GetApp() is None else None 
    1119     problem = FitProblem(Explore(opts)) 
     1152    model = Explore(opts) 
     1153    problem = FitProblem(model) 
    11201154    frame = AppFrame(parent=None, title="explore", size=(1000,700)) 
    11211155    if not is_mac: frame.Show() 
     
    11231157    frame.panel.Layout() 
    11241158    frame.panel.aui.Split(0, wx.TOP) 
     1159    def reset_parameters(event): 
     1160        model.revert_values() 
     1161        signal.update_parameters(problem) 
     1162    frame.Bind(wx.EVT_TOOL, reset_parameters, frame.ToolBar.GetToolByPos(1)) 
    11251163    if is_mac: frame.Show() 
    11261164    # If running withing an app, start the main loop 
     
    11401178        config_matplotlib() 
    11411179        self.opts = opts 
    1142         model_info = opts['def'][0] 
    1143         pars, pd_types = bumps_model.create_parameters(model_info, **opts['pars'][0]) 
     1180        p1, p2 = opts['pars'] 
     1181        m1, m2 = opts['def'] 
     1182        self.fix_p2 = m1 != m2 or p1 != p2 
     1183        model_info = m1 
     1184        pars, pd_types = bumps_model.create_parameters(model_info, **p1) 
    11441185        # Initialize parameter ranges, fixing the 2D parameters for 1D data. 
    11451186        if not opts['is2d']: 
     
    11551196 
    11561197        self.pars = pars 
     1198        self.starting_values = dict((k, v.value) for k, v in pars.items()) 
    11571199        self.pd_types = pd_types 
    11581200        self.limits = None 
     1201 
     1202    def revert_values(self): 
     1203        for k, v in self.starting_values.items(): 
     1204            self.pars[k].value = v 
     1205 
     1206    def model_update(self): 
     1207        pass 
    11591208 
    11601209    def numpoints(self): 
     
    11881237        pars.update(self.pd_types) 
    11891238        self.opts['pars'][0] = pars 
    1190         self.opts['pars'][1] = pars 
    1191         limits = compare(self.opts, limits=self.limits) 
     1239        if not self.fix_p2: 
     1240            self.opts['pars'][1] = pars 
     1241        result = run_models(self.opts) 
     1242        limits = plot_models(self.opts, result, limits=self.limits) 
    11921243        if self.limits is None: 
    11931244            vmin, vmax = limits 
    11941245            self.limits = vmax*1e-7, 1.3*vmax 
     1246            import pylab; pylab.clf() 
     1247            plot_models(self.opts, result, limits=self.limits) 
    11951248 
    11961249 
Note: See TracChangeset for help on using the changeset viewer.