Changeset 674186e in sasmodels


Ignore:
Timestamp:
Mar 6, 2019 4:08:44 PM (5 years ago)
Author:
Paul Kienzle <pkienzle@…>
Branches:
master, core_shell_microgels, magnetic_model, ticket-1257-vesicle-product, ticket_1156, ticket_1265_superball, ticket_822_more_unit_tests
Children:
c11d09f
Parents:
31d5187 (diff), 9150036 (diff)
Note: this is a merge changeset, the changes displayed below correspond to the merge itself.
Use the (diff) links above to see all the changes relative to each parent.
Message:

Merge branch 'beta_approx' into webgl_jitter_viewer

Files:
8 edited

Legend:

Unmodified
Added
Removed
  • doc/guide/plugin.rst

    raa8c6e0 r9150036  
    272272structure factor to account for interactions between particles.  See 
    273273`Form_Factors`_ for more details. 
     274 
     275**model_info = ...** lets you define a model directly, for example, by 
     276loading and modifying existing models.  This is done implicitly by 
     277:func:`sasmodels.core.load_model_info`, which can create a mixture model 
     278from a pair of existing models.  For example:: 
     279 
     280    from sasmodels.core import load_model_info 
     281    model_info = load_model_info('sphere+cylinder') 
     282 
     283See :class:`sasmodels.modelinfo.ModelInfo` for details about the model 
     284attributes that are defined. 
    274285 
    275286Model Parameters 
     
    894905             - \frac{\sin(x)}{x}\left(\frac{1}{x} - \frac{3!}{x^3} + \frac{5!}{x^5} - \frac{7!}{x^7}\right) 
    895906 
    896         For small arguments , 
     907        For small arguments, 
    897908 
    898909        .. math:: 
  • example/multiscatfit.py

    r49d1f8b8 r2c4a190  
    1515 
    1616    # Show the model without fitting 
    17     PYTHONPATH=..:../explore:../../bumps:../../sasview/src python multiscatfit.py 
     17    PYTHONPATH=..:../../bumps:../../sasview/src python multiscatfit.py 
    1818 
    1919    # Run the fit 
    20     PYTHONPATH=..:../explore:../../bumps:../../sasview/src ../../bumps/run.py \ 
     20    PYTHONPATH=..:../../bumps:../../sasview/src ../../bumps/run.py \ 
    2121    multiscatfit.py --store=/tmp/t1 
    2222 
     
    5555    ) 
    5656 
     57# Tie the model to the data 
     58M = Experiment(data=data, model=model) 
     59 
     60# Stack mulitple scattering on top of the existing resolution function. 
     61M.resolution = MultipleScattering(resolution=M.resolution, probability=0.) 
     62 
    5763# SET THE FITTING PARAMETERS 
    5864model.radius_polar.range(15, 3000) 
     
    6571model.scale.range(0, 0.1) 
    6672 
    67 # Mulitple scattering probability parameter 
    68 # HACK: the probability is stuffed in as an extra parameter to the experiment. 
    69 probability = Parameter(name="probability", value=0.0) 
    70 probability.range(0.0, 0.9) 
     73# The multiple scattering probability parameter is in the resolution function 
     74# instead of the scattering function, so access it through M.resolution 
     75M.scattering_probability.range(0.0, 0.9) 
    7176 
    72 M = Experiment(data=data, model=model, extra_pars={'probability': probability}) 
    73  
    74 # Stack mulitple scattering on top of the existing resolution function. 
    75 # Because resolution functions in sasview don't have fitting parameters, 
    76 # we instead allow the multiple scattering calculator to take a function 
    77 # instead of a probability.  This function returns the current value of 
    78 # the parameter. ** THIS IS TEMPORARY ** when multiple scattering is 
    79 # properly integrated into sasmodels and sasview, its fittable parameter 
    80 # will be treated like the model parameters. 
    81 M.resolution = MultipleScattering(resolution=M.resolution, 
    82                                   probability=lambda: probability.value, 
    83                                   ) 
    84 M._kernel_inputs = M.resolution.q_calc 
     77# Let bumps know that we are fitting this experiment 
    8578problem = FitProblem(M) 
    8679 
  • sasmodels/__init__.py

    ra1ec908 r37f38ff  
    1414defining new models. 
    1515""" 
    16 __version__ = "0.98" 
     16__version__ = "0.99" 
    1717 
    1818def data_files(): 
  • sasmodels/bumps_model.py

    r49d1f8b8 r2c4a190  
    3535    # when bumps is not on the path. 
    3636    from bumps.names import Parameter # type: ignore 
     37    from bumps.parameter import Reference # type: ignore 
    3738except ImportError: 
    3839    pass 
     
    139140    def __init__(self, data, model, cutoff=1e-5, name=None, extra_pars=None): 
    140141        # type: (Data, Model, float) -> None 
     142        # Allow resolution function to define fittable parameters.  We do this 
     143        # by creating reference parameters within the resolution object rather 
     144        # than modifying the object itself to use bumps parameters.  We need 
     145        # to reset the parameters each time the object has changed.  These 
     146        # additional parameters need to be returned from the fitting engine. 
     147        # To make them available to the user, they are added as top-level 
     148        # attributes to the experiment object.  The only change to the 
     149        # resolution function is that it needs an optional 'fittable' attribute 
     150        # which maps the internal name to the user visible name for the 
     151        # for the parameter. 
     152        self._resolution = None 
     153        self._resolution_pars = {} 
    141154        # remember inputs so we can inspect from outside 
    142155        self.name = data.filename if name is None else name 
     
    145158        self._interpret_data(data, model.sasmodel) 
    146159        self._cache = {} 
     160        # CRUFT: no longer need extra parameters 
     161        # Multiple scattering probability is now retrieved directly from the 
     162        # multiple scattering resolution function. 
    147163        self.extra_pars = extra_pars 
    148164 
     
    162178        return len(self.Iq) 
    163179 
     180    @property 
     181    def resolution(self): 
     182        return self._resolution 
     183 
     184    @resolution.setter 
     185    def resolution(self, value): 
     186        self._resolution = value 
     187 
     188        # Remove old resolution fitting parameters from experiment 
     189        for name in self._resolution_pars: 
     190            delattr(self, name) 
     191 
     192        # Create new resolution fitting parameters 
     193        res_pars = getattr(self._resolution, 'fittable', {}) 
     194        self._resolution_pars = { 
     195            name: Reference(self._resolution, refname, name=name) 
     196            for refname, name in res_pars.items() 
     197        } 
     198 
     199        # Add new resolution fitting parameters as experiment attributes 
     200        for name, ref in self._resolution_pars.items(): 
     201            setattr(self, name, ref) 
     202 
    164203    def parameters(self): 
    165204        # type: () -> Dict[str, Parameter] 
     
    168207        """ 
    169208        pars = self.model.parameters() 
    170         if self.extra_pars: 
     209        if self.extra_pars is not None: 
    171210            pars.update(self.extra_pars) 
     211        pars.update(self._resolution_pars) 
    172212        return pars 
    173213 
  • sasmodels/direct_model.py

    rc1799d3 r9150036  
    224224            else: 
    225225                Iq, dIq = None, None 
    226             #self._theory = np.zeros_like(q) 
    227             q_vectors = [res.q_calc] 
    228226        elif self.data_type == 'Iqxy': 
    229227            #if not model.info.parameters.has_2d: 
     
    242240            res = resolution2d.Pinhole2D(data=data, index=index, 
    243241                                         nsigma=3.0, accuracy=accuracy) 
    244             #self._theory = np.zeros_like(self.Iq) 
    245             q_vectors = res.q_calc 
    246242        elif self.data_type == 'Iq': 
    247243            index = (data.x >= data.qmin) & (data.x <= data.qmax) 
     
    268264            else: 
    269265                res = resolution.Perfect1D(data.x[index]) 
    270  
    271             #self._theory = np.zeros_like(self.Iq) 
    272             q_vectors = [res.q_calc] 
    273266        elif self.data_type == 'Iq-oriented': 
    274267            index = (data.x >= data.qmin) & (data.x <= data.qmax) 
     
    286279                                      qx_width=data.dxw[index], 
    287280                                      qy_width=data.dxl[index]) 
    288             q_vectors = res.q_calc 
    289281        else: 
    290282            raise ValueError("Unknown data type") # never gets here 
     
    292284        # Remember function inputs so we can delay loading the function and 
    293285        # so we can save/restore state 
    294         self._kernel_inputs = q_vectors 
    295286        self._kernel = None 
    296287        self.Iq, self.dIq, self.index = Iq, dIq, index 
     
    329320        # type: (ParameterSet, float) -> np.ndarray 
    330321        if self._kernel is None: 
    331             self._kernel = self._model.make_kernel(self._kernel_inputs) 
     322            # TODO: change interfaces so that resolution returns kernel inputs 
     323            # Maybe have resolution always return a tuple, or maybe have 
     324            # make_kernel accept either an ndarray or a pair of ndarrays. 
     325            kernel_inputs = self.resolution.q_calc 
     326            if isinstance(kernel_inputs, np.ndarray): 
     327                kernel_inputs = (kernel_inputs,) 
     328            self._kernel = self._model.make_kernel(kernel_inputs) 
    332329 
    333330        # Need to pull background out of resolution for multiple scattering 
  • sasmodels/multiscat.py

    rb3703f5 r2c4a190  
    342342 
    343343    *probability* is related to the expected number of scattering 
    344     events in the sample $\lambda$ as $p = 1 = e^{-\lambda}$.  As a 
    345     hack to allow probability to be a fitted parameter, the "value" 
    346     can be a function that takes no parameters and returns the current 
    347     value of the probability.  *coverage* determines how many scattering 
    348     steps to consider.  The default is 0.99, which sets $n$ such that 
    349     $1 \ldots n$ covers 99% of the Poisson probability mass function. 
     344    events in the sample $\lambda$ as $p = 1 - e^{-\lambda}$. 
     345    *coverage* determines how many scattering steps to consider.  The 
     346    default is 0.99, which sets $n$ such that $1 \ldots n$ covers 99% 
     347    of the Poisson probability mass function. 
    350348 
    351349    *is2d* is True then 2D scattering is used, otherwise it accepts 
     
    399397        self.qmin = qmin 
    400398        self.nq = nq 
    401         self.probability = probability 
     399        self.probability = 0. if probability is None else probability 
    402400        self.coverage = coverage 
    403401        self.is2d = is2d 
     
    456454        self.Iqxy = None # type: np.ndarray 
    457455 
     456        # Label probability as a fittable parameter, and give its external name 
     457        # Note that the external name must be a valid python identifier, since 
     458        # is will be set as an experiment attribute. 
     459        self.fittable = {'probability': 'scattering_probability'} 
     460 
    458461    def apply(self, theory): 
    459462        if self.is2d: 
     
    463466        Iq_calc = Iq_calc.reshape(self.nq, self.nq) 
    464467 
     468        # CRUFT: don't need probability as a function anymore 
    465469        probability = self.probability() if callable(self.probability) else self.probability 
    466470        coverage = self.coverage 
  • sasmodels/sasview_model.py

    ra8a1d48 r9150036  
    2525from . import core 
    2626from . import custom 
     27from . import kernelcl 
    2728from . import product 
    2829from . import generate 
     
    3031from . import modelinfo 
    3132from .details import make_kernel_args, dispersion_mesh 
     33from .kernelcl import reset_environment 
    3234 
    3335# pylint: disable=unused-import 
     
    6870#: has changed since we last reloaded. 
    6971_CACHED_MODULE = {}  # type: Dict[str, "module"] 
     72 
     73def reset_environment(): 
     74    # type: () -> None 
     75    """ 
     76    Clear the compute engine context so that the GUI can change devices. 
     77 
     78    This removes all compiled kernels, even those that are active on fit 
     79    pages, but they will be restored the next time they are needed. 
     80    """ 
     81    kernelcl.reset_environment() 
     82    for model in MODELS.values(): 
     83        model._model = None 
    7084 
    7185def find_model(modelname): 
     
    696710    def _calculate_Iq(self, qx, qy=None): 
    697711        if self._model is None: 
    698             self._model = core.build_model(self._model_info) 
     712            # Only need one copy of the compiled kernel regardless of how many 
     713            # times it is used, so store it in the class.  Also, to reset the 
     714            # compute engine, need to clear out all existing compiled kernels, 
     715            # which is much easier to do if we store them in the class. 
     716            self.__class__._model = core.build_model(self._model_info) 
    699717        if qy is not None: 
    700718            q_vectors = [np.asarray(qx), np.asarray(qy)] 
  • sasmodels/jitter.py

    r7d97437 rcff2939  
    11#!/usr/bin/env python 
     2# -*- coding: utf-8 -*- 
    23""" 
    34Jitter Explorer 
     
    56 
    67Application to explore orientation angle and angular dispersity. 
     8 
     9From the command line:: 
     10 
     11    # Show docs 
     12    python -m sasmodels.jitter --help 
     13 
     14    # Guyou projection jitter, uniform over 20 degree theta and 10 in phi 
     15    python -m sasmodels.jitter --projection=guyou --dist=uniform --jitter=20,10,0 
     16 
     17From a jupyter cell:: 
     18 
     19    import ipyvolume as ipv 
     20    from sasmodels import jitter 
     21    import importlib; importlib.reload(jitter) 
     22    jitter.set_plotter("ipv") 
     23 
     24    size = (10, 40, 100) 
     25    view = (20, 0, 0) 
     26 
     27    #size = (15, 15, 100) 
     28    #view = (60, 60, 0) 
     29 
     30    dview = (0, 0, 0) 
     31    #dview = (5, 5, 0) 
     32    #dview = (15, 180, 0) 
     33    #dview = (180, 15, 0) 
     34 
     35    projection = 'equirectangular' 
     36    #projection = 'azimuthal_equidistance' 
     37    #projection = 'guyou' 
     38    #projection = 'sinusoidal' 
     39    #projection = 'azimuthal_equal_area' 
     40 
     41    dist = 'uniform' 
     42    #dist = 'gaussian' 
     43 
     44    jitter.run(size=size, view=view, jitter=dview, dist=dist, projection=projection) 
     45    #filename = projection+('_theta' if dview[0] == 180 else '_phi' if dview[1] == 180 else '') 
     46    #ipv.savefig(filename+'.png') 
    747""" 
    848from __future__ import division, print_function 
     
    1050import argparse 
    1151 
    12 try: # CRUFT: travis-ci does not support mpl_toolkits.mplot3d 
    13     import mpl_toolkits.mplot3d  # Adds projection='3d' option to subplot 
    14 except ImportError: 
    15     pass 
    16  
    17 import matplotlib as mpl 
    18 import matplotlib.pyplot as plt 
    19 from matplotlib.widgets import Slider 
    2052import numpy as np 
    2153from numpy import pi, cos, sin, sqrt, exp, degrees, radians 
    2254 
    23 def draw_beam(axes, view=(0, 0)): 
     55def draw_beam(axes, view=(0, 0), alpha=0.5, steps=25): 
    2456    """ 
    2557    Draw the beam going from source at (0, 0, 1) to detector at (0, 0, -1) 
    2658    """ 
    2759    #axes.plot([0,0],[0,0],[1,-1]) 
    28     #axes.scatter([0]*100,[0]*100,np.linspace(1, -1, 100), alpha=0.8) 
    29  
    30     steps = 25 
     60    #axes.scatter([0]*100,[0]*100,np.linspace(1, -1, 100), alpha=alpha) 
     61 
    3162    u = np.linspace(0, 2 * np.pi, steps) 
    32     v = np.linspace(-1, 1, steps) 
     63    v = np.linspace(-1, 1, 2) 
    3364 
    3465    r = 0.02 
     
    4273    points = Rz(phi)*Ry(theta)*points 
    4374    x, y, z = [v.reshape(shape) for v in points] 
    44  
    45     axes.plot_surface(x, y, z, rstride=4, cstride=4, color='y', alpha=0.5) 
     75    axes.plot_surface(x, y, z, color='yellow', alpha=alpha) 
     76 
     77    # TODO: draw endcaps on beam 
     78    ## Drawing tiny balls on the end will work 
     79    #draw_sphere(axes, radius=0.02, center=(0, 0, 1.3), color='yellow', alpha=alpha) 
     80    #draw_sphere(axes, radius=0.02, center=(0, 0, -1.3), color='yellow', alpha=alpha) 
     81    ## The following does not work 
     82    #triangles = [(0, i+1, i+2) for i in range(steps-2)] 
     83    #x_cap, y_cap = x[:, 0], y[:, 0] 
     84    #for z_cap in z[:, 0], z[:, -1]: 
     85    #    axes.plot_trisurf(x_cap, y_cap, z_cap, triangles, 
     86    #                      color='yellow', alpha=alpha) 
     87 
    4688 
    4789def draw_ellipsoid(axes, size, view, jitter, steps=25, alpha=1): 
     
    5597    x, y, z = transform_xyz(view, jitter, x, y, z) 
    5698 
    57     axes.plot_surface(x, y, z, rstride=4, cstride=4, color='w', alpha=alpha) 
     99    axes.plot_surface(x, y, z, color='w', alpha=alpha) 
    58100 
    59101    draw_labels(axes, view, jitter, [ 
     
    124166    return atoms 
    125167 
    126 def draw_parallelepiped(axes, size, view, jitter, steps=None, alpha=1): 
     168def draw_box(axes, size, view): 
     169    a, b, c = size 
     170    x = a*np.array([+1, -1, +1, -1, +1, -1, +1, -1]) 
     171    y = b*np.array([+1, +1, -1, -1, +1, +1, -1, -1]) 
     172    z = c*np.array([+1, +1, +1, +1, -1, -1, -1, -1]) 
     173    x, y, z = transform_xyz(view, None, x, y, z) 
     174    def draw(i, j): 
     175        axes.plot([x[i],x[j]], [y[i], y[j]], [z[i], z[j]], color='black') 
     176    draw(0, 1) 
     177    draw(0, 2) 
     178    draw(0, 3) 
     179    draw(7, 4) 
     180    draw(7, 5) 
     181    draw(7, 6) 
     182 
     183def draw_parallelepiped(axes, size, view, jitter, steps=None, 
     184                        color=(0.6, 1.0, 0.6), alpha=1): 
    127185    """Draw a parallelepiped.""" 
    128186    a, b, c = size 
     
    142200 
    143201    x, y, z = transform_xyz(view, jitter, x, y, z) 
    144     axes.plot_trisurf(x, y, triangles=tri, Z=z, color='w', alpha=alpha) 
    145  
    146     # Draw pink face on box. 
     202    axes.plot_trisurf(x, y, triangles=tri, Z=z, color=color, alpha=alpha, 
     203                      linewidth=0) 
     204 
     205    # Colour the c+ face of the box. 
    147206    # Since I can't control face color, instead draw a thin box situated just 
    148207    # in front of the "c+" face.  Use the c face so that rotations about psi 
    149208    # rotate that face. 
    150     if 1: 
     209    if 0: 
     210        color = (1, 0.6, 0.6)  # pink 
    151211        x = a*np.array([+1, -1, +1, -1, +1, -1, +1, -1]) 
    152212        y = b*np.array([+1, +1, -1, -1, +1, +1, -1, -1]) 
    153213        z = c*np.array([+1, +1, +1, +1, -1, -1, -1, -1]) 
    154214        x, y, z = transform_xyz(view, jitter, x, y, abs(z)+0.001) 
    155         axes.plot_trisurf(x, y, triangles=tri, Z=z, color=[1, 0.6, 0.6], alpha=alpha) 
     215        axes.plot_trisurf(x, y, triangles=tri, Z=z, color=color, alpha=alpha) 
    156216 
    157217    draw_labels(axes, view, jitter, [ 
     
    164224    ]) 
    165225 
    166 def draw_sphere(axes, radius=10., steps=100): 
     226def draw_sphere(axes, radius=0.5, steps=25, center=(0,0,0), color='w', alpha=1.): 
    167227    """Draw a sphere""" 
    168228    u = np.linspace(0, 2 * np.pi, steps) 
    169229    v = np.linspace(0, np.pi, steps) 
    170230 
    171     x = radius * np.outer(np.cos(u), np.sin(v)) 
    172     y = radius * np.outer(np.sin(u), np.sin(v)) 
    173     z = radius * np.outer(np.ones(np.size(u)), np.cos(v)) 
    174     axes.plot_surface(x, y, z, rstride=4, cstride=4, color='w') 
    175  
    176 def draw_jitter(axes, view, jitter, dist='gaussian', size=(0.1, 0.4, 1.0), 
    177                 draw_shape=draw_parallelepiped): 
     231    x = radius * np.outer(np.cos(u), np.sin(v)) + center[0] 
     232    y = radius * np.outer(np.sin(u), np.sin(v)) + center[1] 
     233    z = radius * np.outer(np.ones(np.size(u)), np.cos(v)) + center[2] 
     234    axes.plot_surface(x, y, z, color=color, alpha=alpha) 
     235    #axes.plot_wireframe(x, y, z) 
     236 
     237def draw_axes(axes, origin=(-1, -1, -1), length=(2, 2, 2)): 
     238    x, y, z = origin 
     239    dx, dy, dz = length 
     240    axes.plot([x, x+dx], [y, y], [z, z], color='black') 
     241    axes.plot([x, x], [y, y+dy], [z, z], color='black') 
     242    axes.plot([x, x], [y, y], [z, z+dz], color='black') 
     243 
     244def draw_person_on_sphere(axes, view, height=0.5, radius=0.5): 
     245    limb_offset = height * 0.05 
     246    head_radius = height * 0.10 
     247    head_height = height - head_radius 
     248    neck_length = head_radius * 0.50 
     249    shoulder_height = height - 2*head_radius - neck_length 
     250    torso_length = shoulder_height * 0.55 
     251    torso_radius = torso_length * 0.30 
     252    leg_length = shoulder_height - torso_length 
     253    arm_length = torso_length * 0.90 
     254 
     255    def _draw_part(x, z): 
     256        y = np.zeros_like(x) 
     257        xp, yp, zp = transform_xyz(view, None, x, y, z + radius) 
     258        axes.plot(xp, yp, zp, color='k') 
     259 
     260    # circle for head 
     261    u = np.linspace(0, 2 * np.pi, 40) 
     262    x = head_radius * np.cos(u) 
     263    z = head_radius * np.sin(u) + head_height 
     264    _draw_part(x, z) 
     265 
     266    # rectangle for body 
     267    x = np.array([-torso_radius, torso_radius, torso_radius, -torso_radius, -torso_radius]) 
     268    z = np.array([0., 0, torso_length, torso_length, 0]) + leg_length 
     269    _draw_part(x, z) 
     270 
     271    # arms 
     272    x = np.array([-torso_radius - limb_offset, -torso_radius - limb_offset, -torso_radius]) 
     273    z = np.array([shoulder_height - arm_length, shoulder_height, shoulder_height]) 
     274    _draw_part(x, z) 
     275    _draw_part(-x, z) 
     276 
     277    # legs 
     278    x = np.array([-torso_radius + limb_offset, -torso_radius + limb_offset]) 
     279    z = np.array([0, leg_length]) 
     280    _draw_part(x, z) 
     281    _draw_part(-x, z) 
     282 
     283    limits = [-radius-height, radius+height] 
     284    axes.set_xlim(limits) 
     285    axes.set_ylim(limits) 
     286    axes.set_zlim(limits) 
     287    axes.set_axis_off() 
     288 
     289def draw_jitter(axes, view, jitter, dist='gaussian', 
     290                size=(0.1, 0.4, 1.0), 
     291                draw_shape=draw_parallelepiped, 
     292                projection='equirectangular', 
     293                alpha=0.8, 
     294                views=None): 
    178295    """ 
    179296    Represent jitter as a set of shapes at different orientations. 
    180297    """ 
     298    project, project_weight = get_projection(projection) 
     299 
    181300    # set max diagonal to 0.95 
    182301    scale = 0.95/sqrt(sum(v**2 for v in size)) 
    183302    size = tuple(scale*v for v in size) 
    184303 
    185     #np.random.seed(10) 
    186     #cloud = np.random.randn(10,3) 
    187     cloud = [ 
    188         [-1, -1, -1], 
    189         [-1, -1, +0], 
    190         [-1, -1, +1], 
    191         [-1, +0, -1], 
    192         [-1, +0, +0], 
    193         [-1, +0, +1], 
    194         [-1, +1, -1], 
    195         [-1, +1, +0], 
    196         [-1, +1, +1], 
    197         [+0, -1, -1], 
    198         [+0, -1, +0], 
    199         [+0, -1, +1], 
    200         [+0, +0, -1], 
    201         [+0, +0, +0], 
    202         [+0, +0, +1], 
    203         [+0, +1, -1], 
    204         [+0, +1, +0], 
    205         [+0, +1, +1], 
    206         [+1, -1, -1], 
    207         [+1, -1, +0], 
    208         [+1, -1, +1], 
    209         [+1, +0, -1], 
    210         [+1, +0, +0], 
    211         [+1, +0, +1], 
    212         [+1, +1, -1], 
    213         [+1, +1, +0], 
    214         [+1, +1, +1], 
    215     ] 
    216304    dtheta, dphi, dpsi = jitter 
    217     if dtheta == 0: 
    218         cloud = [v for v in cloud if v[0] == 0] 
    219     if dphi == 0: 
    220         cloud = [v for v in cloud if v[1] == 0] 
    221     if dpsi == 0: 
    222         cloud = [v for v in cloud if v[2] == 0] 
    223     draw_shape(axes, size, view, [0, 0, 0], steps=100, alpha=0.8) 
    224     scale = {'gaussian':1, 'rectangle':1/sqrt(3), 'uniform':1/3}[dist] 
    225     for point in cloud: 
    226         delta = [scale*dtheta*point[0], scale*dphi*point[1], scale*dpsi*point[2]] 
    227         draw_shape(axes, size, view, delta, alpha=0.8) 
     305    base = {'gaussian':3, 'rectangle':sqrt(3), 'uniform':1}[dist] 
     306    def steps(delta): 
     307        if views is None: 
     308            n = max(3, min(25, 2*int(base*delta/5))) 
     309        else: 
     310            n = views 
     311        return base*delta*np.linspace(-1, 1, n) if delta > 0 else [0.] 
     312    for theta in steps(dtheta): 
     313        for phi in steps(dphi): 
     314            for psi in steps(dpsi): 
     315                w = project_weight(theta, phi, 1.0, 1.0) 
     316                if w > 0: 
     317                    dview = project(theta, phi, psi) 
     318                    draw_shape(axes, size, view, dview, alpha=alpha) 
    228319    for v in 'xyz': 
    229320        a, b, c = size 
    230321        lim = np.sqrt(a**2 + b**2 + c**2) 
    231322        getattr(axes, 'set_'+v+'lim')([-lim, lim]) 
    232         getattr(axes, v+'axis').label.set_text(v) 
     323        #getattr(axes, v+'axis').label.set_text(v) 
    233324 
    234325PROJECTIONS = [ 
     
    238329    'azimuthal_equal_area', 
    239330] 
    240 def draw_mesh(axes, view, jitter, radius=1.2, n=11, dist='gaussian', 
    241               projection='equirectangular'): 
    242     """ 
    243     Draw the dispersion mesh showing the theta-phi orientations at which 
    244     the model will be evaluated. 
    245  
     331def get_projection(projection): 
     332 
     333    """ 
    246334    jitter projections 
    247335    <https://en.wikipedia.org/wiki/List_of_map_projections> 
     
    299387    # TODO: try Kent distribution instead of a gaussian warped by projection 
    300388 
    301     dist_x = np.linspace(-1, 1, n) 
    302     weights = np.ones_like(dist_x) 
    303     if dist == 'gaussian': 
    304         dist_x *= 3 
    305         weights = exp(-0.5*dist_x**2) 
    306     elif dist == 'rectangle': 
    307         # Note: uses sasmodels ridiculous definition of rectangle width 
    308         dist_x *= sqrt(3) 
    309     elif dist == 'uniform': 
    310         pass 
    311     else: 
    312         raise ValueError("expected dist to be gaussian, rectangle or uniform") 
    313  
    314389    if projection == 'equirectangular':  #define PROJECTION 1 
    315         def _rotate(theta_i, phi_j): 
    316             return Rx(phi_j)*Ry(theta_i) 
     390        def _project(theta_i, phi_j, psi): 
     391            latitude, longitude = theta_i, phi_j 
     392            return latitude, longitude, psi 
     393            #return Rx(phi_j)*Ry(theta_i) 
    317394        def _weight(theta_i, phi_j, w_i, w_j): 
    318395            return w_i*w_j*abs(cos(radians(theta_i))) 
    319396    elif projection == 'sinusoidal':  #define PROJECTION 2 
    320         def _rotate(theta_i, phi_j): 
     397        def _project(theta_i, phi_j, psi): 
    321398            latitude = theta_i 
    322399            scale = cos(radians(latitude)) 
    323400            longitude = phi_j/scale if abs(phi_j) < abs(scale)*180 else 0 
    324401            #print("(%+7.2f, %+7.2f) => (%+7.2f, %+7.2f)"%(theta_i, phi_j, latitude, longitude)) 
    325             return Rx(longitude)*Ry(latitude) 
    326         def _weight(theta_i, phi_j, w_i, w_j): 
     402            return latitude, longitude, psi 
     403            #return Rx(longitude)*Ry(latitude) 
     404        def _project(theta_i, phi_j, w_i, w_j): 
    327405            latitude = theta_i 
    328406            scale = cos(radians(latitude)) 
     
    330408            return active*w_i*w_j 
    331409    elif projection == 'guyou':  #define PROJECTION 3  (eventually?) 
    332         def _rotate(theta_i, phi_j): 
     410        def _project(theta_i, phi_j, psi): 
    333411            from .guyou import guyou_invert 
    334412            #latitude, longitude = guyou_invert([theta_i], [phi_j]) 
    335413            longitude, latitude = guyou_invert([phi_j], [theta_i]) 
    336             return Rx(longitude[0])*Ry(latitude[0]) 
     414            return latitude, longitude, psi 
     415            #return Rx(longitude[0])*Ry(latitude[0]) 
    337416        def _weight(theta_i, phi_j, w_i, w_j): 
    338417            return w_i*w_j 
    339     elif projection == 'azimuthal_equidistance':  # Note: Rz Ry, not Rx Ry 
    340         def _rotate(theta_i, phi_j): 
     418    elif projection == 'azimuthal_equidistance': 
     419        # Note that calculates angles for Rz Ry rather than Rx Ry 
     420        def _project(theta_i, phi_j, psi): 
    341421            latitude = sqrt(theta_i**2 + phi_j**2) 
    342422            longitude = degrees(np.arctan2(phi_j, theta_i)) 
    343423            #print("(%+7.2f, %+7.2f) => (%+7.2f, %+7.2f)"%(theta_i, phi_j, latitude, longitude)) 
    344             return Rz(longitude)*Ry(latitude) 
     424            return latitude, longitude, psi-longitude, 'zyz' 
     425            #R = Rz(longitude)*Ry(latitude)*Rz(psi) 
     426            #return R_to_xyz(R) 
     427            #return Rz(longitude)*Ry(latitude) 
    345428        def _weight(theta_i, phi_j, w_i, w_j): 
    346429            # Weighting for each point comes from the integral: 
     
    376459            return weight*w_i*w_j if latitude < 180 else 0 
    377460    elif projection == 'azimuthal_equal_area': 
    378         def _rotate(theta_i, phi_j): 
     461        # Note that calculates angles for Rz Ry rather than Rx Ry 
     462        def _project(theta_i, phi_j, psi): 
    379463            radius = min(1, sqrt(theta_i**2 + phi_j**2)/180) 
    380464            latitude = 180-degrees(2*np.arccos(radius)) 
    381465            longitude = degrees(np.arctan2(phi_j, theta_i)) 
    382466            #print("(%+7.2f, %+7.2f) => (%+7.2f, %+7.2f)"%(theta_i, phi_j, latitude, longitude)) 
    383             return Rz(longitude)*Ry(latitude) 
     467            return latitude, longitude, psi, "zyz" 
     468            #R = Rz(longitude)*Ry(latitude)*Rz(psi) 
     469            #return R_to_xyz(R) 
     470            #return Rz(longitude)*Ry(latitude) 
    384471        def _weight(theta_i, phi_j, w_i, w_j): 
    385472            latitude = sqrt(theta_i**2 + phi_j**2) 
     
    389476        raise ValueError("unknown projection %r"%projection) 
    390477 
     478    return _project, _weight 
     479 
     480def R_to_xyz(R): 
     481    """ 
     482    Return phi, theta, psi Tait-Bryan angles corresponding to the given rotation matrix. 
     483 
     484    Extracting Euler Angles from a Rotation Matrix 
     485    Mike Day, Insomniac Games 
     486    https://d3cw3dd2w32x2b.cloudfront.net/wp-content/uploads/2012/07/euler-angles1.pdf 
     487    Based on: Shoemake’s "Euler Angle Conversion", Graphics Gems IV, pp.  222-229 
     488    """ 
     489    phi = np.arctan2(R[1, 2], R[2, 2]) 
     490    theta = np.arctan2(-R[0, 2], np.sqrt(R[0, 0]**2 + R[0, 1]**2)) 
     491    psi = np.arctan2(R[0, 1], R[0, 0]) 
     492    return np.degrees(phi), np.degrees(theta), np.degrees(psi) 
     493 
     494def draw_mesh(axes, view, jitter, radius=1.2, n=11, dist='gaussian', 
     495              projection='equirectangular'): 
     496    """ 
     497    Draw the dispersion mesh showing the theta-phi orientations at which 
     498    the model will be evaluated. 
     499    """ 
     500 
     501    _project, _weight = get_projection(projection) 
     502    def _rotate(theta, phi, z): 
     503        dview = _project(theta, phi, 0.) 
     504        if len(dview) == 4: # hack for zyz coords 
     505            return Rz(dview[1])*Ry(dview[0])*z 
     506        else: 
     507            return Rx(dview[1])*Ry(dview[0])*z 
     508 
     509 
     510    dist_x = np.linspace(-1, 1, n) 
     511    weights = np.ones_like(dist_x) 
     512    if dist == 'gaussian': 
     513        dist_x *= 3 
     514        weights = exp(-0.5*dist_x**2) 
     515    elif dist == 'rectangle': 
     516        # Note: uses sasmodels ridiculous definition of rectangle width 
     517        dist_x *= sqrt(3) 
     518    elif dist == 'uniform': 
     519        pass 
     520    else: 
     521        raise ValueError("expected dist to be gaussian, rectangle or uniform") 
     522 
    391523    # mesh in theta, phi formed by rotating z 
    392524    dtheta, dphi, dpsi = jitter 
    393525    z = np.matrix([[0], [0], [radius]]) 
    394     points = np.hstack([_rotate(theta_i, phi_j)*z 
     526    points = np.hstack([_rotate(theta_i, phi_j, z) 
    395527                        for theta_i in dtheta*dist_x 
    396528                        for phi_j in dphi*dist_x]) 
     
    470602    Points are stored in a 3 x n numpy matrix, not a numpy array or tuple. 
    471603    """ 
    472     dtheta, dphi, dpsi = jitter 
    473     points = Rx(dphi)*Ry(dtheta)*Rz(dpsi)*points 
     604    if jitter is None: 
     605        return points 
     606    # Hack to deal with the fact that azimuthal_equidistance uses euler angles 
     607    if len(jitter) == 4: 
     608        dtheta, dphi, dpsi, _ = jitter 
     609        points = Rz(dphi)*Ry(dtheta)*Rz(dpsi)*points 
     610    else: 
     611        dtheta, dphi, dpsi = jitter 
     612        points = Rx(dphi)*Ry(dtheta)*Rz(dpsi)*points 
    474613    return points 
    475614 
     
    481620    """ 
    482621    theta, phi, psi = view 
    483     points = Rz(phi)*Ry(theta)*Rz(psi)*points 
     622    points = Rz(phi)*Ry(theta)*Rz(psi)*points # viewing angle 
     623    #points = Rz(psi)*Ry(pi/2-theta)*Rz(phi)*points # 1-D integration angles 
     624    #points = Rx(phi)*Ry(theta)*Rz(psi)*points  # angular dispersion angle 
    484625    return points 
     626 
     627def orient_relative_to_beam_quaternion(view, points): 
     628    """ 
     629    Apply the view transform to a set of points. 
     630 
     631    Points are stored in a 3 x n numpy matrix, not a numpy array or tuple. 
     632 
     633    This variant uses quaternions rather than rotation matrices for the 
     634    computation.  It works but it is not used because it doesn't solve 
     635    any problems.  The challenge of mapping theta/phi/psi to SO(3) does 
     636    not disappear by calculating the transform differently. 
     637    """ 
     638    theta, phi, psi = view 
     639    x, y, z = [1, 0, 0], [0, 1, 0], [0, 0, 1] 
     640    q = Quaternion(1, [0, 0, 0]) 
     641    ## Compose a rotation about the three axes by rotating 
     642    ## the unit vectors before applying the rotation. 
     643    #q = Quaternion.from_angle_axis(theta, q.rot(x)) * q 
     644    #q = Quaternion.from_angle_axis(phi, q.rot(y)) * q 
     645    #q = Quaternion.from_angle_axis(psi, q.rot(z)) * q 
     646    ## The above turns out to be equivalent to reversing 
     647    ## the order of application, so ignore it and use below. 
     648    q = q * Quaternion.from_angle_axis(theta, x) 
     649    q = q * Quaternion.from_angle_axis(phi, y) 
     650    q = q * Quaternion.from_angle_axis(psi, z) 
     651    ## Reverse the order by post-multiply rather than pre-multiply 
     652    #q = Quaternion.from_angle_axis(theta, x) * q 
     653    #q = Quaternion.from_angle_axis(phi, y) * q 
     654    #q = Quaternion.from_angle_axis(psi, z) * q 
     655    #print("axes psi", q.rot(np.matrix([x, y, z]).T)) 
     656    return q.rot(points) 
     657#orient_relative_to_beam = orient_relative_to_beam_quaternion 
     658 
     659# Simple stand-alone quaternion class 
     660import numpy as np 
     661from copy import copy 
     662class Quaternion(object): 
     663    def __init__(self, w, r): 
     664         self.w = w 
     665         self.r = np.asarray(r, dtype='d') 
     666    @staticmethod 
     667    def from_angle_axis(theta, r): 
     668         theta = np.radians(theta)/2 
     669         r = np.asarray(r) 
     670         w = np.cos(theta) 
     671         r = np.sin(theta)*r/np.dot(r,r) 
     672         return Quaternion(w, r) 
     673    def __mul__(self, other): 
     674        if isinstance(other, Quaternion): 
     675            w = self.w*other.w - np.dot(self.r, other.r) 
     676            r = self.w*other.r + other.w*self.r + np.cross(self.r, other.r) 
     677            return Quaternion(w, r) 
     678    def rot(self, v): 
     679        v = np.asarray(v).T 
     680        use_transpose = (v.shape[-1] != 3) 
     681        if use_transpose: v = v.T 
     682        v = v + np.cross(2*self.r, np.cross(self.r, v) + self.w*v) 
     683        #v = v + 2*self.w*np.cross(self.r, v) + np.cross(2*self.r, np.cross(self.r, v)) 
     684        if use_transpose: v = v.T 
     685        return v.T 
     686    def conj(self): 
     687        return Quaternion(self.w, -self.r) 
     688    def inv(self): 
     689        return self.conj()/self.norm()**2 
     690    def norm(self): 
     691        return np.sqrt(self.w**2 + np.sum(self.r**2)) 
     692    def __str__(self): 
     693        return "%g%+gi%+gj%+gk"%(self.w, self.r[0], self.r[1], self.r[2]) 
     694def test_qrot(): 
     695    # Define rotation of 60 degrees around an axis in y-z that is 60 degrees from y. 
     696    # The rotation axis is determined by rotating the point [0, 1, 0] about x. 
     697    ax = Quaternion.from_angle_axis(60, [1, 0, 0]).rot([0, 1, 0]) 
     698    q = Quaternion.from_angle_axis(60, ax) 
     699    # Set the point to be rotated, and its expected rotated position. 
     700    p = [1, -1, 2] 
     701    target = [(10+4*np.sqrt(3))/8, (1+2*np.sqrt(3))/8, (14-3*np.sqrt(3))/8] 
     702    #print(q, q.rot(p) - target) 
     703    assert max(abs(q.rot(p) - target)) < 1e-14 
     704#test_qrot() 
     705#import sys; sys.exit() 
    485706 
    486707# translate between number of dimension of dispersity and the number of 
     
    556777        vmin = vmax*10**-7 
    557778        #vmin, vmax = clipped_range(Iqxy, portion=portion, mode='top') 
     779    #vmin, vmax = Iqxy.min(), Iqxy.max() 
    558780    #print("range",(vmin,vmax)) 
    559781    #qx, qy = np.meshgrid(qx, qy) 
    560782    if 0: 
     783        from matplotlib import cm 
    561784        level = np.asarray(255*(Iqxy - vmin)/(vmax - vmin), 'i') 
    562785        level[level < 0] = 0 
    563786        colors = plt.get_cmap()(level) 
    564         axes.plot_surface(qx, qy, -1.1, rstride=1, cstride=1, facecolors=colors) 
     787        #colors = cm.coolwarm(level) 
     788        #colors = cm.gist_yarg(level) 
     789        #colors = cm.Wistia(level) 
     790        colors[level<=0, 3] = 0.  # set floor to transparent 
     791        x, y = np.meshgrid(qx/qx.max(), qy/qy.max()) 
     792        axes.plot_surface(x, y, -1.1*np.ones_like(x), facecolors=colors) 
    565793    elif 1: 
    566794        axes.contourf(qx/qx.max(), qy/qy.max(), Iqxy, zdir='z', offset=-1.1, 
     
    692920} 
    693921 
     922 
    694923def run(model_name='parallelepiped', size=(10, 40, 100), 
     924        view=(0, 0, 0), jitter=(0, 0, 0), 
    695925        dist='gaussian', mesh=30, 
    696926        projection='equirectangular'): 
     
    702932 
    703933    *size* gives the dimensions (a, b, c) of the shape. 
     934 
     935    *view* gives the initial view (theta, phi, psi) of the shape. 
     936 
     937    *view* gives the initial jitter (dtheta, dphi, dpsi) of the shape. 
    704938 
    705939    *dist* is the type of dispersition: gaussian, rectangle, or uniform. 
     
    721955    calculator, size = select_calculator(model_name, n=150, size=size) 
    722956    draw_shape = DRAW_SHAPES.get(model_name, draw_parallelepiped) 
     957    #draw_shape = draw_fcc 
    723958 
    724959    ## uncomment to set an independent the colour range for every view 
     
    726961    calculator.limits = None 
    727962 
    728     ## initial view 
    729     #theta, dtheta = 70., 10. 
    730     #phi, dphi = -45., 3. 
    731     #psi, dpsi = -45., 3. 
    732     theta, phi, psi = 0, 0, 0 
    733     dtheta, dphi, dpsi = 0, 0, 0 
     963    PLOT_ENGINE(calculator, draw_shape, size, view, jitter, dist, mesh, projection) 
     964 
     965def mpl_plot(calculator, draw_shape, size, view, jitter, dist, mesh, projection): 
     966    # Note: travis-ci does not support mpl_toolkits.mplot3d, but this shouldn't be 
     967    # an issue since we are lazy-loading the package on a path that isn't tested. 
     968    import mpl_toolkits.mplot3d  # Adds projection='3d' option to subplot 
     969    import matplotlib as mpl 
     970    import matplotlib.pyplot as plt 
     971    from matplotlib.widgets import Slider 
    734972 
    735973    ## create the plot window 
     
    752990 
    753991    ## add control widgets to plot 
    754     axes_theta = plt.axes([0.1, 0.15, 0.45, 0.04], **props) 
    755     axes_phi = plt.axes([0.1, 0.1, 0.45, 0.04], **props) 
    756     axes_psi = plt.axes([0.1, 0.05, 0.45, 0.04], **props) 
    757     stheta = Slider(axes_theta, 'Theta', -90, 90, valinit=theta) 
    758     sphi = Slider(axes_phi, 'Phi', -180, 180, valinit=phi) 
    759     spsi = Slider(axes_psi, 'Psi', -180, 180, valinit=psi) 
    760  
    761     axes_dtheta = plt.axes([0.75, 0.15, 0.15, 0.04], **props) 
    762     axes_dphi = plt.axes([0.75, 0.1, 0.15, 0.04], **props) 
    763     axes_dpsi = plt.axes([0.75, 0.05, 0.15, 0.04], **props) 
     992    axes_theta = plt.axes([0.05, 0.15, 0.50, 0.04], **props) 
     993    axes_phi = plt.axes([0.05, 0.10, 0.50, 0.04], **props) 
     994    axes_psi = plt.axes([0.05, 0.05, 0.50, 0.04], **props) 
     995    stheta = Slider(axes_theta, u'Ξ', -90, 90, valinit=0) 
     996    sphi = Slider(axes_phi, u'φ', -180, 180, valinit=0) 
     997    spsi = Slider(axes_psi, u'ψ', -180, 180, valinit=0) 
     998 
     999    axes_dtheta = plt.axes([0.70, 0.15, 0.20, 0.04], **props) 
     1000    axes_dphi = plt.axes([0.70, 0.1, 0.20, 0.04], **props) 
     1001    axes_dpsi = plt.axes([0.70, 0.05, 0.20, 0.04], **props) 
     1002 
    7641003    # Note: using ridiculous definition of rectangle distribution, whose width 
    7651004    # in sasmodels is sqrt(3) times the given width.  Divide by sqrt(3) to keep 
    7661005    # the maximum width to 90. 
    7671006    dlimit = DIST_LIMITS[dist] 
    768     sdtheta = Slider(axes_dtheta, 'dTheta', 0, 2*dlimit, valinit=dtheta) 
    769     sdphi = Slider(axes_dphi, 'dPhi', 0, 2*dlimit, valinit=dphi) 
    770     sdpsi = Slider(axes_dpsi, 'dPsi', 0, 2*dlimit, valinit=dpsi) 
    771  
     1007    sdtheta = Slider(axes_dtheta, u'Δξ', 0, 2*dlimit, valinit=0) 
     1008    sdphi = Slider(axes_dphi, u'Δφ', 0, 2*dlimit, valinit=0) 
     1009    sdpsi = Slider(axes_dpsi, u'Δψ', 0, 2*dlimit, valinit=0) 
     1010 
     1011    ## initial view and jitter 
     1012    theta, phi, psi = view 
     1013    stheta.set_val(theta) 
     1014    sphi.set_val(phi) 
     1015    spsi.set_val(psi) 
     1016    dtheta, dphi, dpsi = jitter 
     1017    sdtheta.set_val(dtheta) 
     1018    sdphi.set_val(dphi) 
     1019    sdpsi.set_val(dpsi) 
    7721020 
    7731021    ## callback to draw the new view 
     
    7771025        # set small jitter as 0 if multiple pd dims 
    7781026        dims = sum(v > 0 for v in jitter) 
    779         limit = [0, 0.5, 5][dims] 
     1027        limit = [0, 0.5, 5, 5][dims] 
    7801028        jitter = [0 if v < limit else v for v in jitter] 
    7811029        axes.cla() 
    782         draw_beam(axes, (0, 0)) 
    783         draw_jitter(axes, view, jitter, dist=dist, size=size, draw_shape=draw_shape) 
    784         #draw_jitter(axes, view, (0,0,0)) 
     1030 
     1031        ## Visualize as person on globe 
     1032        #draw_sphere(axes) 
     1033        #draw_person_on_sphere(axes, view) 
     1034 
     1035        ## Move beam instead of shape 
     1036        #draw_beam(axes, -view[:2]) 
     1037        #draw_jitter(axes, (0,0,0), (0,0,0), views=3) 
     1038 
     1039        ## Move shape and draw scattering 
     1040        draw_beam(axes, (0, 0), alpha=1.) 
     1041        draw_jitter(axes, view, jitter, dist=dist, size=size, alpha=1., 
     1042                    draw_shape=draw_shape, projection=projection, views=3) 
    7851043        draw_mesh(axes, view, jitter, dist=dist, n=mesh, projection=projection) 
    7861044        draw_scattering(calculator, axes, view, jitter, dist=dist) 
     1045 
    7871046        plt.gcf().canvas.draw() 
    7881047 
     
    8001059    ## go interactive 
    8011060    plt.show() 
     1061 
     1062 
     1063def map_colors(z, kw): 
     1064    from matplotlib import cm 
     1065 
     1066    cmap = kw.pop('cmap', cm.coolwarm) 
     1067    alpha = kw.pop('alpha', None) 
     1068    vmin = kw.pop('vmin', z.min()) 
     1069    vmax = kw.pop('vmax', z.max()) 
     1070    c = kw.pop('c', None) 
     1071    color = kw.pop('color', c) 
     1072    if color is None: 
     1073        znorm = ((z - vmin) / (vmax - vmin)).clip(0, 1) 
     1074        color = cmap(znorm) 
     1075    elif isinstance(color, np.ndarray) and color.shape == z.shape: 
     1076        color = cmap(color) 
     1077    if alpha is None: 
     1078        if isinstance(color, np.ndarray): 
     1079            color = color[..., :3] 
     1080    else: 
     1081        color[..., 3] = alpha 
     1082    kw['color'] = color 
     1083 
     1084def make_vec(*args): 
     1085    #return [np.asarray(v, 'd').flatten() for v in args] 
     1086    return [np.asarray(v, 'd') for v in args] 
     1087 
     1088def make_image(z, kw): 
     1089    import PIL.Image 
     1090    from matplotlib import cm 
     1091 
     1092    cmap = kw.pop('cmap', cm.coolwarm) 
     1093 
     1094    znorm = (z-z.min())/z.ptp() 
     1095    c = cmap(znorm) 
     1096    c = c[..., :3] 
     1097    rgb = np.asarray(c*255, 'u1') 
     1098    image = PIL.Image.fromarray(rgb, mode='RGB') 
     1099    return image 
     1100 
     1101 
     1102_IPV_MARKERS = { 
     1103    'o': 'sphere', 
     1104} 
     1105_IPV_COLORS = { 
     1106    'w': 'white', 
     1107    'k': 'black', 
     1108    'c': 'cyan', 
     1109    'm': 'magenta', 
     1110    'y': 'yellow', 
     1111    'r': 'red', 
     1112    'g': 'green', 
     1113    'b': 'blue', 
     1114} 
     1115def ipv_fix_color(kw): 
     1116    alpha = kw.pop('alpha', None) 
     1117    color = kw.get('color', None) 
     1118    if isinstance(color, str): 
     1119        color = _IPV_COLORS.get(color, color) 
     1120        kw['color'] = color 
     1121    if alpha is not None: 
     1122        color = kw['color'] 
     1123        #TODO: convert color to [r, g, b, a] if not already 
     1124        if isinstance(color, (tuple, list)): 
     1125            if len(color) == 3: 
     1126                color = (color[0], color[1], color[2], alpha) 
     1127            else: 
     1128                color = (color[0], color[1], color[2], alpha*color[3]) 
     1129            color = np.array(color) 
     1130        if isinstance(color, np.ndarray) and color.shape[-1] == 4: 
     1131            color[..., 3] = alpha 
     1132            kw['color'] = color 
     1133 
     1134def ipv_set_transparency(kw, obj): 
     1135    color = kw.get('color', None) 
     1136    if (isinstance(color, np.ndarray) 
     1137            and color.shape[-1] == 4 
     1138            and (color[..., 3] != 1.0).any()): 
     1139        obj.material.transparent = True 
     1140        obj.material.side = "FrontSide" 
     1141 
     1142def ipv_axes(): 
     1143    import ipyvolume as ipv 
     1144 
     1145    class Plotter: 
     1146        # transparency can be achieved by setting the following: 
     1147        #    mesh.color = [r, g, b, alpha] 
     1148        #    mesh.material.transparent = True 
     1149        #    mesh.material.side = "FrontSide" 
     1150        # smooth(ish) rotation can be achieved by setting: 
     1151        #    slide.continuous_update = True 
     1152        #    figure.animation = 0. 
     1153        #    mesh.material.x = x 
     1154        #    mesh.material.y = y 
     1155        #    mesh.material.z = z 
     1156        # maybe need to synchronize update of x/y/z to avoid shimmy when moving 
     1157        def plot(self, x, y, z, **kw): 
     1158            ipv_fix_color(kw) 
     1159            x, y, z = make_vec(x, y, z) 
     1160            ipv.plot(x, y, z, **kw) 
     1161        def plot_surface(self, x, y, z, **kw): 
     1162            facecolors = kw.pop('facecolors', None) 
     1163            if facecolors is not None: 
     1164                kw['color'] = facecolors 
     1165            ipv_fix_color(kw) 
     1166            x, y, z = make_vec(x, y, z) 
     1167            h = ipv.plot_surface(x, y, z, **kw) 
     1168            ipv_set_transparency(kw, h) 
     1169            #h.material.side = "DoubleSide" 
     1170            return h 
     1171        def plot_trisurf(self, x, y, triangles=None, Z=None, **kw): 
     1172            kw.pop('linewidth', None) 
     1173            ipv_fix_color(kw) 
     1174            x, y, z = make_vec(x, y, Z) 
     1175            if triangles is not None: 
     1176                triangles = np.asarray(triangles) 
     1177            h = ipv.plot_trisurf(x, y, z, triangles=triangles, **kw) 
     1178            ipv_set_transparency(kw, h) 
     1179            return h 
     1180        def scatter(self, x, y, z, **kw): 
     1181            x, y, z = make_vec(x, y, z) 
     1182            map_colors(z, kw) 
     1183            marker = kw.get('marker', None) 
     1184            kw['marker'] = _IPV_MARKERS.get(marker, marker) 
     1185            h = ipv.scatter(x, y, z, **kw) 
     1186            ipv_set_transparency(kw, h) 
     1187            return h 
     1188        def contourf(self, x, y, v, zdir='z', offset=0, levels=None, **kw): 
     1189            # Don't use contour for now (although we might want to later) 
     1190            self.pcolor(x, y, v, zdir='z', offset=offset, **kw) 
     1191        def pcolor(self, x, y, v, zdir='z', offset=0, **kw): 
     1192            x, y, v = make_vec(x, y, v) 
     1193            image = make_image(v, kw) 
     1194            xmin, xmax = x.min(), x.max() 
     1195            ymin, ymax = y.min(), y.max() 
     1196            x = np.array([[xmin, xmax], [xmin, xmax]]) 
     1197            y = np.array([[ymin, ymin], [ymax, ymax]]) 
     1198            z = x*0 + offset 
     1199            u = np.array([[0., 1], [0, 1]]) 
     1200            v = np.array([[0., 0], [1, 1]]) 
     1201            h = ipv.plot_mesh(x, y, z, u=u, v=v, texture=image, wireframe=False) 
     1202            ipv_set_transparency(kw, h) 
     1203            h.material.side = "DoubleSide" 
     1204            return h 
     1205        def text(self, *args, **kw): 
     1206            pass 
     1207        def set_xlim(self, limits): 
     1208            ipv.xlim(*limits) 
     1209        def set_ylim(self, limits): 
     1210            ipv.ylim(*limits) 
     1211        def set_zlim(self, limits): 
     1212            ipv.zlim(*limits) 
     1213        def set_axes_on(self): 
     1214            ipv.style.axis_on() 
     1215        def set_axis_off(self): 
     1216            ipv.style.axes_off() 
     1217    return Plotter() 
     1218 
     1219def ipv_plot(calculator, draw_shape, size, view, jitter, dist, mesh, projection): 
     1220    import ipywidgets as widgets 
     1221    import ipyvolume as ipv 
     1222 
     1223    axes = ipv_axes() 
     1224 
     1225    def draw(view, jitter): 
     1226        camera = ipv.gcf().camera 
     1227        #print(ipv.gcf().__dict__.keys()) 
     1228        #print(dir(ipv.gcf())) 
     1229        ipv.figure(animation=0.)  # no animation when updating object mesh 
     1230 
     1231        # set small jitter as 0 if multiple pd dims 
     1232        dims = sum(v > 0 for v in jitter) 
     1233        limit = [0, 0.5, 5, 5][dims] 
     1234        jitter = [0 if v < limit else v for v in jitter] 
     1235 
     1236        ## Visualize as person on globe 
     1237        #draw_beam(axes, (0, 0)) 
     1238        #draw_sphere(axes) 
     1239        #draw_person_on_sphere(axes, view) 
     1240 
     1241        ## Move beam instead of shape 
     1242        #draw_beam(axes, view=(-view[0], -view[1])) 
     1243        #draw_jitter(axes, view=(0,0,0), jitter=(0,0,0)) 
     1244 
     1245        ## Move shape and draw scattering 
     1246        draw_beam(axes, (0, 0), steps=25) 
     1247        draw_jitter(axes, view, jitter, dist=dist, size=size, alpha=1.0, 
     1248                    draw_shape=draw_shape, projection=projection) 
     1249        draw_mesh(axes, view, jitter, dist=dist, n=mesh, radius=0.95, 
     1250                  projection=projection) 
     1251        draw_scattering(calculator, axes, view, jitter, dist=dist) 
     1252 
     1253        draw_axes(axes, origin=(-1, -1, -1.1)) 
     1254        ipv.style.box_off() 
     1255        ipv.style.axes_off() 
     1256        ipv.xyzlabel(" ", " ", " ") 
     1257 
     1258        ipv.gcf().camera = camera 
     1259        ipv.show() 
     1260 
     1261 
     1262    trange, prange = (-180., 180., 1.), (-180., 180., 1.) 
     1263    dtrange, dprange = (0., 180., 1.), (0., 180., 1.) 
     1264 
     1265    ## Super simple interfaca, but uses non-ascii variable namese 
     1266    # Ξ φ ψ Δξ Δφ Δψ 
     1267    #def update(**kw): 
     1268    #    view = kw['Ξ'], kw['φ'], kw['ψ'] 
     1269    #    jitter = kw['Δξ'], kw['Δφ'], kw['Δψ'] 
     1270    #    draw(view, jitter) 
     1271    #widgets.interact(update, Ξ=trange, φ=prange, ψ=prange, Δξ=dtrange, Δφ=dprange, Δψ=dprange) 
     1272 
     1273    def update(theta, phi, psi, dtheta, dphi, dpsi): 
     1274        draw(view=(theta, phi, psi), jitter=(dtheta, dphi, dpsi)) 
     1275 
     1276    def slider(name, slice, init=0.): 
     1277        return widgets.FloatSlider( 
     1278            value=init, 
     1279            min=slice[0], 
     1280            max=slice[1], 
     1281            step=slice[2], 
     1282            description=name, 
     1283            disabled=False, 
     1284            #continuous_update=True, 
     1285            continuous_update=False, 
     1286            orientation='horizontal', 
     1287            readout=True, 
     1288            readout_format='.1f', 
     1289            ) 
     1290    theta = slider(u'Ξ', trange, view[0]) 
     1291    phi = slider(u'φ', prange, view[1]) 
     1292    psi = slider(u'ψ', prange, view[2]) 
     1293    dtheta = slider(u'Δξ', dtrange, jitter[0]) 
     1294    dphi = slider(u'Δφ', dprange, jitter[1]) 
     1295    dpsi = slider(u'Δψ', dprange, jitter[2]) 
     1296    fields = { 
     1297        'theta': theta, 'phi': phi, 'psi': psi, 
     1298        'dtheta': dtheta, 'dphi': dphi, 'dpsi': dpsi, 
     1299    } 
     1300    ui = widgets.HBox([ 
     1301        widgets.VBox([theta, phi, psi]), 
     1302        widgets.VBox([dtheta, dphi, dpsi]) 
     1303    ]) 
     1304 
     1305    out = widgets.interactive_output(update, fields) 
     1306    display(ui, out) 
     1307 
     1308 
     1309_ENGINES = { 
     1310    "matplotlib": mpl_plot, 
     1311    "mpl": mpl_plot, 
     1312    #"plotly": plotly_plot, 
     1313    "ipvolume": ipv_plot, 
     1314    "ipv": ipv_plot, 
     1315} 
     1316PLOT_ENGINE = _ENGINES["matplotlib"] 
     1317def set_plotter(name): 
     1318    global PLOT_ENGINE 
     1319    PLOT_ENGINE = _ENGINES[name] 
    8021320 
    8031321def main(): 
     
    8111329    parser.add_argument('-s', '--size', type=str, default='10,40,100', 
    8121330                        help='a,b,c lengths') 
     1331    parser.add_argument('-v', '--view', type=str, default='0,0,0', 
     1332                        help='initial view angles') 
     1333    parser.add_argument('-j', '--jitter', type=str, default='0,0,0', 
     1334                        help='initial angular dispersion') 
    8131335    parser.add_argument('-d', '--distribution', choices=DISTRIBUTIONS, 
    8141336                        default=DISTRIBUTIONS[0], 
     
    8191341                        help='oriented shape') 
    8201342    opts = parser.parse_args() 
    821     size = tuple(int(v) for v in opts.size.split(',')) 
    822     run(opts.shape, size=size, 
     1343    size = tuple(float(v) for v in opts.size.split(',')) 
     1344    view = tuple(float(v) for v in opts.view.split(',')) 
     1345    jitter = tuple(float(v) for v in opts.jitter.split(',')) 
     1346    run(opts.shape, size=size, view=view, jitter=jitter, 
    8231347        mesh=opts.mesh, dist=opts.distribution, 
    8241348        projection=opts.projection) 
Note: See TracChangeset for help on using the changeset viewer.