HOME/Articles/

matplotlib example py3simplex (snippet)

Article Outline

Python matplotlib example 'py3simplex'

Functions in program:

  • def projectSimplex(points):
  • def plotSimplex(points, fig=None,

Modules used in program:

  • import matplotlib.patches as PA
  • import matplotlib.colors as C
  • import matplotlib.cm as CM
  • import matplotlib.lines as L
  • import matplotlib.ticker as MT
  • import matplotlib.pyplot as P
  • import numpy as NP

python py3simplex

Python matplotlib example: py3simplex

"""
Visualize points on the 3-simplex (eg, the parameters of a
3-dimensional multinomial distributions) as a scatter plot 
contained within a 2D triangle.

David Andrzejewski ([email protected])
"""
import numpy as NP
import matplotlib.pyplot as P
import matplotlib.ticker as MT
import matplotlib.lines as L
import matplotlib.cm as CM
import matplotlib.colors as C
import matplotlib.patches as PA

def plotSimplex(points, fig=None, 
                vertexlabels=['1','2','3'],
                **kwargs):
    """
    Plot Nx3 points array on the 3-simplex 
    (with optionally labeled vertices) 

    kwargs will be passed along directly to matplotlib.pyplot.scatter    

    Returns Figure, caller must .show()
    """
    if(fig == None):        
        fig = P.figure()
    # Draw the triangle
    l1 = L.Line2D([0, 0.5, 1.0, 0], # xcoords
                  [0, NP.sqrt(3) / 2, 0, 0], # ycoords
                  color='k')
    fig.gca().add_line(l1)
    fig.gca().xaxis.set_major_locator(MT.NullLocator())
    fig.gca().yaxis.set_major_locator(MT.NullLocator())
    # Draw vertex labels
    fig.gca().text(-0.05, -0.05, vertexlabels[0])
    fig.gca().text(1.05, -0.05, vertexlabels[1])
    fig.gca().text(0.5, NP.sqrt(3) / 2 + 0.05, vertexlabels[2])
    # Project and draw the actual points
    projected = projectSimplex(points)
    P.scatter(projected[:,0], projected[:,1], **kwargs)              
    # Leave some buffer around the triangle for vertex labels
    fig.gca().set_xlim(-0.2, 1.2)
    fig.gca().set_ylim(-0.2, 1.2)

    return fig    

def projectSimplex(points):
    """ 
    Project probabilities on the 3-simplex to a 2D triangle

    N points are given as N x 3 array
    """
    # Convert points one at a time
    tripts = NP.zeros((points.shape[0],2))
    for idx in range(points.shape[0]):
        # Init to triangle centroid
        x = 1.0 / 2
        y = 1.0 / (2 * NP.sqrt(3))
        # Vector 1 - bisect out of lower left vertex 
        p1 = points[idx, 0]
        x = x - (1.0 / NP.sqrt(3)) * p1 * NP.cos(NP.pi / 6)
        y = y - (1.0 / NP.sqrt(3)) * p1 * NP.sin(NP.pi / 6)
        # Vector 2 - bisect out of lower right vertex  
        p2 = points[idx, 1]  
        x = x + (1.0 / NP.sqrt(3)) * p2 * NP.cos(NP.pi / 6)
        y = y - (1.0 / NP.sqrt(3)) * p2 * NP.sin(NP.pi / 6)        
        # Vector 3 - bisect out of top vertex
        p3 = points[idx, 2]
        y = y + (1.0 / NP.sqrt(3) * p3)

        tripts[idx,:] = (x,y)

    return tripts


if __name__ == '__main__':
    # Define a synthetic test dataset
    labels = ('[0.1  0.1  0.8]', 
              '[0.8  0.1  0.1]', 
              '[0.5  0.4  0.1]',
              '[0.33  0.34  0.33]')
    testpoints = NP.array([[0.1, 0.1, 0.8],                   
                           [0.8, 0.1, 0.1],
                           [0.5, 0.4, 0.1],
                           [0.33, 0.34, 0.33]])    
    # Define different colors for each label
    cmap = CM.get_cmap('spectral')
    norm = C.Normalize(vmin=0, vmax=len(labels))
    c = range(len(labels))
    # Do scatter plot
    fig = plotSimplex(testpoints, s=100, c=c,                      
                      cmap=cmap, norm=norm)
    # Make color-label legend
    P.legend([PA.Rectangle((0, 0), 1, 1, 
                           fc=cmap(norm(idx)))
              for idx in range(len(labels))], 
             labels)

    P.show()