source: sasview/src/sas/sascalc/calculator/sas_gen.py @ 95d7c4f

magnetic_scattrelease-4.2.2ticket-1009ticket-1094-headlessticket-1242-2d-resolutionticket-1243ticket-1249ticket885unittest-saveload
Last change on this file since 95d7c4f was f926abb, checked in by Paul Kienzle <pkienzle@…>, 7 years ago

add smoke test for generic scattering calculator

  • Property mode set to 100644
File size: 39.9 KB
Line 
1# pylint: disable=invalid-name
2"""
3SAS generic computation and sld file readers
4"""
5from __future__ import print_function
6
7import os
8import sys
9import copy
10import logging
11
12from periodictable import formula
13from periodictable import nsf
14import numpy as np
15
16from .core import sld2i as mod
17from .BaseComponent import BaseComponent
18
19logger = logging.getLogger(__name__)
20
21if sys.version_info[0] < 3:
22    def decode(s):
23        return s
24else:
25    def decode(s):
26        return s.decode() if isinstance(s, bytes) else s
27
28MFACTOR_AM = 2.853E-12
29MFACTOR_MT = 2.3164E-9
30METER2ANG = 1.0E+10
31#Avogadro constant [1/mol]
32NA = 6.02214129e+23
33
34def mag2sld(mag, v_unit=None):
35    """
36    Convert magnetization to magnatic SLD
37    sldm = Dm * mag where Dm = gamma * classical elec. radius/(2*Bohr magneton)
38    Dm ~ 2.853E-12 [A^(-2)] ==> Shouldn't be 2.90636E-12 [A^(-2)]???
39    """
40    if v_unit == "A/m":
41        factor = MFACTOR_AM
42    elif v_unit == "mT":
43        factor = MFACTOR_MT
44    else:
45        raise ValueError("Invalid valueunit")
46    sld_m = factor * mag
47    return sld_m
48
49def transform_center(pos_x, pos_y, pos_z):
50    """
51    re-center
52    :return: posx, posy, posz   [arrays]
53    """
54    posx = pos_x - (min(pos_x) + max(pos_x)) / 2.0
55    posy = pos_y - (min(pos_y) + max(pos_y)) / 2.0
56    posz = pos_z - (min(pos_z) + max(pos_z)) / 2.0
57    return posx, posy, posz
58
59class GenSAS(BaseComponent):
60    """
61    Generic SAS computation Model based on sld (n & m) arrays
62    """
63    def __init__(self):
64        """
65        Init
66        :Params sld_data: MagSLD object
67        """
68        # Initialize BaseComponent
69        BaseComponent.__init__(self)
70        self.sld_data = None
71        self.data_pos_unit = None
72        self.data_x = None
73        self.data_y = None
74        self.data_z = None
75        self.data_sldn = None
76        self.data_mx = None
77        self.data_my = None
78        self.data_mz = None
79        self.data_vol = None #[A^3]
80        self.is_avg = False
81        ## Name of the model
82        self.name = "GenSAS"
83        ## Define parameters
84        self.params = {}
85        self.params['scale'] = 1.0
86        self.params['background'] = 0.0
87        self.params['solvent_SLD'] = 0.0
88        self.params['total_volume'] = 1.0
89        self.params['Up_frac_in'] = 1.0
90        self.params['Up_frac_out'] = 1.0
91        self.params['Up_theta'] = 0.0
92        self.description = 'GenSAS'
93        ## Parameter details [units, min, max]
94        self.details = {}
95        self.details['scale'] = ['', 0.0, np.inf]
96        self.details['background'] = ['[1/cm]', 0.0, np.inf]
97        self.details['solvent_SLD'] = ['1/A^(2)', -np.inf, np.inf]
98        self.details['total_volume'] = ['A^(3)', 0.0, np.inf]
99        self.details['Up_frac_in'] = ['[u/(u+d)]', 0.0, 1.0]
100        self.details['Up_frac_out'] = ['[u/(u+d)]', 0.0, 1.0]
101        self.details['Up_theta'] = ['[deg]', -np.inf, np.inf]
102        # fixed parameters
103        self.fixed = []
104
105    def set_pixel_volumes(self, volume):
106        """
107        Set the volume of a pixel in (A^3) unit
108        :Param volume: pixel volume [float]
109        """
110        if self.data_vol is None:
111            raise TypeError("data_vol is missing")
112        self.data_vol = volume
113
114    def set_is_avg(self, is_avg=False):
115        """
116        Sets is_avg: [bool]
117        """
118        self.is_avg = is_avg
119
120    def _gen(self, qx, qy):
121        """
122        Evaluate the function
123        :Param x: array of x-values
124        :Param y: array of y-values
125        :Param i: array of initial i-value
126        :return: function value
127        """
128        pos_x = self.data_x
129        pos_y = self.data_y
130        pos_z = self.data_z
131        if self.is_avg is None:
132            pos_x, pos_y, pos_z = transform_center(pos_x, pos_y, pos_z)
133        sldn = copy.deepcopy(self.data_sldn)
134        sldn -= self.params['solvent_SLD']
135        model = mod.new_GenI((1 if self.is_avg else 0),
136                             pos_x, pos_y, pos_z,
137                             sldn, self.data_mx, self.data_my,
138                             self.data_mz, self.data_vol,
139                             self.params['Up_frac_in'],
140                             self.params['Up_frac_out'],
141                             self.params['Up_theta'])
142        if len(qy):
143            qx, qy = _vec(qx), _vec(qy)
144            I_out = np.empty_like(qx)
145            mod.genicomXY(model, qx, qy, I_out)
146        else:
147            qx = _vec(qx)
148            I_out = np.empty_like(qx)
149            mod.genicom(model, qx, I_out)
150        vol_correction = self.data_total_volume / self.params['total_volume']
151        result = (self.params['scale'] * vol_correction * I_out
152                  + self.params['background'])
153        return result
154
155    def set_sld_data(self, sld_data=None):
156        """
157        Sets sld_data
158        """
159        self.sld_data = sld_data
160        self.data_pos_unit = sld_data.pos_unit
161        self.data_x = _vec(sld_data.pos_x)
162        self.data_y = _vec(sld_data.pos_y)
163        self.data_z = _vec(sld_data.pos_z)
164        self.data_sldn = _vec(sld_data.sld_n)
165        self.data_mx = _vec(sld_data.sld_mx)
166        self.data_my = _vec(sld_data.sld_my)
167        self.data_mz = _vec(sld_data.sld_mz)
168        self.data_vol = _vec(sld_data.vol_pix)
169        self.data_total_volume = sum(sld_data.vol_pix)
170        self.params['total_volume'] = sum(sld_data.vol_pix)
171
172    def getProfile(self):
173        """
174        Get SLD profile
175        : return: sld_data
176        """
177        return self.sld_data
178
179    def run(self, x=0.0):
180        """
181        Evaluate the model
182        :param x: simple value
183        :return: (I value)
184        """
185        if isinstance(x, list):
186            if len(x[1]) > 0:
187                msg = "Not a 1D."
188                raise ValueError(msg)
189            # 1D I is found at y =0 in the 2D pattern
190            out = self._gen(x[0], [])
191            return out
192        else:
193            msg = "Q must be given as list of qx's and qy's"
194            raise ValueError(msg)
195
196    def runXY(self, x=0.0):
197        """
198        Evaluate the model
199        :param x: simple value
200        :return: I value
201        :Use this runXY() for the computation
202        """
203        if isinstance(x, list):
204            return self._gen(x[0], x[1])
205        else:
206            msg = "Q must be given as list of qx's and qy's"
207            raise ValueError(msg)
208
209    def evalDistribution(self, qdist):
210        """
211        Evaluate a distribution of q-values.
212
213        :param qdist: ndarray of scalar q-values (for 1D) or list [qx,qy]
214                      where qx,qy are 1D ndarrays (for 2D).
215        """
216        if isinstance(qdist, list):
217            return self.run(qdist) if len(qdist[1]) < 1 else self.runXY(qdist)
218        else:
219            mesg = "evalDistribution is expecting an ndarray of "
220            mesg += "a list [qx,qy] where qx,qy are arrays."
221            raise RuntimeError(mesg)
222
223def _vec(v):
224    return np.ascontiguousarray(v, 'd')
225
226class OMF2SLD(object):
227    """
228    Convert OMFData to MAgData
229    """
230    def __init__(self):
231        """
232        Init
233        """
234        self.pos_x = None
235        self.pos_y = None
236        self.pos_z = None
237        self.mx = None
238        self.my = None
239        self.mz = None
240        self.sld_n = None
241        self.vol_pix = None
242        self.output = None
243        self.omfdata = None
244
245    def set_data(self, omfdata, shape='rectangular'):
246        """
247        Set all data
248        """
249        self.omfdata = omfdata
250        length = int(omfdata.xnodes * omfdata.ynodes * omfdata.znodes)
251        pos_x = np.arange(omfdata.xmin,
252                             omfdata.xnodes*omfdata.xstepsize + omfdata.xmin,
253                             omfdata.xstepsize)
254        pos_y = np.arange(omfdata.ymin,
255                             omfdata.ynodes*omfdata.ystepsize + omfdata.ymin,
256                             omfdata.ystepsize)
257        pos_z = np.arange(omfdata.zmin,
258                             omfdata.znodes*omfdata.zstepsize + omfdata.zmin,
259                             omfdata.zstepsize)
260        self.pos_x = np.tile(pos_x, int(omfdata.ynodes * omfdata.znodes))
261        self.pos_y = pos_y.repeat(int(omfdata.xnodes))
262        self.pos_y = np.tile(self.pos_y, int(omfdata.znodes))
263        self.pos_z = pos_z.repeat(int(omfdata.xnodes * omfdata.ynodes))
264        self.mx = omfdata.mx
265        self.my = omfdata.my
266        self.mz = omfdata.mz
267        self.sld_n = np.zeros(length)
268
269        if omfdata.mx is None:
270            self.mx = np.zeros(length)
271        if omfdata.my is None:
272            self.my = np.zeros(length)
273        if omfdata.mz is None:
274            self.mz = np.zeros(length)
275
276        self._check_data_length(length)
277        self.remove_null_points(False, False)
278        mask = np.ones(len(self.sld_n), dtype=bool)
279        if shape.lower() == 'ellipsoid':
280            try:
281                # Pixel (step) size included
282                x_c = max(self.pos_x) + min(self.pos_x)
283                y_c = max(self.pos_y) + min(self.pos_y)
284                z_c = max(self.pos_z) + min(self.pos_z)
285                x_d = max(self.pos_x) - min(self.pos_x)
286                y_d = max(self.pos_y) - min(self.pos_y)
287                z_d = max(self.pos_z) - min(self.pos_z)
288                x_r = (x_d + omfdata.xstepsize) / 2.0
289                y_r = (y_d + omfdata.ystepsize) / 2.0
290                z_r = (z_d + omfdata.zstepsize) / 2.0
291                x_dir2 = ((self.pos_x - x_c / 2.0) / x_r)
292                x_dir2 *= x_dir2
293                y_dir2 = ((self.pos_y - y_c / 2.0) / y_r)
294                y_dir2 *= y_dir2
295                z_dir2 = ((self.pos_z - z_c / 2.0) / z_r)
296                z_dir2 *= z_dir2
297                mask = (x_dir2 + y_dir2 + z_dir2) <= 1.0
298            except Exception:
299                logger.error(sys.exc_value)
300        self.output = MagSLD(self.pos_x[mask], self.pos_y[mask],
301                             self.pos_z[mask], self.sld_n[mask],
302                             self.mx[mask], self.my[mask], self.mz[mask])
303        self.output.set_pix_type('pixel')
304        self.output.set_pixel_symbols('pixel')
305
306    def get_omfdata(self):
307        """
308        Return all data
309        """
310        return self.omfdata
311
312    def get_output(self):
313        """
314        Return output
315        """
316        return self.output
317
318    def _check_data_length(self, length):
319        """
320        Check if the data lengths are consistent
321        :Params length: data length
322        """
323        parts = (self.pos_x, self.pos_y, self.pos_z, self.mx, self.my, self.mz)
324        if any(len(v) != length for v in parts):
325            raise ValueError("Error: Inconsistent data length.")
326
327    def remove_null_points(self, remove=False, recenter=False):
328        """
329        Removes any mx, my, and mz = 0 points
330        """
331        if remove:
332            is_nonzero = (np.fabs(self.mx) + np.fabs(self.my) +
333                          np.fabs(self.mz)).nonzero()
334            if len(is_nonzero[0]) > 0:
335                self.pos_x = self.pos_x[is_nonzero]
336                self.pos_y = self.pos_y[is_nonzero]
337                self.pos_z = self.pos_z[is_nonzero]
338                self.sld_n = self.sld_n[is_nonzero]
339                self.mx = self.mx[is_nonzero]
340                self.my = self.my[is_nonzero]
341                self.mz = self.mz[is_nonzero]
342        if recenter:
343            self.pos_x -= (min(self.pos_x) + max(self.pos_x)) / 2.0
344            self.pos_y -= (min(self.pos_y) + max(self.pos_y)) / 2.0
345            self.pos_z -= (min(self.pos_z) + max(self.pos_z)) / 2.0
346
347    def get_magsld(self):
348        """
349        return MagSLD
350        """
351        return self.output
352
353
354class OMFReader(object):
355    """
356    Class to load omf/ascii files (3 columns w/header).
357    """
358    ## File type
359    type_name = "OMF ASCII"
360
361    ## Wildcards
362    type = ["OMF files (*.OMF, *.omf)|*.omf"]
363    ## List of allowed extensions
364    ext = ['.omf', '.OMF']
365
366    def read(self, path):
367        """
368        Load data file
369        :param path: file path
370        :return: x, y, z, sld_n, sld_mx, sld_my, sld_mz
371        """
372        desc = ""
373        mx = np.zeros(0)
374        my = np.zeros(0)
375        mz = np.zeros(0)
376        try:
377            input_f = open(path, 'rb')
378            buff = decode(input_f.read())
379            lines = buff.split('\n')
380            input_f.close()
381            output = OMFData()
382            valueunit = None
383            for line in lines:
384                line = line.strip()
385                # Read data
386                if line and not line.startswith('#'):
387                    try:
388                        toks = line.split()
389                        _mx = float(toks[0])
390                        _my = float(toks[1])
391                        _mz = float(toks[2])
392                        _mx = mag2sld(_mx, valueunit)
393                        _my = mag2sld(_my, valueunit)
394                        _mz = mag2sld(_mz, valueunit)
395                        mx = np.append(mx, _mx)
396                        my = np.append(my, _my)
397                        mz = np.append(mz, _mz)
398                    except Exception as exc:
399                        # Skip non-data lines
400                        logger.error(str(exc)+" when processing %r"%line)
401                #Reading Header; Segment count ignored
402                s_line = line.split(":", 1)
403                if s_line[0].lower().count("oommf") > 0:
404                    oommf = s_line[1].lstrip()
405                if s_line[0].lower().count("title") > 0:
406                    title = s_line[1].lstrip()
407                if s_line[0].lower().count("desc") > 0:
408                    desc += s_line[1].lstrip()
409                    desc += '\n'
410                if s_line[0].lower().count("meshtype") > 0:
411                    meshtype = s_line[1].lstrip()
412                if s_line[0].lower().count("meshunit") > 0:
413                    meshunit = s_line[1].lstrip()
414                    if meshunit.count("m") < 1:
415                        msg = "Error: \n"
416                        msg += "We accept only m as meshunit"
417                        raise ValueError(msg)
418                if s_line[0].lower().count("xbase") > 0:
419                    xbase = s_line[1].lstrip()
420                if s_line[0].lower().count("ybase") > 0:
421                    ybase = s_line[1].lstrip()
422                if s_line[0].lower().count("zbase") > 0:
423                    zbase = s_line[1].lstrip()
424                if s_line[0].lower().count("xstepsize") > 0:
425                    xstepsize = s_line[1].lstrip()
426                if s_line[0].lower().count("ystepsize") > 0:
427                    ystepsize = s_line[1].lstrip()
428                if s_line[0].lower().count("zstepsize") > 0:
429                    zstepsize = s_line[1].lstrip()
430                if s_line[0].lower().count("xnodes") > 0:
431                    xnodes = s_line[1].lstrip()
432                if s_line[0].lower().count("ynodes") > 0:
433                    ynodes = s_line[1].lstrip()
434                if s_line[0].lower().count("znodes") > 0:
435                    znodes = s_line[1].lstrip()
436                if s_line[0].lower().count("xmin") > 0:
437                    xmin = s_line[1].lstrip()
438                if s_line[0].lower().count("ymin") > 0:
439                    ymin = s_line[1].lstrip()
440                if s_line[0].lower().count("zmin") > 0:
441                    zmin = s_line[1].lstrip()
442                if s_line[0].lower().count("xmax") > 0:
443                    xmax = s_line[1].lstrip()
444                if s_line[0].lower().count("ymax") > 0:
445                    ymax = s_line[1].lstrip()
446                if s_line[0].lower().count("zmax") > 0:
447                    zmax = s_line[1].lstrip()
448                if s_line[0].lower().count("valueunit") > 0:
449                    valueunit = s_line[1].lstrip().rstrip()
450                if s_line[0].lower().count("valuemultiplier") > 0:
451                    valuemultiplier = s_line[1].lstrip()
452                if s_line[0].lower().count("valuerangeminmag") > 0:
453                    valuerangeminmag = s_line[1].lstrip()
454                if s_line[0].lower().count("valuerangemaxmag") > 0:
455                    valuerangemaxmag = s_line[1].lstrip()
456                if s_line[0].lower().count("end") > 0:
457                    output.filename = os.path.basename(path)
458                    output.oommf = oommf
459                    output.title = title
460                    output.desc = desc
461                    output.meshtype = meshtype
462                    output.xbase = float(xbase) * METER2ANG
463                    output.ybase = float(ybase) * METER2ANG
464                    output.zbase = float(zbase) * METER2ANG
465                    output.xstepsize = float(xstepsize) * METER2ANG
466                    output.ystepsize = float(ystepsize) * METER2ANG
467                    output.zstepsize = float(zstepsize) * METER2ANG
468                    output.xnodes = float(xnodes)
469                    output.ynodes = float(ynodes)
470                    output.znodes = float(znodes)
471                    output.xmin = float(xmin) * METER2ANG
472                    output.ymin = float(ymin) * METER2ANG
473                    output.zmin = float(zmin) * METER2ANG
474                    output.xmax = float(xmax) * METER2ANG
475                    output.ymax = float(ymax) * METER2ANG
476                    output.zmax = float(zmax) * METER2ANG
477                    output.valuemultiplier = valuemultiplier
478                    output.valuerangeminmag = mag2sld(float(valuerangeminmag), \
479                                                      valueunit)
480                    output.valuerangemaxmag = mag2sld(float(valuerangemaxmag), \
481                                                      valueunit)
482            output.set_m(mx, my, mz)
483            return output
484        except Exception:
485            msg = "%s is not supported: \n" % path
486            msg += "We accept only Text format OMF file."
487            raise RuntimeError(msg)
488
489class PDBReader(object):
490    """
491    PDB reader class: limited for reading the lines starting with 'ATOM'
492    """
493    type_name = "PDB"
494    ## Wildcards
495    type = ["pdb files (*.PDB, *.pdb)|*.pdb"]
496    ## List of allowed extensions
497    ext = ['.pdb', '.PDB']
498
499    def read(self, path):
500        """
501        Load data file
502
503        :param path: file path
504        :return: MagSLD
505        :raise RuntimeError: when the file can't be opened
506        """
507        pos_x = np.zeros(0)
508        pos_y = np.zeros(0)
509        pos_z = np.zeros(0)
510        sld_n = np.zeros(0)
511        sld_mx = np.zeros(0)
512        sld_my = np.zeros(0)
513        sld_mz = np.zeros(0)
514        vol_pix = np.zeros(0)
515        pix_symbol = np.zeros(0)
516        x_line = []
517        y_line = []
518        z_line = []
519        x_lines = []
520        y_lines = []
521        z_lines = []
522        try:
523            input_f = open(path, 'rb')
524            buff = decode(input_f.read())
525            lines = buff.split('\n')
526            input_f.close()
527            num = 0
528            for line in lines:
529                try:
530                    # check if line starts with "ATOM"
531                    if line[0:6].strip().count('ATM') > 0 or \
532                                line[0:6].strip() == 'ATOM':
533                        # define fields of interest
534                        atom_name = line[12:16].strip()
535                        try:
536                            float(line[12])
537                            atom_name = atom_name[1].upper()
538                        except Exception:
539                            if len(atom_name) == 4:
540                                atom_name = atom_name[0].upper()
541                            elif line[12] != ' ':
542                                atom_name = atom_name[0].upper() + \
543                                        atom_name[1].lower()
544                            else:
545                                atom_name = atom_name[0].upper()
546                        _pos_x = float(line[30:38].strip())
547                        _pos_y = float(line[38:46].strip())
548                        _pos_z = float(line[46:54].strip())
549                        pos_x = np.append(pos_x, _pos_x)
550                        pos_y = np.append(pos_y, _pos_y)
551                        pos_z = np.append(pos_z, _pos_z)
552                        try:
553                            val = nsf.neutron_sld(atom_name)[0]
554                            # sld in Ang^-2 unit
555                            val *= 1.0e-6
556                            sld_n = np.append(sld_n, val)
557                            atom = formula(atom_name)
558                            # cm to A units
559                            vol = 1.0e+24 * atom.mass / atom.density / NA
560                            vol_pix = np.append(vol_pix, vol)
561                        except Exception:
562                            logger.error("Error: set the sld of %s to zero"% atom_name)
563                            sld_n = np.append(sld_n, 0.0)
564                        sld_mx = np.append(sld_mx, 0)
565                        sld_my = np.append(sld_my, 0)
566                        sld_mz = np.append(sld_mz, 0)
567                        pix_symbol = np.append(pix_symbol, atom_name)
568                    elif line[0:6].strip().count('CONECT') > 0:
569                        toks = line.split()
570                        num = int(toks[1]) - 1
571                        val_list = []
572                        for val in toks[2:]:
573                            try:
574                                int_val = int(val)
575                            except Exception:
576                                break
577                            if int_val == 0:
578                                break
579                            val_list.append(int_val)
580                        #need val_list ordered
581                        for val in val_list:
582                            index = val - 1
583                            if (pos_x[index], pos_x[num]) in x_line and \
584                               (pos_y[index], pos_y[num]) in y_line and \
585                               (pos_z[index], pos_z[num]) in z_line:
586                                continue
587                            x_line.append((pos_x[num], pos_x[index]))
588                            y_line.append((pos_y[num], pos_y[index]))
589                            z_line.append((pos_z[num], pos_z[index]))
590                    if len(x_line) > 0:
591                        x_lines.append(x_line)
592                        y_lines.append(y_line)
593                        z_lines.append(z_line)
594                except Exception:
595                    logger.error(sys.exc_value)
596
597            output = MagSLD(pos_x, pos_y, pos_z, sld_n, sld_mx, sld_my, sld_mz)
598            output.set_conect_lines(x_line, y_line, z_line)
599            output.filename = os.path.basename(path)
600            output.set_pix_type('atom')
601            output.set_pixel_symbols(pix_symbol)
602            output.set_nodes()
603            output.set_pixel_volumes(vol_pix)
604            output.sld_unit = '1/A^(2)'
605            return output
606        except Exception:
607            raise RuntimeError("%s is not a sld file" % path)
608
609    def write(self, path, data):
610        """
611        Write
612        """
613        print("Not implemented... ")
614
615class SLDReader(object):
616    """
617    Class to load ascii files (7 columns).
618    """
619    ## File type
620    type_name = "SLD ASCII"
621    ## Wildcards
622    type = ["sld files (*.SLD, *.sld)|*.sld",
623            "txt files (*.TXT, *.txt)|*.txt",
624            "all files (*.*)|*.*"]
625    ## List of allowed extensions
626    ext = ['.sld', '.SLD', '.txt', '.TXT', '.*']
627    def read(self, path):
628        """
629        Load data file
630        :param path: file path
631        :return MagSLD: x, y, z, sld_n, sld_mx, sld_my, sld_mz
632        :raise RuntimeError: when the file can't be opened
633        :raise ValueError: when the length of the data vectors are inconsistent
634        """
635        try:
636            pos_x = np.zeros(0)
637            pos_y = np.zeros(0)
638            pos_z = np.zeros(0)
639            sld_n = np.zeros(0)
640            sld_mx = np.zeros(0)
641            sld_my = np.zeros(0)
642            sld_mz = np.zeros(0)
643            try:
644                # Use numpy to speed up loading
645                input_f = np.loadtxt(path, dtype='float', skiprows=1,
646                                        ndmin=1, unpack=True)
647                pos_x = np.array(input_f[0])
648                pos_y = np.array(input_f[1])
649                pos_z = np.array(input_f[2])
650                sld_n = np.array(input_f[3])
651                sld_mx = np.array(input_f[4])
652                sld_my = np.array(input_f[5])
653                sld_mz = np.array(input_f[6])
654                ncols = len(input_f)
655                if ncols == 8:
656                    vol_pix = np.array(input_f[7])
657                elif ncols == 7:
658                    vol_pix = None
659            except Exception:
660                # For older version of numpy
661                input_f = open(path, 'rb')
662                buff = decode(input_f.read())
663                lines = buff.split('\n')
664                input_f.close()
665                for line in lines:
666                    toks = line.split()
667                    try:
668                        _pos_x = float(toks[0])
669                        _pos_y = float(toks[1])
670                        _pos_z = float(toks[2])
671                        _sld_n = float(toks[3])
672                        _sld_mx = float(toks[4])
673                        _sld_my = float(toks[5])
674                        _sld_mz = float(toks[6])
675                        pos_x = np.append(pos_x, _pos_x)
676                        pos_y = np.append(pos_y, _pos_y)
677                        pos_z = np.append(pos_z, _pos_z)
678                        sld_n = np.append(sld_n, _sld_n)
679                        sld_mx = np.append(sld_mx, _sld_mx)
680                        sld_my = np.append(sld_my, _sld_my)
681                        sld_mz = np.append(sld_mz, _sld_mz)
682                        try:
683                            _vol_pix = float(toks[7])
684                            vol_pix = np.append(vol_pix, _vol_pix)
685                        except Exception:
686                            vol_pix = None
687                    except Exception:
688                        # Skip non-data lines
689                        logger.error(sys.exc_value)
690            output = MagSLD(pos_x, pos_y, pos_z, sld_n,
691                            sld_mx, sld_my, sld_mz)
692            output.filename = os.path.basename(path)
693            output.set_pix_type('pixel')
694            output.set_pixel_symbols('pixel')
695            if vol_pix is not None:
696                output.set_pixel_volumes(vol_pix)
697            return output
698        except Exception:
699            raise RuntimeError("%s is not a sld file" % path)
700
701    def write(self, path, data):
702        """
703        Write sld file
704        :Param path: file path
705        :Param data: MagSLD data object
706        """
707        if path is None:
708            raise ValueError("Missing the file path.")
709        if data is None:
710            raise ValueError("Missing the data to save.")
711        x_val = data.pos_x
712        y_val = data.pos_y
713        z_val = data.pos_z
714        vol_pix = data.vol_pix
715        length = len(x_val)
716        sld_n = data.sld_n
717        if sld_n is None:
718            sld_n = np.zeros(length)
719        sld_mx = data.sld_mx
720        if sld_mx is None:
721            sld_mx = np.zeros(length)
722            sld_my = np.zeros(length)
723            sld_mz = np.zeros(length)
724        else:
725            sld_my = data.sld_my
726            sld_mz = data.sld_mz
727        out = open(path, 'w')
728        # First Line: Column names
729        out.write("X  Y  Z  SLDN SLDMx  SLDMy  SLDMz VOLUMEpix")
730        for ind in range(length):
731            out.write("\n%g  %g  %g  %g  %g  %g  %g %g" % \
732                      (x_val[ind], y_val[ind], z_val[ind], sld_n[ind],
733                       sld_mx[ind], sld_my[ind], sld_mz[ind], vol_pix[ind]))
734        out.close()
735
736
737class OMFData(object):
738    """
739    OMF Data.
740    """
741    _meshunit = "A"
742    _valueunit = "A^(-2)"
743    def __init__(self):
744        """
745        Init for mag SLD
746        """
747        self.filename = 'default'
748        self.oommf = ''
749        self.title = ''
750        self.desc = ''
751        self.meshtype = ''
752        self.meshunit = self._meshunit
753        self.valueunit = self._valueunit
754        self.xbase = 0.0
755        self.ybase = 0.0
756        self.zbase = 0.0
757        self.xstepsize = 6.0
758        self.ystepsize = 6.0
759        self.zstepsize = 6.0
760        self.xnodes = 10.0
761        self.ynodes = 10.0
762        self.znodes = 10.0
763        self.xmin = 0.0
764        self.ymin = 0.0
765        self.zmin = 0.0
766        self.xmax = 60.0
767        self.ymax = 60.0
768        self.zmax = 60.0
769        self.mx = None
770        self.my = None
771        self.mz = None
772        self.valuemultiplier = 1.
773        self.valuerangeminmag = 0
774        self.valuerangemaxmag = 0
775
776    def __str__(self):
777        """
778        doc strings
779        """
780        _str = "Type:            %s\n" % self.__class__.__name__
781        _str += "File:            %s\n" % self.filename
782        _str += "OOMMF:           %s\n" % self.oommf
783        _str += "Title:           %s\n" % self.title
784        _str += "Desc:            %s\n" % self.desc
785        _str += "meshtype:        %s\n" % self.meshtype
786        _str += "meshunit:        %s\n" % str(self.meshunit)
787        _str += "xbase:           %s [%s]\n" % (str(self.xbase), self.meshunit)
788        _str += "ybase:           %s [%s]\n" % (str(self.ybase), self.meshunit)
789        _str += "zbase:           %s [%s]\n" % (str(self.zbase), self.meshunit)
790        _str += "xstepsize:       %s [%s]\n" % (str(self.xstepsize),
791                                                self.meshunit)
792        _str += "ystepsize:       %s [%s]\n" % (str(self.ystepsize),
793                                                self.meshunit)
794        _str += "zstepsize:       %s [%s]\n" % (str(self.zstepsize),
795                                                self.meshunit)
796        _str += "xnodes:          %s\n" % str(self.xnodes)
797        _str += "ynodes:          %s\n" % str(self.ynodes)
798        _str += "znodes:          %s\n" % str(self.znodes)
799        _str += "xmin:            %s [%s]\n" % (str(self.xmin), self.meshunit)
800        _str += "ymin:            %s [%s]\n" % (str(self.ymin), self.meshunit)
801        _str += "zmin:            %s [%s]\n" % (str(self.zmin), self.meshunit)
802        _str += "xmax:            %s [%s]\n" % (str(self.xmax), self.meshunit)
803        _str += "ymax:            %s [%s]\n" % (str(self.ymax), self.meshunit)
804        _str += "zmax:            %s [%s]\n" % (str(self.zmax), self.meshunit)
805        _str += "valueunit:       %s\n" % self.valueunit
806        _str += "valuemultiplier: %s\n" % str(self.valuemultiplier)
807        _str += "ValueRangeMinMag:%s [%s]\n" % (str(self.valuerangeminmag),
808                                                self.valueunit)
809        _str += "ValueRangeMaxMag:%s [%s]\n" % (str(self.valuerangemaxmag),
810                                                self.valueunit)
811        return _str
812
813    def set_m(self, mx, my, mz):
814        """
815        Set the Mx, My, Mz values
816        """
817        self.mx = mx
818        self.my = my
819        self.mz = mz
820
821class MagSLD(object):
822    """
823    Magnetic SLD.
824    """
825    pos_x = None
826    pos_y = None
827    pos_z = None
828    sld_n = None
829    sld_mx = None
830    sld_my = None
831    sld_mz = None
832    # Units
833    _pos_unit = 'A'
834    _sld_unit = '1/A^(2)'
835    _pix_type = 'pixel'
836
837    def __init__(self, pos_x, pos_y, pos_z, sld_n=None,
838                 sld_mx=None, sld_my=None, sld_mz=None, vol_pix=None):
839        """
840        Init for mag SLD
841        :params : All should be numpy 1D array
842        """
843        self.is_data = True
844        self.filename = ''
845        self.xstepsize = 6.0
846        self.ystepsize = 6.0
847        self.zstepsize = 6.0
848        self.xnodes = 10.0
849        self.ynodes = 10.0
850        self.znodes = 10.0
851        self.has_stepsize = False
852        self.has_conect = False
853        self.pos_unit = self._pos_unit
854        self.sld_unit = self._sld_unit
855        self.pix_type = 'pixel'
856        self.pos_x = pos_x
857        self.pos_y = pos_y
858        self.pos_z = pos_z
859        self.sld_n = sld_n
860        self.line_x = None
861        self.line_y = None
862        self.line_z = None
863        self.sld_mx = sld_mx
864        self.sld_my = sld_my
865        self.sld_mz = sld_mz
866        self.vol_pix = vol_pix
867        self.sld_m = None
868        self.sld_phi = None
869        self.sld_theta = None
870        self.pix_symbol = None
871        if sld_mx is not None and sld_my is not None and sld_mz is not None:
872            self.set_sldms(sld_mx, sld_my, sld_mz)
873        self.set_nodes()
874
875    def __str__(self):
876        """
877        doc strings
878        """
879        _str = "Type:       %s\n" % self.__class__.__name__
880        _str += "File:       %s\n" % self.filename
881        _str += "Axis_unit:  %s\n" % self.pos_unit
882        _str += "SLD_unit:   %s\n" % self.sld_unit
883        return _str
884
885    def set_pix_type(self, pix_type):
886        """
887        Set pixel type
888        :Param pix_type: string, 'pixel' or 'atom'
889        """
890        self.pix_type = pix_type
891
892    def set_sldn(self, sld_n):
893        """
894        Sets neutron SLD
895        """
896        if sld_n.__class__.__name__ == 'float':
897            if self.is_data:
898                # For data, put the value to only the pixels w non-zero M
899                is_nonzero = (np.fabs(self.sld_mx) +
900                              np.fabs(self.sld_my) +
901                              np.fabs(self.sld_mz)).nonzero()
902                self.sld_n = np.zeros(len(self.pos_x))
903                if len(self.sld_n[is_nonzero]) > 0:
904                    self.sld_n[is_nonzero] = sld_n
905                else:
906                    self.sld_n.fill(sld_n)
907            else:
908                # For non-data, put the value to all the pixels
909                self.sld_n = np.ones(len(self.pos_x)) * sld_n
910        else:
911            self.sld_n = sld_n
912
913    def set_sldms(self, sld_mx, sld_my, sld_mz):
914        r"""
915        Sets mx, my, mz and abs(m).
916        """ # Note: escaping
917        if sld_mx.__class__.__name__ == 'float':
918            self.sld_mx = np.ones(len(self.pos_x)) * sld_mx
919        else:
920            self.sld_mx = sld_mx
921        if sld_my.__class__.__name__ == 'float':
922            self.sld_my = np.ones(len(self.pos_x)) * sld_my
923        else:
924            self.sld_my = sld_my
925        if sld_mz.__class__.__name__ == 'float':
926            self.sld_mz = np.ones(len(self.pos_x)) * sld_mz
927        else:
928            self.sld_mz = sld_mz
929
930        sld_m = np.sqrt(sld_mx * sld_mx + sld_my * sld_my + \
931                                sld_mz * sld_mz)
932        self.sld_m = sld_m
933
934    def set_pixel_symbols(self, symbol='pixel'):
935        """
936        Set pixel
937        :Params pixel: str; pixel or atomic symbol, or array of strings
938        """
939        if self.sld_n is None:
940            return
941        if symbol.__class__.__name__ == 'str':
942            self.pix_symbol = np.repeat(symbol, len(self.sld_n))
943        else:
944            self.pix_symbol = symbol
945
946    def set_pixel_volumes(self, vol):
947        """
948        Set pixel volumes
949        :Params pixel: str; pixel or atomic symbol, or array of strings
950        """
951        if self.sld_n is None:
952            return
953        if vol.__class__.__name__ == 'ndarray':
954            self.vol_pix = vol
955        elif vol.__class__.__name__.count('float') > 0:
956            self.vol_pix = np.repeat(vol, len(self.sld_n))
957        else:
958            self.vol_pix = None
959
960    def get_sldn(self):
961        """
962        Returns nuclear sld
963        """
964        return self.sld_n
965
966    def set_nodes(self):
967        """
968        Set xnodes, ynodes, and znodes
969        """
970        self.set_stepsize()
971        if self.pix_type == 'pixel':
972            try:
973                xdist = (max(self.pos_x) - min(self.pos_x)) / self.xstepsize
974                ydist = (max(self.pos_y) - min(self.pos_y)) / self.ystepsize
975                zdist = (max(self.pos_z) - min(self.pos_z)) / self.zstepsize
976                self.xnodes = int(xdist) + 1
977                self.ynodes = int(ydist) + 1
978                self.znodes = int(zdist) + 1
979            except Exception:
980                self.xnodes = None
981                self.ynodes = None
982                self.znodes = None
983        else:
984            self.xnodes = None
985            self.ynodes = None
986            self.znodes = None
987
988    def set_stepsize(self):
989        """
990        Set xtepsize, ystepsize, and zstepsize
991        """
992        if self.pix_type == 'pixel':
993            try:
994                xpos_pre = self.pos_x[0]
995                ypos_pre = self.pos_y[0]
996                zpos_pre = self.pos_z[0]
997                for x_pos in self.pos_x:
998                    if xpos_pre != x_pos:
999                        self.xstepsize = np.fabs(x_pos - xpos_pre)
1000                        break
1001                for y_pos in self.pos_y:
1002                    if ypos_pre != y_pos:
1003                        self.ystepsize = np.fabs(y_pos - ypos_pre)
1004                        break
1005                for z_pos in self.pos_z:
1006                    if zpos_pre != z_pos:
1007                        self.zstepsize = np.fabs(z_pos - zpos_pre)
1008                        break
1009                #default pix volume
1010                self.vol_pix = np.ones(len(self.pos_x))
1011                vol = self.xstepsize * self.ystepsize * self.zstepsize
1012                self.set_pixel_volumes(vol)
1013                self.has_stepsize = True
1014            except Exception:
1015                self.xstepsize = None
1016                self.ystepsize = None
1017                self.zstepsize = None
1018                self.vol_pix = None
1019                self.has_stepsize = False
1020        else:
1021            self.xstepsize = None
1022            self.ystepsize = None
1023            self.zstepsize = None
1024            self.has_stepsize = True
1025        return self.xstepsize, self.ystepsize, self.zstepsize
1026
1027    def set_conect_lines(self, line_x, line_y, line_z):
1028        """
1029        Set bonding line data if taken from pdb
1030        """
1031        if line_x.__class__.__name__ != 'list' or len(line_x) < 1:
1032            return
1033        if line_y.__class__.__name__ != 'list' or len(line_y) < 1:
1034            return
1035        if line_z.__class__.__name__ != 'list' or len(line_z) < 1:
1036            return
1037        self.has_conect = True
1038        self.line_x = line_x
1039        self.line_y = line_y
1040        self.line_z = line_z
1041
1042def _get_data_path(*path_parts):
1043    from os.path import realpath, join as joinpath, dirname, abspath
1044    # in sas/sascalc/calculator;  want sas/sasview/test
1045    return joinpath(dirname(realpath(__file__)),
1046                    '..', '..', 'sasview', 'test', *path_parts)
1047
1048def test_load():
1049    """
1050        Test code
1051    """
1052    from mpl_toolkits.mplot3d import Axes3D
1053    tfpath = _get_data_path("1d_data", "CoreXY_ShellZ.txt")
1054    ofpath = _get_data_path("coordinate_data", "A_Raw_Example-1.omf")
1055    if not os.path.isfile(tfpath) or not os.path.isfile(ofpath):
1056        raise ValueError("file(s) not found: %r, %r"%(tfpath, ofpath))
1057    reader = SLDReader()
1058    oreader = OMFReader()
1059    output = reader.read(tfpath)
1060    ooutput = oreader.read(ofpath)
1061    foutput = OMF2SLD()
1062    foutput.set_data(ooutput)
1063
1064    import matplotlib.pyplot as plt
1065    fig = plt.figure()
1066    ax = Axes3D(fig)
1067    ax.plot(output.pos_x, output.pos_y, output.pos_z, '.', c="g",
1068            alpha=0.7, markeredgecolor='gray', rasterized=True)
1069    gap = 7
1070    max_mx = max(output.sld_mx)
1071    max_my = max(output.sld_my)
1072    max_mz = max(output.sld_mz)
1073    max_m = max(max_mx, max_my, max_mz)
1074    x2 = output.pos_x+output.sld_mx/max_m * gap
1075    y2 = output.pos_y+output.sld_my/max_m * gap
1076    z2 = output.pos_z+output.sld_mz/max_m * gap
1077    x_arrow = np.column_stack((output.pos_x, x2))
1078    y_arrow = np.column_stack((output.pos_y, y2))
1079    z_arrow = np.column_stack((output.pos_z, z2))
1080    unit_x2 = output.sld_mx / max_m
1081    unit_y2 = output.sld_my / max_m
1082    unit_z2 = output.sld_mz / max_m
1083    color_x = np.fabs(unit_x2 * 0.8)
1084    color_y = np.fabs(unit_y2 * 0.8)
1085    color_z = np.fabs(unit_z2 * 0.8)
1086    colors = np.column_stack((color_x, color_y, color_z))
1087    plt.show()
1088
1089def test_save():
1090    ofpath = _get_data_path("coordinate_data", "A_Raw_Example-1.omf")
1091    if not os.path.isfile(ofpath):
1092        raise ValueError("file(s) not found: %r"%(ofpath,))
1093    oreader = OMFReader()
1094    omfdata = oreader.read(ofpath)
1095    omf2sld = OMF2SLD()
1096    omf2sld.set_data(omfdata)
1097    writer = SLDReader()
1098    writer.write("out.txt", omf2sld.output)
1099
1100def test():
1101    """
1102        Test code
1103    """
1104    ofpath = _get_data_path("coordinate_data", "A_Raw_Example-1.omf")
1105    if not os.path.isfile(ofpath):
1106        raise ValueError("file(s) not found: %r"%(ofpath,))
1107    oreader = OMFReader()
1108    omfdata = oreader.read(ofpath)
1109    omf2sld = OMF2SLD()
1110    omf2sld.set_data(omfdata)
1111    model = GenSAS()
1112    model.set_sld_data(omf2sld.output)
1113    x = np.linspace(0, 0.1, 11)[1:]
1114    model.runXY([x, x])
1115
1116if __name__ == "__main__":
1117    #test_load()
1118    #test_save()
1119    test()
Note: See TracBrowser for help on using the repository browser.