source: sasview/src/sas/sascalc/fit/fitstate.py @ 3f89c0e

ticket-1094-headless
Last change on this file since 3f89c0e was 3f89c0e, checked in by Paul Kienzle <pkienzle@…>, 5 years ago

raise error if file loaded by headless fit is ambiguous

  • Property mode set to 100644
File size: 15.3 KB
Line 
1"""
2Interface to page state for saved fits.
3
4The 4.x sasview gui builds the model, smearer, etc. directly inside the wx
5GUI code.  This code separates the in-memory representation from the GUI.
6Initially it is used for a headless bumps fitting backend that operates
7directly on the saved XML; eventually it should provide methods to replace
8direct access to the PageState object so that the code for setting up and
9running fits in wx, qt, and headless versions of SasView shares a common
10in-memory representation.
11"""
12from __future__ import print_function, division
13
14import copy
15from collections import namedtuple
16
17import numpy as np
18
19from bumps.names import FitProblem
20
21from sasmodels.core import load_model_info
22from sasmodels.data import plot_theory
23from sasmodels.sasview_model import _make_standard_model, MultiplicationModel, load_custom_model
24from sasmodels.weights import MODELS as POLYDISPERSITY_MODELS
25
26from .pagestate import Reader, PageState, SimFitPageState, CUSTOM_MODEL
27from .BumpsFitting import SasFitness, ParameterExpressions
28from .AbstractFitEngine import FitData1D, FitData2D, Model
29from .models import PLUGIN_NAME_BASE, find_plugins_dir
30from .qsmearing import smear_selection
31
32# Monkey patch SasFitness class with plotter
33def sasfitness_plot(self, view='log'):
34    data, theory, resid = self.data.sas_data, self.theory(), self.residuals()
35    plot_theory(data, theory, resid, view)
36SasFitness.plot = sasfitness_plot
37
38# Use a named tuple for the sasview parameters
39PARAMETER_FIELDS = [
40    "fitted", "name", "value", "plusminus", "uncertainty",
41    "lower", "upper", "units",
42    ]
43SasviewParameter = namedtuple("Parameter", PARAMETER_FIELDS)
44
45class FitState(object):
46    def __init__(self, fitfile):
47        self.fitfile = fitfile
48        self.simfit = None
49        self.fits = []
50
51        reader = Reader(self._add_entry)
52        datasets = reader.read(fitfile)
53        self._set_constraints()
54        #print("loaded", datasets)
55
56    def _add_entry(self, state=None, datainfo=None, format=None):
57        """
58        Handed to the reader to receive and accumulate the loaded state objects.
59        """
60        # Note: datainfo is in state.data; format=.svs means reset fit panels
61        if isinstance(state, PageState):
62            # TODO: shouldn't the update be part of the load?
63            state._convert_to_sasmodels()
64            self.fits.append(state)
65        elif isinstance(state, SimFitPageState):
66            self.simfit = state
67        else:
68            # ignore empty fit info
69            pass
70
71    def __str__(self):
72        return '<SasFit %s>'%self.fitfile
73
74    def show(self):
75        """
76        Summarize the fit pages in the state object.
77        """
78        for k, fit in enumerate(self.fits):
79            print("="*20, "Fit page", k+1)
80            #print(fit)
81            for attr, value in sorted(fit.__dict__.items()):
82                if isinstance(value, (list, tuple)):
83                    print(attr)
84                    for item in value:
85                        print("   ", item)
86                else:
87                    print(attr, value)
88        if self.simfit:
89            print("="*20, "Constraints")
90            #print(self.simfit)
91            for attr, value in sorted(self.simfit.__dict__.items()):
92                if isinstance(value, (list, tuple)):
93                    print(attr)
94                    for item in value:
95                        print("   ", item)
96                else:
97                    print(attr, value)
98
99    def make_fitproblem(self):
100        """
101        Build collection of bumps fitness calculators and return the FitProblem.
102        """
103        # TODO: batch info not stored with project/analysis file (ticket #907)
104        models = [make_fitness(state) for state in self.fits]
105        if not models:
106            raise RuntimeError("Nothing to fit")
107        fit_problem = FitProblem(models)
108        fit_problem.setp_hook = ParameterExpressions(models)
109        return fit_problem
110
111    def _set_constraints(self):
112        """
113        Adds fit_page and constraints list to each model.
114        """
115        # early return if no sim fit
116        if self.simfit is None:
117            for fit in self.fits:
118                fit.fit_page = 'M1'
119                fit.constraints = {}
120            return
121
122        # Note: simfitpage.py:load_from_save_state relabels the model and
123        # constraint on load, replacing the model name in the constraint
124        # expression with the new name for every constraint expression. We
125        # don't bother to do that here since we don't need to relabel the
126        # model ids.
127        constraints = {}
128        for item in self.simfit.constraints_list:
129            # model_cbox in the constraints list should match fit_page_source
130            # in the model list.
131            pairs = constraints.setdefault(item['model_cbox'], [])
132            pairs.append((item['param_cbox'], item['constraint']))
133
134        # No way to uniquely map the page id (M1, M2, etc.) to the different
135        # items in fits; neither fit_number nor fit_page_source is in the
136        # pagestate for the fit, and neither model_name nor data name are
137        # unique.  The usual case of one model per data file will get us most
138        # of the way there but there will be ambiguity when the data file
139        # is not unique, e.g., when different parts of the data set are
140        # fit with different models.  If the same model and same data are
141        # used (e.g., with different background, scale or resolution in
142        # different segments) then the model-fit association will be assigned
143        # arbitrarily based on whichever happens to come first.
144        used = []
145        for model in self.simfit.model_list:
146            matched = 0
147            for fit in self.fits:
148                #print(model['name'], fit.data_id, model_name(fit), model['model_name'])
149                if (fit.data_id == model['name']
150                        and model_name(fit) == model['model_name']
151                        and fit not in used):
152                    fit.fit_page = model['fit_page_source']
153                    fit.constraints = constraints.setdefault(fit.fit_page, [])
154                    used.append(fit)
155                    matched += 1
156            if matched > 1:
157                raise ValueError("more than one model matches %s"
158                                 % model['fit_page_source'])
159            elif matched == 0:
160                raise ValueError("could not find model %s in file"
161                                 % model['fit_page_source'])
162
163
164def model_name(state):
165    """
166    Build the model name out of form factor and structure factor (if present).
167
168    This will be the name that is stored as the model name in the simultaneous
169    fit model_list structure corresponding to the form factor and structure
170    factor given on the individual fit pages.  The model name is used to help
171    disambiguate different SASentry sections with the same dataset.
172    """
173    p_model, s_model = state.formfactorcombobox, state.structurecombobox
174    if s_model is not None and s_model != "" and s_model.lower() != "none":
175        return '*'.join((p_model, s_model))
176    else:
177        return p_model
178
179def get_data_weight(state):
180    """
181    Get error bars on data.  These could be the values computed by reduction
182    and stored in the file, the square root of the intensity (if instensity
183    is approximately counts), the intensity itself (would be better as a
184    percentage of the intensity, such as 2% or 5% depending on relative
185    counting time), or one for equal weight uncertainty depending on the
186    value of state.dI_*.
187    """
188    # Cribbed from perspectives/fitting/utils.py:get_weight and
189    # perspectives/fitting/fitpage.py: get_weight_flag
190    weight = None
191    if state.enable2D:
192        dy_data = state.data.err_data
193        data = state.data.data
194    else:
195        dy_data = state.data.dy
196        data = state.data.y
197    if state.dI_noweight:
198        weight = np.ones_like(data)
199    elif state.dI_didata:
200        weight = dy_data
201    elif state.dI_sqridata:
202        weight = np.sqrt(np.abs(data))
203    elif state.dI_idata:
204        weight = np.abs(data)
205    return weight
206
207_MODEL_CACHE = {}
208def load_model(name):
209    """
210    Given a model name load the Sasview shim model from sasmodels.
211
212    If name starts with "[Plug-in]" then load it as a custom model from the
213    plugins directory.  This code does not go through the Sasview model manager
214    interface since that loads all available models rather than just those
215    needed.
216    """
217    # Remember the models that are loaded so they are only loaded once.  While
218    # not strictly necessary (the models will use identical but different model
219    # info structure) it saves a little time and memory for the usual case
220    # where models are reused for simultaneous and batch fitting.
221    if name in _MODEL_CACHE:
222        return _MODEL_CACHE[name]
223    if name.startswith(PLUGIN_NAME_BASE):
224        name = name[len(PLUGIN_NAME_BASE):]
225        plugins_dir = find_plugins_dir()
226        path = os.path.abspath(os.path.join(plugins_dir, name + ".py"))
227        #print("loading custom", path)
228        model = load_custom_model(path)
229    elif name and name is not None and name.lower() != "none":
230        #print("loading standard", name)
231        model = _make_standard_model(name)
232    else:
233        model = None
234    _MODEL_CACHE[name] = model
235    return model
236
237def parse_optional_float(value):
238    """
239    Convert optional floating point from string to value, returning None
240    if string is None, empty or contains the word "None" (case insensitive).
241    """
242    if value is not None and value != "" and value.lower() != "none":
243        return float(value)
244    else:
245        return None
246
247def make_fitness(state):
248    # Load the model
249    category_name = state.categorycombobox
250    form_factor_name = state.formfactorcombobox
251    structure_factor_name = state.structurecombobox
252    multiplicity = state.multi_factor
253    if category_name == CUSTOM_MODEL:
254        assert form_factor_name.startswith(PLUGIN_NAME_BASE)
255    form_factor_model = load_model(form_factor_name)
256    structure_factor_model = load_model(structure_factor_name)
257    model = form_factor_model(multiplicity)
258    if structure_factor_model is not None:
259        model = MultiplicationModel(model, structure_factor_model())
260
261    # Set the dispersity distributions for all model parameters.
262    # Default to gaussian
263    dists = {par_name + ".type": "gaussian" for par_name in model.dispersion}
264    dists.update(state.disp_obj_dict)
265    for par_name, dist_name in state.disp_obj_dict.items():
266        dispersion = POLYDISPERSITY_MODELS[dist_name]()
267        if dist_name == "array":
268            dispersion.set_weights(state.values[par_name], state.weights[par_name])
269        base_par = par_name.replace('.width', '')
270        model.set_dispersion(base_par, dispersion)
271
272    # Put parameter values and ranges into the model
273    fitted = []
274    for par_tuple in state.parameters + state.fixed_param + state.fittable_param:
275        par = SasviewParameter(*par_tuple)
276        if par.name not in state.weights:
277            # Don't try to set parameter values for array distributions
278            # TODO: keep weights filename in the array distribution object
279            model.setParam(par.name, parse_optional_float(par.value))
280        if par.fitted:
281            fitted.append(par.name)
282        if par.name in model.details:
283            lower = parse_optional_float(par.lower[1])
284            upper = parse_optional_float(par.upper[1])
285            model.details[par.name] = [par.units, lower, upper]
286    #print("pars", model.params)
287    #print("limits", model.details)
288    #print("fitted", fitted)
289
290    # Set the resolution
291    data = copy.deepcopy(state.data)
292    if state.disable_smearer:
293        smearer = None
294    elif state.enable_smearer:
295        smearer = smear_selection(data, model)
296    elif state.pinhole_smearer:
297        # see sasgui/perspectives/fitting/basepage.py: reset_page_helper
298        dx_percent = state.dx_percent
299        if state.dx_old:
300            dx_percent = 100*(state.dx_percent / data.x[0])
301        # see sasgui/perspectives/fitting/fitpage.py: _set_pinhole_smear
302        percent = dx_percent / 100.
303        if state.enable2D:
304            # smear_type is Pinhole2D.
305            # dqx_data, dqy_data initialized to 0
306            data.dqx_data = percent * data.qx_data
307            data.dqy_data = percent * data.qy_data
308        else:
309            data.dx = percent * data.x
310            data.dxl = data.dxw = None  # be sure it is not slit-smeared
311        smearer = smear_selection(data, model)
312    elif state.slit_smearer:
313        # see sasgui/perspectives/fitting/fitpage.py: _set_pinhole_smear
314        data_len = len(data.x)
315        data.dx = None
316        data.dxl = (state.dxl if state.dxl is not None else 0.) * np.ones(data_len)
317        data.dxw = (state.dxw if state.dxw is not None else 0.) * np.ones(data_len)
318        smearer = smear_selection(data, model)
319    else:
320        raise ValueError("expected resolution specification for fit")
321
322    # Set the data weighting (dI, sqrt(I), I, or uniform)
323    weight = get_data_weight(state)
324
325    # Note: wx GUI makes a copy of the data and assigns weight to
326    # data.err_data/data.dy instead of using the err_data/dy keywords
327    # when creating the Fit object.
328
329    # Make fit data object and set the data weights
330    # TODO: check 2D masked data
331    if state.enable2D:
332        fitdata = FitData2D(sas_data2d=data, data=data.data,
333                            err_data=weight)
334    else:
335        data.mask = (np.isnan(data.y) if data.y is not None
336                        else np.zeros_like(data.x, dtype='bool'))
337        fitdata = FitData1D(x=data.x, y=data.y,
338                            dx=data.dx, dy=weight, smearer=smearer)
339        fitdata.set_fit_range(qmin=state.qmin, qmax=state.qmax)
340    fitdata.sas_data = data
341
342    # Don't need initial values since they have been stuffed into the model
343    # If provided, then they should be one-to-one with the parameter names
344    # listed in fitted.
345    initial_values = None
346
347    fitmodel = Model(model, fitdata)
348    fitmodel.name = state.fit_page
349    fitness = SasFitness(
350        model=fitmodel,
351        data=fitdata,
352        constraints=state.constraints,
353        fitted=fitted,
354        initial_values=initial_values,
355        )
356
357    return fitness
358
359class BumpsPlugin:
360    """
361    Object holding methods for interacting with SasView using the direct
362    bumps interface.
363    """
364    #@staticmethod
365    #def data_view():
366    #    pass
367
368    #@staticmethod
369    #def model_view():
370    #    pass
371
372    @staticmethod
373    def load_model(filename):
374        fit = FitState(filename)
375        #fit.show()
376        #print("====\nfit", fit)
377        problem = fit.make_fitproblem()
378        #print(problem.show())
379        return problem
380
381    #@staticmethod
382    #def new_model():
383    #    pass
384
385
386def setup_sasview():
387    from sas.sasview.sasview import setup_logging, setup_mpl, setup_sasmodels
388    #setup_logging()
389    #setup_mpl()
390    setup_sasmodels()
391
392def setup_bumps():
393    """
394    Install the refl1d plugin into bumps, but don't run main.
395    """
396    import os
397    import bumps.cli
398    bumps.cli.set_mplconfig(appdatadir=os.path.join('.sasview', 'bumpsfit'))
399    bumps.cli.install_plugin(BumpsPlugin)
400
401def bumps_cli():
402    """
403    Install the SasView plugin into bumps and run the command line interface.
404    """
405    setup_sasview()
406    setup_bumps()
407    import bumps.cli
408    bumps.cli.main()
409
410if __name__ == "__main__":
411    # Allow run with:
412    #    python -m sas.sascalc.fit.fitstate
413    bumps_cli()
Note: See TracBrowser for help on using the repository browser.