Changeset 256dfe1 in sasmodels


Ignore:
Timestamp:
Jul 18, 2016 12:42:29 AM (7 years ago)
Author:
Paul Kienzle <pkienzle@…>
Branches:
master, core_shell_microgels, costrafo411, magnetic_model, release_v0.94, release_v0.95, ticket-1257-vesicle-product, ticket_1156, ticket_1265_superball, ticket_822_more_unit_tests
Children:
c5ac2b2
Parents:
46ed760
Message:

allow comparison of multiplicity models with sasview 3.x

Location:
sasmodels
Files:
4 edited

Legend:

Unmodified
Added
Removed
  • sasmodels/compare.py

    rf3bd37f r256dfe1  
    387387    from sas.models.MultiplicationModel import MultiplicationModel 
    388388 
    389     def get_model(name): 
     389    def get_model_class(name): 
    390390        # type: (str) -> "sas.models.BaseComponent" 
    391391        #print("new",sorted(_pars.items())) 
     
    394394        if ModelClass is None: 
    395395            raise ValueError("could not find model %r in sas.models"%name) 
    396         return ModelClass() 
     396        return ModelClass 
     397 
     398    # WARNING: ugly hack when handling model! 
     399    # Sasview models with multiplicity need to be created with the target 
     400    # multiplicity, so we cannot create the target model ahead of time for 
     401    # for multiplicity models.  Instead we store the model in a list and 
     402    # update the first element of that list with the new multiplicity model 
     403    # every time we evaluate. 
    397404 
    398405    # grab the sasview model, or create it if it is a product model 
     
    401408        if composition_type == 'product': 
    402409            P, S = [get_model(revert_name(p)) for p in parts] 
    403             model = MultiplicationModel(P, S) 
     410            model = [MultiplicationModel(P, S)] 
    404411        else: 
    405412            raise ValueError("sasview mixture models not supported by compare") 
     
    409416            raise ValueError("model %r does not exist in old sasview" 
    410417                            % model_info.id) 
    411         model = get_model(old_name) 
     418        ModelClass = get_model_class(old_name) 
     419        model = [ModelClass()] 
    412420 
    413421    # build a smearer with which to call the model, if necessary 
     
    421429            smearer.accuracy = data.accuracy 
    422430            smearer.set_index(index) 
    423             theory = lambda: smearer.get_value() 
     431            def _call_smearer(): 
     432                smearer.model = model[0] 
     433                return smearer.get_value() 
     434            theory = lambda: _call_smearer() 
    424435        else: 
    425             theory = lambda: model.evalDistribution([data.qx_data[index], 
    426                                                      data.qy_data[index]]) 
     436            theory = lambda: model[0].evalDistribution([data.qx_data[index], 
     437                                                        data.qy_data[index]]) 
    427438    elif smearer is not None: 
    428         theory = lambda: smearer(model.evalDistribution(data.x)) 
     439        theory = lambda: smearer(model[0].evalDistribution(data.x)) 
    429440    else: 
    430         theory = lambda: model.evalDistribution(data.x) 
     441        theory = lambda: model[0].evalDistribution(data.x) 
    431442 
    432443    def calculator(**pars): 
     
    435446        Sasview calculator for model. 
    436447        """ 
     448        # For multiplicity models, recreate the model the first time the 
     449        if model_info.control: 
     450            model[0] = ModelClass(int(pars[model_info.control])) 
    437451        # paying for parameter conversion each time to keep life simple, if not fast 
    438         pars = revert_pars(model_info, pars) 
    439         for k, v in pars.items(): 
     452        oldpars = revert_pars(model_info, pars) 
     453        for k, v in oldpars.items(): 
    440454            name_attr = k.split('.')  # polydispersity components 
    441455            if len(name_attr) == 2: 
    442                 model.dispersion[name_attr[0]][name_attr[1]] = v 
     456                model[0].dispersion[name_attr[0]][name_attr[1]] = v 
    443457            else: 
    444                 model.setParam(k, v) 
     458                model[0].setParam(k, v) 
    445459        return theory() 
    446460 
     
    742756        for ext, val in parts: 
    743757            if p.length > 1: 
    744                 dict(("%s%d%s"%(p.id,k,ext), val) for k in range(p.length)) 
     758                dict(("%s%d%s"%(p.id,k,ext), val) for k in range(1, p.length+1)) 
    745759            else: 
    746760                pars[p.id+ext] = val 
  • sasmodels/convert.json

    r6e7ff6d r256dfe1  
    7979  ],  
    8080  "core_multi_shell": [ 
    81     "CoreMultiShellModel",  
    82     { 
    83       "thick_shell": "thick_shell",  
    84       "radius": "rad_core0",  
    85       "sld": "sld_core0",  
    86       "sld_shell": "sld_in_shell",  
    87       "sld_solvent": "sld_solv",  
     81    "CoreMultiShellModel", 
     82    { 
     83      "thickness": "thick_shell", 
     84      "sld": "sld_shell", 
     85      "radius": "rad_core0", 
     86      "sld_core": "sld_core0", 
     87      "sld_solvent": "sld_solv", 
    8888      "n": "n_shells" 
    8989    } 
     
    475475  ],  
    476476  "onion": [ 
    477     "OnionExpShellModel",  
     477    "OnionModel", 
    478478    { 
    479479      "A": "A_shell",  
    480       "core_sld": "sld_core0",  
     480      "sld_core": "sld_core0", 
    481481      "core_radius": "rad_core0",  
    482       "out_sld": "sld_out_shell",  
    483       "n": "n_shells",  
    484       "solvent_sld": "sld_solv",  
     482      "n": "n_shells", 
     483      "sld_solvent": "sld_solv", 
    485484      "thickness": "thick_shell",  
    486       "in_sld": "sld_in_shell" 
     485      "sld_in": "sld_in_shell", 
     486      "sld_out": "sld_out_shell" 
    487487    } 
    488488  ],  
  • sasmodels/convert.py

    r7ae2b7f r256dfe1  
    22Convert models to and from sasview. 
    33""" 
     4from __future__ import print_function 
     5 
    46from os.path import join as joinpath, abspath, dirname 
    57import warnings 
     
    2830] 
    2931 
     32MODELS_WITHOUT_VOLFRACTION = [ 
     33    'fractal', 
     34    'vesicle', 
     35    'multilayer_vesicle', 
     36    'core_multi_shell', 
     37] 
     38 
     39 
    3040# Convert new style names for polydispersity info to old style names 
    3141PD_DOT = [ 
     
    103113    # model name mapping 
    104114 
     115def _unscale(par, scale): 
     116    return [pk*scale for pk in par] if isinstance(par, list) else par*scale 
     117 
    105118def _unscale_sld(pars): 
    106119    """ 
     
    109122    new model definition end with sld. 
    110123    """ 
    111     return dict((p, (v*1e-6 if p.startswith('sld') or p.endswith('sld') 
    112                      else v*1e15 if 'ndensity' in p 
     124    return dict((p, (_unscale(v,1e-6) if p.startswith('sld') or p.endswith('sld') 
     125                     else _unscale(v,1e15) if 'ndensity' in p 
    113126                     else v)) 
    114127                for p, v in pars.items()) 
     
    155168    return oldname 
    156169 
    157 def _get_old_pars(model_info): 
     170def _get_translation_table(model_info): 
    158171    _read_conversion_table() 
    159     oldname, oldpars = CONVERSION_TABLE.get(model_info.id, [None, {}]) 
     172    _, translation = CONVERSION_TABLE.get(model_info.id, [None, {}]) 
     173    translation = translation.copy() 
     174    for p in model_info.parameters.kernel_parameters: 
     175        if p.length > 1: 
     176            newid = p.id 
     177            oldid = translation.get(p.id, p.id) 
     178            del translation[newid] 
     179            for k in range(1, p.length+1): 
     180                translation[newid+str(k)] = oldid+str(k) 
     181    # Remove control parameter from the result 
     182    if model_info.control: 
     183        translation[model_info.control] = None 
     184    return translation 
     185 
     186def _trim_vectors(model_info, pars, oldpars): 
     187    _read_conversion_table() 
     188    _, translation = CONVERSION_TABLE.get(model_info.id, [None, {}]) 
     189    for p in model_info.parameters.kernel_parameters: 
     190        if p.length_control is not None: 
     191            n = int(pars[p.length_control]) 
     192            oldname = translation.get(p.id, p.id) 
     193            for k in range(n+1, p.length+1): 
     194                for _, old in PD_DOT: 
     195                    oldpars.pop(oldname+str(k)+old, None) 
    160196    return oldpars 
    161197 
     
    174210            raise NotImplementedError("cannot convert to sasview sum") 
    175211    else: 
    176         oldpars = _get_old_pars(model_info) 
    177     oldpars = _revert_pars(_unscale_sld(pars), oldpars) 
     212        translation = _get_translation_table(model_info) 
     213    oldpars = _revert_pars(_unscale_sld(pars), translation) 
     214    oldpars = _trim_vectors(model_info, pars, oldpars) 
    178215 
    179216 
     
    194231    namelist = name.split('*') if '*' in name else [name] 
    195232    for name in namelist: 
     233        if name in MODELS_WITHOUT_VOLFRACTION: 
     234            del oldpars['volfraction'] 
    196235        if name == 'stacked_disks': 
    197236            _remove_pd(oldpars, 'n_stacking', name) 
     
    211250        elif name in ['mono_gauss_coil','poly_gauss_coil']: 
    212251            del oldpars['i_zero'] 
    213         elif name == 'fractal': 
    214             del oldpars['volfraction'] 
    215         elif name == 'vesicle': 
    216             del oldpars['volfraction'] 
    217         elif name == 'multilayer_vesicle': 
    218             del oldpars['volfraction'] 
    219252 
    220253    return oldpars 
     
    240273    namelist = name.split('*') if '*' in name else [name] 
    241274    for name in namelist: 
     275        if name in MODELS_WITHOUT_VOLFRACTION: 
     276            pars['volfraction'] = 1 
    242277        if name == 'pearl_necklace': 
    243278            pars['string_thickness_pd_n'] = 0 
     
    252287        elif name == 'poly_gauss_coil': 
    253288            pars['i_zero'] = 1 
    254         elif name == 'fractal': 
    255             pars['volfraction'] = 1 
    256         elif name == 'vesicle': 
    257             pars['volfraction'] = 1 
    258         elif name == 'multilayer_vesicle': 
    259             pars['volfraction'] = 1 
    260              
     289 
  • sasmodels/modelinfo.py

    r98ba1fc r256dfe1  
    186186        vectors = dict((name,value) for name,value in pars.items() 
    187187                       if name in lookup and lookup[name].length > 1) 
     188        #print("lookup", lookup) 
     189        #print("scalars", scalars) 
     190        #print("vectors", vectors) 
    188191        if vectors: 
    189192            for name, value in vectors.items(): 
     
    194197                        key = name+str(k) 
    195198                        if key not in scalars: 
    196                             scalars[key] = vectors 
     199                            scalars[key] = value 
    197200                else: 
    198201                    # supoprt for the form 
    199202                    #    dict(thickness=[20,10,3]) 
    200203                    for (k,v) in enumerate(value): 
    201                         scalars[name+str(k)] = v 
     204                        scalars[name+str(k+1)] = v 
    202205        result.update(scalars) 
     206        #print("expanded", result) 
    203207 
    204208    return result 
     
    402406    parameters don't use vector notation, and instead use p1, p2, ... 
    403407 
    404     * *control_parameters* is the 
    405  
    406408    """ 
    407409    # scale and background are implicit parameters 
     
    455457                         if p.polydisperse and p.type != 'magnetic') 
    456458 
    457  
    458459    def _set_vector_lengths(self): 
    459         # type: () -> None 
     460        # type: () -> List[str] 
    460461        """ 
    461462        Walk the list of kernel parameters, setting the length field of the 
     
    466467        initially created. 
    467468 
     469        Returns the list of control parameter names. 
     470 
    468471        Note: This modifies the underlying parameter object. 
    469472        """ 
    470473        # Sort out the length of the vector parameters such as thickness[n] 
     474 
    471475        for p in self.kernel_parameters: 
    472476            if p.length_control: 
     
    478482                                     % (p.length_control, p.name)) 
    479483                ref.is_control = True 
     484                ref.polydisperse = False 
    480485                low, high = ref.limits 
    481486                if int(low) != low or int(high) != high or low < 0 or high > 20: 
     
    689694    info.profile = getattr(kernel_module, 'profile', None) # type: ignore 
    690695    info.sesans = getattr(kernel_module, 'sesans', None) # type: ignore 
    691     info.control = getattr(kernel_module, 'control', None) 
     696 
     697    # multiplicity info 
     698    control_pars = [p.id for p in parameters.kernel_parameters if p.is_control] 
     699    default_control = control_pars[0] if control_pars else None 
     700    info.control = getattr(kernel_module, 'control', default_control) 
    692701    info.hidden = getattr(kernel_module, 'hidden', None) # type: ignore 
    693702 
Note: See TracChangeset for help on using the changeset viewer.