Changeset aa4946b in sasmodels for sasmodels/sasview_model.py


Ignore:
Timestamp:
Mar 11, 2015 11:15:40 PM (9 years ago)
Author:
Paul Kienzle <pkienzle@…>
Branches:
master, core_shell_microgels, costrafo411, magnetic_model, release_v0.94, release_v0.95, ticket-1257-vesicle-product, ticket_1156, ticket_1265_superball, ticket_822_more_unit_tests
Children:
af1d68c
Parents:
49d1d42f
Message:

refactor so kernels are loaded via core.load_model

File:
1 edited

Legend:

Unmodified
Added
Removed
  • sasmodels/sasview_model.py

    r63b32bb raa4946b  
    1414""" 
    1515 
    16 # TODO: add a sasview=>sasmodels parameter translation layer 
    17 # this will allow us to use the new sasmodels as drop in replacements, and 
    18 # delay renaming parameters until all models have been converted. 
    19  
    2016import math 
    2117from copy import deepcopy 
     
    2420import numpy as np 
    2521 
    26 try: 
    27     from .kernelcl import load_model 
    28 except ImportError, exc: 
    29     warnings.warn(str(exc)) 
    30     warnings.warn("using ctypes instead") 
    31     from .kerneldll import load_model 
    32  
    33  
    34 def make_class(kernel_module, dtype='single', namestyle='name'): 
     22from . import core 
     23 
     24def make_class(model_definition, dtype='single', namestyle='name'): 
    3525    """ 
    3626    Load the sasview model defined in *kernel_module*. 
     
    3828    Returns a class that can be used directly as a sasview model. 
    3929 
    40     Defaults to using the new name for a model. Setting namestyle='name' 
    41     will produce a class with a name compatible with SasView 
     30    Defaults to using the new name for a model.  Setting 
     31    *namestyle='oldname'* will produce a class with a name 
     32    compatible with SasView. 
    4233    """ 
    43     model = load_model(kernel_module, dtype=dtype) 
     34    model = core.load_model(model_definition, dtype=dtype) 
    4435    def __init__(self, multfactor=1): 
    4536        SasviewModel.__init__(self, model) 
     
    313304            return 1.0 
    314305        else: 
    315             vol_pars = self._model.info['partype']['volume'] 
    316             values, weights = self._dispersion_mesh(vol_pars) 
     306            values, weights = self._dispersion_mesh() 
    317307            fv = ER(*values) 
    318308            #print values[0].shape, weights.shape, fv.shape 
     
    329319            return 1.0 
    330320        else: 
    331             vol_pars = self._model.info['partype']['volume'] 
    332             values, weights = self._dispersion_mesh(vol_pars) 
     321            values, weights = self._dispersion_mesh() 
    333322            whole, part = VR(*values) 
    334323            return np.sum(weights * part) / np.sum(weights * whole) 
     
    362351            raise ValueError("%r is not a dispersity or orientation parameter") 
    363352 
    364     def _dispersion_mesh(self, pars): 
     353    def _dispersion_mesh(self): 
    365354        """ 
    366355        Create a mesh grid of dispersion parameters and weights. 
     
    370359        parameter set in the vector. 
    371360        """ 
    372         values, weights = zip(*[self._get_weights(p) for p in pars]) 
    373         values = [v.flatten() for v in np.meshgrid(*values)] 
    374         weights = np.vstack([v.flatten() for v in np.meshgrid(*weights)]) 
    375         weights = np.prod(weights, axis=0) 
    376         return values, weights 
     361        pars = self._model.info['partype']['volume'] 
     362        return core.dispersion_mesh([self._get_weights(p) for p in pars]) 
    377363 
    378364    def _get_weights(self, par): 
Note: See TracChangeset for help on using the changeset viewer.