Source code for gammy.plot

"""Plotting tools for Gammy models and formulae

"""

import logging
from typing import Callable, List, Optional, Union

import bayespy as bp
try:
    from mpl_toolkits.mplot3d import Axes3D
except ImportError:
    logging.info(
        "Problem with importing Axes3D from mpl_toolkits.mplot3d. Skipping."
    )
import matplotlib.pyplot as plt
from matplotlib.colors import SymLogNorm
import numpy as np

import gammy
from gammy import utils
from gammy.utils import pipe


[docs]def validation_plot( model, input_data, y, grid_limits, input_maps, index=None, xlabels=None, titles=None, gridsize=20, color="r", **kwargs ): """Validation plot for a GAM object Contains: - Series plot with predicted vs. observed - Partial residual plots Parameters ---------- model : gammy.bayespy.GAM | gammy.numpy.GAM Visualized model input_data : np.ndarray Input data y : np.ndarray Observations grid_limits : List Grid limits, either `[a, b]` or `[[a_1, b_1], ..., [a_N, b_N]]` input_maps : List[Callable] List of input maps to be used for each pair of grid limits index : np.ndarray Optional x-axis for the series plot xlabels : List[str] Optional x-labels for the partial residual plots gridsize : int Number of points in the input dimensions discretizations color : str Color of scatter points """ N = len(model.formula) N_rows = 2 + (N + 1) // 2 fig = plt.figure(figsize=(8, 2 * N_rows)) gs = fig.add_gridspec(2 + (N + 1) // 2, 2) xlabels = xlabels or [None] * len(model.formula) titles = titles or [None] * len(model.formula) index = np.arange(len(input_data)) if index is None else index assert ( len(grid_limits) == 2 if len(input_data.shape) == 1 else ( len(grid_limits) == input_data.shape[1] and all([len(xs) == 2 for xs in grid_limits]) ) ), ( "Given grid limits do not match with the shape of input data." ) assert len(model.formula.terms) == len(input_maps), ( "Must give exactly one input per model term." ) assert len(model.formula.terms) == len(titles), ( "Must give exactly one title per model term." ) # Data and predictions grid = np.array( utils.listmap(lambda x: np.linspace(x[0], x[1], gridsize))(grid_limits) ).T if len(input_data.shape) == 2 else np.linspace( grid_limits[0], grid_limits[1], gridsize ) marginals = model.predict_variance_marginals(grid) residuals = model.marginal_residuals(input_data, y) # Time series plot ax = fig.add_subplot(gs[0, :]) (mu, sigma_theta) = model.predict_variance_theta(input_data) lower = mu - 2 * np.sqrt(sigma_theta + model.inv_mean_tau) upper = mu + 2 * np.sqrt(sigma_theta + model.inv_mean_tau) ax.plot(index, y, linewidth=0, marker="o", alpha=0.3, color=color) ax.plot(index, mu, color="k") ax.fill_between(index, lower, upper, color="k", alpha=0.3) ax.grid(True) # XY-plot ax = fig.add_subplot(gs[1, :]) ax.plot(mu, y, alpha=0.3, marker="o", lw=0, color=color) ax.plot([mu.min(), mu.max()], [mu.min(), mu.max()], c="k", label="x=y") ax.legend(loc="best") ax.grid(True) ax.set_xlabel("Predictions") ax.set_ylabel("Observations") # Partial residual plots for i, ((mu, sigma), res, input_map, xlabel, title) in enumerate( zip(marginals, residuals, input_maps, xlabels, titles) ): x = input_map(grid) if len(x.shape) == 1 or x.shape[1] == 1: ax = fig.add_subplot(gs[2 + i // 2, i % 2]) (lower, upper) = ( mu - 2 * np.sqrt(sigma), mu + 2 * np.sqrt(sigma) ) ax.scatter(input_map(input_data), res, color=color, **kwargs) ax.plot(x, mu, c='k', lw=2) ax.fill_between(x, lower, upper, alpha=0.3, color="k") ax.set_xlabel(xlabel) elif x.shape[1] == 2: ax = fig.add_subplot(gs[2 + i // 2, i % 2], projection="3d") u, v = np.meshgrid(x[:, 0], x[:, 1]) w = np.hstack(( u.reshape(-1, 1), v.reshape(-1, 1) )) # Override mu and sigma on purpose! (mu, sigma) = model.predict_variance_marginal(w, i) mu_mesh = mu.reshape(u.shape) ax.plot_surface(u, v, mu_mesh) else: raise NotImplementedError("High-dimensional plots not supported.") ax.set_title(title) ax.grid(True) fig.tight_layout() return fig
[docs]def gaussian1d_density_plot(model: gammy.bayespy.GAM): """Plot 1-D density for each parameter """ N = len(model.formula) N_rows = 2 + (N + 1) // 2 fig = plt.figure(figsize=(8, 2 * N_rows)) gs = fig.add_gridspec(N + 1, 1) # Plot inverse gamma ax = fig.add_subplot(gs[0]) (b, a) = (-model.tau.phi[0], model.tau.phi[1]) mu = a / b grid = np.arange(0.5 * mu, 1.5 * mu, mu / 300) ax.plot(grid, model.tau.pdf(grid)) ax.set_title(r"$\tau$ = noise inverse variance") ax.grid(True) # Plot marginal thetas for i, theta in enumerate(model.theta_marginals): ax = fig.add_subplot(gs[i + 1]) mus = theta.get_moments()[0] mus = np.array([mus]) if mus.shape == () else mus cov = utils.solve_covariance(theta.get_moments()) stds = pipe( np.array([cov]) if cov.shape == () else np.diag(cov), np.sqrt ) left = (mus - 4 * stds).min() right = (mus + 4 * stds).max() grid = np.arange(left, right, (right - left) / 300) for (mu, std) in zip(mus, stds): node = bp.nodes.GaussianARD(mu, 1 / std ** 2) ax.plot(grid, node.pdf(grid)) ax.set_title(r"$\theta_{0}$".format(i)) ax.grid(True) fig.tight_layout() return fig
[docs]def gaussian2d_density_plot(model: gammy.bayespy.GAM, i, j): """Plot 2-D joint distribution of indices i and j """ raise NotImplementedError
[docs]def covariance_plot(model, ax=None, linthresh=0.1, **kwargs): """Covariance matrix """ ax = plt.figure().gca() if ax is None else ax C = model.covariance_theta im = ax.imshow( C, norm=SymLogNorm( vmin=np.min(C), vmax=np.max(C), linthresh=linthresh, base=10 ), **kwargs ) return (ax, im)
[docs]def basis_plot( formula: gammy.formulae.Formula, grid_limits, input_maps, gridsize=20 ): """Plot all basis functions """ # Figure definition N = len(formula) fig = plt.figure(figsize=(8, max(4 * N // 2, 8))) gs = fig.add_gridspec(N, 1) # Data and predictions grid = ( pipe( grid_limits, utils.listmap(lambda x: np.linspace(x[0], x[1], gridsize)), lambda x: np.array(x).T ) ) # Plot stuff for i, (basis, input_map) in enumerate(zip(formula.terms, input_maps)): ax = fig.add_subplot(gs[i]) x = input_map(grid) for f in basis: ax.plot(x, f(grid)) return fig