source: sasmodels/example/batch_fit.py @ c5b059c

Last change on this file since c5b059c was 8a5f021, checked in by Paul Kienzle <pkienzle@…>, 7 years ago

batch fit example: PEP8 cleanup; add .csv output; all pars on one plot; more flexible option handling

  • Property mode set to 100755
File size: 4.4 KB
Line 
1#!/usr/bin/env python
2"""
3Script to run a batch fit in a series of files and plot the fitted parameters.
4
5Usage syntax::
6
7    python batch_fit.py model.py sample1.dat sample2.dat ... other_sample.dat
8    (files named sample1.dat, sample2.dat, ..., other_sample.dat)
9
10or if the file names are numbers (and the extension is .dat)::
11
12    python batch_fit.py model.py 93190 93210
13    (files named 093190.dat, 093191.dat, ..., 093210.dat)
14
15or for Grasp-like naming::
16
17    python batch_fit.py model.py 93190 93210 200
18    (files named 093190_200.dat, 093191_201.dat, ..., 093210_220.dat)
19
20The script reads a series of files and fits the model defined by model.py.
21For example model_ellipsoid_hayter_msa.py fits a model consisting in an
22ellipsoid form factor multiplied by a Hayter MSA structure factor.
23
24The file *model.py* must load the data using::
25
26    data = load_data(sys.argv[1])
27
28Include options to bumps (minimizer, steps, etc.) as desired.  For example::
29
30    python batch_fit.py model.py 93190 93210 200 --fit=lm --steps=200 --ftol=1.5e-8 --xtol=1.5e-8
31
32Fit options can come before or after the model and files.
33
34For each file a directory named Fit_filename is created. There the file
35model.par contains the fitted parameters.  These are gathered together into
36batch_fit.csv in the current directory.
37
38Finally the fitted parameters are plotted for the full series.
39
40Example::
41
42    python batch_fit.py model_ellipsoid_hayter_msa.py 93191 93195 201
43
44Note:
45
46    If sasmodels, sasview or bumps are not in the path, use the PYTHONPATH
47    environment variable to set them.
48"""
49from __future__ import print_function
50
51import sys
52import os
53
54import numpy as np
55import matplotlib.pyplot as plt
56from bumps.dream.views import tile_axes  # make a grid of plots
57
58# GET INPUT AND ENSURE MODEL AND DATA FILES ARE DEFINED
59
60fit_opts = [v for v in sys.argv[1:] if v.startswith('--')]
61args = [v for v in sys.argv[1:] if not v.startswith('--')]
62
63nargs = len(args)
64if nargs < 2:
65    print ("Error in the list of arguments! \n")
66    sys.exit()
67
68model_file = args[0]
69if not model_file.endswith('.py'):
70    print("Expected model.py as the first argument")
71    sys.exit(1)
72
73if '.' in args[1]:
74    data_files = args[1:]
75else:
76    first = int(args[1])
77    last = int(args[2])
78    count = last-first+1
79    data_files = []
80    if nargs == 3:
81        data_files = ['%06d.dat'%(first+i) for i in range(count)]
82    elif nargs == 4:
83        ext = int(args[3])
84        data_files = ['%06d_%d.dat'%(first+i, ext+i) for i in range(count)]
85    else:
86        print("Unexpected arguments: " + " ".join(args[4:]))
87        sys.exit(1)
88
89# CHECK THAT DATA FILES EXIST
90missing = [filename for filename in data_files if not os.path.isfile(filename)]
91if missing:
92    print("Missing data files: %s" % ", ".join(missing))
93    sys.exit(1)
94
95# STORE DIRECTORY FOR BUMPS FITS
96def fit_dir(filename):
97    "Return the store directory name for the given file"
98    return "Fit_" + os.path.splitext(filename)[0]
99
100# LOOP OVER FILES AND CALL TO BUMPS FOR EACH OF THEM
101bumps_cmd = "python -m bumps.cli --batch"
102fit_opts = " ".join(fit_opts)
103for data_file in data_files:
104    store_opts = "--store=" + fit_dir(data_file)
105    cmd = " ".join((bumps_cmd, fit_opts, store_opts, model_file, data_file))
106    os.system(cmd)
107
108# GATHER RESULTS
109results = {}
110par_file = os.path.splitext(model_file)[0] + '.par'
111for data_file in data_files:
112    with open(os.path.join(fit_dir(data_file), par_file), 'r') as fid:
113        for line in fid:
114            parameter, value = line.split()
115            results.setdefault(parameter, []).append(float(value))
116
117# SAVE RESULTS INTO FILE
118with open('batch_fit.csv', 'w') as fid:
119    parameters = list(sorted(results.keys()))
120    values_by_file = zip(*(v for k, v in sorted(results.items())))
121    fid.write(','.join(['filename'] + parameters) + '\n')
122    for filename, values in zip(data_files, values_by_file):
123        fid.write(','.join([filename] + [str(v) for v in values]) + '\n')
124
125# SHOW FITTED PARAMETERS
126nh, nw = tile_axes(len(results))
127ticks = np.arange(1, len(data_files)+1)
128labels = [os.path.splitext(filename)[0] for filename in data_files]
129for k, (parameter, values) in enumerate(sorted(results.items())):
130    plt.subplot(nh, nw, k+1)
131    plt.plot(ticks, values)
132    plt.xlim(ticks[0]-0.5, ticks[-1]+0.5)
133    if k%nh == nh-1:
134        #plt.xlabel('Dataset #')
135        plt.xticks(ticks, labels, rotation=30)
136    else:
137        plt.xticks(ticks, [' ']*len(labels))
138    plt.ylabel(parameter)
139plt.suptitle("Fit " + args[0])
140plt.show()
Note: See TracBrowser for help on using the repository browser.