Changeset ff1fff5 in sasmodels for sasmodels/compare.py


Ignore:
Timestamp:
Oct 16, 2016 4:14:15 PM (8 years ago)
Author:
Paul Kienzle <pkienzle@…>
Branches:
master, core_shell_microgels, costrafo411, magnetic_model, ticket-1257-vesicle-product, ticket_1156, ticket_1265_superball, ticket_822_more_unit_tests
Children:
7fcdc9f
Parents:
4f79d94 (diff), a0d75ce (diff)
Note: this is a merge changeset, the changes displayed below correspond to the merge itself.
Use the (diff) links above to see all the changes relative to each parent.
Message:

allow comparison of different models, such as 'sascomp sphere:ellipsoid radius_polar=:radius radius_equatorial=:radius -mono'

File:
1 edited

Legend:

Unmodified
Added
Removed
  • sasmodels/compare.py

    r6831fa0 rff1fff5  
    3333import datetime 
    3434import traceback 
     35import re 
    3536 
    3637import numpy as np  # type: ignore 
     
    4243from .direct_model import DirectModel 
    4344from .convert import revert_name, revert_pars, constrain_new_to_old 
     45from .generate import FLOAT_RE 
    4446 
    4547try: 
     
    612614    parameters. 
    613615    """ 
    614     n_base, n_comp = opts['n1'], opts['n2'] 
    615     pars = opts['pars'] 
     616    n_base, n_comp = opts['count'] 
     617    pars, pars2 = opts['pars'] 
    616618    data = opts['data'] 
    617619 
     
    637639    if n_comp > 0: 
    638640        try: 
    639             comp_raw, comp_time = time_calculation(comp, pars, n_comp) 
     641            comp_raw, comp_time = time_calculation(comp, pars2, n_comp) 
    640642            comp_value = np.ma.masked_invalid(comp_raw) 
    641643            print("%s t=%.2f ms, intensity=%.0f" 
     
    700702    #    h.ax.set_title(cbar_title) 
    701703    fig = plt.gcf() 
    702     fig.suptitle(opts['name']) 
     704    extra_title = ' '+opts['title'] if opts['title'] else '' 
     705    fig.suptitle(":".join(opts['name']) + extra_title) 
    703706 
    704707    if n_comp > 0 and n_base > 0 and opts['show_hist']: 
     
    753756VALUE_OPTIONS = [ 
    754757    # Note: random is both a name option and a value option 
    755     'cutoff', 'random', 'nq', 'res', 'accuracy', 
     758    'cutoff', 'random', 'nq', 'res', 'accuracy', 'title', 
    756759    ] 
    757760 
     
    800803    return pars 
    801804 
     805INTEGER_RE = re.compile("^[+-]?[1-9][0-9]*$") 
     806def isnumber(str): 
     807    match = FLOAT_RE.match(str) 
     808    isfloat = (match and not str[match.end():]) 
     809    return isfloat or INTEGER_RE.match(str) 
    802810 
    803811def parse_opts(argv): 
     
    822830        print("expected parameters: model N1 N2") 
    823831 
    824     name = positional_args[0] 
    825     try: 
    826         model_info = core.load_model_info(name) 
    827     except ImportError as exc: 
    828         print(str(exc)) 
    829         print("Could not find model; use one of:\n    " + models) 
    830         return None 
    831  
    832832    invalid = [o[1:] for o in flags 
    833833               if o[1:] not in NAME_OPTIONS 
     
    837837        return None 
    838838 
     839    name = positional_args[0] 
     840    n1 = int(positional_args[1]) if len(positional_args) > 1 else 1 
     841    n2 = int(positional_args[2]) if len(positional_args) > 2 else 1 
    839842 
    840843    # pylint: disable=bad-whitespace 
     
    858861        'zero'      : False, 
    859862        'html'      : False, 
     863        'title'     : None, 
    860864    } 
    861865    engines = [] 
     
    878882        elif arg.startswith('-cutoff='):   opts['cutoff'] = float(arg[8:]) 
    879883        elif arg.startswith('-random='):   opts['seed'] = int(arg[8:]) 
     884        elif arg.startswith('-title'):     opts['title'] = arg[7:] 
    880885        elif arg == '-random':  opts['seed'] = np.random.randint(1000000) 
    881886        elif arg == '-preset':  opts['seed'] = -1 
     
    902907    # pylint: enable=bad-whitespace 
    903908 
     909    if ':' in name: 
     910        name, name2 = name.split(':',2) 
     911    else: 
     912        name2 = name 
     913    try: 
     914        model_info = core.load_model_info(name) 
     915        model_info2 = core.load_model_info(name2) if name2 != name else model_info 
     916    except ImportError as exc: 
     917        print(str(exc)) 
     918        print("Could not find model; use one of:\n    " + models) 
     919        return None 
     920 
     921    # Get demo parameters from model definition, or use default parameters 
     922    # if model does not define demo parameters 
     923    pars = get_pars(model_info, opts['use_demo']) 
     924    pars2 = get_pars(model_info2, opts['use_demo']) 
     925    # randomize parameters 
     926    #pars.update(set_pars)  # set value before random to control range 
     927    if opts['seed'] > -1: 
     928        pars = randomize_pars(model_info, pars, seed=opts['seed']) 
     929        if model_info != model_info2: 
     930            pars2 = randomize_pars(model_info2, pars2, seed=opts['seed']) 
     931        else: 
     932            pars2 = pars.copy() 
     933        print("Randomize using -random=%i"%opts['seed']) 
     934    if opts['mono']: 
     935        pars = suppress_pd(pars) 
     936        pars2 = suppress_pd(pars2) 
     937 
     938    # Fill in parameters given on the command line 
     939    presets = {} 
     940    presets2 = {} 
     941    for arg in values: 
     942        k, v = arg.split('=', 1) 
     943        if k not in pars and k not in pars2: 
     944            # extract base name without polydispersity info 
     945            s = set(p.split('_pd')[0] for p in pars) 
     946            print("%r invalid; parameters are: %s"%(k, ", ".join(sorted(s)))) 
     947            return None 
     948        v1, v2 = v.split(':',2) if ':' in v else (v,v) 
     949        if v1 and k in pars: 
     950            presets[k] = float(v1) if isnumber(v1) else v1 
     951        if v2 and k in pars2: 
     952            presets2[k] = float(v2) if isnumber(v2) else v2 
     953 
     954    # Evaluate preset parameter expressions 
     955    context = pars.copy() 
     956    context.update((k,v) for k,v in presets.items() if isinstance(v, float)) 
     957    for k, v in presets.items(): 
     958        if not isinstance(v, float) and not k.endswith('_type'): 
     959            presets[k] = eval(v, context) 
     960    context = pars.copy() 
     961    context.update(presets) 
     962    context.update((k,v) for k,v in presets2.items() if isinstance(v, float)) 
     963    for k, v in presets2.items(): 
     964        if not isinstance(v, float) and not k.endswith('_type'): 
     965            presets2[k] = eval(v, context) 
     966 
     967    # update parameters with presets 
     968    pars.update(presets)  # set value after random to control value 
     969    pars2.update(presets2)  # set value after random to control value 
     970    #import pprint; pprint.pprint(model_info) 
     971    constrain_pars(model_info, pars) 
     972    constrain_pars(model_info2, pars2) 
     973 
     974    same_model = name == name2 and pars == pars 
    904975    if len(engines) == 0: 
    905         engines.extend(['single', 'double']) 
     976        if same_model: 
     977            engines.extend(['single', 'double']) 
     978        else: 
     979            engines.extend(['single', 'single']) 
    906980    elif len(engines) == 1: 
    907         if engines[0][0] == 'double': 
     981        if not same_model: 
     982            engines.append(engines[0]) 
     983        elif engines[0] == 'double': 
    908984            engines.append('single') 
    909985        else: 
     
    912988        del engines[2:] 
    913989 
    914     n1 = int(positional_args[1]) if len(positional_args) > 1 else 1 
    915     n2 = int(positional_args[2]) if len(positional_args) > 2 else 1 
    916990    use_sasview = any(engine == 'sasview' and count > 0 
    917991                      for engine, count in zip(engines, [n1, n2])) 
    918  
    919     # Get demo parameters from model definition, or use default parameters 
    920     # if model does not define demo parameters 
    921     pars = get_pars(model_info, opts['use_demo']) 
    922  
    923  
    924     # Fill in parameters given on the command line 
    925     presets = {} 
    926     for arg in values: 
    927         k, v = arg.split('=', 1) 
    928         if k not in pars: 
    929             # extract base name without polydispersity info 
    930             s = set(p.split('_pd')[0] for p in pars) 
    931             print("%r invalid; parameters are: %s"%(k, ", ".join(sorted(s)))) 
    932             return None 
    933         presets[k] = float(v) if not k.endswith('type') else v 
    934  
    935     # randomize parameters 
    936     #pars.update(set_pars)  # set value before random to control range 
    937     if opts['seed'] > -1: 
    938         pars = randomize_pars(model_info, pars, seed=opts['seed']) 
    939         print("Randomize using -random=%i"%opts['seed']) 
    940     if opts['mono']: 
    941         pars = suppress_pd(pars) 
    942     pars.update(presets)  # set value after random to control value 
    943     #import pprint; pprint.pprint(model_info) 
    944     constrain_pars(model_info, pars) 
    945992    if use_sasview: 
    946993        constrain_new_to_old(model_info, pars) 
     994        constrain_new_to_old(model_info2, pars2) 
     995 
    947996    if opts['show_pars']: 
    948997        print(str(parlist(model_info, pars, opts['is2d']))) 
     
    9551004        base = None 
    9561005    if n2: 
    957         comp = make_engine(model_info, data, engines[1], opts['cutoff']) 
     1006        comp = make_engine(model_info2, data, engines[1], opts['cutoff']) 
    9581007    else: 
    9591008        comp = None 
     
    9621011    # Remember it all 
    9631012    opts.update({ 
    964         'name'      : name, 
    965         'def'       : model_info, 
    966         'n1'        : n1, 
    967         'n2'        : n2, 
    968         'presets'   : presets, 
    969         'pars'      : pars, 
    9701013        'data'      : data, 
     1014        'name'      : [name, name2], 
     1015        'def'       : [model_info, model_info2], 
     1016        'count'     : [n1, n2], 
     1017        'presets'   : [presets, presets2], 
     1018        'pars'      : [pars, pars2], 
    9711019        'engines'   : [base, comp], 
    9721020    }) 
     
    9831031    from .generate import view_html_from_info 
    9841032    app = wx.App() if wx.GetApp() is None else None 
    985     view_html_from_info(opts['def']) 
     1033    view_html_from_info(opts['def'][0]) 
    9861034    if app: app.MainLoop() 
    9871035 
     
    10221070        config_matplotlib() 
    10231071        self.opts = opts 
    1024         model_info = opts['def'] 
    1025         pars, pd_types = bumps_model.create_parameters(model_info, **opts['pars']) 
     1072        model_info = opts['def'][0] 
     1073        pars, pd_types = bumps_model.create_parameters(model_info, **opts['pars'][0]) 
    10261074        # Initialize parameter ranges, fixing the 2D parameters for 1D data. 
    10271075        if not opts['is2d']: 
     
    10691117        pars = dict((k, v.value) for k, v in self.pars.items()) 
    10701118        pars.update(self.pd_types) 
    1071         self.opts['pars'] = pars 
     1119        self.opts['pars'][0] = pars 
     1120        self.opts['pars'][1] = pars 
    10721121        limits = compare(self.opts, limits=self.limits) 
    10731122        if self.limits is None: 
Note: See TracChangeset for help on using the changeset viewer.