Back to index

python-biopython  1.60
_utils.py
Go to the documentation of this file.
00001 # Copyright (C) 2009 by Eric Talevich (eric.talevich@gmail.com)
00002 # This code is part of the Biopython distribution and governed by its
00003 # license. Please see the LICENSE file that should have been included
00004 # as part of this package.
00005 
00006 """Utilities for handling, displaying and exporting Phylo trees.
00007 
00008 Third-party libraries are loaded when the corresponding function is called.
00009 """
00010 __docformat__ = "restructuredtext en"
00011 
00012 import math
00013 import sys
00014 
00015 
00016 def to_networkx(tree):
00017     """Convert a Tree object to a networkx graph.
00018 
00019     The result is useful for graph-oriented analysis, and also interactive
00020     plotting with pylab, matplotlib or pygraphviz, though the resulting diagram
00021     is usually not ideal for displaying a phylogeny.
00022 
00023     Requires NetworkX version 0.99 or later.
00024     """
00025     try:
00026         import networkx
00027     except ImportError:
00028         from Bio import MissingPythonDependencyError
00029         raise MissingPythonDependencyError(
00030                 "Install NetworkX if you want to use to_networkx.")
00031 
00032     def add_edge(graph, n1, n2):
00033         # NB (1/2010): the networkx API congealed recently
00034         # Ubuntu Lucid uses v0.99, newest is v1.0.1, let's support both
00035         if networkx.__version__ >= '1.0':
00036             graph.add_edge(n1, n2, weight=str(n2.branch_length or 1.0))
00037             # Copy branch color value as hex, if available
00038             if hasattr(n2, 'color') and n2.color is not None:
00039                 graph[n1][n2]['color'] = n2.color.to_hex()
00040             elif hasattr(n1, 'color') and n1.color is not None:
00041                 # Cascading color attributes
00042                 graph[n1][n2]['color'] = n1.color.to_hex()
00043                 n2.color = n1.color
00044             # Copy branch weight value (float) if available
00045             if hasattr(n2, 'width') and n2.width is not None:
00046                 graph[n1][n2]['width'] = n2.width
00047             elif hasattr(n1, 'width') and n1.width is not None:
00048                 # Cascading width attributes
00049                 graph[n1][n2]['width'] = n1.width
00050                 n2.width = n1.width
00051         elif networkx.__version__ >= '0.99':
00052             graph.add_edge(n1, n2, (n2.branch_length or 1.0))
00053         else:
00054             graph.add_edge(n1, n2)
00055 
00056     def build_subgraph(graph, top):
00057         """Walk down the Tree, building graphs, edges and nodes."""
00058         for clade in top:
00059             graph.add_node(clade.root)
00060             add_edge(graph, top.root, clade.root)
00061             build_subgraph(graph, clade)
00062 
00063     if tree.rooted:
00064         G = networkx.DiGraph()
00065     else:
00066         G = networkx.Graph()
00067     G.add_node(tree.root)
00068     build_subgraph(G, tree.root)
00069     return G
00070 
00071 
00072 def draw_graphviz(tree, label_func=str, prog='twopi', args='',
00073         node_color='#c0deff', **kwargs):
00074     """Display a tree or clade as a graph, using the graphviz engine.
00075 
00076     Requires NetworkX, matplotlib, Graphviz and either PyGraphviz or pydot.
00077 
00078     The third and fourth parameters apply to Graphviz, and the remaining
00079     arbitrary keyword arguments are passed directly to networkx.draw(), which
00080     in turn mostly wraps matplotlib/pylab.  See the documentation for Graphviz
00081     and networkx for detailed explanations.
00082 
00083     The NetworkX/matplotlib parameters are described in the docstrings for
00084     networkx.draw() and pylab.scatter(), but the most reasonable options to try
00085     are: *alpha, node_color, node_size, node_shape, edge_color, style,
00086     font_size, font_color, font_weight, font_family*
00087 
00088     :Parameters:
00089 
00090         label_func : callable
00091             A function to extract a label from a node. By default this is str(),
00092             but you can use a different function to select another string
00093             associated with each node. If this function returns None for a node,
00094             no label will be shown for that node.
00095 
00096             The label will also be silently skipped if the throws an exception
00097             related to ordinary attribute access (LookupError, AttributeError,
00098             ValueError); all other exception types will still be raised. This
00099             means you can use a lambda expression that simply attempts to look
00100             up the desired value without checking if the intermediate attributes
00101             are available:
00102 
00103                 >>> Phylo.draw_graphviz(tree, lambda n: n.taxonomies[0].code)
00104 
00105         prog : string
00106             The Graphviz program to use when rendering the graph. 'twopi'
00107             behaves the best for large graphs, reliably avoiding crossing edges,
00108             but for moderate graphs 'neato' looks a bit nicer.  For small
00109             directed graphs, 'dot' may produce a normal-looking cladogram, but
00110             will cross and distort edges in larger graphs. (The programs 'circo'
00111             and 'fdp' are not recommended.)
00112         args : string
00113             Options passed to the external graphviz program.  Normally not
00114             needed, but offered here for completeness.
00115 
00116     Example
00117     -------
00118 
00119     >>> import pylab
00120     >>> from Bio import Phylo
00121     >>> tree = Phylo.read('ex/apaf.xml', 'phyloxml')
00122     >>> Phylo.draw_graphviz(tree)
00123     >>> pylab.show()
00124     >>> pylab.savefig('apaf.png')
00125     """
00126     try:
00127         import networkx
00128     except ImportError:
00129         from Bio import MissingPythonDependencyError
00130         raise MissingPythonDependencyError(
00131                 "Install NetworkX if you want to use to_networkx.")
00132 
00133     G = to_networkx(tree)
00134     Gi = networkx.convert_node_labels_to_integers(G, discard_old_labels=False)
00135     try:
00136         posi = networkx.graphviz_layout(Gi, prog, args=args)
00137     except ImportError:
00138         raise MissingPythonDependencyError(
00139                 "Install PyGraphviz or pydot if you want to use draw_graphviz.")
00140 
00141     def get_label_mapping(G, selection):
00142         for node in G.nodes():
00143             if (selection is None) or (node in selection):
00144                 try:
00145                     label = label_func(node)
00146                     if label not in (None, node.__class__.__name__):
00147                         yield (node, label)
00148                 except (LookupError, AttributeError, ValueError):
00149                     pass
00150 
00151     if 'nodelist' in kwargs:
00152         labels = dict(get_label_mapping(G, set(kwargs['nodelist'])))
00153     else:
00154         labels = dict(get_label_mapping(G, None))
00155     kwargs['nodelist'] = labels.keys()
00156     if 'edge_color' not in kwargs:
00157         kwargs['edge_color'] = [isinstance(e[2], dict) and
00158                                 e[2].get('color', 'k') or 'k'
00159                                 for e in G.edges(data=True)]
00160     if 'width' not in kwargs:
00161         kwargs['width'] = [isinstance(e[2], dict) and
00162                            e[2].get('width', 1.0) or 1.0
00163                            for e in G.edges(data=True)]
00164 
00165     posn = dict((n, posi[Gi.node_labels[n]]) for n in G)
00166     networkx.draw(G, posn, labels=labels, node_color=node_color, **kwargs)
00167 
00168 
00169 def draw_ascii(tree, file=sys.stdout, column_width=80):
00170     """Draw an ascii-art phylogram of the given tree.
00171 
00172     The printed result looks like::
00173 
00174                                         _________ Orange
00175                          ______________|
00176                         |              |______________ Tangerine
00177           ______________|
00178          |              |          _________________________ Grapefruit
00179         _|              |_________|
00180          |                        |______________ Pummelo
00181          |
00182          |__________________________________ Apple
00183 
00184 
00185     :Parameters:
00186         file : file-like object
00187             File handle opened for writing the output drawing.
00188         column_width : int
00189             Total number of text columns used by the drawing.
00190     """
00191     taxa = tree.get_terminals()
00192     # Some constants for the drawing calculations
00193     max_label_width = max(len(str(taxon)) for taxon in taxa)
00194     drawing_width = column_width - max_label_width - 1
00195     drawing_height = 2 * len(taxa) - 1
00196 
00197     def get_col_positions(tree):
00198         """Create a mapping of each clade to its column position."""
00199         depths = tree.depths()
00200         # If there are no branch lengths, assume unit branch lengths
00201         if not max(depths.itervalues()):
00202             depths = tree.depths(unit_branch_lengths=True)
00203         # Potential drawing overflow due to rounding -- 1 char per tree layer
00204         fudge_margin = int(math.ceil(math.log(len(taxa), 2)))
00205         cols_per_branch_unit = ((drawing_width - fudge_margin)
00206                                 / float(max(depths.itervalues())))
00207         return dict((clade, int(round(blen*cols_per_branch_unit + 0.5)))
00208                     for clade, blen in depths.iteritems())
00209 
00210     def get_row_positions(tree):
00211         positions = dict((taxon, 2*idx) for idx, taxon in enumerate(taxa))
00212         def calc_row(clade):
00213             for subclade in clade:
00214                 if subclade not in positions:
00215                     calc_row(subclade)
00216             positions[clade] = (positions[clade.clades[0]] +
00217                                 positions[clade.clades[-1]]) / 2
00218         calc_row(tree.root)
00219         return positions
00220 
00221     col_positions = get_col_positions(tree)
00222     row_positions = get_row_positions(tree)
00223     char_matrix = [[' ' for x in range(drawing_width)]
00224                     for y in range(drawing_height)]
00225 
00226     def draw_clade(clade, startcol):
00227         thiscol = col_positions[clade]
00228         thisrow = row_positions[clade]
00229         # Draw a horizontal line
00230         for col in range(startcol, thiscol):
00231             char_matrix[thisrow][col] = '_'
00232         if clade.clades:
00233             # Draw a vertical line
00234             toprow = row_positions[clade.clades[0]]
00235             botrow = row_positions[clade.clades[-1]]
00236             for row in range(toprow+1, botrow+1):
00237                 char_matrix[row][thiscol] = '|'
00238             # NB: Short terminal branches need something to stop rstrip()
00239             if (col_positions[clade.clades[0]] - thiscol) < 2:
00240                 char_matrix[toprow][thiscol] = ','
00241             # Draw descendents
00242             for child in clade:
00243                 draw_clade(child, thiscol+1)
00244 
00245     draw_clade(tree.root, 0)
00246     # Print the complete drawing
00247     for idx, row in enumerate(char_matrix):
00248         line = ''.join(row).rstrip()
00249         # Add labels for terminal taxa in the right margin
00250         if idx % 2 == 0:
00251             line += ' ' + str(taxa[idx/2])
00252         file.write(line + '\n')
00253     file.write('\n')
00254 
00255 
00256 def draw(tree, label_func=str, do_show=True, show_confidence=True,
00257         # For power users
00258         axes=None, branch_labels=None):
00259     """Plot the given tree using matplotlib (or pylab).
00260 
00261     The graphic is a rooted tree, drawn with roughly the same algorithm as
00262     draw_ascii.
00263 
00264     Visual aspects of the plot can be modified using pyplot's own functions and
00265     objects (via pylab or matplotlib). In particular, the pyplot.rcParams
00266     object can be used to scale the font size (rcParams["font.size"]) and line
00267     width (rcParams["lines.linewidth"]).
00268 
00269     :Parameters:
00270         label_func : callable
00271             A function to extract a label from a node. By default this is str(),
00272             but you can use a different function to select another string
00273             associated with each node. If this function returns None for a node,
00274             no label will be shown for that node.
00275         do_show : bool
00276             Whether to show() the plot automatically.
00277         show_confidence : bool
00278             Whether to display confidence values, if present on the tree.
00279         axes : matplotlib/pylab axes
00280             If a valid matplotlib.axes.Axes instance, the phylogram is plotted
00281             in that Axes. By default (None), a new figure is created.
00282         branch_labels : dict or callable
00283             A mapping of each clade to the label that will be shown along the
00284             branch leading to it. By default this is the confidence value(s) of
00285             the clade, taken from the ``confidence`` attribute, and can be
00286             easily toggled off with this function's ``show_confidence`` option.
00287             But if you would like to alter the formatting of confidence values,
00288             or label the branches with something other than confidence, then use
00289             this option.
00290     """
00291     try:
00292         import matplotlib.pyplot as plt
00293     except ImportError:
00294         try:
00295             import pylab as plt
00296         except ImportError:
00297             from Bio import MissingPythonDependencyError
00298             raise MissingPythonDependencyError(
00299                     "Install matplotlib or pylab if you want to use draw.")
00300 
00301     # Options for displaying branch labels / confidence
00302     def conf2str(conf):
00303         if int(conf) == conf:
00304             return str(int(conf))
00305         return str(conf)
00306     if not branch_labels:
00307         if show_confidence:
00308             def format_branch_label(clade):
00309                 if hasattr(clade, 'confidences'):
00310                     # phyloXML supports multiple confidences
00311                     return '/'.join(conf2str(cnf.value)
00312                                     for cnf in clade.confidences)
00313                 if clade.confidence:
00314                     return conf2str(clade.confidence)
00315                 return None
00316         else:
00317             def format_branch_label(clade):
00318                 return None
00319     elif isinstance(branch_labels, dict):
00320         def format_branch_label(clade):
00321             return branch_labels.get(clade)
00322     else:
00323         assert callable(branch_labels), \
00324                 "branch_labels must be either a dict or a callable (function)"
00325         format_branch_label = branch_labels
00326 
00327     # Layout
00328 
00329     def get_x_positions(tree):
00330         """Create a mapping of each clade to its horizontal position.
00331 
00332         Dict of {clade: x-coord}
00333         """
00334         depths = tree.depths()
00335         # If there are no branch lengths, assume unit branch lengths
00336         if not max(depths.itervalues()):
00337             depths = tree.depths(unit_branch_lengths=True)
00338         return depths
00339 
00340     def get_y_positions(tree):
00341         """Create a mapping of each clade to its vertical position.
00342 
00343         Dict of {clade: y-coord}.
00344         Coordinates are negative, and integers for tips.
00345         """
00346         maxheight = tree.count_terminals()
00347         # Rows are defined by the tips
00348         heights = dict((tip, maxheight - i)
00349                 for i, tip in enumerate(reversed(tree.get_terminals())))
00350         # Internal nodes: place at midpoint of children
00351         def calc_row(clade):
00352             for subclade in clade:
00353                 if subclade not in heights:
00354                     calc_row(subclade)
00355             # Closure over heights
00356             heights[clade] = (heights[clade.clades[0]] +
00357                                 heights[clade.clades[-1]]) / 2.0
00358 
00359         if tree.root.clades:
00360             calc_row(tree.root)
00361         return heights
00362 
00363     x_posns = get_x_positions(tree)
00364     y_posns = get_y_positions(tree)
00365     # The function draw_clade closes over the axes object
00366     if axes is None:
00367         fig = plt.figure()
00368         axes = fig.add_subplot(1, 1, 1)
00369     elif not isinstance(axes, plt.matplotlib.axes.Axes):
00370         raise ValueError("Invalid argument for axes: %s" % axes)
00371 
00372     def draw_clade(clade, x_start, color, lw):
00373         """Recursively draw a tree, down from the given clade."""
00374         x_here = x_posns[clade]
00375         y_here = y_posns[clade]
00376         # phyloXML-only graphics annotations
00377         if hasattr(clade, 'color') and clade.color is not None:
00378             color = clade.color.to_hex()
00379         if hasattr(clade, 'width') and clade.width is not None:
00380             lw = clade.width * plt.rcParams['lines.linewidth']
00381         # Draw a horizontal line from start to here
00382         axes.hlines(y_here, x_start, x_here, color=color, lw=lw)
00383         # Add node/taxon labels
00384         label = label_func(clade)
00385         if label not in (None, clade.__class__.__name__):
00386             axes.text(x_here, y_here, ' %s' % label, verticalalignment='center')
00387         # Add label above the branch (optional)
00388         conf_label = format_branch_label(clade)
00389         if conf_label:
00390             axes.text(0.5*(x_start + x_here), y_here, conf_label,
00391                     fontsize='small', horizontalalignment='center')
00392         if clade.clades:
00393             # Draw a vertical line connecting all children
00394             y_top = y_posns[clade.clades[0]]
00395             y_bot = y_posns[clade.clades[-1]]
00396             # Only apply widths to horizontal lines, like Archaeopteryx
00397             axes.vlines(x_here, y_bot, y_top, color=color)
00398             # Draw descendents
00399             for child in clade:
00400                 draw_clade(child, x_here, color, lw)
00401 
00402     draw_clade(tree.root, 0, 'k', plt.rcParams['lines.linewidth'])
00403 
00404     # Aesthetics
00405 
00406     if hasattr(tree, 'name') and tree.name:
00407         axes.set_title(tree.name)
00408     axes.set_xlabel('branch length')
00409     axes.set_ylabel('taxa')
00410     # Add margins around the tree to prevent overlapping the axes
00411     xmax = max(x_posns.itervalues())
00412     axes.set_xlim(-0.05 * xmax, 1.25 * xmax)
00413     # Also invert the y-axis (origin at the top)
00414     # Add a small vertical margin, but avoid including 0 and N+1 on the y axis
00415     axes.set_ylim(max(y_posns.itervalues()) + 0.8, 0.2)
00416     if do_show:
00417         plt.show()
00418