source: sasmodels/sasmodel.py @ 496b252

core_shell_microgelscostrafo411magnetic_modelrelease_v0.94release_v0.95ticket-1257-vesicle-productticket_1156ticket_1265_superballticket_822_more_unit_tests
Last change on this file since 496b252 was 8faffcd, checked in by HMP1 <helen.park@…>, 11 years ago

Update for Aaron

  • Property mode set to 100644
File size: 3.5 KB
Line 
1#!/usr/bin/env python
2# -*- coding: utf-8 -*-
3
4import numpy as np
5import pyopencl as cl
6from bumps.names import Parameter
7from sans.dataloader.loader import Loader
8from sans.dataloader.manipulations import Ringcut
9
10
11def load_data(filename):
12    loader = Loader()
13    data = loader.load(filename)
14    if data is None:
15        raise IOError("Data %r could not be loaded"%filename)
16    return data
17
18
19def set_beam_stop(data, radius):
20    data.mask = Ringcut(0, radius)(data)
21
22
23def plot_data(data, iq):
24    from numpy.ma import masked_array
25    import matplotlib.pyplot as plt
26    img = masked_array(iq, data.mask)
27    xmin, xmax = min(data.qx_data), max(data.qx_data)
28    ymin, ymax = min(data.qy_data), max(data.qy_data)
29    plt.imshow(img.reshape(128,128),
30               interpolation='nearest', aspect=1, origin='upper',
31               extent=[xmin, xmax, ymin, ymax])
32
33
34def plot_result(data, theory):
35    import matplotlib.pyplot as plt
36    plt.subplot(1,3,1)
37    plot_data(data, data.data)
38    plt.subplot(1,3,2)
39    plot_data(data, theory)
40    plt.subplot(1,3,3)
41    plot_data(data, (theory-data.data)/data.err_data)
42    plt.colorbar()
43
44
45def demo():
46    data = load_data('JUN03289.DAT')
47    set_beam_stop(data, 0.004)
48    plot_data(data)
49    import matplotlib.pyplot as plt; plt.show()
50
51
52GPU_CONTEXT = None
53GPU_QUEUE = None
54def card():
55    global GPU_CONTEXT, GPU_QUEUE
56    if GPU_CONTEXT is None:
57        GPU_CONTEXT = cl.create_some_context()
58        GPU_QUEUE = cl.CommandQueue(GPU_CONTEXT)
59    return GPU_CONTEXT, GPU_QUEUE
60
61
62class SasModel(object):
63    def __init__(self, data, model, dtype='float32', **kw):
64        self.__dict__['_parameters'] = {}
65        self.index = data.mask==0
66        self.iq = data.data[self.index]
67        self.diq = data.err_data[self.index]
68        self.data = data
69        self.qx = data.qx_data
70        self.qy = data.qy_data
71        self.gpu = model(self.qx, self.qy, dtype=dtype)
72        pd_pars = set(base+attr for base in model.PD_PARS for attr in ('_pd','_pd_n','_pd_nsigma'))
73        total_pars = set(model.PARS.keys()) | pd_pars
74        extra_pars = set(kw.keys()) - total_pars
75        if extra_pars:
76            raise TypeError("unexpected parameters %s"%(str(extra_pars,)))
77        pars = model.PARS.copy()
78        pars.update((base+'_pd', 0) for base in model.PD_PARS)
79        pars.update((base+'_pd_n', 35) for base in model.PD_PARS)
80        pars.update((base+'_pd_nsigma', 3) for base in model.PD_PARS)
81        pars.update(kw)
82        self._parameters = dict((k, Parameter.default(v, name=k)) for k, v in pars.items())
83
84    def numpoints(self):
85        return len(self.iq)
86
87    def parameters(self):
88        return self._parameters
89
90    def __getattr__(self, par):
91        return self._parameters[par]
92
93    def __setattr__(self, par, val):
94        if par in self._parameters:
95            self._parameters[par] = val
96        else:
97            self.__dict__[par] = val
98
99    def theory(self):
100        pars = dict((k,v.value) for k,v in self._parameters.items())
101        result = self.gpu.eval(pars)
102        return result
103
104    def residuals(self):
105        #if np.any(self.err ==0): print "zeros in err"
106        return (self.theory()[self.index]-self.iq)/self.diq
107
108    def nllf(self):
109        R = self.residuals()
110        #if np.any(np.isnan(R)): print "NaN in residuals"
111        return 0.5*np.sum(R**2)
112
113    def __call__(self):
114        return 2*self.nllf()/self.dof
115
116    def plot(self, view='linear'):
117        plot_result(self.data, self.theory())
118
119    def save(self, basename):
120        pass
121
122    def update(self):
123        pass
Note: See TracBrowser for help on using the repository browser.