Changeset de97440 in sasmodels


Ignore:
Timestamp:
Mar 18, 2016 11:00:59 AM (9 years ago)
Author:
Paul Kienzle <pkienzle@…>
Branches:
master, core_shell_microgels, costrafo411, magnetic_model, release_v0.94, release_v0.95, ticket-1257-vesicle-product, ticket_1156, ticket_1265_superball, ticket_822_more_unit_tests
Children:
68b8734
Parents:
e98c1e0
Message:

picklable sasview model wrapper?

File:
1 edited

Legend:

Unmodified
Added
Removed
  • sasmodels/sasview_model.py

    r2622b3f rde97440  
    2323from . import core 
    2424 
    25 def make_class(model_info, dtype='single', namestyle='name'): 
     25def standard_models(): 
     26    return [make_class(model_name) for model_name in core.list_models()] 
     27 
     28def make_class(model_name, namestyle='name'): 
    2629    """ 
    2730    Load the sasview model defined in *kernel_module*. 
    2831 
    29     Returns a class that can be used directly as a sasview model. 
     32    Returns a class that can be used directly as a sasview model.t 
    3033 
    3134    Defaults to using the new name for a model.  Setting 
     
    3336    compatible with SasView. 
    3437    """ 
    35     model = core.build_model(model_info, dtype=dtype) 
     38    model_info = core.load_model_info(model_name) 
    3639    def __init__(self, multfactor=1): 
    37         SasviewModel.__init__(self, model) 
     40        SasviewModel.__init__(self, model_info) 
    3841    attrs = dict(__init__=__init__) 
    39     ConstructedModel = type(model.info[namestyle], (SasviewModel,), attrs) 
     42    ConstructedModel = type(model_info[namestyle], (SasviewModel,), attrs) 
    4043    return ConstructedModel 
    4144 
     
    4447    Sasview wrapper for opencl/ctypes model. 
    4548    """ 
    46     def __init__(self, model): 
    47         """Initialization""" 
    48         self._model = model 
    49  
    50         self.name = model.info['name'] 
    51         self.oldname = model.info['oldname'] 
    52         self.description = model.info['description'] 
     49    def __init__(self, model_info): 
     50        self._model_info = model_info 
     51        self._kernel = None 
     52 
     53        self.name = model_info['name'] 
     54        self.oldname = model_info['oldname'] 
     55        self.description = model_info['description'] 
    5356        self.category = None 
    5457        self.multiplicity_info = None 
     
    6063        self.params = collections.OrderedDict() 
    6164        self.dispersion = dict() 
    62         partype = model.info['partype'] 
    63  
    64         for p in model.info['parameters']: 
     65        partype = model_info['partype'] 
     66 
     67        for p in model_info['parameters']: 
    6568            self.params[p.name] = p.default 
    6669            self.details[p.name] = [p.units] + p.limits 
     
    8386 
    8487        ## independent parameter name and unit [string] 
    85         self.input_name = model.info.get("input_name", "Q") 
    86         self.input_unit = model.info.get("input_unit", "A^{-1}") 
    87         self.output_name = model.info.get("output_name", "Intensity") 
    88         self.output_unit = model.info.get("output_unit", "cm^{-1}") 
     88        self.input_name = model_info.get("input_name", "Q") 
     89        self.input_unit = model_info.get("input_unit", "A^{-1}") 
     90        self.output_name = model_info.get("output_name", "Intensity") 
     91        self.output_unit = model_info.get("output_unit", "cm^{-1}") 
    8992 
    9093        ## _persistency_dict is used by sas.perspectives.fitting.basepage 
     
    9598        ## New fields introduced for opencl rewrite 
    9699        self.cutoff = 1e-5 
     100 
     101    def __get_state__(self): 
     102        state = self.__dict__.copy() 
     103        model_id = self._model_info['id'] 
     104        state.pop('_kernel') 
     105        # May need to reload model info on set state since it has pointers 
     106        # to python implementations of Iq, etc. 
     107        #state.pop('_model_info') 
     108        return state 
     109 
     110    def __set_state__(self, state): 
     111        self.__dict__ = state 
     112        self._kernel = None 
    97113 
    98114    def __str__(self): 
     
    187203        # TODO: fix test so that parameter order doesn't matter 
    188204        ret = ['%s.%s' % (d.lower(), p) 
    189                for d in self._model.info['partype']['pd-2d'] 
     205               for d in self._model_info['partype']['pd-2d'] 
    190206               for p in ('npts', 'nsigmas', 'width')] 
    191207        #print(ret) 
     
    261277            # Check whether we have a list of ndarrays [qx,qy] 
    262278            qx, qy = qdist 
    263             partype = self._model.info['partype'] 
     279            partype = self._model_info['partype'] 
    264280            if not partype['orientation'] and not partype['magnetic']: 
    265281                return self.calculate_Iq(np.sqrt(qx ** 2 + qy ** 2)) 
     
    284300        to the card for each evaluation. 
    285301        """ 
     302        if self._kernel is None: 
     303            self._kernel = core.build_model(self._model_info) 
    286304        q_vectors = [np.asarray(q) for q in args] 
    287         fn = self._model(q_vectors) 
     305        fn = self._kernel(q_vectors) 
    288306        pars = [self.params[v] for v in fn.fixed_pars] 
    289307        pd_pars = [self._get_weights(p) for p in fn.pd_pars] 
     
    299317        :return: the value of the effective radius 
    300318        """ 
    301         ER = self._model.info.get('ER', None) 
     319        ER = self._model_info.get('ER', None) 
    302320        if ER is None: 
    303321            return 1.0 
     
    314332        :return: the value of the volf ratio 
    315333        """ 
    316         VR = self._model.info.get('VR', None) 
     334        VR = self._model_info.get('VR', None) 
    317335        if VR is None: 
    318336            return 1.0 
     
    358376        parameter set in the vector. 
    359377        """ 
    360         pars = self._model.info['partype']['volume'] 
     378        pars = self._model_info['partype']['volume'] 
    361379        return core.dispersion_mesh([self._get_weights(p) for p in pars]) 
    362380 
     
    368386        from . import weights 
    369387 
    370         relative = self._model.info['partype']['pd-rel'] 
    371         limits = self._model.info['limits'] 
     388        relative = self._model_info['partype']['pd-rel'] 
     389        limits = self._model_info['limits'] 
    372390        dis = self.dispersion[par] 
    373391        value, weight = weights.get_weights( 
     
    375393            self.params[par], limits[par], par in relative) 
    376394        return value, weight / np.sum(weight) 
     395 
Note: See TracChangeset for help on using the changeset viewer.