Ignore:
Timestamp:
Nov 7, 2017 1:05:12 PM (6 years ago)
Author:
Paul Kienzle <pkienzle@…>
Branches:
master, magnetic_scatt, release-4.2.2, ticket-1009, ticket-1094-headless, ticket-1242-2d-resolution, ticket-1243, ticket-1249, ticket885, unittest-saveload
Children:
f926abb
Parents:
0957bb3a
Message:

fix C interface to sldi after py3 conversion

File:
1 edited

Legend:

Unmodified
Added
Removed
  • src/sas/sascalc/calculator/sas_gen.py

    rb58265c3 r54b0650  
    118118        self.is_avg = is_avg 
    119119 
    120     def _gen(self, x, y, i): 
     120    def _gen(self, qx, qy): 
    121121        """ 
    122122        Evaluate the function 
     
    129129        pos_y = self.data_y 
    130130        pos_z = self.data_z 
    131         len_x = len(pos_x) 
    132131        if self.is_avg is None: 
    133             len_x *= -1 
    134132            pos_x, pos_y, pos_z = transform_center(pos_x, pos_y, pos_z) 
    135         len_q = len(x) 
    136133        sldn = copy.deepcopy(self.data_sldn) 
    137134        sldn -= self.params['solvent_SLD'] 
    138         model = mod.new_GenI(len_x, pos_x, pos_y, pos_z, 
     135        model = mod.new_GenI((1 if self.is_avg else 0), 
     136                             pos_x, pos_y, pos_z, 
    139137                             sldn, self.data_mx, self.data_my, 
    140138                             self.data_mz, self.data_vol, 
     
    142140                             self.params['Up_frac_out'], 
    143141                             self.params['Up_theta']) 
    144         if y == []: 
    145             mod.genicom(model, len_q, x, i) 
    146         else: 
    147             mod.genicomXY(model, len_q, x, y, i) 
     142        if len(qy): 
     143            qx, qy = _vec(qx), _vec(qy) 
     144            I_out = np.empty_like(qx) 
     145            mod.genicomXY(model, qx, qy, I_out) 
     146        else: 
     147            qx = _vec(qx) 
     148            I_out = np.empty_like(qx) 
     149            mod.genicom(model, qx, I_out) 
    148150        vol_correction = self.data_total_volume / self.params['total_volume'] 
    149         return  self.params['scale'] * vol_correction * i + \ 
    150                         self.params['background'] 
     151        result = (self.params['scale'] * vol_correction * I_out 
     152                  + self.params['background']) 
     153        return result 
    151154 
    152155    def set_sld_data(self, sld_data=None): 
     
    156159        self.sld_data = sld_data 
    157160        self.data_pos_unit = sld_data.pos_unit 
    158         self.data_x = sld_data.pos_x 
    159         self.data_y = sld_data.pos_y 
    160         self.data_z = sld_data.pos_z 
    161         self.data_sldn = sld_data.sld_n 
    162         self.data_mx = sld_data.sld_mx 
    163         self.data_my = sld_data.sld_my 
    164         self.data_mz = sld_data.sld_mz 
    165         self.data_vol = sld_data.vol_pix 
     161        self.data_x = _vec(sld_data.pos_x) 
     162        self.data_y = _vec(sld_data.pos_y) 
     163        self.data_z = _vec(sld_data.pos_z) 
     164        self.data_sldn = _vec(sld_data.sld_n) 
     165        self.data_mx = _vec(sld_data.sld_mx) 
     166        self.data_my = _vec(sld_data.sld_my) 
     167        self.data_mz = _vec(sld_data.sld_mz) 
     168        self.data_vol = _vec(sld_data.vol_pix) 
    166169        self.data_total_volume = sum(sld_data.vol_pix) 
    167170        self.params['total_volume'] = sum(sld_data.vol_pix) 
     
    180183        :return: (I value) 
    181184        """ 
    182         if x.__class__.__name__ == 'list': 
     185        if isinstance(x, list): 
    183186            if len(x[1]) > 0: 
    184187                msg = "Not a 1D." 
    185188                raise ValueError(msg) 
    186             i_out = np.zeros_like(x[0]) 
    187189            # 1D I is found at y =0 in the 2D pattern 
    188             out = self._gen(x[0], [], i_out) 
     190            out = self._gen(x[0], []) 
    189191            return out 
    190192        else: 
     
    199201        :Use this runXY() for the computation 
    200202        """ 
    201         if x.__class__.__name__ == 'list': 
    202             i_out = np.zeros_like(x[0]) 
    203             out = self._gen(x[0], x[1], i_out) 
    204             return out 
     203        if isinstance(x, list): 
     204            return self._gen(x[0], x[1]) 
    205205        else: 
    206206            msg = "Q must be given as list of qx's and qy's" 
     
    214214                      where qx,qy are 1D ndarrays (for 2D). 
    215215        """ 
    216         if qdist.__class__.__name__ == 'list': 
    217             if len(qdist[1]) < 1: 
    218                 out = self.run(qdist) 
    219             else: 
    220                 out = self.runXY(qdist) 
    221             return out 
     216        if isinstance(qdist, list): 
     217            return self.run(qdist) if len(qdist[1]) < 1 else self.runXY(qdist) 
    222218        else: 
    223219            mesg = "evalDistribution is expecting an ndarray of " 
    224220            mesg += "a list [qx,qy] where qx,qy are arrays." 
    225221            raise RuntimeError(mesg) 
     222 
     223def _vec(v): 
     224    return np.ascontiguousarray(v, 'd') 
    226225 
    227226class OMF2SLD(object): 
     
    10411040        self.line_z = line_z 
    10421041 
     1042def _get_data_path(*path_parts): 
     1043    from os.path import realpath, join as joinpath, dirname, abspath 
     1044    # in sas/sascalc/calculator;  want sas/sasview/test 
     1045    return joinpath(dirname(realpath(__file__)), 
     1046                    '..', '..', 'sasview', 'test', *path_parts) 
     1047 
    10431048def test_load(): 
    10441049    """ 
     
    10461051    """ 
    10471052    from mpl_toolkits.mplot3d import Axes3D 
    1048     current_dir = os.path.abspath(os.path.curdir) 
    1049     print(current_dir) 
    1050     for i in range(6): 
    1051         current_dir, _ = os.path.split(current_dir) 
    1052         tfile = os.path.join(current_dir, "test", "CoreXY_ShellZ.txt") 
    1053         ofile = os.path.join(current_dir, "test", "A_Raw_Example-1.omf") 
    1054         if os.path.isfile(tfile): 
    1055             tfpath = tfile 
    1056             ofpath = ofile 
    1057             break 
     1053    tfpath = _get_data_path("1d_data", "CoreXY_ShellZ.txt") 
     1054    ofpath = _get_data_path("coordinate_data", "A_Raw_Example-1.omf") 
     1055    if not os.path.isfile(tfpath) or not os.path.isfile(ofpath): 
     1056        raise ValueError("file(s) not found: %r, %r"%(tfpath, ofpath)) 
    10581057    reader = SLDReader() 
    10591058    oreader = OMFReader() 
     
    10921091        Test code 
    10931092    """ 
    1094     current_dir = os.path.abspath(os.path.curdir) 
    1095     for i in range(3): 
    1096         current_dir, _ = os.path.split(current_dir) 
    1097         ofile = os.path.join(current_dir, "test", "A_Raw_Example-1.omf") 
    1098         if os.path.isfile(ofile): 
    1099             ofpath = ofile 
    1100             break 
     1093    ofpath = _get_data_path("coordinate_data", "A_Raw_Example-1.omf") 
     1094    if not os.path.isfile(ofpath): 
     1095        raise ValueError("file(s) not found: %r"%(ofpath,)) 
    11011096    oreader = OMFReader() 
    11021097    ooutput = decode(oreader.read(ofpath)) 
     
    11041099    foutput.set_data(ooutput) 
    11051100    writer = SLDReader() 
    1106     writer.write(os.path.join(os.path.dirname(ofpath), "out.txt"), 
    1107                  foutput.output) 
     1101    writer.write("out.txt", foutput.output) 
    11081102    model = GenSAS() 
    11091103    model.set_sld_data(foutput.output) 
     
    11141108 
    11151109if __name__ == "__main__": 
     1110    #test_load() 
    11161111    test() 
    1117     test_load() 
Note: See TracChangeset for help on using the changeset viewer.