Source code for bumps.dream.corrplot

# This program is public domain
# Author Paul Kienzle
"""
2-D correlation histograms

Generate 2-D correlation histograms and display them in a figure.

Uses false color plots of density.
"""
__all__ = ['Corr2d']

import numpy as np
from numpy import inf

from matplotlib import cm, colors, image, artist
from matplotlib.font_manager import FontProperties
from matplotlib.ticker import MaxNLocator

try:
    COLORMAP = colors.LinearSegmentedColormap.from_list(
        'density', ('w', 'y', 'g', 'b', 'r'))
except Exception:
    COLORMAP = cm.gist_earth_r


[docs] class Corr2d(object): """ Generate and manage 2D correlation histograms. """ def __init__(self, data, labels=None, **kw): if labels is None: labels = ["P"+str(i+1) for i, _ in enumerate(data)] self.N = len(data) self.labels = labels self.data = data self.hists = _hists(data, **kw) #for k, v in self.hists.items(): # print k, (v[1][0], v[1][-1]), (v[2][0], v[2][-1]) self.ax = None # will be set on plot
[docs] def R(self): return np.corrcoef(self.data)
def __getitem__(self, key): """ Retrieve correlation histogram for data[i] X data[j]. Returns bin i edges, bin j edges, and histogram """ i, j = key return self.hists[i, j]
[docs] def plot(self, title=None): """ Plot the correlation histograms on the specified figure """ import pylab pylab.clf() fig = pylab.gcf() if title is not None: fig.text(0.5, 0.95, title, horizontalalignment='center', fontproperties=FontProperties(size=16)) self.ax = _plot(fig, self.hists, self.labels, self.N)
def _hists(data, ranges=None, **kw): """ Generate pair-wise correlation histograms """ n = len(data) if ranges is None: low, high = np.min(data, axis=1), np.max(data, axis=1) ranges = [(l, h) for l, h in zip(low, high)] return dict(((i, j), np.histogram2d(data[i], data[j], range=[ranges[i], ranges[j]], **kw)) for i in range(0, n) for j in range(i+1, n)) def _plot(fig, hists, labels, n, show_ticks=None): """ Plot pair-wise correlation histograms """ if n <= 1: fig.text(0.5, 0.5, "No correlation plots when only one variable", ha="center", va="center") return vmin, vmax = inf, -inf for data, _, _ in hists.values(): positive = data[data > 0] if len(positive) > 0: vmin = min(vmin, np.amin(positive)) vmax = max(vmax, np.amax(positive)) norm = colors.LogNorm(vmin=vmin, vmax=vmax, clip=False) #norm = colors.Normalize(vmin=vmin, vmax=vmax) mapper = image.FigureImage(fig) mapper.set_array(np.zeros(0)) mapper.set_cmap(cmap=COLORMAP) mapper.set_norm(norm) if show_ticks is None: show_ticks = n < 3 ax = {} Nr = Nc = n-1 for i in range(0, n-1): for j in range(i+1, n): sharex = ax.get((0, j), None) sharey = ax.get((i, i+1), None) a = fig.add_subplot(Nr, Nc, (Nr-i-1)*Nc + j, sharex=sharex, sharey=sharey) ax[(i, j)] = a a.xaxis.set_major_locator(MaxNLocator(4, steps=[1, 2, 4, 5, 10])) a.yaxis.set_major_locator(MaxNLocator(4, steps=[1, 2, 4, 5, 10])) data, x, y = hists[(i, j)] data = np.clip(data, vmin, vmax) a.pcolorfast(y, x, data, cmap=COLORMAP, norm=norm) # Show labels or hide ticks if i != 0: artist.setp(a.get_xticklabels(), visible=False) if i == n-2 and j == n-1: a.set_xlabel(labels[j]) #a.xaxis.set_label_position("top") #a.xaxis.set_offset_position("top") if not show_ticks: a.xaxis.set_ticks([]) if j == i+1: a.set_ylabel(labels[i]) else: artist.setp(a.get_yticklabels(), visible=False) if not show_ticks: a.yaxis.set_ticks([]) a.zoomable = True # Adjust subplots and add the colorbar fig.subplots_adjust(left=0.07, bottom=0.07, top=0.9, right=0.85, wspace=0.0, hspace=0.0) cax = fig.add_axes([0.88, 0.2, 0.04, 0.6]) fig.colorbar(mapper, cax=cax, orientation='vertical') return ax def zoom(event, step): ax = event.inaxes if not hasattr(ax, 'zoomable'): return # TODO: test logscale step *= 3 if ax.zoomable is not True and 'mapper' in ax.zoomable: mapper = ax.zoomable['mapper'] if event.ydata is not None: lo, hi = mapper.get_clim() pt = event.ydata*(hi-lo)+lo lo, hi = _rescale(lo, hi, pt, step) mapper.set_clim((lo, hi)) if ax.zoomable is True and event.xdata is not None: lo, hi = ax.get_xlim() lo, hi = _rescale(lo, hi, event.xdata, step) ax.set_xlim((lo, hi)) if ax.zoomable is True and event.ydata is not None: lo, hi = ax.get_ylim() lo, hi = _rescale(lo, hi, event.ydata, step) ax.set_ylim((lo, hi)) ax.figure.canvas.draw_idle() def _rescale(lo, hi, pt, step): scale = float(hi-lo)*step/(100 if step > 0 else 100-step) bal = float(pt-lo)/(hi-lo) new_lo = lo - bal*scale new_hi = hi + (1-bal)*scale return new_lo, new_hi