Changeset 17bbadd in sasmodels for sasmodels/compare.py
- Timestamp:
- Mar 15, 2016 10:47:12 AM (8 years ago)
- Branches:
- master, core_shell_microgels, costrafo411, magnetic_model, release_v0.94, release_v0.95, ticket-1257-vesicle-product, ticket_1156, ticket_1265_superball, ticket_822_more_unit_tests
- Children:
- 754e27b
- Parents:
- 5ceb7d0
- File:
-
- 1 edited
Legend:
- Unmodified
- Added
- Removed
-
sasmodels/compare.py
r6869ceb r17bbadd 38 38 from . import core 39 39 from . import kerneldll 40 from . import generate40 from . import product 41 41 from .data import plot_theory, empty_data1D, empty_data2D 42 42 from .direct_model import DirectModel 43 from .convert import revert_ model, constrain_new_to_old43 from .convert import revert_pars, constrain_new_to_old 44 44 45 45 USAGE = """ … … 264 264 return pars 265 265 266 def constrain_pars(model_ definition, pars):266 def constrain_pars(model_info, pars): 267 267 """ 268 268 Restrict parameters to valid values. … … 272 272 cylinder radius in this case). 273 273 """ 274 name = model_definition.name 274 name = model_info['id'] 275 # if it is a product model, then just look at the form factor since 276 # none of the structure factors need any constraints. 277 if '*' in name: 278 name = name.split('*')[0] 279 275 280 if name == 'capped_cylinder' and pars['cap_radius'] < pars['radius']: 276 281 pars['radius'], pars['cap_radius'] = pars['cap_radius'], pars['radius'] … … 340 345 return pars 341 346 342 def eval_sasview(model_ definition, data):347 def eval_sasview(model_info, data): 343 348 """ 344 349 Return a model calculator using the SasView fitting engine. … … 349 354 from sas.models.qsmearing import smear_selection 350 355 351 # convert model parameters from sasmodel form to sasview form 352 #print("old",sorted(pars.items())) 353 modelname, _ = revert_model(model_definition, {}) 354 #print("new",sorted(_pars.items())) 355 sas = __import__('sas.models.'+modelname) 356 ModelClass = getattr(getattr(sas.models, modelname, None), modelname, None) 357 if ModelClass is None: 358 raise ValueError("could not find model %r in sas.models"%modelname) 359 model = ModelClass() 356 def get_model(name): 357 #print("new",sorted(_pars.items())) 358 sas = __import__('sas.models.' + name) 359 ModelClass = getattr(getattr(sas.models, name, None), name, None) 360 if ModelClass is None: 361 raise ValueError("could not find model %r in sas.models"%name) 362 return ModelClass() 363 364 # grab the sasview model, or create it if it is a product model 365 if model_info['composition']: 366 composition_type, parts = model_info['composition'] 367 if composition_type == 'product': 368 from sas.models import MultiplicationModel 369 P, S = [get_model(p) for p in model_info['oldname']] 370 model = MultiplicationModel(P, S) 371 else: 372 raise ValueError("mixture models not handled yet") 373 else: 374 model = get_model(model_info['oldname']) 375 376 # build a smearer with which to call the model, if necessary 360 377 smearer = smear_selection(data, model=model) 361 362 378 if hasattr(data, 'qx_data'): 363 379 q = np.sqrt(data.qx_data**2 + data.qy_data**2) … … 382 398 """ 383 399 # paying for parameter conversion each time to keep life simple, if not fast 384 _, pars = revert_model(model_definition, pars)400 pars = revert_pars(model_info, pars) 385 401 for k, v in pars.items(): 386 402 parts = k.split('.') # polydispersity components … … 405 421 'longdouble': '128', 406 422 } 407 def eval_opencl(model_ definition, data, dtype='single', cutoff=0.):423 def eval_opencl(model_info, data, dtype='single', cutoff=0.): 408 424 """ 409 425 Return a model calculator using the OpenCL calculation engine. 410 426 """ 411 try: 412 model = core.load_model(model_definition, dtype=dtype, platform="ocl") 413 except Exception as exc: 414 print(exc) 415 print("... trying again with single precision") 416 dtype = 'single' 417 model = core.load_model(model_definition, dtype=dtype, platform="ocl") 427 def builder(model_info): 428 try: 429 return core.build_model(model_info, dtype=dtype, platform="ocl") 430 except Exception as exc: 431 print(exc) 432 print("... trying again with single precision") 433 dtype = 'single' 434 return core.build_model(model_info, dtype=dtype, platform="ocl") 435 if model_info['composition']: 436 composition_type, parts = model_info['composition'] 437 if composition_type == 'product': 438 P, S = [builder(p) for p in parts] 439 model = product.ProductModel(P, S) 440 else: 441 raise ValueError("mixture models not handled yet") 442 else: 443 model = builder(model_info) 418 444 calculator = DirectModel(data, model, cutoff=cutoff) 419 445 calculator.engine = "OCL%s"%DTYPE_MAP[dtype] 420 446 return calculator 421 447 422 def eval_ctypes(model_ definition, data, dtype='double', cutoff=0.):448 def eval_ctypes(model_info, data, dtype='double', cutoff=0.): 423 449 """ 424 450 Return a model calculator using the DLL calculation engine. … … 426 452 if dtype == 'quad': 427 453 dtype = 'longdouble' 428 model = core.load_model(model_definition, dtype=dtype, platform="dll") 454 def builder(model_info): 455 return core.build_model(model_info, dtype=dtype, platform="dll") 456 457 if model_info['composition']: 458 composition_type, parts = model_info['composition'] 459 if composition_type == 'product': 460 P, S = [builder(p) for p in parts] 461 model = product.ProductModel(P, S) 462 else: 463 raise ValueError("mixture models not handled yet") 464 else: 465 model = builder(model_info) 429 466 calculator = DirectModel(data, model, cutoff=cutoff) 430 467 calculator.engine = "OMP%s"%DTYPE_MAP[dtype] … … 470 507 return data, index 471 508 472 def make_engine(model_ definition, data, dtype, cutoff):509 def make_engine(model_info, data, dtype, cutoff): 473 510 """ 474 511 Generate the appropriate calculation engine for the given datatype. … … 478 515 """ 479 516 if dtype == 'sasview': 480 return eval_sasview(model_ definition, data)517 return eval_sasview(model_info, data) 481 518 elif dtype.endswith('!'): 482 return eval_ctypes(model_definition, data, dtype=dtype[:-1], 483 cutoff=cutoff) 484 else: 485 return eval_opencl(model_definition, data, dtype=dtype, 486 cutoff=cutoff) 519 return eval_ctypes(model_info, data, dtype=dtype[:-1], cutoff=cutoff) 520 else: 521 return eval_opencl(model_info, data, dtype=dtype, cutoff=cutoff) 487 522 488 523 def compare(opts, limits=None): … … 642 677 643 678 644 def get_demo_pars(model_ definition):679 def get_demo_pars(model_info): 645 680 """ 646 681 Extract demo parameters from the model definition. 647 682 """ 648 info = generate.make_info(model_definition)649 683 # Get the default values for the parameters 650 pars = dict((p[0], p[2]) for p in info['parameters'])684 pars = dict((p[0], p[2]) for p in model_info['parameters']) 651 685 652 686 # Fill in default values for the polydispersity parameters 653 for p in info['parameters']:687 for p in model_info['parameters']: 654 688 if p[4] in ('volume', 'orientation'): 655 689 pars[p[0]+'_pd'] = 0.0 … … 659 693 660 694 # Plug in values given in demo 661 pars.update( info['demo'])695 pars.update(model_info['demo']) 662 696 return pars 697 663 698 664 699 def parse_opts(): … … 679 714 print(columnize(MODELS, indent=" ")) 680 715 sys.exit(1) 681 682 name = args[0]683 try:684 model_definition = core.load_model_definition(name)685 except ImportError, exc:686 print(str(exc))687 print("Use one of:\n " + models)688 sys.exit(1)689 716 if len(args) > 3: 690 717 print("expected parameters: model N1 N2") 718 719 def load_model(name): 720 try: 721 model_info = core.load_model_info(name) 722 except ImportError, exc: 723 print(str(exc)) 724 print("Use one of:\n " + models) 725 sys.exit(1) 726 return model_info 727 728 name = args[0] 729 if '*' in name: 730 parts = [load_model(k) for k in name.split('*')] 731 model_info = product.make_product_info(*parts) 732 else: 733 model_info = load_model(name) 691 734 692 735 invalid = [o[1:] for o in flags … … 770 813 # Get demo parameters from model definition, or use default parameters 771 814 # if model does not define demo parameters 772 pars = get_demo_pars(model_ definition)815 pars = get_demo_pars(model_info) 773 816 774 817 # Fill in parameters given on the command line … … 791 834 pars = suppress_pd(pars) 792 835 pars.update(presets) # set value after random to control value 793 constrain_pars(model_ definition, pars)794 constrain_new_to_old(model_ definition, pars)836 constrain_pars(model_info, pars) 837 constrain_new_to_old(model_info, pars) 795 838 if opts['show_pars']: 796 839 print(str(parlist(pars))) … … 799 842 data, _ = make_data(opts) 800 843 if n1: 801 base = make_engine(model_ definition, data, engines[0], opts['cutoff'])844 base = make_engine(model_info, data, engines[0], opts['cutoff']) 802 845 else: 803 846 base = None 804 847 if n2: 805 comp = make_engine(model_ definition, data, engines[1], opts['cutoff'])848 comp = make_engine(model_info, data, engines[1], opts['cutoff']) 806 849 else: 807 850 comp = None … … 811 854 opts.update({ 812 855 'name' : name, 813 'def' : model_ definition,856 'def' : model_info, 814 857 'n1' : n1, 815 858 'n2' : n2, … … 854 897 config_matplotlib() 855 898 self.opts = opts 856 info = generate.make_info(opts['def'])857 pars, pd_types = bumps_model.create_parameters( info, **opts['pars'])899 model_info = opts['def'] 900 pars, pd_types = bumps_model.create_parameters(model_info, **opts['pars']) 858 901 if not opts['is2d']: 859 902 active = [base + ext 860 for base in info['partype']['pd-1d']903 for base in model_info['partype']['pd-1d'] 861 904 for ext in ['', '_pd', '_pd_n', '_pd_nsigma']] 862 active.extend( info['partype']['fixed-1d'])905 active.extend(model_info['partype']['fixed-1d']) 863 906 for k in active: 864 907 v = pars[k]
Note: See TracChangeset
for help on using the changeset viewer.