Plotting
pymgcv.plot
provides plotting utilities for visualizing GAM models. Plotting is performed using matplotlib.
Across the package and examples, we use the import convention
- For an overall plot of the partial effects of a gam model, use
plot
. - For more fine control over the plotting, specific terms can be plotted onto a single matplotlib axis, using the functions:
plot
plot(
gam: AbstractGAM,
*,
ncols: int = 2,
scatter: bool = False,
data: DataFrame | Mapping[str, ndarray | Series] | None = None,
to_plot: type | UnionType | dict[str, list[AbstractTerm]] = AbstractTerm,
kwargs_mapper: dict[Callable, dict[str, Any]] | None = None,
) -> tuple[matplotlib.figure.Figure, matplotlib.axes._axes.Axes | numpy.ndarray]
Plot a gam model.
Except for some specialised cases, this plots the partial effects of the terms.
Parameters:
-
gam
(AbstractGAM
) –The fitted gam object to plot.
-
ncols
(int
, default:2
) –The number of columns before wrapping axes.
-
scatter
(bool
, default:False
) –Whether to plot the residuals (where possible), and the overlayed datapoints on 2D plots. For more fine control, see
kwargs_mapper
. Defaults to False. -
data
(DataFrame | Mapping[str, ndarray | Series] | None
, default:None
) –The data to use for plotting partial residuals and scatter points. Will default to using the data used for fitting. Only relevant if
scatter=True
. -
to_plot
(type | UnionType | dict[str, list[AbstractTerm]]
, default:AbstractTerm
) –Which terms to plot. If a type, only plots terms of that type (e.g.
to_plot = S | T
to plot smooths). If a dictionary, it should map the target names to an iterable of terms to plot (similar to how models are specified). -
kwargs_mapper
(dict[Callable, dict[str, Any]] | None
, default:None
) –Used to pass keyword arguments to the underlying
pymgcv.plot
functions. A dictionary mapping the plotting function to kwargs. For example, to disable the confidence intervals on the 1d plots, setkwargs_mapper
to
continuous_1d
continuous_1d(
*,
term: AbstractTerm | int,
gam: AbstractGAM,
target: str | None = None,
data: DataFrame | Mapping[str, Series | ndarray] | None = None,
eval_density: int = 100,
level: str | None = None,
n_standard_errors: int | float = 2,
residuals: bool = False,
plot_kwargs: dict[str, Any] | None = None,
fill_between_kwargs: dict[str, Any] | None = None,
scatter_kwargs: dict[str, Any] | None = None,
ax: Axes | None = None,
) -> Axes
Plot 1D smooth or linear terms with confidence intervals.
Note
- For terms with numeric "by" variables, the "by" variable is set to 1, showing the unscaled effect of the smooth.
Parameters:
-
term
(AbstractTerm | int
) –The model term to plot. Must be a univariate term (single variable). If an integer is provided, it is assumed to be the index of the term in the predictor of
target
. -
gam
(AbstractGAM
) –GAM model containing the term to plot.
-
target
(str | None
, default:None
) –Name of the target variable (response variable or family parameter name from the model specification). If set to None, an error is raised when multiple predictors are present; otherwise, the sole available target is used.
-
data
(DataFrame | Mapping[str, Series | ndarray] | None
, default:None
) –DataFrame used for plotting partial residuals and determining axis limits. Defaults to the data used for training.
-
eval_density
(int
, default:100
) –Number of evaluation points along the variable range for plotting the smooth curve. Higher values give smoother curves but increase computation time. Default is 100.
-
level
(str | None
, default:None
) –Must be provided for smooths with a categorical "by" variable or a
FactorSmooth
basis. Specifies the level to plot. -
n_standard_errors
(int | float
, default:2
) –Number of standard errors for confidence intervals.
-
residuals
(bool
, default:False
) –Whether to plot partial residuals.
-
plot_kwargs
(dict[str, Any] | None
, default:None
) –Keyword arguments passed to
matplotlib.pyplot.plot
for the main curve. -
fill_between_kwargs
(dict[str, Any] | None
, default:None
) –Keyword arguments passed to
matplotlib.pyplot.fill_between
for the confidence interval band. Pass{"disable": True}
to disable the confidence interval band. -
scatter_kwargs
(dict[str, Any] | None
, default:None
) –Keyword arguments passed to
matplotlib.pyplot.scatter
for partial residuals (ignored ifresiduals=False
). -
ax
(Axes | None
, default:None
) –Matplotlib Axes object to plot on. If None, uses current axes.
Returns:
-
Axes
–The matplotlib Axes object with the plot.
continuous_2d
continuous_2d(
*,
term: AbstractTerm | int,
gam: AbstractGAM,
target: str | None = None,
data: DataFrame | Mapping[str, ndarray | Series] | None = None,
eval_density: int = 50,
level: str | None = None,
contour_kwargs: dict | None = None,
contourf_kwargs: dict | None = None,
scatter_kwargs: dict | None = None,
ax: Axes | None = None,
) -> Axes
Plot 2D smooth surfaces as contour plots with data overlay.
This function is essential for understanding bivariate relationships and interactions between two continuous variables.
Parameters:
-
term
(AbstractTerm | int
) –The bivariate term to plot. Must have exactly two variables. Can be S('x1', 'x2') or T('x1', 'x2'). If an integer is provided, it is interpreted as the index of the term the list of predictors for
target
. -
gam
(AbstractGAM
) –GAM model containing the term to plot.
-
target
(str | None
, default:None
) –Name of the target variable (response variable or family parameter name from the model specification). If set to None, an error is raised when multiple predictors are present; otherwise, the sole available target is used.
-
data
(DataFrame | Mapping[str, ndarray | Series] | None
, default:None
) –DataFrame containing the variables for determining plot range and showing data points. Should typically be the training data.
-
eval_density
(int
, default:50
) –Number of evaluation points along each axis, creating an eval_density × eval_density grid. Higher values give smoother surfaces but increase computation time. Default is 50.
-
level
(str | None
, default:None
) –Must be provided for smooths with a categorical "by" variable or a
FactorSmooth
basis. Specifies the level to plot. -
contour_kwargs
(dict | None
, default:None
) –Keyword arguments passed to
matplotlib.pyplot.contour
for the contour lines. -
contourf_kwargs
(dict | None
, default:None
) –Keyword arguments passed to
matplotlib.pyplot.contourf
for the filled contours. -
scatter_kwargs
(dict | None
, default:None
) –Keyword arguments passed to
matplotlib.pyplot.scatter
for the data points overlay. Pass{"disable": True}
to avoid plotting. -
ax
(Axes | None
, default:None
) –Matplotlib Axes object to plot on. If None, uses current axes.
Returns:
-
Axes
–The matplotlib Axes object with the plot, allowing further customization.
Raises:
-
ValueError
–If the term doesn't have exactly two variables.
categorical
categorical(
*,
term: L | int,
gam: AbstractGAM,
target: str | None = None,
data: DataFrame | Mapping[str, Series | ndarray] | None = None,
residuals: bool = False,
n_standard_errors: int | float = 2,
errorbar_kwargs: dict[str, Any] | None = None,
scatter_kwargs: dict[str, Any] | None = None,
ax: Axes | None = None,
) -> Axes
Plot categorical terms with error bars and partial residuals.
Creates a plot showing:
- The estimated effect of each category level as points.
- Error bars representing confidence intervals.
- Partial residuals as jittered scatter points.
Parameters:
-
term
(L | int
) –The categorical term to plot. Must be a L term with a single categorical variable.
-
gam
(AbstractGAM
) –GAM model containing the term to plot.
-
target
(str | None
, default:None
) –Name of the target variable (response variable or family parameter name from the model specification). If set to None, an error is raised when multiple predictors are present; otherwise, the sole available target is used.
-
data
(DataFrame | Mapping[str, Series | ndarray] | None
, default:None
) –DataFrame (or dictionary) containing the categorical variable and response variable.
-
residuals
(bool
, default:False
) –Whether to plot partial residuals (jittered on x-axis).
-
n_standard_errors
(int | float
, default:2
) –Number of standard errors for confidence intervals.
-
errorbar_kwargs
(dict[str, Any] | None
, default:None
) –Keyword arguments passed to
matplotlib.pyplot.errorbar
. -
scatter_kwargs
(dict[str, Any] | None
, default:None
) –Keyword arguments passed to
matplotlib.pyplot.scatter
. -
ax
(Axes | None
, default:None
) –Matplotlib Axes object to plot on. If None, uses current axes.
random_effect
random_effect(
*,
term: S | int,
gam: AbstractGAM,
target: str | None = None,
confidence_interval_level: float = 0.95,
axline_kwargs: dict[str, Any] | None = None,
scatter_kwargs: dict[str, Any] | None = None,
fill_between_kwargs: dict[str, Any] | None = None,
ax: Axes | None = None,
) -> Axes
A QQ-like-plot for random effect terms.
This function plots the estimated random effects against Gaussian quantiles and includes a confidence envelope to assess whether the random effects follow a normal distribution, as assumed by the model.
Parameters:
-
term
(S | int
) –The random effect term to plot. Must be a smooth term with a
RandomEffect
basis function. If an integer is provided, it is assumed to be the index of the term in the predictors fortarget
. -
gam
(AbstractGAM
) –The fitted GAM model containing the random effect.
-
target
(str | None
, default:None
) –The target variable to plot when multiple predictors are present. If None and only one predictor exists, that predictor is used.
-
confidence_interval_level
(float
, default:0.95
) –The confidence level for the confidence envelope.
-
axline_kwargs
(dict[str, Any] | None
, default:None
) –Keyword arguments passed to
matplotlib.axes.Axes.axline
for the reference line. -
scatter_kwargs
(dict[str, Any] | None
, default:None
) –Keyword arguments passed to
matplotlib.axes.Axes.scatter
for the random effect points. -
fill_between_kwargs
(dict[str, Any] | None
, default:None
) –Keyword arguments passed to
matplotlib.axes.Axes.fill_between
for the confidence envelope. -
ax
(Axes | None
, default:None
) –Matplotlib axes to use for the plot. If None, uses the current axes.
Returns:
-
Axes
–The matplotlib axes object.
Note
The confidence interval calculation is based on the formula from: "Worm plot: a simple diagnostic device for modelling growth reference curves" (page 6). The random effects are constrained to be centered, so the reference line passes through (0, 0).
qq
qq(
gam: AbstractGAM,
*,
qq_fun: Callable[[AbstractGAM], QQResult] = qq_simulate,
scatter_kwargs: dict | None = None,
fill_between_kwargs: dict | None = None,
axline_kwargs: dict | None = None,
ax: Axes | None = None,
) -> Axes
A Q-Q plot of deviance residuals.
Parameters:
-
gam
(AbstractGAM
) –The fitted GAM model.
-
qq_fun
(Callable[[AbstractGAM], QQResult]
, default:qq_simulate
) –A function taking only the GAM model, and returning a
QQResult
object storing the theoretical residuals, residuals, and the confidence interval. Defaults toqq_simulate
, which is the most widely supported method only requiring the family to provide a sampling function.qq_transform
can be used for families providing a cdf method, which transforms the data to a known distribution for which an analytical confidence interval is available. -
scatter_kwargs
(dict | None
, default:None
) –Key word arguments passed to
matplotlib.pyplot.scatter
. -
fill_between_kwargs
(dict | None
, default:None
) –Key word arguments passed to
matplotlib.pyplot.fill_between
, for plotting the confidence interval. -
axline_kwargs
(dict | None
, default:None
) –Key word arguments passed to
matplotlib.pyplot.axline
for plotting the reference line. Pass {"disable": True} to avoid plotting. -
ax
(Axes | None
, default:None
) –Matplotlib axes to use for the plot.
Note
To change settings of qq_fun
, use partial application, e.g.
Returns:
-
Axes
–The matplotlib axes object.
Example
As an example, we will create a heavy tailed response variable,
and fit a Gaussian
model, and a
Scat
model.
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from pymgcv.families import Gaussian, Scat
from pymgcv.gam import GAM
import pymgcv.plot as gplt
from pymgcv.terms import S
rng = np.random.default_rng(1)
n = 1000
x = np.linspace(0, 1, n)
y = np.sin(2 * np.pi * x) + rng.standard_t(df=3, size=n) # Heavy-tailed
data = pd.DataFrame({"x": x, "y": y})
models = [
GAM({"y": S("x")}, family=Gaussian()),
GAM({"y": S("x")}, family=Scat()), # Better for heavy-tailed data
]
fig, axes = plt.subplots(ncols=2)
for model, ax in zip(models, axes, strict=False):
model.fit(data)
gplt.qq(model, ax=ax)
ax.set_title(model.family.__class__.__name__)
ax.set_box_aspect(1)
# fig.show() # Uncomment to display the figure
residuals_vs_linear_predictor
residuals_vs_linear_predictor(
gam: AbstractGAM,
type: Literal[
"deviance", "pearson", "scaled.pearson", "working", "response"
] = "deviance",
target: str | None = None,
ax: Axes | None = None,
scatter_kwargs: dict[str, Any] | None = None,
)
Plot the residuals against the linear predictor.
Parameters:
-
gam
(AbstractGAM
) –The fitted GAM model.
-
type
(Literal['deviance', 'pearson', 'scaled.pearson', 'working', 'response']
, default:'deviance'
) –The type of residuals to plot.
-
target
(str | None
, default:None
) –The target variable to plot residuals for.
-
ax
(Axes | None
, default:None
) –The axes to plot on.
-
scatter_kwargs
(dict[str, Any] | None
, default:None
) –Keyword arguments to pass to the scatter plot.
hexbin_residuals
hexbin_residuals(
residuals: ndarray,
var1: str,
var2: str,
data: DataFrame | Mapping[str, ndarray | Series],
*,
gridsize: int = 25,
max_val: int | float | None = None,
ax: Axes | None = None,
**kwargs: Any,
)
Hexbin plot for visualising residuals as function of two variables.
Useful e.g. for assessing if interactions are might be required. This
is a thin wrapper around matplotlib.pyplot.hexbin
, with better defaults
for plotting residuals (e.g. uses a symmetric color scale).
The default reduction function is np.sum(res) / np.sqrt(len(res))
,
which has constant variance w.r.t. the number of points.
Parameters:
-
residuals
(ndarray
) –Residuals to plot.
-
var1
(str
) –Name of the first variable.
-
var2
(str
) –Name of the second variable.
-
data
(DataFrame | Mapping[str, ndarray | Series]
) –The data (containing
var1
andvar2
). -
gridsize
(int
, default:25
) –The number of hexagons in the x-direction. The y direction is chosen such that the hexagons are approximately regular.
-
max_val
(int | float | None
, default:None
) –Maximum and minimum value for the symmetric color scale. Defaults to the maximum absolute value of the residuals.
-
ax
(Axes | None
, default:None
) –Axes to plot on. If None, the current axes are used.
-
**kwargs
(Any
) –Additional keyword arguments passed to
matplotlib.hexbin
.
Example
import numpy as np
import pymgcv.plot as gplt
import matplotlib.pyplot as plt
rng = np.random.default_rng(1)
fig, ax = plt.subplots()
residuals = rng.normal(size=500) # or gam.residuals()
data = {
"x0": rng.normal(size=residuals.shape),
"x1": rng.normal(size=residuals.shape),
}
gplt.hexbin_residuals(residuals, "x0", "x1", data=data, ax=ax)