source: sasview/src/sas/sascalc/fit/fitstate.py @ 1d56359

ticket-1094-headless
Last change on this file since 1d56359 was 1d56359, checked in by Paul Kienzle <pkienzle@…>, 9 months ago

code cleanup: move shared code to function

  • Property mode set to 100644
File size: 15.2 KB
Line 
1"""
2Interface to page state for saved fits.
3
4The 4.x sasview gui builds the model, smearer, etc. directly inside the wx
5GUI code.  This code separates the in-memory representation from the GUI.
6Initially it is used for a headless bumps fitting backend that operates
7directly on the saved XML; eventually it should provide methods to replace
8direct access to the PageState object so that the code for setting up and
9running fits in wx, qt, and headless versions of SasView shares a common
10in-memory representation.
11"""
12from __future__ import print_function, division
13
14import copy
15from collections import namedtuple
16
17import numpy as np
18
19from bumps.names import FitProblem
20
21from sasmodels.core import load_model_info
22from sasmodels.data import plot_theory
23from sasmodels.sasview_model import _make_standard_model, MultiplicationModel, load_custom_model
24from sasmodels.weights import MODELS as POLYDISPERSITY_MODELS
25
26from .pagestate import Reader, PageState, SimFitPageState, CUSTOM_MODEL
27from .BumpsFitting import SasFitness, ParameterExpressions
28from .AbstractFitEngine import FitData1D, FitData2D, Model
29from .models import PLUGIN_NAME_BASE, find_plugins_dir
30from .qsmearing import smear_selection
31
32# Monkey patch SasFitness class with plotter
33def sasfitness_plot(self, view='log'):
34    data, theory, resid = self.data.sas_data, self.theory(), self.residuals()
35    plot_theory(data, theory, resid, view)
36SasFitness.plot = sasfitness_plot
37
38# Use a named tuple for the sasview parameters
39PARAMETER_FIELDS = [
40    "fitted", "name", "value", "plusminus", "uncertainty",
41    "lower", "upper", "units",
42    ]
43SasviewParameter = namedtuple("Parameter", PARAMETER_FIELDS)
44
45class FitState(object):
46    def __init__(self, fitfile):
47        self.fitfile = fitfile
48        self.simfit = None
49        self.fits = []
50
51        reader = Reader(self._add_entry)
52        datasets = reader.read(fitfile)
53        self._set_constraints()
54        #print("loaded", datasets)
55
56    def _add_entry(self, state=None, datainfo=None, format=None):
57        """
58        Handed to the reader to receive and accumulate the loaded state objects.
59        """
60        # Note: datainfo is in state.data; format=.svs means reset fit panels
61        if isinstance(state, PageState):
62            # TODO: shouldn't the update be part of the load?
63            state._convert_to_sasmodels()
64            self.fits.append(state)
65        elif isinstance(state, SimFitPageState):
66            self.simfit = state
67        else:
68            # ignore empty fit info
69            pass
70
71    def __str__(self):
72        return '<SasFit %s>'%self.fitfile
73
74    def show(self):
75        """
76        Summarize the fit pages in the state object.
77        """
78        # Note: _dump_attrs isn't a closure, but putting it here anyway
79        # because it is specific to show and doesn't need to be efficient.
80        def _dump_attrs(obj, label=""):
81            #print(obj)
82            print("="*20, label)
83            for attr, value in sorted(obj.__dict__.items()):
84                if isinstance(value, (list, tuple)):
85                    print(attr)
86                    for item in value:
87                        print("   ", item)
88                else:
89                    print(attr, value)
90        for k, fit in enumerate(self.fits):
91            _dump_attrs(fit, label="Fit page "+str(k+1))
92        if self.simfit:
93            _dump_attrs(self.simfit, label="Constraints")
94
95    def make_fitproblem(self):
96        """
97        Build collection of bumps fitness calculators and return the FitProblem.
98        """
99        # TODO: batch info not stored with project/analysis file (ticket #907)
100        models = [make_fitness(state) for state in self.fits]
101        if not models:
102            raise RuntimeError("Nothing to fit")
103        fit_problem = FitProblem(models)
104        fit_problem.setp_hook = ParameterExpressions(models)
105        return fit_problem
106
107    def _set_constraints(self):
108        """
109        Adds fit_page and constraints list to each model.
110        """
111        # early return if no sim fit
112        if self.simfit is None:
113            for fit in self.fits:
114                fit.fit_page = 'M1'
115                fit.constraints = {}
116            return
117
118        # Note: simfitpage.py:load_from_save_state relabels the model and
119        # constraint on load, replacing the model name in the constraint
120        # expression with the new name for every constraint expression. We
121        # don't bother to do that here since we don't need to relabel the
122        # model ids.
123        constraints = {}
124        for item in self.simfit.constraints_list:
125            # model_cbox in the constraints list should match fit_page_source
126            # in the model list.
127            pairs = constraints.setdefault(item['model_cbox'], [])
128            pairs.append((item['param_cbox'], item['constraint']))
129
130        # No way to uniquely map the page id (M1, M2, etc.) to the different
131        # items in fits; neither fit_number nor fit_page_source is in the
132        # pagestate for the fit, and neither model_name nor data name are
133        # unique.  The usual case of one model per data file will get us most
134        # of the way there but there will be ambiguity when the data file
135        # is not unique, e.g., when different parts of the data set are
136        # fit with different models.  If the same model and same data are
137        # used (e.g., with different background, scale or resolution in
138        # different segments) then the model-fit association will be assigned
139        # arbitrarily based on whichever happens to come first.
140        used = []
141        for model in self.simfit.model_list:
142            matched = 0
143            for fit in self.fits:
144                #print(model['name'], fit.data_id, model_name(fit), model['model_name'])
145                if (fit.data_id == model['name']
146                        and model_name(fit) == model['model_name']
147                        and fit not in used):
148                    fit.fit_page = model['fit_page_source']
149                    fit.constraints = constraints.setdefault(fit.fit_page, [])
150                    used.append(fit)
151                    matched += 1
152            if matched > 1:
153                raise ValueError("more than one model matches %s"
154                                 % model['fit_page_source'])
155            elif matched == 0:
156                raise ValueError("could not find model %s in file"
157                                 % model['fit_page_source'])
158
159
160def model_name(state):
161    """
162    Build the model name out of form factor and structure factor (if present).
163
164    This will be the name that is stored as the model name in the simultaneous
165    fit model_list structure corresponding to the form factor and structure
166    factor given on the individual fit pages.  The model name is used to help
167    disambiguate different SASentry sections with the same dataset.
168    """
169    p_model, s_model = state.formfactorcombobox, state.structurecombobox
170    if s_model is not None and s_model != "" and s_model.lower() != "none":
171        return '*'.join((p_model, s_model))
172    else:
173        return p_model
174
175def get_data_weight(state):
176    """
177    Get error bars on data.  These could be the values computed by reduction
178    and stored in the file, the square root of the intensity (if instensity
179    is approximately counts), the intensity itself (would be better as a
180    percentage of the intensity, such as 2% or 5% depending on relative
181    counting time), or one for equal weight uncertainty depending on the
182    value of state.dI_*.
183    """
184    # Cribbed from perspectives/fitting/utils.py:get_weight and
185    # perspectives/fitting/fitpage.py: get_weight_flag
186    weight = None
187    if state.enable2D:
188        dy_data = state.data.err_data
189        data = state.data.data
190    else:
191        dy_data = state.data.dy
192        data = state.data.y
193    if state.dI_noweight:
194        weight = np.ones_like(data)
195    elif state.dI_didata:
196        weight = dy_data
197    elif state.dI_sqridata:
198        weight = np.sqrt(np.abs(data))
199    elif state.dI_idata:
200        weight = np.abs(data)
201    return weight
202
203_MODEL_CACHE = {}
204def load_model(name):
205    """
206    Given a model name load the Sasview shim model from sasmodels.
207
208    If name starts with "[Plug-in]" then load it as a custom model from the
209    plugins directory.  This code does not go through the Sasview model manager
210    interface since that loads all available models rather than just those
211    needed.
212    """
213    # Remember the models that are loaded so they are only loaded once.  While
214    # not strictly necessary (the models will use identical but different model
215    # info structure) it saves a little time and memory for the usual case
216    # where models are reused for simultaneous and batch fitting.
217    if name in _MODEL_CACHE:
218        return _MODEL_CACHE[name]
219    if name.startswith(PLUGIN_NAME_BASE):
220        name = name[len(PLUGIN_NAME_BASE):]
221        plugins_dir = find_plugins_dir()
222        path = os.path.abspath(os.path.join(plugins_dir, name + ".py"))
223        #print("loading custom", path)
224        model = load_custom_model(path)
225    elif name and name is not None and name.lower() != "none":
226        #print("loading standard", name)
227        model = _make_standard_model(name)
228    else:
229        model = None
230    _MODEL_CACHE[name] = model
231    return model
232
233def parse_optional_float(value):
234    """
235    Convert optional floating point from string to value, returning None
236    if string is None, empty or contains the word "None" (case insensitive).
237    """
238    if value is not None and value != "" and value.lower() != "none":
239        return float(value)
240    else:
241        return None
242
243def make_fitness(state):
244    # Load the model
245    category_name = state.categorycombobox
246    form_factor_name = state.formfactorcombobox
247    structure_factor_name = state.structurecombobox
248    multiplicity = state.multi_factor
249    if category_name == CUSTOM_MODEL:
250        assert form_factor_name.startswith(PLUGIN_NAME_BASE)
251    form_factor_model = load_model(form_factor_name)
252    structure_factor_model = load_model(structure_factor_name)
253    model = form_factor_model(multiplicity)
254    if structure_factor_model is not None:
255        model = MultiplicationModel(model, structure_factor_model())
256
257    # Set the dispersity distributions for all model parameters.
258    # Default to gaussian
259    dists = {par_name + ".type": "gaussian" for par_name in model.dispersion}
260    dists.update(state.disp_obj_dict)
261    for par_name, dist_name in state.disp_obj_dict.items():
262        dispersion = POLYDISPERSITY_MODELS[dist_name]()
263        if dist_name == "array":
264            dispersion.set_weights(state.values[par_name], state.weights[par_name])
265        base_par = par_name.replace('.width', '')
266        model.set_dispersion(base_par, dispersion)
267
268    # Put parameter values and ranges into the model
269    fitted = []
270    for par_tuple in state.parameters + state.fixed_param + state.fittable_param:
271        par = SasviewParameter(*par_tuple)
272        if par.name not in state.weights:
273            # Don't try to set parameter values for array distributions
274            # TODO: keep weights filename in the array distribution object
275            model.setParam(par.name, parse_optional_float(par.value))
276        if par.fitted:
277            fitted.append(par.name)
278        if par.name in model.details:
279            lower = parse_optional_float(par.lower[1])
280            upper = parse_optional_float(par.upper[1])
281            model.details[par.name] = [par.units, lower, upper]
282    #print("pars", model.params)
283    #print("limits", model.details)
284    #print("fitted", fitted)
285
286    # Set the resolution
287    data = copy.deepcopy(state.data)
288    if state.disable_smearer:
289        smearer = None
290    elif state.enable_smearer:
291        smearer = smear_selection(data, model)
292    elif state.pinhole_smearer:
293        # see sasgui/perspectives/fitting/basepage.py: reset_page_helper
294        dx_percent = state.dx_percent
295        if state.dx_old:
296            dx_percent = 100*(state.dx_percent / data.x[0])
297        # see sasgui/perspectives/fitting/fitpage.py: _set_pinhole_smear
298        percent = dx_percent / 100.
299        if state.enable2D:
300            # smear_type is Pinhole2D.
301            # dqx_data, dqy_data initialized to 0
302            data.dqx_data = percent * data.qx_data
303            data.dqy_data = percent * data.qy_data
304        else:
305            data.dx = percent * data.x
306            data.dxl = data.dxw = None  # be sure it is not slit-smeared
307        smearer = smear_selection(data, model)
308    elif state.slit_smearer:
309        # see sasgui/perspectives/fitting/fitpage.py: _set_pinhole_smear
310        data_len = len(data.x)
311        data.dx = None
312        data.dxl = (state.dxl if state.dxl is not None else 0.) * np.ones(data_len)
313        data.dxw = (state.dxw if state.dxw is not None else 0.) * np.ones(data_len)
314        smearer = smear_selection(data, model)
315    else:
316        raise ValueError("expected resolution specification for fit")
317
318    # Set the data weighting (dI, sqrt(I), I, or uniform)
319    weight = get_data_weight(state)
320
321    # Note: wx GUI makes a copy of the data and assigns weight to
322    # data.err_data/data.dy instead of using the err_data/dy keywords
323    # when creating the Fit object.
324
325    # Make fit data object and set the data weights
326    # TODO: check 2D masked data
327    if state.enable2D:
328        fitdata = FitData2D(sas_data2d=data, data=data.data,
329                            err_data=weight)
330    else:
331        data.mask = (np.isnan(data.y) if data.y is not None
332                        else np.zeros_like(data.x, dtype='bool'))
333        fitdata = FitData1D(x=data.x, y=data.y,
334                            dx=data.dx, dy=weight, smearer=smearer)
335        fitdata.set_fit_range(qmin=state.qmin, qmax=state.qmax)
336    fitdata.sas_data = data
337
338    # Don't need initial values since they have been stuffed into the model
339    # If provided, then they should be one-to-one with the parameter names
340    # listed in fitted.
341    initial_values = None
342
343    fitmodel = Model(model, fitdata)
344    fitmodel.name = state.fit_page
345    fitness = SasFitness(
346        model=fitmodel,
347        data=fitdata,
348        constraints=state.constraints,
349        fitted=fitted,
350        initial_values=initial_values,
351        )
352
353    return fitness
354
355class BumpsPlugin:
356    """
357    Object holding methods for interacting with SasView using the direct
358    bumps interface.
359    """
360    #@staticmethod
361    #def data_view():
362    #    pass
363
364    #@staticmethod
365    #def model_view():
366    #    pass
367
368    @staticmethod
369    def load_model(filename):
370        fit = FitState(filename)
371        #fit.show()
372        #print("====\nfit", fit)
373        problem = fit.make_fitproblem()
374        #print(problem.show())
375        return problem
376
377    #@staticmethod
378    #def new_model():
379    #    pass
380
381
382def setup_sasview():
383    from sas.sasview.sasview import setup_logging, setup_mpl, setup_sasmodels
384    #setup_logging()
385    #setup_mpl()
386    setup_sasmodels()
387
388def setup_bumps():
389    """
390    Install the refl1d plugin into bumps, but don't run main.
391    """
392    import os
393    import bumps.cli
394    bumps.cli.set_mplconfig(appdatadir=os.path.join('.sasview', 'bumpsfit'))
395    bumps.cli.install_plugin(BumpsPlugin)
396
397def bumps_cli():
398    """
399    Install the SasView plugin into bumps and run the command line interface.
400    """
401    setup_sasview()
402    setup_bumps()
403    import bumps.cli
404    bumps.cli.main()
405
406if __name__ == "__main__":
407    # Allow run with:
408    #    python -m sas.sascalc.fit.fitstate
409    bumps_cli()
Note: See TracBrowser for help on using the repository browser.