source: sasview/src/sas/sasgui/perspectives/invariant/invariant_state.py @ 6015eee

magnetic_scattrelease-4.2.2ticket-1009ticket-1094-headlessticket-1242-2d-resolutionticket-1243ticket-1249unittest-saveload
Last change on this file since 6015eee was 2469df7, checked in by Paul Kienzle <pkienzle@…>, 7 years ago

lint: update 'if x==True/False?' to 'if x/not x:'

  • Property mode set to 100644
File size: 30.9 KB
Line 
1"""
2    State class for the invariant UI
3"""
4
5# import time
6import os
7import sys
8import logging
9import copy
10import sas.sascalc.dataloader
11# from xml.dom.minidom import parse
12from lxml import etree
13from sas.sascalc.dataloader.readers.cansas_reader import Reader as CansasReader
14from sas.sascalc.dataloader.readers.cansas_reader import get_content
15from sas.sasgui.guiframe.utils import format_number
16from sas.sasgui.guiframe.gui_style import GUIFRAME_ID
17from sas.sasgui.guiframe.dataFitting import Data1D
18
19logger = logging.getLogger(__name__)
20
21INVNODE_NAME = 'invariant'
22CANSAS_NS = "cansas1d/1.0"
23
24# default state
25DEFAULT_STATE = {'file': 'None',
26                 'compute_num':0,
27                 'state_num':0,
28                 'is_time_machine':False,
29                 'background_tcl':0.0,
30                 'scale_tcl':1.0,
31                 'contrast_tcl':1.0,
32                 'porod_constant_tcl':'',
33                 'npts_low_tcl':10,
34                 'npts_high_tcl':10,
35                 'power_high_tcl':4.0,
36                 'power_low_tcl': 4.0,
37                 'enable_high_cbox':False,
38                 'enable_low_cbox':False,
39                 'guinier': True,
40                 'power_law_high': False,
41                 'power_law_low': False,
42                 'fit_enable_high': False,
43                 'fit_enable_low': False,
44                 'fix_enable_high':True,
45                 'fix_enable_low':True,
46                 'volume_tcl':'',
47                 'volume_err_tcl':'',
48                 'surface_tcl':'',
49                 'surface_err_tcl':''}
50# list of states: This list will be filled as panel
51# init and the number of states increases
52state_list = {}
53bookmark_list = {}
54# list of input parameters (will be filled up on panel init) used by __str__
55input_list = {'background_tcl':0,
56              'scale_tcl':0,
57              'contrast_tcl':0,
58              'porod_constant_tcl':'',
59              'npts_low_tcl':0,
60              'npts_high_tcl':0,
61              'power_high_tcl':0,
62              'power_low_tcl': 0}
63# list of output parameters (order sensitive) used by __str__
64output_list = [["qstar_low", "Q* from low Q extrapolation [1/(cm*A)]"],
65               ["qstar_low_err", "dQ* from low Q extrapolation"],
66               ["qstar_low_percent", "Q* percent from low Q extrapolation"],
67               ["qstar", "Q* from data [1/(cm*A)]"],
68               ["qstar_err", "dQ* from data"],
69               ["qstar_percent", "Q* percent from data"],
70               ["qstar_high", "Q* from high Q extrapolation [1/(cm*A)]"],
71               ["qstar_high_err", "dQ* from high Q extrapolation"],
72               ["qstar_high_percent", "Q* percent from low Q extrapolation"],
73               ["qstar_total", "total Q* [1/(cm*A)]"],
74               ["qstar_total_err", "total dQ*"],
75               ["volume", "volume fraction"],
76               ["volume_err", "volume fraction error"],
77               ["surface", "specific surface"],
78               ["surface_err", "specific surface error"]]
79
80
81
82class InvariantState(object):
83    """
84    Class to hold the state information of the InversionControl panel.
85    """
86    def __init__(self):
87        """
88        Default values
89        """
90        # Input
91        self.file = None
92        self.data = Data1D(x=[], y=[], dx=None, dy=None)
93        self.theory_lowQ = Data1D(x=[], y=[], dy=None)
94        self.theory_lowQ.symbol = GUIFRAME_ID.CURVE_SYMBOL_NUM
95        self.theory_highQ = Data1D(x=[], y=[], dy=None)
96        self.theory_highQ.symbol = GUIFRAME_ID.CURVE_SYMBOL_NUM
97        # self.is_time_machine = False
98        self.saved_state = DEFAULT_STATE
99        self.state_list = state_list
100        self.bookmark_list = bookmark_list
101        self.input_list = input_list
102        self.output_list = output_list
103
104        self.compute_num = 0
105        self.state_num = 0
106        self.timestamp = ('00:00:00', '00/00/0000')
107        self.container = None
108        # plot image
109        self.wximbmp = None
110        # report_html strings
111        import sas.sasgui.perspectives.invariant as invariant
112        path = invariant.get_data_path(media='media')
113        path_report_html = os.path.join(path, "report_template.html")
114        html_template = open(path_report_html, "r")
115        self.template_str = html_template.read()
116        self.report_str = self.template_str
117        # self.report_str_save = None
118        html_template.close()
119
120    def __str__(self):
121        """
122        Pretty print
123
124        : return: string representing the state
125        """
126        # Input string
127        compute_num = self.saved_state['compute_num']
128        compute_state = self.state_list[str(compute_num)]
129        my_time, date = self.timestamp
130        file_name = self.file
131
132        state_num = int(self.saved_state['state_num'])
133        state = "\n[Invariant computation for %s: " % file_name
134        state += "performed at %s on %s] \n" % (my_time, date)
135        state += "State No.: %d \n" % state_num
136        state += "\n=== Inputs ===\n"
137
138        # text ctl general inputs ( excluding extrapolation text ctl)
139        for key, value in self.input_list.iteritems():
140            if value == '':
141                continue
142            key_split = key.split('_')
143            max_ind = len(key_split) - 1
144            if key_split[max_ind] == 'tcl':
145                name = ""
146                if key_split[1] == 'low' or key_split[1] == 'high':
147                    continue
148                for ind in range(0, max_ind):
149                    name += " %s" % key_split[ind]
150                state += "%s:   %s\n" % (name.lstrip(" "), value)
151
152        # other input parameters
153        extra_lo = compute_state['enable_low_cbox']
154        if compute_state['enable_low_cbox']:
155            if compute_state['guinier']:
156                extra_lo = 'Guinier'
157            else:
158                extra_lo = 'Power law'
159        extra_hi = compute_state['enable_high_cbox']
160        if compute_state['enable_high_cbox']:
161            extra_hi = 'Power law'
162        state += "\nExtrapolation:  High=%s; Low=%s\n" % (extra_hi, extra_lo)
163        low_off = False
164        high_off = False
165        for key, value in self.input_list.iteritems():
166            key_split = key.split('_')
167            max_ind = len(key_split) - 1
168            if key_split[max_ind] == 'tcl':
169                name = ""
170                # check each buttons whether or not ON or OFF
171                if key_split[1] == 'low' or key_split[1] == 'high':
172                    if not compute_state['enable_low_cbox'] and \
173                        key_split[max_ind - 1] == 'low':
174                        low_off = True
175                        continue
176                    elif not compute_state['enable_high_cbox'] and \
177                        key_split[max_ind - 1] == 'high':
178                        high_off = True
179                        continue
180                    elif extra_lo == 'Guinier' and key_split[0] == 'power' and \
181                        key_split[max_ind - 1] == 'low':
182                        continue
183                    for ind in range(0, max_ind):
184                        name += " %s" % key_split[ind]
185                    name = name.lstrip(" ")
186                    if name == "power low":
187                        if compute_state['fix_enable_low']:
188                            name += ' (Fixed)'
189                        else:
190                            name += ' (Fitted)'
191                    if name == "power high":
192                        if compute_state['fix_enable_high']:
193                            name += ' (Fixed)'
194                        else:
195                            name += ' (Fitted)'
196                    state += "%s:   %s\n" % (name, value)
197        # Outputs
198        state += "\n=== Outputs ==="
199        for item in output_list:
200            item_split = item[0].split('_')
201            # Exclude the extrapolation that turned off
202            if len(item_split) > 1:
203                if low_off and item_split[1] == 'low':
204                    continue
205                if high_off and item_split[1] == 'high':
206                    continue
207            max_ind = len(item_split) - 1
208            value = None
209            if hasattr(self.container, item[0]):
210                # Q* outputs
211                value = getattr(self.container, item[0])
212            else:
213                # other outputs than Q*
214                name = item[0] + "_tcl"
215                if name in self.saved_state.keys():
216                    value = self.saved_state[name]
217
218            # Exclude the outputs w/''
219            if value == '':
220                continue
221            # Error outputs
222            if item_split[max_ind] == 'err':
223                state += "+- %s " % format_number(value)
224            # Percentage outputs
225            elif item_split[max_ind] == 'percent':
226                value = float(value) * 100
227                state += "(%s %s)" % (format_number(value), '%')
228            # Outputs
229            else:
230                state += "\n%s:   %s " % (item[1],
231                                          format_number(value, high=True))
232        # Include warning msg
233        if self.container is not None:
234            state += "\n\nNote:\n" + self.container.warning_msg
235        return state
236
237    def clone_state(self):
238        """
239        deepcopy the state
240        """
241        return copy.deepcopy(self.saved_state)
242
243    def toXML(self, file="inv_state.inv", doc=None, entry_node=None):
244        """
245        Writes the state of the InversionControl panel to file, as XML.
246
247        Compatible with standalone writing, or appending to an
248        already existing XML document. In that case, the XML document
249        is required. An optional entry node in the XML document
250        may also be given.
251
252        : param file: file to write to
253        : param doc: XML document object [optional]
254        : param entry_node: XML node within the document at which we will append the data [optional]
255        """
256        # TODO: Get this to work
257        from xml.dom.minidom import getDOMImplementation
258        import time
259        timestamp = time.time()
260        # Check whether we have to write a standalone XML file
261        if doc is None:
262            impl = getDOMImplementation()
263
264            doc_type = impl.createDocumentType(INVNODE_NAME, "1.0", "1.0")
265
266            newdoc = impl.createDocument(None, INVNODE_NAME, doc_type)
267            top_element = newdoc.documentElement
268        else:
269            # We are appending to an existing document
270            newdoc = doc
271            top_element = newdoc.createElement(INVNODE_NAME)
272            if entry_node is None:
273                newdoc.documentElement.appendChild(top_element)
274            else:
275                entry_node.appendChild(top_element)
276
277        attr = newdoc.createAttribute("version")
278        attr.nodeValue = '1.0'
279        top_element.setAttributeNode(attr)
280
281        # File name
282        element = newdoc.createElement("filename")
283        if self.file is not None and self.file != '':
284            element.appendChild(newdoc.createTextNode(str(self.file)))
285        else:
286            element.appendChild(newdoc.createTextNode(str(file)))
287        top_element.appendChild(element)
288
289        element = newdoc.createElement("timestamp")
290        element.appendChild(newdoc.createTextNode(time.ctime(timestamp)))
291        attr = newdoc.createAttribute("epoch")
292        attr.nodeValue = str(timestamp)
293        element.setAttributeNode(attr)
294        top_element.appendChild(element)
295
296        # Current state
297        state = newdoc.createElement("state")
298        top_element.appendChild(state)
299
300        for name, value in self.saved_state.iteritems():
301            element = newdoc.createElement(str(name))
302            element.appendChild(newdoc.createTextNode(str(value)))
303            state.appendChild(element)
304
305        # State history list
306        history = newdoc.createElement("history")
307        top_element.appendChild(history)
308
309        for name, value in self.state_list.iteritems():
310            history_element = newdoc.createElement('state_' + str(name))
311            for state_name, state_value in value.iteritems():
312                state_element = newdoc.createElement(str(state_name))
313                child = newdoc.createTextNode(str(state_value))
314                state_element.appendChild(child)
315                history_element.appendChild(state_element)
316            # history_element.appendChild(state_list_element)
317            history.appendChild(history_element)
318
319        # Bookmarks  bookmark_list[self.bookmark_num] = [\
320        # my_time,date,state,comp_state]
321        bookmark = newdoc.createElement("bookmark")
322        top_element.appendChild(bookmark)
323        item_list = ['time', 'date', 'state', 'comp_state']
324        for name, value_list in self.bookmark_list.iteritems():
325            element = newdoc.createElement('mark_' + str(name))
326            _, date, state, comp_state = value_list
327            time_element = newdoc.createElement('time')
328            time_element.appendChild(newdoc.createTextNode(str(value_list[0])))
329            date_element = newdoc.createElement('date')
330            date_element.appendChild(newdoc.createTextNode(str(value_list[1])))
331            state_list_element = newdoc.createElement('state')
332            comp_state_list_element = newdoc.createElement('comp_state')
333            for state_name, state_value in value_list[2].iteritems():
334                state_element = newdoc.createElement(str(state_name))
335                child = newdoc.createTextNode(str(state_value))
336                state_element.appendChild(child)
337                state_list_element.appendChild(state_element)
338            for comp_name, comp_value in value_list[3].iteritems():
339                comp_element = newdoc.createElement(str(comp_name))
340                comp_element.appendChild(newdoc.createTextNode(str(comp_value)))
341                comp_state_list_element.appendChild(comp_element)
342            element.appendChild(time_element)
343            element.appendChild(date_element)
344            element.appendChild(state_list_element)
345            element.appendChild(comp_state_list_element)
346            bookmark.appendChild(element)
347
348        # Save the file
349        if doc is None:
350            fd = open('test000', 'w')
351            fd.write(newdoc.toprettyxml())
352            fd.close()
353            return None
354        else:
355            return newdoc
356
357    def fromXML(self, file=None, node=None):
358        """
359        Load invariant states from a file
360
361        : param file: .inv file
362        : param node: node of a XML document to read from
363        """
364        if file is not None:
365            msg = "InvariantSate no longer supports non-CanSAS"
366            msg += " format for invariant files"
367            raise RuntimeError, msg
368
369        if node.get('version')\
370            and node.get('version') == '1.0':
371
372            # Get file name
373            entry = get_content('ns:filename', node)
374            if entry is not None:
375                file_name = entry.text.strip()
376
377            # Get time stamp
378            entry = get_content('ns:timestamp', node)
379            if entry is not None and entry.get('epoch'):
380                try:
381                    timestamp = (entry.get('epoch'))
382                except:
383                    msg = "InvariantSate.fromXML: Could not read"
384                    msg += " timestamp\n %s" % sys.exc_value
385                    logger.error(msg)
386
387            # Parse bookmarks
388            entry_bookmark = get_content('ns:bookmark', node)
389
390            for ind in range(1, len(entry_bookmark) + 1):
391                temp_state = {}
392                temp_bookmark = {}
393                entry = get_content('ns:mark_%s' % ind, entry_bookmark)
394
395                if entry is not None:
396                    my_time = get_content('ns:time', entry)
397                    val_time = str(my_time.text.strip())
398                    date = get_content('ns:date', entry)
399                    val_date = str(date.text.strip())
400                    state_entry = get_content('ns:state', entry)
401                    for item in DEFAULT_STATE:
402                        input_field = get_content('ns:%s' % item, state_entry)
403                        val = str(input_field.text.strip())
404                        if input_field is not None:
405                            temp_state[item] = val
406                    comp_entry = get_content('ns:comp_state', entry)
407
408                    for item in DEFAULT_STATE:
409                        input_field = get_content('ns:%s' % item, comp_entry)
410                        val = str(input_field.text.strip())
411                        if input_field is not None:
412                            temp_bookmark[item] = val
413                    try:
414                        self.bookmark_list[ind] = [val_time, val_date, temp_state, temp_bookmark]
415                    except:
416                        raise "missing components of bookmarks..."
417            # Parse histories
418            entry_history = get_content('ns:history', node)
419
420            for ind in range(0, len(entry_history)):
421                temp_state = {}
422                entry = get_content('ns:state_%s' % ind, entry_history)
423
424                if entry is not None:
425                    for item in DEFAULT_STATE:
426                        input_field = get_content('ns:%s' % item, entry)
427                        if input_field.text is not None:
428                            val = str(input_field.text.strip())
429                        else:
430                            val = ''
431                        if input_field is not None:
432                            temp_state[item] = val
433                            self.state_list[str(ind)] = temp_state
434
435            # Parse current state (ie, saved_state)
436            entry = get_content('ns:state', node)
437            if entry is not None:
438                for item in DEFAULT_STATE:
439                    input_field = get_content('ns:%s' % item, entry)
440                    if input_field.text is not None:
441                        val = str(input_field.text.strip())
442                    else:
443                        val = ''
444                    if input_field is not None:
445                        self.set_saved_state(name=item, value=val)
446            self.file = file_name
447
448    def set_report_string(self):
449        """
450        Get the values (strings) from __str__ for report
451        """
452        strings = self.__str__()
453
454        # default string values
455        for num in range(1, 19):
456            exec "s_%s = 'NA'" % str(num)
457        lines = strings.split('\n')
458        # get all string values from __str__()
459        for line in range(0, len(lines)):
460            if line == 1:
461                s_1 = lines[1]
462            elif line == 2:
463                s_2 = lines[2]
464            else:
465                item = lines[line].split(':')
466                item[0] = item[0].strip()
467                if item[0] == "scale":
468                    s_3 = item[1]
469                elif item[0] == "porod constant":
470                    s_4 = item[1]
471                elif item[0] == "background":
472                    s_5 = item[1]
473                elif item[0] == "contrast":
474                    s_6 = item[1]
475                elif item[0] == "Extrapolation":
476                    extra = item[1].split(";")
477                    bool_0 = extra[0].split("=")
478                    bool_1 = extra[1].split("=")
479                    s_8 = " " + bool_0[0] + "Q region = " + bool_0[1]
480                    s_7 = " " + bool_1[0] + "Q region = " + bool_1[1]
481                elif item[0] == "npts low":
482                    s_9 = item[1]
483                elif item[0] == "npts high":
484                    s_10 = item[1]
485                elif item[0] == "volume fraction":
486                    val = item[1].split("+-")[0].strip()
487                    error = item[1].split("+-")[1].strip()
488                    s_17 = val + " &plusmn; " + error
489                elif item[0] == "specific surface":
490                    val = item[1].split("+-")[0].strip()
491                    error = item[1].split("+-")[1].strip()
492                    s_18 = val + " &plusmn; " + error
493                elif item[0].split("(")[0].strip() == "power low":
494                    s_11 = item[0] + " =" + item[1]
495                elif item[0].split("(")[0].strip() == "power high":
496                    s_12 = item[0] + " =" + item[1]
497                elif item[0].split("[")[0].strip() == "Q* from low Q extrapolation":
498                    # looks messy but this way the symbols +_ and % work on html
499                    val = item[1].split("+-")[0].strip()
500                    error = item[1].split("+-")[1].strip()
501                    err = error.split("%")[0].strip()
502                    percent = error.split("%")[1].strip()
503                    s_13 = val + " &plusmn; " + err + "&#37" + percent
504                elif item[0].split("[")[0].strip() == "Q* from data":
505                    val = item[1].split("+-")[0].strip()
506                    error = item[1].split("+-")[1].strip()
507                    err = error.split("%")[0].strip()
508                    percent = error.split("%")[1].strip()
509                    s_14 = val + " &plusmn; " + err + "&#37" + percent
510                elif item[0].split("[")[0].strip() == "Q* from high Q extrapolation":
511                    val = item[1].split("+-")[0].strip()
512                    error = item[1].split("+-")[1].strip()
513                    err = error.split("%")[0].strip()
514                    percent = error.split("%")[1].strip()
515                    s_15 = val + " &plusmn; " + err + "&#37" + percent
516                elif item[0].split("[")[0].strip() == "total Q*":
517                    val = item[1].split("+-")[0].strip()
518                    error = item[1].split("+-")[1].strip()
519                    s_16 = val + " &plusmn; " + error
520                else:
521                    continue
522
523        s_1 = self._check_html_format(s_1)
524        file_name = self._check_html_format(self.file)
525
526        # make plot image
527        self.set_plot_state(extra_high=bool_0[1], extra_low=bool_1[1])
528        # get ready for report with setting all the html strings
529        self.report_str = str(self.template_str) % (s_1, s_2,
530                                                    s_3, s_4, s_5, s_6, s_7, s_8,
531                                                    s_9, s_10, s_11, s_12, s_13, s_14, s_15,
532                                                    s_16, s_17, s_18, file_name, "%s")
533
534    def _check_html_format(self, name):
535        """
536        Check string '%' for html format
537        """
538        if name.count('%'):
539            name = name.replace('%', '&#37')
540
541        return name
542
543    def set_saved_state(self, name, value):
544        """
545        Set the state list
546
547        : param name: name of the state component
548        : param value: value of the state component
549        """
550        rb_list = [['power_law_low', 'guinier'],
551                   ['fit_enable_low', 'fix_enable_low'],
552                   ['fit_enable_high', 'fix_enable_high']]
553
554        self.name = value
555        self.saved_state[name] = value
556        # set the count part of radio button clicked
557        # False for the saved_state
558        for title, content in rb_list:
559            if name == title:
560                name = content
561                value = False
562            elif name == content:
563                name = title
564                value = False
565        self.saved_state[name] = value
566        self.state_num = self.saved_state['state_num']
567
568    def set_plot_state(self, extra_high=False, extra_low=False):
569        """
570        Build image state that wx.html understand
571        by plotting, putting it into wx.FileSystem image object
572
573        : extrap_high,extra_low: low/high extrapolations
574        are possible extra-plots
575        """
576        # some imports
577        import wx
578        import matplotlib.pyplot as plt
579        from matplotlib.backends.backend_agg import FigureCanvasAgg
580
581        # we use simple plot, not plotpanel
582        # make matlab figure
583        fig = plt.figure()
584        fig.set_facecolor('w')
585        graph = fig.add_subplot(111)
586
587        # data plot
588        graph.errorbar(self.data.x, self.data.y, yerr=self.data.dy, fmt='o')
589        # low Q extrapolation fit plot
590        if not extra_low == 'False':
591            graph.plot(self.theory_lowQ.x, self.theory_lowQ.y)
592        # high Q extrapolation fit plot
593        if not extra_high == 'False':
594            graph.plot(self.theory_highQ.x, self.theory_highQ.y)
595        graph.set_xscale("log", nonposx='clip')
596        graph.set_yscale("log", nonposy='clip')
597        graph.set_xlabel('$\\rm{Q}(\\AA^{-1})$', fontsize=12)
598        graph.set_ylabel('$\\rm{Intensity}(cm^{-1})$', fontsize=12)
599        canvas = FigureCanvasAgg(fig)
600        # actually make image
601        canvas.draw()
602
603        # make python.Image object
604        # size
605        w, h = canvas.get_width_height()
606        # convert to wx.Image
607        wximg = wx.EmptyImage(w, h)
608        # wxim.SetData(img.convert('RGB').tostring() )
609        wximg.SetData(canvas.tostring_rgb())
610        # get the dynamic image for the htmlwindow
611        wximgbmp = wx.BitmapFromImage(wximg)
612        # store the image in wx.FileSystem Object
613        wx.FileSystem.AddHandler(wx.MemoryFSHandler())
614        # use wx.MemoryFSHandler
615        self.imgRAM = wx.MemoryFSHandler()
616        # AddFile, image can be retrieved with 'memory:filename'
617        self.imgRAM.AddFile('img_inv.png', wximgbmp, wx.BITMAP_TYPE_PNG)
618
619        self.wximgbmp = 'memory:img_inv.png'
620        self.image = fig
621
622class Reader(CansasReader):
623    """
624    Class to load a .inv invariant file
625    """
626    # # File type
627    type_name = "Invariant"
628
629    # # Wildcards
630    type = ["Invariant file (*.inv)|*.inv",
631            "SASView file (*.svs)|*.svs"]
632    # # List of allowed extensions
633    ext = ['.inv', '.INV', '.svs', 'SVS']
634
635    def __init__(self, call_back, cansas=True):
636        """
637        Initialize the call-back method to be called
638        after we load a file
639
640        : param call_back: call-back method
641        : param cansas:  True = files will be written/read in CanSAS format
642                        False = write CanSAS format
643        """
644        # # Call back method to be executed after a file is read
645        self.call_back = call_back
646        # # CanSAS format flag
647        self.cansas = cansas
648        self.state = None
649
650    def read(self, path):
651        """
652        Load a new invariant state from file
653
654        : param path: file path
655        : return: None
656        """
657        if self.cansas:
658            return self._read_cansas(path)
659        else:
660            return self._read_standalone(path)
661
662    def _read_standalone(self, path):
663        """
664        Load a new invariant state from file.
665        The invariant node is assumed to be the top element.
666
667        : param path: file path
668        : return: None
669        """
670        # Read the new state from file
671        state = InvariantState()
672
673        state.fromXML(file=path)
674
675        # Call back to post the new state
676        self.call_back(state)
677        return None
678
679    def _parse_state(self, entry):
680        """
681        Read an invariant result from an XML node
682
683        : param entry: XML node to read from
684        : return: InvariantState object
685        """
686        state = None
687        # Locate the invariant node
688        try:
689            nodes = entry.xpath('ns:%s' % INVNODE_NAME,
690                                namespaces={'ns': CANSAS_NS})
691            # Create an empty state
692            if nodes != []:
693                state = InvariantState()
694                state.fromXML(node=nodes[0])
695        except:
696            msg = "XML document does not contain invariant"
697            msg += " information.\n %s" % sys.exc_value
698            logger.info(msg)
699        return state
700
701    def _read_cansas(self, path):
702        """
703        Load data and invariant information from a CanSAS XML file.
704
705        : param path: file path
706        : return: Data1D object if a single SASentry was found,
707                    or a list of Data1D objects if multiple entries were found,
708                    or None of nothing was found
709        : raise RuntimeError: when the file can't be opened
710        : raise ValueError: when the length of the data vectors are inconsistent
711        """
712        output = []
713        if os.path.isfile(path):
714            basename = os.path.basename(path)
715            root, extension = os.path.splitext(basename)
716
717            if  extension.lower() in self.ext or \
718                extension.lower() == '.xml':
719                tree = etree.parse(path, parser=etree.ETCompatXMLParser())
720
721                # Check the format version number
722                # Specifying the namespace will take care of
723                # the file format version
724                root = tree.getroot()
725
726                entry_list = root.xpath('/ns:SASroot/ns:SASentry',
727                                        namespaces={'ns': CANSAS_NS})
728
729                for entry in entry_list:
730                    invstate = self._parse_state(entry)
731                    # invstate could be None when .svs file is loaded
732                    # in this case, skip appending to output
733                    if invstate is not None:
734                        sas_entry, _ = self._parse_entry(entry)
735                        sas_entry.meta_data['invstate'] = invstate
736                        sas_entry.filename = invstate.file
737                        output.append(sas_entry)
738        else:
739            raise RuntimeError, "%s is not a file" % path
740
741        # Return output consistent with the loader's api
742        if len(output) == 0:
743            return None
744        elif len(output) == 1:
745            # Call back to post the new state
746            self.state = output[0].meta_data['invstate']
747            self.call_back(state=output[0].meta_data['invstate'],
748                           datainfo=output[0])
749            return output[0]
750        else:
751            return output
752
753    def get_state(self):
754        return self.state
755
756    def write(self, filename, datainfo=None, invstate=None):
757        """
758        Write the content of a Data1D as a CanSAS XML file
759
760        : param filename: name of the file to write
761        : param datainfo: Data1D object
762        : param invstate: InvariantState object
763        """
764        # Sanity check
765        if self.cansas:
766            doc = self.write_toXML(datainfo, invstate)
767            # Write the XML document
768            fd = open(filename, 'w')
769            fd.write(doc.toprettyxml())
770            fd.close()
771        else:
772            invstate.toXML(file=filename)
773
774    def write_toXML(self, datainfo=None, state=None):
775        """
776        Write toXML, a helper for write()
777
778        : return: xml doc
779        """
780        if datainfo is None:
781            datainfo = sas.sascalc.dataloader.data_info.Data1D(x=[], y=[])
782        elif not issubclass(datainfo.__class__, sas.sascalc.dataloader.data_info.Data1D):
783            msg = "The cansas writer expects a Data1D"
784            msg += " instance: %s" % str(datainfo.__class__.__name__)
785            raise RuntimeError, msg
786        # make sure title and data run is filled up.
787        if datainfo.title is None or datainfo.title == '':
788            datainfo.title = datainfo.name
789        if datainfo.run_name is None or datainfo.run_name == {}:
790            datainfo.run = [str(datainfo.name)]
791            datainfo.run_name[0] = datainfo.name
792        # Create basic XML document
793        doc, sasentry = self._to_xml_doc(datainfo)
794        # Add the invariant information to the XML document
795        if state is not None:
796            doc = state.toXML(datainfo.name, doc=doc, entry_node=sasentry)
797        return doc
Note: See TracBrowser for help on using the repository browser.