Changeset ce27e21 in sasmodels for sasmodels/core.py


Ignore:
Timestamp:
Aug 24, 2014 7:18:14 PM (10 years ago)
Author:
Paul Kienzle <pkienzle@…>
Branches:
master, core_shell_microgels, costrafo411, magnetic_model, release_v0.94, release_v0.95, ticket-1257-vesicle-product, ticket_1156, ticket_1265_superball, ticket_822_more_unit_tests
Children:
1780d59
Parents:
14de349
Message:

first pass for sasview wrapper around opencl models

File:
1 edited

Legend:

Unmodified
Added
Removed
  • sasmodels/core.py

    r14de349 rce27e21  
    44import sys, os 
    55import datetime 
     6import warnings 
    67 
    78import numpy as np 
     
    1415    return gen.make(modelpath) 
    1516 
     17 
     18 
    1619def opencl_model(modelname, dtype="single"): 
    1720    from sasmodels import gpu 
    1821 
    19     source, meta, _ = load_model(modelname) 
     22    source, info, _ = load_model(modelname) 
    2023    # for debugging, save source to a .cl file, edit it, and reload as model 
    2124    #open(modelname+'.cl','w').write(source) 
    2225    #source = open(modelname+'.cl','r').read() 
    23     return gpu.GpuModel(source, meta, dtype) 
     26    return gpu.GpuModel(source, info, dtype) 
    2427 
    2528 
     
    3134    COMPILE = "cc -shared -fPIC -std=c99 -fopenmp -O2 -Wall %s -o %s -lm" 
    3235DLL_PATH = "/tmp" 
    33 def dll_path(meta): 
     36 
     37 
     38def dll_path(info): 
    3439    from os.path import join as joinpath, split as splitpath, splitext 
    35     basename = splitext(splitpath(meta['filename'])[1])[0] 
     40    basename = splitext(splitpath(info['filename'])[1])[0] 
    3641    return joinpath(DLL_PATH, basename+'.so') 
     42 
    3743 
    3844def dll_model(modelname): 
     
    4046    from sasmodels import dll 
    4147 
    42     source, meta, _ = load_model(modelname) 
    43     dllpath = dll_path(meta) 
     48    source, info, _ = load_model(modelname) 
     49    dllpath = dll_path(info) 
    4450    if not os.path.exists(dllpath) \ 
    45             or (os.path.getmtime(dllpath) < os.path.getmtime(meta['filename'])): 
     51            or (os.path.getmtime(dllpath) < os.path.getmtime(info['filename'])): 
    4652        # Replace with a proper temp file 
    4753        srcfile = '/tmp/%s.c'%modelname 
    4854        open(srcfile, 'w').write(source) 
    4955        os.system(COMPILE%(srcfile, dllpath)) 
    50     return dll.DllModel(dllpath, meta) 
     56    return dll.DllModel(dllpath, info) 
     57 
    5158 
    5259TIC = None 
     
    5764    return TIC 
    5865 
     66 
    5967def toc(): 
    6068    return TIC() 
     69 
    6170 
    6271def load_data(filename): 
     
    6877    return data 
    6978 
     79 
    7080def fake_data2D(qx, qy=None): 
    7181    from sans.dataloader.data_info import Data2D, Detector 
    72  
    7382 
    7483    if qy is None: 
     
    121130            data.mask &= (data.x<outer) 
    122131 
     132 
    123133def set_half(data, half): 
    124134    from sans.dataloader.manipulations import Boxcut 
     
    128138        data.mask += Boxcut(x_min=0.0, x_max=np.inf, y_min=-np.inf, y_max=np.inf)(data) 
    129139 
     140 
    130141def set_top(data, max): 
    131142    from sans.dataloader.manipulations import Boxcut 
    132143    data.mask += Boxcut(x_min=-np.inf, x_max=np.inf, y_min=-np.inf, y_max=max)(data) 
     144 
    133145 
    134146def plot_data(data, iq, vmin=None, vmax=None): 
     
    141153               interpolation='nearest', aspect=1, origin='upper', 
    142154               extent=[xmin, xmax, ymin, ymax], vmin=vmin, vmax=vmax) 
     155 
    143156 
    144157def plot_result2D(data, theory, view='linear'): 
     
    184197    mresid = masked_array((theory-data.y)/data.dy, mdata.mask) 
    185198 
    186     plt.subplot(1,2,1) 
     199    plt.subplot(121) 
    187200    plt.errorbar(data.x, mdata, yerr=data.dy) 
    188201    plt.plot(data.x, mtheory, '-', hold=True) 
    189202    plt.yscale(view) 
    190     plt.subplot(1, 2, 2) 
     203    plt.subplot(122) 
    191204    plt.plot(data.x, mresid, 'x') 
    192205    #plt.axhline(1, color='black', ls='--',lw=1, hold=True) 
     
    219232 
    220233        # create model 
    221         self.fn = model(input, cutoff=cutoff) 
     234        self.fn = model(input) 
     235        self.cutoff = cutoff 
    222236 
    223237        # define bumps parameters 
    224238        pars = [] 
    225         for p in model.meta['parameters']: 
     239        extras = [] 
     240        for p in model.info['parameters']: 
    226241            name, default, limits, ptype = p[0], p[2], p[3], p[4] 
    227242            value = kw.pop(name, default) 
    228243            setattr(self, name, Parameter.default(value, name=name, limits=limits)) 
    229244            pars.append(name) 
    230             if ptype != "": 
    231                 for xpart,xdefault,xlimits in [ 
    232                         ('_pd', 0, limits), 
    233                         ('_pd_n', 35, (0,1000)), 
    234                         ('_pd_nsigma', 3, (0,10)), 
    235                         ]: 
    236                     xname = name+xpart 
    237                     xvalue = kw.pop(xname, xdefault) 
    238                     setattr(self, xname, Parameter.default(xvalue, name=xname)) 
     245        for name in model.info['partype']['pd-2d']: 
     246            for xpart,xdefault,xlimits in [ 
     247                    ('_pd', 0, limits), 
     248                    ('_pd_n', 35, (0,1000)), 
     249                    ('_pd_nsigma', 3, (0, 10)), 
     250                    ('_pd_type', 'gaussian', None), 
     251                ]: 
     252                xname = name+xpart 
     253                xvalue = kw.pop(xname, xdefault) 
     254                if xlimits is not None: 
     255                    xvalue = Parameter.default(xvalue, name=xname, limits=xlimits) 
    239256                    pars.append(xname) 
     257                setattr(self, xname, xvalue) 
     258        self._parameter_names = pars 
    240259        if kw: 
    241260            raise TypeError("unexpected parameters: %s"%(", ".join(sorted(kw.keys())))) 
    242         self._parameter_names = pars 
    243261        self.update() 
    244262 
     
    254272    def theory(self): 
    255273        if 'theory' not in self._cache: 
    256             pars = dict((k,getattr(self,k).value) for k in self._parameter_names) 
     274            pars = [getattr(self,p).value for p in self.fn.fixed_pars] 
     275            pd_pars = [self._get_weights(p) for p in self.fn.pd_pars] 
    257276            #print pars 
    258             self._theory[self.index] = self.fn.eval(pars) 
    259             #self._theory[:] = self.fn.eval(pars) 
     277            self._theory[self.index] = self.fn(pars, pd_pars, self.cutoff) 
     278            #self._theory[:] = self.fn.eval(pars, pd_pars) 
    260279            self._cache['theory'] = self._theory 
    261280        return self._cache['theory'] 
     
    282301        pass 
    283302 
     303    def _get_weights(self, par): 
     304        from . import weights 
     305 
     306        relative = self.fn.info['partype']['pd-rel'] 
     307        limits = self.fn.info['limits'] 
     308        disperser,value,npts,width,nsigma = [getattr(self, par+ext) 
     309                for ext in ('_pd_type','','_pd_n','_pd','_pd_nsigma')] 
     310        v,w = weights.get_weights( 
     311            disperser, int(npts.value), width.value, nsigma.value, 
     312            value.value, limits[par], par in relative) 
     313        return v,w/w.max() 
     314 
     315 
    284316def demo(): 
    285317    data = load_data('DEC07086.DAT') 
     
    288320    import matplotlib.pyplot as plt; plt.show() 
    289321 
     322 
    290323if __name__ == "__main__": 
    291324    demo() 
Note: See TracChangeset for help on using the changeset viewer.