source: sasview/src/sas/sasgui/perspectives/fitting/fitproblem.py @ 728b291

Last change on this file since 728b291 was 251ef684, checked in by Paul Kienzle <pkienzle@…>, 7 years ago

code cleanup for fitproblem

  • Property mode set to 100644
File size: 15.8 KB
Line 
1"""
2Inferface containing information to store data, model, range of data, etc...
3and retreive this information. This is an inferface
4for a fitProblem i.e relationship between data and model.
5"""
6################################################################################
7#This software was developed by the University of Tennessee as part of the
8#Distributed Data Analysis of Neutron Scattering Experiments (DANSE)
9#project funded by the US National Science Foundation.
10#
11#See the license text in license.txt
12#
13#copyright 2009, University of Tennessee
14################################################################################
15import copy
16
17from sas.sascalc.fit.qsmearing import smear_selection
18
19class FitProblem(object):
20    """
21    Define the relationship between data and model, including range, weights,
22    etc.
23    """
24    def __init__(self):
25        """
26        contains information about data and model to fit
27        """
28        ## data used for fitting
29        self.fit_data = None
30        self.theory_data = None
31        self.residuals = None
32        # original data: should not be modified
33        self.original_data = None
34        ## the current model
35        self.model = None
36        ## if 1 this fit problem will be selected to fit , if 0
37        ## it will not be selected for fit
38        self.schedule = 0
39        ##list containing parameter name and value
40        self.list_param = []
41        self.list_param2fit = []
42        ## smear object to smear or not data1D
43        self.smearer_computed = False
44        self.smearer_enable = False
45        self.smearer_computer_value = None
46        ## fitting range
47        self.qmin = None
48        self.qmax = None
49        # fit weight
50        self.weight = None
51        self.result = None
52        self.fit_tab_caption = None
53        self.name_per_page = None
54
55    def enable_smearing(self, flag=False):
56        """
57        :param flag: bool.When flag is 1 get the computer smear value. When
58            flag is 0 ingore smear value.
59        """
60        self.smearer_enable = flag
61
62    def set_smearer(self, smearer):
63        """
64        save reference of  smear object on fitdata
65
66        :param smear: smear object from DataLoader
67
68        """
69        self.smearer_computer_value = smearer
70
71    def get_smearer(self):
72        """
73        return smear object
74        """
75        if not self.smearer_enable:
76            return None
77        if not self.smearer_computed:
78            #smeari_selection should be call only once per fitproblem
79            self.smearer_computer_value = smear_selection(self.fit_data,
80                                                          self.model)
81            self.smearer_computed = True
82        return self.smearer_computer_value
83
84    def save_model_name(self, name):
85        """
86        """
87        self.name_per_page = name
88
89    def get_name(self):
90        """
91        """
92        return self.name_per_page
93
94    def set_model(self, model):
95        """
96        associates each model with its new created name
97        :param model: model selected
98        :param name: name created for model
99        """
100        self.model = model
101        self.smearer_computer_value = smear_selection(self.fit_data,
102                                                      self.model)
103        self.smearer_computed = True
104
105    def get_model(self):
106        """
107        :return: saved model
108        """
109        return self.model
110
111    def set_residuals(self, residuals):
112        """
113        save a copy of residual
114        :param data: data selected
115        """
116        self.residuals = residuals
117
118    def get_residuals(self):
119        """
120        :return: residuals
121        """
122        return self.residuals
123
124    def set_theory_data(self, data):
125        """
126        save a copy of the data select to fit
127
128        :param data: data selected
129
130        """
131        self.theory_data = copy.deepcopy(data)
132
133    def get_theory_data(self):
134        """
135        :return: theory generated with the current model and data of this class
136        """
137        return self.theory_data
138
139    def set_fit_data(self, data):
140        """
141        Store data associated with this class
142        :param data: list of data selected
143        """
144        self.original_data = None
145        self.fit_data = None
146        # original data: should not be modified
147        self.original_data = data
148        # fit data: used for fit and can be modified for convenience
149        self.fit_data = copy.deepcopy(data)
150        self.smearer_computer_value = smear_selection(self.fit_data, self.model)
151        self.smearer_computed = True
152        self.result = None
153
154    def get_fit_data(self):
155        """
156        :return: data associate with this class
157        """
158        return self.fit_data
159
160    def get_origin_data(self):
161        """
162        """
163        return self.original_data
164
165    def set_weight(self, is2d, flag=None):
166        """
167        Received flag and compute error on data.
168        :param flag: flag to transform error of data.
169        :param is2d: flag to distinguish 1D to 2D Data
170        """
171        from sas.sasgui.perspectives.fitting.utils import get_weight
172        # send original data for weighting
173        self.weight = get_weight(data=self.original_data, is2d=is2d, flag=flag)
174        if is2d:
175            self.fit_data.err_data = self.weight
176        else:
177            self.fit_data.dy = self.weight
178
179    def get_weight(self):
180        """
181        returns weight array
182        """
183        return self.weight
184
185    def set_param2fit(self, list):
186        """
187        Store param names to fit (checked)
188        :param list: list of the param names
189        """
190        self.list_param2fit = list
191
192    def get_param2fit(self):
193        """
194        return the list param names to fit
195        """
196        return self.list_param2fit
197
198    def set_model_param(self, name, value=None):
199        """
200        Store the name and value of a parameter of this fitproblem's model
201        :param name: name of the given parameter
202        :param value: value of that parameter
203        """
204        self.list_param.append([name, value])
205
206    def get_model_param(self):
207        """
208        return list of couple of parameter name and value
209        """
210        return self.list_param
211
212    def schedule_tofit(self, schedule=0):
213        """
214        set schedule to true to decide if this fit  must be performed
215        """
216        self.schedule = schedule
217
218    def get_scheduled(self):
219        """
220        return true or false if a problem as being schedule for fitting
221        """
222        return self.schedule
223
224    def set_range(self, qmin=None, qmax=None):
225        """
226        set fitting range
227        :param qmin: minimum value to consider for the fit range
228        :param qmax: maximum value to consider for the fit range
229        """
230        self.qmin = qmin
231        self.qmax = qmax
232
233    def get_range(self):
234        """
235        :return: fitting range
236
237        """
238        return self.qmin, self.qmax
239
240    def clear_model_param(self):
241        """
242        clear constraint info
243        """
244        self.list_param = []
245
246    def set_fit_tab_caption(self, caption):
247        """
248        """
249        self.fit_tab_caption = str(caption)
250
251    def get_fit_tab_caption(self):
252        """
253        """
254        return self.fit_tab_caption
255
256    def set_graph_id(self, id):
257        """
258        Set graph id (from data_group_id at the time the graph produced)
259        """
260        self.graph_id = id
261
262    def get_graph_id(self):
263        """
264        Get graph_id
265        """
266        return self.graph_id
267
268    def set_result(self, result):
269        """
270        """
271        self.result = result
272
273    def get_result(self):
274        """
275        get result
276        """
277        return self.result
278
279
280class FitProblemDictionary(dict):
281    """
282    This module implements a dictionary of fitproblem objects
283    """
284    def __init__(self):
285        dict.__init__(self)
286        ## the current model
287        self.model = None
288        ## if 1 this fit problem will be selected to fit , if 0
289        ## it will not be selected for fit
290        self.schedule = 0
291        ##list containing parameter name and value
292        self.list_param = []
293        ## fitting range
294        self.qmin = None
295        self.qmax = None
296        self.graph_id = None
297        self._smear_on = False
298        self.scheduled = 0
299        self.fit_tab_caption = ''
300        self.nbr_residuals_computed = 0
301        self.batch_inputs = {}
302        self.batch_outputs = {}
303
304    def enable_smearing(self, flag=False, fid=None):
305        """
306        :param flag: bool.When flag is 1 get the computer smear value. When
307            flag is 0 ingore smear value.
308        """
309        self._smear_on = flag
310        if fid is None:
311            for value in self.values():
312                value.enable_smearing(flag)
313        elif fid in self:
314            self[fid].enable_smearing(flag)
315
316    def set_smearer(self, smearer, fid=None):
317        """
318        save reference of  smear object on fitdata
319        :param smear: smear object from DataLoader
320        """
321        if fid is None:
322            for value in self.values():
323                value.set_smearer(smearer)
324        elif fid in self:
325            self[fid].set_smearer(smearer)
326
327    def get_smearer(self, fid=None):
328        """
329        return smear object
330        """
331        if fid in self:
332            return self[fid].get_smearer()
333
334    def save_model_name(self, name, fid=None):
335        """
336        """
337        if fid is None:
338            for value in self.values():
339                value.save_model_name(name)
340        elif fid in self:
341            self[fid].save_model_name(name)
342
343    def get_name(self, fid=None):
344        """
345        """
346        result = []
347        if fid is None:
348            for value in self.values():
349                result.append(value.get_name())
350        elif fid in self:
351            result.append(self[fid].get_name())
352        return result
353
354    def set_model(self, model, fid=None):
355        """
356        associates each model with its new created name
357        :param model: model selected
358        :param name: name created for model
359        """
360        self.model = model
361        if fid is None:
362            for value in self.values():
363                value.set_model(self.model)
364        elif fid in self:
365            self[fid].set_model(self.model)
366
367    def get_model(self, fid):
368        """
369        :return: saved model
370        """
371        if fid in self:
372            return self[fid].get_model()
373
374    def set_fit_tab_caption(self, caption):
375        """
376        store the caption of the page associated with object
377        """
378        self.fit_tab_caption = caption
379
380    def get_fit_tab_caption(self):
381        """
382        Return the caption of the page associated with object
383        """
384        return self.fit_tab_caption
385
386    def set_residuals(self, residuals, fid):
387        """
388        save a copy of residual
389        :param data: data selected
390        """
391        if fid in self:
392            self[fid].set_residuals(residuals)
393
394    def get_residuals(self, fid):
395        """
396        :return: residuals
397        """
398        if fid in self:
399            return self[fid].get_residuals()
400
401    def set_theory_data(self, fid, data=None):
402        """
403        save a copy of the data select to fit
404        :param data: data selected
405        """
406        if fid in self:
407            self[fid].set_theory_data(data)
408
409    def get_theory_data(self, fid):
410        """
411        :return: list of data dList
412        """
413        if fid in self:
414            return self[fid].get_theory_data()
415
416    def add_data(self, data):
417        """
418        Add data to the current dictionary of fitproblem. if data id does not
419        exist create a new fit problem.
420        :note: only data changes in the fit problem
421        """
422        if data.id not in self:
423            self[data.id] = FitProblem()
424        self[data.id].set_fit_data(data)
425
426    def set_fit_data(self, data):
427        """
428        save a copy of the data select to fit
429        :param data: data selected
430
431        """
432        self.clear()
433        if data is None:
434            data = []
435        for d in data:
436            if d is not None:
437                if d.id not in self:
438                    self[d.id] = FitProblem()
439                self[d.id].set_fit_data(d)
440                self[d.id].set_model(self.model)
441                self[d.id].set_range(self.qmin, self.qmax)
442
443    def get_fit_data(self, fid):
444        """
445        return data for the given fitproblem id
446        :param fid: key representing a fitproblem, usually extract from data id
447        """
448        if fid in self:
449            return self[fid].get_fit_data()
450
451    def set_model_param(self, name, value=None, fid=None):
452        """
453        Store the name and value of a parameter of this fitproblem's model
454        :param name: name of the given parameter
455        :param value: value of that parameter
456        """
457        if fid is None:
458            for value in self.values():
459                value.set_model_param(name, value)
460        elif fid in self:
461            self[fid].set_model_param(name, value)
462
463    def get_model_param(self, fid):
464        """
465        return list of couple of parameter name and value
466        """
467        if fid in self:
468            return self[fid].get_model_param()
469
470    def set_param2fit(self, list):
471        """
472        Store param names to fit (checked)
473        :param list: list of the param names
474        """
475        self.list_param2fit = list
476
477    def get_param2fit(self):
478        """
479        return the list param names to fit
480        """
481        return self.list_param2fit
482
483    def schedule_tofit(self, schedule=0):
484        """
485        set schedule to true to decide if this fit  must be performed
486        """
487        self.scheduled = schedule
488        for value in self.values():
489            value.schedule_tofit(schedule)
490
491    def get_scheduled(self):
492        """
493        return true or false if a problem as being schedule for fitting
494        """
495        return self.scheduled
496
497    def set_range(self, qmin=None, qmax=None, fid=None):
498        """
499        set fitting range
500        """
501        self.qmin = qmin
502        self.qmax = qmax
503        if fid is None:
504            for value in self.values():
505                value.set_range(self.qmin, self.qmax)
506        elif fid in self:
507            self[fid].value.set_range(self.qmin, self.qmax)
508
509    def get_range(self, fid):
510        """
511        :return: fitting range
512        """
513        if fid in self:
514            return self[fid].get_range()
515
516    def set_weight(self, is2d, flag=None, fid=None):
517        """
518        fit weight
519        """
520        if fid is None:
521            for value in self.values():
522                value.set_weight(flag=flag, is2d=is2d)
523        elif fid in self:
524            self[fid].set_weight(flag=flag, is2d=is2d)
525
526    def get_weight(self, fid=None):
527        """
528        return fit weight
529        """
530        if fid in self:
531            return self[fid].get_weight()
532
533    def clear_model_param(self, fid=None):
534        """
535        clear constraint info
536        """
537        if fid is None:
538            for value in self.values():
539                value.clear_model_param()
540        elif fid in self:
541            self[fid].clear_model_param()
542
543    def get_fit_problem(self):
544        """
545        return fitproblem contained in this dictionary
546        """
547        return self.values()
548
549    def set_result(self, result, fid):
550        """
551        """
552        if fid in self:
553            self[fid].set_result(result)
554
555    def set_batch_result(self, batch_inputs, batch_outputs):
556        """
557        set a list of result
558        """
559        self.batch_inputs = batch_inputs
560        self.batch_outputs = batch_outputs
561
562    def get_result(self, fid):
563        """
564        get result
565        """
566        if fid in self:
567            return self[fid].get_result()
568
569    def get_batch_result(self):
570        """
571        get result
572        """
573        return self.batch_inputs, self.batch_outputs
574
575    def set_graph_id(self, id):
576        """
577        Set graph id (from data_group_id at the time the graph produced)
578        """
579        self.graph_id = id
580
581    def get_graph_id(self):
582        """
583        Get graph_id
584        """
585        return self.graph_id
Note: See TracBrowser for help on using the repository browser.