Python matplotlib example 'render tree usa'

  def get_color_name(state):

Python matplotlib example: render tree usa

from Bio import Phylo
import datetime
from decimal import Decimal

import numpy as np

import pandas as pd

import matplotlib.gridspec as gridspec
import matplotlib.pyplot as plt
import matplotlib.cm as cm
import matplotlib.colors as colors
import matplotlib.colorbar as cb
from matplotlib.patches import Polygon, Circle, PathPatch
from matplotlib import rc

from mpl_toolkits.basemap import Basemap

rc('font',**{'family':'sans-serif','sans-serif':['Helvetica'], 'size': 18})

tp = Phylo.parse("../2018.10.17/RAxML_bipartitions.2018.10.17.wnv.usa", 'newick')
tree = None
for t in tp:
    tree = t

# Root at oldest sequence
_min_date = datetime.datetime(2018, 8, 18, 10, 9, 9, 425642)
root = None
for i in tree.get_terminals():
    _ = datetime.datetime.strptime(i.name.split("_")[1], "%Y-%m-%d")
    if _min_date > _:
        _min_date = _
        root = i
    i.name = i.name.replace("SanFransisco", "SanFrancisco")

# tree.root_at_midpoint()

# Write tree to file
Phylo.write(tree, "../2018.10.17/RAxML_bipartitions.2018.10.17.wnv.usa.rerooted.nwk", "newick")

# Get list of states other than CA
states = [i.name.split("_")[3].lower() for i in tree.get_terminals() if i.name.split("_")[3]!="CA"]
states = list(set(states))

cnorm_state = colors.Normalize(vmin=0, vmax=len(states))
smap_state = cm.ScalarMappable(norm=cnorm_state, cmap=cm.Dark2)

for _i, i in enumerate(tree.get_terminals()):
    i.y = _i
    i.x = tree.distance(i)

for i in reversed(tree.get_nonterminals()):
    _ = i.clades
    i.y = (_[0].y + _[-1].y)/2
    i.x = tree.distance(i)

f = plt.figure(figsize=(20,25))
gs = gridspec.GridSpec(2, 2, width_ratios=[1,0.75], height_ratios = [0.5,0.5])

ax = plt.subplot(gs[:,0])
rtax = plt.subplot(gs[0, 1])
cax = plt.subplot(gs[1, 1])

ax.set_title("A. Maximum likelihood tree")
rtax.set_title("B. Root to tip regression plot")
cax.set_title("C. Legend")

# Draw California
ll_lng = -125.0
ll_lat = 25.0
ur_lat = 49.5
ur_lng = -66.96
m = Basemap(projection='merc', resolution = 'i', llcrnrlon= ll_lng , llcrnrlat=ll_lat, urcrnrlon=ur_lng, urcrnrlat=ur_lat, ax = cax)
m.readshapefile("../../shapefile/gadm36_USA_shp/gadm36_USA_1", "units", drawbounds=False)
centroids = []
patches = []
for info, shape in zip(m.units_info, m.units):
    if info["NAME_1"] == "Alaska" or info["NAME_1"] == "Hawaii":
    poly = Polygon(np.array(shape), True)
    x, y = zip(*shape)
        "x": np.mean(x),
        "y": np.mean(y),
        "name": info["HASC_1"][3:],
        "RINGNUM": info["RINGNUM"]

_ = [i["x"] for i in centroids]
cNorm  = colors.Normalize(vmin=np.min(_), vmax=np.max(_))
smap_state = cm.ScalarMappable(norm=cNorm, cmap=cm.plasma)

for i in range(0, len(patches)):

    m.scatter([i["x"] for i in centroids],[i["y"] for i in centroids],color="#000000",marker="o",s=0)

centroids = pd.DataFrame(centroids)

def get_color_name(state):
    _ = centroids[centroids["name"].str.lower().str.replace(" ", "") == state.lower()]
    _ =  _["x"].mean()
    return smap_state.to_rgba(_)

# Plot branches
_ = {
    "x": [],
    "y": [],
    "c": []
for i in tree.get_nonterminals():
    for j in i.clades:
        _t = ax.plot([i.x, i.x], [i.y, j.y], ls='-', color="#000000", zorder = 1)
        _t = ax.plot([i.x, j.x], [j.y, j.y], ls='-', color="#000000", zorder = 1)
        if j.confidence == None:
        if j.confidence >= 75:
        elif j.confidence >= 50:

ax.scatter(_["x"], _["y"], c = "#000000", s = 50, zorder = 2)
ax.scatter(_["x"], _["y"], c = _["c"], s = 25, zorder = 2)

# for i in tree.get_nonterminals():
#     if i.branch_length != None:
#         _ = ax.plot(i.x, i.y, marker = 'o', color='#000000')
#         _ = ax.text(i.x, i.y+1, str(i.y))

_ = {
    "x": [],
    "y": [],
    "c": []
for i in tree.get_terminals():
    c = (150,75,0,1)
    c = [i/255 for i in c]
    c[-1] = 1
    c =get_color_name(i.name.split("_")[3].lower())
    # ax.text(i.x+0.001, i.y, i.name.split("_")[3])
    # if i.name in hp["taxon name"].tolist():
    #     ax.text(i.x + 0.000025, i.y, i.name.split("_")[0] + " " + str(hp[hp["taxon name"] == i.name]["#of homoplasic mutations"].values[0]))

ax.scatter(_["x"], _["y"], c = "#000000", s = 100, zorder = 2)
ax.scatter(_["x"], _["y"], c = _["c"], s = 50, zorder = 2)

for i in np.arange(0, 0.014, 0.004):
    ax.axvspan(i, i+0.002, color = "#ECECEC", zorder = -1)

df = pd.read_table("../2018.10.17/clock_rate.tsv", sep ="\t")

fit = np.polyfit(df["date"],df["distance"],1)
fit_fn = np.poly1d(fit)

_ = np.arange(1995, 2021)
rtax.plot(_, fit_fn(_), "--k")

c = []
for j in df.index.values:
    i = df.ix[j]["tip"]
    _ = get_color_name(i.split("_")[3].lower())
    # if i in hp["taxon name"].tolist():
    #     rtax.text(df.ix[j]["date"], df.ix[j]["distance"], i.split("_")[0])

_ = "%.2E" % Decimal((np.diff(fit_fn(_)[:2]) / np.diff(_[:2]))[0])
_ += " subs/site/year"
_ = "slope = " + _
rtax.text(1996, 0.012, _)

rtax.scatter(df["date"], df["distance"], c = "#000000", s = 100)
rtax.scatter(df["date"], df["distance"], c = c, s = 50)
