Changeset 4d76711 in sasmodels for sasmodels/sasview_model.py


Ignore:
Timestamp:
Apr 5, 2016 10:33:44 AM (8 years ago)
Author:
Paul Kienzle <pkienzle@…>
Branches:
master, core_shell_microgels, costrafo411, magnetic_model, release_v0.94, release_v0.95, ticket-1257-vesicle-product, ticket_1156, ticket_1265_superball, ticket_822_more_unit_tests
Children:
3a45c2c, c4e7a5f
Parents:
cd0a808
Message:

adjust interface to sasview

File:
1 edited

Legend:

Unmodified
Added
Removed
  • sasmodels/sasview_model.py

    rf247314 r4d76711  
    1313using :func:`sasmodels.convert.convert`. 
    1414""" 
     15from __future__ import print_function 
    1516 
    1617import math 
    1718from copy import deepcopy 
    1819import collections 
     20import traceback 
     21import logging 
    1922 
    2023import numpy as np 
    2124 
    2225from . import core 
     26from . import custom 
    2327from . import generate 
    2428 
    25 def standard_models(): 
    26     return [make_class(model_name) for model_name in core.list_models()] 
    27  
    28 # TODO: rename to make_class_from_name and update sasview 
    29 def make_class(model_name): 
    30     """ 
    31     Load the sasview model defined in *kernel_module*. 
    32  
    33     Returns a class that can be used directly as a sasview model.t 
    34     """ 
    35     model_info = core.load_model_info(model_name) 
    36     return make_class_from_info(model_info) 
    37  
    38 def make_class_from_file(path): 
    39     model_info = core.load_model_info_from_path(path) 
    40     return make_class_from_info(model_info) 
    41  
    42 def make_class_from_info(model_info): 
     29def load_standard_models(): 
     30    """ 
     31    Load and return the list of predefined models. 
     32 
     33    If there is an error loading a model, then a traceback is logged and the 
     34    model is not returned. 
     35    """ 
     36    models = [] 
     37    for name in core.list_models(): 
     38        try: 
     39            models.append(_make_standard_model(name)) 
     40        except: 
     41            logging.error(traceback.format_exc()) 
     42    return models 
     43 
     44 
     45def load_custom_model(path): 
     46    """ 
     47    Load a custom model given the model path. 
     48    """ 
     49    kernel_module = custom.load_custom_kernel_module(path) 
     50    model_info = generate.make_model_info(kernel_module) 
     51    return _make_model_from_info(model_info) 
     52 
     53 
     54def _make_standard_model(name): 
     55    """ 
     56    Load the sasview model defined by *name*. 
     57 
     58    *name* can be a standard model name or a path to a custom model. 
     59 
     60    Returns a class that can be used directly as a sasview model. 
     61    """ 
     62    kernel_module = generate.load_kernel_module(name) 
     63    model_info = generate.make_model_info(kernel_module) 
     64    return _make_model_from_info(model_info) 
     65 
     66 
     67def _make_model_from_info(model_info): 
     68    """ 
     69    Convert *model_info* into a SasView model wrapper. 
     70    """ 
    4371    def __init__(self, multfactor=1): 
    4472        SasviewModel.__init__(self) 
     
    4775    return ConstructedModel 
    4876 
     77 
    4978class SasviewModel(object): 
    5079    """ 
    5180    Sasview wrapper for opencl/ctypes model. 
    5281    """ 
     82    _model_info = {} 
    5383    def __init__(self): 
    54         self._kernel = None 
     84        self._model = None 
    5585        model_info = self._model_info 
    5686 
     
    104134    def __get_state__(self): 
    105135        state = self.__dict__.copy() 
    106         model_id = self._model_info['id'] 
    107         state.pop('_kernel') 
     136        state.pop('_model') 
    108137        # May need to reload model info on set state since it has pointers 
    109138        # to python implementations of Iq, etc. 
     
    113142    def __set_state__(self, state): 
    114143        self.__dict__ = state 
    115         self._kernel = None 
     144        self._model = None 
    116145 
    117146    def __str__(self): 
     
    202231    def getDispParamList(self): 
    203232        """ 
    204         Return a list of all available parameters for the model 
     233        Return a list of polydispersity parameters for the model 
    205234        """ 
    206235        # TODO: fix test so that parameter order doesn't matter 
     
    303332        to the card for each evaluation. 
    304333        """ 
    305         if self._kernel is None: 
    306             self._kernel = core.build_model(self._model_info) 
     334        if self._model is None: 
     335            self._model = core.build_model(self._model_info) 
    307336        q_vectors = [np.asarray(q) for q in args] 
    308         fn = self._kernel(q_vectors) 
     337        fn = self._model.make_kernel(q_vectors) 
    309338        pars = [self.params[v] for v in fn.fixed_pars] 
    310339        pd_pars = [self._get_weights(p) for p in fn.pd_pars] 
     
    384413    def _get_weights(self, par): 
    385414        """ 
    386             Return dispersion weights 
    387             :param par parameter name 
     415        Return dispersion weights for parameter 
    388416        """ 
    389417        from . import weights 
    390  
    391418        relative = self._model_info['partype']['pd-rel'] 
    392419        limits = self._model_info['limits'] 
     
    397424        return value, weight / np.sum(weight) 
    398425 
     426 
     427def test_model(): 
     428    """ 
     429    Test that a sasview model (cylinder) can be run. 
     430    """ 
     431    Cylinder = _make_standard_model('cylinder') 
     432    cylinder = Cylinder() 
     433    return cylinder.evalDistribution([0.1,0.1]) 
     434 
     435 
     436def test_model_list(): 
     437    """ 
     438    Make sure that all models build as sasview models. 
     439    """ 
     440    from .exception import annotate_exception 
     441    for name in core.list_models(): 
     442        try: 
     443            _make_standard_model(name) 
     444        except: 
     445            annotate_exception("when loading "+name) 
     446            raise 
     447 
     448if __name__ == "__main__": 
     449    print("cylinder(0.1,0.1)=%g"%test_model()) 
Note: See TracChangeset for help on using the changeset viewer.