Changeset 256dfe1 in sasmodels for sasmodels/compare.py


Ignore:
Timestamp:
Jul 18, 2016 12:42:29 AM (8 years ago)
Author:
Paul Kienzle <pkienzle@…>
Branches:
master, core_shell_microgels, costrafo411, magnetic_model, release_v0.94, release_v0.95, ticket-1257-vesicle-product, ticket_1156, ticket_1265_superball, ticket_822_more_unit_tests
Children:
c5ac2b2
Parents:
46ed760
Message:

allow comparison of multiplicity models with sasview 3.x

File:
1 edited

Legend:

Unmodified
Added
Removed
  • sasmodels/compare.py

    rf3bd37f r256dfe1  
    387387    from sas.models.MultiplicationModel import MultiplicationModel 
    388388 
    389     def get_model(name): 
     389    def get_model_class(name): 
    390390        # type: (str) -> "sas.models.BaseComponent" 
    391391        #print("new",sorted(_pars.items())) 
     
    394394        if ModelClass is None: 
    395395            raise ValueError("could not find model %r in sas.models"%name) 
    396         return ModelClass() 
     396        return ModelClass 
     397 
     398    # WARNING: ugly hack when handling model! 
     399    # Sasview models with multiplicity need to be created with the target 
     400    # multiplicity, so we cannot create the target model ahead of time for 
     401    # for multiplicity models.  Instead we store the model in a list and 
     402    # update the first element of that list with the new multiplicity model 
     403    # every time we evaluate. 
    397404 
    398405    # grab the sasview model, or create it if it is a product model 
     
    401408        if composition_type == 'product': 
    402409            P, S = [get_model(revert_name(p)) for p in parts] 
    403             model = MultiplicationModel(P, S) 
     410            model = [MultiplicationModel(P, S)] 
    404411        else: 
    405412            raise ValueError("sasview mixture models not supported by compare") 
     
    409416            raise ValueError("model %r does not exist in old sasview" 
    410417                            % model_info.id) 
    411         model = get_model(old_name) 
     418        ModelClass = get_model_class(old_name) 
     419        model = [ModelClass()] 
    412420 
    413421    # build a smearer with which to call the model, if necessary 
     
    421429            smearer.accuracy = data.accuracy 
    422430            smearer.set_index(index) 
    423             theory = lambda: smearer.get_value() 
     431            def _call_smearer(): 
     432                smearer.model = model[0] 
     433                return smearer.get_value() 
     434            theory = lambda: _call_smearer() 
    424435        else: 
    425             theory = lambda: model.evalDistribution([data.qx_data[index], 
    426                                                      data.qy_data[index]]) 
     436            theory = lambda: model[0].evalDistribution([data.qx_data[index], 
     437                                                        data.qy_data[index]]) 
    427438    elif smearer is not None: 
    428         theory = lambda: smearer(model.evalDistribution(data.x)) 
     439        theory = lambda: smearer(model[0].evalDistribution(data.x)) 
    429440    else: 
    430         theory = lambda: model.evalDistribution(data.x) 
     441        theory = lambda: model[0].evalDistribution(data.x) 
    431442 
    432443    def calculator(**pars): 
     
    435446        Sasview calculator for model. 
    436447        """ 
     448        # For multiplicity models, recreate the model the first time the 
     449        if model_info.control: 
     450            model[0] = ModelClass(int(pars[model_info.control])) 
    437451        # paying for parameter conversion each time to keep life simple, if not fast 
    438         pars = revert_pars(model_info, pars) 
    439         for k, v in pars.items(): 
     452        oldpars = revert_pars(model_info, pars) 
     453        for k, v in oldpars.items(): 
    440454            name_attr = k.split('.')  # polydispersity components 
    441455            if len(name_attr) == 2: 
    442                 model.dispersion[name_attr[0]][name_attr[1]] = v 
     456                model[0].dispersion[name_attr[0]][name_attr[1]] = v 
    443457            else: 
    444                 model.setParam(k, v) 
     458                model[0].setParam(k, v) 
    445459        return theory() 
    446460 
     
    742756        for ext, val in parts: 
    743757            if p.length > 1: 
    744                 dict(("%s%d%s"%(p.id,k,ext), val) for k in range(p.length)) 
     758                dict(("%s%d%s"%(p.id,k,ext), val) for k in range(1, p.length+1)) 
    745759            else: 
    746760                pars[p.id+ext] = val 
Note: See TracChangeset for help on using the changeset viewer.