source: sasview/src/sas/sascalc/calculator/sas_gen.py @ 14e1ff0

ESS_GUIESS_GUI_bumps_abstractionESS_GUI_iss1116ESS_GUI_opencl
Last change on this file since 14e1ff0 was b8080e1, checked in by Piotr Rozyczko <rozyczko@…>, 6 years ago

cherry picking sascalc changes from master SASVIEW-996
minor unit test fixes

  • Property mode set to 100644
File size: 40.2 KB
Line 
1# pylint: disable=invalid-name
2"""
3SAS generic computation and sld file readers
4"""
5from __future__ import print_function
6
7import os
8import sys
9import copy
10import logging
11
12from periodictable import formula
13from periodictable import nsf
14import numpy as np
15
16from .core import sld2i as mod
17from .BaseComponent import BaseComponent
18
19logger = logging.getLogger(__name__)
20
21if sys.version_info[0] < 3:
22    def decode(s):
23        return s
24else:
25    def decode(s):
26        return s.decode() if isinstance(s, bytes) else s
27
28MFACTOR_AM = 2.853E-12
29MFACTOR_MT = 2.3164E-9
30METER2ANG = 1.0E+10
31#Avogadro constant [1/mol]
32NA = 6.02214129e+23
33
34def mag2sld(mag, v_unit=None):
35    """
36    Convert magnetization to magnatic SLD
37    sldm = Dm * mag where Dm = gamma * classical elec. radius/(2*Bohr magneton)
38    Dm ~ 2.853E-12 [A^(-2)] ==> Shouldn't be 2.90636E-12 [A^(-2)]???
39    """
40    if v_unit == "A/m":
41        factor = MFACTOR_AM
42    elif v_unit == "mT":
43        factor = MFACTOR_MT
44    else:
45        raise ValueError("Invalid valueunit")
46    sld_m = factor * mag
47    return sld_m
48
49def transform_center(pos_x, pos_y, pos_z):
50    """
51    re-center
52    :return: posx, posy, posz   [arrays]
53    """
54    posx = pos_x - (min(pos_x) + max(pos_x)) / 2.0
55    posy = pos_y - (min(pos_y) + max(pos_y)) / 2.0
56    posz = pos_z - (min(pos_z) + max(pos_z)) / 2.0
57    return posx, posy, posz
58
59class GenSAS(BaseComponent):
60    """
61    Generic SAS computation Model based on sld (n & m) arrays
62    """
63    def __init__(self):
64        """
65        Init
66        :Params sld_data: MagSLD object
67        """
68        # Initialize BaseComponent
69        BaseComponent.__init__(self)
70        self.sld_data = None
71        self.data_pos_unit = None
72        self.data_x = None
73        self.data_y = None
74        self.data_z = None
75        self.data_sldn = None
76        self.data_mx = None
77        self.data_my = None
78        self.data_mz = None
79        self.data_vol = None #[A^3]
80        self.is_avg = False
81        ## Name of the model
82        self.name = "GenSAS"
83        ## Define parameters
84        self.params = {}
85        self.params['scale'] = 1.0
86        self.params['background'] = 0.0
87        self.params['solvent_SLD'] = 0.0
88        self.params['total_volume'] = 1.0
89        self.params['Up_frac_in'] = 1.0
90        self.params['Up_frac_out'] = 1.0
91        self.params['Up_theta'] = 0.0
92        self.description = 'GenSAS'
93        ## Parameter details [units, min, max]
94        self.details = {}
95        self.details['scale'] = ['', 0.0, np.inf]
96        self.details['background'] = ['[1/cm]', 0.0, np.inf]
97        self.details['solvent_SLD'] = ['1/A^(2)', -np.inf, np.inf]
98        self.details['total_volume'] = ['A^(3)', 0.0, np.inf]
99        self.details['Up_frac_in'] = ['[u/(u+d)]', 0.0, 1.0]
100        self.details['Up_frac_out'] = ['[u/(u+d)]', 0.0, 1.0]
101        self.details['Up_theta'] = ['[deg]', -np.inf, np.inf]
102        # fixed parameters
103        self.fixed = []
104
105    def set_pixel_volumes(self, volume):
106        """
107        Set the volume of a pixel in (A^3) unit
108        :Param volume: pixel volume [float]
109        """
110        if self.data_vol is None:
111            raise TypeError("data_vol is missing")
112        self.data_vol = volume
113
114    def set_is_avg(self, is_avg=False):
115        """
116        Sets is_avg: [bool]
117        """
118        self.is_avg = is_avg
119
120    def _gen(self, qx, qy):
121        """
122        Evaluate the function
123        :Param x: array of x-values
124        :Param y: array of y-values
125        :Param i: array of initial i-value
126        :return: function value
127        """
128        pos_x = self.data_x
129        pos_y = self.data_y
130        pos_z = self.data_z
131        if self.is_avg is None:
132            pos_x, pos_y, pos_z = transform_center(pos_x, pos_y, pos_z)
133        sldn = copy.deepcopy(self.data_sldn)
134        sldn -= self.params['solvent_SLD']
135        # **** WARNING **** new_GenI holds pointers to numpy vectors
136        # be sure that they are contiguous double precision arrays and make
137        # sure the GC doesn't eat them before genicom is called.
138        # TODO: rewrite so that the parameters are passed directly to genicom
139        args = (
140            (1 if self.is_avg else 0),
141            pos_x, pos_y, pos_z,
142            sldn, self.data_mx, self.data_my,
143            self.data_mz, self.data_vol,
144            self.params['Up_frac_in'],
145            self.params['Up_frac_out'],
146            self.params['Up_theta'])
147        model = mod.new_GenI(*args)
148        if len(qy):
149            qx, qy = _vec(qx), _vec(qy)
150            I_out = np.empty_like(qx)
151            #print("npoints", qx.shape, "npixels", pos_x.shape)
152            mod.genicomXY(model, qx, qy, I_out)
153            #print("I_out after", I_out)
154        else:
155            qx = _vec(qx)
156            I_out = np.empty_like(qx)
157            mod.genicom(model, qx, I_out)
158        vol_correction = self.data_total_volume / self.params['total_volume']
159        result = (self.params['scale'] * vol_correction * I_out
160                  + self.params['background'])
161        return result
162
163    def set_sld_data(self, sld_data=None):
164        """
165        Sets sld_data
166        """
167        self.sld_data = sld_data
168        self.data_pos_unit = sld_data.pos_unit
169        self.data_x = _vec(sld_data.pos_x)
170        self.data_y = _vec(sld_data.pos_y)
171        self.data_z = _vec(sld_data.pos_z)
172        self.data_sldn = _vec(sld_data.sld_n)
173        self.data_mx = _vec(sld_data.sld_mx)
174        self.data_my = _vec(sld_data.sld_my)
175        self.data_mz = _vec(sld_data.sld_mz)
176        self.data_vol = _vec(sld_data.vol_pix)
177        self.data_total_volume = sum(sld_data.vol_pix)
178        self.params['total_volume'] = sum(sld_data.vol_pix)
179
180    def getProfile(self):
181        """
182        Get SLD profile
183        : return: sld_data
184        """
185        return self.sld_data
186
187    def run(self, x=0.0):
188        """
189        Evaluate the model
190        :param x: simple value
191        :return: (I value)
192        """
193        if isinstance(x, list):
194            if len(x[1]) > 0:
195                msg = "Not a 1D."
196                raise ValueError(msg)
197            # 1D I is found at y =0 in the 2D pattern
198            out = self._gen(x[0], [])
199            return out
200        else:
201            msg = "Q must be given as list of qx's and qy's"
202            raise ValueError(msg)
203
204    def runXY(self, x=0.0):
205        """
206        Evaluate the model
207        :param x: simple value
208        :return: I value
209        :Use this runXY() for the computation
210        """
211        if isinstance(x, list):
212            return self._gen(x[0], x[1])
213        else:
214            msg = "Q must be given as list of qx's and qy's"
215            raise ValueError(msg)
216
217    def evalDistribution(self, qdist):
218        """
219        Evaluate a distribution of q-values.
220
221        :param qdist: ndarray of scalar q-values (for 1D) or list [qx,qy]
222                      where qx,qy are 1D ndarrays (for 2D).
223        """
224        if isinstance(qdist, list):
225            return self.run(qdist) if len(qdist[1]) < 1 else self.runXY(qdist)
226        else:
227            mesg = "evalDistribution is expecting an ndarray of "
228            mesg += "a list [qx,qy] where qx,qy are arrays."
229            raise RuntimeError(mesg)
230
231def _vec(v):
232    return np.ascontiguousarray(v, 'd')
233
234class OMF2SLD(object):
235    """
236    Convert OMFData to MAgData
237    """
238    def __init__(self):
239        """
240        Init
241        """
242        self.pos_x = None
243        self.pos_y = None
244        self.pos_z = None
245        self.mx = None
246        self.my = None
247        self.mz = None
248        self.sld_n = None
249        self.vol_pix = None
250        self.output = None
251        self.omfdata = None
252
253    def set_data(self, omfdata, shape='rectangular'):
254        """
255        Set all data
256        """
257        self.omfdata = omfdata
258        length = int(omfdata.xnodes * omfdata.ynodes * omfdata.znodes)
259        pos_x = np.arange(omfdata.xmin,
260                             omfdata.xnodes*omfdata.xstepsize + omfdata.xmin,
261                             omfdata.xstepsize)
262        pos_y = np.arange(omfdata.ymin,
263                             omfdata.ynodes*omfdata.ystepsize + omfdata.ymin,
264                             omfdata.ystepsize)
265        pos_z = np.arange(omfdata.zmin,
266                             omfdata.znodes*omfdata.zstepsize + omfdata.zmin,
267                             omfdata.zstepsize)
268        self.pos_x = np.tile(pos_x, int(omfdata.ynodes * omfdata.znodes))
269        self.pos_y = pos_y.repeat(int(omfdata.xnodes))
270        self.pos_y = np.tile(self.pos_y, int(omfdata.znodes))
271        self.pos_z = pos_z.repeat(int(omfdata.xnodes * omfdata.ynodes))
272        self.mx = omfdata.mx
273        self.my = omfdata.my
274        self.mz = omfdata.mz
275        self.sld_n = np.zeros(length)
276
277        if omfdata.mx is None:
278            self.mx = np.zeros(length)
279        if omfdata.my is None:
280            self.my = np.zeros(length)
281        if omfdata.mz is None:
282            self.mz = np.zeros(length)
283
284        self._check_data_length(length)
285        self.remove_null_points(False, False)
286        mask = np.ones(len(self.sld_n), dtype=bool)
287        if shape.lower() == 'ellipsoid':
288            try:
289                # Pixel (step) size included
290                x_c = max(self.pos_x) + min(self.pos_x)
291                y_c = max(self.pos_y) + min(self.pos_y)
292                z_c = max(self.pos_z) + min(self.pos_z)
293                x_d = max(self.pos_x) - min(self.pos_x)
294                y_d = max(self.pos_y) - min(self.pos_y)
295                z_d = max(self.pos_z) - min(self.pos_z)
296                x_r = (x_d + omfdata.xstepsize) / 2.0
297                y_r = (y_d + omfdata.ystepsize) / 2.0
298                z_r = (z_d + omfdata.zstepsize) / 2.0
299                x_dir2 = ((self.pos_x - x_c / 2.0) / x_r)
300                x_dir2 *= x_dir2
301                y_dir2 = ((self.pos_y - y_c / 2.0) / y_r)
302                y_dir2 *= y_dir2
303                z_dir2 = ((self.pos_z - z_c / 2.0) / z_r)
304                z_dir2 *= z_dir2
305                mask = (x_dir2 + y_dir2 + z_dir2) <= 1.0
306            except Exception:
307                logger.error(sys.exc_value)
308        self.output = MagSLD(self.pos_x[mask], self.pos_y[mask],
309                             self.pos_z[mask], self.sld_n[mask],
310                             self.mx[mask], self.my[mask], self.mz[mask])
311        self.output.set_pix_type('pixel')
312        self.output.set_pixel_symbols('pixel')
313
314    def get_omfdata(self):
315        """
316        Return all data
317        """
318        return self.omfdata
319
320    def get_output(self):
321        """
322        Return output
323        """
324        return self.output
325
326    def _check_data_length(self, length):
327        """
328        Check if the data lengths are consistent
329        :Params length: data length
330        """
331        parts = (self.pos_x, self.pos_y, self.pos_z, self.mx, self.my, self.mz)
332        if any(len(v) != length for v in parts):
333            raise ValueError("Error: Inconsistent data length.")
334
335    def remove_null_points(self, remove=False, recenter=False):
336        """
337        Removes any mx, my, and mz = 0 points
338        """
339        if remove:
340            is_nonzero = (np.fabs(self.mx) + np.fabs(self.my) +
341                          np.fabs(self.mz)).nonzero()
342            if len(is_nonzero[0]) > 0:
343                self.pos_x = self.pos_x[is_nonzero]
344                self.pos_y = self.pos_y[is_nonzero]
345                self.pos_z = self.pos_z[is_nonzero]
346                self.sld_n = self.sld_n[is_nonzero]
347                self.mx = self.mx[is_nonzero]
348                self.my = self.my[is_nonzero]
349                self.mz = self.mz[is_nonzero]
350        if recenter:
351            self.pos_x -= (min(self.pos_x) + max(self.pos_x)) / 2.0
352            self.pos_y -= (min(self.pos_y) + max(self.pos_y)) / 2.0
353            self.pos_z -= (min(self.pos_z) + max(self.pos_z)) / 2.0
354
355    def get_magsld(self):
356        """
357        return MagSLD
358        """
359        return self.output
360
361
362class OMFReader(object):
363    """
364    Class to load omf/ascii files (3 columns w/header).
365    """
366    ## File type
367    type_name = "OMF ASCII"
368
369    ## Wildcards
370    type = ["OMF files (*.OMF, *.omf)|*.omf"]
371    ## List of allowed extensions
372    ext = ['.omf', '.OMF']
373
374    def read(self, path):
375        """
376        Load data file
377        :param path: file path
378        :return: x, y, z, sld_n, sld_mx, sld_my, sld_mz
379        """
380        desc = ""
381        mx = np.zeros(0)
382        my = np.zeros(0)
383        mz = np.zeros(0)
384        try:
385            input_f = open(path, 'rb')
386            buff = decode(input_f.read())
387            lines = buff.split('\n')
388            input_f.close()
389            output = OMFData()
390            valueunit = None
391            for line in lines:
392                line = line.strip()
393                # Read data
394                if line and not line.startswith('#'):
395                    try:
396                        toks = line.split()
397                        _mx = float(toks[0])
398                        _my = float(toks[1])
399                        _mz = float(toks[2])
400                        _mx = mag2sld(_mx, valueunit)
401                        _my = mag2sld(_my, valueunit)
402                        _mz = mag2sld(_mz, valueunit)
403                        mx = np.append(mx, _mx)
404                        my = np.append(my, _my)
405                        mz = np.append(mz, _mz)
406                    except Exception as exc:
407                        # Skip non-data lines
408                        logger.error(str(exc)+" when processing %r"%line)
409                #Reading Header; Segment count ignored
410                s_line = line.split(":", 1)
411                if s_line[0].lower().count("oommf") > 0:
412                    oommf = s_line[1].lstrip()
413                if s_line[0].lower().count("title") > 0:
414                    title = s_line[1].lstrip()
415                if s_line[0].lower().count("desc") > 0:
416                    desc += s_line[1].lstrip()
417                    desc += '\n'
418                if s_line[0].lower().count("meshtype") > 0:
419                    meshtype = s_line[1].lstrip()
420                if s_line[0].lower().count("meshunit") > 0:
421                    meshunit = s_line[1].lstrip()
422                    if meshunit.count("m") < 1:
423                        msg = "Error: \n"
424                        msg += "We accept only m as meshunit"
425                        raise ValueError(msg)
426                if s_line[0].lower().count("xbase") > 0:
427                    xbase = s_line[1].lstrip()
428                if s_line[0].lower().count("ybase") > 0:
429                    ybase = s_line[1].lstrip()
430                if s_line[0].lower().count("zbase") > 0:
431                    zbase = s_line[1].lstrip()
432                if s_line[0].lower().count("xstepsize") > 0:
433                    xstepsize = s_line[1].lstrip()
434                if s_line[0].lower().count("ystepsize") > 0:
435                    ystepsize = s_line[1].lstrip()
436                if s_line[0].lower().count("zstepsize") > 0:
437                    zstepsize = s_line[1].lstrip()
438                if s_line[0].lower().count("xnodes") > 0:
439                    xnodes = s_line[1].lstrip()
440                if s_line[0].lower().count("ynodes") > 0:
441                    ynodes = s_line[1].lstrip()
442                if s_line[0].lower().count("znodes") > 0:
443                    znodes = s_line[1].lstrip()
444                if s_line[0].lower().count("xmin") > 0:
445                    xmin = s_line[1].lstrip()
446                if s_line[0].lower().count("ymin") > 0:
447                    ymin = s_line[1].lstrip()
448                if s_line[0].lower().count("zmin") > 0:
449                    zmin = s_line[1].lstrip()
450                if s_line[0].lower().count("xmax") > 0:
451                    xmax = s_line[1].lstrip()
452                if s_line[0].lower().count("ymax") > 0:
453                    ymax = s_line[1].lstrip()
454                if s_line[0].lower().count("zmax") > 0:
455                    zmax = s_line[1].lstrip()
456                if s_line[0].lower().count("valueunit") > 0:
457                    valueunit = s_line[1].lstrip().rstrip()
458                if s_line[0].lower().count("valuemultiplier") > 0:
459                    valuemultiplier = s_line[1].lstrip()
460                if s_line[0].lower().count("valuerangeminmag") > 0:
461                    valuerangeminmag = s_line[1].lstrip()
462                if s_line[0].lower().count("valuerangemaxmag") > 0:
463                    valuerangemaxmag = s_line[1].lstrip()
464                if s_line[0].lower().count("end") > 0:
465                    output.filename = os.path.basename(path)
466                    output.oommf = oommf
467                    output.title = title
468                    output.desc = desc
469                    output.meshtype = meshtype
470                    output.xbase = float(xbase) * METER2ANG
471                    output.ybase = float(ybase) * METER2ANG
472                    output.zbase = float(zbase) * METER2ANG
473                    output.xstepsize = float(xstepsize) * METER2ANG
474                    output.ystepsize = float(ystepsize) * METER2ANG
475                    output.zstepsize = float(zstepsize) * METER2ANG
476                    output.xnodes = float(xnodes)
477                    output.ynodes = float(ynodes)
478                    output.znodes = float(znodes)
479                    output.xmin = float(xmin) * METER2ANG
480                    output.ymin = float(ymin) * METER2ANG
481                    output.zmin = float(zmin) * METER2ANG
482                    output.xmax = float(xmax) * METER2ANG
483                    output.ymax = float(ymax) * METER2ANG
484                    output.zmax = float(zmax) * METER2ANG
485                    output.valuemultiplier = valuemultiplier
486                    output.valuerangeminmag = mag2sld(float(valuerangeminmag), \
487                                                      valueunit)
488                    output.valuerangemaxmag = mag2sld(float(valuerangemaxmag), \
489                                                      valueunit)
490            output.set_m(mx, my, mz)
491            return output
492        except Exception:
493            msg = "%s is not supported: \n" % path
494            msg += "We accept only Text format OMF file."
495            raise RuntimeError(msg)
496
497class PDBReader(object):
498    """
499    PDB reader class: limited for reading the lines starting with 'ATOM'
500    """
501    type_name = "PDB"
502    ## Wildcards
503    type = ["pdb files (*.PDB, *.pdb)|*.pdb"]
504    ## List of allowed extensions
505    ext = ['.pdb', '.PDB']
506
507    def read(self, path):
508        """
509        Load data file
510
511        :param path: file path
512        :return: MagSLD
513        :raise RuntimeError: when the file can't be opened
514        """
515        pos_x = np.zeros(0)
516        pos_y = np.zeros(0)
517        pos_z = np.zeros(0)
518        sld_n = np.zeros(0)
519        sld_mx = np.zeros(0)
520        sld_my = np.zeros(0)
521        sld_mz = np.zeros(0)
522        vol_pix = np.zeros(0)
523        pix_symbol = np.zeros(0)
524        x_line = []
525        y_line = []
526        z_line = []
527        x_lines = []
528        y_lines = []
529        z_lines = []
530        try:
531            input_f = open(path, 'rb')
532            buff = decode(input_f.read())
533            lines = buff.split('\n')
534            input_f.close()
535            num = 0
536            for line in lines:
537                try:
538                    # check if line starts with "ATOM"
539                    if line[0:6].strip().count('ATM') > 0 or \
540                                line[0:6].strip() == 'ATOM':
541                        # define fields of interest
542                        atom_name = line[12:16].strip()
543                        try:
544                            float(line[12])
545                            atom_name = atom_name[1].upper()
546                        except Exception:
547                            if len(atom_name) == 4:
548                                atom_name = atom_name[0].upper()
549                            elif line[12] != ' ':
550                                atom_name = atom_name[0].upper() + \
551                                        atom_name[1].lower()
552                            else:
553                                atom_name = atom_name[0].upper()
554                        _pos_x = float(line[30:38].strip())
555                        _pos_y = float(line[38:46].strip())
556                        _pos_z = float(line[46:54].strip())
557                        pos_x = np.append(pos_x, _pos_x)
558                        pos_y = np.append(pos_y, _pos_y)
559                        pos_z = np.append(pos_z, _pos_z)
560                        try:
561                            val = nsf.neutron_sld(atom_name)[0]
562                            # sld in Ang^-2 unit
563                            val *= 1.0e-6
564                            sld_n = np.append(sld_n, val)
565                            atom = formula(atom_name)
566                            # cm to A units
567                            vol = 1.0e+24 * atom.mass / atom.density / NA
568                            vol_pix = np.append(vol_pix, vol)
569                        except Exception:
570                            logger.error("Error: set the sld of %s to zero"% atom_name)
571                            sld_n = np.append(sld_n, 0.0)
572                        sld_mx = np.append(sld_mx, 0)
573                        sld_my = np.append(sld_my, 0)
574                        sld_mz = np.append(sld_mz, 0)
575                        pix_symbol = np.append(pix_symbol, atom_name)
576                    elif line[0:6].strip().count('CONECT') > 0:
577                        toks = line.split()
578                        num = int(toks[1]) - 1
579                        val_list = []
580                        for val in toks[2:]:
581                            try:
582                                int_val = int(val)
583                            except Exception:
584                                break
585                            if int_val == 0:
586                                break
587                            val_list.append(int_val)
588                        #need val_list ordered
589                        for val in val_list:
590                            index = val - 1
591                            if (pos_x[index], pos_x[num]) in x_line and \
592                               (pos_y[index], pos_y[num]) in y_line and \
593                               (pos_z[index], pos_z[num]) in z_line:
594                                continue
595                            x_line.append((pos_x[num], pos_x[index]))
596                            y_line.append((pos_y[num], pos_y[index]))
597                            z_line.append((pos_z[num], pos_z[index]))
598                    if len(x_line) > 0:
599                        x_lines.append(x_line)
600                        y_lines.append(y_line)
601                        z_lines.append(z_line)
602                except Exception:
603                    logger.error(sys.exc_value)
604
605            output = MagSLD(pos_x, pos_y, pos_z, sld_n, sld_mx, sld_my, sld_mz)
606            output.set_conect_lines(x_line, y_line, z_line)
607            output.filename = os.path.basename(path)
608            output.set_pix_type('atom')
609            output.set_pixel_symbols(pix_symbol)
610            output.set_nodes()
611            output.set_pixel_volumes(vol_pix)
612            output.sld_unit = '1/A^(2)'
613            return output
614        except Exception:
615            raise RuntimeError("%s is not a sld file" % path)
616
617    def write(self, path, data):
618        """
619        Write
620        """
621        print("Not implemented... ")
622
623class SLDReader(object):
624    """
625    Class to load ascii files (7 columns).
626    """
627    ## File type
628    type_name = "SLD ASCII"
629    ## Wildcards
630    type = ["sld files (*.SLD, *.sld)|*.sld",
631            "txt files (*.TXT, *.txt)|*.txt",
632            "all files (*.*)|*.*"]
633    ## List of allowed extensions
634    ext = ['.sld', '.SLD', '.txt', '.TXT', '.*']
635    def read(self, path):
636        """
637        Load data file
638        :param path: file path
639        :return MagSLD: x, y, z, sld_n, sld_mx, sld_my, sld_mz
640        :raise RuntimeError: when the file can't be opened
641        :raise ValueError: when the length of the data vectors are inconsistent
642        """
643        try:
644            pos_x = np.zeros(0)
645            pos_y = np.zeros(0)
646            pos_z = np.zeros(0)
647            sld_n = np.zeros(0)
648            sld_mx = np.zeros(0)
649            sld_my = np.zeros(0)
650            sld_mz = np.zeros(0)
651            try:
652                # Use numpy to speed up loading
653                input_f = np.loadtxt(path, dtype='float', skiprows=1,
654                                        ndmin=1, unpack=True)
655                pos_x = np.array(input_f[0])
656                pos_y = np.array(input_f[1])
657                pos_z = np.array(input_f[2])
658                sld_n = np.array(input_f[3])
659                sld_mx = np.array(input_f[4])
660                sld_my = np.array(input_f[5])
661                sld_mz = np.array(input_f[6])
662                ncols = len(input_f)
663                if ncols == 8:
664                    vol_pix = np.array(input_f[7])
665                elif ncols == 7:
666                    vol_pix = None
667            except Exception:
668                # For older version of numpy
669                input_f = open(path, 'rb')
670                buff = decode(input_f.read())
671                lines = buff.split('\n')
672                input_f.close()
673                for line in lines:
674                    toks = line.split()
675                    try:
676                        _pos_x = float(toks[0])
677                        _pos_y = float(toks[1])
678                        _pos_z = float(toks[2])
679                        _sld_n = float(toks[3])
680                        _sld_mx = float(toks[4])
681                        _sld_my = float(toks[5])
682                        _sld_mz = float(toks[6])
683                        pos_x = np.append(pos_x, _pos_x)
684                        pos_y = np.append(pos_y, _pos_y)
685                        pos_z = np.append(pos_z, _pos_z)
686                        sld_n = np.append(sld_n, _sld_n)
687                        sld_mx = np.append(sld_mx, _sld_mx)
688                        sld_my = np.append(sld_my, _sld_my)
689                        sld_mz = np.append(sld_mz, _sld_mz)
690                        try:
691                            _vol_pix = float(toks[7])
692                            vol_pix = np.append(vol_pix, _vol_pix)
693                        except Exception:
694                            vol_pix = None
695                    except Exception:
696                        # Skip non-data lines
697                        logger.error(sys.exc_value)
698            output = MagSLD(pos_x, pos_y, pos_z, sld_n,
699                            sld_mx, sld_my, sld_mz)
700            output.filename = os.path.basename(path)
701            output.set_pix_type('pixel')
702            output.set_pixel_symbols('pixel')
703            if vol_pix is not None:
704                output.set_pixel_volumes(vol_pix)
705            return output
706        except Exception:
707            raise RuntimeError("%s is not a sld file" % path)
708
709    def write(self, path, data):
710        """
711        Write sld file
712        :Param path: file path
713        :Param data: MagSLD data object
714        """
715        if path is None:
716            raise ValueError("Missing the file path.")
717        if data is None:
718            raise ValueError("Missing the data to save.")
719        x_val = data.pos_x
720        y_val = data.pos_y
721        z_val = data.pos_z
722        vol_pix = data.vol_pix
723        length = len(x_val)
724        sld_n = data.sld_n
725        if sld_n is None:
726            sld_n = np.zeros(length)
727        sld_mx = data.sld_mx
728        if sld_mx is None:
729            sld_mx = np.zeros(length)
730            sld_my = np.zeros(length)
731            sld_mz = np.zeros(length)
732        else:
733            sld_my = data.sld_my
734            sld_mz = data.sld_mz
735        out = open(path, 'w')
736        # First Line: Column names
737        out.write("X  Y  Z  SLDN SLDMx  SLDMy  SLDMz VOLUMEpix")
738        for ind in range(length):
739            out.write("\n%g  %g  %g  %g  %g  %g  %g %g" % \
740                      (x_val[ind], y_val[ind], z_val[ind], sld_n[ind],
741                       sld_mx[ind], sld_my[ind], sld_mz[ind], vol_pix[ind]))
742        out.close()
743
744
745class OMFData(object):
746    """
747    OMF Data.
748    """
749    _meshunit = "A"
750    _valueunit = "A^(-2)"
751    def __init__(self):
752        """
753        Init for mag SLD
754        """
755        self.filename = 'default'
756        self.oommf = ''
757        self.title = ''
758        self.desc = ''
759        self.meshtype = ''
760        self.meshunit = self._meshunit
761        self.valueunit = self._valueunit
762        self.xbase = 0.0
763        self.ybase = 0.0
764        self.zbase = 0.0
765        self.xstepsize = 6.0
766        self.ystepsize = 6.0
767        self.zstepsize = 6.0
768        self.xnodes = 10.0
769        self.ynodes = 10.0
770        self.znodes = 10.0
771        self.xmin = 0.0
772        self.ymin = 0.0
773        self.zmin = 0.0
774        self.xmax = 60.0
775        self.ymax = 60.0
776        self.zmax = 60.0
777        self.mx = None
778        self.my = None
779        self.mz = None
780        self.valuemultiplier = 1.
781        self.valuerangeminmag = 0
782        self.valuerangemaxmag = 0
783
784    def __str__(self):
785        """
786        doc strings
787        """
788        _str = "Type:            %s\n" % self.__class__.__name__
789        _str += "File:            %s\n" % self.filename
790        _str += "OOMMF:           %s\n" % self.oommf
791        _str += "Title:           %s\n" % self.title
792        _str += "Desc:            %s\n" % self.desc
793        _str += "meshtype:        %s\n" % self.meshtype
794        _str += "meshunit:        %s\n" % str(self.meshunit)
795        _str += "xbase:           %s [%s]\n" % (str(self.xbase), self.meshunit)
796        _str += "ybase:           %s [%s]\n" % (str(self.ybase), self.meshunit)
797        _str += "zbase:           %s [%s]\n" % (str(self.zbase), self.meshunit)
798        _str += "xstepsize:       %s [%s]\n" % (str(self.xstepsize),
799                                                self.meshunit)
800        _str += "ystepsize:       %s [%s]\n" % (str(self.ystepsize),
801                                                self.meshunit)
802        _str += "zstepsize:       %s [%s]\n" % (str(self.zstepsize),
803                                                self.meshunit)
804        _str += "xnodes:          %s\n" % str(self.xnodes)
805        _str += "ynodes:          %s\n" % str(self.ynodes)
806        _str += "znodes:          %s\n" % str(self.znodes)
807        _str += "xmin:            %s [%s]\n" % (str(self.xmin), self.meshunit)
808        _str += "ymin:            %s [%s]\n" % (str(self.ymin), self.meshunit)
809        _str += "zmin:            %s [%s]\n" % (str(self.zmin), self.meshunit)
810        _str += "xmax:            %s [%s]\n" % (str(self.xmax), self.meshunit)
811        _str += "ymax:            %s [%s]\n" % (str(self.ymax), self.meshunit)
812        _str += "zmax:            %s [%s]\n" % (str(self.zmax), self.meshunit)
813        _str += "valueunit:       %s\n" % self.valueunit
814        _str += "valuemultiplier: %s\n" % str(self.valuemultiplier)
815        _str += "ValueRangeMinMag:%s [%s]\n" % (str(self.valuerangeminmag),
816                                                self.valueunit)
817        _str += "ValueRangeMaxMag:%s [%s]\n" % (str(self.valuerangemaxmag),
818                                                self.valueunit)
819        return _str
820
821    def set_m(self, mx, my, mz):
822        """
823        Set the Mx, My, Mz values
824        """
825        self.mx = mx
826        self.my = my
827        self.mz = mz
828
829class MagSLD(object):
830    """
831    Magnetic SLD.
832    """
833    pos_x = None
834    pos_y = None
835    pos_z = None
836    sld_n = None
837    sld_mx = None
838    sld_my = None
839    sld_mz = None
840    # Units
841    _pos_unit = 'A'
842    _sld_unit = '1/A^(2)'
843    _pix_type = 'pixel'
844
845    def __init__(self, pos_x, pos_y, pos_z, sld_n=None,
846                 sld_mx=None, sld_my=None, sld_mz=None, vol_pix=None):
847        """
848        Init for mag SLD
849        :params : All should be numpy 1D array
850        """
851        self.is_data = True
852        self.filename = ''
853        self.xstepsize = 6.0
854        self.ystepsize = 6.0
855        self.zstepsize = 6.0
856        self.xnodes = 10.0
857        self.ynodes = 10.0
858        self.znodes = 10.0
859        self.has_stepsize = False
860        self.has_conect = False
861        self.pos_unit = self._pos_unit
862        self.sld_unit = self._sld_unit
863        self.pix_type = 'pixel'
864        self.pos_x = pos_x
865        self.pos_y = pos_y
866        self.pos_z = pos_z
867        self.sld_n = sld_n
868        self.line_x = None
869        self.line_y = None
870        self.line_z = None
871        self.sld_mx = sld_mx
872        self.sld_my = sld_my
873        self.sld_mz = sld_mz
874        self.vol_pix = vol_pix
875        self.sld_m = None
876        self.sld_phi = None
877        self.sld_theta = None
878        self.pix_symbol = None
879        if sld_mx is not None and sld_my is not None and sld_mz is not None:
880            self.set_sldms(sld_mx, sld_my, sld_mz)
881        self.set_nodes()
882
883    def __str__(self):
884        """
885        doc strings
886        """
887        _str = "Type:       %s\n" % self.__class__.__name__
888        _str += "File:       %s\n" % self.filename
889        _str += "Axis_unit:  %s\n" % self.pos_unit
890        _str += "SLD_unit:   %s\n" % self.sld_unit
891        return _str
892
893    def set_pix_type(self, pix_type):
894        """
895        Set pixel type
896        :Param pix_type: string, 'pixel' or 'atom'
897        """
898        self.pix_type = pix_type
899
900    def set_sldn(self, sld_n):
901        """
902        Sets neutron SLD
903        """
904        if sld_n.__class__.__name__ == 'float':
905            if self.is_data:
906                # For data, put the value to only the pixels w non-zero M
907                is_nonzero = (np.fabs(self.sld_mx) +
908                              np.fabs(self.sld_my) +
909                              np.fabs(self.sld_mz)).nonzero()
910                self.sld_n = np.zeros(len(self.pos_x))
911                if len(self.sld_n[is_nonzero]) > 0:
912                    self.sld_n[is_nonzero] = sld_n
913                else:
914                    self.sld_n.fill(sld_n)
915            else:
916                # For non-data, put the value to all the pixels
917                self.sld_n = np.ones(len(self.pos_x)) * sld_n
918        else:
919            self.sld_n = sld_n
920
921    def set_sldms(self, sld_mx, sld_my, sld_mz):
922        r"""
923        Sets mx, my, mz and abs(m).
924        """ # Note: escaping
925        if sld_mx.__class__.__name__ == 'float':
926            self.sld_mx = np.ones(len(self.pos_x)) * sld_mx
927        else:
928            self.sld_mx = sld_mx
929        if sld_my.__class__.__name__ == 'float':
930            self.sld_my = np.ones(len(self.pos_x)) * sld_my
931        else:
932            self.sld_my = sld_my
933        if sld_mz.__class__.__name__ == 'float':
934            self.sld_mz = np.ones(len(self.pos_x)) * sld_mz
935        else:
936            self.sld_mz = sld_mz
937
938        sld_m = np.sqrt(sld_mx * sld_mx + sld_my * sld_my + \
939                                sld_mz * sld_mz)
940        self.sld_m = sld_m
941
942    def set_pixel_symbols(self, symbol='pixel'):
943        """
944        Set pixel
945        :Params pixel: str; pixel or atomic symbol, or array of strings
946        """
947        if self.sld_n is None:
948            return
949        if symbol.__class__.__name__ == 'str':
950            self.pix_symbol = np.repeat(symbol, len(self.sld_n))
951        else:
952            self.pix_symbol = symbol
953
954    def set_pixel_volumes(self, vol):
955        """
956        Set pixel volumes
957        :Params pixel: str; pixel or atomic symbol, or array of strings
958        """
959        if self.sld_n is None:
960            return
961        if vol.__class__.__name__ == 'ndarray':
962            self.vol_pix = vol
963        elif vol.__class__.__name__.count('float') > 0:
964            self.vol_pix = np.repeat(vol, len(self.sld_n))
965        else:
966            self.vol_pix = None
967
968    def get_sldn(self):
969        """
970        Returns nuclear sld
971        """
972        return self.sld_n
973
974    def set_nodes(self):
975        """
976        Set xnodes, ynodes, and znodes
977        """
978        self.set_stepsize()
979        if self.pix_type == 'pixel':
980            try:
981                xdist = (max(self.pos_x) - min(self.pos_x)) / self.xstepsize
982                ydist = (max(self.pos_y) - min(self.pos_y)) / self.ystepsize
983                zdist = (max(self.pos_z) - min(self.pos_z)) / self.zstepsize
984                self.xnodes = int(xdist) + 1
985                self.ynodes = int(ydist) + 1
986                self.znodes = int(zdist) + 1
987            except Exception:
988                self.xnodes = None
989                self.ynodes = None
990                self.znodes = None
991        else:
992            self.xnodes = None
993            self.ynodes = None
994            self.znodes = None
995
996    def set_stepsize(self):
997        """
998        Set xtepsize, ystepsize, and zstepsize
999        """
1000        if self.pix_type == 'pixel':
1001            try:
1002                xpos_pre = self.pos_x[0]
1003                ypos_pre = self.pos_y[0]
1004                zpos_pre = self.pos_z[0]
1005                for x_pos in self.pos_x:
1006                    if xpos_pre != x_pos:
1007                        self.xstepsize = np.fabs(x_pos - xpos_pre)
1008                        break
1009                for y_pos in self.pos_y:
1010                    if ypos_pre != y_pos:
1011                        self.ystepsize = np.fabs(y_pos - ypos_pre)
1012                        break
1013                for z_pos in self.pos_z:
1014                    if zpos_pre != z_pos:
1015                        self.zstepsize = np.fabs(z_pos - zpos_pre)
1016                        break
1017                #default pix volume
1018                self.vol_pix = np.ones(len(self.pos_x))
1019                vol = self.xstepsize * self.ystepsize * self.zstepsize
1020                self.set_pixel_volumes(vol)
1021                self.has_stepsize = True
1022            except Exception:
1023                self.xstepsize = None
1024                self.ystepsize = None
1025                self.zstepsize = None
1026                self.vol_pix = None
1027                self.has_stepsize = False
1028        else:
1029            self.xstepsize = None
1030            self.ystepsize = None
1031            self.zstepsize = None
1032            self.has_stepsize = True
1033        return self.xstepsize, self.ystepsize, self.zstepsize
1034
1035    def set_conect_lines(self, line_x, line_y, line_z):
1036        """
1037        Set bonding line data if taken from pdb
1038        """
1039        if line_x.__class__.__name__ != 'list' or len(line_x) < 1:
1040            return
1041        if line_y.__class__.__name__ != 'list' or len(line_y) < 1:
1042            return
1043        if line_z.__class__.__name__ != 'list' or len(line_z) < 1:
1044            return
1045        self.has_conect = True
1046        self.line_x = line_x
1047        self.line_y = line_y
1048        self.line_z = line_z
1049
1050def _get_data_path(*path_parts):
1051    from os.path import realpath, join as joinpath, dirname, abspath
1052    # in sas/sascalc/calculator;  want sas/sasview/test
1053    return joinpath(dirname(realpath(__file__)),
1054                    '..', '..', 'sasview', 'test', *path_parts)
1055
1056def test_load():
1057    """
1058        Test code
1059    """
1060    from mpl_toolkits.mplot3d import Axes3D
1061    tfpath = _get_data_path("1d_data", "CoreXY_ShellZ.txt")
1062    ofpath = _get_data_path("coordinate_data", "A_Raw_Example-1.omf")
1063    if not os.path.isfile(tfpath) or not os.path.isfile(ofpath):
1064        raise ValueError("file(s) not found: %r, %r"%(tfpath, ofpath))
1065    reader = SLDReader()
1066    oreader = OMFReader()
1067    output = reader.read(tfpath)
1068    ooutput = oreader.read(ofpath)
1069    foutput = OMF2SLD()
1070    foutput.set_data(ooutput)
1071
1072    import matplotlib.pyplot as plt
1073    fig = plt.figure()
1074    ax = Axes3D(fig)
1075    ax.plot(output.pos_x, output.pos_y, output.pos_z, '.', c="g",
1076            alpha=0.7, markeredgecolor='gray', rasterized=True)
1077    gap = 7
1078    max_mx = max(output.sld_mx)
1079    max_my = max(output.sld_my)
1080    max_mz = max(output.sld_mz)
1081    max_m = max(max_mx, max_my, max_mz)
1082    x2 = output.pos_x+output.sld_mx/max_m * gap
1083    y2 = output.pos_y+output.sld_my/max_m * gap
1084    z2 = output.pos_z+output.sld_mz/max_m * gap
1085    x_arrow = np.column_stack((output.pos_x, x2))
1086    y_arrow = np.column_stack((output.pos_y, y2))
1087    z_arrow = np.column_stack((output.pos_z, z2))
1088    unit_x2 = output.sld_mx / max_m
1089    unit_y2 = output.sld_my / max_m
1090    unit_z2 = output.sld_mz / max_m
1091    color_x = np.fabs(unit_x2 * 0.8)
1092    color_y = np.fabs(unit_y2 * 0.8)
1093    color_z = np.fabs(unit_z2 * 0.8)
1094    colors = np.column_stack((color_x, color_y, color_z))
1095    plt.show()
1096
1097def test_save():
1098    ofpath = _get_data_path("coordinate_data", "A_Raw_Example-1.omf")
1099    if not os.path.isfile(ofpath):
1100        raise ValueError("file(s) not found: %r"%(ofpath,))
1101    oreader = OMFReader()
1102    omfdata = oreader.read(ofpath)
1103    omf2sld = OMF2SLD()
1104    omf2sld.set_data(omfdata)
1105    writer = SLDReader()
1106    writer.write("out.txt", omf2sld.output)
1107
1108def test():
1109    """
1110        Test code
1111    """
1112    ofpath = _get_data_path("coordinate_data", "A_Raw_Example-1.omf")
1113    if not os.path.isfile(ofpath):
1114        raise ValueError("file(s) not found: %r"%(ofpath,))
1115    oreader = OMFReader()
1116    omfdata = oreader.read(ofpath)
1117    omf2sld = OMF2SLD()
1118    omf2sld.set_data(omfdata)
1119    model = GenSAS()
1120    model.set_sld_data(omf2sld.output)
1121    x = np.linspace(0, 0.1, 11)[1:]
1122    return model.runXY([x, x])
1123
1124if __name__ == "__main__":
1125    #test_load()
1126    #test_save()
1127    #print(test())
1128    test()
Note: See TracBrowser for help on using the repository browser.