Source code for bumps.dream.outliers

"""
Chain outlier tests.
"""

__all__ = ["identify_outliers"]

import os

from numpy import mean, std, sqrt, where, argmin, arange, array
from numpy import sort
from scipy.stats import t as student_t
from scipy.stats import scoreatpercentile

from .mahal import mahalanobis
from .acr import ACR

tinv = student_t.ppf

# Hack to adjust the aggressiveness of the interquartile range test.
# TODO: Document/remove BUMPS_INTERQUARTILES or replace with command option.
BUMPS_INTERQUARTILES = float(os.environ.get("BUMPS_INTERQUARTILES", "2.0"))

# CRUFT: scoreatpercentile not accepting array arguments in older scipy
def prctile(v, Q):
    v = sort(v)
    return [scoreatpercentile(v, Qi) for Qi in Q]


[docs] def identify_outliers(test, llf, x=None): """ Determine which chains have converged on a local maximum much lower than the maximum likelihood. *test* is the name of the test to use (one of IQR, Grubbs, Mahal or none). IQR rejects any chains with mean log likelihood more than than twice the inter-quartile range below the value of the 25% quartile. The Grubbs method uses a t-test to determine which chains have a mean log likelihood extremely far below the mean across all the chains. The Mahal test looks at the head of the chain with the worst mean log likelihood and marks it as an outlier if it is far from the centroid of the population. This assumes that the posterior is approximately gaussian, which is not true in general. *llf* is a set of log likelihood values for all chains, which is an array of shape (chain len, num chains) *x* is the current population with one point for each each, which is an array of shape (num chains, num vars). This is only used for the Mahal test. Returns an integer array of outlier indices. """ # Check whether any of these active chains are outlier chains test = test.lower() if test == 'iqr': # Determine the mean log density of the active chains v = mean(llf, axis=0) # Derive the upper and lower quartile of the chain averages q1, q3 = prctile(v, [25., 75.]) # Derive the Inter Quartile Range (IQR) iqr = q3 - q1 # See whether there are any outlier chains # 2017-10-06 [PAK] test against chain max rather than chain mean. # Chains wandering inside the active region should not be punished. # Since removing outliers will delay convergence tests until the # outlier removal effects have disappeared (i.e., an entire frame), # a less aggressive test will speed completion. vmax = llf.max(axis=0) outliers = where(vmax < q1 - BUMPS_INTERQUARTILES*iqr)[0] elif test == 'grubbs': # Determine the mean log density of the active chains v = mean(llf, axis=0) # Compute zscore for chain averages zscore = (mean(v) - v) / std(v, ddof=1) # Determine t-value of one-sided interval n = len(v) t2 = tinv(1 - 0.01/n, n-2)**2 # 95% interval # Determine the critical value gcrit = ((n - 1)/sqrt(n)) * sqrt(t2/(n-2 + t2)) # Then check against this outliers = where(zscore > gcrit)[0] elif test == 'mahal': # Determine the mean log density of the active chains v = mean(llf, axis=0) # Find which chain has minimum log_density minidx = argmin(v) # Use the Mahalanobis distance to find outliers in the population alpha = 0.01 npop, nvar = x.shape gcrit = ACR(nvar, npop-1, alpha) #print "alpha", alpha, "nvar", nvar, "npop", npop, "gcrit", gcrit # check the Mahalanobis distance of the current point to other chains d1 = mahalanobis(x[minidx, :], x[minidx != arange(npop), :]) #print "d1", d1, "minidx", minidx # and see if it is an outlier outliers = array([minidx]) if d1 > gcrit else array([]) elif test == 'none': outliers = array([]) else: raise ValueError("Unknown outlier test "+test) return outliers
def test_outliers(): from .walk import walk from numpy.random import multivariate_normal, seed from numpy import vstack, ones, eye seed(2) # Remove uncertainty on tests # Set a number of good and bad chains ngood, nbad = 25, 2 # Make chains mean-reverting chains with widely separated values for # bad and good; put bad chains first. chains = walk(1000, mu=[1]*nbad+[5]*ngood, sigma=0.45, alpha=0.1) # Check IQR and Grubbs assert (identify_outliers('IQR', chains, None) == arange(nbad)).all() assert (identify_outliers('Grubbs', chains, None) == arange(nbad)).all() # Put points for 'bad' chains at [-1,...,-1] and 'good' chains at [1,...,1] x = vstack((multivariate_normal(-ones(4), 0.1*eye(4), size=nbad), multivariate_normal(ones(4), 0.1*eye(4), size=ngood))) assert identify_outliers('Mahal', chains, x)[0] in range(nbad) # Put points for _all_ chains at [1,...,1] and check that mahal return [] xsame = multivariate_normal(ones(4), 0.2*eye(4), size=ngood+nbad) assert len(identify_outliers('Mahal', chains, xsame)) == 0 # Check again with large variance x = vstack((multivariate_normal(-3*ones(4), eye(4), size=nbad), multivariate_normal(ones(4), 10*eye(4), size=ngood))) assert len(identify_outliers('Mahal', chains, x)) == 0 # ===================================================================== # Test replacement # Construct a state object from numpy.linalg import norm from .state import MCMCDraw ngen, npop = chains.shape npop, nvar = x.shape state = MCMCDraw(Ngen=ngen, Nthin=ngen, Nupdate=0, Nvar=nvar, Npop=npop, Ncr=0, thinning=0) # Fill it with chains for i in range(ngen): state._generation(new_draws=npop, x=x, logp=chains[i], accept=npop) # Make a copy of the current state so we can check it was updated nx, nlogp = x+0, chains[-1]+0 # Remove outliers state.remove_outliers(nx, nlogp, test='IQR') # Check that the outliers were removed outliers = state.outliers() assert outliers.shape[0] == nbad for i in range(nbad): assert nlogp[outliers[i, 1]] == chains[-1][outliers[i, 2]] assert nx[outliers[i, 1], -1] == x[outliers[i, 2], -1] if __name__ == "__main__": test_outliers()