[f6f3fb4] | 1 | """ |
---|
| 2 | Interface to page state for saved fits. |
---|
| 3 | |
---|
| 4 | The 4.x sasview gui builds the model, smearer, etc. directly inside the wx |
---|
| 5 | GUI code. This code separates the in-memory representation from the GUI. |
---|
| 6 | Initially it is used for a headless bumps fitting backend that operates |
---|
| 7 | directly on the saved XML; eventually it should provide methods to replace |
---|
| 8 | direct access to the PageState object so that the code for setting up and |
---|
| 9 | running fits in wx, qt, and headless versions of SasView shares a common |
---|
| 10 | in-memory representation. |
---|
| 11 | """ |
---|
| 12 | from __future__ import print_function, division |
---|
| 13 | |
---|
| 14 | import copy |
---|
| 15 | from collections import namedtuple |
---|
| 16 | |
---|
| 17 | import numpy as np |
---|
| 18 | |
---|
| 19 | from bumps.names import FitProblem |
---|
| 20 | |
---|
| 21 | from sasmodels.core import load_model_info |
---|
| 22 | from sasmodels.data import plot_theory |
---|
| 23 | from sasmodels.sasview_model import _make_standard_model, MultiplicationModel, load_custom_model |
---|
| 24 | from sasmodels.weights import MODELS as POLYDISPERSITY_MODELS |
---|
| 25 | |
---|
| 26 | from .pagestate import Reader, PageState, SimFitPageState, CUSTOM_MODEL |
---|
| 27 | from .BumpsFitting import SasFitness, ParameterExpressions |
---|
| 28 | from .AbstractFitEngine import FitData1D, FitData2D, Model |
---|
| 29 | from .models import PLUGIN_NAME_BASE, find_plugins_dir |
---|
| 30 | from .qsmearing import smear_selection |
---|
| 31 | |
---|
| 32 | # Monkey patch SasFitness class with plotter |
---|
| 33 | def sasfitness_plot(self, view='log'): |
---|
| 34 | data, theory, resid = self.data.sas_data, self.theory(), self.residuals() |
---|
| 35 | plot_theory(data, theory, resid, view) |
---|
| 36 | SasFitness.plot = sasfitness_plot |
---|
| 37 | |
---|
| 38 | # Use a named tuple for the sasview parameters |
---|
| 39 | PARAMETER_FIELDS = [ |
---|
| 40 | "fitted", "name", "value", "plusminus", "uncertainty", |
---|
| 41 | "lower", "upper", "units", |
---|
| 42 | ] |
---|
| 43 | SasviewParameter = namedtuple("Parameter", PARAMETER_FIELDS) |
---|
| 44 | |
---|
| 45 | class FitState(object): |
---|
| 46 | def __init__(self, fitfile): |
---|
| 47 | self.fitfile = fitfile |
---|
| 48 | self.simfit = None |
---|
| 49 | self.fits = [] |
---|
| 50 | |
---|
| 51 | reader = Reader(self._add_entry) |
---|
| 52 | datasets = reader.read(fitfile) |
---|
| 53 | self._set_constraints() |
---|
| 54 | #print("loaded", datasets) |
---|
| 55 | |
---|
| 56 | def _add_entry(self, state=None, datainfo=None, format=None): |
---|
| 57 | """ |
---|
| 58 | Handed to the reader to receive and accumulate the loaded state objects. |
---|
| 59 | """ |
---|
| 60 | # Note: datainfo is in state.data; format=.svs means reset fit panels |
---|
| 61 | if isinstance(state, PageState): |
---|
| 62 | # TODO: shouldn't the update be part of the load? |
---|
| 63 | state._convert_to_sasmodels() |
---|
| 64 | self.fits.append(state) |
---|
| 65 | elif isinstance(state, SimFitPageState): |
---|
| 66 | self.simfit = state |
---|
| 67 | else: |
---|
| 68 | # ignore empty fit info |
---|
| 69 | pass |
---|
| 70 | |
---|
| 71 | def __str__(self): |
---|
| 72 | return '<SasFit %s>'%self.fitfile |
---|
| 73 | |
---|
| 74 | def show(self): |
---|
| 75 | """ |
---|
| 76 | Summarize the fit pages in the state object. |
---|
| 77 | """ |
---|
| 78 | for k, fit in enumerate(self.fits): |
---|
| 79 | print("="*20, "Fit page", k+1) |
---|
| 80 | #print(fit) |
---|
| 81 | for attr, value in sorted(fit.__dict__.items()): |
---|
| 82 | if isinstance(value, (list, tuple)): |
---|
| 83 | print(attr) |
---|
| 84 | for item in value: |
---|
| 85 | print(" ", item) |
---|
| 86 | else: |
---|
| 87 | print(attr, value) |
---|
| 88 | if self.simfit: |
---|
| 89 | print("="*20, "Constraints") |
---|
| 90 | #print(self.simfit) |
---|
| 91 | for attr, value in sorted(self.simfit.__dict__.items()): |
---|
| 92 | if isinstance(value, (list, tuple)): |
---|
| 93 | print(attr) |
---|
| 94 | for item in value: |
---|
| 95 | print(" ", item) |
---|
| 96 | else: |
---|
| 97 | print(attr, value) |
---|
| 98 | |
---|
| 99 | def make_fitproblem(self): |
---|
| 100 | """ |
---|
| 101 | Build collection of bumps fitness calculators and return the FitProblem. |
---|
| 102 | """ |
---|
| 103 | # TODO: batch info not stored with project/analysis file (ticket #907) |
---|
| 104 | models = [make_fitness(state) for state in self.fits] |
---|
| 105 | if not models: |
---|
| 106 | raise RuntimeError("Nothing to fit") |
---|
| 107 | fit_problem = FitProblem(models) |
---|
| 108 | fit_problem.setp_hook = ParameterExpressions(models) |
---|
| 109 | return fit_problem |
---|
| 110 | |
---|
| 111 | def _set_constraints(self): |
---|
| 112 | """ |
---|
| 113 | Adds fit_page and constraints list to each model. |
---|
| 114 | """ |
---|
| 115 | # early return if no sim fit |
---|
| 116 | if self.simfit is None: |
---|
| 117 | for fit in self.fits: |
---|
| 118 | fit.fit_page = 'M1' |
---|
| 119 | fit.constraints = {} |
---|
| 120 | return |
---|
| 121 | |
---|
| 122 | # Note: simfitpage.py:load_from_save_state relabels the model and |
---|
| 123 | # constraint on load, replacing the model name in the constraint |
---|
| 124 | # expression with the new name for every constraint expression. We |
---|
| 125 | # don't bother to do that here since we don't need to relabel the |
---|
| 126 | # model ids. |
---|
| 127 | constraints = {} |
---|
| 128 | for item in self.simfit.constraints_list: |
---|
| 129 | # model_cbox in the constraints list should match fit_page_source |
---|
| 130 | # in the model list. |
---|
| 131 | pairs = constraints.setdefault(item['model_cbox'], []) |
---|
| 132 | pairs.append((item['param_cbox'], item['constraint'])) |
---|
| 133 | |
---|
| 134 | # No way to uniquely map the page id (M1, M2, etc.) to the different |
---|
| 135 | # items in fits; neither fit_number nor fit_page_source is in the |
---|
| 136 | # pagestate for the fit, and neither model_name nor data name are |
---|
| 137 | # unique. The usual case of one model per data file will get us most |
---|
| 138 | # of the way there but there will be ambiguity when the data file |
---|
| 139 | # is not unique, e.g., when different parts of the data set are |
---|
| 140 | # fit with different models. If the same model and same data are |
---|
| 141 | # used (e.g., with different background, scale or resolution in |
---|
| 142 | # different segments) then the model-fit association will be assigned |
---|
| 143 | # arbitrarily based on whichever happens to come first. |
---|
| 144 | used = [] |
---|
| 145 | for model in self.simfit.model_list: |
---|
| 146 | for fit in self.fits: |
---|
| 147 | #print(model['name'], fit.data_id, model_name(fit), model['model_name']) |
---|
| 148 | if (fit.data_id == model['name'] |
---|
| 149 | and model_name(fit) == model['model_name'] |
---|
| 150 | and fit not in used): |
---|
| 151 | fit.fit_page = model['fit_page_source'] |
---|
| 152 | fit.constraints = constraints.setdefault(fit.fit_page, []) |
---|
| 153 | used.append(fit) |
---|
| 154 | break |
---|
| 155 | else: |
---|
| 156 | raise ValueError("could not find model %s in file" |
---|
| 157 | % model['fit_page_source']) |
---|
| 158 | |
---|
| 159 | |
---|
| 160 | def model_name(state): |
---|
| 161 | """ |
---|
| 162 | Build the model name out of form factor and structure factor (if present). |
---|
| 163 | |
---|
| 164 | This will be the name that is stored as the model name in the simultaneous |
---|
| 165 | fit model_list structure corresponding to the form factor and structure |
---|
| 166 | factor given on the individual fit pages. The model name is used to help |
---|
| 167 | disambiguate different SASentry sections with the same dataset. |
---|
| 168 | """ |
---|
| 169 | p_model, s_model = state.formfactorcombobox, state.structurecombobox |
---|
| 170 | if s_model is not None and s_model != "" and s_model.lower() != "none": |
---|
| 171 | return '*'.join((p_model, s_model)) |
---|
| 172 | else: |
---|
| 173 | return p_model |
---|
| 174 | |
---|
| 175 | def get_data_weight(state): |
---|
| 176 | """ |
---|
| 177 | Get error bars on data. These could be the values computed by reduction |
---|
| 178 | and stored in the file, the square root of the intensity (if instensity |
---|
| 179 | is approximately counts), the intensity itself (would be better as a |
---|
| 180 | percentage of the intensity, such as 2% or 5% depending on relative |
---|
| 181 | counting time), or one for equal weight uncertainty depending on the |
---|
| 182 | value of state.dI_*. |
---|
| 183 | """ |
---|
| 184 | # Cribbed from perspectives/fitting/utils.py:get_weight and |
---|
| 185 | # perspectives/fitting/fitpage.py: get_weight_flag |
---|
| 186 | weight = None |
---|
| 187 | if state.enable2D: |
---|
| 188 | dy_data = state.data.err_data |
---|
| 189 | data = state.data.data |
---|
| 190 | else: |
---|
| 191 | dy_data = state.data.dy |
---|
| 192 | data = state.data.y |
---|
| 193 | if state.dI_noweight: |
---|
| 194 | weight = np.ones_like(data) |
---|
| 195 | elif state.dI_didata: |
---|
| 196 | weight = dy_data |
---|
| 197 | elif state.dI_sqridata: |
---|
| 198 | weight = np.sqrt(np.abs(data)) |
---|
| 199 | elif state.dI_idata: |
---|
| 200 | weight = np.abs(data) |
---|
| 201 | return weight |
---|
| 202 | |
---|
| 203 | _MODEL_CACHE = {} |
---|
| 204 | def load_model(name): |
---|
| 205 | """ |
---|
| 206 | Given a model name load the Sasview shim model from sasmodels. |
---|
| 207 | |
---|
| 208 | If name starts with "[Plug-in]" then load it as a custom model from the |
---|
| 209 | plugins directory. This code does not go through the Sasview model manager |
---|
| 210 | interface since that loads all available models rather than just those |
---|
| 211 | needed. |
---|
| 212 | """ |
---|
| 213 | # Remember the models that are loaded so they are only loaded once. While |
---|
| 214 | # not strictly necessary (the models will use identical but different model |
---|
| 215 | # info structure) it saves a little time and memory for the usual case |
---|
| 216 | # where models are reused for simultaneous and batch fitting. |
---|
| 217 | if name in _MODEL_CACHE: |
---|
| 218 | return _MODEL_CACHE[name] |
---|
| 219 | if name.startswith(PLUGIN_NAME_BASE): |
---|
| 220 | name = name[len(PLUGIN_NAME_BASE):] |
---|
| 221 | plugins_dir = find_plugins_dir() |
---|
| 222 | path = os.path.abspath(os.path.join(plugins_dir, name + ".py")) |
---|
| 223 | #print("loading custom", path) |
---|
| 224 | model = load_custom_model(path) |
---|
| 225 | elif name and name is not None and name.lower() != "none": |
---|
| 226 | #print("loading standard", name) |
---|
| 227 | model = _make_standard_model(name) |
---|
| 228 | else: |
---|
| 229 | model = None |
---|
| 230 | _MODEL_CACHE[name] = model |
---|
| 231 | return model |
---|
| 232 | |
---|
| 233 | def parse_optional_float(value): |
---|
| 234 | """ |
---|
| 235 | Convert optional floating point from string to value, returning None |
---|
| 236 | if string is None, empty or contains the word "None" (case insensitive). |
---|
| 237 | """ |
---|
| 238 | if value is not None and value != "" and value.lower() != "none": |
---|
| 239 | return float(value) |
---|
| 240 | else: |
---|
| 241 | return None |
---|
| 242 | |
---|
| 243 | def make_fitness(state): |
---|
| 244 | # Load the model |
---|
| 245 | category_name = state.categorycombobox |
---|
| 246 | form_factor_name = state.formfactorcombobox |
---|
| 247 | structure_factor_name = state.structurecombobox |
---|
| 248 | multiplicity = state.multi_factor |
---|
| 249 | if category_name == CUSTOM_MODEL: |
---|
| 250 | assert form_factor_name.startswith(PLUGIN_NAME_BASE) |
---|
| 251 | form_factor_model = load_model(form_factor_name) |
---|
| 252 | structure_factor_model = load_model(structure_factor_name) |
---|
| 253 | model = form_factor_model(multiplicity) |
---|
| 254 | if structure_factor_model is not None: |
---|
| 255 | model = MultiplicationModel(model, structure_factor_model()) |
---|
| 256 | |
---|
| 257 | # Set the dispersity distributions for all model parameters. |
---|
| 258 | # Default to gaussian |
---|
| 259 | dists = {par_name + ".type": "gaussian" for par_name in model.dispersion} |
---|
| 260 | dists.update(state.disp_obj_dict) |
---|
| 261 | for par_name, dist_name in state.disp_obj_dict.items(): |
---|
| 262 | dispersion = POLYDISPERSITY_MODELS[dist_name]() |
---|
| 263 | if dist_name == "array": |
---|
| 264 | dispersion.set_weights(state.values[par_name], state.weights[par_name]) |
---|
| 265 | base_par = par_name.replace('.width', '') |
---|
| 266 | model.set_dispersion(base_par, dispersion) |
---|
| 267 | |
---|
| 268 | # Put parameter values and ranges into the model |
---|
| 269 | fitted = [] |
---|
| 270 | for par_tuple in state.parameters + state.fixed_param + state.fittable_param: |
---|
| 271 | par = SasviewParameter(*par_tuple) |
---|
| 272 | if par.name not in state.weights: |
---|
| 273 | # Don't try to set parameter values for array distributions |
---|
| 274 | # TODO: keep weights filename in the array distribution object |
---|
| 275 | model.setParam(par.name, parse_optional_float(par.value)) |
---|
| 276 | if par.fitted: |
---|
| 277 | fitted.append(par.name) |
---|
| 278 | if par.name in model.details: |
---|
| 279 | lower = parse_optional_float(par.lower[1]) |
---|
| 280 | upper = parse_optional_float(par.upper[1]) |
---|
| 281 | model.details[par.name] = [par.units, lower, upper] |
---|
| 282 | #print("pars", model.params) |
---|
| 283 | #print("limits", model.details) |
---|
| 284 | #print("fitted", fitted) |
---|
| 285 | |
---|
| 286 | # Set the resolution |
---|
| 287 | data = copy.deepcopy(state.data) |
---|
| 288 | if state.disable_smearer: |
---|
| 289 | smearer = None |
---|
| 290 | elif state.enable_smearer: |
---|
| 291 | smearer = smear_selection(data, model) |
---|
| 292 | elif state.pinhole_smearer: |
---|
| 293 | # see sasgui/perspectives/fitting/basepage.py: reset_page_helper |
---|
| 294 | dx_percent = state.dx_percent |
---|
| 295 | if state.dx_old: |
---|
| 296 | dx_percent = 100*(state.dx_percent / data.x[0]) |
---|
| 297 | # see sasgui/perspectives/fitting/fitpage.py: _set_pinhole_smear |
---|
| 298 | percent = dx_percent / 100. |
---|
| 299 | if state.enable2D: |
---|
| 300 | # smear_type is Pinhole2D. |
---|
| 301 | # dqx_data, dqy_data initialized to 0 |
---|
| 302 | data.dqx_data = percent * data.qx_data |
---|
| 303 | data.dqy_data = percent * data.qy_data |
---|
| 304 | else: |
---|
| 305 | data.dx = percent * data.x |
---|
| 306 | data.dxl = data.dxw = None # be sure it is not slit-smeared |
---|
| 307 | smearer = smear_selection(data, model) |
---|
| 308 | elif state.slit_smearer: |
---|
| 309 | # see sasgui/perspectives/fitting/fitpage.py: _set_pinhole_smear |
---|
| 310 | data_len = len(data.x) |
---|
| 311 | data.dx = None |
---|
| 312 | data.dxl = (state.dxl if state.dxl is not None else 0.) * np.ones(data_len) |
---|
| 313 | data.dxw = (state.dxw if state.dxw is not None else 0.) * np.ones(data_len) |
---|
| 314 | smearer = smear_selection(data, model) |
---|
| 315 | else: |
---|
| 316 | raise ValueError("expected resolution specification for fit") |
---|
| 317 | |
---|
| 318 | # Set the data weighting (dI, sqrt(I), I, or uniform) |
---|
| 319 | weight = get_data_weight(state) |
---|
| 320 | |
---|
| 321 | # Make fit data object and set the data weights |
---|
| 322 | # TODO: check 2D masked data |
---|
| 323 | if state.enable2D: |
---|
| 324 | fitdata = FitData2D(sas_data2d=data, data=data.data, |
---|
| 325 | err_data=data.err_data) |
---|
| 326 | fitdata.err_data = weight |
---|
| 327 | else: |
---|
| 328 | data.mask = (np.isnan(data.y) if data.y is not None |
---|
| 329 | else np.zeros_like(data.x, dtype='bool')) |
---|
| 330 | fitdata = FitData1D(x=data.x, y=data.y, |
---|
| 331 | dx=data.dx, dy=data.dy, smearer=smearer) |
---|
| 332 | fitdata.dy = weight |
---|
| 333 | fitdata.set_fit_range(qmin=state.qmin, qmax=state.qmax) |
---|
| 334 | fitdata.sas_data = data |
---|
| 335 | |
---|
| 336 | # Don't need initial values since they have been stuffed into the model |
---|
| 337 | # If provided, then they should be one-to-one with the parameter names |
---|
| 338 | # listed in fitted. |
---|
| 339 | initial_values = None |
---|
| 340 | |
---|
| 341 | fitmodel = Model(model, fitdata) |
---|
| 342 | fitmodel.name = state.fit_page |
---|
| 343 | fitness = SasFitness( |
---|
| 344 | model=fitmodel, |
---|
| 345 | data=fitdata, |
---|
| 346 | constraints=state.constraints, |
---|
| 347 | fitted=fitted, |
---|
| 348 | initial_values=initial_values, |
---|
| 349 | ) |
---|
| 350 | |
---|
| 351 | return fitness |
---|
| 352 | |
---|
| 353 | class BumpsPlugin: |
---|
| 354 | """ |
---|
| 355 | Object holding methods for interacting with SasView using the direct |
---|
| 356 | bumps interface. |
---|
| 357 | """ |
---|
| 358 | #@staticmethod |
---|
| 359 | #def data_view(): |
---|
| 360 | # pass |
---|
| 361 | |
---|
| 362 | #@staticmethod |
---|
| 363 | #def model_view(): |
---|
| 364 | # pass |
---|
| 365 | |
---|
| 366 | @staticmethod |
---|
| 367 | def load_model(filename): |
---|
| 368 | fit = FitState(filename) |
---|
| 369 | #fit.show() |
---|
| 370 | #print("====\nfit", fit) |
---|
| 371 | problem = fit.make_fitproblem() |
---|
| 372 | #print(problem.show()) |
---|
| 373 | return problem |
---|
| 374 | |
---|
| 375 | #@staticmethod |
---|
| 376 | #def new_model(): |
---|
| 377 | # pass |
---|
| 378 | |
---|
| 379 | |
---|
| 380 | def setup_sasview(): |
---|
| 381 | from sas.sasview.sasview import setup_logging, setup_mpl, setup_sasmodels |
---|
| 382 | #setup_logging() |
---|
| 383 | #setup_mpl() |
---|
| 384 | setup_sasmodels() |
---|
| 385 | |
---|
| 386 | def setup_bumps(): |
---|
| 387 | """ |
---|
| 388 | Install the refl1d plugin into bumps, but don't run main. |
---|
| 389 | """ |
---|
| 390 | import os |
---|
| 391 | import bumps.cli |
---|
| 392 | bumps.cli.set_mplconfig(appdatadir=os.path.join('.sasview', 'bumpsfit')) |
---|
| 393 | bumps.cli.install_plugin(BumpsPlugin) |
---|
| 394 | |
---|
| 395 | def bumps_cli(): |
---|
| 396 | """ |
---|
| 397 | Install the Refl1D plugin into bumps and run the command line interface. |
---|
| 398 | """ |
---|
| 399 | setup_sasview() |
---|
| 400 | setup_bumps() |
---|
| 401 | import bumps.cli |
---|
| 402 | bumps.cli.main() |
---|
| 403 | |
---|
| 404 | if __name__ == "__main__": |
---|
| 405 | # Allow run with: |
---|
| 406 | # python -m sas.sascalc.fit.fitstate |
---|
| 407 | bumps_cli() |
---|