source: sasview/park-1.2.1/park/fit.py @ bbbed8c

ESS_GUIESS_GUI_DocsESS_GUI_batch_fittingESS_GUI_bumps_abstractionESS_GUI_iss1116ESS_GUI_iss879ESS_GUI_iss959ESS_GUI_openclESS_GUI_orderingESS_GUI_sync_sascalccostrafo411magnetic_scattrelease-4.1.1release-4.1.2release-4.2.2release_4.0.1ticket-1009ticket-1094-headlessticket-1242-2d-resolutionticket-1243ticket-1249ticket885unittest-saveload
Last change on this file since bbbed8c was 3570545, checked in by Mathieu Doucet <doucetm@…>, 13 years ago

Adding park Part 2

  • Property mode set to 100644
File size: 11.2 KB
Line 
1# This program is public domain
2"""
3Fitting service interface.
4
5A fit consists of a set of models and a fitting engine.  The models are
6collected in an assembly, which manages the parameter set and the
7constraints between them.  The models themselves are tightly coupled
8to the data that they are modeling and the data is invisible to the fit.
9
10The fitting engine can use a variety of methods depending on model.
11
12
13Usage
14=====
15
16The fitter can be run directly on the local machine::
17
18    import park
19    M1 = park.models.Peaks(datafile=park.sampledata('peak.dat'))
20    M1.add_peak('P1', 'gaussian', A=[4,6], mu=[0.2, 0.5], sigma=0.1)
21    result = park.fit(models=[M1])
22    print result
23
24The default settings print results every time the fit improves, and
25print a global result when the fit is complete.  This is a suitable
26interface for a fitting script.
27
28For larger fit jobs you will want to run the fit on a remote server.
29The model setup is identical, but the fit call is different::
30
31    service = park.FitService('server:port')
32    result = park.fit(models=[M1], service=service)
33    print result
34
35Again, the default settings print results every time the fit improves,
36and print a global result when the fit is complete.
37
38For long running fit jobs, you want to be able to disconnect from
39the server after submitting the job, and later reconnect to fetch
40the results.  An additional email field will send notification by
41email when the fit starts and ends, and daily updates on the status
42of all fits::
43
44    service = park.FitService('server:port')
45    service.notify(email='me@my.email.address',update='daily')
46    fit = park.Fit(models=[M1])
47    id = service.submit_job(fit, jobname='peaks')
48    print id
49
50The results can be retrieved either by id returned from the server,
51or by the given jobname::
52
53    import park
54    service = park.FitService('server:port',user='userid')
55    fitlist = service.retrieve('peaks')
56    for fit in fitlist:
57        print fit.summary()
58
59The fit itself is a complicated object, including the model, the
60optimizer, and the type of uncertainty analysis to perform.
61
62GUI Usage
63=========
64
65When used from a graphical user interface, a different programming
66interface is needed.  In this case, the user may want to watch
67the progress of the fit and perhaps stop it.  Also, as fits can
68take some time to complete, the user would like to be able to
69set up additional fits and run them at the same time, switching
70between them as necessary to monitor progress.
71
72"""
73import time, thread
74
75import numpy
76
77import assembly, fitresult
78
79class Objective(object):
80    """
81    Abstract interface to the fitness function for the park minimizer
82    classes.
83
84    Park provides a specific implementation `park.assembly.Assembly`.
85
86    TODO: add a results() method to return model specific info to the
87    TODO: fit handler.
88    """
89    def residuals(self, p):
90        """
91        Some fitters, notably Levenberg-Marquardt, operate directly on the
92        residuals vector.  If the individual residuals are not available,
93        then LM cannot be used.
94
95        This method is optional.
96        """
97        raise NotImplementedError
98
99    def residuals_deriv(self, p):
100        """
101        Returns residuals and derivatives with respect to the given
102        parameters.
103
104        If these are unavailable in the model, then they can be approximated
105        by numerical derivatives, though it is generally better to use a
106        derivative free optimizer such as coliny or cobyla which can use the
107        function evaluations more efficiently.  In any case, your objective
108        function is responsible for calculating these.
109
110        This method is optional.
111        """
112        raise NotImplementedError
113
114    def fit_parameters(self):
115        """
116        Returns a list of fit parameters.  Each parameter has a name,
117        an initial value and a range.
118
119        See `park.fitresult.FitParameter` for an example.
120
121        On each function evaluation a new parameter set will be passed
122        to the fitter, with values in the same order as the list of
123        parameters.
124        """
125        raise NotImplementedError
126
127    def __call__(self, p):
128        """
129        Returns the objective value for parameter set p .
130        """
131        raise NotImplementedError
132
133    def abort(self):
134        """
135        Halts the current function evaluation, and has it return inf.
136        This will be called from a separate thread.  If the function
137        contains an expensive calculation, it should reset an abort
138        flag before each evaluation and test it periodically.
139
140        This method is optional.
141        """
142
143class Fitter(object):
144    """Abstract interface for a fitness optimizer.
145
146    A fitter has a single method, fit, which takes an objective
147    function (`park.fit.Objective`) and a handler.
148
149    For a concrete instance see `park.fitmc.FitMC`.
150    """
151    def __init__(self, **kw):
152        for k,v in kw.items():
153            if hasattr(self,k):
154                setattr(self,k,v)
155            else:
156                raise AttributeError(k+" is not an attribute of "+self.__class__.__name__)
157
158    def _threaded(self, fn, *args, **kw):
159        thread.start_new_thread(fn,args,kw)
160
161
162    def _fit(self, objective, x0, bounds):
163        """
164        Run the actual fit in a separate thread
165
166        Each cycle k of n:
167            self.handler.progress(k,n)
168        Each improvement:
169            self.handler.result.update(x,fx,ncalls)
170            self.handler.improvement()
171        On completion (if not already performed):
172            self.hander.result.update(x,fx,ncalls)
173            self.handler.done
174            self.handler.finalize()
175        """
176        raise NotImplementedError
177
178    def fit(self, fitness, handler):
179        """
180        Global optimizer.
181
182        This function should return immediately
183        """
184        # Determine initial value and bounds
185        pars = fitness.fit_parameters()
186        bounds = numpy.array([p.range for p in pars]).T
187        x0 = [p.value for p in pars]
188
189        # Initialize the monitor and results.
190        # Need to make our own copy of the fit results so that the
191        # values don't get stomped on by the next fit iteration.
192        handler.done = False
193        self.handler = handler
194        fitpars = [fitresult.FitParameter(pars[i].name,pars[i].range,v)
195                   for i,v in enumerate(x0)]
196        handler.result = fitresult.FitResult(fitpars, 0, numpy.NaN)
197
198        # Run the fit (fit should perform _progress and _improvement updates)
199        # This function may return before the fit is complete.
200        self._fit(fitness, x0, bounds)
201
202class FitJob(object):
203    """
204    Fit job.
205
206    This implements `park.job.Job`.
207    """
208    def __init__(self, objective=None, fitter=None, handler=None):
209        self.fitter = fitter
210        self.objective = objective
211        self.handler = handler
212    def run(self):
213        self.fitter.fit(self.objective, self.handler)
214
215class LocalQueue(object):
216    """
217    Simple interface to the local job queue.  Currently supports start and
218    wait.  Needs to support stop and status.  Also, needs to be a proper queue,
219    and needs to allow jobs to run in separate processes according to priority,
220    etc.  All the essentials of the remote queuing system without the remote
221    calls.
222
223    Unlike the remote queue, the local queue need not maintain persistence.
224    """
225    running = False
226    def start(self, job):
227        self.job = job
228        job.run()
229        return id(job)
230
231    def wait(self, interval=.1):
232        """
233        Wait for the job to complete.  This is used in scripts to impose
234        a synchronous interface to the fitting service.
235        """
236        while not self.job.handler.done:
237            time.sleep(interval)
238        return self.job.handler.result
239
240def fit(models=None, fitter=None, service=None, handler=None):
241    """
242    Start a fit with a set of models.  The model set must be
243    in a form accepted by `park.assembly.Assembly`.
244
245    This is a convenience function which sets up the default
246    optimizer and uses the local fitting engine to do the work.
247    Progress reports are printed as they are received.
248
249    The choice of fitter, service and handler can be specified
250    by the caller.
251
252    The default fitter is FitMC, which is a monte carlo Nelder-Mead
253    simplex local optimizer with 100 random start points.
254
255    The default handler does nothing.  Instead, ConsoleUpdate could
256    be used to report progress during the fit.
257
258    The default service is to run in a separate thread with FitThread.
259    Note that this will change soon to run in a separate process on
260    the local machine so that python's global interpreter lock does
261    not interfere with parallelism.
262    """
263    if models is None: raise RuntimeError('fit expected a list of models')
264    if service is None: service = LocalQueue()
265    if fitter is None:
266        import fitmc
267        fitter = fitmc.FitMC()
268    if handler is None: handler = fitresult.FitHandler()
269
270    objective = assembly.Assembly(models) if isinstance(models,list) else models
271    job = FitJob(objective,fitter,handler)
272    service.start(job)
273    return service.wait()
274
275
276def assembly_example():
277    import park, time
278    problem = park.assembly.example()
279    #result = fit(problem)
280    #result.print_summary()
281    handler=fitresult.ConsoleUpdate(improvement_delta=0.1,progress_delta=1)
282    #result = fit(problem, handler=handler)
283    result = fit(problem)
284    print "=== Fit complete ==="
285    result.print_summary()
286    print "=== Target values ==="
287    print "M1: a=1, c=1.5"
288    print "M2: a=2.5, c=3"
289
290    if False:  # Detailed results
291        print "parameter vector",result.pvec
292        problem(result.pvec)
293        print "residuals",problem.residuals
294        for k,m in enumerate(problem.parts):
295            print "Model",k,"chisq",m.chisq,"weight",m.weight
296            print "pars",m.fitness.model.a,m.fitness.model.c
297            print "x",m.fitness.data.fit_x
298            print "y",m.fitness.data.fit_y
299            print "f(x)",m.fitness.data.fx
300            print "(y-f(x))/dy",m.residuals
301
302
303def demo(fitter=None):
304    """Multiple minima example"""
305    import time, math
306    class MultiMin(object):
307        def fit_parameters(self):
308            return [fitresult.FitParameter('x1',[-5,5],1)]
309        def __call__(self, p):
310            x=p[0]
311            fx = x**2 + math.sin(2*math.pi*x+3*math.pi/2)
312            return fx
313    handler = fitresult.ConsoleUpdate() # Show updates on the console
314    handler.progress_delta = 1          # Update user every second
315    handler.improvement_delta = 0.1     # Show improvements almost immediately
316    fitter.fit(MultiMin(), handler)
317    while not handler.done: time.sleep(1)
318
319def demo2(fitter=None):
320    import park, time
321    problem = park.assembly.example()
322    handler = fitresult.ConsoleUpdate() # Show updates on the console
323    handler.progress_delta = 1          # Update user every second
324    handler.improvement_delta = 1       # Show improvements at the same rate
325    fitter.fit(problem, handler)
326    while not handler.done: time.sleep(1)
327
328
329
330if __name__ == "__main__":
331    #main()
332    assembly_example()
Note: See TracBrowser for help on using the repository browser.