diff --git a/pymc_bart/__init__.py b/pymc_bart/__init__.py index 8774803..18fe054 100644 --- a/pymc_bart/__init__.py +++ b/pymc_bart/__init__.py @@ -17,11 +17,14 @@ from pymc_bart.pgbart import PGBART from pymc_bart.split_rules import ContinuousSplitRule, OneHotSplitRule, SubsetSplitRule from pymc_bart.utils import ( + compute_variable_importance, plot_convergence, plot_dependence, plot_ice, plot_pdp, + plot_scatter_submodels, plot_variable_importance, + plot_variable_inclusion, ) __all__ = [ @@ -30,11 +33,14 @@ "ContinuousSplitRule", "OneHotSplitRule", "SubsetSplitRule", + "compute_variable_importance", "plot_convergence", "plot_dependence", "plot_ice", "plot_pdp", + "plot_scatter_submodels", "plot_variable_importance", + "plot_variable_inclusion", ] __version__ = "0.7.1" diff --git a/pymc_bart/utils.py b/pymc_bart/utils.py index a50f2d9..e8c60bb 100644 --- a/pymc_bart/utils.py +++ b/pymc_bart/utils.py @@ -1,3 +1,4 @@ +# pylint: disable=too-many-branches """Utility function for variable selection and bart interpretability.""" import warnings @@ -248,7 +249,7 @@ def identity(x): _, ) = _prepare_plot_data(X, Y, "linear", None, var_idx, var_discrete) - fig, axes, shape = _get_axes(bartrv, var_idx, grid, sharey, figsize, ax) + fig, axes, shape = _create_figure_axes(bartrv, var_idx, grid, sharey, figsize, ax) instances_ary = rng.choice(range(X.shape[0]), replace=False, size=instances) idx_s = list(range(X.shape[0])) @@ -270,7 +271,6 @@ def identity(x): new_x = fake_X[:, var] p_d = np.array(y_pred) - print(p_d.shape) for s_i in range(shape): if centered: @@ -398,7 +398,7 @@ def identity(x): xs_values, ) = _prepare_plot_data(X, Y, xs_interval, xs_values, var_idx, var_discrete) - fig, axes, shape = _get_axes(bartrv, var_idx, grid, sharey, figsize, ax) + fig, axes, shape = _create_figure_axes(bartrv, var_idx, grid, sharey, figsize, ax) count = 0 fake_X = _create_pdp_data(X, xs_interval, xs_values) @@ -447,7 +447,7 @@ def identity(x): return axes -def _get_axes( +def _create_figure_axes( bartrv: Variable, var_idx: List[int], grid: str = "long", @@ -492,29 +492,8 @@ def _get_axes( n_plots = len(var_idx) * shape if ax is None: - if grid == "long": - fig, axes = plt.subplots(n_plots, sharey=sharey, figsize=figsize) - if n_plots == 1: - axes = [axes] - elif grid == "wide": - fig, axes = plt.subplots(1, n_plots, sharey=sharey, figsize=figsize) - if n_plots == 1: - axes = [axes] - elif isinstance(grid, tuple): - grid_size = grid[0] * grid[1] - if n_plots > grid_size: - warnings.warn( - """The grid is smaller than the number of available variables to plot. - Automatically adjusting the grid size.""" - ) - grid = (n_plots // grid[1] + (n_plots % grid[1] > 0), grid[1]) - - fig, axes = plt.subplots(*grid, sharey=sharey, figsize=figsize) - axes = np.ravel(axes) + fig, axes = _get_axes(grid, n_plots, False, sharey, figsize) - for i in range(n_plots, len(axes)): - fig.delaxes(axes[i]) - axes = axes[:n_plots] elif isinstance(ax, np.ndarray): axes = ax fig = ax[0].get_figure() @@ -525,6 +504,33 @@ def _get_axes( return fig, axes, shape +def _get_axes(grid, n_plots, sharex, sharey, figsize): + if grid == "long": + fig, axes = plt.subplots(n_plots, sharex=sharex, sharey=sharey, figsize=figsize) + if n_plots == 1: + axes = [axes] + elif grid == "wide": + fig, axes = plt.subplots(1, n_plots, sharex=sharex, sharey=sharey, figsize=figsize) + if n_plots == 1: + axes = [axes] + elif isinstance(grid, tuple): + grid_size = grid[0] * grid[1] + if n_plots > grid_size: + warnings.warn( + """The grid is smaller than the number of available variables to plot. + Automatically adjusting the grid size.""" + ) + grid = (n_plots // grid[1] + (n_plots % grid[1] > 0), grid[1]) + + fig, axes = plt.subplots(*grid, sharey=sharey, figsize=figsize) + axes = np.ravel(axes) + + for i in range(n_plots, len(axes)): + fig.delaxes(axes[i]) + axes = axes[:n_plots] + return fig, axes + + def _prepare_plot_data( X: npt.NDArray[np.float64], Y: Optional[npt.NDArray[np.float64]] = None, @@ -693,18 +699,86 @@ def _smooth_mean( return x_data, y_data -def plot_variable_importance( # noqa: PLR0915 +def plot_variable_inclusion(idata, X, labels=None, figsize=None, plot_kwargs=None, ax=None): + """ + Plot normalized variable inclusion from BART model. + + Parameters + ---------- + idata: InferenceData + InferenceData containing a collection of BART_trees in sample_stats group + X : npt.NDArray[np.float64] + The covariate matrix. + labels : Optional[List[str]] + List of the names of the covariates. If X is a DataFrame the names of the covariables will + be taken from it and this argument will be ignored. + figsize : tuple + Figure size. If None it will be defined automatically. + plot_kwargs : dict + Additional keyword arguments for the plot. Defaults to None. + Valid keys are: + - color: matplotlib valid color for VI + - marker: matplotlib valid marker for VI + - ls: matplotlib valid linestyle for the VI line + - rotation: float, rotation of the x-axis labels + ax : axes + Matplotlib axes. + + Returns + ------- + idxs: indexes of the covariates from higher to lower relative importance + axes: matplotlib axes + """ + if plot_kwargs is None: + plot_kwargs = {} + + VIs = idata["sample_stats"]["variable_inclusion"].mean(("chain", "draw")).values + VIs = VIs / VIs.sum() + idxs = np.argsort(VIs) + + indices = idxs[::-1] + n_vars = len(indices) + + if hasattr(X, "columns") and hasattr(X, "to_numpy"): + labels = X.columns + + if labels is None: + labels = np.arange(n_vars).astype(str) + + new_labels = ["+ " + ele if index != 0 else ele for index, ele in enumerate(labels[indices])] + + ticks = np.arange(n_vars, dtype=int) + + if figsize is None: + figsize = (8, 3) + + if ax is None: + _, ax = plt.subplots(1, 1, figsize=figsize) + + ax.plot( + VIs[indices], + color=plot_kwargs.get("color", "k"), + marker=plot_kwargs.get("marker", "o"), + ls=plot_kwargs.get("ls", "-"), + ) + + ax.set_xticks(ticks, new_labels, rotation=plot_kwargs.get("rotation", 0)) + + ax.axhline(1 / n_vars, color="0.5", linestyle="--") + ax.set_ylim(0, 1) + + return idxs, ax + + +def compute_variable_importance( # noqa: PLR0915 PLR0912 idata: az.InferenceData, bartrv: Variable, X: npt.NDArray[np.float64], - labels: Optional[List[str]] = None, method: str = "VI", - figsize: Optional[Tuple[float, float]] = None, + fixed: int = 0, samples: int = 50, random_seed: Optional[int] = None, - plot_kwargs: Optional[Dict[str, Any]] = None, - ax: Optional[plt.Axes] = None, -) -> Tuple[List[int], Union[List[plt.Axes], Any]]: +) -> dict[str, object]: """ Estimates variable importance from the BART-posterior. @@ -716,87 +790,74 @@ def plot_variable_importance( # noqa: PLR0915 BART variable once the model that include it has been fitted. X : npt.NDArray[np.float64] The covariate matrix. - labels : Optional[List[str]] - List of the names of the covariates. If X is a DataFrame the names of the covariables will - be taken from it and this argument will be ignored. method : str - Method used to rank variables. Available options are "VI" (default) and "backward". + Method used to rank variables. Available options are "VI" (default), "backward" + and "backward_VI". The R squared will be computed following this ranking. "VI" counts how many times each variable is included in the posterior distribution of trees. "backward" uses a backward search based on the R squared. - VI requieres less computation time. - figsize : tuple - Figure size. If None it will be defined automatically. + "backward_VI" combines both methods with the backward search excluding + the ``fixed`` number of variables with the lowest variable inclusion. + "VI" is the fastest method, while "backward" is the slowest. + fixed : Optional[int] + Number of variables to fix in the backward search. Defaults to None. + Must be greater than 0 and less than the number of variables. + Ignored if method is "VI" or "backward". samples : int - Number of predictions used to compute correlation for subsets of variables. Defaults to 100 + Number of predictions used to compute correlation for subsets of variables. Defaults to 50 random_seed : Optional[int] random_seed used to sample from the posterior. Defaults to None. - plot_kwargs : dict - Additional keyword arguments for the plot. Defaults to None. - Valid keys are: - - color_r2: matplotlib valid color for error bars - - marker_r2: matplotlib valid marker for the mean R squared - - marker_fc_r2: matplotlib valid marker face color for the mean R squared - - ls_ref: matplotlib valid linestyle for the reference line - - color_ref: matplotlib valid color for the reference line - ax : axes - Matplotlib axes. Returns ------- - idxs: indexes of the covariates from higher to lower relative importance - axes: matplotlib axes + vi_results: dictionary """ + if method not in ["VI", "backward", "backward_VI"]: + raise ValueError("method must be 'VI', 'backward' or 'backward_VI'") + rng = np.random.default_rng(random_seed) all_trees = bartrv.owner.op.all_trees - if plot_kwargs is None: - plot_kwargs = {} - if bartrv.ndim == 1: # type: ignore shape = 1 else: shape = bartrv.eval().shape[0] if hasattr(X, "columns") and hasattr(X, "to_numpy"): - labels = X.columns X = X.to_numpy() n_vars = X.shape[1] - - if figsize is None: - figsize = (8, 3) - - if ax is None: - _, ax = plt.subplots(1, 1, figsize=figsize) - - if labels is None: - labels_ary = np.arange(n_vars).astype(str) + r2_mean = np.zeros(n_vars) + r2_hdi = np.zeros((n_vars, 2)) + preds = np.zeros((n_vars, samples, bartrv.eval().shape[0])) + + if method == "backward_VI": + if fixed >= n_vars: + raise ValueError("fixed must be less than the number of variables") + elif fixed < 1: + raise ValueError("fixed must be greater than 0") + init = fixed + 1 else: - labels_ary = np.array(labels) - - ticks = np.arange(n_vars, dtype=int) + fixed = 0 + init = 0 predicted_all = _sample_posterior( all_trees, X=X, rng=rng, size=samples, excluded=None, shape=shape ) - r_2_ref = np.array( - [pearsonr2(predicted_all[j], predicted_all[j + 1]) for j in range(samples - 1)] - ) - - if method == "VI": + if method in ["VI", "backward_VI"]: idxs = np.argsort( idata["sample_stats"]["variable_inclusion"].mean(("chain", "draw")).values ) subsets = [idxs[:-i].tolist() for i in range(1, len(idxs))] subsets.append(None) # type: ignore + if method == "backward_VI": + subsets = subsets[-init:] + indices: List[int] = list(idxs[::-1]) - r2_mean = np.zeros(n_vars) - r2_hdi = np.zeros((n_vars, 2)) for idx, subset in enumerate(subsets): predicted_subset = _sample_posterior( all_trees=all_trees, @@ -811,19 +872,24 @@ def plot_variable_importance( # noqa: PLR0915 ) r2_mean[idx] = np.mean(r_2) r2_hdi[idx] = az.hdi(r_2) - - elif method == "backward": - r2_mean = np.zeros(n_vars) - r2_hdi = np.zeros((n_vars, 2)) - - variables = set(range(n_vars)) - least_important_vars: List[int] = [] - indices = [] + preds[idx] = predicted_subset.squeeze() + + if method in ["backward", "backward_VI"]: + if method == "backward_VI": + least_important_vars: List[int] = indices[-fixed:] + r2_mean_vi = r2_mean[:init] + r2_hdi_vi = r2_hdi[:init] + preds_vi = preds[:init] + r2_mean = np.zeros(n_vars - fixed - 1) + r2_hdi = np.zeros((n_vars - fixed - 1, 2)) + preds = np.zeros((n_vars - fixed - 1, samples, bartrv.eval().shape[0])) + else: + least_important_vars = [] # Iterate over each variable to determine its contribution # least_important_vars tracks the variable with the lowest contribution - # at the current stage. One new varible is added at each iteration. - for i_var in range(n_vars): + # at the current stage. One new variable is added at each iteration. + for i_var in range(init, n_vars): # Generate all possible subsets by adding one variable at a time to # least_important_vars subsets = generate_sequences(n_vars, i_var, least_important_vars) @@ -851,30 +917,116 @@ def plot_variable_importance( # noqa: PLR0915 max_r_2 = mean_r_2 least_important_subset = subset r_2_without_least_important_vars = r_2 + least_important_samples = predicted_subset # Save values for plotting later - r2_mean[i_var] = max_r_2 - r2_hdi[i_var] = az.hdi(r_2_without_least_important_vars) + r2_mean[i_var - init] = max_r_2 + r2_hdi[i_var - init] = az.hdi(r_2_without_least_important_vars) + preds[i_var - init] = least_important_samples.squeeze() # extend current list of least important variable - least_important_vars += least_important_subset + for var_i in least_important_subset: + if var_i not in least_important_vars: + least_important_vars.append(var_i) + + # Add the remaining variables to the list of least important variables + for var_i in range(n_vars): + if var_i not in least_important_vars: + least_important_vars.append(var_i) + + if method == "backward_VI": + r2_mean = np.concatenate((r2_mean[::-1], r2_mean_vi)) + r2_hdi = np.concatenate((r2_hdi[::-1], r2_hdi_vi)) + preds = np.concatenate((preds[::-1], preds_vi)) + else: + r2_mean = r2_mean[::-1] + r2_hdi = r2_hdi[::-1] + preds = preds[::-1] + + indices = least_important_vars[::-1] + + vi_results = { + "indices": indices, + "r2_mean": r2_mean, + "r2_hdi": r2_hdi, + "preds": preds, + "preds_all": predicted_all.squeeze(), + } + return vi_results + + +def plot_variable_importance( + vi_results: dict, + X: npt.NDArray[np.float64], + labels=None, + figsize=None, + plot_kwargs: Optional[Dict[str, Any]] = None, + ax: Optional[plt.Axes] = None, +): + """ + Estimates variable importance from the BART-posterior. - # add index of removed variable - indices += list(set(least_important_subset) - set(indices)) + Parameters + ---------- + vi_results: Dictionary + Dictionary computed with `compute_variable_importance` + X : npt.NDArray[np.float64] + The covariate matrix. + labels : Optional[List[str]] + List of the names of the covariates. If X is a DataFrame the names of the covariables will + be taken from it and this argument will be ignored. + plot_kwargs : dict + Additional keyword arguments for the plot. Defaults to None. + Valid keys are: + - color_r2: matplotlib valid color for error bars + - marker_r2: matplotlib valid marker for the mean R squared + - marker_fc_r2: matplotlib valid marker face color for the mean R squared + - ls_ref: matplotlib valid linestyle for the reference line + - color_ref: matplotlib valid color for the reference line + - rotation: float, rotation angle of the x-axis labels. Defaults to 0. + ax : axes + Matplotlib axes. - # add remaining index - indices += list(set(variables) - set(least_important_vars)) + Returns + ------- + axes: matplotlib axes + """ - indices = indices[::-1] - r2_mean = r2_mean[::-1] - r2_hdi = r2_hdi[::-1] + indices = vi_results["indices"] + r2_mean = vi_results["r2_mean"] + r2_hdi = vi_results["r2_hdi"] + preds = vi_results["preds"] + preds_all = vi_results["preds_all"] + samples = preds.shape[1] - new_labels = [ - "+ " + ele if index != 0 else ele for index, ele in enumerate(labels_ary[indices]) - ] + n_vars = len(indices) + ticks = np.arange(n_vars, dtype=int) + + if plot_kwargs is None: + plot_kwargs = {} + + if figsize is None: + figsize = (8, 3) + + if hasattr(X, "columns") and hasattr(X, "to_numpy"): + labels = X.columns + X = X.to_numpy() + + if ax is None: + _, ax = plt.subplots(1, 1, figsize=figsize) + + if labels is None: + labels = np.arange(n_vars).astype(str) + else: + labels = np.asarray(labels) + + new_labels = ["+ " + ele if index != 0 else ele for index, ele in enumerate(labels[indices])] + + r_2_ref = np.array([pearsonr2(preds_all[j], preds_all[j + 1]) for j in range(samples - 1)]) r2_yerr_min = np.clip(r2_mean - r2_hdi[:, 0], 0, None) r2_yerr_max = np.clip(r2_hdi[:, 1] - r2_mean, 0, None) + ax.errorbar( ticks, r2_mean, @@ -903,7 +1055,28 @@ def plot_variable_importance( # noqa: PLR0915 ax.set_ylim(0, 1) ax.set_xlim(-0.5, n_vars - 0.5) - return indices, ax + return ax + + +def plot_scatter_submodels(vi_results, func=None, grid="long", axes=None): + indices = vi_results["indices"] + preds = vi_results["preds"] + preds_all = vi_results["preds_all"] + + if axes is None: + _, axes = _get_axes(grid, len(indices), False, True, None) + + func = None + if func is not None: + preds = func(preds) + preds_all = func(preds_all) + + min_ = min(np.min(preds), np.min(preds_all)) + max_ = max(np.max(preds), np.max(preds_all)) + + for pred, ax in zip(preds, axes.ravel()): + ax.plot(pred, preds_all, ".", color="C0", alpha=0.1) + ax.axline([min_, min_], [max_, max_], color="0.5") def generate_sequences(n_vars, i_var, include): diff --git a/tests/test_bart.py b/tests/test_bart.py index e56735e..c10fc94 100644 --- a/tests/test_bart.py +++ b/tests/test_bart.py @@ -184,12 +184,17 @@ def test_pdp(self, kwargs): @pytest.mark.parametrize( "kwargs", [ - {}, + {"samples": 50}, {"labels": ["A", "B", "C"], "samples": 2, "figsize": (6, 6)}, ], ) def test_vi(self, kwargs): - pmb.plot_variable_importance(self.idata, X=self.X, bartrv=self.mu, **kwargs) + samples = kwargs.pop("samples") + vi_results = pmb.compute_variable_importance( + self.idata, bartrv=self.mu, X=self.X, samples=samples + ) + pmb.plot_variable_importance(vi_results, X=self.X, **kwargs) + pmb.plot_scatter_submodels(vi_results) def test_pdp_pandas_labels(self): pd = pytest.importorskip("pandas")