source: sasview/src/sas/qtgui/Perspectives/Fitting/FittingLogic.py @ c71b20a

Last change on this file since c71b20a was 61f0c75, checked in by Torin Cooper-Bennun <torin.cooper-bennun@…>, 6 years ago

fix crash on receiving non-array intermediate results

  • Property mode set to 100644
File size: 9.0 KB
RevLine 
[2add354]1import numpy as np
[4d457df]2
[dc5ef15]3from sas.qtgui.Plotting.PlotterData import Data1D
4from sas.qtgui.Plotting.PlotterData import Data2D
5
[4d457df]6from sas.sascalc.dataloader.data_info import Detector
7from sas.sascalc.dataloader.data_info import Source
8
9
10class FittingLogic(object):
11    """
12    All the data-related logic. This class deals exclusively with Data1D/2D
13    No QStandardModelIndex here.
14    """
15    def __init__(self, data=None):
16        self._data = data
[7248d75d]17        self.data_is_loaded = False
[87dfca4]18        #dq data presence in the dataset
19        self.dq_flag = False
20        #di data presence in the dataset
21        self.di_flag = False
[7248d75d]22        if data is not None:
23            self.data_is_loaded = True
[87dfca4]24            self.setDataProperties()
[4d457df]25
26    @property
27    def data(self):
28        return self._data
29
30    @data.setter
31    def data(self, value):
32        """ data setter """
33        self._data = value
34        self.data_is_loaded = True
[87dfca4]35        self.setDataProperties()
[4d457df]36
[180bd54]37    def isLoadedData(self):
38        """ accessor """
39        return self.data_is_loaded
40
[87dfca4]41    def setDataProperties(self):
42        """
43        Analyze data and set up some properties important for
44        the Presentation layer
45        """
46        if self._data.__class__.__name__ == "Data2D":
47            if self._data.err_data is not None and np.any(self._data.err_data):
48                self.di_flag = True
49            if self._data.dqx_data is not None and np.any(self._data.dqx_data):
50                self.dq_flag = True
51        else:
52            if self._data.dy is not None and np.any(self._data.dy):
53                self.di_flag = True
54            if self._data.dx is not None and np.any(self._data.dx):
55                self.dq_flag = True
56            elif self._data.dxl is not None and np.any(self._data.dxl):
57                self.dq_flag = True
58
[4d457df]59    def createDefault1dData(self, interval, tab_id=0):
60        """
61        Create default data for fitting perspective
62        Only when the page is on theory mode.
63        """
64        self._data = Data1D(x=interval)
65        self._data.xaxis('\\rm{Q}', "A^{-1}")
66        self._data.yaxis('\\rm{Intensity}', "cm^{-1}")
67        self._data.is_data = False
68        self._data.id = str(tab_id) + " data"
69        self._data.group_id = str(tab_id) + " Model1D"
70
71    def createDefault2dData(self, qmax, qstep, tab_id=0):
72        """
73        Create 2D data by default
74        Only when the page is on theory mode.
75        """
76        self._data = Data2D()
77        self._data.xaxis('\\rm{Q_{x}}', 'A^{-1}')
78        self._data.yaxis('\\rm{Q_{y}}', 'A^{-1}')
79        self._data.is_data = False
80        self._data.id = str(tab_id) + " data"
81        self._data.group_id = str(tab_id) + " Model2D"
82
83        # Default detector
84        self._data.detector.append(Detector())
85        index = len(self._data.detector) - 1
86        self._data.detector[index].distance = 8000   # mm
87        self._data.source.wavelength = 6             # A
88        self._data.detector[index].pixel_size.x = 5  # mm
89        self._data.detector[index].pixel_size.y = 5  # mm
90        self._data.detector[index].beam_center.x = qmax
91        self._data.detector[index].beam_center.y = qmax
92        # theory default: assume the beam
93        #center is located at the center of sqr detector
94        xmax = qmax
95        xmin = -qmax
96        ymax = qmax
97        ymin = -qmax
98
[2add354]99        x = np.linspace(start=xmin, stop=xmax, num=qstep, endpoint=True)
100        y = np.linspace(start=ymin, stop=ymax, num=qstep, endpoint=True)
[4d457df]101        # Use data info instead
[2add354]102        new_x = np.tile(x, (len(y), 1))
103        new_y = np.tile(y, (len(x), 1))
[4d457df]104        new_y = new_y.swapaxes(0, 1)
105
106        # all data required in 1d array
107        qx_data = new_x.flatten()
108        qy_data = new_y.flatten()
[2add354]109        q_data = np.sqrt(qx_data * qx_data + qy_data * qy_data)
[4d457df]110
111        # set all True (standing for unmasked) as default
[2add354]112        mask = np.ones(len(qx_data), dtype=bool)
[4d457df]113        # calculate the range of qx and qy: this way,
114        # it is a little more independent
115        # store x and y bin centers in q space
116        x_bins = x
117        y_bins = y
118
119        self._data.source = Source()
[2add354]120        self._data.data = np.ones(len(mask))
121        self._data.err_data = np.ones(len(mask))
[4d457df]122        self._data.qx_data = qx_data
123        self._data.qy_data = qy_data
124        self._data.q_data = q_data
125        self._data.mask = mask
126        self._data.x_bins = x_bins
127        self._data.y_bins = y_bins
128        # max and min taking account of the bin sizes
129        self._data.xmin = xmin
130        self._data.xmax = xmax
131        self._data.ymin = ymin
132        self._data.ymax = ymax
133
[3ae9179]134    def _create1DPlot(self, tab_id, x, y, model, data, component=None):
[4d457df]135        """
[44777ee]136        For internal use: create a new 1D data instance based on fitting results.
137        'component' is a string indicating the model component, e.g. "P(Q)"
[4d457df]138        """
139        # Create the new plot
140        new_plot = Data1D(x=x, y=y)
141        new_plot.is_data = False
[2add354]142        new_plot.dy = np.zeros(len(y))
[4d457df]143        _yaxis, _yunit = data.get_yaxis()
144        _xaxis, _xunit = data.get_xaxis()
145
146        new_plot.group_id = data.group_id
[3ae9179]147        new_plot.id = str(tab_id) + " " + ("[" + component + "] " if component else "") + model.id
[d6e38661]148
[3ae9179]149        # use data.filename for data, use model.id for theory
150        id_str = data.filename if data.filename else model.id
151        new_plot.name = model.name + ((" " + component) if component else "") + " [" + id_str + "]"
[d6e38661]152
[0268aed]153        new_plot.title = new_plot.name
[4d457df]154        new_plot.xaxis(_xaxis, _xunit)
155        new_plot.yaxis(_yaxis, _yunit)
156
[6fd4e36]157        return new_plot
[4d457df]158
[3ae9179]159    def new1DPlot(self, return_data, tab_id):
160        """
161        Create a new 1D data instance based on fitting results
162        """
[dcabba7]163        return self._create1DPlot(tab_id, return_data['x'], return_data['y'],
164                                  return_data['model'], return_data['data'])
[3ae9179]165
[4d457df]166    def new2DPlot(self, return_data):
167        """
168        Create a new 2D data instance based on fitting results
169        """
[dcabba7]170        image = return_data['image']
171        data = return_data['data']
172        model = return_data['model']
[4d457df]173
[2add354]174        np.nan_to_num(image)
[4d457df]175        new_plot = Data2D(image=image, err_image=data.err_data)
176        new_plot.name = model.name + '2d'
177        new_plot.title = "Analytical model 2D "
[dcabba7]178        new_plot.id = str(return_data['page_id']) + " " + data.name
179        new_plot.group_id = str(return_data['page_id']) + " Model2D"
[4d457df]180        new_plot.detector = data.detector
181        new_plot.source = data.source
182        new_plot.is_data = False
183        new_plot.qx_data = data.qx_data
184        new_plot.qy_data = data.qy_data
185        new_plot.q_data = data.q_data
186        new_plot.mask = data.mask
187        ## plot boundaries
188        new_plot.ymin = data.ymin
189        new_plot.ymax = data.ymax
190        new_plot.xmin = data.xmin
191        new_plot.xmax = data.xmax
192
193        title = data.title
194
195        new_plot.is_data = False
196        if data.is_data:
197            data_name = str(data.name)
198        else:
199            data_name = str(model.__class__.__name__) + '2d'
200
201        if len(title) > 1:
202            new_plot.title = "Model2D for %s " % model.name + data_name
203        new_plot.name = model.name + " [" + \
204                                    data_name + "]"
205
[6fd4e36]206        return new_plot
[4d457df]207
[3ae9179]208    def new1DProductPlots(self, return_data, tab_id):
209        """
[b4d05bd]210        If return_data contains separated P(Q) and/or S(Q) data, create 1D plots for each and return as the tuple
211        (pq_plot, sq_plot). If either are unavailable, the corresponding plot is None.
[3ae9179]212        """
[40975f8]213        plots = []
[9ba91b7]214        for name, result in return_data['intermediate_results'].items():
[61f0c75]215            if not isinstance(result, np.ndarray):
216                continue
[9ba91b7]217            plots.append(self._create1DPlot(tab_id, return_data['x'], result,
218                         return_data['model'], return_data['data'],
219                         component=name))
[40975f8]220        return plots
[3ae9179]221
[4d457df]222    def computeDataRange(self):
223        """
[ee18d33]224        Wrapper for calculating the data range based on local dataset
225        """
226        return self.computeRangeFromData(self.data)
227
228    def computeRangeFromData(self, data):
229        """
[4d457df]230        Compute the minimum and the maximum range of the data
231        return the npts contains in data
232        """
233        qmin, qmax, npts = None, None, None
[ee18d33]234        if isinstance(data, Data1D):
[4d457df]235            try:
[ee18d33]236                qmin = min(data.x)
237                qmax = max(data.x)
238                npts = len(data.x)
[4d457df]239            except (ValueError, TypeError):
240                msg = "Unable to find min/max/length of \n data named %s" % \
241                            self.data.filename
[b3e8629]242                raise ValueError(msg)
[4d457df]243
244        else:
245            qmin = 0
246            try:
[ee18d33]247                x = max(np.fabs(data.xmin), np.fabs(data.xmax))
248                y = max(np.fabs(data.ymin), np.fabs(data.ymax))
[4d457df]249            except (ValueError, TypeError):
250                msg = "Unable to find min/max of \n data named %s" % \
251                            self.data.filename
[b3e8629]252                raise ValueError(msg)
[2add354]253            qmax = np.sqrt(x * x + y * y)
[ee18d33]254            npts = len(data.data)
[4d457df]255        return qmin, qmax, npts
Note: See TracBrowser for help on using the repository browser.