source: sasmodels/explore/precision.py @ eb2946f

core_shell_microgelscostrafo411magnetic_modelticket-1257-vesicle-productticket_1156ticket_1265_superballticket_822_more_unit_tests
Last change on this file since eb2946f was eb2946f, checked in by Paul Kienzle <pkienzle@…>, 7 years ago

improve precision of sasmodels special functions

  • Property mode set to 100755
File size: 15.5 KB
Line 
1#!/usr/bin/env python
2r"""
3Show numerical precision of $2 J_1(x)/x$.
4"""
5from __future__ import division, print_function
6
7import sys
8import os
9sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
10
11import numpy as np
12from numpy import pi, inf
13import scipy.special
14try:
15    from mpmath import mp
16except ImportError:
17    # CRUFT: mpmath split out into its own package
18    from sympy.mpmath import mp
19#import matplotlib; matplotlib.use('TkAgg')
20import pylab
21
22from sasmodels import core, data, direct_model, modelinfo
23
24class Comparator(object):
25    def __init__(self, name, mp_function, np_function, ocl_function, xaxis, limits):
26        self.name = name
27        self.mp_function = mp_function
28        self.np_function = np_function
29        self.ocl_function = ocl_function
30        self.xaxis = xaxis
31        self.limits = limits
32
33    def __repr__(self):
34        return "Comparator(%s)"%self.name
35
36    def call_mpmath(self, vec, bits=500):
37        """
38        Direct calculation using mpmath extended precision library.
39        """
40        with mp.workprec(bits):
41            return [self.mp_function(mp.mpf(x)) for x in vec]
42
43    def call_numpy(self, x, dtype):
44        """
45        Direct calculation using numpy/scipy.
46        """
47        x = np.asarray(x, dtype)
48        return self.np_function(x)
49
50    def call_ocl(self, x, dtype, platform='ocl'):
51        """
52        Calculation using sasmodels ocl libraries.
53        """
54        x = np.asarray(x, dtype)
55        model = core.build_model(self.ocl_function, dtype=dtype)
56        calculator = direct_model.DirectModel(data.empty_data1D(x), model)
57        return calculator(background=0)
58
59    def run(self, xrange="log", diff=True):
60        r"""
61        Compare accuracy of different methods for computing f.
62
63        *xrange* is log=[10^-3,10^5], linear=[1,1000], zoom[1000,1010],
64        or neg=[-100,100].
65
66        *diff* is False if showing function value rather than relative error.
67
68        *x_bits* is the precision with which the x values are specified.  The
69        default 23 should reproduce the equivalent of a single precisio
70        """
71        linear = xrange != "log"
72        if xrange == "zoom":
73            lin_min, lin_max, lin_steps = 1000, 1010, 2000
74        elif xrange == "neg":
75            lin_min, lin_max, lin_steps = -100.1, 100.1, 2000
76        else:
77            lin_min, lin_max, lin_steps = 1, 1000, 2000
78        lin_min = max(lin_min, self.limits[0])
79        lin_max = min(lin_max, self.limits[1])
80        log_min, log_max, log_steps = -3, 5, 400
81        with mp.workprec(500):
82            if linear:
83                qrf = np.linspace(lin_min, lin_max, lin_steps, dtype='single')
84                qr = [mp.mpf(float(v)) for v in qrf]
85                #qr = mp.linspace(lin_min, lin_max, lin_steps)
86            else:
87                qrf = np.logspace(log_min, log_max, log_steps, dtype='single')
88                qr = [mp.mpf(float(v)) for v in qrf]
89                #qr = [10**v for v in mp.linspace(log_min, log_max, log_steps)]
90
91        target = self.call_mpmath(qr, bits=500)
92        pylab.subplot(121)
93        self.compare(qr, 'single', target, linear, diff)
94        pylab.legend(loc='best')
95        pylab.subplot(122)
96        self.compare(qr, 'double', target, linear, diff)
97        pylab.legend(loc='best')
98        pylab.suptitle(self.name + " compared to 500-bit mpmath")
99
100    def compare(self, x, precision, target, linear=False, diff=True):
101        r"""
102        Compare the different computation methods using the given precision.
103        """
104        if precision == 'single':
105            #n=11; plotdiff(x, target, self.call_mpmath(x, n), 'mp %d bits'%n, diff=diff)
106            #n=23; plotdiff(x, target, self.call_mpmath(x, n), 'mp %d bits'%n, diff=diff)
107            pass
108        elif precision == 'double':
109            #n=53; plotdiff(x, target, self.call_mpmath(x, n), 'mp %d bits'%n, diff=diff)
110            #n=83; plotdiff(x, target, self.call_mpmath(x, n), 'mp %d bits'%n, diff=diff)
111            pass
112        plotdiff(x, target, self.call_numpy(x, precision), 'numpy '+precision, diff=diff)
113        plotdiff(x, target, self.call_ocl(x, precision, 0), 'OpenCL '+precision, diff=diff)
114        pylab.xlabel(self.xaxis)
115        if diff:
116            pylab.ylabel("relative error")
117        else:
118            pylab.ylabel(self.name)
119            pylab.semilogx(x, target, '-', label="true value")
120        if linear:
121            pylab.xscale('linear')
122
123def plotdiff(x, target, actual, label, diff=True):
124    """
125    Plot the computed value.
126
127    Use relative error if SHOW_DIFF, otherwise just plot the value directly.
128    """
129    if diff:
130        err = np.array([abs((t-a)/t) for t, a in zip(target, actual)], 'd')
131        #err = np.clip(err, 0, 1)
132        pylab.loglog(x, err, '-', label=label)
133    else:
134        limits = np.min(target), np.max(target)
135        pylab.semilogx(x, np.clip(actual, *limits), '-', label=label)
136
137def make_ocl(function, name, source=[]):
138    class Kernel(object):
139        pass
140    Kernel.__file__ = name+".py"
141    Kernel.name = name
142    Kernel.parameters = []
143    Kernel.source = source
144    Kernel.Iq = function
145    model_info = modelinfo.make_model_info(Kernel)
146    return model_info
147
148
149# =============== FUNCTION DEFINITIONS ================
150
151FUNCTIONS = {}
152def add_function(name, mp_function, np_function, ocl_function,
153                 shortname=None, xaxis="x", limits=(-inf, inf)):
154    if shortname is None:
155        shortname = name.replace('(x)', '').replace(' ', '')
156    FUNCTIONS[shortname] = Comparator(name, mp_function, np_function, ocl_function, xaxis, limits)
157
158add_function(
159    name="J0(x)",
160    mp_function=mp.j0,
161    np_function=scipy.special.j0,
162    ocl_function=make_ocl("return sas_J0(q);", "sas_J0", ["lib/polevl.c", "lib/sas_J0.c"]),
163)
164add_function(
165    name="J1(x)",
166    mp_function=mp.j1,
167    np_function=scipy.special.j1,
168    ocl_function=make_ocl("return sas_J1(q);", "sas_J1", ["lib/polevl.c", "lib/sas_J1.c"]),
169)
170add_function(
171    name="JN(-3, x)",
172    mp_function=lambda x: mp.besselj(-3, x),
173    np_function=lambda x: scipy.special.jn(-3, x),
174    ocl_function=make_ocl("return sas_JN(-3, q);", "sas_JN",
175                          ["lib/polevl.c", "lib/sas_J0.c", "lib/sas_J1.c", "lib/sas_JN.c"]),
176    shortname="J-3",
177)
178add_function(
179    name="JN(3, x)",
180    mp_function=lambda x: mp.besselj(3, x),
181    np_function=lambda x: scipy.special.jn(3, x),
182    ocl_function=make_ocl("return sas_JN(3, q);", "sas_JN",
183                          ["lib/polevl.c", "lib/sas_J0.c", "lib/sas_J1.c", "lib/sas_JN.c"]),
184    shortname="J3",
185)
186add_function(
187    name="JN(2, x)",
188    mp_function=lambda x: mp.besselj(2, x),
189    np_function=lambda x: scipy.special.jn(2, x),
190    ocl_function=make_ocl("return sas_JN(2, q);", "sas_JN",
191                          ["lib/polevl.c", "lib/sas_J0.c", "lib/sas_J1.c", "lib/sas_JN.c"]),
192    shortname="J2",
193)
194add_function(
195    name="2 J1(x)/x",
196    mp_function=lambda x: 2*mp.j1(x)/x,
197    np_function=lambda x: 2*scipy.special.j1(x)/x,
198    ocl_function=make_ocl("return sas_2J1x_x(q);", "sas_2J1x_x", ["lib/polevl.c", "lib/sas_J1.c"]),
199)
200add_function(
201    name="J1(x)",
202    mp_function=mp.j1,
203    np_function=scipy.special.j1,
204    ocl_function=make_ocl("return sas_J1(q);", "sas_J1", ["lib/polevl.c", "lib/sas_J1.c"]),
205)
206add_function(
207    name="Si(x)",
208    mp_function=mp.si,
209    np_function=lambda x: scipy.special.sici(x)[0],
210    ocl_function=make_ocl("return sas_Si(q);", "sas_Si", ["lib/sas_Si.c"]),
211)
212#import fnlib
213#add_function(
214#    name="fnlibJ1",
215#    mp_function=mp.j1,
216#    np_function=fnlib.J1,
217#    ocl_function=make_ocl("return sas_J1(q);", "sas_J1", ["lib/polevl.c", "lib/sas_J1.c"]),
218#)
219add_function(
220    name="sin(x)",
221    mp_function=mp.sin,
222    np_function=np.sin,
223    #ocl_function=make_ocl("double sn, cn; SINCOS(q,sn,cn); return sn;", "sas_sin"),
224    ocl_function=make_ocl("return sin(q);", "sas_sin"),
225)
226add_function(
227    name="sin(x)/x",
228    mp_function=lambda x: mp.sin(x)/x if x != 0 else 1,
229    ## scipy sinc function is inaccurate and has an implied pi*x term
230    #np_function=lambda x: scipy.special.sinc(x/pi),
231    ## numpy sin(x)/x needs to check for x=0
232    np_function=lambda x: np.sin(x)/x,
233    ocl_function=make_ocl("return sas_sinx_x(q);", "sas_sinc"),
234)
235add_function(
236    name="cos(x)",
237    mp_function=mp.cos,
238    np_function=np.cos,
239    #ocl_function=make_ocl("double sn, cn; SINCOS(q,sn,cn); return cn;", "sas_cos"),
240    ocl_function=make_ocl("return cos(q);", "sas_cos"),
241)
242add_function(
243    name="gamma(x)",
244    mp_function=mp.gamma,
245    np_function=scipy.special.gamma,
246    ocl_function=make_ocl("return sas_gamma(q);", "sas_gamma", ["lib/sas_gamma.c"]),
247    limits=(-3.1,10),
248)
249add_function(
250    name="erf(x)",
251    mp_function=mp.erf,
252    np_function=scipy.special.erf,
253    ocl_function=make_ocl("return sas_erf(q);", "sas_erf", ["lib/polevl.c", "lib/sas_erf.c"]),
254    limits=(-5.,5.),
255)
256add_function(
257    name="erfc(x)",
258    mp_function=mp.erfc,
259    np_function=scipy.special.erfc,
260    ocl_function=make_ocl("return sas_erfc(q);", "sas_erfc", ["lib/polevl.c", "lib/sas_erf.c"]),
261    limits=(-5.,5.),
262)
263add_function(
264    name="arctan(x)",
265    mp_function=mp.atan,
266    np_function=np.arctan,
267    ocl_function=make_ocl("return atan(q);", "sas_arctan"),
268)
269add_function(
270    name="3 j1(x)/x",
271    mp_function=lambda x: 3*(mp.sin(x)/x - mp.cos(x))/(x*x),
272    # Note: no taylor expansion near 0
273    np_function=lambda x: 3*(np.sin(x)/x - np.cos(x))/(x*x),
274    ocl_function=make_ocl("return sas_3j1x_x(q);", "sas_j1c", ["lib/sas_3j1x_x.c"]),
275)
276add_function(
277    name="fmod_2pi",
278    mp_function=lambda x: mp.fmod(x, 2*mp.pi),
279    np_function=lambda x: np.fmod(x, 2*np.pi),
280    ocl_function=make_ocl("return fmod(q, 2*M_PI);", "sas_fmod"),
281)
282
283RADIUS=3000
284LENGTH=30
285THETA=45
286def mp_cyl(x):
287    f = mp.mpf
288    theta = f(THETA)*mp.pi/f(180)
289    qr = x * f(RADIUS)*mp.sin(theta)
290    qh = x * f(LENGTH)/f(2)*mp.cos(theta)
291    return (f(2)*mp.j1(qr)/qr * mp.sin(qh)/qh)**f(2)
292def np_cyl(x):
293    f = np.float64 if x.dtype == np.float64 else np.float32
294    theta = f(THETA)*f(np.pi)/f(180)
295    qr = x * f(RADIUS)*np.sin(theta)
296    qh = x * f(LENGTH)/f(2)*np.cos(theta)
297    return (f(2)*scipy.special.j1(qr)/qr*np.sin(qh)/qh)**f(2)
298ocl_cyl = """\
299    double THETA = %(THETA).15e*M_PI_180;
300    double qr = q*%(RADIUS).15e*sin(THETA);
301    double qh = q*0.5*%(LENGTH).15e*cos(THETA);
302    return square(sas_2J1x_x(qr)*sas_sinx_x(qh));
303"""%{"LENGTH":LENGTH, "RADIUS": RADIUS, "THETA": THETA}
304add_function(
305    name="cylinder(r=%g, l=%g, theta=%g)"%(RADIUS, LENGTH, THETA),
306    mp_function=mp_cyl,
307    np_function=np_cyl,
308    ocl_function=make_ocl(ocl_cyl, "ocl_cyl", ["lib/polevl.c", "lib/sas_J1.c"]),
309    shortname="cylinder",
310    xaxis="$q/A^{-1}$",
311)
312
313lanczos_gamma = """\
314    const double coeff[] = {
315            76.18009172947146,     -86.50532032941677,
316            24.01409824083091,     -1.231739572450155,
317            0.1208650973866179e-2,-0.5395239384953e-5
318            };
319    const double x = q;
320    double tmp  = x + 5.5;
321    tmp -= (x + 0.5)*log(tmp);
322    double ser = 1.000000000190015;
323    for (int k=0; k < 6; k++) ser += coeff[k]/(x + k+1);
324    return -tmp + log(2.5066282746310005*ser/x);
325"""
326add_function(
327    name="log gamma(x)",
328    mp_function=mp.loggamma,
329    np_function=scipy.special.gammaln,
330    ocl_function=make_ocl(lanczos_gamma, "lgamma"),
331)
332
333# Alternate versions of 3 j1(x)/x, for posterity
334def taylor_3j1x_x(x):
335    """
336    Calculation using taylor series.
337    """
338    # Generate coefficients using the precision of the target value.
339    n = 5
340    cinv = [3991680, -45360, 840, -30, 3]
341    three = x.dtype.type(3)
342    p = three/np.array(cinv, x.dtype)
343    return np.polyval(p[-n:], x*x)
344add_function(
345    name="3 j1(x)/x: taylor",
346    mp_function=lambda x: 3*(mp.sin(x)/x - mp.cos(x))/(x*x),
347    np_function=taylor_3j1x_x,
348    ocl_function=make_ocl("return sas_3j1x_x(q);", "sas_j1c", ["lib/sas_3j1x_x.c"]),
349)
350def trig_3j1x_x(x):
351    r"""
352    Direct calculation using linear combination of sin/cos.
353
354    Use the following trig identity:
355
356    .. math::
357
358        a \sin(x) + b \cos(x) = c \sin(x + \phi)
359
360    where $c = \surd(a^2+b^2)$ and $\phi = \tan^{-1}(b/a) to calculate the
361    numerator $\sin(x) - x\cos(x)$.
362    """
363    one = x.dtype.type(1)
364    three = x.dtype.type(3)
365    c = np.sqrt(one + x*x)
366    phi = np.arctan2(-x, one)
367    return three*(c*np.sin(x+phi))/(x*x*x)
368add_function(
369    name="3 j1(x)/x: trig",
370    mp_function=lambda x: 3*(mp.sin(x)/x - mp.cos(x))/(x*x),
371    np_function=trig_3j1x_x,
372    ocl_function=make_ocl("return sas_3j1x_x(q);", "sas_j1c", ["lib/sas_3j1x_x.c"]),
373)
374def np_2J1x_x(x):
375    """
376    numpy implementation of 2J1(x)/x using single precision algorithm
377    """
378    # pylint: disable=bad-continuation
379    f = x.dtype.type
380    ax = abs(x)
381    if ax < f(8.0):
382        y = x*x
383        ans1 = f(2)*(f(72362614232.0)
384                  + y*(f(-7895059235.0)
385                  + y*(f(242396853.1)
386                  + y*(f(-2972611.439)
387                  + y*(f(15704.48260)
388                  + y*(f(-30.16036606)))))))
389        ans2 = (f(144725228442.0)
390                  + y*(f(2300535178.0)
391                  + y*(f(18583304.74)
392                  + y*(f(99447.43394)
393                  + y*(f(376.9991397)
394                  + y)))))
395        return ans1/ans2
396    else:
397        y = f(64.0)/(ax*ax)
398        xx = ax - f(2.356194491)
399        ans1 = (f(1.0)
400                  + y*(f(0.183105e-2)
401                  + y*(f(-0.3516396496e-4)
402                  + y*(f(0.2457520174e-5)
403                  + y*f(-0.240337019e-6)))))
404        ans2 = (f(0.04687499995)
405                  + y*(f(-0.2002690873e-3)
406                  + y*(f(0.8449199096e-5)
407                  + y*(f(-0.88228987e-6)
408                  + y*f(0.105787412e-6)))))
409        sn, cn = np.sin(xx), np.cos(xx)
410        ans = np.sqrt(f(0.636619772)/ax) * (cn*ans1 - (f(8.0)/ax)*sn*ans2) * f(2)/x
411        return -ans if (x < f(0.0)) else ans
412add_function(
413    name="2 J1(x)/x:alt",
414    mp_function=lambda x: 2*mp.j1(x)/x,
415    np_function=lambda x: np.asarray([np_2J1x_x(v) for v in x], x.dtype),
416    ocl_function=make_ocl("return sas_2J1x_x(q);", "sas_2J1x_x", ["lib/polevl.c", "lib/sas_J1.c"]),
417)
418
419ALL_FUNCTIONS = set(FUNCTIONS.keys())
420ALL_FUNCTIONS.discard("loggamma")  # OCL version not ready yet
421ALL_FUNCTIONS.discard("3j1/x:taylor")
422ALL_FUNCTIONS.discard("3j1/x:trig")
423ALL_FUNCTIONS.discard("2J1/x:alt")
424
425# =============== MAIN PROGRAM ================
426
427def usage():
428    names = ", ".join(sorted(ALL_FUNCTIONS))
429    print("""\
430usage: precision.py [-f] [--log|--linear|--zoom|--neg] name...
431where
432    -f indicates that the function value should be plotted rather than error,
433    --log indicates log stepping in [10^-3, 10^5]
434    --linear indicates linear stepping in [1, 1000]
435    --zoom indicates linear stepping in [1000, 1010]
436    --neg indicates linear stepping in [-100.1, 100.1]
437and name is "all [first]" or one of:
438    """+names)
439    sys.exit(1)
440
441def main():
442    import sys
443    diff = True
444    xrange = "log"
445    args = sys.argv[1:]
446    if '-f' in args:
447        args.remove('-f')
448        diff = False
449    for k in "log linear zoom neg".split():
450        if '--'+k in args:
451            args.remove('--'+k)
452            xrange = k
453    if not args:
454        usage()
455    if args[0] == "all":
456        cutoff = args[1] if len(args) > 1 else ""
457        args = list(sorted(ALL_FUNCTIONS))
458        args = [k for k in args if k >= cutoff]
459    if any(k not in FUNCTIONS for k in args):
460        usage()
461    multiple = len(args) > 1
462    pylab.interactive(multiple)
463    for k in args:
464        pylab.clf()
465        comparator = FUNCTIONS[k]
466        comparator.run(xrange=xrange, diff=diff)
467        if multiple:
468            raw_input()
469    if not multiple:
470        pylab.show()
471
472if __name__ == "__main__":
473    main()
Note: See TracBrowser for help on using the repository browser.