source: sasview/src/sas/sascalc/fit/fitstate.py @ f0b9bce

ticket-1094-headless
Last change on this file since f0b9bce was f6f3fb4, checked in by Paul Kienzle <pkienzle@…>, 7 years ago

add ability to run bumps fit from saved project without the sasview gui

  • Property mode set to 100644
File size: 15.0 KB
Line 
1"""
2Interface to page state for saved fits.
3
4The 4.x sasview gui builds the model, smearer, etc. directly inside the wx
5GUI code.  This code separates the in-memory representation from the GUI.
6Initially it is used for a headless bumps fitting backend that operates
7directly on the saved XML; eventually it should provide methods to replace
8direct access to the PageState object so that the code for setting up and
9running fits in wx, qt, and headless versions of SasView shares a common
10in-memory representation.
11"""
12from __future__ import print_function, division
13
14import copy
15from collections import namedtuple
16
17import numpy as np
18
19from bumps.names import FitProblem
20
21from sasmodels.core import load_model_info
22from sasmodels.data import plot_theory
23from sasmodels.sasview_model import _make_standard_model, MultiplicationModel, load_custom_model
24from sasmodels.weights import MODELS as POLYDISPERSITY_MODELS
25
26from .pagestate import Reader, PageState, SimFitPageState, CUSTOM_MODEL
27from .BumpsFitting import SasFitness, ParameterExpressions
28from .AbstractFitEngine import FitData1D, FitData2D, Model
29from .models import PLUGIN_NAME_BASE, find_plugins_dir
30from .qsmearing import smear_selection
31
32# Monkey patch SasFitness class with plotter
33def sasfitness_plot(self, view='log'):
34    data, theory, resid = self.data.sas_data, self.theory(), self.residuals()
35    plot_theory(data, theory, resid, view)
36SasFitness.plot = sasfitness_plot
37
38# Use a named tuple for the sasview parameters
39PARAMETER_FIELDS = [
40    "fitted", "name", "value", "plusminus", "uncertainty",
41    "lower", "upper", "units",
42    ]
43SasviewParameter = namedtuple("Parameter", PARAMETER_FIELDS)
44
45class FitState(object):
46    def __init__(self, fitfile):
47        self.fitfile = fitfile
48        self.simfit = None
49        self.fits = []
50
51        reader = Reader(self._add_entry)
52        datasets = reader.read(fitfile)
53        self._set_constraints()
54        #print("loaded", datasets)
55
56    def _add_entry(self, state=None, datainfo=None, format=None):
57        """
58        Handed to the reader to receive and accumulate the loaded state objects.
59        """
60        # Note: datainfo is in state.data; format=.svs means reset fit panels
61        if isinstance(state, PageState):
62            # TODO: shouldn't the update be part of the load?
63            state._convert_to_sasmodels()
64            self.fits.append(state)
65        elif isinstance(state, SimFitPageState):
66            self.simfit = state
67        else:
68            # ignore empty fit info
69            pass
70
71    def __str__(self):
72        return '<SasFit %s>'%self.fitfile
73
74    def show(self):
75        """
76        Summarize the fit pages in the state object.
77        """
78        for k, fit in enumerate(self.fits):
79            print("="*20, "Fit page", k+1)
80            #print(fit)
81            for attr, value in sorted(fit.__dict__.items()):
82                if isinstance(value, (list, tuple)):
83                    print(attr)
84                    for item in value:
85                        print("   ", item)
86                else:
87                    print(attr, value)
88        if self.simfit:
89            print("="*20, "Constraints")
90            #print(self.simfit)
91            for attr, value in sorted(self.simfit.__dict__.items()):
92                if isinstance(value, (list, tuple)):
93                    print(attr)
94                    for item in value:
95                        print("   ", item)
96                else:
97                    print(attr, value)
98
99    def make_fitproblem(self):
100        """
101        Build collection of bumps fitness calculators and return the FitProblem.
102        """
103        # TODO: batch info not stored with project/analysis file (ticket #907)
104        models = [make_fitness(state) for state in self.fits]
105        if not models:
106            raise RuntimeError("Nothing to fit")
107        fit_problem = FitProblem(models)
108        fit_problem.setp_hook = ParameterExpressions(models)
109        return fit_problem
110
111    def _set_constraints(self):
112        """
113        Adds fit_page and constraints list to each model.
114        """
115        # early return if no sim fit
116        if self.simfit is None:
117            for fit in self.fits:
118                fit.fit_page = 'M1'
119                fit.constraints = {}
120            return
121
122        # Note: simfitpage.py:load_from_save_state relabels the model and
123        # constraint on load, replacing the model name in the constraint
124        # expression with the new name for every constraint expression. We
125        # don't bother to do that here since we don't need to relabel the
126        # model ids.
127        constraints = {}
128        for item in self.simfit.constraints_list:
129            # model_cbox in the constraints list should match fit_page_source
130            # in the model list.
131            pairs = constraints.setdefault(item['model_cbox'], [])
132            pairs.append((item['param_cbox'], item['constraint']))
133
134        # No way to uniquely map the page id (M1, M2, etc.) to the different
135        # items in fits; neither fit_number nor fit_page_source is in the
136        # pagestate for the fit, and neither model_name nor data name are
137        # unique.  The usual case of one model per data file will get us most
138        # of the way there but there will be ambiguity when the data file
139        # is not unique, e.g., when different parts of the data set are
140        # fit with different models.  If the same model and same data are
141        # used (e.g., with different background, scale or resolution in
142        # different segments) then the model-fit association will be assigned
143        # arbitrarily based on whichever happens to come first.
144        used = []
145        for model in self.simfit.model_list:
146            for fit in self.fits:
147                #print(model['name'], fit.data_id, model_name(fit), model['model_name'])
148                if (fit.data_id == model['name']
149                        and model_name(fit) == model['model_name']
150                        and fit not in used):
151                    fit.fit_page = model['fit_page_source']
152                    fit.constraints = constraints.setdefault(fit.fit_page, [])
153                    used.append(fit)
154                    break
155            else:
156                raise ValueError("could not find model %s in file"
157                                 % model['fit_page_source'])
158
159
160def model_name(state):
161    """
162    Build the model name out of form factor and structure factor (if present).
163
164    This will be the name that is stored as the model name in the simultaneous
165    fit model_list structure corresponding to the form factor and structure
166    factor given on the individual fit pages.  The model name is used to help
167    disambiguate different SASentry sections with the same dataset.
168    """
169    p_model, s_model = state.formfactorcombobox, state.structurecombobox
170    if s_model is not None and s_model != "" and s_model.lower() != "none":
171        return '*'.join((p_model, s_model))
172    else:
173        return p_model
174
175def get_data_weight(state):
176    """
177    Get error bars on data.  These could be the values computed by reduction
178    and stored in the file, the square root of the intensity (if instensity
179    is approximately counts), the intensity itself (would be better as a
180    percentage of the intensity, such as 2% or 5% depending on relative
181    counting time), or one for equal weight uncertainty depending on the
182    value of state.dI_*.
183    """
184    # Cribbed from perspectives/fitting/utils.py:get_weight and
185    # perspectives/fitting/fitpage.py: get_weight_flag
186    weight = None
187    if state.enable2D:
188        dy_data = state.data.err_data
189        data = state.data.data
190    else:
191        dy_data = state.data.dy
192        data = state.data.y
193    if state.dI_noweight:
194        weight = np.ones_like(data)
195    elif state.dI_didata:
196        weight = dy_data
197    elif state.dI_sqridata:
198        weight = np.sqrt(np.abs(data))
199    elif state.dI_idata:
200        weight = np.abs(data)
201    return weight
202
203_MODEL_CACHE = {}
204def load_model(name):
205    """
206    Given a model name load the Sasview shim model from sasmodels.
207
208    If name starts with "[Plug-in]" then load it as a custom model from the
209    plugins directory.  This code does not go through the Sasview model manager
210    interface since that loads all available models rather than just those
211    needed.
212    """
213    # Remember the models that are loaded so they are only loaded once.  While
214    # not strictly necessary (the models will use identical but different model
215    # info structure) it saves a little time and memory for the usual case
216    # where models are reused for simultaneous and batch fitting.
217    if name in _MODEL_CACHE:
218        return _MODEL_CACHE[name]
219    if name.startswith(PLUGIN_NAME_BASE):
220        name = name[len(PLUGIN_NAME_BASE):]
221        plugins_dir = find_plugins_dir()
222        path = os.path.abspath(os.path.join(plugins_dir, name + ".py"))
223        #print("loading custom", path)
224        model = load_custom_model(path)
225    elif name and name is not None and name.lower() != "none":
226        #print("loading standard", name)
227        model = _make_standard_model(name)
228    else:
229        model = None
230    _MODEL_CACHE[name] = model
231    return model
232
233def parse_optional_float(value):
234    """
235    Convert optional floating point from string to value, returning None
236    if string is None, empty or contains the word "None" (case insensitive).
237    """
238    if value is not None and value != "" and value.lower() != "none":
239        return float(value)
240    else:
241        return None
242
243def make_fitness(state):
244    # Load the model
245    category_name = state.categorycombobox
246    form_factor_name = state.formfactorcombobox
247    structure_factor_name = state.structurecombobox
248    multiplicity = state.multi_factor
249    if category_name == CUSTOM_MODEL:
250        assert form_factor_name.startswith(PLUGIN_NAME_BASE)
251    form_factor_model = load_model(form_factor_name)
252    structure_factor_model = load_model(structure_factor_name)
253    model = form_factor_model(multiplicity)
254    if structure_factor_model is not None:
255        model = MultiplicationModel(model, structure_factor_model())
256
257    # Set the dispersity distributions for all model parameters.
258    # Default to gaussian
259    dists = {par_name + ".type": "gaussian" for par_name in model.dispersion}
260    dists.update(state.disp_obj_dict)
261    for par_name, dist_name in state.disp_obj_dict.items():
262        dispersion = POLYDISPERSITY_MODELS[dist_name]()
263        if dist_name == "array":
264            dispersion.set_weights(state.values[par_name], state.weights[par_name])
265        base_par = par_name.replace('.width', '')
266        model.set_dispersion(base_par, dispersion)
267
268    # Put parameter values and ranges into the model
269    fitted = []
270    for par_tuple in state.parameters + state.fixed_param + state.fittable_param:
271        par = SasviewParameter(*par_tuple)
272        if par.name not in state.weights:
273            # Don't try to set parameter values for array distributions
274            # TODO: keep weights filename in the array distribution object
275            model.setParam(par.name, parse_optional_float(par.value))
276        if par.fitted:
277            fitted.append(par.name)
278        if par.name in model.details:
279            lower = parse_optional_float(par.lower[1])
280            upper = parse_optional_float(par.upper[1])
281            model.details[par.name] = [par.units, lower, upper]
282    #print("pars", model.params)
283    #print("limits", model.details)
284    #print("fitted", fitted)
285
286    # Set the resolution
287    data = copy.deepcopy(state.data)
288    if state.disable_smearer:
289        smearer = None
290    elif state.enable_smearer:
291        smearer = smear_selection(data, model)
292    elif state.pinhole_smearer:
293        # see sasgui/perspectives/fitting/basepage.py: reset_page_helper
294        dx_percent = state.dx_percent
295        if state.dx_old:
296            dx_percent = 100*(state.dx_percent / data.x[0])
297        # see sasgui/perspectives/fitting/fitpage.py: _set_pinhole_smear
298        percent = dx_percent / 100.
299        if state.enable2D:
300            # smear_type is Pinhole2D.
301            # dqx_data, dqy_data initialized to 0
302            data.dqx_data = percent * data.qx_data
303            data.dqy_data = percent * data.qy_data
304        else:
305            data.dx = percent * data.x
306            data.dxl = data.dxw = None  # be sure it is not slit-smeared
307        smearer = smear_selection(data, model)
308    elif state.slit_smearer:
309        # see sasgui/perspectives/fitting/fitpage.py: _set_pinhole_smear
310        data_len = len(data.x)
311        data.dx = None
312        data.dxl = (state.dxl if state.dxl is not None else 0.) * np.ones(data_len)
313        data.dxw = (state.dxw if state.dxw is not None else 0.) * np.ones(data_len)
314        smearer = smear_selection(data, model)
315    else:
316        raise ValueError("expected resolution specification for fit")
317
318    # Set the data weighting (dI, sqrt(I), I, or uniform)
319    weight = get_data_weight(state)
320
321    # Make fit data object and set the data weights
322    # TODO: check 2D masked data
323    if state.enable2D:
324        fitdata = FitData2D(sas_data2d=data, data=data.data,
325                            err_data=data.err_data)
326        fitdata.err_data = weight
327    else:
328        data.mask = (np.isnan(data.y) if data.y is not None
329                        else np.zeros_like(data.x, dtype='bool'))
330        fitdata = FitData1D(x=data.x, y=data.y,
331                            dx=data.dx, dy=data.dy, smearer=smearer)
332        fitdata.dy = weight
333        fitdata.set_fit_range(qmin=state.qmin, qmax=state.qmax)
334    fitdata.sas_data = data
335
336    # Don't need initial values since they have been stuffed into the model
337    # If provided, then they should be one-to-one with the parameter names
338    # listed in fitted.
339    initial_values = None
340
341    fitmodel = Model(model, fitdata)
342    fitmodel.name = state.fit_page
343    fitness = SasFitness(
344        model=fitmodel,
345        data=fitdata,
346        constraints=state.constraints,
347        fitted=fitted,
348        initial_values=initial_values,
349        )
350
351    return fitness
352
353class BumpsPlugin:
354    """
355    Object holding methods for interacting with SasView using the direct
356    bumps interface.
357    """
358    #@staticmethod
359    #def data_view():
360    #    pass
361
362    #@staticmethod
363    #def model_view():
364    #    pass
365
366    @staticmethod
367    def load_model(filename):
368        fit = FitState(filename)
369        #fit.show()
370        #print("====\nfit", fit)
371        problem = fit.make_fitproblem()
372        #print(problem.show())
373        return problem
374
375    #@staticmethod
376    #def new_model():
377    #    pass
378
379
380def setup_sasview():
381    from sas.sasview.sasview import setup_logging, setup_mpl, setup_sasmodels
382    #setup_logging()
383    #setup_mpl()
384    setup_sasmodels()
385
386def setup_bumps():
387    """
388    Install the refl1d plugin into bumps, but don't run main.
389    """
390    import os
391    import bumps.cli
392    bumps.cli.set_mplconfig(appdatadir=os.path.join('.sasview', 'bumpsfit'))
393    bumps.cli.install_plugin(BumpsPlugin)
394
395def bumps_cli():
396    """
397    Install the Refl1D plugin into bumps and run the command line interface.
398    """
399    setup_sasview()
400    setup_bumps()
401    import bumps.cli
402    bumps.cli.main()
403
404if __name__ == "__main__":
405    # Allow run with:
406    #    python -m sas.sascalc.fit.fitstate
407    bumps_cli()
Note: See TracBrowser for help on using the repository browser.