source: sasview/prview/perspectives/pr/pr.py @ b2d9826

ESS_GUIESS_GUI_DocsESS_GUI_batch_fittingESS_GUI_bumps_abstractionESS_GUI_iss1116ESS_GUI_iss879ESS_GUI_iss959ESS_GUI_openclESS_GUI_orderingESS_GUI_sync_sascalccostrafo411magnetic_scattrelease-4.1.1release-4.1.2release-4.2.2release_4.0.1ticket-1009ticket-1094-headlessticket-1242-2d-resolutionticket-1243ticket-1249ticket885unittest-saveload
Last change on this file since b2d9826 was 053c769, checked in by Gervaise Alina <gervyh@…>, 14 years ago

add todo in pr

  • Property mode set to 100644
File size: 48.8 KB
Line 
1
2################################################################################
3#This software was developed by the University of Tennessee as part of the
4#Distributed Data Analysis of Neutron Scattering Experiments (DANSE)
5#project funded by the US National Science Foundation.
6#
7#See the license text in license.txt
8#
9#copyright 2009, University of Tennessee
10################################################################################
11
12
13# Make sure the option of saving each curve is available
14# Use the I(q) curve as input and compare the output to P(r)
15
16import os
17import sys
18import wx
19import logging
20import time
21import copy
22import math
23import numpy
24import pylab
25from sans.guiframe.dataFitting import Data1D
26from sans.guiframe.events import NewPlotEvent
27from sans.guiframe.events import StatusEvent
28from sans.guiframe.gui_style import GUIFRAME_ID   
29from sans.pr.invertor import Invertor
30from DataLoader.loader import Loader
31
32from pr_widgets import load_error
33from sans.guiframe.plugin_base import PluginBase
34
35
36PR_FIT_LABEL       = r"$P_{fit}(r)$"
37PR_LOADED_LABEL    = r"$P_{loaded}(r)$"
38IQ_DATA_LABEL      = r"$I_{obs}(q)$"
39IQ_FIT_LABEL       = r"$I_{fit}(q)$"
40IQ_SMEARED_LABEL   = r"$I_{smeared}(q)$"
41
42
43class Plugin(PluginBase):
44    """
45    """
46    DEFAULT_ALPHA = 0.0001
47    DEFAULT_NFUNC = 10
48    DEFAULT_DMAX  = 140.0
49   
50    def __init__(self, standalone=True):
51        PluginBase.__init__(self, name="Pr inversion", standalone=standalone)
52        ## Simulation window manager
53        self.simview = None
54       
55        ## State data
56        self.alpha      = self.DEFAULT_ALPHA
57        self.nfunc      = self.DEFAULT_NFUNC
58        self.max_length = self.DEFAULT_DMAX
59        self.q_min      = None
60        self.q_max      = None
61        self.has_bck    = False
62        self.slit_height = 0
63        self.slit_width  = 0
64        ## Remember last plottable processed
65        self.last_data  = "sphere_60_q0_2.txt"
66        self._current_file_data = None
67        ## Time elapsed for last computation [sec]
68        # Start with a good default
69        self.elapsed = 0.022
70        self.iq_data_shown = False
71       
72        ## Current invertor
73        self.invertor    = None
74        self.pr          = None
75        # Copy of the last result in case we need to display it.
76        self._last_pr    = None
77        self._last_out   = None
78        self._last_cov   = None
79        ## Calculation thread
80        self.calc_thread = None
81        ## Estimation thread
82        self.estimation_thread = None
83        ## Result panel
84        self.control_panel = None
85        ## Currently views plottable
86        self.current_plottable = None
87        ## Number of P(r) points to display on the output plot
88        self._pr_npts = 51
89        ## Flag to let the plug-in know that it is running standalone
90        self.standalone = standalone
91        self._normalize_output = False
92        self._scale_output_unity = False
93       
94        ## List of added P(r) plots
95        self._added_plots = {}
96        self._default_Iq  = {}
97       
98        # Associate the inversion state reader with .prv files
99        from inversion_state import Reader
100         
101        # Create a CanSAS/Pr reader
102        self.state_reader = Reader(self.set_state)
103        self._extensions = '.prv'
104        l = Loader()
105        l.associate_file_reader('.prv', self.state_reader)
106        l.associate_file_reader(".svs", self.state_reader)
107               
108        # Log startup
109        logging.info("Pr(r) plug-in started")
110       
111    def get_data(self):
112        """
113        """
114        return self.current_plottable
115   
116    def set_state(self, state=None, datainfo=None):
117        """
118        Call-back method for the inversion state reader.
119        This method is called when a .prv file is loaded.
120       
121        :param state: InversionState object
122        :param datainfo: Data1D object [optional]
123       
124        """
125        try:
126            if datainfo.__class__.__name__ == 'list':
127                if len(datainfo) >= 1:
128                    data = datainfo[0]
129                else:
130                    data = None
131            else:
132                data = datainfo
133            if data is None:
134                raise RuntimeError, "Pr.set_state: datainfo parameter cannot be None in standalone mode"
135           
136            # Ensuring that plots are coordinated correctly
137            t = time.localtime(data.meta_data['prstate'].timestamp)
138            time_str = time.strftime("%b %d %H:%M", t)
139           
140            # Check that no time stamp is already appended
141            max_char = data.meta_data['prstate'].file.find("[")
142            if max_char < 0:
143                max_char = len(data.meta_data['prstate'].file)
144           
145            datainfo.meta_data['prstate'].file = data.meta_data['prstate'].file[0:max_char] +' [' + time_str + ']'
146            data.filename = data.meta_data['prstate'].file
147            # TODO:
148            #remove this call when state save all information about the gui data
149            # such as ID , Group_ID, etc...
150            #make self.current_plottable = datainfo directly
151            self.current_plottable = self.parent.create_gui_data(data,None)
152            self.current_plottable.group_id = data.meta_data['prstate'].file
153           
154            # Make sure the user sees the P(r) panel after loading
155            #self.parent.set_perspective(self.perspective) 
156            self.on_perspective(event=None)   
157           
158            # Load the P(r) results
159            #state = self.state_reader.get_state()
160            wx.PostEvent(self.parent, NewPlotEvent(plot=self.current_plottable,
161                                        title=self.current_plottable.title))
162            self.control_panel.set_state(state)
163        except:
164            logging.error("prview.set_state: %s" % sys.exc_value)
165
166 
167    def help(self, evt):
168        """
169        Show a general help dialog.
170       
171        :TODO: replace the text with a nice image
172       
173        """
174        from inversion_panel import HelpDialog
175        dialog = HelpDialog(None, -1)
176        if dialog.ShowModal() == wx.ID_OK:
177            dialog.Destroy()
178        else:
179            dialog.Destroy()
180   
181    def _fit_pr(self, evt):
182        """
183        """
184        from sans.pr.invertor import Invertor
185        # Generate P(r) for sphere
186        radius = 60.0
187        d_max  = 2*radius
188       
189        r = pylab.arange(0.01, d_max, d_max/51.0)
190        M = len(r)
191        y = numpy.zeros(M)
192        pr_err = numpy.zeros(M)
193       
194        sum = 0.0
195        for j in range(M):
196            value = self.pr_theory(r[j], radius)
197            sum += value
198            y[j] = value
199            pr_err[j] = math.sqrt(y[j])
200
201        y = y/sum*d_max/len(r)
202
203        # Perform fit
204        pr = Invertor()
205        pr.d_max = d_max
206        pr.alpha = 0
207        pr.x = r
208        pr.y = y
209        pr.err = pr_err
210        out, cov = pr.pr_fit()
211        for i in range(len(out)):
212            print "%g +- %g" % (out[i], math.sqrt(cov[i][i]))
213       
214        # Show input P(r)
215        title = "Pr"
216        new_plot = Data1D(pr.x, pr.y, dy=pr.err)
217        new_plot.name = "P_{obs}(r)"
218        new_plot.xaxis("\\rm{r}", 'A')
219        new_plot.yaxis("\\rm{P(r)} ","cm^{-3}")
220        group_id = "P_{obs}(r)"
221        if group_id not in new_plot.group_id:
222            new_plot.group_id.append(group_id)
223        new_plot.id = "P_{obs}(r)"
224        new_plot.title = title
225        wx.PostEvent(self.parent, NewPlotEvent(plot=new_plot, title=title))
226
227        # Show P(r) fit
228        self.show_pr(out, pr)
229       
230        # Show I(q) fit
231        q = pylab.arange(0.001, 0.1, 0.01/51.0)
232        self.show_iq(out, pr, q)
233       
234    def show_shpere(self, x, radius=70.0, x_range=70.0):
235        """
236        """
237        # Show P(r)
238        y_true = numpy.zeros(len(x))
239
240        sum_true = 0.0
241        for i in range(len(x)):
242            y_true[i] = self.pr_theory(x[i], radius)           
243            sum_true += y_true[i]
244           
245        y_true = y_true/sum_true*x_range/len(x)
246       
247        # Show the theory P(r)
248        new_plot = Data1D(x, y_true)
249        new_plot.symbol = GUIFRAME_ID.CURVE_SYMBOL_NUM
250        new_plot.name = "P_{true}(r)"
251        new_plot.xaxis("\\rm{r}", 'A')
252        new_plot.yaxis("\\rm{P(r)} ","cm^{-3}")
253       
254        #Put this call in plottables/guitools   
255        wx.PostEvent(self.parent, NewPlotEvent(plot=new_plot, title="Sphere P(r)"))
256       
257       
258    def get_npts(self):
259        """
260        Returns the number of points in the I(q) data
261        """
262        try:
263            return len(self.pr.x)
264        except:
265            return 0
266       
267    def show_iq(self, out, pr, q=None):
268        """
269        """ 
270        qtemp = pr.x
271        if not q==None:
272            qtemp = q
273
274        # Make a plot
275        maxq = -1
276        for q_i in qtemp:
277            if q_i>maxq:
278                maxq=q_i
279               
280        minq = 0.001
281       
282        # Check for user min/max
283        if not pr.q_min==None:
284            minq = pr.q_min
285        if not pr.q_max==None:
286            maxq = pr.q_max
287               
288        x = pylab.arange(minq, maxq, maxq/301.0)
289        y = numpy.zeros(len(x))
290        err = numpy.zeros(len(x))
291        for i in range(len(x)):
292            value = pr.iq(out, x[i])
293            y[i] = value
294            try:
295                err[i] = math.sqrt(math.fabs(value))
296            except:
297                err[i] = 1.0
298                print "Error getting error", value, x[i]
299               
300        new_plot = Data1D(x, y)
301        new_plot.symbol = GUIFRAME_ID.CURVE_SYMBOL_NUM
302        new_plot.name = IQ_FIT_LABEL
303        new_plot.xaxis("\\rm{Q}", 'A^{-1}')
304        new_plot.yaxis("\\rm{Intensity} ","cm^{-1}")
305        title = "I(q)"
306        new_plot.title = title
307       
308        # If we have a group ID, use it
309        if pr.info.has_key("plot_group_id"):
310            if len( pr.info["plot_group_id"]) > 0:
311                index =  len( pr.info["plot_group_id"]) - 1
312                new_plot.group_id.append( pr.info["plot_group_id"][index])
313        new_plot.id = IQ_FIT_LABEL
314        #new_plot.group_id.append(2)
315        wx.PostEvent(self.parent, NewPlotEvent(plot=new_plot, title=title))
316       
317        # If we have used slit smearing, plot the smeared I(q) too
318        if pr.slit_width>0 or pr.slit_height>0:
319            x = pylab.arange(minq, maxq, maxq/301.0)
320            y = numpy.zeros(len(x))
321            err = numpy.zeros(len(x))
322            for i in range(len(x)):
323                value = pr.iq_smeared(out, x[i])
324                y[i] = value
325                try:
326                    err[i] = math.sqrt(math.fabs(value))
327                except:
328                    err[i] = 1.0
329                    print "Error getting error", value, x[i]
330                   
331            new_plot = Data1D(x, y)
332            new_plot.symbol = GUIFRAME_ID.CURVE_SYMBOL_NUM
333            new_plot.name = IQ_SMEARED_LABEL
334            new_plot.xaxis("\\rm{Q}", 'A^{-1}')
335            new_plot.yaxis("\\rm{Intensity} ","cm^{-1}")
336            # If we have a group ID, use it
337            if pr.info.has_key("plot_group_id"):
338                if len( pr.info["plot_group_id"]) > 0:
339                    index =  len( pr.info["plot_group_id"]) - 1
340                    new_plot.group_id.append( pr.info["plot_group_id"][index])
341           
342            new_plot.id = IQ_SMEARED_LABEL
343            new_plot.title = title
344            #new_plot.group_id.append(2)
345            wx.PostEvent(self.parent, NewPlotEvent(plot=new_plot, title=title))
346       
347       
348    def _on_pr_npts(self, evt):
349        """
350        Redisplay P(r) with a different number of points
351        """   
352        from inversion_panel import PrDistDialog
353        dialog = PrDistDialog(None, -1)
354        dialog.set_content(self._pr_npts)
355        if dialog.ShowModal() == wx.ID_OK:
356            self._pr_npts= dialog.get_content()
357            dialog.Destroy()
358            self.show_pr(self._last_out, self._last_pr, self._last_cov)
359        else:
360            dialog.Destroy()
361       
362       
363    def show_pr(self, out, pr, cov=None):
364        """
365        """     
366        # Show P(r)
367        x = pylab.arange(0.0, pr.d_max, pr.d_max/self._pr_npts)
368   
369        y = numpy.zeros(len(x))
370        dy = numpy.zeros(len(x))
371        y_true = numpy.zeros(len(x))
372
373        sum = 0.0
374        pmax = 0.0
375        cov2 = numpy.ascontiguousarray(cov)
376       
377        for i in range(len(x)):
378            if cov2==None:
379                value = pr.pr(out, x[i])
380            else:
381                (value, dy[i]) = pr.pr_err(out, cov2, x[i])
382            sum += value*pr.d_max/len(x)
383           
384            # keep track of the maximum P(r) value
385            if value>pmax:
386                pmax = value
387               
388            y[i] = value
389               
390        if self._normalize_output==True:
391            y = y/sum
392            dy = dy/sum
393        elif self._scale_output_unity==True:
394            y = y/pmax
395            dy = dy/pmax
396       
397        if cov2==None:
398            new_plot = Data1D(x, y)
399            new_plot.symbol = GUIFRAME_ID.CURVE_SYMBOL_NUM
400        else:
401            new_plot = Data1D(x, y, dy=dy)
402        new_plot.name = PR_FIT_LABEL
403        new_plot.xaxis("\\rm{r}", 'A')
404        new_plot.yaxis("\\rm{P(r)} ","cm^{-3}")
405        new_plot.title = "P(r) fit"
406        new_plot.id = PR_FIT_LABEL
407        # Make sure that the plot is linear
408        new_plot.xtransform = "x"
409        new_plot.ytransform = "y"                 
410        wx.PostEvent(self.parent, NewPlotEvent(plot=new_plot, title="P(r) fit"))
411       
412        return x, pr.d_max
413       
414               
415    def load(self, data):
416        """
417        Load data. This will eventually be replaced
418        by our standard DataLoader class.
419        """
420        class FileData:
421            x = None
422            y = None
423            err = None
424            path = None
425           
426            def __init__(self, path):
427                self.path = path
428               
429        self._current_file_data = FileData(data.path)
430       
431        # Use data loader to load file
432        #dataread = Loader().load(path)
433        dataread = data
434        # Notify the user if we could not read the file
435        if dataread is None:
436            raise RuntimeError, "Invalid data"
437           
438        x = None
439        y = None
440        err = None
441        if dataread.__class__.__name__ == 'Data1D':
442            x = dataread.x
443            y = dataread.y
444            err = dataread.dy
445        else:
446            if isinstance(dataread, list) and len(dataread)>0:
447                x = dataread[0].x
448                y = dataread[0].y
449                err = dataread[0].dy
450                msg = "PrView only allows a single data set at a time. "
451                msg += "Only the first data set was loaded." 
452                wx.PostEvent(self.parent, StatusEvent(status=msg))
453            else:
454                if dataread is None:
455                    return x, y, err
456                raise RuntimeError, "This tool can only read 1D data"
457       
458        self._current_file_data.x = x
459        self._current_file_data.y = y
460        self._current_file_data.err = err
461        return x, y, err
462               
463    def load_columns(self, path = "sphere_60_q0_2.txt"):
464        """
465        Load 2- or 3- column ascii
466        """
467        import numpy, math, sys
468        # Read the data from the data file
469        data_x   = numpy.zeros(0)
470        data_y   = numpy.zeros(0)
471        data_err = numpy.zeros(0)
472        scale    = None
473        min_err  = 0.0
474        if not path == None:
475            input_f = open(path,'r')
476            buff    = input_f.read()
477            lines   = buff.split('\n')
478            for line in lines:
479                try:
480                    toks = line.split()
481                    x = float(toks[0])
482                    y = float(toks[1])
483                    if len(toks)>2:
484                        err = float(toks[2])
485                    else:
486                        if scale==None:
487                            scale = 0.05*math.sqrt(y)
488                            #scale = 0.05/math.sqrt(y)
489                            min_err = 0.01*y
490                        err = scale*math.sqrt(y)+min_err
491                        #err = 0
492                       
493                    data_x = numpy.append(data_x, x)
494                    data_y = numpy.append(data_y, y)
495                    data_err = numpy.append(data_err, err)
496                except:
497                    pass
498                   
499        if not scale==None:
500            message = "The loaded file had no error bars, statistical errors are assumed."
501            wx.PostEvent(self.parent, StatusEvent(status=message))
502        else:
503            wx.PostEvent(self.parent, StatusEvent(status=''))
504                       
505        return data_x, data_y, data_err     
506       
507    def load_abs(self, path):
508        """
509        Load an IGOR .ABS reduced file
510       
511        :param path: file path
512       
513        :return: x, y, err vectors
514       
515        """
516        import numpy, math, sys
517        # Read the data from the data file
518        data_x   = numpy.zeros(0)
519        data_y   = numpy.zeros(0)
520        data_err = numpy.zeros(0)
521        scale    = None
522        min_err  = 0.0
523       
524        data_started = False
525        if not path == None:
526            input_f = open(path,'r')
527            buff    = input_f.read()
528            lines   = buff.split('\n')
529            for line in lines:
530                if data_started==True:
531                    try:
532                        toks = line.split()
533                        x = float(toks[0])
534                        y = float(toks[1])
535                        if len(toks)>2:
536                            err = float(toks[2])
537                        else:
538                            if scale==None:
539                                scale = 0.05*math.sqrt(y)
540                                #scale = 0.05/math.sqrt(y)
541                                min_err = 0.01*y
542                            err = scale*math.sqrt(y)+min_err
543                            #err = 0
544                           
545                        data_x = numpy.append(data_x, x)
546                        data_y = numpy.append(data_y, y)
547                        data_err = numpy.append(data_err, err)
548                    except:
549                        pass
550                elif line.find("The 6 columns")>=0:
551                    data_started = True     
552                   
553        if not scale==None:
554            message = "The loaded file had no error bars, statistical errors are assumed."
555            wx.PostEvent(self.parent, StatusEvent(status=message))
556        else:
557            wx.PostEvent(self.parent, StatusEvent(status=''))
558                       
559        return data_x, data_y, data_err     
560       
561    def pr_theory(self, r, R):
562        """ 
563        """
564        if r<=2*R:
565            return 12.0* ((0.5*r/R)**2) * ((1.0-0.5*r/R)**2) * ( 2.0 + 0.5*r/R )
566        else:
567            return 0.0
568
569    def get_context_menu(self, plotpanel=None):
570        """
571        Get the context menu items available for P(r)
572       
573        :param graph: the Graph object to which we attach the context menu
574       
575        :return: a list of menu items with call-back function
576       
577        """
578        graph = plotpanel.graph
579        # Look whether this Graph contains P(r) data
580        if graph.selected_plottable not in plotpanel.plots:
581            return []
582        item = plotpanel.plots[graph.selected_plottable]
583        if item.id == PR_FIT_LABEL:
584            #add_data_hint = "Load a data file and display it on this plot"
585            #["Add P(r) data",add_data_hint , self._on_add_data],
586            change_n_hint = "Change the number of"
587            change_n_hint += " points on the P(r) output"
588            change_n_label = "Change number of P(r) points"
589            m_list = [[change_n_label, change_n_hint , self._on_pr_npts]]
590
591            if self._scale_output_unity or self._normalize_output:
592                hint = "Let the output P(r) keep the scale of the data"
593                m_list.append(["Disable P(r) scaling", hint, 
594                               self._on_disable_scaling])
595            if not self._scale_output_unity:
596                m_list.append(["Scale P_max(r) to unity", 
597                               "Scale P(r) so that its maximum is 1", 
598                               self._on_scale_unity])
599            if not self._normalize_output:
600                m_list.append(["Normalize P(r) to unity", 
601                               "Normalize the integral of P(r) to 1", 
602                               self._on_normalize])
603               
604            return m_list
605             
606        elif item.id in [PR_LOADED_LABEL, IQ_DATA_LABEL, IQ_FIT_LABEL,
607                          IQ_SMEARED_LABEL]:
608            return []
609        elif item.id == graph.selected_plottable:
610               if not self.standalone and issubclass(item.__class__, Data1D):
611                return [["Compute P(r)", 
612                             "Compute P(r) from distribution", 
613                             self._on_context_inversion]]     
614               
615        return []
616
617    def _on_disable_scaling(self, evt):
618        """
619        Disable P(r) scaling
620           
621        :param evt: Menu event
622       
623        """
624        self._normalize_output = False
625        self._scale_output_unity = False
626        self.show_pr(self._last_out, self._last_pr, self._last_cov)
627       
628        # Now replot the original added data
629        for plot in self._added_plots:
630            self._added_plots[plot].y = numpy.copy(self._default_Iq[plot])
631            wx.PostEvent(self.parent, 
632                         NewPlotEvent(plot=self._added_plots[plot], 
633                                      title=self._added_plots[plot].name,
634                                                   update=True))       
635       
636        # Need the update flag in the NewPlotEvent to protect against
637        # the plot no longer being there...
638       
639    def _on_normalize(self, evt):
640        """
641        Normalize the area under the P(r) curve to 1.
642        This operation is done for all displayed plots.
643       
644        :param evt: Menu event
645       
646        """
647        self._normalize_output = True
648        self._scale_output_unity = False
649           
650        self.show_pr(self._last_out, self._last_pr, self._last_cov)
651       
652        # Now scale the added plots too
653        for plot in self._added_plots:
654            sum = numpy.sum(self._added_plots[plot].y)
655            npts = len(self._added_plots[plot].x)
656            sum *= self._added_plots[plot].x[npts-1]/npts
657            y = self._added_plots[plot].y/sum
658           
659            new_plot = Data1D(self._added_plots[plot].x, y)
660            new_plot.symbol = GUIFRAME_ID.CURVE_SYMBOL_NUM
661            index  = len(self._added_plots[plot].group_id) - 1
662            if group_id not in new_plot.group_id:
663                new_plot.group_id.append(group_id)
664            new_plot.id = self._added_plots[plot].id
665            new_plot.title = self._added_plots[plot].title
666            new_plot.name = self._added_plots[plot].name
667            new_plot.xaxis("\\rm{r}", 'A')
668            new_plot.yaxis("\\rm{P(r)} ","cm^{-3}")
669           
670            wx.PostEvent(self.parent, 
671                         NewPlotEvent(plot=new_plot, update=True,
672                                         title=self._added_plots[plot].name))
673       
674    def _on_scale_unity(self, evt):
675        """
676        Scale the maximum P(r) value on each displayed plot to 1.
677       
678        :param evt: Menu event
679       
680        """
681        self._scale_output_unity = True
682        self._normalize_output = False
683           
684        self.show_pr(self._last_out, self._last_pr, self._last_cov)
685       
686        # Now scale the added plots too
687        for plot in self._added_plots:
688            _max = 0
689            for y in self._added_plots[plot].y:
690                if y>_max: 
691                    _max = y
692            y = self._added_plots[plot].y/_max
693           
694            new_plot = Data1D(self._added_plots[plot].x, y)
695            new_plot.symbol = GUIFRAME_ID.CURVE_SYMBOL_NUM
696            new_plot.name = self._added_plots[plot].name
697            new_plot.xaxis("\\rm{r}", 'A')
698            new_plot.yaxis("\\rm{P(r)} ","cm^{-3}")
699           
700            wx.PostEvent(self.parent, 
701                         NewPlotEvent(plot=new_plot, update=True,
702                                title=self._added_plots[plot].name))       
703       
704       
705    def _on_add_data(self, evt):
706        """
707        Add a data curve to the plot
708       
709        :WARNING: this will be removed once guiframe.plotting has
710             its full functionality
711        """
712        path = self.choose_file()
713        if path==None:
714            return
715       
716        #x, y, err = self.parent.load_ascii_1D(path)
717        # Use data loader to load file
718        try:
719            dataread = Loader().load(path)
720            x = None
721            y = None
722            err = None
723            if dataread.__class__.__name__ == 'Data1D':
724                x = dataread.x
725                y = dataread.y
726                err = dataread.dy
727            else:
728                if isinstance(dataread, list) and len(dataread)>0:
729                    x = dataread[0].x
730                    y = dataread[0].y
731                    err = dataread[0].dy
732                    msg = "PrView only allows a single data set at a time. "
733                    msg += "Only the first data set was loaded." 
734                    wx.PostEvent(self.parent, StatusEvent(status=msg))
735                else:
736                    msg = "This tool can only read 1D data"
737                    wx.PostEvent(self.parent, StatusEvent(status=msg))
738                    return
739           
740        except:
741            wx.PostEvent(self.parent, StatusEvent(status=sys.exc_value))
742            return
743       
744        filename = os.path.basename(path)
745
746        new_plot = Data1D(x, y)
747        new_plot.symbol = GUIFRAME_ID.CURVE_SYMBOL_NUM
748        new_plot.name = filename
749        new_plot.xaxis("\\rm{r}", 'A')
750        new_plot.yaxis("\\rm{P(r)} ","cm^{-3}")
751           
752        # Store a ref to the plottable for later use
753        self._added_plots[filename] = new_plot
754        self._default_Iq[filename]  = numpy.copy(y)
755       
756        wx.PostEvent(self.parent, NewPlotEvent(plot=new_plot, title=filename))
757       
758       
759
760    def start_thread(self):
761        """
762        """
763        from pr_thread import CalcPr
764        from copy import deepcopy
765       
766        # If a thread is already started, stop it
767        if self.calc_thread != None and self.calc_thread.isrunning():
768            self.calc_thread.stop()
769               
770        pr = self.pr.clone()
771        self.calc_thread = CalcPr(pr, self.nfunc,
772                                   error_func=self._thread_error, 
773                                   completefn=self._completed, updatefn=None)
774        self.calc_thread.queue()
775        self.calc_thread.ready(2.5)
776   
777    def _thread_error(self, error):
778        """
779        """
780        wx.PostEvent(self.parent, StatusEvent(status=error))
781   
782    def _estimate_completed(self, alpha, message, elapsed):
783        """
784        Parameter estimation completed,
785        display the results to the user
786       
787        :param alpha: estimated best alpha
788        :param elapsed: computation time
789       
790        """
791        # Save useful info
792        self.elapsed = elapsed
793        self.control_panel.alpha_estimate = alpha
794        if not message==None:
795            wx.PostEvent(self.parent, StatusEvent(status=str(message)))
796           
797        self.perform_estimateNT()
798   
799
800   
801    def _estimateNT_completed(self, nterms, alpha, message, elapsed):
802        """
803        Parameter estimation completed,
804        display the results to the user
805       
806        :param alpha: estimated best alpha
807        :param nterms: estimated number of terms
808        :param elapsed: computation time
809       
810        """
811        # Save useful info
812        self.elapsed = elapsed
813        self.control_panel.nterms_estimate = nterms
814        self.control_panel.alpha_estimate = alpha
815        if not message==None:
816            wx.PostEvent(self.parent, StatusEvent(status=str(message)))
817   
818    def _completed(self, out, cov, pr, elapsed):
819        """
820        Method called with the results when the inversion
821        is done
822       
823        :param out: output coefficient for the base functions
824        :param cov: covariance matrix
825        :param pr: Invertor instance
826        :param elapsed: time spent computing
827       
828        """
829        from copy import deepcopy
830        # Save useful info
831        self.elapsed = elapsed
832        # Keep a copy of the last result
833        self._last_pr  = pr.clone()
834        self._last_out = out
835        self._last_cov = cov
836       
837        # Save Pr invertor
838        self.pr = pr
839       
840        #message = "Computation completed in"
841        #message +=  %g seconds [chi2=%g]" % (elapsed, pr.chi2)
842        #wx.PostEvent(self.parent, StatusEvent(status=message))
843
844        cov = numpy.ascontiguousarray(cov)
845
846        # Show result on control panel
847        self.control_panel.chi2 = pr.chi2
848        self.control_panel.elapsed = elapsed
849        self.control_panel.oscillation = pr.oscillations(out)
850        #print "OSCILL", pr.oscillations(out)
851        #print "PEAKS:", pr.get_peaks(out)
852        self.control_panel.positive = pr.get_positive(out)
853        self.control_panel.pos_err  = pr.get_pos_err(out, cov)
854        self.control_panel.rg = pr.rg(out)
855        self.control_panel.iq0 = pr.iq0(out)
856        self.control_panel.bck = pr.background
857       
858        if False:
859            for i in range(len(out)):
860                try:
861                    print "%d: %g +- %g" % (i, out[i],
862                                             math.sqrt(math.fabs(cov[i][i])))
863                except: 
864                    print sys.exc_value
865                    print "%d: %g +- ?" % (i, out[i])       
866       
867            # Make a plot of I(q) data
868            new_plot = Data1D(self.pr.x, self.pr.y, dy=self.pr.err)
869            new_plot.name = IQ_DATA_LABEL
870            new_plot.xaxis("\\rm{Q}", 'A^{-1}')
871            new_plot.yaxis("\\rm{Intensity} ","cm^{-1}")
872            if pr.info.has_key("plot_group_id"):
873                new_plot.group_id.append(pr.info["plot_group_id"])
874            new_plot.id = IQ_DATA_LABEL
875            wx.PostEvent(self.parent, NewPlotEvent(plot=new_plot, title="Iq"))
876               
877        # Show I(q) fit
878        self.show_iq(out, self.pr)
879       
880        # Show P(r) fit
881        x_values, x_range = self.show_pr(out, self.pr, cov) 
882       
883        # Popup result panel
884        #result_panel = InversionResults(self.parent,
885        #-1, style=wx.RAISED_BORDER)
886       
887    def show_data(self, path=None, data=None, reset=False):
888        """
889        Show data read from a file
890       
891        :param path: file path
892        :param reset: if True all other plottables will be cleared
893       
894        """
895        #if path is not None:
896        if data is not None:
897            try:
898                pr = self._create_file_pr(data)
899            except:
900                status = "Problem reading data: %s" % sys.exc_value
901                wx.PostEvent(self.parent, StatusEvent(status=status))
902                raise RuntimeError, status
903               
904            # If the file contains nothing, just return
905            if pr is None:
906                raise RuntimeError, "Loaded data is invalid"
907           
908            self.pr = pr
909       
910        # Make a plot of I(q) data
911        if self.pr.err == None:
912            new_plot = Data1D(self.pr.x, self.pr.y)
913            new_plot.symbol = GUIFRAME_ID.CURVE_SYMBOL_NUM
914        else:
915            new_plot = Data1D(self.pr.x, self.pr.y, dy=self.pr.err)
916        new_plot.name = IQ_DATA_LABEL
917        new_plot.xaxis("\\rm{Q}", 'A^{-1}')
918        new_plot.yaxis("\\rm{Intensity} ","cm^{-1}")
919        new_plot.interactive = True
920        new_plot.group_id.append(IQ_DATA_LABEL)
921        new_plot.id = IQ_DATA_LABEL
922        new_plot.title = "I(q)"
923        wx.PostEvent(self.parent, 
924                     NewPlotEvent(plot=new_plot, title="I(q)", reset=reset))
925       
926        self.current_plottable = new_plot
927        # Get Q range
928        self.control_panel.q_min = self.pr.x.min()
929        self.control_panel.q_max = self.pr.x.max()
930           
931    def save_data(self, filepath, prstate=None):
932        """
933        Save data in provided state object.
934       
935        :TODO: move the state code away from inversion_panel and move it here.
936                Then remove the "prstate" input and make this method private.
937               
938        :param filepath: path of file to write to
939        :param prstate: P(r) inversion state
940       
941        """
942        #TODO: do we need this or can we use DataLoader.loader.save directly?
943       
944        # Add output data and coefficients to state
945        prstate.coefficients = self._last_out
946        prstate.covariance = self._last_cov
947       
948        # Write the output to file
949        # First, check that the data is of the right type
950        if issubclass(self.current_plottable.__class__,
951                       DataLoader.data_info.Data1D):
952            self.state_reader.write(filepath, self.current_plottable, prstate)
953        else:
954            msg = "pr.save_data: the data being saved is not a"
955            msg += " DataLoader.data_info.Data1D object" 
956            raise RuntimeError, msg
957       
958       
959    def setup_plot_inversion(self, alpha, nfunc, d_max, q_min=None, q_max=None, 
960                             bck=False, height=0, width=0):
961        """
962        """
963        self.alpha = alpha
964        self.nfunc = nfunc
965        self.max_length = d_max
966        self.q_min = q_min
967        self.q_max = q_max
968        self.has_bck = bck
969        self.slit_height = height
970        self.slit_width  = width
971       
972        try:
973            pr = self._create_plot_pr()
974            if not pr==None:
975                self.pr = pr
976                self.perform_inversion()
977        except:
978            wx.PostEvent(self.parent, StatusEvent(status=sys.exc_value))
979
980    def estimate_plot_inversion(self, alpha, nfunc, d_max, 
981                                q_min=None, q_max=None, 
982                                bck=False, height=0, width=0):
983        """
984        """
985        self.alpha = alpha
986        self.nfunc = nfunc
987        self.max_length = d_max
988        self.q_min = q_min
989        self.q_max = q_max
990        self.has_bck = bck
991        self.slit_height = height
992        self.slit_width  = width
993       
994        try:
995            pr = self._create_plot_pr()
996            if not pr==None:
997                self.pr = pr
998                self.perform_estimate()
999        except:
1000            wx.PostEvent(self.parent, StatusEvent(status=sys.exc_value))           
1001
1002    def _create_plot_pr(self, estimate=False):
1003        """
1004        Create and prepare invertor instance from
1005        a plottable data set.
1006       
1007        :param path: path of the file to read in
1008       
1009        """
1010        # Sanity check
1011        if self.current_plottable is None:
1012            msg = "Please load a valid data set before proceeding."
1013            wx.PostEvent(self.parent, StatusEvent(status=msg)) 
1014            return None   
1015       
1016        # Get the data from the chosen data set and perform inversion
1017        pr = Invertor()
1018        pr.d_max = self.max_length
1019        pr.alpha = self.alpha
1020        pr.q_min = self.q_min
1021        pr.q_max = self.q_max
1022        pr.x = self.current_plottable.x
1023        pr.y = self.current_plottable.y
1024        pr.has_bck = self.has_bck
1025        pr.slit_height = self.slit_height
1026        pr.slit_width = self.slit_width
1027       
1028        # Keep track of the plot window title to ensure that
1029        # we can overlay the plots
1030        if self.current_plottable.group_id:
1031            index = len(self.current_plottable.group_id) - 1
1032            group_id = self.current_plottable.group_id[index]
1033            pr.info["plot_group_id"] = self.current_plottable.group_id
1034       
1035        # Fill in errors if none were provided
1036        err = self.current_plottable.dy
1037        all_zeros = True
1038        if err == None:
1039            err = numpy.zeros(len(pr.y)) 
1040        else:   
1041            for i in range(len(err)):
1042                if err[i]>0:
1043                    all_zeros = False
1044       
1045        if all_zeros:       
1046            scale = None
1047            min_err = 0.0
1048            for i in range(len(pr.y)):
1049                # Scale the error so that we can fit over several decades of Q
1050                if scale==None:
1051                    scale = 0.05*math.sqrt(pr.y[i])
1052                    min_err = 0.01*pr.y[i]
1053                err[i] = scale*math.sqrt( math.fabs(pr.y[i]) ) + min_err
1054            message = "The loaded file had no error bars, "
1055            message += "statistical errors are assumed."
1056            wx.PostEvent(self.parent, StatusEvent(status=message))
1057
1058        pr.err = err
1059       
1060        return pr
1061
1062         
1063    def setup_file_inversion(self, alpha, nfunc, d_max, data,
1064                             path=None, q_min=None, q_max=None, 
1065                             bck=False, height=0, width=0):
1066        """
1067        """
1068        self.alpha = alpha
1069        self.nfunc = nfunc
1070        self.max_length = d_max
1071        self.q_min = q_min
1072        self.q_max = q_max
1073        self.has_bck = bck
1074        self.slit_height = height
1075        self.slit_width  = width
1076       
1077        try:
1078            #pr = self._create_file_pr(path)
1079            pr = self._create_file_pr(data)
1080            if not pr==None:
1081                self.pr = pr
1082                self.perform_inversion()
1083        except:
1084            wx.PostEvent(self.parent, StatusEvent(status=sys.exc_value))
1085         
1086    def estimate_file_inversion(self, alpha, nfunc, d_max, data,
1087                                path=None, q_min=None, q_max=None, 
1088                                bck=False, height=0, width=0):
1089        """
1090        """
1091        self.alpha = alpha
1092        self.nfunc = nfunc
1093        self.max_length = d_max
1094        self.q_min = q_min
1095        self.q_max = q_max
1096        self.has_bck = bck
1097        self.slit_height = height
1098        self.slit_width  = width
1099       
1100        try:
1101            pr = self._create_file_pr(data)
1102            #pr = self._create_file_pr(path)
1103            if not pr is None:
1104                self.pr = pr
1105                self.perform_estimate()
1106        except:
1107            wx.PostEvent(self.parent, StatusEvent(status=sys.exc_value))
1108               
1109         
1110    def _create_file_pr(self, data):
1111        """
1112        Create and prepare invertor instance from
1113        a file data set.
1114       
1115        :param path: path of the file to read in
1116       
1117        """
1118        # Load data
1119        #if os.path.isfile(path):
1120        """   
1121        if self._current_file_data is not None \
1122            and self._current_file_data.path==path:
1123            # Protect against corrupted data from
1124            # previous failed load attempt
1125            if self._current_file_data.x is None:
1126                return None
1127            x = self._current_file_data.x
1128            y = self._current_file_data.y
1129            err = self._current_file_data.err
1130           
1131            message = "The data from this file has already been loaded."
1132            wx.PostEvent(self.parent, StatusEvent(status=message))
1133        else:
1134        """
1135        # Reset the status bar so that we don't get mixed up
1136        # with old messages.
1137        #TODO: refactor this into a proper status handling
1138        wx.PostEvent(self.parent, StatusEvent(status=''))
1139        try:
1140            class FileData:
1141                x = None
1142                y = None
1143                err = None
1144                path = None
1145                def __init__(self, path):
1146                    self.path = path
1147               
1148            self._current_file_data = FileData(data.path)
1149            self._current_file_data.x = data.x
1150            self._current_file_data.y = data.y
1151            self._current_file_data.err = data.dy
1152            x, y, err = data.x, data.y, data.dy
1153        except:
1154            load_error(sys.exc_value)
1155            return None
1156       
1157        # If the file contains no data, just return
1158        if x is None or len(x) == 0:
1159            load_error("The loaded file contains no data")
1160            return None
1161       
1162        # If we have not errors, add statistical errors
1163        if err is not None and y is not None:
1164            err = numpy.zeros(len(y))
1165            scale = None
1166            min_err = 0.0
1167            for i in range(len(y)):
1168                # Scale the error so that we can fit over several decades of Q
1169                if scale == None:
1170                    scale = 0.05 * math.sqrt(y[i])
1171                    min_err = 0.01 * y[i]
1172                err[i] = scale * math.sqrt(math.fabs(y[i])) + min_err
1173            message = "The loaded file had no error bars, "
1174            message += "statistical errors are assumed."
1175            wx.PostEvent(self.parent, StatusEvent(status=message))
1176       
1177        try:
1178            # Get the data from the chosen data set and perform inversion
1179            pr = Invertor()
1180            pr.d_max = self.max_length
1181            pr.alpha = self.alpha
1182            pr.q_min = self.q_min
1183            pr.q_max = self.q_max
1184            pr.x = x
1185            pr.y = y
1186            pr.err = err
1187            pr.has_bck = self.has_bck
1188            pr.slit_height = self.slit_height
1189            pr.slit_width = self.slit_width
1190            return pr
1191        except:
1192            load_error(sys.exc_value)
1193        return None
1194       
1195    def perform_estimate(self):
1196        """
1197        """
1198        from pr_thread import EstimatePr
1199        from copy import deepcopy
1200       
1201        # If a thread is already started, stop it
1202        if self.estimation_thread != None and \
1203            self.estimation_thread.isrunning():
1204            self.estimation_thread.stop()
1205               
1206        pr = self.pr.clone()
1207        self.estimation_thread = EstimatePr(pr, self.nfunc,
1208                                             error_func=self._thread_error, 
1209                                         completefn = self._estimate_completed, 
1210                                            updatefn   = None)
1211        self.estimation_thread.queue()
1212        self.estimation_thread.ready(2.5)
1213   
1214    def perform_estimateNT(self):
1215        """
1216        """
1217        from pr_thread import EstimateNT
1218        from copy import deepcopy
1219       
1220        # If a thread is already started, stop it
1221        if self.estimation_thread != None and self.estimation_thread.isrunning():
1222            self.estimation_thread.stop()
1223               
1224        pr = self.pr.clone()
1225        # Skip the slit settings for the estimation
1226        # It slows down the application and it doesn't change the estimates
1227        pr.slit_height = 0.0
1228        pr.slit_width  = 0.0
1229        self.estimation_thread = EstimateNT(pr, self.nfunc, 
1230                                            error_func=self._thread_error, 
1231                                        completefn = self._estimateNT_completed, 
1232                                            updatefn   = None)
1233        self.estimation_thread.queue()
1234        self.estimation_thread.ready(2.5)
1235       
1236    def perform_inversion(self):
1237        """
1238        """
1239        # Time estimate
1240        #estimated = self.elapsed*self.nfunc**2
1241        #message = "Computation time may take up to %g seconds" % self.elapsed
1242        #wx.PostEvent(self.parent, StatusEvent(status=message))
1243       
1244        # Start inversion thread
1245        self.start_thread()
1246        return
1247       
1248        out, cov = self.pr.lstsq(self.nfunc)
1249       
1250        # Save useful info
1251        self.elapsed = self.pr.elapsed
1252       
1253        for i in range(len(out)):
1254            try:
1255                print "%d: %g +- %g" % (i, out[i],
1256                                         math.sqrt(math.fabs(cov[i][i])))
1257            except: 
1258                print "%d: %g +- ?" % (i, out[i])       
1259       
1260        # Make a plot of I(q) data
1261        new_plot = Data1D(self.pr.x, self.pr.y, dy=self.pr.err)
1262        new_plot.name = "I_{obs}(q)"
1263        new_plot.xaxis("\\rm{Q}", 'A^{-1}')
1264        new_plot.yaxis("\\rm{Intensity} ","cm^{-1}")
1265        wx.PostEvent(self.parent, NewPlotEvent(plot=new_plot, title="Iq"))
1266        # Show I(q) fit
1267        self.show_iq(out, self.pr)
1268        # Show P(r) fit
1269        x_values, x_range = self.show_pr(out, self.pr, cov=cov)
1270       
1271    def _on_context_inversion(self, event):
1272        """
1273        """
1274        panel = event.GetEventObject()
1275
1276        # If we have more than one displayed plot, make the user choose
1277        if len(panel.plots) > 1 and \
1278            panel.graph.selected_plottable in panel.plots:
1279            dataset = panel.graph.selected_plottable
1280        elif len(panel.plots) == 1:
1281            dataset = panel.plots.keys()[0]
1282        else:
1283            logging.info("Prview Error: No data is available")
1284            return
1285       
1286        # Store a reference to the current plottable
1287        # If we have a suggested value, use it.
1288        try:
1289            estimate = float(self.control_panel.alpha_estimate)
1290            self.control_panel.alpha = estimate
1291        except:
1292            self.control_panel.alpha = self.alpha
1293            logging.info("Prview :Alpha Not estimate yet")
1294            pass
1295        try:
1296            estimate = int(self.control_panel.nterms_estimate)
1297            self.control_panel.nfunc = estimate
1298        except:
1299            self.control_panel.nfunc = self.nfunc
1300            logging.info("Prview : ntemrs Not estimate yet")
1301            pass
1302       
1303        self.current_plottable = panel.plots[dataset]
1304        self.control_panel.plotname = dataset
1305        #self.control_panel.nfunc = self.nfunc
1306        self.control_panel.d_max = self.max_length
1307        self.parent.set_perspective(self.perspective)
1308        self.control_panel._on_invert(None)
1309           
1310    def get_panels(self, parent):
1311        """
1312            Create and return a list of panel objects
1313        """
1314        from inversion_panel import InversionControl
1315       
1316        self.parent = parent
1317        self.control_panel = InversionControl(self.parent, -1, 
1318                                              style=wx.RAISED_BORDER,
1319                                              standalone=self.standalone)
1320        self.control_panel.set_manager(self)
1321        self.control_panel.nfunc = self.nfunc
1322        self.control_panel.d_max = self.max_length
1323        self.control_panel.alpha = self.alpha
1324        self.perspective = []
1325        self.perspective.append(self.control_panel.window_name)
1326     
1327        return [self.control_panel]
1328   
1329    def set_data(self, data_list):
1330        """
1331        receive a list of data to compute pr
1332        """
1333        if len(data_list) > 1:
1334            msg = "Pr panel does not allow multiple Data.\n"
1335            msg += "Please select one!\n"
1336            from pr_widgets import DataDialog
1337            dlg = DataDialog(data_list=data_list, text=msg)
1338            if dlg.ShowModal() == wx.ID_OK:
1339                data = dlg.get_data()
1340                if issubclass(data.__class__, Data1D):
1341                    self.control_panel._change_file(evt=None, data=data)
1342                else:   
1343                    msg = "Pr cannot be computed for data of "
1344                    msg += "type %s" % (data_list[0].__class__.__name__)
1345                    wx.PostEvent(self.parent, 
1346                             StatusEvent(status=msg, info='error'))
1347        elif len(data_list) == 1:
1348            if issubclass(data_list[0].__class__, Data1D):
1349                self.control_panel._change_file(evt=None, data=data_list[0])
1350            else:
1351                msg = "Pr cannot be computed for"
1352                msg += " data of type %s" % (data_list[0].__class__.__name__)
1353                wx.PostEvent(self.parent, 
1354                             StatusEvent(status=msg, info='error'))
1355        else:
1356            msg = "Pr contain no data"
1357            wx.PostEvent(self.parent, StatusEvent(status=msg, info='warning'))
1358           
1359    def post_init(self):
1360        """
1361            Post initialization call back to close the loose ends
1362            [Somehow openGL needs this call]
1363        """
1364        if self.standalone:
1365            self.parent.set_perspective(self.perspective)
1366 
1367if __name__ == "__main__":
1368    i = Plugin()
1369    print i.perform_estimateNT()
1370   
1371   
1372   
1373   
Note: See TracBrowser for help on using the repository browser.