source: sasview/src/sas/sascalc/fit/fitstate.py @ e66f9c1

ticket-1094-headless
Last change on this file since e66f9c1 was e66f9c1, checked in by Paul Kienzle <pkienzle@…>, 5 years ago

use dQ/|Q| for user-defined 2D pinhole smearing in headless fits

  • Property mode set to 100644
File size: 15.2 KB
Line 
1"""
2Interface to page state for saved fits.
3
4The 4.x sasview gui builds the model, smearer, etc. directly inside the wx
5GUI code.  This code separates the in-memory representation from the GUI.
6Initially it is used for a headless bumps fitting backend that operates
7directly on the saved XML; eventually it should provide methods to replace
8direct access to the PageState object so that the code for setting up and
9running fits in wx, qt, and headless versions of SasView shares a common
10in-memory representation.
11"""
12from __future__ import print_function, division
13
14import copy
15from collections import namedtuple
16
17import numpy as np
18
19from bumps.names import FitProblem
20
21from sasmodels.core import load_model_info
22from sasmodels.data import plot_theory
23from sasmodels.sasview_model import _make_standard_model, MultiplicationModel, load_custom_model
24from sasmodels.weights import MODELS as POLYDISPERSITY_MODELS
25
26from .pagestate import Reader, PageState, SimFitPageState, CUSTOM_MODEL
27from .BumpsFitting import SasFitness, ParameterExpressions
28from .AbstractFitEngine import FitData1D, FitData2D, Model
29from .models import PLUGIN_NAME_BASE, find_plugins_dir
30from .qsmearing import smear_selection
31
32# Monkey patch SasFitness class with plotter
33def sasfitness_plot(self, view='log'):
34    data, theory, resid = self.data.sas_data, self.theory(), self.residuals()
35    plot_theory(data, theory, resid, view)
36SasFitness.plot = sasfitness_plot
37
38# Use a named tuple for the sasview parameters
39PARAMETER_FIELDS = [
40    "fitted", "name", "value", "plusminus", "uncertainty",
41    "lower", "upper", "units",
42    ]
43SasviewParameter = namedtuple("Parameter", PARAMETER_FIELDS)
44
45class FitState(object):
46    def __init__(self, fitfile):
47        self.fitfile = fitfile
48        self.simfit = None
49        self.fits = []
50
51        reader = Reader(self._add_entry)
52        datasets = reader.read(fitfile)
53        self._set_constraints()
54        #print("loaded", datasets)
55
56    def _add_entry(self, state=None, datainfo=None, format=None):
57        """
58        Handed to the reader to receive and accumulate the loaded state objects.
59        """
60        # Note: datainfo is in state.data; format=.svs means reset fit panels
61        if isinstance(state, PageState):
62            # TODO: shouldn't the update be part of the load?
63            state._convert_to_sasmodels()
64            self.fits.append(state)
65        elif isinstance(state, SimFitPageState):
66            self.simfit = state
67        else:
68            # ignore empty fit info
69            pass
70
71    def __str__(self):
72        return '<SasFit %s>'%self.fitfile
73
74    def show(self):
75        """
76        Summarize the fit pages in the state object.
77        """
78        # Note: _dump_attrs isn't a closure, but putting it here anyway
79        # because it is specific to show and doesn't need to be efficient.
80        def _dump_attrs(obj, label=""):
81            #print(obj)
82            print("="*20, label)
83            for attr, value in sorted(obj.__dict__.items()):
84                if isinstance(value, (list, tuple)):
85                    print(attr)
86                    for item in value:
87                        print("   ", item)
88                else:
89                    print(attr, value)
90        for k, fit in enumerate(self.fits):
91            _dump_attrs(fit, label="Fit page "+str(k+1))
92        if self.simfit:
93            _dump_attrs(self.simfit, label="Constraints")
94
95    def make_fitproblem(self):
96        """
97        Build collection of bumps fitness calculators and return the FitProblem.
98        """
99        # TODO: batch info not stored with project/analysis file (ticket #907)
100        models = [make_fitness(state) for state in self.fits]
101        if not models:
102            raise RuntimeError("Nothing to fit")
103        fit_problem = FitProblem(models)
104        fit_problem.setp_hook = ParameterExpressions(models)
105        return fit_problem
106
107    def _set_constraints(self):
108        """
109        Adds fit_page and constraints list to each model.
110        """
111        # early return if no sim fit
112        if self.simfit is None:
113            for fit in self.fits:
114                fit.fit_page = 'M1'
115                fit.constraints = {}
116            return
117
118        # Note: simfitpage.py:load_from_save_state relabels the model and
119        # constraint on load, replacing the model name in the constraint
120        # expression with the new name for every constraint expression. We
121        # don't bother to do that here since we don't need to relabel the
122        # model ids.
123        constraints = {}
124        for item in self.simfit.constraints_list:
125            # model_cbox in the constraints list should match fit_page_source
126            # in the model list.
127            pairs = constraints.setdefault(item['model_cbox'], [])
128            pairs.append((item['param_cbox'], item['constraint']))
129
130        # No way to uniquely map the page id (M1, M2, etc.) to the different
131        # items in fits; neither fit_number nor fit_page_source is in the
132        # pagestate for the fit, and neither model_name nor data name are
133        # unique.  The usual case of one model per data file will get us most
134        # of the way there but there will be ambiguity when the data file
135        # is not unique, e.g., when different parts of the data set are
136        # fit with different models.  If the same model and same data are
137        # used (e.g., with different background, scale or resolution in
138        # different segments) then the model-fit association will be assigned
139        # arbitrarily based on whichever happens to come first.
140        used = []
141        for model in self.simfit.model_list:
142            matched = 0
143            for fit in self.fits:
144                #print(model['name'], fit.data_id, model_name(fit), model['model_name'])
145                if (fit.data_id == model['name']
146                        and model_name(fit) == model['model_name']
147                        and fit not in used):
148                    fit.fit_page = model['fit_page_source']
149                    fit.constraints = constraints.setdefault(fit.fit_page, [])
150                    used.append(fit)
151                    matched += 1
152            if matched > 1:
153                raise ValueError("more than one model matches %s"
154                                 % model['fit_page_source'])
155            elif matched == 0:
156                raise ValueError("could not find model %s in file"
157                                 % model['fit_page_source'])
158
159
160def model_name(state):
161    """
162    Build the model name out of form factor and structure factor (if present).
163
164    This will be the name that is stored as the model name in the simultaneous
165    fit model_list structure corresponding to the form factor and structure
166    factor given on the individual fit pages.  The model name is used to help
167    disambiguate different SASentry sections with the same dataset.
168    """
169    p_model, s_model = state.formfactorcombobox, state.structurecombobox
170    if s_model is not None and s_model != "" and s_model.lower() != "none":
171        return '*'.join((p_model, s_model))
172    else:
173        return p_model
174
175def get_data_weight(state):
176    """
177    Get error bars on data.  These could be the values computed by reduction
178    and stored in the file, the square root of the intensity (if instensity
179    is approximately counts), the intensity itself (would be better as a
180    percentage of the intensity, such as 2% or 5% depending on relative
181    counting time), or one for equal weight uncertainty depending on the
182    value of state.dI_*.
183    """
184    # Cribbed from perspectives/fitting/utils.py:get_weight and
185    # perspectives/fitting/fitpage.py: get_weight_flag
186    weight = None
187    if state.enable2D:
188        dy_data = state.data.err_data
189        data = state.data.data
190    else:
191        dy_data = state.data.dy
192        data = state.data.y
193    if state.dI_noweight:
194        weight = np.ones_like(data)
195    elif state.dI_didata:
196        weight = dy_data
197    elif state.dI_sqridata:
198        weight = np.sqrt(np.abs(data))
199    elif state.dI_idata:
200        weight = np.abs(data)
201    return weight
202
203_MODEL_CACHE = {}
204def load_model(name):
205    """
206    Given a model name load the Sasview shim model from sasmodels.
207
208    If name starts with "[Plug-in]" then load it as a custom model from the
209    plugins directory.  This code does not go through the Sasview model manager
210    interface since that loads all available models rather than just those
211    needed.
212    """
213    # Remember the models that are loaded so they are only loaded once.  While
214    # not strictly necessary (the models will use identical but different model
215    # info structure) it saves a little time and memory for the usual case
216    # where models are reused for simultaneous and batch fitting.
217    if name in _MODEL_CACHE:
218        return _MODEL_CACHE[name]
219    if name.startswith(PLUGIN_NAME_BASE):
220        name = name[len(PLUGIN_NAME_BASE):]
221        plugins_dir = find_plugins_dir()
222        path = os.path.abspath(os.path.join(plugins_dir, name + ".py"))
223        #print("loading custom", path)
224        model = load_custom_model(path)
225    elif name and name is not None and name.lower() != "none":
226        #print("loading standard", name)
227        model = _make_standard_model(name)
228    else:
229        model = None
230    _MODEL_CACHE[name] = model
231    return model
232
233def parse_optional_float(value):
234    """
235    Convert optional floating point from string to value, returning None
236    if string is None, empty or contains the word "None" (case insensitive).
237    """
238    if value is not None and value != "" and value.lower() != "none":
239        return float(value)
240    else:
241        return None
242
243def make_fitness(state):
244    # Load the model
245    category_name = state.categorycombobox
246    form_factor_name = state.formfactorcombobox
247    structure_factor_name = state.structurecombobox
248    multiplicity = state.multi_factor
249    if category_name == CUSTOM_MODEL:
250        assert form_factor_name.startswith(PLUGIN_NAME_BASE)
251    form_factor_model = load_model(form_factor_name)
252    structure_factor_model = load_model(structure_factor_name)
253    model = form_factor_model(multiplicity)
254    if structure_factor_model is not None:
255        model = MultiplicationModel(model, structure_factor_model())
256
257    # Set the dispersity distributions for all model parameters.
258    # Default to gaussian
259    dists = {par_name + ".type": "gaussian" for par_name in model.dispersion}
260    dists.update(state.disp_obj_dict)
261    for par_name, dist_name in state.disp_obj_dict.items():
262        dispersion = POLYDISPERSITY_MODELS[dist_name]()
263        if dist_name == "array":
264            dispersion.set_weights(state.values[par_name], state.weights[par_name])
265        base_par = par_name.replace('.width', '')
266        model.set_dispersion(base_par, dispersion)
267
268    # Put parameter values and ranges into the model
269    fitted = []
270    for par_tuple in state.parameters + state.fixed_param + state.fittable_param:
271        par = SasviewParameter(*par_tuple)
272        if par.name not in state.weights:
273            # Don't try to set parameter values for array distributions
274            # TODO: keep weights filename in the array distribution object
275            model.setParam(par.name, parse_optional_float(par.value))
276        if par.fitted:
277            fitted.append(par.name)
278        if par.name in model.details:
279            lower = parse_optional_float(par.lower[1])
280            upper = parse_optional_float(par.upper[1])
281            model.details[par.name] = [par.units, lower, upper]
282    #print("pars", model.params)
283    #print("limits", model.details)
284    #print("fitted", fitted)
285
286    # Set the resolution
287    data = copy.deepcopy(state.data)
288    if state.disable_smearer:
289        smearer = None
290    elif state.enable_smearer:
291        smearer = smear_selection(data, model)
292    elif state.pinhole_smearer:
293        # see sasgui/perspectives/fitting/basepage.py: reset_page_helper
294        dx_percent = state.dx_percent
295        if state.dx_old:
296            dx_percent = 100*(state.dx_percent / data.x[0])
297        # see sasgui/perspectives/fitting/fitpage.py: _set_pinhole_smear
298        percent = dx_percent / 100.
299        if state.enable2D:
300            # smear_type is Pinhole2D.
301            q = np.sqrt(data.qx_data**2 + data.qy_data**2)
302            data.dx_data = data.dqy_data = percent*q
303        else:
304            data.dx = percent * data.x
305            data.dxl = data.dxw = None  # be sure it is not slit-smeared
306        smearer = smear_selection(data, model)
307    elif state.slit_smearer:
308        # see sasgui/perspectives/fitting/fitpage.py: _set_pinhole_smear
309        data_len = len(data.x)
310        data.dx = None
311        data.dxl = (state.dxl if state.dxl is not None else 0.) * np.ones(data_len)
312        data.dxw = (state.dxw if state.dxw is not None else 0.) * np.ones(data_len)
313        smearer = smear_selection(data, model)
314    else:
315        raise ValueError("expected resolution specification for fit")
316
317    # Set the data weighting (dI, sqrt(I), I, or uniform)
318    weight = get_data_weight(state)
319
320    # Note: wx GUI makes a copy of the data and assigns weight to
321    # data.err_data/data.dy instead of using the err_data/dy keywords
322    # when creating the Fit object.
323
324    # Make fit data object and set the data weights
325    # TODO: check 2D masked data
326    if state.enable2D:
327        fitdata = FitData2D(sas_data2d=data, data=data.data,
328                            err_data=weight)
329    else:
330        data.mask = (np.isnan(data.y) if data.y is not None
331                        else np.zeros_like(data.x, dtype='bool'))
332        fitdata = FitData1D(x=data.x, y=data.y,
333                            dx=data.dx, dy=weight, smearer=smearer)
334        fitdata.set_fit_range(qmin=state.qmin, qmax=state.qmax)
335    fitdata.sas_data = data
336
337    # Don't need initial values since they have been stuffed into the model
338    # If provided, then they should be one-to-one with the parameter names
339    # listed in fitted.
340    initial_values = None
341
342    fitmodel = Model(model, fitdata)
343    fitmodel.name = state.fit_page
344    fitness = SasFitness(
345        model=fitmodel,
346        data=fitdata,
347        constraints=state.constraints,
348        fitted=fitted,
349        initial_values=initial_values,
350        )
351
352    return fitness
353
354class BumpsPlugin:
355    """
356    Object holding methods for interacting with SasView using the direct
357    bumps interface.
358    """
359    #@staticmethod
360    #def data_view():
361    #    pass
362
363    #@staticmethod
364    #def model_view():
365    #    pass
366
367    @staticmethod
368    def load_model(filename):
369        state = FitState(filename)
370        #state.show()
371        #print("====\nfit", state)
372        problem = state.make_fitproblem()
373        #print(problem.show())
374        return problem
375
376    #@staticmethod
377    #def new_model():
378    #    pass
379
380
381def setup_sasview():
382    from sas.sasview.sasview import setup_logging, setup_mpl, setup_sasmodels
383    #setup_logging()
384    #setup_mpl()
385    setup_sasmodels()
386
387def setup_bumps():
388    """
389    Install the refl1d plugin into bumps, but don't run main.
390    """
391    import os
392    import bumps.cli
393    bumps.cli.set_mplconfig(appdatadir=os.path.join('.sasview', 'bumpsfit'))
394    bumps.cli.install_plugin(BumpsPlugin)
395
396def bumps_cli():
397    """
398    Install the SasView plugin into bumps and run the command line interface.
399    """
400    setup_sasview()
401    setup_bumps()
402    import bumps.cli
403    bumps.cli.main()
404
405if __name__ == "__main__":
406    # Allow run with:
407    #    python -m sas.sascalc.fit.fitstate
408    bumps_cli()
Note: See TracBrowser for help on using the repository browser.