Changeset 31d5187 in sasmodels


Ignore:
Timestamp:
Mar 6, 2019 5:14:53 PM (6 years ago)
Author:
GitHub <noreply@…>
Branches:
master, core_shell_microgels, magnetic_model, ticket-1257-vesicle-product, ticket_1156, ticket_1265_superball, ticket_822_more_unit_tests
Children:
674186e
Parents:
cff2939 (diff), e589e9a (diff)
Note: this is a merge changeset, the changes displayed below correspond to the merge itself.
Use the (diff) links above to see all the changes relative to each parent.
git-author:
Paul Kienzle <pkienzle@…> (03/06/19 17:14:53)
git-committer:
GitHub <noreply@…> (03/06/19 17:14:53)
Message:

Merge branch 'beta_approx' into webgl_jitter_viewer

Files:
10 edited

Legend:

Unmodified
Added
Removed
  • README.rst

    re30d645 r2a64722  
    1010is available. 
    1111 
    12 Example 
     12Install 
    1313------- 
     14 
     15The easiest way to use sasmodels is from `SasView <http://www.sasview.org/>`_. 
     16 
     17You can also install sasmodels as a standalone package in python. Use 
     18`miniconda <https://docs.conda.io/en/latest/miniconda.html>`_ 
     19or `anaconda <https://www.anaconda.com/>`_ 
     20to create a python environment with the sasmodels dependencies:: 
     21 
     22    $ conda create -n sasmodels -c conda-forge numpy scipy matplotlib pyopencl 
     23 
     24The option ``-n sasmodels`` names the environment sasmodels, and the option 
     25``-c conda-forge`` selects the conda-forge package channel because pyopencl 
     26is not part of the base anaconda distribution. 
     27 
     28Activate the environment and install sasmodels:: 
     29 
     30    $ conda activate sasmodels 
     31    (sasmodels) $ pip install sasmodels 
     32 
     33Install `bumps <https://github.com/bumps/bumps>`_ if you want to use it to fit 
     34your data:: 
     35 
     36    (sasmodels) $ pip install bumps 
     37 
     38Usage 
     39----- 
     40 
     41Check that the works:: 
     42 
     43    (sasmodels) $ python -m sasmodels.compare cylinder 
     44 
     45To show the orientation explorer:: 
     46 
     47    (sasmodels) $ python -m sasmodels.jitter 
     48 
     49Documentation is available online as part of the SasView 
     50`fitting perspective <http://www.sasview.org/docs/index.html>`_ 
     51as well as separate pages for 
     52`individual models <http://www.sasview.org/docs/user/sasgui/perspectives/fitting/models/index.html>`_. 
     53Programming details for sasmodels are available in the 
     54`developer documentation <http://www.sasview.org/docs/dev/dev.html>`_. 
     55 
     56 
     57Fitting Example 
     58--------------- 
    1459 
    1560The example directory contains a radial+tangential data set for an oriented 
    1661rod-like shape. 
    1762 
    18 The data is loaded by sas.dataloader from the sasview package, so sasview 
    19 is needed to run the example. 
     63To load the example data, you will need the SAS data loader from the sasview 
     64package. This is not yet available on PyPI, so you will need a copy of the 
     65SasView source code to run it.  Create a directory somewhere to hold the 
     66sasview and sasmodels source code, which we will refer to as $SOURCE. 
    2067 
    21 To run the example, you need sasview, sasmodels and bumps.  Assuming these 
    22 repositories are installed side by side, change to the sasmodels/example 
    23 directory and enter:: 
     68Use the following to install sasview, and the sasmodels examples:: 
    2469 
    25     PYTHONPATH=..:../../sasview/src ../../bumps/run.py fit.py \ 
    26         cylinder --preview 
     70    (sasmodels) $ cd $SOURCE 
     71    (sasmodels) $ conda install git 
     72    (sasmodels) $ git clone https://github.com/sasview/sasview.git 
     73    (sasmodels) $ git clone https://github.com/sasview/sasmodels.git 
    2774 
    28 See bumps documentation for instructions on running the fit.  With the 
    29 python packages installed, e.g., into a virtual environment, then the 
    30 python path need not be set, and the command would be:: 
     75Set the path to the sasview source on your python path within the sasmodels 
     76environment.  On Windows, this will be:: 
    3177 
    32     bumps fit.py cylinder --preview 
     78    (sasmodels)> set PYTHONPATH="$SOURCE\sasview\src" 
     79    (sasmodels)> cd $SOURCE/sasmodels/example 
     80    (sasmodels)> python -m bumps.cli fit.py cylinder --preview 
     81 
     82On Mac/Linux with the standard shell this will be:: 
     83 
     84    (sasmodels) $ export PYTHONPATH="$SOURCE/sasview/src" 
     85    (sasmodels) $ cd $SOURCE/sasmodels/example 
     86    (sasmodels) $ bumps fit.py cylinder --preview 
    3387 
    3488The fit.py model accepts up to two arguments.  The first argument is the 
     
    3892both radial and tangential simultaneously, use the word "both". 
    3993 
    40 Notes 
    41 ----- 
    42  
    43 cylinder.c + cylinder.py is the cylinder model with renamed variables and 
    44 sld scaled by 1e6 so the numbers are nicer.  The model name is "cylinder" 
    45  
    46 lamellar.py is an example of a single file model with embedded C code. 
     94See `bumps documentation <https://bumps.readthedocs.io/>`_ for detailed 
     95instructions on running the fit. 
    4796 
    4897|TravisStatus|_ 
  • explore/precision.py

    raa8c6e0 rcd28947  
    207207    return model_info 
    208208 
    209 # Hack to allow second parameter A in two parameter functions 
     209# Hack to allow second parameter A in the gammainc and gammaincc functions. 
     210# Create a 2-D variant of the precision test if we need to handle other two 
     211# parameter functions. 
    210212A = 1 
    211213def parse_extra_pars(): 
     214    """ 
     215    Parse the command line looking for the second parameter "A=..." for the 
     216    gammainc/gammaincc functions. 
     217    """ 
    212218    global A 
    213219 
     
    333339) 
    334340add_function( 
     341    # Note: "a" is given as A=... on the command line via parse_extra_pars 
    335342    name="gammainc(x)", 
    336343    mp_function=lambda x, a=A: mp.gammainc(a, a=0, b=x)/mp.gamma(a), 
     
    339346) 
    340347add_function( 
     348    # Note: "a" is given as A=... on the command line via parse_extra_pars 
    341349    name="gammaincc(x)", 
    342350    mp_function=lambda x, a=A: mp.gammainc(a, a=x, b=mp.inf)/mp.gamma(a), 
  • sasmodels/compare.py

    r07646b6 rc1799d3  
    11521152        'rel_err'   : True, 
    11531153        'explore'   : False, 
    1154         'use_demo'  : True, 
     1154        'use_demo'  : False, 
    11551155        'zero'      : False, 
    11561156        'html'      : False, 
  • sasmodels/direct_model.py

    r5024a56 rc1799d3  
    332332 
    333333        # Need to pull background out of resolution for multiple scattering 
    334         background = pars.get('background', DEFAULT_BACKGROUND) 
     334        default_background = self._model.info.parameters.common_parameters[1].default 
     335        background = pars.get('background', default_background) 
    335336        pars = pars.copy() 
    336337        pars['background'] = 0. 
  • sasmodels/generate.py

    r39a06c9 ra8a1d48  
    703703    """ 
    704704    for code in source: 
    705         m = _FQ_PATTERN.search(code) 
    706         if m is not None: 
     705        if _FQ_PATTERN.search(code) is not None: 
    707706            return True 
    708707    return False 
     
    712711    # type: (List[str]) -> bool 
    713712    """ 
    714     Return True if C source defines "void Fq(". 
     713    Return True if C source defines "double shell_volume(". 
    715714    """ 
    716715    for code in source: 
    717         m = _SHELL_VOLUME_PATTERN.search(code) 
    718         if m is not None: 
     716        if _SHELL_VOLUME_PATTERN.search(code) is not None: 
    719717            return True 
    720718    return False 
     
    10081006        pars = model_info.parameters.kernel_parameters 
    10091007    else: 
    1010         pars = model_info.parameters.COMMON + model_info.parameters.kernel_parameters 
     1008        pars = (model_info.parameters.common_parameters 
     1009                + model_info.parameters.kernel_parameters) 
    10111010    partable = make_partable(pars) 
    10121011    subst = dict(id=model_info.id.replace('_', '-'), 
  • sasmodels/kernel.py

    re44432d rcd28947  
    133133        nout = 2 if self.info.have_Fq and self.dim == '1d' else 1 
    134134        total_weight = self.result[nout*self.q_input.nq + 0] 
     135        # Note: total_weight = sum(weight > cutoff), with cutoff >= 0, so it 
     136        # is okay to test directly against zero.  If weight is zero then I(q), 
     137        # etc. must also be zero. 
    135138        if total_weight == 0.: 
    136139            total_weight = 1. 
  • sasmodels/modelinfo.py

    r39a06c9 rc1799d3  
    404404      parameters counted as n individual parameters p1, p2, ... 
    405405 
     406    * *common_parameters* is the list of common parameters, with a unique 
     407      copy for each model so that structure factors can have a default 
     408      background of 0.0. 
     409 
    406410    * *call_parameters* is the complete list of parameters to the kernel, 
    407411      including scale and background, with vector parameters recorded as 
     
    416420    parameters don't use vector notation, and instead use p1, p2, ... 
    417421    """ 
    418     # scale and background are implicit parameters 
    419     COMMON = [Parameter(*p) for p in COMMON_PARAMETERS] 
    420  
    421422    def __init__(self, parameters): 
    422423        # type: (List[Parameter]) -> None 
     424 
     425        # scale and background are implicit parameters 
     426        # Need them to be unique to each model in case they have different 
     427        # properties, such as default=0.0 for structure factor backgrounds. 
     428        self.common_parameters = [Parameter(*p) for p in COMMON_PARAMETERS] 
     429 
    423430        self.kernel_parameters = parameters 
    424431        self._set_vector_lengths() 
     
    468475                         if p.polydisperse and p.type not in ('orientation', 'magnetic')) 
    469476        self.pd_2d = set(p.name for p in self.call_parameters if p.polydisperse) 
     477 
     478    def set_zero_background(self): 
     479        """ 
     480        Set the default background to zero for this model.  This is done for 
     481        structure factor models. 
     482        """ 
     483        # type: () -> None 
     484        # Make sure background is the second common parameter. 
     485        assert self.common_parameters[1].id == "background" 
     486        self.common_parameters[1].default = 0.0 
     487        self.defaults = self._get_defaults() 
    470488 
    471489    def check_angles(self): 
     
    567585    def _get_call_parameters(self): 
    568586        # type: () -> List[Parameter] 
    569         full_list = self.COMMON[:] 
     587        full_list = self.common_parameters[:] 
    570588        for p in self.kernel_parameters: 
    571589            if p.length == 1: 
     
    670688 
    671689        # Gather the user parameters in order 
    672         result = control + self.COMMON 
     690        result = control + self.common_parameters 
    673691        for p in self.kernel_parameters: 
    674692            if not is2d and p.type in ('orientation', 'magnetic'): 
     
    770788 
    771789    info = ModelInfo() 
     790 
     791    # Build the parameter table 
    772792    #print("make parameter table", kernel_module.parameters) 
    773793    parameters = make_parameter_table(getattr(kernel_module, 'parameters', [])) 
     794 
     795    # background defaults to zero for structure factor models 
     796    structure_factor = getattr(kernel_module, 'structure_factor', False) 
     797    if structure_factor: 
     798        parameters.set_zero_background() 
     799 
     800    # TODO: remove demo parameters 
     801    # The plots in the docs are generated from the model default values. 
     802    # Sascomp set parameters from the command line, and so doesn't need 
     803    # demo values for testing. 
    774804    demo = expand_pars(parameters, getattr(kernel_module, 'demo', None)) 
     805 
    775806    filename = abspath(kernel_module.__file__).replace('.pyc', '.py') 
    776807    kernel_id = splitext(basename(filename))[0] 
  • sasmodels/models/hardsphere.py

    r304c775 rc1799d3  
    162162    return pars 
    163163 
    164 demo = dict(radius_effective=200, volfraction=0.2, 
    165             radius_effective_pd=0.1, radius_effective_pd_n=40) 
    166164# Q=0.001 is in the Taylor series, low Q part, so add Q=0.1, 
    167165# assuming double precision sasview is correct 
  • sasmodels/sasview_model.py

    r5024a56 ra8a1d48  
    382382            hidden.add('scale') 
    383383            hidden.add('background') 
    384             self._model_info.parameters.defaults['background'] = 0. 
    385384 
    386385        # Update the parameter lists to exclude any hidden parameters 
     
    695694            return self._calculate_Iq(qx, qy) 
    696695 
    697     def _calculate_Iq(self, qx, qy=None, Fq=False, effective_radius_type=1): 
     696    def _calculate_Iq(self, qx, qy=None): 
    698697        if self._model is None: 
    699698            self._model = core.build_model(self._model_info) 
     
    715714        #print("values", values) 
    716715        #print("is_mag", is_magnetic) 
    717         if Fq: 
    718             result = calculator.Fq(call_details, values, cutoff=self.cutoff, 
    719                                    magnetic=is_magnetic, 
    720                                    effective_radius_type=effective_radius_type) 
    721716        result = calculator(call_details, values, cutoff=self.cutoff, 
    722717                            magnetic=is_magnetic) 
     
    736731        Calculate the effective radius for P(q)*S(q) 
    737732 
     733        *mode* is the R_eff type, which defaults to 1 to match the ER 
     734        calculation for sasview models from version 3.x. 
     735 
    738736        :return: the value of the effective radius 
    739737        """ 
    740         Fq = self._calculate_Iq([0.1], True, mode) 
    741         return Fq[2] 
     738        # ER and VR are only needed for old multiplication models, based on 
     739        # sas.sascalc.fit.MultiplicationModel.  Fail for now.  If we want to 
     740        # continue supporting them then add some test cases so that the code 
     741        # is exercised.  We can access ER/VR using the kernel Fq function by 
     742        # extending _calculate_Iq so that it calls: 
     743        #    if er_mode > 0: 
     744        #        res = calculator.Fq(call_details, values, cutoff=self.cutoff, 
     745        #                            magnetic=False, effective_radius_type=mode) 
     746        #        R_eff, form_shell_ratio = res[2], res[4] 
     747        #        return R_eff, form_shell_ratio 
     748        # Then use the following in calculate_ER: 
     749        #    ER, VR = self._calculate_Iq(q=[0.1], er_mode=mode) 
     750        #    return ER 
     751        # Similarly, for calculate_VR: 
     752        #    ER, VR = self._calculate_Iq(q=[0.1], er_mode=1) 
     753        #    return VR 
     754        # Obviously a combined calculate_ER_VR method would be better, but 
     755        # we only need them to support very old models, so ignore the 2x 
     756        # performance hit. 
     757        raise NotImplementedError("ER function is no longer available.") 
    742758 
    743759    def calculate_VR(self): 
     
    748764        :return: the value of the form:shell volume ratio 
    749765        """ 
    750         Fq = self._calculate_Iq([0.1], True, mode) 
    751         return Fq[4] 
     766        # See comments in calculate_ER. 
     767        raise NotImplementedError("VR function is no longer available.") 
    752768 
    753769    def set_dispersion(self, parameter, dispersion): 
     
    914930    CylinderModel().evalDistribution([0.1, 0.1]) 
    915931 
     932def test_structure_factor_background(): 
     933    # type: () -> None 
     934    """ 
     935    Check that sasview model and direct model match, with background=0. 
     936    """ 
     937    from .data import empty_data1D 
     938    from .core import load_model_info, build_model 
     939    from .direct_model import DirectModel 
     940 
     941    model_name = "hardsphere" 
     942    q = [0.0] 
     943 
     944    sasview_model = _make_standard_model(model_name)() 
     945    sasview_value = sasview_model.evalDistribution(np.array(q))[0] 
     946 
     947    data = empty_data1D(q) 
     948    model_info = load_model_info(model_name) 
     949    model = build_model(model_info) 
     950    direct_model = DirectModel(data, model) 
     951    direct_value_zero_background = direct_model(background=0.0) 
     952 
     953    assert sasview_value == direct_value_zero_background 
     954 
     955    # Additionally check that direct value background defaults to zero 
     956    direct_value_default = direct_model() 
     957    assert sasview_value == direct_value_default 
     958 
     959 
    916960def magnetic_demo(): 
    917961    Model = _make_standard_model('sphere') 
     
    934978    #print("rpa:", test_rpa()) 
    935979    #test_empty_distribution() 
     980    #test_structure_factor_background() 
  • sasmodels/jitter.py

    r7d97437 rcff2939  
    11#!/usr/bin/env python 
     2# -*- coding: utf-8 -*- 
    23""" 
    34Jitter Explorer 
     
    56 
    67Application to explore orientation angle and angular dispersity. 
     8 
     9From the command line:: 
     10 
     11    # Show docs 
     12    python -m sasmodels.jitter --help 
     13 
     14    # Guyou projection jitter, uniform over 20 degree theta and 10 in phi 
     15    python -m sasmodels.jitter --projection=guyou --dist=uniform --jitter=20,10,0 
     16 
     17From a jupyter cell:: 
     18 
     19    import ipyvolume as ipv 
     20    from sasmodels import jitter 
     21    import importlib; importlib.reload(jitter) 
     22    jitter.set_plotter("ipv") 
     23 
     24    size = (10, 40, 100) 
     25    view = (20, 0, 0) 
     26 
     27    #size = (15, 15, 100) 
     28    #view = (60, 60, 0) 
     29 
     30    dview = (0, 0, 0) 
     31    #dview = (5, 5, 0) 
     32    #dview = (15, 180, 0) 
     33    #dview = (180, 15, 0) 
     34 
     35    projection = 'equirectangular' 
     36    #projection = 'azimuthal_equidistance' 
     37    #projection = 'guyou' 
     38    #projection = 'sinusoidal' 
     39    #projection = 'azimuthal_equal_area' 
     40 
     41    dist = 'uniform' 
     42    #dist = 'gaussian' 
     43 
     44    jitter.run(size=size, view=view, jitter=dview, dist=dist, projection=projection) 
     45    #filename = projection+('_theta' if dview[0] == 180 else '_phi' if dview[1] == 180 else '') 
     46    #ipv.savefig(filename+'.png') 
    747""" 
    848from __future__ import division, print_function 
     
    1050import argparse 
    1151 
    12 try: # CRUFT: travis-ci does not support mpl_toolkits.mplot3d 
    13     import mpl_toolkits.mplot3d  # Adds projection='3d' option to subplot 
    14 except ImportError: 
    15     pass 
    16  
    17 import matplotlib as mpl 
    18 import matplotlib.pyplot as plt 
    19 from matplotlib.widgets import Slider 
    2052import numpy as np 
    2153from numpy import pi, cos, sin, sqrt, exp, degrees, radians 
    2254 
    23 def draw_beam(axes, view=(0, 0)): 
     55def draw_beam(axes, view=(0, 0), alpha=0.5, steps=25): 
    2456    """ 
    2557    Draw the beam going from source at (0, 0, 1) to detector at (0, 0, -1) 
    2658    """ 
    2759    #axes.plot([0,0],[0,0],[1,-1]) 
    28     #axes.scatter([0]*100,[0]*100,np.linspace(1, -1, 100), alpha=0.8) 
    29  
    30     steps = 25 
     60    #axes.scatter([0]*100,[0]*100,np.linspace(1, -1, 100), alpha=alpha) 
     61 
    3162    u = np.linspace(0, 2 * np.pi, steps) 
    32     v = np.linspace(-1, 1, steps) 
     63    v = np.linspace(-1, 1, 2) 
    3364 
    3465    r = 0.02 
     
    4273    points = Rz(phi)*Ry(theta)*points 
    4374    x, y, z = [v.reshape(shape) for v in points] 
    44  
    45     axes.plot_surface(x, y, z, rstride=4, cstride=4, color='y', alpha=0.5) 
     75    axes.plot_surface(x, y, z, color='yellow', alpha=alpha) 
     76 
     77    # TODO: draw endcaps on beam 
     78    ## Drawing tiny balls on the end will work 
     79    #draw_sphere(axes, radius=0.02, center=(0, 0, 1.3), color='yellow', alpha=alpha) 
     80    #draw_sphere(axes, radius=0.02, center=(0, 0, -1.3), color='yellow', alpha=alpha) 
     81    ## The following does not work 
     82    #triangles = [(0, i+1, i+2) for i in range(steps-2)] 
     83    #x_cap, y_cap = x[:, 0], y[:, 0] 
     84    #for z_cap in z[:, 0], z[:, -1]: 
     85    #    axes.plot_trisurf(x_cap, y_cap, z_cap, triangles, 
     86    #                      color='yellow', alpha=alpha) 
     87 
    4688 
    4789def draw_ellipsoid(axes, size, view, jitter, steps=25, alpha=1): 
     
    5597    x, y, z = transform_xyz(view, jitter, x, y, z) 
    5698 
    57     axes.plot_surface(x, y, z, rstride=4, cstride=4, color='w', alpha=alpha) 
     99    axes.plot_surface(x, y, z, color='w', alpha=alpha) 
    58100 
    59101    draw_labels(axes, view, jitter, [ 
     
    124166    return atoms 
    125167 
    126 def draw_parallelepiped(axes, size, view, jitter, steps=None, alpha=1): 
     168def draw_box(axes, size, view): 
     169    a, b, c = size 
     170    x = a*np.array([+1, -1, +1, -1, +1, -1, +1, -1]) 
     171    y = b*np.array([+1, +1, -1, -1, +1, +1, -1, -1]) 
     172    z = c*np.array([+1, +1, +1, +1, -1, -1, -1, -1]) 
     173    x, y, z = transform_xyz(view, None, x, y, z) 
     174    def draw(i, j): 
     175        axes.plot([x[i],x[j]], [y[i], y[j]], [z[i], z[j]], color='black') 
     176    draw(0, 1) 
     177    draw(0, 2) 
     178    draw(0, 3) 
     179    draw(7, 4) 
     180    draw(7, 5) 
     181    draw(7, 6) 
     182 
     183def draw_parallelepiped(axes, size, view, jitter, steps=None, 
     184                        color=(0.6, 1.0, 0.6), alpha=1): 
    127185    """Draw a parallelepiped.""" 
    128186    a, b, c = size 
     
    142200 
    143201    x, y, z = transform_xyz(view, jitter, x, y, z) 
    144     axes.plot_trisurf(x, y, triangles=tri, Z=z, color='w', alpha=alpha) 
    145  
    146     # Draw pink face on box. 
     202    axes.plot_trisurf(x, y, triangles=tri, Z=z, color=color, alpha=alpha, 
     203                      linewidth=0) 
     204 
     205    # Colour the c+ face of the box. 
    147206    # Since I can't control face color, instead draw a thin box situated just 
    148207    # in front of the "c+" face.  Use the c face so that rotations about psi 
    149208    # rotate that face. 
    150     if 1: 
     209    if 0: 
     210        color = (1, 0.6, 0.6)  # pink 
    151211        x = a*np.array([+1, -1, +1, -1, +1, -1, +1, -1]) 
    152212        y = b*np.array([+1, +1, -1, -1, +1, +1, -1, -1]) 
    153213        z = c*np.array([+1, +1, +1, +1, -1, -1, -1, -1]) 
    154214        x, y, z = transform_xyz(view, jitter, x, y, abs(z)+0.001) 
    155         axes.plot_trisurf(x, y, triangles=tri, Z=z, color=[1, 0.6, 0.6], alpha=alpha) 
     215        axes.plot_trisurf(x, y, triangles=tri, Z=z, color=color, alpha=alpha) 
    156216 
    157217    draw_labels(axes, view, jitter, [ 
     
    164224    ]) 
    165225 
    166 def draw_sphere(axes, radius=10., steps=100): 
     226def draw_sphere(axes, radius=0.5, steps=25, center=(0,0,0), color='w', alpha=1.): 
    167227    """Draw a sphere""" 
    168228    u = np.linspace(0, 2 * np.pi, steps) 
    169229    v = np.linspace(0, np.pi, steps) 
    170230 
    171     x = radius * np.outer(np.cos(u), np.sin(v)) 
    172     y = radius * np.outer(np.sin(u), np.sin(v)) 
    173     z = radius * np.outer(np.ones(np.size(u)), np.cos(v)) 
    174     axes.plot_surface(x, y, z, rstride=4, cstride=4, color='w') 
    175  
    176 def draw_jitter(axes, view, jitter, dist='gaussian', size=(0.1, 0.4, 1.0), 
    177                 draw_shape=draw_parallelepiped): 
     231    x = radius * np.outer(np.cos(u), np.sin(v)) + center[0] 
     232    y = radius * np.outer(np.sin(u), np.sin(v)) + center[1] 
     233    z = radius * np.outer(np.ones(np.size(u)), np.cos(v)) + center[2] 
     234    axes.plot_surface(x, y, z, color=color, alpha=alpha) 
     235    #axes.plot_wireframe(x, y, z) 
     236 
     237def draw_axes(axes, origin=(-1, -1, -1), length=(2, 2, 2)): 
     238    x, y, z = origin 
     239    dx, dy, dz = length 
     240    axes.plot([x, x+dx], [y, y], [z, z], color='black') 
     241    axes.plot([x, x], [y, y+dy], [z, z], color='black') 
     242    axes.plot([x, x], [y, y], [z, z+dz], color='black') 
     243 
     244def draw_person_on_sphere(axes, view, height=0.5, radius=0.5): 
     245    limb_offset = height * 0.05 
     246    head_radius = height * 0.10 
     247    head_height = height - head_radius 
     248    neck_length = head_radius * 0.50 
     249    shoulder_height = height - 2*head_radius - neck_length 
     250    torso_length = shoulder_height * 0.55 
     251    torso_radius = torso_length * 0.30 
     252    leg_length = shoulder_height - torso_length 
     253    arm_length = torso_length * 0.90 
     254 
     255    def _draw_part(x, z): 
     256        y = np.zeros_like(x) 
     257        xp, yp, zp = transform_xyz(view, None, x, y, z + radius) 
     258        axes.plot(xp, yp, zp, color='k') 
     259 
     260    # circle for head 
     261    u = np.linspace(0, 2 * np.pi, 40) 
     262    x = head_radius * np.cos(u) 
     263    z = head_radius * np.sin(u) + head_height 
     264    _draw_part(x, z) 
     265 
     266    # rectangle for body 
     267    x = np.array([-torso_radius, torso_radius, torso_radius, -torso_radius, -torso_radius]) 
     268    z = np.array([0., 0, torso_length, torso_length, 0]) + leg_length 
     269    _draw_part(x, z) 
     270 
     271    # arms 
     272    x = np.array([-torso_radius - limb_offset, -torso_radius - limb_offset, -torso_radius]) 
     273    z = np.array([shoulder_height - arm_length, shoulder_height, shoulder_height]) 
     274    _draw_part(x, z) 
     275    _draw_part(-x, z) 
     276 
     277    # legs 
     278    x = np.array([-torso_radius + limb_offset, -torso_radius + limb_offset]) 
     279    z = np.array([0, leg_length]) 
     280    _draw_part(x, z) 
     281    _draw_part(-x, z) 
     282 
     283    limits = [-radius-height, radius+height] 
     284    axes.set_xlim(limits) 
     285    axes.set_ylim(limits) 
     286    axes.set_zlim(limits) 
     287    axes.set_axis_off() 
     288 
     289def draw_jitter(axes, view, jitter, dist='gaussian', 
     290                size=(0.1, 0.4, 1.0), 
     291                draw_shape=draw_parallelepiped, 
     292                projection='equirectangular', 
     293                alpha=0.8, 
     294                views=None): 
    178295    """ 
    179296    Represent jitter as a set of shapes at different orientations. 
    180297    """ 
     298    project, project_weight = get_projection(projection) 
     299 
    181300    # set max diagonal to 0.95 
    182301    scale = 0.95/sqrt(sum(v**2 for v in size)) 
    183302    size = tuple(scale*v for v in size) 
    184303 
    185     #np.random.seed(10) 
    186     #cloud = np.random.randn(10,3) 
    187     cloud = [ 
    188         [-1, -1, -1], 
    189         [-1, -1, +0], 
    190         [-1, -1, +1], 
    191         [-1, +0, -1], 
    192         [-1, +0, +0], 
    193         [-1, +0, +1], 
    194         [-1, +1, -1], 
    195         [-1, +1, +0], 
    196         [-1, +1, +1], 
    197         [+0, -1, -1], 
    198         [+0, -1, +0], 
    199         [+0, -1, +1], 
    200         [+0, +0, -1], 
    201         [+0, +0, +0], 
    202         [+0, +0, +1], 
    203         [+0, +1, -1], 
    204         [+0, +1, +0], 
    205         [+0, +1, +1], 
    206         [+1, -1, -1], 
    207         [+1, -1, +0], 
    208         [+1, -1, +1], 
    209         [+1, +0, -1], 
    210         [+1, +0, +0], 
    211         [+1, +0, +1], 
    212         [+1, +1, -1], 
    213         [+1, +1, +0], 
    214         [+1, +1, +1], 
    215     ] 
    216304    dtheta, dphi, dpsi = jitter 
    217     if dtheta == 0: 
    218         cloud = [v for v in cloud if v[0] == 0] 
    219     if dphi == 0: 
    220         cloud = [v for v in cloud if v[1] == 0] 
    221     if dpsi == 0: 
    222         cloud = [v for v in cloud if v[2] == 0] 
    223     draw_shape(axes, size, view, [0, 0, 0], steps=100, alpha=0.8) 
    224     scale = {'gaussian':1, 'rectangle':1/sqrt(3), 'uniform':1/3}[dist] 
    225     for point in cloud: 
    226         delta = [scale*dtheta*point[0], scale*dphi*point[1], scale*dpsi*point[2]] 
    227         draw_shape(axes, size, view, delta, alpha=0.8) 
     305    base = {'gaussian':3, 'rectangle':sqrt(3), 'uniform':1}[dist] 
     306    def steps(delta): 
     307        if views is None: 
     308            n = max(3, min(25, 2*int(base*delta/5))) 
     309        else: 
     310            n = views 
     311        return base*delta*np.linspace(-1, 1, n) if delta > 0 else [0.] 
     312    for theta in steps(dtheta): 
     313        for phi in steps(dphi): 
     314            for psi in steps(dpsi): 
     315                w = project_weight(theta, phi, 1.0, 1.0) 
     316                if w > 0: 
     317                    dview = project(theta, phi, psi) 
     318                    draw_shape(axes, size, view, dview, alpha=alpha) 
    228319    for v in 'xyz': 
    229320        a, b, c = size 
    230321        lim = np.sqrt(a**2 + b**2 + c**2) 
    231322        getattr(axes, 'set_'+v+'lim')([-lim, lim]) 
    232         getattr(axes, v+'axis').label.set_text(v) 
     323        #getattr(axes, v+'axis').label.set_text(v) 
    233324 
    234325PROJECTIONS = [ 
     
    238329    'azimuthal_equal_area', 
    239330] 
    240 def draw_mesh(axes, view, jitter, radius=1.2, n=11, dist='gaussian', 
    241               projection='equirectangular'): 
    242     """ 
    243     Draw the dispersion mesh showing the theta-phi orientations at which 
    244     the model will be evaluated. 
    245  
     331def get_projection(projection): 
     332 
     333    """ 
    246334    jitter projections 
    247335    <https://en.wikipedia.org/wiki/List_of_map_projections> 
     
    299387    # TODO: try Kent distribution instead of a gaussian warped by projection 
    300388 
    301     dist_x = np.linspace(-1, 1, n) 
    302     weights = np.ones_like(dist_x) 
    303     if dist == 'gaussian': 
    304         dist_x *= 3 
    305         weights = exp(-0.5*dist_x**2) 
    306     elif dist == 'rectangle': 
    307         # Note: uses sasmodels ridiculous definition of rectangle width 
    308         dist_x *= sqrt(3) 
    309     elif dist == 'uniform': 
    310         pass 
    311     else: 
    312         raise ValueError("expected dist to be gaussian, rectangle or uniform") 
    313  
    314389    if projection == 'equirectangular':  #define PROJECTION 1 
    315         def _rotate(theta_i, phi_j): 
    316             return Rx(phi_j)*Ry(theta_i) 
     390        def _project(theta_i, phi_j, psi): 
     391            latitude, longitude = theta_i, phi_j 
     392            return latitude, longitude, psi 
     393            #return Rx(phi_j)*Ry(theta_i) 
    317394        def _weight(theta_i, phi_j, w_i, w_j): 
    318395            return w_i*w_j*abs(cos(radians(theta_i))) 
    319396    elif projection == 'sinusoidal':  #define PROJECTION 2 
    320         def _rotate(theta_i, phi_j): 
     397        def _project(theta_i, phi_j, psi): 
    321398            latitude = theta_i 
    322399            scale = cos(radians(latitude)) 
    323400            longitude = phi_j/scale if abs(phi_j) < abs(scale)*180 else 0 
    324401            #print("(%+7.2f, %+7.2f) => (%+7.2f, %+7.2f)"%(theta_i, phi_j, latitude, longitude)) 
    325             return Rx(longitude)*Ry(latitude) 
    326         def _weight(theta_i, phi_j, w_i, w_j): 
     402            return latitude, longitude, psi 
     403            #return Rx(longitude)*Ry(latitude) 
     404        def _project(theta_i, phi_j, w_i, w_j): 
    327405            latitude = theta_i 
    328406            scale = cos(radians(latitude)) 
     
    330408            return active*w_i*w_j 
    331409    elif projection == 'guyou':  #define PROJECTION 3  (eventually?) 
    332         def _rotate(theta_i, phi_j): 
     410        def _project(theta_i, phi_j, psi): 
    333411            from .guyou import guyou_invert 
    334412            #latitude, longitude = guyou_invert([theta_i], [phi_j]) 
    335413            longitude, latitude = guyou_invert([phi_j], [theta_i]) 
    336             return Rx(longitude[0])*Ry(latitude[0]) 
     414            return latitude, longitude, psi 
     415            #return Rx(longitude[0])*Ry(latitude[0]) 
    337416        def _weight(theta_i, phi_j, w_i, w_j): 
    338417            return w_i*w_j 
    339     elif projection == 'azimuthal_equidistance':  # Note: Rz Ry, not Rx Ry 
    340         def _rotate(theta_i, phi_j): 
     418    elif projection == 'azimuthal_equidistance': 
     419        # Note that calculates angles for Rz Ry rather than Rx Ry 
     420        def _project(theta_i, phi_j, psi): 
    341421            latitude = sqrt(theta_i**2 + phi_j**2) 
    342422            longitude = degrees(np.arctan2(phi_j, theta_i)) 
    343423            #print("(%+7.2f, %+7.2f) => (%+7.2f, %+7.2f)"%(theta_i, phi_j, latitude, longitude)) 
    344             return Rz(longitude)*Ry(latitude) 
     424            return latitude, longitude, psi-longitude, 'zyz' 
     425            #R = Rz(longitude)*Ry(latitude)*Rz(psi) 
     426            #return R_to_xyz(R) 
     427            #return Rz(longitude)*Ry(latitude) 
    345428        def _weight(theta_i, phi_j, w_i, w_j): 
    346429            # Weighting for each point comes from the integral: 
     
    376459            return weight*w_i*w_j if latitude < 180 else 0 
    377460    elif projection == 'azimuthal_equal_area': 
    378         def _rotate(theta_i, phi_j): 
     461        # Note that calculates angles for Rz Ry rather than Rx Ry 
     462        def _project(theta_i, phi_j, psi): 
    379463            radius = min(1, sqrt(theta_i**2 + phi_j**2)/180) 
    380464            latitude = 180-degrees(2*np.arccos(radius)) 
    381465            longitude = degrees(np.arctan2(phi_j, theta_i)) 
    382466            #print("(%+7.2f, %+7.2f) => (%+7.2f, %+7.2f)"%(theta_i, phi_j, latitude, longitude)) 
    383             return Rz(longitude)*Ry(latitude) 
     467            return latitude, longitude, psi, "zyz" 
     468            #R = Rz(longitude)*Ry(latitude)*Rz(psi) 
     469            #return R_to_xyz(R) 
     470            #return Rz(longitude)*Ry(latitude) 
    384471        def _weight(theta_i, phi_j, w_i, w_j): 
    385472            latitude = sqrt(theta_i**2 + phi_j**2) 
     
    389476        raise ValueError("unknown projection %r"%projection) 
    390477 
     478    return _project, _weight 
     479 
     480def R_to_xyz(R): 
     481    """ 
     482    Return phi, theta, psi Tait-Bryan angles corresponding to the given rotation matrix. 
     483 
     484    Extracting Euler Angles from a Rotation Matrix 
     485    Mike Day, Insomniac Games 
     486    https://d3cw3dd2w32x2b.cloudfront.net/wp-content/uploads/2012/07/euler-angles1.pdf 
     487    Based on: Shoemake’s "Euler Angle Conversion", Graphics Gems IV, pp.  222-229 
     488    """ 
     489    phi = np.arctan2(R[1, 2], R[2, 2]) 
     490    theta = np.arctan2(-R[0, 2], np.sqrt(R[0, 0]**2 + R[0, 1]**2)) 
     491    psi = np.arctan2(R[0, 1], R[0, 0]) 
     492    return np.degrees(phi), np.degrees(theta), np.degrees(psi) 
     493 
     494def draw_mesh(axes, view, jitter, radius=1.2, n=11, dist='gaussian', 
     495              projection='equirectangular'): 
     496    """ 
     497    Draw the dispersion mesh showing the theta-phi orientations at which 
     498    the model will be evaluated. 
     499    """ 
     500 
     501    _project, _weight = get_projection(projection) 
     502    def _rotate(theta, phi, z): 
     503        dview = _project(theta, phi, 0.) 
     504        if len(dview) == 4: # hack for zyz coords 
     505            return Rz(dview[1])*Ry(dview[0])*z 
     506        else: 
     507            return Rx(dview[1])*Ry(dview[0])*z 
     508 
     509 
     510    dist_x = np.linspace(-1, 1, n) 
     511    weights = np.ones_like(dist_x) 
     512    if dist == 'gaussian': 
     513        dist_x *= 3 
     514        weights = exp(-0.5*dist_x**2) 
     515    elif dist == 'rectangle': 
     516        # Note: uses sasmodels ridiculous definition of rectangle width 
     517        dist_x *= sqrt(3) 
     518    elif dist == 'uniform': 
     519        pass 
     520    else: 
     521        raise ValueError("expected dist to be gaussian, rectangle or uniform") 
     522 
    391523    # mesh in theta, phi formed by rotating z 
    392524    dtheta, dphi, dpsi = jitter 
    393525    z = np.matrix([[0], [0], [radius]]) 
    394     points = np.hstack([_rotate(theta_i, phi_j)*z 
     526    points = np.hstack([_rotate(theta_i, phi_j, z) 
    395527                        for theta_i in dtheta*dist_x 
    396528                        for phi_j in dphi*dist_x]) 
     
    470602    Points are stored in a 3 x n numpy matrix, not a numpy array or tuple. 
    471603    """ 
    472     dtheta, dphi, dpsi = jitter 
    473     points = Rx(dphi)*Ry(dtheta)*Rz(dpsi)*points 
     604    if jitter is None: 
     605        return points 
     606    # Hack to deal with the fact that azimuthal_equidistance uses euler angles 
     607    if len(jitter) == 4: 
     608        dtheta, dphi, dpsi, _ = jitter 
     609        points = Rz(dphi)*Ry(dtheta)*Rz(dpsi)*points 
     610    else: 
     611        dtheta, dphi, dpsi = jitter 
     612        points = Rx(dphi)*Ry(dtheta)*Rz(dpsi)*points 
    474613    return points 
    475614 
     
    481620    """ 
    482621    theta, phi, psi = view 
    483     points = Rz(phi)*Ry(theta)*Rz(psi)*points 
     622    points = Rz(phi)*Ry(theta)*Rz(psi)*points # viewing angle 
     623    #points = Rz(psi)*Ry(pi/2-theta)*Rz(phi)*points # 1-D integration angles 
     624    #points = Rx(phi)*Ry(theta)*Rz(psi)*points  # angular dispersion angle 
    484625    return points 
     626 
     627def orient_relative_to_beam_quaternion(view, points): 
     628    """ 
     629    Apply the view transform to a set of points. 
     630 
     631    Points are stored in a 3 x n numpy matrix, not a numpy array or tuple. 
     632 
     633    This variant uses quaternions rather than rotation matrices for the 
     634    computation.  It works but it is not used because it doesn't solve 
     635    any problems.  The challenge of mapping theta/phi/psi to SO(3) does 
     636    not disappear by calculating the transform differently. 
     637    """ 
     638    theta, phi, psi = view 
     639    x, y, z = [1, 0, 0], [0, 1, 0], [0, 0, 1] 
     640    q = Quaternion(1, [0, 0, 0]) 
     641    ## Compose a rotation about the three axes by rotating 
     642    ## the unit vectors before applying the rotation. 
     643    #q = Quaternion.from_angle_axis(theta, q.rot(x)) * q 
     644    #q = Quaternion.from_angle_axis(phi, q.rot(y)) * q 
     645    #q = Quaternion.from_angle_axis(psi, q.rot(z)) * q 
     646    ## The above turns out to be equivalent to reversing 
     647    ## the order of application, so ignore it and use below. 
     648    q = q * Quaternion.from_angle_axis(theta, x) 
     649    q = q * Quaternion.from_angle_axis(phi, y) 
     650    q = q * Quaternion.from_angle_axis(psi, z) 
     651    ## Reverse the order by post-multiply rather than pre-multiply 
     652    #q = Quaternion.from_angle_axis(theta, x) * q 
     653    #q = Quaternion.from_angle_axis(phi, y) * q 
     654    #q = Quaternion.from_angle_axis(psi, z) * q 
     655    #print("axes psi", q.rot(np.matrix([x, y, z]).T)) 
     656    return q.rot(points) 
     657#orient_relative_to_beam = orient_relative_to_beam_quaternion 
     658 
     659# Simple stand-alone quaternion class 
     660import numpy as np 
     661from copy import copy 
     662class Quaternion(object): 
     663    def __init__(self, w, r): 
     664         self.w = w 
     665         self.r = np.asarray(r, dtype='d') 
     666    @staticmethod 
     667    def from_angle_axis(theta, r): 
     668         theta = np.radians(theta)/2 
     669         r = np.asarray(r) 
     670         w = np.cos(theta) 
     671         r = np.sin(theta)*r/np.dot(r,r) 
     672         return Quaternion(w, r) 
     673    def __mul__(self, other): 
     674        if isinstance(other, Quaternion): 
     675            w = self.w*other.w - np.dot(self.r, other.r) 
     676            r = self.w*other.r + other.w*self.r + np.cross(self.r, other.r) 
     677            return Quaternion(w, r) 
     678    def rot(self, v): 
     679        v = np.asarray(v).T 
     680        use_transpose = (v.shape[-1] != 3) 
     681        if use_transpose: v = v.T 
     682        v = v + np.cross(2*self.r, np.cross(self.r, v) + self.w*v) 
     683        #v = v + 2*self.w*np.cross(self.r, v) + np.cross(2*self.r, np.cross(self.r, v)) 
     684        if use_transpose: v = v.T 
     685        return v.T 
     686    def conj(self): 
     687        return Quaternion(self.w, -self.r) 
     688    def inv(self): 
     689        return self.conj()/self.norm()**2 
     690    def norm(self): 
     691        return np.sqrt(self.w**2 + np.sum(self.r**2)) 
     692    def __str__(self): 
     693        return "%g%+gi%+gj%+gk"%(self.w, self.r[0], self.r[1], self.r[2]) 
     694def test_qrot(): 
     695    # Define rotation of 60 degrees around an axis in y-z that is 60 degrees from y. 
     696    # The rotation axis is determined by rotating the point [0, 1, 0] about x. 
     697    ax = Quaternion.from_angle_axis(60, [1, 0, 0]).rot([0, 1, 0]) 
     698    q = Quaternion.from_angle_axis(60, ax) 
     699    # Set the point to be rotated, and its expected rotated position. 
     700    p = [1, -1, 2] 
     701    target = [(10+4*np.sqrt(3))/8, (1+2*np.sqrt(3))/8, (14-3*np.sqrt(3))/8] 
     702    #print(q, q.rot(p) - target) 
     703    assert max(abs(q.rot(p) - target)) < 1e-14 
     704#test_qrot() 
     705#import sys; sys.exit() 
    485706 
    486707# translate between number of dimension of dispersity and the number of 
     
    556777        vmin = vmax*10**-7 
    557778        #vmin, vmax = clipped_range(Iqxy, portion=portion, mode='top') 
     779    #vmin, vmax = Iqxy.min(), Iqxy.max() 
    558780    #print("range",(vmin,vmax)) 
    559781    #qx, qy = np.meshgrid(qx, qy) 
    560782    if 0: 
     783        from matplotlib import cm 
    561784        level = np.asarray(255*(Iqxy - vmin)/(vmax - vmin), 'i') 
    562785        level[level < 0] = 0 
    563786        colors = plt.get_cmap()(level) 
    564         axes.plot_surface(qx, qy, -1.1, rstride=1, cstride=1, facecolors=colors) 
     787        #colors = cm.coolwarm(level) 
     788        #colors = cm.gist_yarg(level) 
     789        #colors = cm.Wistia(level) 
     790        colors[level<=0, 3] = 0.  # set floor to transparent 
     791        x, y = np.meshgrid(qx/qx.max(), qy/qy.max()) 
     792        axes.plot_surface(x, y, -1.1*np.ones_like(x), facecolors=colors) 
    565793    elif 1: 
    566794        axes.contourf(qx/qx.max(), qy/qy.max(), Iqxy, zdir='z', offset=-1.1, 
     
    692920} 
    693921 
     922 
    694923def run(model_name='parallelepiped', size=(10, 40, 100), 
     924        view=(0, 0, 0), jitter=(0, 0, 0), 
    695925        dist='gaussian', mesh=30, 
    696926        projection='equirectangular'): 
     
    702932 
    703933    *size* gives the dimensions (a, b, c) of the shape. 
     934 
     935    *view* gives the initial view (theta, phi, psi) of the shape. 
     936 
     937    *view* gives the initial jitter (dtheta, dphi, dpsi) of the shape. 
    704938 
    705939    *dist* is the type of dispersition: gaussian, rectangle, or uniform. 
     
    721955    calculator, size = select_calculator(model_name, n=150, size=size) 
    722956    draw_shape = DRAW_SHAPES.get(model_name, draw_parallelepiped) 
     957    #draw_shape = draw_fcc 
    723958 
    724959    ## uncomment to set an independent the colour range for every view 
     
    726961    calculator.limits = None 
    727962 
    728     ## initial view 
    729     #theta, dtheta = 70., 10. 
    730     #phi, dphi = -45., 3. 
    731     #psi, dpsi = -45., 3. 
    732     theta, phi, psi = 0, 0, 0 
    733     dtheta, dphi, dpsi = 0, 0, 0 
     963    PLOT_ENGINE(calculator, draw_shape, size, view, jitter, dist, mesh, projection) 
     964 
     965def mpl_plot(calculator, draw_shape, size, view, jitter, dist, mesh, projection): 
     966    # Note: travis-ci does not support mpl_toolkits.mplot3d, but this shouldn't be 
     967    # an issue since we are lazy-loading the package on a path that isn't tested. 
     968    import mpl_toolkits.mplot3d  # Adds projection='3d' option to subplot 
     969    import matplotlib as mpl 
     970    import matplotlib.pyplot as plt 
     971    from matplotlib.widgets import Slider 
    734972 
    735973    ## create the plot window 
     
    752990 
    753991    ## add control widgets to plot 
    754     axes_theta = plt.axes([0.1, 0.15, 0.45, 0.04], **props) 
    755     axes_phi = plt.axes([0.1, 0.1, 0.45, 0.04], **props) 
    756     axes_psi = plt.axes([0.1, 0.05, 0.45, 0.04], **props) 
    757     stheta = Slider(axes_theta, 'Theta', -90, 90, valinit=theta) 
    758     sphi = Slider(axes_phi, 'Phi', -180, 180, valinit=phi) 
    759     spsi = Slider(axes_psi, 'Psi', -180, 180, valinit=psi) 
    760  
    761     axes_dtheta = plt.axes([0.75, 0.15, 0.15, 0.04], **props) 
    762     axes_dphi = plt.axes([0.75, 0.1, 0.15, 0.04], **props) 
    763     axes_dpsi = plt.axes([0.75, 0.05, 0.15, 0.04], **props) 
     992    axes_theta = plt.axes([0.05, 0.15, 0.50, 0.04], **props) 
     993    axes_phi = plt.axes([0.05, 0.10, 0.50, 0.04], **props) 
     994    axes_psi = plt.axes([0.05, 0.05, 0.50, 0.04], **props) 
     995    stheta = Slider(axes_theta, u'Ξ', -90, 90, valinit=0) 
     996    sphi = Slider(axes_phi, u'φ', -180, 180, valinit=0) 
     997    spsi = Slider(axes_psi, u'ψ', -180, 180, valinit=0) 
     998 
     999    axes_dtheta = plt.axes([0.70, 0.15, 0.20, 0.04], **props) 
     1000    axes_dphi = plt.axes([0.70, 0.1, 0.20, 0.04], **props) 
     1001    axes_dpsi = plt.axes([0.70, 0.05, 0.20, 0.04], **props) 
     1002 
    7641003    # Note: using ridiculous definition of rectangle distribution, whose width 
    7651004    # in sasmodels is sqrt(3) times the given width.  Divide by sqrt(3) to keep 
    7661005    # the maximum width to 90. 
    7671006    dlimit = DIST_LIMITS[dist] 
    768     sdtheta = Slider(axes_dtheta, 'dTheta', 0, 2*dlimit, valinit=dtheta) 
    769     sdphi = Slider(axes_dphi, 'dPhi', 0, 2*dlimit, valinit=dphi) 
    770     sdpsi = Slider(axes_dpsi, 'dPsi', 0, 2*dlimit, valinit=dpsi) 
    771  
     1007    sdtheta = Slider(axes_dtheta, u'Δξ', 0, 2*dlimit, valinit=0) 
     1008    sdphi = Slider(axes_dphi, u'Δφ', 0, 2*dlimit, valinit=0) 
     1009    sdpsi = Slider(axes_dpsi, u'Δψ', 0, 2*dlimit, valinit=0) 
     1010 
     1011    ## initial view and jitter 
     1012    theta, phi, psi = view 
     1013    stheta.set_val(theta) 
     1014    sphi.set_val(phi) 
     1015    spsi.set_val(psi) 
     1016    dtheta, dphi, dpsi = jitter 
     1017    sdtheta.set_val(dtheta) 
     1018    sdphi.set_val(dphi) 
     1019    sdpsi.set_val(dpsi) 
    7721020 
    7731021    ## callback to draw the new view 
     
    7771025        # set small jitter as 0 if multiple pd dims 
    7781026        dims = sum(v > 0 for v in jitter) 
    779         limit = [0, 0.5, 5][dims] 
     1027        limit = [0, 0.5, 5, 5][dims] 
    7801028        jitter = [0 if v < limit else v for v in jitter] 
    7811029        axes.cla() 
    782         draw_beam(axes, (0, 0)) 
    783         draw_jitter(axes, view, jitter, dist=dist, size=size, draw_shape=draw_shape) 
    784         #draw_jitter(axes, view, (0,0,0)) 
     1030 
     1031        ## Visualize as person on globe 
     1032        #draw_sphere(axes) 
     1033        #draw_person_on_sphere(axes, view) 
     1034 
     1035        ## Move beam instead of shape 
     1036        #draw_beam(axes, -view[:2]) 
     1037        #draw_jitter(axes, (0,0,0), (0,0,0), views=3) 
     1038 
     1039        ## Move shape and draw scattering 
     1040        draw_beam(axes, (0, 0), alpha=1.) 
     1041        draw_jitter(axes, view, jitter, dist=dist, size=size, alpha=1., 
     1042                    draw_shape=draw_shape, projection=projection, views=3) 
    7851043        draw_mesh(axes, view, jitter, dist=dist, n=mesh, projection=projection) 
    7861044        draw_scattering(calculator, axes, view, jitter, dist=dist) 
     1045 
    7871046        plt.gcf().canvas.draw() 
    7881047 
     
    8001059    ## go interactive 
    8011060    plt.show() 
     1061 
     1062 
     1063def map_colors(z, kw): 
     1064    from matplotlib import cm 
     1065 
     1066    cmap = kw.pop('cmap', cm.coolwarm) 
     1067    alpha = kw.pop('alpha', None) 
     1068    vmin = kw.pop('vmin', z.min()) 
     1069    vmax = kw.pop('vmax', z.max()) 
     1070    c = kw.pop('c', None) 
     1071    color = kw.pop('color', c) 
     1072    if color is None: 
     1073        znorm = ((z - vmin) / (vmax - vmin)).clip(0, 1) 
     1074        color = cmap(znorm) 
     1075    elif isinstance(color, np.ndarray) and color.shape == z.shape: 
     1076        color = cmap(color) 
     1077    if alpha is None: 
     1078        if isinstance(color, np.ndarray): 
     1079            color = color[..., :3] 
     1080    else: 
     1081        color[..., 3] = alpha 
     1082    kw['color'] = color 
     1083 
     1084def make_vec(*args): 
     1085    #return [np.asarray(v, 'd').flatten() for v in args] 
     1086    return [np.asarray(v, 'd') for v in args] 
     1087 
     1088def make_image(z, kw): 
     1089    import PIL.Image 
     1090    from matplotlib import cm 
     1091 
     1092    cmap = kw.pop('cmap', cm.coolwarm) 
     1093 
     1094    znorm = (z-z.min())/z.ptp() 
     1095    c = cmap(znorm) 
     1096    c = c[..., :3] 
     1097    rgb = np.asarray(c*255, 'u1') 
     1098    image = PIL.Image.fromarray(rgb, mode='RGB') 
     1099    return image 
     1100 
     1101 
     1102_IPV_MARKERS = { 
     1103    'o': 'sphere', 
     1104} 
     1105_IPV_COLORS = { 
     1106    'w': 'white', 
     1107    'k': 'black', 
     1108    'c': 'cyan', 
     1109    'm': 'magenta', 
     1110    'y': 'yellow', 
     1111    'r': 'red', 
     1112    'g': 'green', 
     1113    'b': 'blue', 
     1114} 
     1115def ipv_fix_color(kw): 
     1116    alpha = kw.pop('alpha', None) 
     1117    color = kw.get('color', None) 
     1118    if isinstance(color, str): 
     1119        color = _IPV_COLORS.get(color, color) 
     1120        kw['color'] = color 
     1121    if alpha is not None: 
     1122        color = kw['color'] 
     1123        #TODO: convert color to [r, g, b, a] if not already 
     1124        if isinstance(color, (tuple, list)): 
     1125            if len(color) == 3: 
     1126                color = (color[0], color[1], color[2], alpha) 
     1127            else: 
     1128                color = (color[0], color[1], color[2], alpha*color[3]) 
     1129            color = np.array(color) 
     1130        if isinstance(color, np.ndarray) and color.shape[-1] == 4: 
     1131            color[..., 3] = alpha 
     1132            kw['color'] = color 
     1133 
     1134def ipv_set_transparency(kw, obj): 
     1135    color = kw.get('color', None) 
     1136    if (isinstance(color, np.ndarray) 
     1137            and color.shape[-1] == 4 
     1138            and (color[..., 3] != 1.0).any()): 
     1139        obj.material.transparent = True 
     1140        obj.material.side = "FrontSide" 
     1141 
     1142def ipv_axes(): 
     1143    import ipyvolume as ipv 
     1144 
     1145    class Plotter: 
     1146        # transparency can be achieved by setting the following: 
     1147        #    mesh.color = [r, g, b, alpha] 
     1148        #    mesh.material.transparent = True 
     1149        #    mesh.material.side = "FrontSide" 
     1150        # smooth(ish) rotation can be achieved by setting: 
     1151        #    slide.continuous_update = True 
     1152        #    figure.animation = 0. 
     1153        #    mesh.material.x = x 
     1154        #    mesh.material.y = y 
     1155        #    mesh.material.z = z 
     1156        # maybe need to synchronize update of x/y/z to avoid shimmy when moving 
     1157        def plot(self, x, y, z, **kw): 
     1158            ipv_fix_color(kw) 
     1159            x, y, z = make_vec(x, y, z) 
     1160            ipv.plot(x, y, z, **kw) 
     1161        def plot_surface(self, x, y, z, **kw): 
     1162            facecolors = kw.pop('facecolors', None) 
     1163            if facecolors is not None: 
     1164                kw['color'] = facecolors 
     1165            ipv_fix_color(kw) 
     1166            x, y, z = make_vec(x, y, z) 
     1167            h = ipv.plot_surface(x, y, z, **kw) 
     1168            ipv_set_transparency(kw, h) 
     1169            #h.material.side = "DoubleSide" 
     1170            return h 
     1171        def plot_trisurf(self, x, y, triangles=None, Z=None, **kw): 
     1172            kw.pop('linewidth', None) 
     1173            ipv_fix_color(kw) 
     1174            x, y, z = make_vec(x, y, Z) 
     1175            if triangles is not None: 
     1176                triangles = np.asarray(triangles) 
     1177            h = ipv.plot_trisurf(x, y, z, triangles=triangles, **kw) 
     1178            ipv_set_transparency(kw, h) 
     1179            return h 
     1180        def scatter(self, x, y, z, **kw): 
     1181            x, y, z = make_vec(x, y, z) 
     1182            map_colors(z, kw) 
     1183            marker = kw.get('marker', None) 
     1184            kw['marker'] = _IPV_MARKERS.get(marker, marker) 
     1185            h = ipv.scatter(x, y, z, **kw) 
     1186            ipv_set_transparency(kw, h) 
     1187            return h 
     1188        def contourf(self, x, y, v, zdir='z', offset=0, levels=None, **kw): 
     1189            # Don't use contour for now (although we might want to later) 
     1190            self.pcolor(x, y, v, zdir='z', offset=offset, **kw) 
     1191        def pcolor(self, x, y, v, zdir='z', offset=0, **kw): 
     1192            x, y, v = make_vec(x, y, v) 
     1193            image = make_image(v, kw) 
     1194            xmin, xmax = x.min(), x.max() 
     1195            ymin, ymax = y.min(), y.max() 
     1196            x = np.array([[xmin, xmax], [xmin, xmax]]) 
     1197            y = np.array([[ymin, ymin], [ymax, ymax]]) 
     1198            z = x*0 + offset 
     1199            u = np.array([[0., 1], [0, 1]]) 
     1200            v = np.array([[0., 0], [1, 1]]) 
     1201            h = ipv.plot_mesh(x, y, z, u=u, v=v, texture=image, wireframe=False) 
     1202            ipv_set_transparency(kw, h) 
     1203            h.material.side = "DoubleSide" 
     1204            return h 
     1205        def text(self, *args, **kw): 
     1206            pass 
     1207        def set_xlim(self, limits): 
     1208            ipv.xlim(*limits) 
     1209        def set_ylim(self, limits): 
     1210            ipv.ylim(*limits) 
     1211        def set_zlim(self, limits): 
     1212            ipv.zlim(*limits) 
     1213        def set_axes_on(self): 
     1214            ipv.style.axis_on() 
     1215        def set_axis_off(self): 
     1216            ipv.style.axes_off() 
     1217    return Plotter() 
     1218 
     1219def ipv_plot(calculator, draw_shape, size, view, jitter, dist, mesh, projection): 
     1220    import ipywidgets as widgets 
     1221    import ipyvolume as ipv 
     1222 
     1223    axes = ipv_axes() 
     1224 
     1225    def draw(view, jitter): 
     1226        camera = ipv.gcf().camera 
     1227        #print(ipv.gcf().__dict__.keys()) 
     1228        #print(dir(ipv.gcf())) 
     1229        ipv.figure(animation=0.)  # no animation when updating object mesh 
     1230 
     1231        # set small jitter as 0 if multiple pd dims 
     1232        dims = sum(v > 0 for v in jitter) 
     1233        limit = [0, 0.5, 5, 5][dims] 
     1234        jitter = [0 if v < limit else v for v in jitter] 
     1235 
     1236        ## Visualize as person on globe 
     1237        #draw_beam(axes, (0, 0)) 
     1238        #draw_sphere(axes) 
     1239        #draw_person_on_sphere(axes, view) 
     1240 
     1241        ## Move beam instead of shape 
     1242        #draw_beam(axes, view=(-view[0], -view[1])) 
     1243        #draw_jitter(axes, view=(0,0,0), jitter=(0,0,0)) 
     1244 
     1245        ## Move shape and draw scattering 
     1246        draw_beam(axes, (0, 0), steps=25) 
     1247        draw_jitter(axes, view, jitter, dist=dist, size=size, alpha=1.0, 
     1248                    draw_shape=draw_shape, projection=projection) 
     1249        draw_mesh(axes, view, jitter, dist=dist, n=mesh, radius=0.95, 
     1250                  projection=projection) 
     1251        draw_scattering(calculator, axes, view, jitter, dist=dist) 
     1252 
     1253        draw_axes(axes, origin=(-1, -1, -1.1)) 
     1254        ipv.style.box_off() 
     1255        ipv.style.axes_off() 
     1256        ipv.xyzlabel(" ", " ", " ") 
     1257 
     1258        ipv.gcf().camera = camera 
     1259        ipv.show() 
     1260 
     1261 
     1262    trange, prange = (-180., 180., 1.), (-180., 180., 1.) 
     1263    dtrange, dprange = (0., 180., 1.), (0., 180., 1.) 
     1264 
     1265    ## Super simple interfaca, but uses non-ascii variable namese 
     1266    # Ξ φ ψ Δξ Δφ Δψ 
     1267    #def update(**kw): 
     1268    #    view = kw['Ξ'], kw['φ'], kw['ψ'] 
     1269    #    jitter = kw['Δξ'], kw['Δφ'], kw['Δψ'] 
     1270    #    draw(view, jitter) 
     1271    #widgets.interact(update, Ξ=trange, φ=prange, ψ=prange, Δξ=dtrange, Δφ=dprange, Δψ=dprange) 
     1272 
     1273    def update(theta, phi, psi, dtheta, dphi, dpsi): 
     1274        draw(view=(theta, phi, psi), jitter=(dtheta, dphi, dpsi)) 
     1275 
     1276    def slider(name, slice, init=0.): 
     1277        return widgets.FloatSlider( 
     1278            value=init, 
     1279            min=slice[0], 
     1280            max=slice[1], 
     1281            step=slice[2], 
     1282            description=name, 
     1283            disabled=False, 
     1284            #continuous_update=True, 
     1285            continuous_update=False, 
     1286            orientation='horizontal', 
     1287            readout=True, 
     1288            readout_format='.1f', 
     1289            ) 
     1290    theta = slider(u'Ξ', trange, view[0]) 
     1291    phi = slider(u'φ', prange, view[1]) 
     1292    psi = slider(u'ψ', prange, view[2]) 
     1293    dtheta = slider(u'Δξ', dtrange, jitter[0]) 
     1294    dphi = slider(u'Δφ', dprange, jitter[1]) 
     1295    dpsi = slider(u'Δψ', dprange, jitter[2]) 
     1296    fields = { 
     1297        'theta': theta, 'phi': phi, 'psi': psi, 
     1298        'dtheta': dtheta, 'dphi': dphi, 'dpsi': dpsi, 
     1299    } 
     1300    ui = widgets.HBox([ 
     1301        widgets.VBox([theta, phi, psi]), 
     1302        widgets.VBox([dtheta, dphi, dpsi]) 
     1303    ]) 
     1304 
     1305    out = widgets.interactive_output(update, fields) 
     1306    display(ui, out) 
     1307 
     1308 
     1309_ENGINES = { 
     1310    "matplotlib": mpl_plot, 
     1311    "mpl": mpl_plot, 
     1312    #"plotly": plotly_plot, 
     1313    "ipvolume": ipv_plot, 
     1314    "ipv": ipv_plot, 
     1315} 
     1316PLOT_ENGINE = _ENGINES["matplotlib"] 
     1317def set_plotter(name): 
     1318    global PLOT_ENGINE 
     1319    PLOT_ENGINE = _ENGINES[name] 
    8021320 
    8031321def main(): 
     
    8111329    parser.add_argument('-s', '--size', type=str, default='10,40,100', 
    8121330                        help='a,b,c lengths') 
     1331    parser.add_argument('-v', '--view', type=str, default='0,0,0', 
     1332                        help='initial view angles') 
     1333    parser.add_argument('-j', '--jitter', type=str, default='0,0,0', 
     1334                        help='initial angular dispersion') 
    8131335    parser.add_argument('-d', '--distribution', choices=DISTRIBUTIONS, 
    8141336                        default=DISTRIBUTIONS[0], 
     
    8191341                        help='oriented shape') 
    8201342    opts = parser.parse_args() 
    821     size = tuple(int(v) for v in opts.size.split(',')) 
    822     run(opts.shape, size=size, 
     1343    size = tuple(float(v) for v in opts.size.split(',')) 
     1344    view = tuple(float(v) for v in opts.view.split(',')) 
     1345    jitter = tuple(float(v) for v in opts.jitter.split(',')) 
     1346    run(opts.shape, size=size, view=view, jitter=jitter, 
    8231347        mesh=opts.mesh, dist=opts.distribution, 
    8241348        projection=opts.projection) 
Note: See TracChangeset for help on using the changeset viewer.