"""Plotting tools for Gammy models and formulae
"""
import logging
from typing import Callable, List, Optional, Union
import bayespy as bp
try:
    from mpl_toolkits.mplot3d import Axes3D
except ImportError:
    logging.info(
        "Problem with importing Axes3D from mpl_toolkits.mplot3d. Skipping."
    )
import matplotlib.pyplot as plt
from matplotlib.colors import SymLogNorm
import numpy as np
import gammy
from gammy import utils
from gammy.utils import pipe
[docs]def validation_plot(
        model,
        input_data,
        y,
        grid_limits,
        input_maps,
        index=None,
        xlabels=None,
        titles=None,
        gridsize=20,
        color="r",
        **kwargs
):
    """Validation plot for a GAM object
    Contains:
        - Series plot with predicted vs. observed
        - Partial residual plots
    Parameters
    ----------
    model : gammy.bayespy.GAM | gammy.numpy.GAM
        Visualized model
    input_data : np.ndarray
        Input data
    y : np.ndarray
        Observations
    grid_limits : List
        Grid limits, either `[a, b]` or `[[a_1, b_1], ..., [a_N, b_N]]`
    input_maps : List[Callable]
        List of input maps to be used for each pair of grid limits
    index : np.ndarray
        Optional x-axis for the series plot
    xlabels : List[str]
        Optional x-labels for the partial residual plots
    gridsize : int
        Number of points in the input dimensions discretizations
    color : str
        Color of scatter points
    """
    N = len(model.formula)
    N_rows = 2 + (N + 1) // 2
    fig = plt.figure(figsize=(8, 2 * N_rows))
    gs = fig.add_gridspec(2 + (N + 1) // 2, 2)
    xlabels = xlabels or [None] * len(model.formula)
    titles = titles or [None] * len(model.formula)
    index = np.arange(len(input_data)) if index is None else index
    assert (
        len(grid_limits) == 2 if len(input_data.shape) == 1 else
        (
            len(grid_limits) == input_data.shape[1] and
            all([len(xs) == 2 for xs in grid_limits])
        )
    ), (
        "Given grid limits do not match with the shape of input data."
    )
    assert len(model.formula.terms) == len(input_maps), (
        "Must give exactly one input per model term."
    )
    assert len(model.formula.terms) == len(titles), (
        "Must give exactly one title per model term."
    )
    # Data and predictions
    grid = np.array(
        utils.listmap(lambda x: np.linspace(x[0], x[1], gridsize))(grid_limits)
    ).T if len(input_data.shape) == 2 else np.linspace(
        grid_limits[0], grid_limits[1], gridsize
    )
    marginals = model.predict_variance_marginals(grid)
    residuals = model.marginal_residuals(input_data, y)
    # Time series plot
    ax = fig.add_subplot(gs[0, :])
    (mu, sigma_theta) = model.predict_variance_theta(input_data)
    lower = mu - 2 * np.sqrt(sigma_theta + model.inv_mean_tau)
    upper = mu + 2 * np.sqrt(sigma_theta + model.inv_mean_tau)
    ax.plot(index, y, linewidth=0, marker="o", alpha=0.3, color=color)
    ax.plot(index, mu, color="k")
    ax.fill_between(index, lower, upper, color="k", alpha=0.3)
    ax.grid(True)
    # XY-plot
    ax = fig.add_subplot(gs[1, :])
    ax.plot(mu, y, alpha=0.3, marker="o", lw=0, color=color)
    ax.plot([mu.min(), mu.max()], [mu.min(), mu.max()], c="k", label="x=y")
    ax.legend(loc="best")
    ax.grid(True)
    ax.set_xlabel("Predictions")
    ax.set_ylabel("Observations")
    # Partial residual plots
    for i, ((mu, sigma), res, input_map, xlabel, title) in enumerate(
        zip(marginals, residuals, input_maps, xlabels, titles)
    ):
        x = input_map(grid)
        if len(x.shape) == 1 or x.shape[1] == 1:
            ax = fig.add_subplot(gs[2 + i // 2, i % 2])
            (lower, upper) = (
                mu - 2 * np.sqrt(sigma),
                mu + 2 * np.sqrt(sigma)
            )
            ax.scatter(input_map(input_data), res, color=color, **kwargs)
            ax.plot(x, mu, c='k', lw=2)
            ax.fill_between(x, lower, upper, alpha=0.3, color="k")
            ax.set_xlabel(xlabel)
        elif x.shape[1] == 2:
            ax = fig.add_subplot(gs[2 + i // 2, i % 2], projection="3d")
            u, v = np.meshgrid(x[:, 0], x[:, 1])
            w = np.hstack((
                u.reshape(-1, 1), v.reshape(-1, 1)
            ))
            # Override mu and sigma on purpose!
            (mu, sigma) = model.predict_variance_marginal(w, i)
            mu_mesh = mu.reshape(u.shape)
            ax.plot_surface(u, v, mu_mesh)
        else:
            raise NotImplementedError("High-dimensional plots not supported.")
        ax.set_title(title)
        ax.grid(True)
    fig.tight_layout()
    return fig 
[docs]def gaussian1d_density_plot(model: gammy.bayespy.GAM):
    """Plot 1-D density for each parameter
    """
    N = len(model.formula)
    N_rows = 2 + (N + 1) // 2
    fig = plt.figure(figsize=(8, 2 * N_rows))
    gs = fig.add_gridspec(N + 1, 1)
    # Plot inverse gamma
    ax = fig.add_subplot(gs[0])
    (b, a) = (-model.tau.phi[0], model.tau.phi[1])
    mu = a / b
    grid = np.arange(0.5 * mu, 1.5 * mu, mu / 300)
    ax.plot(grid, model.tau.pdf(grid))
    ax.set_title(r"$\tau$ = noise inverse variance")
    ax.grid(True)
    # Plot marginal thetas
    for i, theta in enumerate(model.theta_marginals):
        ax = fig.add_subplot(gs[i + 1])
        mus = theta.get_moments()[0]
        mus = np.array([mus]) if mus.shape == () else mus
        cov = utils.solve_covariance(theta.get_moments())
        stds = pipe(
            np.array([cov]) if cov.shape == ()
            else np.diag(cov),
            np.sqrt
        )
        left = (mus - 4 * stds).min()
        right = (mus + 4 * stds).max()
        grid = np.arange(left, right, (right - left) / 300)
        for (mu, std) in zip(mus, stds):
            node = bp.nodes.GaussianARD(mu, 1 / std ** 2)
            ax.plot(grid, node.pdf(grid))
        ax.set_title(r"$\theta_{0}$".format(i))
        ax.grid(True)
    fig.tight_layout()
    return fig 
[docs]def gaussian2d_density_plot(model: gammy.bayespy.GAM, i, j):
    """Plot 2-D joint distribution of indices i and j
    """
    raise NotImplementedError 
[docs]def covariance_plot(model, ax=None, linthresh=0.1, **kwargs):
    """Covariance matrix
    """
    ax = plt.figure().gca() if ax is None else ax
    C = model.covariance_theta
    im = ax.imshow(
        C,
        norm=SymLogNorm(
            vmin=np.min(C),
            vmax=np.max(C),
            linthresh=linthresh,
            base=10
        ),
        **kwargs
    )
    return (ax, im) 
[docs]def basis_plot(
        formula: gammy.formulae.Formula,
        grid_limits,
        input_maps,
        gridsize=20
):
    """Plot all basis functions
    """
    # Figure definition
    N = len(formula)
    fig = plt.figure(figsize=(8, max(4 * N // 2, 8)))
    gs = fig.add_gridspec(N, 1)
    # Data and predictions
    grid = (
        pipe(
            grid_limits,
            utils.listmap(lambda x: np.linspace(x[0], x[1], gridsize)),
            lambda x: np.array(x).T
        )
    )
    # Plot stuff
    for i, (basis, input_map) in enumerate(zip(formula.terms, input_maps)):
        ax = fig.add_subplot(gs[i])
        x = input_map(grid)
        for f in basis:
            ax.plot(x, f(grid))
    return fig