Ignore:
File:
1 edited

Legend:

Unmodified
Added
Removed
  • sasmodels/sasview_model.py

    r81ec7c8 r60f03de  
    2121import logging 
    2222 
    23 import numpy as np 
     23import numpy as np  # type: ignore 
    2424 
    2525from . import core 
    2626from . import custom 
    2727from . import generate 
     28from . import weights 
     29from . import details 
     30from . import modelinfo 
    2831 
    2932try: 
    30     from typing import Dict, Mapping, Any, Sequence, Tuple, NamedTuple, List, Optional 
     33    from typing import Dict, Mapping, Any, Sequence, Tuple, NamedTuple, List, Optional, Union, Callable 
     34    from .modelinfo import ModelInfo, Parameter 
    3135    from .kernel import KernelModel 
    3236    MultiplicityInfoType = NamedTuple( 
     
    3438        [("number", int), ("control", str), ("choices", List[str]), 
    3539         ("x_axis_label", str)]) 
     40    SasviewModelType = Callable[[int], "SasviewModel"] 
    3641except ImportError: 
    3742    pass 
    3843 
    3944# TODO: separate x_axis_label from multiplicity info 
    40 # The x-axis label belongs with the profile generating function 
     45# The profile x-axis label belongs with the profile generating function 
    4146MultiplicityInfo = collections.namedtuple( 
    4247    'MultiplicityInfo', 
     
    4449) 
    4550 
     51# TODO: figure out how to say that the return type is a subclass 
    4652def load_standard_models(): 
     53    # type: () -> List[SasviewModelType] 
    4754    """ 
    4855    Load and return the list of predefined models. 
     
    5562        try: 
    5663            models.append(_make_standard_model(name)) 
    57         except: 
     64        except Exception: 
    5865            logging.error(traceback.format_exc()) 
    5966    return models 
     
    6168 
    6269def load_custom_model(path): 
     70    # type: (str) -> SasviewModelType 
    6371    """ 
    6472    Load a custom model given the model path. 
    6573    """ 
    6674    kernel_module = custom.load_custom_kernel_module(path) 
    67     model_info = generate.make_model_info(kernel_module) 
     75    model_info = modelinfo.make_model_info(kernel_module) 
    6876    return _make_model_from_info(model_info) 
    6977 
    7078 
    7179def _make_standard_model(name): 
     80    # type: (str) -> SasviewModelType 
    7281    """ 
    7382    Load the sasview model defined by *name*. 
     
    7887    """ 
    7988    kernel_module = generate.load_kernel_module(name) 
    80     model_info = generate.make_model_info(kernel_module) 
     89    model_info = modelinfo.make_model_info(kernel_module) 
    8190    return _make_model_from_info(model_info) 
    8291 
    8392 
    8493def _make_model_from_info(model_info): 
     94    # type: (ModelInfo) -> SasviewModelType 
    8595    """ 
    8696    Convert *model_info* into a SasView model wrapper. 
    8797    """ 
    88     model_info['variant_info'] = None  # temporary hack for older sasview 
    89     def __init__(self, multiplicity=1): 
     98    def __init__(self, multiplicity=None): 
    9099        SasviewModel.__init__(self, multiplicity=multiplicity) 
    91100    attrs = _generate_model_attributes(model_info) 
    92101    attrs['__init__'] = __init__ 
    93     ConstructedModel = type(model_info['id'], (SasviewModel,), attrs) 
     102    ConstructedModel = type(model_info.name, (SasviewModel,), attrs) # type: SasviewModelType 
    94103    return ConstructedModel 
    95104 
     
    104113    All the attributes should be immutable to avoid accidents. 
    105114    """ 
     115 
     116    # TODO: allow model to override axis labels input/output name/unit 
     117 
     118    # Process multiplicity 
     119    non_fittable = []  # type: List[str] 
     120    xlabel = model_info.profile_axes[0] if model_info.profile is not None else "" 
     121    variants = MultiplicityInfo(0, "", [], xlabel) 
     122    for p in model_info.parameters.kernel_parameters: 
     123        if p.name == model_info.control: 
     124            non_fittable.append(p.name) 
     125            variants = MultiplicityInfo( 
     126                len(p.choices), p.name, p.choices, xlabel 
     127            ) 
     128            break 
     129        elif p.is_control: 
     130            non_fittable.append(p.name) 
     131            variants = MultiplicityInfo( 
     132                int(p.limits[1]), p.name, p.choices, xlabel 
     133            ) 
     134            break 
     135 
     136    # Organize parameter sets 
     137    orientation_params = [] 
     138    magnetic_params = [] 
     139    fixed = [] 
     140    for p in model_info.parameters.user_parameters(): 
     141        if p.type == 'orientation': 
     142            orientation_params.append(p.name) 
     143            orientation_params.append(p.name+".width") 
     144            fixed.append(p.name+".width") 
     145        if p.type == 'magnetic': 
     146            orientation_params.append(p.name) 
     147            magnetic_params.append(p.name) 
     148            fixed.append(p.name+".width") 
     149 
     150    # Build class dictionary 
    106151    attrs = {}  # type: Dict[str, Any] 
    107152    attrs['_model_info'] = model_info 
    108     attrs['name'] = model_info['name'] 
    109     attrs['id'] = model_info['id'] 
    110     attrs['description'] = model_info['description'] 
    111     attrs['category'] = model_info['category'] 
    112  
    113     # TODO: allow model to override axis labels input/output name/unit 
    114  
    115     #self.is_multifunc = False 
    116     non_fittable = []  # type: List[str] 
    117     variants = MultiplicityInfo(0, "", [], "") 
    118     attrs['is_structure_factor'] = model_info['structure_factor'] 
    119     attrs['is_form_factor'] = model_info['ER'] is not None 
     153    attrs['name'] = model_info.name 
     154    attrs['id'] = model_info.id 
     155    attrs['description'] = model_info.description 
     156    attrs['category'] = model_info.category 
     157    attrs['is_structure_factor'] = model_info.structure_factor 
     158    attrs['is_form_factor'] = model_info.ER is not None 
    120159    attrs['is_multiplicity_model'] = variants[0] > 1 
    121160    attrs['multiplicity_info'] = variants 
    122  
    123     partype = model_info['partype'] 
    124     orientation_params = ( 
    125             partype['orientation'] 
    126             + [n + '.width' for n in partype['orientation']] 
    127             + partype['magnetic']) 
    128     magnetic_params = partype['magnetic'] 
    129     fixed = [n + '.width' for n in partype['pd-2d']] 
    130  
    131161    attrs['orientation_params'] = tuple(orientation_params) 
    132162    attrs['magnetic_params'] = tuple(magnetic_params) 
    133163    attrs['fixed'] = tuple(fixed) 
    134  
    135164    attrs['non_fittable'] = tuple(non_fittable) 
    136165 
     
    192221    dispersion = None  # type: Dict[str, Any] 
    193222    #: units and limits for each parameter 
    194     details = None     # type: Mapping[str, Tuple(str, float, float)] 
    195     #: multiplicity used, or None if no multiplicity controls 
     223    details = None     # type: Dict[str, Sequence[Any]] 
     224    #                  # actual type is Dict[str, List[str, float, float]] 
     225    #: multiplicity value, or None if no multiplicity on the model 
    196226    multiplicity = None     # type: Optional[int] 
    197  
    198     def __init__(self, multiplicity): 
    199         # type: () -> None 
    200         print("initializing", self.name) 
    201         #raise Exception("first initialization") 
    202         self._model = None 
    203  
    204         ## _persistency_dict is used by sas.perspectives.fitting.basepage 
    205         ## to store dispersity reference. 
     227    #: memory for polydispersity array if using ArrayDispersion (used by sasview). 
     228    _persistency_dict = None # type: Dict[str, Tuple[np.ndarray, np.ndarray]] 
     229 
     230    def __init__(self, multiplicity=None): 
     231        # type: (Optional[int]) -> None 
     232 
     233        # TODO: _persistency_dict to persistency_dict throughout sasview 
     234        # TODO: refactor multiplicity to encompass variants 
     235        # TODO: dispersion should be a class 
     236        # TODO: refactor multiplicity info 
     237        # TODO: separate profile view from multiplicity 
     238        # The button label, x and y axis labels and scale need to be under 
     239        # the control of the model, not the fit page.  Maximum flexibility, 
     240        # the fit page would supply the canvas and the profile could plot 
     241        # how it wants, but this assumes matplotlib.  Next level is that 
     242        # we provide some sort of data description including title, labels 
     243        # and lines to plot. 
     244 
     245        # Get the list of hidden parameters given the mulitplicity 
     246        # Don't include multiplicity in the list of parameters 
     247        self.multiplicity = multiplicity 
     248        if multiplicity is not None: 
     249            hidden = self._model_info.get_hidden_parameters(multiplicity) 
     250            hidden |= set([self.multiplicity_info.control]) 
     251        else: 
     252            hidden = set() 
     253 
    206254        self._persistency_dict = {} 
    207  
    208         self.multiplicity = multiplicity 
    209  
    210255        self.params = collections.OrderedDict() 
    211256        self.dispersion = {} 
    212257        self.details = {} 
    213  
    214         for p in self._model_info['parameters']: 
     258        for p in self._model_info.parameters.user_parameters(): 
     259            if p.name in hidden: 
     260                continue 
    215261            self.params[p.name] = p.default 
    216             self.details[p.name] = [p.units] + p.limits 
    217  
    218         for name in self._model_info['partype']['pd-2d']: 
    219             self.dispersion[name] = { 
    220                 'width': 0, 
    221                 'npts': 35, 
    222                 'nsigmas': 3, 
    223                 'type': 'gaussian', 
    224             } 
     262            self.details[p.id] = [p.units, p.limits[0], p.limits[1]] 
     263            if p.polydisperse: 
     264                self.details[p.id+".width"] = [ 
     265                    "", 0.0, 1.0 if p.relative_pd else np.inf 
     266                ] 
     267                self.dispersion[p.name] = { 
     268                    'width': 0, 
     269                    'npts': 35, 
     270                    'nsigmas': 3, 
     271                    'type': 'gaussian', 
     272                } 
    225273 
    226274    def __get_state__(self): 
     275        # type: () -> Dict[str, Any] 
    227276        state = self.__dict__.copy() 
    228277        state.pop('_model') 
     
    233282 
    234283    def __set_state__(self, state): 
     284        # type: (Dict[str, Any]) -> None 
    235285        self.__dict__ = state 
    236286        self._model = None 
    237287 
    238288    def __str__(self): 
     289        # type: () -> str 
    239290        """ 
    240291        :return: string representation 
     
    243294 
    244295    def is_fittable(self, par_name): 
     296        # type: (str) -> bool 
    245297        """ 
    246298        Check if a given parameter is fittable or not 
     
    253305 
    254306 
    255     # pylint: disable=no-self-use 
    256307    def getProfile(self): 
     308        # type: () -> (np.ndarray, np.ndarray) 
    257309        """ 
    258310        Get SLD profile 
     
    261313                beta is a list of the corresponding SLD values 
    262314        """ 
    263         return None, None 
     315        args = [] # type: List[Union[float, np.ndarray]] 
     316        for p in self._model_info.parameters.kernel_parameters: 
     317            if p.id == self.multiplicity_info.control: 
     318                args.append(float(self.multiplicity)) 
     319            elif p.length == 1: 
     320                args.append(self.params.get(p.id, np.NaN)) 
     321            else: 
     322                args.append([self.params.get(p.id+str(k), np.NaN) 
     323                             for k in range(1,p.length+1)]) 
     324        return self._model_info.profile(*args) 
    264325 
    265326    def setParam(self, name, value): 
     327        # type: (str, float) -> None 
    266328        """ 
    267329        Set the value of a model parameter 
     
    290352 
    291353    def getParam(self, name): 
     354        # type: (str) -> float 
    292355        """ 
    293356        Set the value of a model parameter 
     
    313376 
    314377    def getParamList(self): 
     378        # type: () -> Sequence[str] 
    315379        """ 
    316380        Return a list of all available parameters for the model 
    317381        """ 
    318         param_list = self.params.keys() 
     382        param_list = list(self.params.keys()) 
    319383        # WARNING: Extending the list with the dispersion parameters 
    320384        param_list.extend(self.getDispParamList()) 
     
    322386 
    323387    def getDispParamList(self): 
     388        # type: () -> Sequence[str] 
    324389        """ 
    325390        Return a list of polydispersity parameters for the model 
    326391        """ 
    327392        # TODO: fix test so that parameter order doesn't matter 
    328         ret = ['%s.%s' % (d.lower(), p) 
    329                for d in self._model_info['partype']['pd-2d'] 
    330                for p in ('npts', 'nsigmas', 'width')] 
     393        ret = ['%s.%s' % (p.name.lower(), ext) 
     394               for p in self._model_info.parameters.user_parameters() 
     395               for ext in ('npts', 'nsigmas', 'width') 
     396               if p.polydisperse] 
    331397        #print(ret) 
    332398        return ret 
    333399 
    334400    def clone(self): 
     401        # type: () -> "SasviewModel" 
    335402        """ Return a identical copy of self """ 
    336403        return deepcopy(self) 
    337404 
    338405    def run(self, x=0.0): 
     406        # type: (Union[float, (float, float), List[float]]) -> float 
    339407        """ 
    340408        Evaluate the model 
     
    349417            # pylint: disable=unpacking-non-sequence 
    350418            q, phi = x 
    351             return self.calculate_Iq([q * math.cos(phi)], 
    352                                      [q * math.sin(phi)])[0] 
    353         else: 
    354             return self.calculate_Iq([float(x)])[0] 
     419            return self.calculate_Iq([q*math.cos(phi)], [q*math.sin(phi)])[0] 
     420        else: 
     421            return self.calculate_Iq([x])[0] 
    355422 
    356423 
    357424    def runXY(self, x=0.0): 
     425        # type: (Union[float, (float, float), List[float]]) -> float 
    358426        """ 
    359427        Evaluate the model in cartesian coordinates 
     
    366434        """ 
    367435        if isinstance(x, (list, tuple)): 
    368             return self.calculate_Iq([float(x[0])], [float(x[1])])[0] 
    369         else: 
    370             return self.calculate_Iq([float(x)])[0] 
     436            return self.calculate_Iq([x[0]], [x[1]])[0] 
     437        else: 
     438            return self.calculate_Iq([x])[0] 
    371439 
    372440    def evalDistribution(self, qdist): 
     441        # type: (Union[np.ndarray, Tuple[np.ndarray, np.ndarray], List[np.ndarray]]) -> np.ndarray 
    373442        r""" 
    374443        Evaluate a distribution of q-values. 
     
    401470            # Check whether we have a list of ndarrays [qx,qy] 
    402471            qx, qy = qdist 
    403             partype = self._model_info['partype'] 
    404             if not partype['orientation'] and not partype['magnetic']: 
     472            if not self._model_info.parameters.has_2d: 
    405473                return self.calculate_Iq(np.sqrt(qx ** 2 + qy ** 2)) 
    406474            else: 
     
    415483                            % type(qdist)) 
    416484 
    417     def calculate_Iq(self, *args): 
     485    def calculate_Iq(self, qx, qy=None): 
     486        # type: (Sequence[float], Optional[Sequence[float]]) -> np.ndarray 
    418487        """ 
    419488        Calculate Iq for one set of q with the current parameters. 
     
    426495        if self._model is None: 
    427496            self._model = core.build_model(self._model_info) 
    428         q_vectors = [np.asarray(q) for q in args] 
    429         fn = self._model.make_kernel(q_vectors) 
    430         pars = [self.params[v] for v in fn.fixed_pars] 
    431         pd_pars = [self._get_weights(p) for p in fn.pd_pars] 
    432         result = fn(pars, pd_pars, self.cutoff) 
    433         fn.q_input.release() 
    434         fn.release() 
     497        if qy is not None: 
     498            q_vectors = [np.asarray(qx), np.asarray(qy)] 
     499        else: 
     500            q_vectors = [np.asarray(qx)] 
     501        kernel = self._model.make_kernel(q_vectors) 
     502        pairs = [self._get_weights(p) 
     503                 for p in self._model_info.parameters.call_parameters] 
     504        call_details, weight, value = details.build_details(kernel, pairs) 
     505        result = kernel(call_details, weight, value, cutoff=self.cutoff) 
     506        kernel.release() 
    435507        return result 
    436508 
    437509    def calculate_ER(self): 
     510        # type: () -> float 
    438511        """ 
    439512        Calculate the effective radius for P(q)*S(q) 
     
    441514        :return: the value of the effective radius 
    442515        """ 
    443         ER = self._model_info.get('ER', None) 
    444         if ER is None: 
     516        if self._model_info.ER is None: 
    445517            return 1.0 
    446518        else: 
    447             values, weights = self._dispersion_mesh() 
    448             fv = ER(*values) 
     519            value, weight = self._dispersion_mesh() 
     520            fv = self._model_info.ER(*value) 
    449521            #print(values[0].shape, weights.shape, fv.shape) 
    450             return np.sum(weights * fv) / np.sum(weights) 
     522            return np.sum(weight * fv) / np.sum(weight) 
    451523 
    452524    def calculate_VR(self): 
     525        # type: () -> float 
    453526        """ 
    454527        Calculate the volf ratio for P(q)*S(q) 
     
    456529        :return: the value of the volf ratio 
    457530        """ 
    458         VR = self._model_info.get('VR', None) 
    459         if VR is None: 
     531        if self._model_info.VR is None: 
    460532            return 1.0 
    461533        else: 
    462             values, weights = self._dispersion_mesh() 
    463             whole, part = VR(*values) 
    464             return np.sum(weights * part) / np.sum(weights * whole) 
     534            value, weight = self._dispersion_mesh() 
     535            whole, part = self._model_info.VR(*value) 
     536            return np.sum(weight * part) / np.sum(weight * whole) 
    465537 
    466538    def set_dispersion(self, parameter, dispersion): 
     539        # type: (str, weights.Dispersion) -> Dict[str, Any] 
    467540        """ 
    468541        Set the dispersion object for a model parameter 
     
    487560            from . import weights 
    488561            disperser = weights.dispersers[dispersion.__class__.__name__] 
    489             dispersion = weights.models[disperser]() 
     562            dispersion = weights.MODELS[disperser]() 
    490563            self.dispersion[parameter] = dispersion.get_pars() 
    491564        else: 
     
    493566 
    494567    def _dispersion_mesh(self): 
     568        # type: () -> List[Tuple[np.ndarray, np.ndarray]] 
    495569        """ 
    496570        Create a mesh grid of dispersion parameters and weights. 
     
    500574        parameter set in the vector. 
    501575        """ 
    502         pars = self._model_info['partype']['volume'] 
    503         return core.dispersion_mesh([self._get_weights(p) for p in pars]) 
     576        pars = [self._get_weights(p) 
     577                for p in self._model_info.parameters.call_parameters 
     578                if p.type == 'volume'] 
     579        return details.dispersion_mesh(self._model_info, pars) 
    504580 
    505581    def _get_weights(self, par): 
     582        # type: (Parameter) -> Tuple[np.ndarray, np.ndarray] 
    506583        """ 
    507584        Return dispersion weights for parameter 
    508585        """ 
    509         from . import weights 
    510         relative = self._model_info['partype']['pd-rel'] 
    511         limits = self._model_info['limits'] 
    512         dis = self.dispersion[par] 
    513         value, weight = weights.get_weights( 
    514             dis['type'], dis['npts'], dis['width'], dis['nsigmas'], 
    515             self.params[par], limits[par], par in relative) 
    516         return value, weight / np.sum(weight) 
    517  
     586        if par.name not in self.params: 
     587            if par.name == self.multiplicity_info.control: 
     588                return [self.multiplicity], [] 
     589            else: 
     590                return [np.NaN], [] 
     591        elif par.polydisperse: 
     592            dis = self.dispersion[par.name] 
     593            value, weight = weights.get_weights( 
     594                dis['type'], dis['npts'], dis['width'], dis['nsigmas'], 
     595                self.params[par.name], par.limits, par.relative_pd) 
     596            return value, weight / np.sum(weight) 
     597        else: 
     598            return [self.params[par.name]], [] 
    518599 
    519600def test_model(): 
     601    # type: () -> float 
    520602    """ 
    521603    Test that a sasview model (cylinder) can be run. 
     
    525607    return cylinder.evalDistribution([0.1,0.1]) 
    526608 
     609def test_rpa(): 
     610    # type: () -> float 
     611    """ 
     612    Test that a sasview model (cylinder) can be run. 
     613    """ 
     614    RPA = _make_standard_model('rpa') 
     615    rpa = RPA(3) 
     616    return rpa.evalDistribution([0.1,0.1]) 
     617 
    527618 
    528619def test_model_list(): 
     620    # type: () -> None 
    529621    """ 
    530622    Make sure that all models build as sasview models. 
Note: See TracChangeset for help on using the changeset viewer.