source: sasmodels/sasmodel.py @ 8a20be5

core_shell_microgelscostrafo411magnetic_modelrelease_v0.94release_v0.95ticket-1257-vesicle-productticket_1156ticket_1265_superballticket_822_more_unit_tests
Last change on this file since 8a20be5 was 8a20be5, checked in by HMP1 <helen.park@…>, 10 years ago

Added a fit2 (fits two different models at different angles)
(preliminary) Added CoreshellCyl? and CapCyl? Kernels
(preliminary) Updated kernels to include functions

  • Property mode set to 100644
File size: 3.3 KB
Line 
1#!/usr/bin/env python
2# -*- coding: utf-8 -*-
3
4import numpy as np
5import pyopencl as cl
6from bumps.names import Parameter
7from sans.dataloader.loader import Loader
8from sans.dataloader.manipulations import Ringcut
9
10
11def load_data(filename):
12    loader = Loader()
13    data = loader.load(filename)
14    if data is None:
15        raise IOError("Data %r could not be loaded"%filename)
16    return data
17
18
19def set_beam_stop(data, radius):
20    data.mask = Ringcut(0, radius)(data)
21
22
23def plot_data(data, iq):
24    from numpy.ma import masked_array
25    import matplotlib.pyplot as plt
26    img = masked_array(iq, data.mask)
27    xmin, xmax = min(data.qx_data), max(data.qx_data)
28    ymin, ymax = min(data.qy_data), max(data.qy_data)
29    plt.imshow(img.reshape(128,128),
30               interpolation='nearest', aspect=1, origin='upper',
31               extent=[xmin, xmax, ymin, ymax])
32
33
34def plot_result(data, theory):
35    import matplotlib.pyplot as plt
36    plt.subplot(1,3,1)
37    plot_data(data, data.data)
38    plt.subplot(1,3,2)
39    plot_data(data, theory)
40    plt.subplot(1,3,3)
41    plot_data(data, (theory-data.data)/data.err_data)
42    plt.colorbar()
43
44
45def demo():
46    data = load_data('JUN03289.DAT')
47    set_beam_stop(data, 0.004)
48    plot_data(data)
49    import matplotlib.pyplot as plt; plt.show()
50
51
52GPU_CONTEXT = None
53GPU_QUEUE = None
54def card():
55    global GPU_CONTEXT, GPU_QUEUE
56    if GPU_CONTEXT is None:
57        GPU_CONTEXT = cl.create_some_context()
58        GPU_QUEUE = cl.CommandQueue(GPU_CONTEXT)
59    return GPU_CONTEXT, GPU_QUEUE
60
61
62class SasModel(object):
63    def __init__(self, data, model, dtype='float32', **kw):
64        self.index = data.mask==0
65        self.iq = data.data[self.index]
66        self.diq = data.err_data[self.index]
67        self.data = data
68        self.qx = data.qx_data
69        self.qy = data.qy_data
70        self.gpu = model(self.qx, self.qy, dtype=dtype)
71        pd_pars = set(base+attr for base in model.PD_PARS for attr in ('_pd','_pd_n','_pd_nsigma'))
72        total_pars = set(model.PARS.keys()) | pd_pars
73        extra_pars = set(kw.keys()) - total_pars
74        if extra_pars:
75            raise TypeError("unexpected parameters %s"%(str(extra_pars,)))
76        pars = model.PARS.copy()
77        pars.update((base+'_pd', 0) for base in model.PD_PARS)
78        pars.update((base+'_pd_n', 35) for base in model.PD_PARS)
79        pars.update((base+'_pd_nsigma', 3) for base in model.PD_PARS)
80        pars.update(kw)
81        self._parameters = dict((k, Parameter(v, name=k)) for k, v in pars.items())
82
83    def numpoints(self):
84        return len(self.iq)
85
86    def parameters(self):
87        return self._parameters
88
89    def __getattr__(self, par):
90        return self._parameters[par]
91
92    def theory(self):
93        pars = dict((k,v.value) for k,v in self._parameters.items())
94        print pars
95        result = self.gpu.eval(pars)
96        return result
97
98    def residuals(self):
99        #if np.any(self.err ==0): print "zeros in err"
100        return (self.theory()[self.index]-self.iq)/self.diq
101
102    def nllf(self):
103        R = self.residuals()
104        #if np.any(np.isnan(R)): print "NaN in residuals"
105        return 0.5*np.sum(R**2)
106
107    def __call__(self):
108        return 2*self.nllf()/self.dof
109
110    def plot(self, view='linear'):
111        plot_result(self.data, self.theory())
112
113    def save(self, basename):
114        pass
115
116    def update(self):
117        pass
Note: See TracBrowser for help on using the repository browser.