source: sasmodels/sasmodel.py @ 79fcc40

core_shell_microgelscostrafo411magnetic_modelrelease_v0.94release_v0.95ticket-1257-vesicle-productticket_1156ticket_1265_superballticket_822_more_unit_tests
Last change on this file since 79fcc40 was 09e15be, checked in by HMP1 <helen.park@…>, 10 years ago

Attempt at faster kernel for TEST,
updated fit.py,
errors in the kernels fixed

  • Property mode set to 100644
File size: 4.3 KB
Line 
1#!/usr/bin/env python
2# -*- coding: utf-8 -*-
3
4import sys
5import numpy as np
6import pyopencl as cl
7from bumps.names import Parameter
8from sans.dataloader.loader import Loader
9from sans.dataloader.manipulations import Ringcut, Boxcut
10
11
12def load_data(filename):
13    loader = Loader()
14    data = loader.load(filename)
15    if data is None:
16        raise IOError("Data %r could not be loaded"%filename)
17    return data
18
19
20def set_beam_stop(data, radius, outer=None):
21    data.mask = Ringcut(0, radius)(data)
22    if outer is not None:
23        data.mask += Ringcut(outer,np.inf)(data)
24
25def set_half(data, half):
26    if half == 'left':
27        data.mask += Boxcut(x_min=-np.inf, x_max=0.0, y_min=-np.inf, y_max=np.inf)(data)
28    if half == 'right':
29        data.mask += Boxcut(x_min=0.0, x_max=np.inf, y_min=-np.inf, y_max=np.inf)(data)
30
31
32
33def plot_data(data, iq, vmin=None, vmax=None):
34    from numpy.ma import masked_array
35    import matplotlib.pyplot as plt
36    img = masked_array(iq, data.mask)
37    xmin, xmax = min(data.qx_data), max(data.qx_data)
38    ymin, ymax = min(data.qy_data), max(data.qy_data)
39    plt.imshow(img.reshape(128,128),
40               interpolation='nearest', aspect=1, origin='upper',
41               extent=[xmin, xmax, ymin, ymax], vmin=vmin, vmax=vmax)
42
43
44def plot_result(data, theory):
45    import matplotlib.pyplot as plt
46    plt.subplot(1, 3, 1)
47    #print "not a number",sum(np.isnan(data.data))
48    #data.data[data.data<0.05] = 0.5
49    logdata = np.log10(data.data)
50    #print data.data.min(), data.data.max()
51    clean = logdata[~np.isnan(logdata)]
52    vmin, vmax = clean.min(), clean.max()
53    vmin, vmax = np.percentile(clean, 5), 1.05*vmax
54    #vmin,vmax = None,None
55    plot_data(data, logdata, vmin=vmin, vmax=vmax)
56    plt.colorbar()
57    plt.subplot(1, 3, 2)
58    plot_data(data, np.log10(theory), vmin=vmin, vmax=vmax)
59    plt.colorbar()
60    plt.subplot(1, 3, 3)
61    plot_data(data, (theory-data.data)/data.err_data)
62    plt.colorbar()
63
64
65def demo():
66    data = load_data('JUN03289.DAT')
67    set_beam_stop(data, 0.004)
68    plot_data(data)
69    import matplotlib.pyplot as plt; plt.show()
70
71
72GPU_CONTEXT = None
73GPU_QUEUE = None
74def card():
75    global GPU_CONTEXT, GPU_QUEUE
76    if GPU_CONTEXT is None:
77        GPU_CONTEXT = cl.create_some_context()
78        GPU_QUEUE = cl.CommandQueue(GPU_CONTEXT)
79    return GPU_CONTEXT, GPU_QUEUE
80
81
82class SasModel(object):
83    def __init__(self, data, model, dtype='float32', **kw):
84        self.__dict__['_parameters'] = {}
85        self.index = (data.mask==0) & (~np.isnan(data.data))
86        self.iq = data.data[self.index]
87        self.diq = data.err_data[self.index]
88        self.data = data
89        self.qx = data.qx_data
90        self.qy = data.qy_data
91        self.gpu = model(self.qx, self.qy, dtype=dtype)
92        pd_pars = set(base+attr for base in model.PD_PARS for attr in ('_pd','_pd_n','_pd_nsigma'))
93        total_pars = set(model.PARS.keys()) | pd_pars
94        extra_pars = set(kw.keys()) - total_pars
95        if extra_pars:
96            raise TypeError("unexpected parameters %s"%(str(extra_pars,)))
97        pars = model.PARS.copy()
98        pars.update((base+'_pd', 0) for base in model.PD_PARS)
99        pars.update((base+'_pd_n', 35) for base in model.PD_PARS)
100        pars.update((base+'_pd_nsigma', 3) for base in model.PD_PARS)
101        pars.update(kw)
102        self._parameters = dict((k, Parameter.default(v, name=k)) for k, v in pars.items())
103
104    def numpoints(self):
105        return len(self.iq)
106
107    def parameters(self):
108        return self._parameters
109
110    def __getattr__(self, par):
111        return self._parameters[par]
112
113    def __setattr__(self, par, val):
114        if par in self._parameters:
115            self._parameters[par] = val
116        else:
117            self.__dict__[par] = val
118
119    def theory(self):
120        pars = dict((k,v.value) for k,v in self._parameters.items())
121        result = self.gpu.eval(pars)
122        return result
123
124    def residuals(self):
125        #if np.any(self.err ==0): print "zeros in err"
126        return (self.theory()[self.index]-self.iq)/self.diq
127
128    def nllf(self):
129        R = self.residuals()
130        #if np.any(np.isnan(R)): print "NaN in residuals"
131        return 0.5*np.sum(R**2)
132
133    def __call__(self):
134        return 2*self.nllf()/self.dof
135
136    def plot(self, view='linear'):
137        plot_result(self.data, self.theory())
138
139    def save(self, basename):
140        pass
141
142    def update(self):
143        pass
Note: See TracBrowser for help on using the repository browser.