1 | """ |
---|
2 | Interface to page state for saved fits. |
---|
3 | |
---|
4 | The 4.x sasview gui builds the model, smearer, etc. directly inside the wx |
---|
5 | GUI code. This code separates the in-memory representation from the GUI. |
---|
6 | Initially it is used for a headless bumps fitting backend that operates |
---|
7 | directly on the saved XML; eventually it should provide methods to replace |
---|
8 | direct access to the PageState object so that the code for setting up and |
---|
9 | running fits in wx, qt, and headless versions of SasView shares a common |
---|
10 | in-memory representation. |
---|
11 | """ |
---|
12 | from __future__ import print_function, division |
---|
13 | |
---|
14 | import copy |
---|
15 | from collections import namedtuple |
---|
16 | |
---|
17 | import numpy as np |
---|
18 | |
---|
19 | from bumps.names import FitProblem |
---|
20 | |
---|
21 | from sasmodels.core import load_model_info |
---|
22 | from sasmodels.data import plot_theory |
---|
23 | from sasmodels.sasview_model import _make_standard_model, MultiplicationModel, load_custom_model |
---|
24 | from sasmodels.weights import MODELS as POLYDISPERSITY_MODELS |
---|
25 | |
---|
26 | from .pagestate import Reader, PageState, SimFitPageState, CUSTOM_MODEL |
---|
27 | from .BumpsFitting import SasFitness, ParameterExpressions |
---|
28 | from .AbstractFitEngine import FitData1D, FitData2D, Model |
---|
29 | from .models import PLUGIN_NAME_BASE, find_plugins_dir |
---|
30 | from .qsmearing import smear_selection |
---|
31 | |
---|
32 | # Monkey patch SasFitness class with plotter |
---|
33 | def sasfitness_plot(self, view='log'): |
---|
34 | data, theory, resid = self.data.sas_data, self.theory(), self.residuals() |
---|
35 | plot_theory(data, theory, resid, view) |
---|
36 | SasFitness.plot = sasfitness_plot |
---|
37 | |
---|
38 | # Use a named tuple for the sasview parameters |
---|
39 | PARAMETER_FIELDS = [ |
---|
40 | "fitted", "name", "value", "plusminus", "uncertainty", |
---|
41 | "lower", "upper", "units", |
---|
42 | ] |
---|
43 | SasviewParameter = namedtuple("Parameter", PARAMETER_FIELDS) |
---|
44 | |
---|
45 | class FitState(object): |
---|
46 | def __init__(self, fitfile): |
---|
47 | self.fitfile = fitfile |
---|
48 | self.simfit = None |
---|
49 | self.fits = [] |
---|
50 | |
---|
51 | reader = Reader(self._add_entry) |
---|
52 | datasets = reader.read(fitfile) |
---|
53 | self._set_constraints() |
---|
54 | #print("loaded", datasets) |
---|
55 | |
---|
56 | def _add_entry(self, state=None, datainfo=None, format=None): |
---|
57 | """ |
---|
58 | Handed to the reader to receive and accumulate the loaded state objects. |
---|
59 | """ |
---|
60 | # Note: datainfo is in state.data; format=.svs means reset fit panels |
---|
61 | if isinstance(state, PageState): |
---|
62 | # TODO: shouldn't the update be part of the load? |
---|
63 | state._convert_to_sasmodels() |
---|
64 | self.fits.append(state) |
---|
65 | elif isinstance(state, SimFitPageState): |
---|
66 | self.simfit = state |
---|
67 | else: |
---|
68 | # ignore empty fit info |
---|
69 | pass |
---|
70 | |
---|
71 | def __str__(self): |
---|
72 | return '<SasFit %s>'%self.fitfile |
---|
73 | |
---|
74 | def show(self): |
---|
75 | """ |
---|
76 | Summarize the fit pages in the state object. |
---|
77 | """ |
---|
78 | for k, fit in enumerate(self.fits): |
---|
79 | print("="*20, "Fit page", k+1) |
---|
80 | #print(fit) |
---|
81 | for attr, value in sorted(fit.__dict__.items()): |
---|
82 | if isinstance(value, (list, tuple)): |
---|
83 | print(attr) |
---|
84 | for item in value: |
---|
85 | print(" ", item) |
---|
86 | else: |
---|
87 | print(attr, value) |
---|
88 | if self.simfit: |
---|
89 | print("="*20, "Constraints") |
---|
90 | #print(self.simfit) |
---|
91 | for attr, value in sorted(self.simfit.__dict__.items()): |
---|
92 | if isinstance(value, (list, tuple)): |
---|
93 | print(attr) |
---|
94 | for item in value: |
---|
95 | print(" ", item) |
---|
96 | else: |
---|
97 | print(attr, value) |
---|
98 | |
---|
99 | def make_fitproblem(self): |
---|
100 | """ |
---|
101 | Build collection of bumps fitness calculators and return the FitProblem. |
---|
102 | """ |
---|
103 | # TODO: batch info not stored with project/analysis file (ticket #907) |
---|
104 | models = [make_fitness(state) for state in self.fits] |
---|
105 | if not models: |
---|
106 | raise RuntimeError("Nothing to fit") |
---|
107 | fit_problem = FitProblem(models) |
---|
108 | fit_problem.setp_hook = ParameterExpressions(models) |
---|
109 | return fit_problem |
---|
110 | |
---|
111 | def _set_constraints(self): |
---|
112 | """ |
---|
113 | Adds fit_page and constraints list to each model. |
---|
114 | """ |
---|
115 | # early return if no sim fit |
---|
116 | if self.simfit is None: |
---|
117 | for fit in self.fits: |
---|
118 | fit.fit_page = 'M1' |
---|
119 | fit.constraints = {} |
---|
120 | return |
---|
121 | |
---|
122 | # Note: simfitpage.py:load_from_save_state relabels the model and |
---|
123 | # constraint on load, replacing the model name in the constraint |
---|
124 | # expression with the new name for every constraint expression. We |
---|
125 | # don't bother to do that here since we don't need to relabel the |
---|
126 | # model ids. |
---|
127 | constraints = {} |
---|
128 | for item in self.simfit.constraints_list: |
---|
129 | # model_cbox in the constraints list should match fit_page_source |
---|
130 | # in the model list. |
---|
131 | pairs = constraints.setdefault(item['model_cbox'], []) |
---|
132 | pairs.append((item['param_cbox'], item['constraint'])) |
---|
133 | |
---|
134 | # No way to uniquely map the page id (M1, M2, etc.) to the different |
---|
135 | # items in fits; neither fit_number nor fit_page_source is in the |
---|
136 | # pagestate for the fit, and neither model_name nor data name are |
---|
137 | # unique. The usual case of one model per data file will get us most |
---|
138 | # of the way there but there will be ambiguity when the data file |
---|
139 | # is not unique, e.g., when different parts of the data set are |
---|
140 | # fit with different models. If the same model and same data are |
---|
141 | # used (e.g., with different background, scale or resolution in |
---|
142 | # different segments) then the model-fit association will be assigned |
---|
143 | # arbitrarily based on whichever happens to come first. |
---|
144 | used = [] |
---|
145 | for model in self.simfit.model_list: |
---|
146 | matched = 0 |
---|
147 | for fit in self.fits: |
---|
148 | #print(model['name'], fit.data_id, model_name(fit), model['model_name']) |
---|
149 | if (fit.data_id == model['name'] |
---|
150 | and model_name(fit) == model['model_name'] |
---|
151 | and fit not in used): |
---|
152 | fit.fit_page = model['fit_page_source'] |
---|
153 | fit.constraints = constraints.setdefault(fit.fit_page, []) |
---|
154 | used.append(fit) |
---|
155 | matched += 1 |
---|
156 | if matched > 1: |
---|
157 | raise ValueError("more than one model matches %s" |
---|
158 | % model['fit_page_source']) |
---|
159 | elif matched == 0: |
---|
160 | raise ValueError("could not find model %s in file" |
---|
161 | % model['fit_page_source']) |
---|
162 | |
---|
163 | |
---|
164 | def model_name(state): |
---|
165 | """ |
---|
166 | Build the model name out of form factor and structure factor (if present). |
---|
167 | |
---|
168 | This will be the name that is stored as the model name in the simultaneous |
---|
169 | fit model_list structure corresponding to the form factor and structure |
---|
170 | factor given on the individual fit pages. The model name is used to help |
---|
171 | disambiguate different SASentry sections with the same dataset. |
---|
172 | """ |
---|
173 | p_model, s_model = state.formfactorcombobox, state.structurecombobox |
---|
174 | if s_model is not None and s_model != "" and s_model.lower() != "none": |
---|
175 | return '*'.join((p_model, s_model)) |
---|
176 | else: |
---|
177 | return p_model |
---|
178 | |
---|
179 | def get_data_weight(state): |
---|
180 | """ |
---|
181 | Get error bars on data. These could be the values computed by reduction |
---|
182 | and stored in the file, the square root of the intensity (if instensity |
---|
183 | is approximately counts), the intensity itself (would be better as a |
---|
184 | percentage of the intensity, such as 2% or 5% depending on relative |
---|
185 | counting time), or one for equal weight uncertainty depending on the |
---|
186 | value of state.dI_*. |
---|
187 | """ |
---|
188 | # Cribbed from perspectives/fitting/utils.py:get_weight and |
---|
189 | # perspectives/fitting/fitpage.py: get_weight_flag |
---|
190 | weight = None |
---|
191 | if state.enable2D: |
---|
192 | dy_data = state.data.err_data |
---|
193 | data = state.data.data |
---|
194 | else: |
---|
195 | dy_data = state.data.dy |
---|
196 | data = state.data.y |
---|
197 | if state.dI_noweight: |
---|
198 | weight = np.ones_like(data) |
---|
199 | elif state.dI_didata: |
---|
200 | weight = dy_data |
---|
201 | elif state.dI_sqridata: |
---|
202 | weight = np.sqrt(np.abs(data)) |
---|
203 | elif state.dI_idata: |
---|
204 | weight = np.abs(data) |
---|
205 | return weight |
---|
206 | |
---|
207 | _MODEL_CACHE = {} |
---|
208 | def load_model(name): |
---|
209 | """ |
---|
210 | Given a model name load the Sasview shim model from sasmodels. |
---|
211 | |
---|
212 | If name starts with "[Plug-in]" then load it as a custom model from the |
---|
213 | plugins directory. This code does not go through the Sasview model manager |
---|
214 | interface since that loads all available models rather than just those |
---|
215 | needed. |
---|
216 | """ |
---|
217 | # Remember the models that are loaded so they are only loaded once. While |
---|
218 | # not strictly necessary (the models will use identical but different model |
---|
219 | # info structure) it saves a little time and memory for the usual case |
---|
220 | # where models are reused for simultaneous and batch fitting. |
---|
221 | if name in _MODEL_CACHE: |
---|
222 | return _MODEL_CACHE[name] |
---|
223 | if name.startswith(PLUGIN_NAME_BASE): |
---|
224 | name = name[len(PLUGIN_NAME_BASE):] |
---|
225 | plugins_dir = find_plugins_dir() |
---|
226 | path = os.path.abspath(os.path.join(plugins_dir, name + ".py")) |
---|
227 | #print("loading custom", path) |
---|
228 | model = load_custom_model(path) |
---|
229 | elif name and name is not None and name.lower() != "none": |
---|
230 | #print("loading standard", name) |
---|
231 | model = _make_standard_model(name) |
---|
232 | else: |
---|
233 | model = None |
---|
234 | _MODEL_CACHE[name] = model |
---|
235 | return model |
---|
236 | |
---|
237 | def parse_optional_float(value): |
---|
238 | """ |
---|
239 | Convert optional floating point from string to value, returning None |
---|
240 | if string is None, empty or contains the word "None" (case insensitive). |
---|
241 | """ |
---|
242 | if value is not None and value != "" and value.lower() != "none": |
---|
243 | return float(value) |
---|
244 | else: |
---|
245 | return None |
---|
246 | |
---|
247 | def make_fitness(state): |
---|
248 | # Load the model |
---|
249 | category_name = state.categorycombobox |
---|
250 | form_factor_name = state.formfactorcombobox |
---|
251 | structure_factor_name = state.structurecombobox |
---|
252 | multiplicity = state.multi_factor |
---|
253 | if category_name == CUSTOM_MODEL: |
---|
254 | assert form_factor_name.startswith(PLUGIN_NAME_BASE) |
---|
255 | form_factor_model = load_model(form_factor_name) |
---|
256 | structure_factor_model = load_model(structure_factor_name) |
---|
257 | model = form_factor_model(multiplicity) |
---|
258 | if structure_factor_model is not None: |
---|
259 | model = MultiplicationModel(model, structure_factor_model()) |
---|
260 | |
---|
261 | # Set the dispersity distributions for all model parameters. |
---|
262 | # Default to gaussian |
---|
263 | dists = {par_name + ".type": "gaussian" for par_name in model.dispersion} |
---|
264 | dists.update(state.disp_obj_dict) |
---|
265 | for par_name, dist_name in state.disp_obj_dict.items(): |
---|
266 | dispersion = POLYDISPERSITY_MODELS[dist_name]() |
---|
267 | if dist_name == "array": |
---|
268 | dispersion.set_weights(state.values[par_name], state.weights[par_name]) |
---|
269 | base_par = par_name.replace('.width', '') |
---|
270 | model.set_dispersion(base_par, dispersion) |
---|
271 | |
---|
272 | # Put parameter values and ranges into the model |
---|
273 | fitted = [] |
---|
274 | for par_tuple in state.parameters + state.fixed_param + state.fittable_param: |
---|
275 | par = SasviewParameter(*par_tuple) |
---|
276 | if par.name not in state.weights: |
---|
277 | # Don't try to set parameter values for array distributions |
---|
278 | # TODO: keep weights filename in the array distribution object |
---|
279 | model.setParam(par.name, parse_optional_float(par.value)) |
---|
280 | if par.fitted: |
---|
281 | fitted.append(par.name) |
---|
282 | if par.name in model.details: |
---|
283 | lower = parse_optional_float(par.lower[1]) |
---|
284 | upper = parse_optional_float(par.upper[1]) |
---|
285 | model.details[par.name] = [par.units, lower, upper] |
---|
286 | #print("pars", model.params) |
---|
287 | #print("limits", model.details) |
---|
288 | #print("fitted", fitted) |
---|
289 | |
---|
290 | # Set the resolution |
---|
291 | data = copy.deepcopy(state.data) |
---|
292 | if state.disable_smearer: |
---|
293 | smearer = None |
---|
294 | elif state.enable_smearer: |
---|
295 | smearer = smear_selection(data, model) |
---|
296 | elif state.pinhole_smearer: |
---|
297 | # see sasgui/perspectives/fitting/basepage.py: reset_page_helper |
---|
298 | dx_percent = state.dx_percent |
---|
299 | if state.dx_old: |
---|
300 | dx_percent = 100*(state.dx_percent / data.x[0]) |
---|
301 | # see sasgui/perspectives/fitting/fitpage.py: _set_pinhole_smear |
---|
302 | percent = dx_percent / 100. |
---|
303 | if state.enable2D: |
---|
304 | # smear_type is Pinhole2D. |
---|
305 | # dqx_data, dqy_data initialized to 0 |
---|
306 | data.dqx_data = percent * data.qx_data |
---|
307 | data.dqy_data = percent * data.qy_data |
---|
308 | else: |
---|
309 | data.dx = percent * data.x |
---|
310 | data.dxl = data.dxw = None # be sure it is not slit-smeared |
---|
311 | smearer = smear_selection(data, model) |
---|
312 | elif state.slit_smearer: |
---|
313 | # see sasgui/perspectives/fitting/fitpage.py: _set_pinhole_smear |
---|
314 | data_len = len(data.x) |
---|
315 | data.dx = None |
---|
316 | data.dxl = (state.dxl if state.dxl is not None else 0.) * np.ones(data_len) |
---|
317 | data.dxw = (state.dxw if state.dxw is not None else 0.) * np.ones(data_len) |
---|
318 | smearer = smear_selection(data, model) |
---|
319 | else: |
---|
320 | raise ValueError("expected resolution specification for fit") |
---|
321 | |
---|
322 | # Set the data weighting (dI, sqrt(I), I, or uniform) |
---|
323 | weight = get_data_weight(state) |
---|
324 | |
---|
325 | # Note: wx GUI makes a copy of the data and assigns weight to |
---|
326 | # data.err_data/data.dy instead of using the err_data/dy keywords |
---|
327 | # when creating the Fit object. |
---|
328 | |
---|
329 | # Make fit data object and set the data weights |
---|
330 | # TODO: check 2D masked data |
---|
331 | if state.enable2D: |
---|
332 | fitdata = FitData2D(sas_data2d=data, data=data.data, |
---|
333 | err_data=weight) |
---|
334 | else: |
---|
335 | data.mask = (np.isnan(data.y) if data.y is not None |
---|
336 | else np.zeros_like(data.x, dtype='bool')) |
---|
337 | fitdata = FitData1D(x=data.x, y=data.y, |
---|
338 | dx=data.dx, dy=weight, smearer=smearer) |
---|
339 | fitdata.set_fit_range(qmin=state.qmin, qmax=state.qmax) |
---|
340 | fitdata.sas_data = data |
---|
341 | |
---|
342 | # Don't need initial values since they have been stuffed into the model |
---|
343 | # If provided, then they should be one-to-one with the parameter names |
---|
344 | # listed in fitted. |
---|
345 | initial_values = None |
---|
346 | |
---|
347 | fitmodel = Model(model, fitdata) |
---|
348 | fitmodel.name = state.fit_page |
---|
349 | fitness = SasFitness( |
---|
350 | model=fitmodel, |
---|
351 | data=fitdata, |
---|
352 | constraints=state.constraints, |
---|
353 | fitted=fitted, |
---|
354 | initial_values=initial_values, |
---|
355 | ) |
---|
356 | |
---|
357 | return fitness |
---|
358 | |
---|
359 | class BumpsPlugin: |
---|
360 | """ |
---|
361 | Object holding methods for interacting with SasView using the direct |
---|
362 | bumps interface. |
---|
363 | """ |
---|
364 | #@staticmethod |
---|
365 | #def data_view(): |
---|
366 | # pass |
---|
367 | |
---|
368 | #@staticmethod |
---|
369 | #def model_view(): |
---|
370 | # pass |
---|
371 | |
---|
372 | @staticmethod |
---|
373 | def load_model(filename): |
---|
374 | fit = FitState(filename) |
---|
375 | #fit.show() |
---|
376 | #print("====\nfit", fit) |
---|
377 | problem = fit.make_fitproblem() |
---|
378 | #print(problem.show()) |
---|
379 | return problem |
---|
380 | |
---|
381 | #@staticmethod |
---|
382 | #def new_model(): |
---|
383 | # pass |
---|
384 | |
---|
385 | |
---|
386 | def setup_sasview(): |
---|
387 | from sas.sasview.sasview import setup_logging, setup_mpl, setup_sasmodels |
---|
388 | #setup_logging() |
---|
389 | #setup_mpl() |
---|
390 | setup_sasmodels() |
---|
391 | |
---|
392 | def setup_bumps(): |
---|
393 | """ |
---|
394 | Install the refl1d plugin into bumps, but don't run main. |
---|
395 | """ |
---|
396 | import os |
---|
397 | import bumps.cli |
---|
398 | bumps.cli.set_mplconfig(appdatadir=os.path.join('.sasview', 'bumpsfit')) |
---|
399 | bumps.cli.install_plugin(BumpsPlugin) |
---|
400 | |
---|
401 | def bumps_cli(): |
---|
402 | """ |
---|
403 | Install the SasView plugin into bumps and run the command line interface. |
---|
404 | """ |
---|
405 | setup_sasview() |
---|
406 | setup_bumps() |
---|
407 | import bumps.cli |
---|
408 | bumps.cli.main() |
---|
409 | |
---|
410 | if __name__ == "__main__": |
---|
411 | # Allow run with: |
---|
412 | # python -m sas.sascalc.fit.fitstate |
---|
413 | bumps_cli() |
---|