source: sasview/src/sas/sascalc/calculator/sas_gen.py @ b796c72

ESS_GUIESS_GUI_DocsESS_GUI_batch_fittingESS_GUI_bumps_abstractionESS_GUI_iss1116ESS_GUI_iss879ESS_GUI_iss959ESS_GUI_openclESS_GUI_orderingESS_GUI_sync_sascalcmagnetic_scattrelease-4.2.2ticket-1009ticket-1094-headlessticket-1242-2d-resolutionticket-1243ticket-1249ticket885unittest-saveload
Last change on this file since b796c72 was 1d014cb, checked in by Paul Kienzle <pkienzle@…>, 7 years ago

sas_gen tests now pass cleanly

  • Property mode set to 100644
File size: 39.7 KB
Line 
1# pylint: disable=invalid-name
2"""
3SAS generic computation and sld file readers
4"""
5from __future__ import print_function
6
7import os
8import sys
9import copy
10import logging
11
12from periodictable import formula
13from periodictable import nsf
14import numpy as np
15
16from .core import sld2i as mod
17from .BaseComponent import BaseComponent
18
19logger = logging.getLogger(__name__)
20
21if sys.version_info[0] < 3:
22    def decode(s):
23        return s
24else:
25    def decode(s):
26        return s.decode() if isinstance(s, bytes) else s
27
28MFACTOR_AM = 2.853E-12
29MFACTOR_MT = 2.3164E-9
30METER2ANG = 1.0E+10
31#Avogadro constant [1/mol]
32NA = 6.02214129e+23
33
34def mag2sld(mag, v_unit=None):
35    """
36    Convert magnetization to magnatic SLD
37    sldm = Dm * mag where Dm = gamma * classical elec. radius/(2*Bohr magneton)
38    Dm ~ 2.853E-12 [A^(-2)] ==> Shouldn't be 2.90636E-12 [A^(-2)]???
39    """
40    if v_unit == "A/m":
41        factor = MFACTOR_AM
42    elif v_unit == "mT":
43        factor = MFACTOR_MT
44    else:
45        raise ValueError("Invalid valueunit")
46    sld_m = factor * mag
47    return sld_m
48
49def transform_center(pos_x, pos_y, pos_z):
50    """
51    re-center
52    :return: posx, posy, posz   [arrays]
53    """
54    posx = pos_x - (min(pos_x) + max(pos_x)) / 2.0
55    posy = pos_y - (min(pos_y) + max(pos_y)) / 2.0
56    posz = pos_z - (min(pos_z) + max(pos_z)) / 2.0
57    return posx, posy, posz
58
59class GenSAS(BaseComponent):
60    """
61    Generic SAS computation Model based on sld (n & m) arrays
62    """
63    def __init__(self):
64        """
65        Init
66        :Params sld_data: MagSLD object
67        """
68        # Initialize BaseComponent
69        BaseComponent.__init__(self)
70        self.sld_data = None
71        self.data_pos_unit = None
72        self.data_x = None
73        self.data_y = None
74        self.data_z = None
75        self.data_sldn = None
76        self.data_mx = None
77        self.data_my = None
78        self.data_mz = None
79        self.data_vol = None #[A^3]
80        self.is_avg = False
81        ## Name of the model
82        self.name = "GenSAS"
83        ## Define parameters
84        self.params = {}
85        self.params['scale'] = 1.0
86        self.params['background'] = 0.0
87        self.params['solvent_SLD'] = 0.0
88        self.params['total_volume'] = 1.0
89        self.params['Up_frac_in'] = 1.0
90        self.params['Up_frac_out'] = 1.0
91        self.params['Up_theta'] = 0.0
92        self.description = 'GenSAS'
93        ## Parameter details [units, min, max]
94        self.details = {}
95        self.details['scale'] = ['', 0.0, np.inf]
96        self.details['background'] = ['[1/cm]', 0.0, np.inf]
97        self.details['solvent_SLD'] = ['1/A^(2)', -np.inf, np.inf]
98        self.details['total_volume'] = ['A^(3)', 0.0, np.inf]
99        self.details['Up_frac_in'] = ['[u/(u+d)]', 0.0, 1.0]
100        self.details['Up_frac_out'] = ['[u/(u+d)]', 0.0, 1.0]
101        self.details['Up_theta'] = ['[deg]', -np.inf, np.inf]
102        # fixed parameters
103        self.fixed = []
104
105    def set_pixel_volumes(self, volume):
106        """
107        Set the volume of a pixel in (A^3) unit
108        :Param volume: pixel volume [float]
109        """
110        if self.data_vol is None:
111            raise TypeError("data_vol is missing")
112        self.data_vol = volume
113
114    def set_is_avg(self, is_avg=False):
115        """
116        Sets is_avg: [bool]
117        """
118        self.is_avg = is_avg
119
120    def _gen(self, x, y, i):
121        """
122        Evaluate the function
123        :Param x: array of x-values
124        :Param y: array of y-values
125        :Param i: array of initial i-value
126        :return: function value
127        """
128        pos_x = self.data_x
129        pos_y = self.data_y
130        pos_z = self.data_z
131        len_x = len(pos_x)
132        if self.is_avg is None:
133            len_x *= -1
134            pos_x, pos_y, pos_z = transform_center(pos_x, pos_y, pos_z)
135        len_q = len(x)
136        sldn = copy.deepcopy(self.data_sldn)
137        sldn -= self.params['solvent_SLD']
138        model = mod.new_GenI(len_x, pos_x, pos_y, pos_z,
139                             sldn, self.data_mx, self.data_my,
140                             self.data_mz, self.data_vol,
141                             self.params['Up_frac_in'],
142                             self.params['Up_frac_out'],
143                             self.params['Up_theta'])
144        if y == []:
145            mod.genicom(model, len_q, x, i)
146        else:
147            mod.genicomXY(model, len_q, x, y, i)
148        vol_correction = self.data_total_volume / self.params['total_volume']
149        return  self.params['scale'] * vol_correction * i + \
150                        self.params['background']
151
152    def set_sld_data(self, sld_data=None):
153        """
154        Sets sld_data
155        """
156        self.sld_data = sld_data
157        self.data_pos_unit = sld_data.pos_unit
158        self.data_x = sld_data.pos_x
159        self.data_y = sld_data.pos_y
160        self.data_z = sld_data.pos_z
161        self.data_sldn = sld_data.sld_n
162        self.data_mx = sld_data.sld_mx
163        self.data_my = sld_data.sld_my
164        self.data_mz = sld_data.sld_mz
165        self.data_vol = sld_data.vol_pix
166        self.data_total_volume = sum(sld_data.vol_pix)
167        self.params['total_volume'] = sum(sld_data.vol_pix)
168
169    def getProfile(self):
170        """
171        Get SLD profile
172        : return: sld_data
173        """
174        return self.sld_data
175
176    def run(self, x=0.0):
177        """
178        Evaluate the model
179        :param x: simple value
180        :return: (I value)
181        """
182        if x.__class__.__name__ == 'list':
183            if len(x[1]) > 0:
184                msg = "Not a 1D."
185                raise ValueError(msg)
186            i_out = np.zeros_like(x[0])
187            # 1D I is found at y =0 in the 2D pattern
188            out = self._gen(x[0], [], i_out)
189            return out
190        else:
191            msg = "Q must be given as list of qx's and qy's"
192            raise ValueError(msg)
193
194    def runXY(self, x=0.0):
195        """
196        Evaluate the model
197        :param x: simple value
198        :return: I value
199        :Use this runXY() for the computation
200        """
201        if x.__class__.__name__ == 'list':
202            i_out = np.zeros_like(x[0])
203            out = self._gen(x[0], x[1], i_out)
204            return out
205        else:
206            msg = "Q must be given as list of qx's and qy's"
207            raise ValueError(msg)
208
209    def evalDistribution(self, qdist):
210        """
211        Evaluate a distribution of q-values.
212
213        :param qdist: ndarray of scalar q-values (for 1D) or list [qx,qy]
214                      where qx,qy are 1D ndarrays (for 2D).
215        """
216        if qdist.__class__.__name__ == 'list':
217            if len(qdist[1]) < 1:
218                out = self.run(qdist)
219            else:
220                out = self.runXY(qdist)
221            return out
222        else:
223            mesg = "evalDistribution is expecting an ndarray of "
224            mesg += "a list [qx,qy] where qx,qy are arrays."
225            raise RuntimeError(mesg)
226
227class OMF2SLD(object):
228    """
229    Convert OMFData to MAgData
230    """
231    def __init__(self):
232        """
233        Init
234        """
235        self.pos_x = None
236        self.pos_y = None
237        self.pos_z = None
238        self.mx = None
239        self.my = None
240        self.mz = None
241        self.sld_n = None
242        self.vol_pix = None
243        self.output = None
244        self.omfdata = None
245
246    def set_data(self, omfdata, shape='rectangular'):
247        """
248        Set all data
249        """
250        self.omfdata = omfdata
251        length = int(omfdata.xnodes * omfdata.ynodes * omfdata.znodes)
252        pos_x = np.arange(omfdata.xmin,
253                             omfdata.xnodes*omfdata.xstepsize + omfdata.xmin,
254                             omfdata.xstepsize)
255        pos_y = np.arange(omfdata.ymin,
256                             omfdata.ynodes*omfdata.ystepsize + omfdata.ymin,
257                             omfdata.ystepsize)
258        pos_z = np.arange(omfdata.zmin,
259                             omfdata.znodes*omfdata.zstepsize + omfdata.zmin,
260                             omfdata.zstepsize)
261        self.pos_x = np.tile(pos_x, int(omfdata.ynodes * omfdata.znodes))
262        self.pos_y = pos_y.repeat(int(omfdata.xnodes))
263        self.pos_y = np.tile(self.pos_y, int(omfdata.znodes))
264        self.pos_z = pos_z.repeat(int(omfdata.xnodes * omfdata.ynodes))
265        self.mx = omfdata.mx
266        self.my = omfdata.my
267        self.mz = omfdata.mz
268        self.sld_n = np.zeros(length)
269
270        if omfdata.mx is None:
271            self.mx = np.zeros(length)
272        if omfdata.my is None:
273            self.my = np.zeros(length)
274        if omfdata.mz is None:
275            self.mz = np.zeros(length)
276
277        self._check_data_length(length)
278        self.remove_null_points(False, False)
279        mask = np.ones(len(self.sld_n), dtype=bool)
280        if shape.lower() == 'ellipsoid':
281            try:
282                # Pixel (step) size included
283                x_c = max(self.pos_x) + min(self.pos_x)
284                y_c = max(self.pos_y) + min(self.pos_y)
285                z_c = max(self.pos_z) + min(self.pos_z)
286                x_d = max(self.pos_x) - min(self.pos_x)
287                y_d = max(self.pos_y) - min(self.pos_y)
288                z_d = max(self.pos_z) - min(self.pos_z)
289                x_r = (x_d + omfdata.xstepsize) / 2.0
290                y_r = (y_d + omfdata.ystepsize) / 2.0
291                z_r = (z_d + omfdata.zstepsize) / 2.0
292                x_dir2 = ((self.pos_x - x_c / 2.0) / x_r)
293                x_dir2 *= x_dir2
294                y_dir2 = ((self.pos_y - y_c / 2.0) / y_r)
295                y_dir2 *= y_dir2
296                z_dir2 = ((self.pos_z - z_c / 2.0) / z_r)
297                z_dir2 *= z_dir2
298                mask = (x_dir2 + y_dir2 + z_dir2) <= 1.0
299            except Exception:
300                logger.error(sys.exc_value)
301        self.output = MagSLD(self.pos_x[mask], self.pos_y[mask],
302                             self.pos_z[mask], self.sld_n[mask],
303                             self.mx[mask], self.my[mask], self.mz[mask])
304        self.output.set_pix_type('pixel')
305        self.output.set_pixel_symbols('pixel')
306
307    def get_omfdata(self):
308        """
309        Return all data
310        """
311        return self.omfdata
312
313    def get_output(self):
314        """
315        Return output
316        """
317        return self.output
318
319    def _check_data_length(self, length):
320        """
321        Check if the data lengths are consistent
322        :Params length: data length
323        """
324        parts = (self.pos_x, self.pos_y, self.pos_z, self.mx, self.my, self.mz)
325        if any(len(v) != length for v in parts):
326            raise ValueError("Error: Inconsistent data length.")
327
328    def remove_null_points(self, remove=False, recenter=False):
329        """
330        Removes any mx, my, and mz = 0 points
331        """
332        if remove:
333            is_nonzero = (np.fabs(self.mx) + np.fabs(self.my) +
334                          np.fabs(self.mz)).nonzero()
335            if len(is_nonzero[0]) > 0:
336                self.pos_x = self.pos_x[is_nonzero]
337                self.pos_y = self.pos_y[is_nonzero]
338                self.pos_z = self.pos_z[is_nonzero]
339                self.sld_n = self.sld_n[is_nonzero]
340                self.mx = self.mx[is_nonzero]
341                self.my = self.my[is_nonzero]
342                self.mz = self.mz[is_nonzero]
343        if recenter:
344            self.pos_x -= (min(self.pos_x) + max(self.pos_x)) / 2.0
345            self.pos_y -= (min(self.pos_y) + max(self.pos_y)) / 2.0
346            self.pos_z -= (min(self.pos_z) + max(self.pos_z)) / 2.0
347
348    def get_magsld(self):
349        """
350        return MagSLD
351        """
352        return self.output
353
354
355class OMFReader(object):
356    """
357    Class to load omf/ascii files (3 columns w/header).
358    """
359    ## File type
360    type_name = "OMF ASCII"
361
362    ## Wildcards
363    type = ["OMF files (*.OMF, *.omf)|*.omf"]
364    ## List of allowed extensions
365    ext = ['.omf', '.OMF']
366
367    def read(self, path):
368        """
369        Load data file
370        :param path: file path
371        :return: x, y, z, sld_n, sld_mx, sld_my, sld_mz
372        """
373        desc = ""
374        mx = np.zeros(0)
375        my = np.zeros(0)
376        mz = np.zeros(0)
377        try:
378            input_f = open(path, 'rb')
379            buff = decode(input_f.read())
380            lines = buff.split('\n')
381            input_f.close()
382            output = OMFData()
383            valueunit = None
384            for line in lines:
385                line = line.strip()
386                # Read data
387                if line and not line.startswith('#'):
388                    try:
389                        toks = line.split()
390                        _mx = float(toks[0])
391                        _my = float(toks[1])
392                        _mz = float(toks[2])
393                        _mx = mag2sld(_mx, valueunit)
394                        _my = mag2sld(_my, valueunit)
395                        _mz = mag2sld(_mz, valueunit)
396                        mx = np.append(mx, _mx)
397                        my = np.append(my, _my)
398                        mz = np.append(mz, _mz)
399                    except Exception as exc:
400                        # Skip non-data lines
401                        logger.error(str(exc)+" when processing %r"%line)
402                #Reading Header; Segment count ignored
403                s_line = line.split(":", 1)
404                if s_line[0].lower().count("oommf") > 0:
405                    oommf = s_line[1].lstrip()
406                if s_line[0].lower().count("title") > 0:
407                    title = s_line[1].lstrip()
408                if s_line[0].lower().count("desc") > 0:
409                    desc += s_line[1].lstrip()
410                    desc += '\n'
411                if s_line[0].lower().count("meshtype") > 0:
412                    meshtype = s_line[1].lstrip()
413                if s_line[0].lower().count("meshunit") > 0:
414                    meshunit = s_line[1].lstrip()
415                    if meshunit.count("m") < 1:
416                        msg = "Error: \n"
417                        msg += "We accept only m as meshunit"
418                        raise ValueError(msg)
419                if s_line[0].lower().count("xbase") > 0:
420                    xbase = s_line[1].lstrip()
421                if s_line[0].lower().count("ybase") > 0:
422                    ybase = s_line[1].lstrip()
423                if s_line[0].lower().count("zbase") > 0:
424                    zbase = s_line[1].lstrip()
425                if s_line[0].lower().count("xstepsize") > 0:
426                    xstepsize = s_line[1].lstrip()
427                if s_line[0].lower().count("ystepsize") > 0:
428                    ystepsize = s_line[1].lstrip()
429                if s_line[0].lower().count("zstepsize") > 0:
430                    zstepsize = s_line[1].lstrip()
431                if s_line[0].lower().count("xnodes") > 0:
432                    xnodes = s_line[1].lstrip()
433                if s_line[0].lower().count("ynodes") > 0:
434                    ynodes = s_line[1].lstrip()
435                if s_line[0].lower().count("znodes") > 0:
436                    znodes = s_line[1].lstrip()
437                if s_line[0].lower().count("xmin") > 0:
438                    xmin = s_line[1].lstrip()
439                if s_line[0].lower().count("ymin") > 0:
440                    ymin = s_line[1].lstrip()
441                if s_line[0].lower().count("zmin") > 0:
442                    zmin = s_line[1].lstrip()
443                if s_line[0].lower().count("xmax") > 0:
444                    xmax = s_line[1].lstrip()
445                if s_line[0].lower().count("ymax") > 0:
446                    ymax = s_line[1].lstrip()
447                if s_line[0].lower().count("zmax") > 0:
448                    zmax = s_line[1].lstrip()
449                if s_line[0].lower().count("valueunit") > 0:
450                    valueunit = s_line[1].lstrip().rstrip()
451                if s_line[0].lower().count("valuemultiplier") > 0:
452                    valuemultiplier = s_line[1].lstrip()
453                if s_line[0].lower().count("valuerangeminmag") > 0:
454                    valuerangeminmag = s_line[1].lstrip()
455                if s_line[0].lower().count("valuerangemaxmag") > 0:
456                    valuerangemaxmag = s_line[1].lstrip()
457                if s_line[0].lower().count("end") > 0:
458                    output.filename = os.path.basename(path)
459                    output.oommf = oommf
460                    output.title = title
461                    output.desc = desc
462                    output.meshtype = meshtype
463                    output.xbase = float(xbase) * METER2ANG
464                    output.ybase = float(ybase) * METER2ANG
465                    output.zbase = float(zbase) * METER2ANG
466                    output.xstepsize = float(xstepsize) * METER2ANG
467                    output.ystepsize = float(ystepsize) * METER2ANG
468                    output.zstepsize = float(zstepsize) * METER2ANG
469                    output.xnodes = float(xnodes)
470                    output.ynodes = float(ynodes)
471                    output.znodes = float(znodes)
472                    output.xmin = float(xmin) * METER2ANG
473                    output.ymin = float(ymin) * METER2ANG
474                    output.zmin = float(zmin) * METER2ANG
475                    output.xmax = float(xmax) * METER2ANG
476                    output.ymax = float(ymax) * METER2ANG
477                    output.zmax = float(zmax) * METER2ANG
478                    output.valuemultiplier = valuemultiplier
479                    output.valuerangeminmag = mag2sld(float(valuerangeminmag), \
480                                                      valueunit)
481                    output.valuerangemaxmag = mag2sld(float(valuerangemaxmag), \
482                                                      valueunit)
483            output.set_m(mx, my, mz)
484            return output
485        except Exception:
486            msg = "%s is not supported: \n" % path
487            msg += "We accept only Text format OMF file."
488            raise RuntimeError(msg)
489
490class PDBReader(object):
491    """
492    PDB reader class: limited for reading the lines starting with 'ATOM'
493    """
494    type_name = "PDB"
495    ## Wildcards
496    type = ["pdb files (*.PDB, *.pdb)|*.pdb"]
497    ## List of allowed extensions
498    ext = ['.pdb', '.PDB']
499
500    def read(self, path):
501        """
502        Load data file
503
504        :param path: file path
505        :return: MagSLD
506        :raise RuntimeError: when the file can't be opened
507        """
508        pos_x = np.zeros(0)
509        pos_y = np.zeros(0)
510        pos_z = np.zeros(0)
511        sld_n = np.zeros(0)
512        sld_mx = np.zeros(0)
513        sld_my = np.zeros(0)
514        sld_mz = np.zeros(0)
515        vol_pix = np.zeros(0)
516        pix_symbol = np.zeros(0)
517        x_line = []
518        y_line = []
519        z_line = []
520        x_lines = []
521        y_lines = []
522        z_lines = []
523        try:
524            input_f = open(path, 'rb')
525            buff = decode(input_f.read())
526            lines = buff.split('\n')
527            input_f.close()
528            num = 0
529            for line in lines:
530                try:
531                    # check if line starts with "ATOM"
532                    if line[0:6].strip().count('ATM') > 0 or \
533                                line[0:6].strip() == 'ATOM':
534                        # define fields of interest
535                        atom_name = line[12:16].strip()
536                        try:
537                            float(line[12])
538                            atom_name = atom_name[1].upper()
539                        except Exception:
540                            if len(atom_name) == 4:
541                                atom_name = atom_name[0].upper()
542                            elif line[12] != ' ':
543                                atom_name = atom_name[0].upper() + \
544                                        atom_name[1].lower()
545                            else:
546                                atom_name = atom_name[0].upper()
547                        _pos_x = float(line[30:38].strip())
548                        _pos_y = float(line[38:46].strip())
549                        _pos_z = float(line[46:54].strip())
550                        pos_x = np.append(pos_x, _pos_x)
551                        pos_y = np.append(pos_y, _pos_y)
552                        pos_z = np.append(pos_z, _pos_z)
553                        try:
554                            val = nsf.neutron_sld(atom_name)[0]
555                            # sld in Ang^-2 unit
556                            val *= 1.0e-6
557                            sld_n = np.append(sld_n, val)
558                            atom = formula(atom_name)
559                            # cm to A units
560                            vol = 1.0e+24 * atom.mass / atom.density / NA
561                            vol_pix = np.append(vol_pix, vol)
562                        except Exception:
563                            print("Error: set the sld of %s to zero"% atom_name)
564                            sld_n = np.append(sld_n, 0.0)
565                        sld_mx = np.append(sld_mx, 0)
566                        sld_my = np.append(sld_my, 0)
567                        sld_mz = np.append(sld_mz, 0)
568                        pix_symbol = np.append(pix_symbol, atom_name)
569                    elif line[0:6].strip().count('CONECT') > 0:
570                        toks = line.split()
571                        num = int(toks[1]) - 1
572                        val_list = []
573                        for val in toks[2:]:
574                            try:
575                                int_val = int(val)
576                            except Exception:
577                                break
578                            if int_val == 0:
579                                break
580                            val_list.append(int_val)
581                        #need val_list ordered
582                        for val in val_list:
583                            index = val - 1
584                            if (pos_x[index], pos_x[num]) in x_line and \
585                               (pos_y[index], pos_y[num]) in y_line and \
586                               (pos_z[index], pos_z[num]) in z_line:
587                                continue
588                            x_line.append((pos_x[num], pos_x[index]))
589                            y_line.append((pos_y[num], pos_y[index]))
590                            z_line.append((pos_z[num], pos_z[index]))
591                    if len(x_line) > 0:
592                        x_lines.append(x_line)
593                        y_lines.append(y_line)
594                        z_lines.append(z_line)
595                except Exception:
596                    logger.error(sys.exc_value)
597
598            output = MagSLD(pos_x, pos_y, pos_z, sld_n, sld_mx, sld_my, sld_mz)
599            output.set_conect_lines(x_line, y_line, z_line)
600            output.filename = os.path.basename(path)
601            output.set_pix_type('atom')
602            output.set_pixel_symbols(pix_symbol)
603            output.set_nodes()
604            output.set_pixel_volumes(vol_pix)
605            output.sld_unit = '1/A^(2)'
606            return output
607        except Exception:
608            raise RuntimeError("%s is not a sld file" % path)
609
610    def write(self, path, data):
611        """
612        Write
613        """
614        print("Not implemented... ")
615
616class SLDReader(object):
617    """
618    Class to load ascii files (7 columns).
619    """
620    ## File type
621    type_name = "SLD ASCII"
622    ## Wildcards
623    type = ["sld files (*.SLD, *.sld)|*.sld",
624            "txt files (*.TXT, *.txt)|*.txt",
625            "all files (*.*)|*.*"]
626    ## List of allowed extensions
627    ext = ['.sld', '.SLD', '.txt', '.TXT', '.*']
628    def read(self, path):
629        """
630        Load data file
631        :param path: file path
632        :return MagSLD: x, y, z, sld_n, sld_mx, sld_my, sld_mz
633        :raise RuntimeError: when the file can't be opened
634        :raise ValueError: when the length of the data vectors are inconsistent
635        """
636        try:
637            pos_x = np.zeros(0)
638            pos_y = np.zeros(0)
639            pos_z = np.zeros(0)
640            sld_n = np.zeros(0)
641            sld_mx = np.zeros(0)
642            sld_my = np.zeros(0)
643            sld_mz = np.zeros(0)
644            try:
645                # Use numpy to speed up loading
646                input_f = np.loadtxt(path, dtype='float', skiprows=1,
647                                        ndmin=1, unpack=True)
648                pos_x = np.array(input_f[0])
649                pos_y = np.array(input_f[1])
650                pos_z = np.array(input_f[2])
651                sld_n = np.array(input_f[3])
652                sld_mx = np.array(input_f[4])
653                sld_my = np.array(input_f[5])
654                sld_mz = np.array(input_f[6])
655                ncols = len(input_f)
656                if ncols == 8:
657                    vol_pix = np.array(input_f[7])
658                elif ncols == 7:
659                    vol_pix = None
660            except Exception:
661                # For older version of numpy
662                input_f = open(path, 'rb')
663                buff = decode(input_f.read())
664                lines = buff.split('\n')
665                input_f.close()
666                for line in lines:
667                    toks = line.split()
668                    try:
669                        _pos_x = float(toks[0])
670                        _pos_y = float(toks[1])
671                        _pos_z = float(toks[2])
672                        _sld_n = float(toks[3])
673                        _sld_mx = float(toks[4])
674                        _sld_my = float(toks[5])
675                        _sld_mz = float(toks[6])
676                        pos_x = np.append(pos_x, _pos_x)
677                        pos_y = np.append(pos_y, _pos_y)
678                        pos_z = np.append(pos_z, _pos_z)
679                        sld_n = np.append(sld_n, _sld_n)
680                        sld_mx = np.append(sld_mx, _sld_mx)
681                        sld_my = np.append(sld_my, _sld_my)
682                        sld_mz = np.append(sld_mz, _sld_mz)
683                        try:
684                            _vol_pix = float(toks[7])
685                            vol_pix = np.append(vol_pix, _vol_pix)
686                        except Exception:
687                            vol_pix = None
688                    except Exception:
689                        # Skip non-data lines
690                        logger.error(sys.exc_value)
691            output = MagSLD(pos_x, pos_y, pos_z, sld_n,
692                            sld_mx, sld_my, sld_mz)
693            output.filename = os.path.basename(path)
694            output.set_pix_type('pixel')
695            output.set_pixel_symbols('pixel')
696            if vol_pix is not None:
697                output.set_pixel_volumes(vol_pix)
698            return output
699        except Exception:
700            raise RuntimeError("%s is not a sld file" % path)
701
702    def write(self, path, data):
703        """
704        Write sld file
705        :Param path: file path
706        :Param data: MagSLD data object
707        """
708        if path is None:
709            raise ValueError("Missing the file path.")
710        if data is None:
711            raise ValueError("Missing the data to save.")
712        x_val = data.pos_x
713        y_val = data.pos_y
714        z_val = data.pos_z
715        vol_pix = data.vol_pix
716        length = len(x_val)
717        sld_n = data.sld_n
718        if sld_n is None:
719            sld_n = np.zeros(length)
720        sld_mx = data.sld_mx
721        if sld_mx is None:
722            sld_mx = np.zeros(length)
723            sld_my = np.zeros(length)
724            sld_mz = np.zeros(length)
725        else:
726            sld_my = data.sld_my
727            sld_mz = data.sld_mz
728        out = open(path, 'w')
729        # First Line: Column names
730        out.write("X  Y  Z  SLDN SLDMx  SLDMy  SLDMz VOLUMEpix")
731        for ind in range(length):
732            out.write("\n%g  %g  %g  %g  %g  %g  %g %g" % \
733                      (x_val[ind], y_val[ind], z_val[ind], sld_n[ind],
734                       sld_mx[ind], sld_my[ind], sld_mz[ind], vol_pix[ind]))
735        out.close()
736
737
738class OMFData(object):
739    """
740    OMF Data.
741    """
742    _meshunit = "A"
743    _valueunit = "A^(-2)"
744    def __init__(self):
745        """
746        Init for mag SLD
747        """
748        self.filename = 'default'
749        self.oommf = ''
750        self.title = ''
751        self.desc = ''
752        self.meshtype = ''
753        self.meshunit = self._meshunit
754        self.valueunit = self._valueunit
755        self.xbase = 0.0
756        self.ybase = 0.0
757        self.zbase = 0.0
758        self.xstepsize = 6.0
759        self.ystepsize = 6.0
760        self.zstepsize = 6.0
761        self.xnodes = 10.0
762        self.ynodes = 10.0
763        self.znodes = 10.0
764        self.xmin = 0.0
765        self.ymin = 0.0
766        self.zmin = 0.0
767        self.xmax = 60.0
768        self.ymax = 60.0
769        self.zmax = 60.0
770        self.mx = None
771        self.my = None
772        self.mz = None
773        self.valuemultiplier = 1.
774        self.valuerangeminmag = 0
775        self.valuerangemaxmag = 0
776
777    def __str__(self):
778        """
779        doc strings
780        """
781        _str = "Type:            %s\n" % self.__class__.__name__
782        _str += "File:            %s\n" % self.filename
783        _str += "OOMMF:           %s\n" % self.oommf
784        _str += "Title:           %s\n" % self.title
785        _str += "Desc:            %s\n" % self.desc
786        _str += "meshtype:        %s\n" % self.meshtype
787        _str += "meshunit:        %s\n" % str(self.meshunit)
788        _str += "xbase:           %s [%s]\n" % (str(self.xbase), self.meshunit)
789        _str += "ybase:           %s [%s]\n" % (str(self.ybase), self.meshunit)
790        _str += "zbase:           %s [%s]\n" % (str(self.zbase), self.meshunit)
791        _str += "xstepsize:       %s [%s]\n" % (str(self.xstepsize),
792                                                self.meshunit)
793        _str += "ystepsize:       %s [%s]\n" % (str(self.ystepsize),
794                                                self.meshunit)
795        _str += "zstepsize:       %s [%s]\n" % (str(self.zstepsize),
796                                                self.meshunit)
797        _str += "xnodes:          %s\n" % str(self.xnodes)
798        _str += "ynodes:          %s\n" % str(self.ynodes)
799        _str += "znodes:          %s\n" % str(self.znodes)
800        _str += "xmin:            %s [%s]\n" % (str(self.xmin), self.meshunit)
801        _str += "ymin:            %s [%s]\n" % (str(self.ymin), self.meshunit)
802        _str += "zmin:            %s [%s]\n" % (str(self.zmin), self.meshunit)
803        _str += "xmax:            %s [%s]\n" % (str(self.xmax), self.meshunit)
804        _str += "ymax:            %s [%s]\n" % (str(self.ymax), self.meshunit)
805        _str += "zmax:            %s [%s]\n" % (str(self.zmax), self.meshunit)
806        _str += "valueunit:       %s\n" % self.valueunit
807        _str += "valuemultiplier: %s\n" % str(self.valuemultiplier)
808        _str += "ValueRangeMinMag:%s [%s]\n" % (str(self.valuerangeminmag),
809                                                self.valueunit)
810        _str += "ValueRangeMaxMag:%s [%s]\n" % (str(self.valuerangemaxmag),
811                                                self.valueunit)
812        return _str
813
814    def set_m(self, mx, my, mz):
815        """
816        Set the Mx, My, Mz values
817        """
818        self.mx = mx
819        self.my = my
820        self.mz = mz
821
822class MagSLD(object):
823    """
824    Magnetic SLD.
825    """
826    pos_x = None
827    pos_y = None
828    pos_z = None
829    sld_n = None
830    sld_mx = None
831    sld_my = None
832    sld_mz = None
833    # Units
834    _pos_unit = 'A'
835    _sld_unit = '1/A^(2)'
836    _pix_type = 'pixel'
837
838    def __init__(self, pos_x, pos_y, pos_z, sld_n=None,
839                 sld_mx=None, sld_my=None, sld_mz=None, vol_pix=None):
840        """
841        Init for mag SLD
842        :params : All should be numpy 1D array
843        """
844        self.is_data = True
845        self.filename = ''
846        self.xstepsize = 6.0
847        self.ystepsize = 6.0
848        self.zstepsize = 6.0
849        self.xnodes = 10.0
850        self.ynodes = 10.0
851        self.znodes = 10.0
852        self.has_stepsize = False
853        self.has_conect = False
854        self.pos_unit = self._pos_unit
855        self.sld_unit = self._sld_unit
856        self.pix_type = 'pixel'
857        self.pos_x = pos_x
858        self.pos_y = pos_y
859        self.pos_z = pos_z
860        self.sld_n = sld_n
861        self.line_x = None
862        self.line_y = None
863        self.line_z = None
864        self.sld_mx = sld_mx
865        self.sld_my = sld_my
866        self.sld_mz = sld_mz
867        self.vol_pix = vol_pix
868        self.sld_m = None
869        self.sld_phi = None
870        self.sld_theta = None
871        self.pix_symbol = None
872        if sld_mx is not None and sld_my is not None and sld_mz is not None:
873            self.set_sldms(sld_mx, sld_my, sld_mz)
874        self.set_nodes()
875
876    def __str__(self):
877        """
878        doc strings
879        """
880        _str = "Type:       %s\n" % self.__class__.__name__
881        _str += "File:       %s\n" % self.filename
882        _str += "Axis_unit:  %s\n" % self.pos_unit
883        _str += "SLD_unit:   %s\n" % self.sld_unit
884        return _str
885
886    def set_pix_type(self, pix_type):
887        """
888        Set pixel type
889        :Param pix_type: string, 'pixel' or 'atom'
890        """
891        self.pix_type = pix_type
892
893    def set_sldn(self, sld_n):
894        """
895        Sets neutron SLD
896        """
897        if sld_n.__class__.__name__ == 'float':
898            if self.is_data:
899                # For data, put the value to only the pixels w non-zero M
900                is_nonzero = (np.fabs(self.sld_mx) +
901                              np.fabs(self.sld_my) +
902                              np.fabs(self.sld_mz)).nonzero()
903                self.sld_n = np.zeros(len(self.pos_x))
904                if len(self.sld_n[is_nonzero]) > 0:
905                    self.sld_n[is_nonzero] = sld_n
906                else:
907                    self.sld_n.fill(sld_n)
908            else:
909                # For non-data, put the value to all the pixels
910                self.sld_n = np.ones(len(self.pos_x)) * sld_n
911        else:
912            self.sld_n = sld_n
913
914    def set_sldms(self, sld_mx, sld_my, sld_mz):
915        r"""
916        Sets mx, my, mz and abs(m).
917        """ # Note: escaping
918        if sld_mx.__class__.__name__ == 'float':
919            self.sld_mx = np.ones(len(self.pos_x)) * sld_mx
920        else:
921            self.sld_mx = sld_mx
922        if sld_my.__class__.__name__ == 'float':
923            self.sld_my = np.ones(len(self.pos_x)) * sld_my
924        else:
925            self.sld_my = sld_my
926        if sld_mz.__class__.__name__ == 'float':
927            self.sld_mz = np.ones(len(self.pos_x)) * sld_mz
928        else:
929            self.sld_mz = sld_mz
930
931        sld_m = np.sqrt(sld_mx * sld_mx + sld_my * sld_my + \
932                                sld_mz * sld_mz)
933        self.sld_m = sld_m
934
935    def set_pixel_symbols(self, symbol='pixel'):
936        """
937        Set pixel
938        :Params pixel: str; pixel or atomic symbol, or array of strings
939        """
940        if self.sld_n is None:
941            return
942        if symbol.__class__.__name__ == 'str':
943            self.pix_symbol = np.repeat(symbol, len(self.sld_n))
944        else:
945            self.pix_symbol = symbol
946
947    def set_pixel_volumes(self, vol):
948        """
949        Set pixel volumes
950        :Params pixel: str; pixel or atomic symbol, or array of strings
951        """
952        if self.sld_n is None:
953            return
954        if vol.__class__.__name__ == 'ndarray':
955            self.vol_pix = vol
956        elif vol.__class__.__name__.count('float') > 0:
957            self.vol_pix = np.repeat(vol, len(self.sld_n))
958        else:
959            self.vol_pix = None
960
961    def get_sldn(self):
962        """
963        Returns nuclear sld
964        """
965        return self.sld_n
966
967    def set_nodes(self):
968        """
969        Set xnodes, ynodes, and znodes
970        """
971        self.set_stepsize()
972        if self.pix_type == 'pixel':
973            try:
974                xdist = (max(self.pos_x) - min(self.pos_x)) / self.xstepsize
975                ydist = (max(self.pos_y) - min(self.pos_y)) / self.ystepsize
976                zdist = (max(self.pos_z) - min(self.pos_z)) / self.zstepsize
977                self.xnodes = int(xdist) + 1
978                self.ynodes = int(ydist) + 1
979                self.znodes = int(zdist) + 1
980            except Exception:
981                self.xnodes = None
982                self.ynodes = None
983                self.znodes = None
984        else:
985            self.xnodes = None
986            self.ynodes = None
987            self.znodes = None
988
989    def set_stepsize(self):
990        """
991        Set xtepsize, ystepsize, and zstepsize
992        """
993        if self.pix_type == 'pixel':
994            try:
995                xpos_pre = self.pos_x[0]
996                ypos_pre = self.pos_y[0]
997                zpos_pre = self.pos_z[0]
998                for x_pos in self.pos_x:
999                    if xpos_pre != x_pos:
1000                        self.xstepsize = np.fabs(x_pos - xpos_pre)
1001                        break
1002                for y_pos in self.pos_y:
1003                    if ypos_pre != y_pos:
1004                        self.ystepsize = np.fabs(y_pos - ypos_pre)
1005                        break
1006                for z_pos in self.pos_z:
1007                    if zpos_pre != z_pos:
1008                        self.zstepsize = np.fabs(z_pos - zpos_pre)
1009                        break
1010                #default pix volume
1011                self.vol_pix = np.ones(len(self.pos_x))
1012                vol = self.xstepsize * self.ystepsize * self.zstepsize
1013                self.set_pixel_volumes(vol)
1014                self.has_stepsize = True
1015            except Exception:
1016                self.xstepsize = None
1017                self.ystepsize = None
1018                self.zstepsize = None
1019                self.vol_pix = None
1020                self.has_stepsize = False
1021        else:
1022            self.xstepsize = None
1023            self.ystepsize = None
1024            self.zstepsize = None
1025            self.has_stepsize = True
1026        return self.xstepsize, self.ystepsize, self.zstepsize
1027
1028    def set_conect_lines(self, line_x, line_y, line_z):
1029        """
1030        Set bonding line data if taken from pdb
1031        """
1032        if line_x.__class__.__name__ != 'list' or len(line_x) < 1:
1033            return
1034        if line_y.__class__.__name__ != 'list' or len(line_y) < 1:
1035            return
1036        if line_z.__class__.__name__ != 'list' or len(line_z) < 1:
1037            return
1038        self.has_conect = True
1039        self.line_x = line_x
1040        self.line_y = line_y
1041        self.line_z = line_z
1042
1043def test_load():
1044    """
1045        Test code
1046    """
1047    from mpl_toolkits.mplot3d import Axes3D
1048    current_dir = os.path.abspath(os.path.curdir)
1049    print(current_dir)
1050    for i in range(6):
1051        current_dir, _ = os.path.split(current_dir)
1052        tfile = os.path.join(current_dir, "test", "CoreXY_ShellZ.txt")
1053        ofile = os.path.join(current_dir, "test", "A_Raw_Example-1.omf")
1054        if os.path.isfile(tfile):
1055            tfpath = tfile
1056            ofpath = ofile
1057            break
1058    reader = SLDReader()
1059    oreader = OMFReader()
1060    output = decode(reader.read(tfpath))
1061    ooutput = decode(oreader.read(ofpath))
1062    foutput = OMF2SLD()
1063    foutput.set_data(ooutput)
1064
1065    import matplotlib.pyplot as plt
1066    fig = plt.figure()
1067    ax = Axes3D(fig)
1068    ax.plot(output.pos_x, output.pos_y, output.pos_z, '.', c="g",
1069            alpha=0.7, markeredgecolor='gray', rasterized=True)
1070    gap = 7
1071    max_mx = max(output.sld_mx)
1072    max_my = max(output.sld_my)
1073    max_mz = max(output.sld_mz)
1074    max_m = max(max_mx, max_my, max_mz)
1075    x2 = output.pos_x+output.sld_mx/max_m * gap
1076    y2 = output.pos_y+output.sld_my/max_m * gap
1077    z2 = output.pos_z+output.sld_mz/max_m * gap
1078    x_arrow = np.column_stack((output.pos_x, x2))
1079    y_arrow = np.column_stack((output.pos_y, y2))
1080    z_arrow = np.column_stack((output.pos_z, z2))
1081    unit_x2 = output.sld_mx / max_m
1082    unit_y2 = output.sld_my / max_m
1083    unit_z2 = output.sld_mz / max_m
1084    color_x = np.fabs(unit_x2 * 0.8)
1085    color_y = np.fabs(unit_y2 * 0.8)
1086    color_z = np.fabs(unit_z2 * 0.8)
1087    colors = np.column_stack((color_x, color_y, color_z))
1088    plt.show()
1089
1090def test():
1091    """
1092        Test code
1093    """
1094    current_dir = os.path.abspath(os.path.curdir)
1095    for i in range(3):
1096        current_dir, _ = os.path.split(current_dir)
1097        ofile = os.path.join(current_dir, "test", "A_Raw_Example-1.omf")
1098        if os.path.isfile(ofile):
1099            ofpath = ofile
1100            break
1101    oreader = OMFReader()
1102    ooutput = decode(oreader.read(ofpath))
1103    foutput = OMF2SLD()
1104    foutput.set_data(ooutput)
1105    writer = SLDReader()
1106    writer.write(os.path.join(os.path.dirname(ofpath), "out.txt"),
1107                 foutput.output)
1108    model = GenSAS()
1109    model.set_sld_data(foutput.output)
1110    x = np.arange(1000)/10000. + 1e-5
1111    y = np.arange(1000)/10000. + 1e-5
1112    i = np.zeros(1000)
1113    model.runXY([x, y, i])
1114
1115if __name__ == "__main__":
1116    test()
1117    test_load()
Note: See TracBrowser for help on using the repository browser.