Changeset 04dc697 in sasmodels


Ignore:
Timestamp:
Apr 13, 2016 11:39:09 AM (8 years ago)
Author:
Paul Kienzle <pkienzle@…>
Branches:
master, core_shell_microgels, costrafo411, magnetic_model, release_v0.94, release_v0.95, ticket-1257-vesicle-product, ticket_1156, ticket_1265_superball, ticket_822_more_unit_tests
Children:
60f03de
Parents:
a18c5b3
Message:

more type hinting

Location:
sasmodels
Files:
4 edited

Legend:

Unmodified
Added
Removed
  • sasmodels/bumps_model.py

    r7ae2b7f r04dc697  
    1111 
    1212""" 
    13  
    14 import warnings 
     13from __future__ import print_function 
     14 
     15__all__ = [ "Model", "Experiment" ] 
    1516 
    1617import numpy as np  # type: ignore 
     
    1920from .direct_model import DataMixin 
    2021 
    21 __all__ = [ 
    22     "Model", "Experiment", 
    23     ] 
    24  
    25 # CRUFT: old style bumps wrapper which doesn't separate data and model 
    26 # pylint: disable=invalid-name 
    27 def BumpsModel(data, model, cutoff=1e-5, **kw): 
    28     r""" 
    29     Bind a model to data, along with a polydispersity cutoff. 
    30  
    31     *data* is a :class:`data.Data1D`, :class:`data.Data2D` or 
    32     :class:`data.Sesans` object.  Use :func:`data.empty_data1D` or 
    33     :func:`data.empty_data2D` to define $q, \Delta q$ calculation 
    34     points for displaying the SANS curve when there is no measured data. 
    35  
    36     *model* is a runnable module as returned from :func:`core.load_model`. 
    37  
    38     *cutoff* is the polydispersity weight cutoff. 
    39  
    40     Any additional *key=value* pairs are model dependent parameters. 
    41  
    42     Returns an :class:`Experiment` object. 
    43  
    44     Note that the usual Bumps semantics is not fully supported, since 
    45     assigning *M.name = parameter* on the returned experiment object 
    46     does not set that parameter in the model.  Range setting will still 
    47     work as expected though. 
    48  
    49     .. deprecated:: 0.1 
    50         Use :class:`Experiment` instead. 
    51     """ 
    52     warnings.warn("Use of BumpsModel is deprecated.  Use bumps_model.Experiment instead.") 
    53  
    54     # Create the model and experiment 
    55     model = Model(model, **kw) 
    56     experiment = Experiment(data=data, model=model, cutoff=cutoff) 
    57  
    58     # Copy the model parameters up to the experiment object. 
    59     for k, v in model.parameters().items(): 
    60         setattr(experiment, k, v) 
    61     return experiment 
     22try: 
     23    from typing import Dict, Union, Tuple, Any 
     24    from .data import Data1D, Data2D 
     25    from .kernel import KernelModel 
     26    from .modelinfo import ModelInfo 
     27    Data = Union[Data1D, Data2D] 
     28except ImportError: 
     29    pass 
     30 
     31try: 
     32    # Optional import. This allows the doc builder and nosetests to run even 
     33    # when bumps is not on the path. 
     34    from bumps.names import Parameter # type: ignore 
     35except ImportError: 
     36    pass 
    6237 
    6338 
    6439def create_parameters(model_info, **kwargs): 
     40    # type: (ModelInfo, **Union[float, str, Parameter]) -> Tuple[Dict[str, Parameter], Dict[str, str]] 
    6541    """ 
    6642    Generate Bumps parameters from the model info. 
     
    7147    Any additional *key=value* pairs are initial values for the parameters 
    7248    to the models.  Uninitialized parameters will use the model default 
    73     value. 
     49    value.  The value can be a float, a bumps parameter, or in the case 
     50    of the distribution type parameter, a string. 
    7451 
    7552    Returns a dictionary of *{name: Parameter}* containing the bumps 
     
    7754    *{name: str}* containing the polydispersity distribution types. 
    7855    """ 
    79     # lazy import; this allows the doc builder and nosetests to run even 
    80     # when bumps is not on the path. 
    81     from bumps.names import Parameter  # type: ignore 
    82  
    83     pars = {}     # floating point parameters 
    84     pd_types = {} # distribution names 
     56    pars = {}     # type: Dict[str, Parameter] 
     57    pd_types = {} # type: Dict[str, str] 
    8558    for p in model_info.parameters.call_parameters: 
    8659        value = kwargs.pop(p.name, p.default) 
     
    9669                pars[name] = Parameter.default(value, name=name, limits=limits) 
    9770            name = p.name + '_pd_type' 
    98             pd_types[name] = kwargs.pop(name, 'gaussian') 
     71            pd_types[name] = str(kwargs.pop(name, 'gaussian')) 
    9972 
    10073    if kwargs:  # args not corresponding to parameters 
     
    11588    """ 
    11689    def __init__(self, model, **kwargs): 
    117         self._sasmodel = model 
     90        # type: (KernelModel, **Dict[str, Union[float, Parameter]]) -> None 
     91        self.sasmodel = model 
    11892        pars, pd_types = create_parameters(model.info, **kwargs) 
    11993        for k, v in pars.items(): 
     
    12599 
    126100    def parameters(self): 
     101        # type: () -> Dict[str, Parameter] 
    127102        """ 
    128103        Return a dictionary of parameters objects for the parameters, 
     
    132107 
    133108    def state(self): 
     109        # type: () -> Dict[str, Union[Parameter, str]] 
    134110        """ 
    135111        Return a dictionary of current values for all the parameters, 
     
    156132    The resulting model can be used directly in a Bumps FitProblem call. 
    157133    """ 
     134    _cache = None # type: Dict[str, np.ndarray] 
    158135    def __init__(self, data, model, cutoff=1e-5): 
    159  
     136        # type: (Data, Model, float) -> None 
    160137        # remember inputs so we can inspect from outside 
    161138        self.model = model 
    162139        self.cutoff = cutoff 
    163         self._interpret_data(data, model._sasmodel) 
    164         self.update() 
     140        self._interpret_data(data, model.sasmodel) 
     141        self._cache = {} 
    165142 
    166143    def update(self): 
     144        # type: () -> None 
    167145        """ 
    168146        Call when model parameters have changed and theory needs to be 
    169147        recalculated. 
    170148        """ 
    171         self._cache = {} 
     149        self._cache.clear() 
    172150 
    173151    def numpoints(self): 
     152        # type: () -> float 
    174153        """ 
    175154        Return the number of data points 
     
    178157 
    179158    def parameters(self): 
     159        # type: () -> Dict[str, Parameter] 
    180160        """ 
    181161        Return a dictionary of parameters 
     
    184164 
    185165    def theory(self): 
     166        # type: () -> np.ndarray 
    186167        """ 
    187168        Return the theory corresponding to the model parameters. 
     
    196177 
    197178    def residuals(self): 
     179        # type: () -> np.ndarray 
    198180        """ 
    199181        Return theory minus data normalized by uncertainty. 
     
    203185 
    204186    def nllf(self): 
     187        # type: () -> float 
    205188        """ 
    206189        Return the negative log likelihood of seeing data given the model 
     
    210193        delta = self.residuals() 
    211194        #if np.any(np.isnan(R)): print("NaN in residuals") 
    212         return 0.5 * np.sum(delta ** 2) 
     195        return 0.5 * np.sum(delta**2) 
    213196 
    214197    #def __call__(self): 
     
    216199 
    217200    def plot(self, view='log'): 
     201        # type: (str) -> None 
    218202        """ 
    219203        Plot the data and residuals. 
     
    223207 
    224208    def simulate_data(self, noise=None): 
     209        # type: (float) -> None 
    225210        """ 
    226211        Generate simulated data. 
     
    230215 
    231216    def save(self, basename): 
     217        # type: (str) -> None 
    232218        """ 
    233219        Save the model parameters and data into a file. 
     
    240226 
    241227    def __getstate__(self): 
     228        # type: () -> Dict[str, Any] 
    242229        # Can't pickle gpu functions, so instead make them lazy 
    243230        state = self.__dict__.copy() 
     
    246233 
    247234    def __setstate__(self, state): 
     235        # type: (Dict[str, Any]) -> None 
    248236        # pylint: disable=attribute-defined-outside-init 
    249237        self.__dict__ = state 
     238 
  • sasmodels/kernel.py

    r7ae2b7f r04dc697  
    1919 
    2020class KernelModel(object): 
     21    info = None  # type: ModelInfo 
    2122    def make_kernel(self, q_vectors): 
    2223        # type: (List[np.ndarray]) -> "Kernel" 
  • sasmodels/modelinfo.py

    r04045f4 r04dc697  
    2424    from .details import CallDetails 
    2525    Limits = Tuple[float, float] 
    26     LimitsOrChoice = Union[Limits, Tuple[Sequence[str]]] 
     26    #LimitsOrChoice = Union[Limits, Tuple[Sequence[str]]] 
    2727    ParameterDef = Tuple[str, str, float, Limits, str, str] 
    2828    ParameterSetUser = Dict[str, Union[float, List[float]]] 
     
    7171def parse_parameter(name, units='', default=np.NaN, 
    7272                    user_limits=None, ptype='', description=''): 
    73     # type: (str, str, float, LimitsOrChoice, str, str) -> Parameter 
     73    # type: (str, str, float, Limits, str, str) -> Parameter 
    7474    """ 
    7575    Parse an individual parameter from the parameter definition block. 
     
    732732    profile = None          # type: Optional[Callable[[np.ndarray], None]] 
    733733    sesans = None           # type: Optional[Callable[[np.ndarray], np.ndarray]] 
    734     hidden = None           # type: Optional[Callable[int], Set[str]] 
     734    hidden = None           # type: Optional[Callable[[int], Set[str]]] 
    735735    mono_details = None     # type: CallDetails 
    736736 
  • sasmodels/sasview_model.py

    ra18c5b3 r04dc697  
    3131 
    3232try: 
    33     from typing import Dict, Mapping, Any, Sequence, Tuple, NamedTuple, List, Optional 
     33    from typing import Dict, Mapping, Any, Sequence, Tuple, NamedTuple, List, Optional, Union 
    3434    from .modelinfo import ModelInfo, Parameter 
    3535    from .kernel import KernelModel 
     
    4242 
    4343# TODO: separate x_axis_label from multiplicity info 
    44 # The x-axis label belongs with the profile generating function 
     44# The profile x-axis label belongs with the profile generating function 
    4545MultiplicityInfo = collections.namedtuple( 
    4646    'MultiplicityInfo', 
     
    220220    dispersion = None  # type: Dict[str, Any] 
    221221    #: units and limits for each parameter 
    222     details = None     # type: Mapping[str, Tuple(str, float, float)] 
    223     #: multiplicity used, or None if no multiplicity controls 
     222    details = None     # type: Mapping[str, Tuple[str, float, float]] 
     223    #: multiplicity value, or None if no multiplicity on the model 
    224224    multiplicity = None     # type: Optional[int] 
     225    #: memory for polydispersity array if using ArrayDispersion (used by sasview). 
     226    _persistency_dict = None # type: Dict[str, Tuple[np.ndarray, np.ndarray]] 
    225227 
    226228    def __init__(self, multiplicity=None): 
    227         # type: () -> None 
    228  
    229         ## _persistency_dict is used by sas.perspectives.fitting.basepage 
    230         ## to store dispersity reference. 
    231         self._persistency_dict = {} 
     229        # type: (Optional[int]) -> None 
    232230 
    233231        # TODO: _persistency_dict to persistency_dict throughout sasview 
     
    252250            hidden = set() 
    253251 
     252        self._persistency_dict = {} 
    254253        self.params = collections.OrderedDict() 
    255254        self.dispersion = {} 
     
    375374 
    376375    def getParamList(self): 
    377         # type: () - > Sequence[str] 
     376        # type: () -> Sequence[str] 
    378377        """ 
    379378        Return a list of all available parameters for the model 
    380379        """ 
    381         param_list = self.params.keys() 
     380        param_list = list(self.params.keys()) 
    382381        # WARNING: Extending the list with the dispersion parameters 
    383382        param_list.extend(self.getDispParamList()) 
     
    385384 
    386385    def getDispParamList(self): 
    387         # type: () - > Sequence[str] 
     386        # type: () -> Sequence[str] 
    388387        """ 
    389388        Return a list of polydispersity parameters for the model 
     
    398397 
    399398    def clone(self): 
    400         # type: () - > "SasviewModel" 
     399        # type: () -> "SasviewModel" 
    401400        """ Return a identical copy of self """ 
    402401        return deepcopy(self) 
     
    439438 
    440439    def evalDistribution(self, qdist): 
    441         # type: (Union[np.ndarray, Tuple[np.ndarray, np.ndarray], List[np.ndarray]) -> np.ndarray 
     440        # type: (Union[np.ndarray, Tuple[np.ndarray, np.ndarray], List[np.ndarray]]) -> np.ndarray 
    442441        r""" 
    443442        Evaluate a distribution of q-values. 
Note: See TracChangeset for help on using the changeset viewer.