source: sasview/sansview/perspectives/fitting/fitting.py @ 442895f

ESS_GUIESS_GUI_DocsESS_GUI_batch_fittingESS_GUI_bumps_abstractionESS_GUI_iss1116ESS_GUI_iss879ESS_GUI_iss959ESS_GUI_openclESS_GUI_orderingESS_GUI_sync_sascalccostrafo411magnetic_scattrelease-4.1.1release-4.1.2release-4.2.2release_4.0.1ticket-1009ticket-1094-headlessticket-1242-2d-resolutionticket-1243ticket-1249ticket885unittest-saveload
Last change on this file since 442895f was 442895f, checked in by Gervaise Alina <gervyh@…>, 16 years ago

models added

  • Property mode set to 100644
File size: 22.6 KB
Line 
1import os,os.path, re
2import sys, wx, logging
3import string, numpy, pylab, math
4
5from copy import deepcopy
6from sans.guitools.plottables import Data1D, Theory1D
7from sans.guitools.PlotPanel import PlotPanel
8from sans.guicomm.events import NewPlotEvent, StatusEvent 
9from sans.fit.AbstractFitEngine import Model,Data
10from fitproblem import FitProblem
11from fitpanel import FitPanel
12
13import models
14import fitpage
15import park
16class PlottableDatas(Data,Data1D):
17    """ class plottable data: class allowing to plot Data type on panel"""
18   
19    def __init__(self,data=None,data1d=None):
20        Data.__init__(self,sans_data=data1d)
21        Data1D.__init__(self,x=data1d.x,y = data1d.y,dx = data1d.dx,dy = data1d.dy)
22        #self.x = data1d.x
23        #self.y = data1d.y
24        #self.dx = data1d.dx
25        #self.dy = data1d.dy
26        #self.data=data
27        self.group_id = data1d.group_id
28        #x_name, x_units = data1d.get_xaxis()
29        #y_name, y_units = data1d.get_yaxis()
30        #self.xaxis( x_name, x_units)
31        #self.yaxis( y_name, y_units )
32        #self.qmin = data.qmin
33        #self.qmax = data.qmax
34       
35
36class PlottableData(Data,Data1D):
37    """ class plottable data: class allowing to plot Data type on panel"""
38   
39    def __init__(self,data=None,data1d=None):
40        #Data.__init__(self,*args)
41        #Data1D.__init__(self,**kw)
42        self.x = data1d.x
43        self.y = data1d.y
44        self.dx = data1d.dx
45        self.dy = data1d.dy
46        self.data=data
47        self.group_id = data1d.group_id
48        x_name, x_units = data1d.get_xaxis() 
49        y_name, y_units = data1d.get_yaxis() 
50        self.xaxis( x_name, x_units)
51        self.yaxis( y_name, y_units )
52        self.qmin = data.qmin
53        self.qmax = data.qmax
54        def residuals(self, fn):
55            return self.data.residuals(fn)
56
57class Plugin:
58    """
59        Fitting plugin is used to perform fit
60    """
61    def __init__(self):
62        ## Plug-in name
63        self.sub_menu = "Fitting"
64       
65        ## Reference to the parent window
66        self.parent = None
67        self.menu_mng = models.ModelManager()
68        ## List of panels for the simulation perspective (names)
69        self.perspective = []
70        # Start with a good default
71        self.elapsed = 0.022
72        self.fitter  = None
73       
74        #Flag to let the plug-in know that it is running standalone
75        self.standalone=True
76        ## Fit engine
77        self._fit_engine = 'scipy'
78        # Log startup
79        logging.info("Fitting plug-in started")   
80
81    def populate_menu(self, id, owner):
82        """
83            Create a menu for the Fitting plug-in
84            @param id: id to create a menu
85            @param owner: owner of menu
86            @ return : list of information to populate the main menu
87        """
88        #Menu for fitting
89        self.menu1 = wx.Menu()
90        id1 = wx.NewId()
91        self.menu1.Append(id1, '&Show fit panel')
92        wx.EVT_MENU(owner, id1, self.on_perspective)
93        id3 = wx.NewId()
94        self.menu1.AppendCheckItem(id3, "park") 
95        wx.EVT_MENU(owner, id3, self._onset_engine)
96       
97        #menu for model
98        menu2 = wx.Menu()
99        self.menu_mng.populate_menu(menu2, owner)
100        id2 = wx.NewId()
101        owner.Bind(models.EVT_MODEL,self._on_model_menu)
102        self.fit_panel.set_owner(owner)
103        self.fit_panel.set_model_list(self.menu_mng.get_model_list())
104        owner.Bind(fitpage.EVT_MODEL_BOX,self._on_model_panel)
105        #create  menubar items
106        return [(id, self.menu1, "Fitting"),(id2, menu2, "Model")]
107   
108   
109    def help(self, evt):
110        """
111            Show a general help dialog.
112            TODO: replace the text with a nice image
113        """
114        pass
115   
116    def get_context_menu(self, graph=None):
117        """
118            Get the context menu items available for P(r)
119            @param graph: the Graph object to which we attach the context menu
120            @return: a list of menu items with call-back function
121        """
122        self.graph=graph
123        for item in graph.plottables:
124            if item.name==graph.selected_plottable and item.__class__.__name__ is not "Theory1D":
125                return [["Select Data", "Dialog with fitting parameters ", self._onSelect]] 
126        return []   
127
128
129    def get_panels(self, parent):
130        """
131            Create and return a list of panel objects
132        """
133        self.parent = parent
134        # Creation of the fit panel
135        self.fit_panel = FitPanel(self.parent, -1)
136        #Set the manager forthe main panel
137        self.fit_panel.set_manager(self)
138        # List of windows used for the perspective
139        self.perspective = []
140        self.perspective.append(self.fit_panel.window_name)
141        # take care of saving  data, model and page associated with each other
142        self.page_finder = {}
143        #index number to create random model name
144        self.index_model = 0
145        #create the fitting panel
146        return [self.fit_panel]
147   
148     
149    def get_perspective(self):
150        """
151            Get the list of panel names for this perspective
152        """
153        return self.perspective
154   
155   
156    def on_perspective(self, event):
157        """
158            Call back function for the perspective menu item.
159            We notify the parent window that the perspective
160            has changed.
161        """
162        self.parent.set_perspective(self.perspective)
163   
164   
165    def post_init(self):
166        """
167            Post initialization call back to close the loose ends
168            [Somehow openGL needs this call]
169        """
170        self.parent.set_perspective(self.perspective)
171       
172       
173    def _onSelect(self,event):
174        """
175            when Select data to fit a new page is created .Its reference is
176            added to self.page_finder
177        """
178        self.panel = event.GetEventObject()
179        for item in self.panel.graph.plottables:
180            if item.name == self.panel.graph.selected_plottable:
181                #find a name for the page created for notebook
182                try:
183                    name = item.group_id # item in Data1D
184                except:
185                    name = 'Fit'
186                try:
187                    page = self.fit_panel.add_fit_page(name)
188                    # add data associated to the page created
189                    page.set_data_name(item)
190                    #create a fitproblem storing all link to data,model,page creation
191                    self.page_finder[page]= FitProblem()
192                    #data_for_park= Data(sans_data=item)
193                    #datap = PlottableData(data=data_for_park,data1d=item)
194                    #self.page_finder[page].add_data(datap)
195                    self.page_finder[page].add_data(item)
196                except:
197                    #raise
198                    wx.PostEvent(self.parent, StatusEvent(status="Fitting error: \
199                    data already Selected "))
200                   
201                   
202    def get_page_finder(self):
203        """ @return self.page_finder used also by simfitpage.py""" 
204        return self.page_finder
205   
206   
207    def set_page_finder(self,modelname,names,values):
208        """
209             Used by simfitpage.py to reset a parameter given the string constrainst.
210             @param modelname: the name ot the model for with the parameter has to reset
211             @param value: can be a string in this case.
212             @param names: the paramter name
213             @note: expecting park used for fit.
214        """ 
215        sim_page=self.fit_panel.get_page(0)
216        for page, value in self.page_finder.iteritems():
217            if page != sim_page:
218                list=value.get_model()
219                model=list[0]
220                if model.name== modelname:
221                    value.set_model_param(names,values)
222                    break
223
224   
225                           
226    def split_string(self,item): 
227        """
228            receive a word containing dot and split it. used to split parameterset
229            name into model name and parameter name example:
230            paramaterset (item) = M1.A
231            @return model_name =M1 , parameter name =A
232        """
233        if string.find(item,".")!=-1:
234            param_names= re.split("\.",item)
235            model_name=param_names[0]
236            param_name=param_names[1] 
237            return model_name,param_name
238       
239       
240    def _single_fit_completed(self,result,pars,current_pg,qmin,qmax):
241        """
242            Display fit result on one page of the notebook.
243            @param result: result of fit
244            @param pars: list of names of parameters fitted
245            @param current_pg: the page where information will be displayed
246            @param qmin: the minimum value of x to replot the model
247            @param qmax: the maximum value of x to replot model
248         
249        """
250        try:
251            for page, value in self.page_finder.iteritems():
252                if page== current_pg:
253                    data = value.get_data()
254                    list = value.get_model()
255                    model= list[0]
256                    break
257            i = 0
258#            print "fitting: single fit pars ", pars
259            for name in pars:
260                if result.pvec.__class__==numpy.float64:
261                    model.setParam(name,result.pvec)
262                else:
263                    model.setParam(name,result.pvec[i])
264#                    print "fitting: single fit", name, result.pvec[i]
265                    i += 1
266#            print "fitting result : chisqr",result.fitness
267#            print "fitting result : pvec",result.pvec
268#            print "fitting result : stderr",result.stderr
269           
270            current_pg.onsetValues(result.fitness, result.pvec,result.stderr)
271            self.plot_helper(currpage=current_pg,qmin=qmin,qmax=qmax)
272        except:
273            raise
274            wx.PostEvent(self.parent, StatusEvent(status="Fitting error: %s" % sys.exc_value))
275           
276       
277    def _simul_fit_completed(self,result,qmin,qmax):
278        """
279            Parameter estimation completed,
280            display the results to the user
281            @param alpha: estimated best alpha
282            @param elapsed: computation time
283        """
284        try:
285            for page, value in self.page_finder.iteritems():
286                if value.get_scheduled()=='True':
287                    data = value.get_data()
288                    list = value.get_model()
289                    model= list[0]
290                   
291                    small_out = []
292                    small_cov = []
293                    i = 0
294                    #Separate result in to data corresponding to each page
295                    for p in result.parameters:
296                        model_name,param_name = self.split_string(p.name) 
297                        if model.name == model_name:
298                            small_out.append(p.value )
299                            small_cov.append(p.stderr)
300                            model.setParam(param_name,p.value) 
301                    # Display result on each page
302                    page.onsetValues(result.fitness, small_out,small_cov)
303                    #Replot model
304                    self.plot_helper(currpage= page,qmin= qmin,qmax= qmax) 
305        except:
306             wx.PostEvent(self.parent, StatusEvent(status="Fitting error: %s" % sys.exc_value))
307           
308   
309    def _on_single_fit(self,id=None,qmin=None,qmax=None):
310        """
311            perform fit for the  current page  and return chisqr,out and cov
312            @param engineName: type of fit to be performed
313            @param id: unique id corresponding to a fit problem(model, set of data)
314            @param model: model to fit
315           
316        """
317        #set an engine to perform fit
318        from sans.fit.Fitting import Fit
319        self.fitter= Fit(self._fit_engine)
320        #Setting an id to store model and data in fit engine
321        if id==None:
322            id=0
323        self.id = id
324        #Get information (model , data) related to the page on
325        #with the fit will be perform
326        current_pg=self.fit_panel.get_current_page() 
327        for page, value in self.page_finder.iteritems():
328            if page ==current_pg :
329                data = value.get_data()
330                list=value.get_model()
331                model=list[0]
332               
333                #Create list of parameters for fitting used
334                pars=[]
335                templist=[]
336                try:
337                    templist=current_pg.get_param_list()
338                except:
339                    wx.PostEvent(self.parent, StatusEvent(status="Fitting error: %s" % sys.exc_value))
340                    return
341             
342                for element in templist:
343                    try:
344                       pars.append(str(element[0].GetLabelText()))
345                    except:
346                        wx.PostEvent(self.parent, StatusEvent(status="Fitting error: %s" % sys.exc_value))
347                        return
348                # make sure to keep an alphabetic order
349                #of parameter names in the list
350                pars.sort()
351                #Do the single fit
352                try:
353                    self.fitter.set_model(Model(model), self.id, pars) 
354                    #print "fitting: data .x",data.x
355                    #print "fitting: data .y",data.y
356                    #print "fitting: data .dy",data.dy
357                    self.fitter.set_data(Data(sans_data=data),self.id,qmin,qmax)
358               
359                    result=self.fitter.fit()
360                    self._single_fit_completed(result,pars,current_pg,qmin,qmax)
361                   
362                except:
363                    raise
364                    wx.PostEvent(self.parent, StatusEvent(status="Single Fit error: %s" % sys.exc_value))
365                    return
366         
367    def _on_simul_fit(self, id=None,qmin=None,qmax=None):
368        """
369            perform fit for all the pages selected on simpage and return chisqr,out and cov
370            @param engineName: type of fit to be performed
371            @param id: unique id corresponding to a fit problem(model, set of data)
372             in park_integration
373            @param model: model to fit
374           
375        """
376        #set an engine to perform fit
377        from sans.fit.Fitting import Fit
378        self.fitter= Fit(self._fit_engine)
379       
380        #Setting an id to store model and data
381        if id==None:
382             id = 0
383        self.id = id
384       
385        for page, value in self.page_finder.iteritems():
386            try:
387                if value.get_scheduled()=='True':
388                    data = value.get_data()
389                    list = value.get_model()
390                    model= list[0]
391                    #Create dictionary of parameters for fitting used
392                    pars = []
393                    templist = []
394                    templist = page.get_param_list()
395                    for element in templist:
396                        try:
397                            name = str(element[0].GetLabelText())
398                            pars.append(name)
399                        except:
400                            wx.PostEvent(self.parent, StatusEvent(status="Fitting error: %s" % sys.exc_value))
401                            return
402                    self.fitter.set_model(Model(model), self.id, pars) 
403                    self.fitter.set_data(Data(sans_data=data),self.id,qmin,qmax)
404               
405                    self.id += 1 
406            except:
407                wx.PostEvent(self.parent, StatusEvent(status="Fitting error: %s" % sys.exc_value))
408                return 
409        #Do the simultaneous fit
410        try:
411            result=self.fitter.fit()
412            self._simul_fit_completed(result,qmin,qmax)
413        except:
414            wx.PostEvent(self.parent, StatusEvent(status="Simultaneous Fitting error: %s" % sys.exc_value))
415            return
416       
417       
418    def _onset_engine(self,event):
419        """ set engine to scipy"""
420        if self._fit_engine== 'park':
421            self._on_change_engine('scipy')
422        else:
423            self._on_change_engine('park')
424        wx.PostEvent(self.parent, StatusEvent(status="Engine set to: %s" % self._fit_engine))
425 
426   
427    def _on_change_engine(self, engine='park'):
428        """
429            Allow to select the type of engine to perform fit
430            @param engine: the key work of the engine
431        """
432        self._fit_engine = engine
433   
434   
435    def _on_model_panel(self, evt):
436        """
437            react to model selection on any combo box or model menu.plot the model 
438        """
439        model = evt.model
440        name = evt.name
441        sim_page=self.fit_panel.get_page(0)
442        current_pg = self.fit_panel.get_current_page() 
443        if current_pg != sim_page:
444            current_pg.set_model_name(name)
445            current_pg.set_panel(model)
446            try:
447                data=self.page_finder[current_pg].get_data()
448                M_name="M"+str(self.index_model)+"= "+name+"("+data.group_id+")"
449            except:
450                raise 
451                M_name="M"+str(self.index_model)+"= "+name
452            model.name="M"+str(self.index_model)
453            self.index_model += 1 
454           
455            self.page_finder[current_pg].set_model(model,M_name)
456            self.plot_helper(currpage= current_pg,qmin= None,qmax= None)
457            sim_page.add_model(self.page_finder)
458       
459           
460    def redraw_model(self,qmin= None,qmax= None):
461        """
462            Draw a theory according to model changes or data range.
463            @param qmin: the minimum value plotted for theory
464            @param qmax: the maximum value plotted for theory
465        """
466        current_pg=self.fit_panel.get_current_page()
467        for page, value in self.page_finder.iteritems():
468            if page ==current_pg :
469                break 
470        self.plot_helper(currpage=page,qmin= qmin,qmax= qmax)
471       
472    def plot_helper(self,currpage,qmin=None,qmax=None):
473        """
474            Plot a theory given a model and data
475            @param model: the model from where the theory is derived
476            @param currpage: page in a dictionary referring to some data
477        """
478        if self.fit_panel.get_page_count() >1:
479            for page in self.page_finder.iterkeys():
480                if  page==currpage : 
481                    break 
482            data=self.page_finder[page].get_data()
483            list=self.page_finder[page].get_model()
484            model=list[0]
485            if data!=None:
486                theory = Theory1D(x=[], y=[])
487                theory.name = "Model"
488                theory.group_id = data.group_id
489             
490                x_name, x_units = data.get_xaxis() 
491                y_name, y_units = data.get_yaxis() 
492                theory.xaxis(x_name, x_units)
493                theory.yaxis(y_name, y_units)
494                #print"fitting : redraw data.x",data.x
495                #print"fitting : redraw data.y",data.y
496                #print"fitting : redraw data.dy",data.dy
497                if qmin == None :
498                   qmin = min(data.x)
499                if qmax == None :
500                    qmax = max(data.x)
501                try:
502                    tempx = qmin
503                    tempy = model.run(qmin)
504                    theory.x.append(tempx)
505                    theory.y.append(tempy)
506                except :
507                        wx.PostEvent(self.parent, StatusEvent(status="fitting \
508                        skipping point x %g %s" %(qmin, sys.exc_value)))
509                           
510                for i in range(len(data.x)):
511                    try:
512                        if data.x[i]> qmin and data.x[i]< qmax:
513                            tempx = data.x[i]
514                            tempy = model.run(tempx)
515                           
516                            theory.x.append(tempx) 
517                            theory.y.append(tempy)
518                    except:
519                        wx.PostEvent(self.parent, StatusEvent(status="fitting \
520                        skipping point x %g %s" %(data.x[i], sys.exc_value)))   
521                try:
522                    tempx = qmax
523                    tempy = model.run(qmax)
524                    theory.x.append(tempx)
525                    theory.y.append(tempy)
526                except:
527                        wx.PostEvent(self.parent, StatusEvent(status="fitting \
528                        skipping point x %g %s" %(qmax, sys.exc_value)))
529                try:
530                    #print "fitting redraw for plot thoery .x",theory.x
531                    #print "fitting redraw for plot thoery .y",theory.y
532                    #print "fitting redraw for plot thoery .dy",theory.dy
533                    #rom sans.guicomm.events import NewPlotEvent
534                    wx.PostEvent(self.parent, NewPlotEvent(plot=theory, title="Analytical model"))
535                except:
536                    raise
537                    print "SimView.complete1D: could not import sans.guicomm.events"
538           
539           
540    def _on_model_menu(self, evt):
541        """
542            Plot a theory from a model selected from the menu
543        """
544        name="Model View"
545        model=evt.modelinfo.model()
546        description=evt.modelinfo.description
547        self.fit_panel.add_model_page(model,description,name)       
548        self.draw_model(model)
549       
550    def draw_model(self,model):
551        """
552             draw model with default data value
553        """
554        x = pylab.arange(0.001, 0.1, 0.001)
555        xlen = len(x)
556        dy = numpy.zeros(xlen)
557        y = numpy.zeros(xlen)
558       
559        for i in range(xlen):
560            y[i] = model.run(x[i])
561            dy[i] = math.sqrt(math.fabs(y[i]))
562        try:
563           
564            new_plot = Theory1D(x, y)
565            new_plot.name = "Model"
566            new_plot.xaxis("\\rm{Q}", 'A^{-1}')
567            new_plot.yaxis("\\rm{Intensity} ","cm^{-1}")
568            new_plot.group_id ="Fitness"
569         
570            wx.PostEvent(self.parent, NewPlotEvent(plot=new_plot, title="Analytical model"))
571        except:
572            print "SimView.complete1D: could not import sans.guicomm.events\n %s" % sys.exc_value
573            logging.error("SimView.complete1D: could not import sans.guicomm.events\n %s" % sys.exc_value)
574
575if __name__ == "__main__":
576    i = Plugin()
577   
578   
579   
580   
Note: See TracBrowser for help on using the repository browser.