diff --git a/pymc_bart/utils.py b/pymc_bart/utils.py index 10b5dfd..e10a511 100644 --- a/pymc_bart/utils.py +++ b/pymc_bart/utils.py @@ -705,7 +705,7 @@ def plot_variable_inclusion(idata, X, labels=None, figsize=None, plot_kwargs=Non Parameters ---------- - idata: InferenceData + idata : InferenceData InferenceData containing a collection of BART_trees in sample_stats group X : npt.NDArray[np.float64] The covariate matrix. @@ -784,7 +784,7 @@ def compute_variable_importance( # noqa: PLR0915 PLR0912 Parameters ---------- - idata: InferenceData + idata : InferenceData InferenceData containing a collection of BART_trees in sample_stats group bartrv : BART Random Variable BART variable once the model that include it has been fitted. @@ -949,8 +949,10 @@ def compute_variable_importance( # noqa: PLR0915 PLR0912 indices = least_important_vars[::-1] + labels = np.array(["+ " + ele if index != 0 else ele for index, ele in enumerate(labels)]) + vi_results = { - "indices": indices, + "indices": np.asarray(indices), "labels": labels[indices], "r2_mean": r2_mean, "r2_hdi": r2_hdi, @@ -962,8 +964,9 @@ def compute_variable_importance( # noqa: PLR0915 PLR0912 def plot_variable_importance( vi_results: dict, - labels=None, - figsize=None, + submodels: Optional[Union[list[int], np.ndarray, tuple[int, ...]]] = None, + labels: Optional[list[str]] = None, + figsize: Optional[tuple[float, float]] = None, plot_kwargs: Optional[dict[str, Any]] = None, ax: Optional[plt.Axes] = None, ): @@ -974,8 +977,11 @@ def plot_variable_importance( ---------- vi_results: Dictionary Dictionary computed with `compute_variable_importance` - X : npt.NDArray[np.float64] - The covariate matrix. + submodels : Optional[Union[list[int], np.ndarray]] + List of the indices of the submodels to plot. Defaults to None, all variables are ploted. + The indices correspond to order computed by `compute_variable_importance`. + For example `submodels=[0,1]` will plot the two most important variables. + `submodels=[1,0]` is equivalent as values are sorted before use. labels : Optional[list[str]] List of the names of the covariates. If X is a DataFrame the names of the covariables will be taken from it and this argument will be ignored. @@ -995,11 +1001,15 @@ def plot_variable_importance( ------- axes: matplotlib axes """ + if submodels is None: + submodels = np.sort(vi_results["indices"]) + else: + submodels = np.sort(submodels) - indices = vi_results["indices"] - r2_mean = vi_results["r2_mean"] - r2_hdi = vi_results["r2_hdi"] - preds = vi_results["preds"] + indices = vi_results["indices"][submodels] + r2_mean = vi_results["r2_mean"][submodels] + r2_hdi = vi_results["r2_hdi"][submodels] + preds = vi_results["preds"][submodels] preds_all = vi_results["preds_all"] samples = preds.shape[1] @@ -1016,9 +1026,7 @@ def plot_variable_importance( _, ax = plt.subplots(1, 1, figsize=figsize) if labels is None: - labels = vi_results["labels"] - - labels = ["+ " + ele if index != 0 else ele for index, ele in enumerate(labels)] + labels = vi_results["labels"][submodels] r_2_ref = np.array([pearsonr2(preds_all[j], preds_all[j + 1]) for j in range(samples - 1)]) @@ -1059,21 +1067,27 @@ def plot_variable_importance( def plot_scatter_submodels( vi_results: dict, func: Optional[Callable] = None, + submodels: Optional[Union[list[int], np.ndarray]] = None, grid: str = "long", - labels=None, + labels: Optional[list[str]] = None, figsize: Optional[tuple[float, float]] = None, plot_kwargs: Optional[dict[str, Any]] = None, - axes: Optional[plt.Axes] = None, -): + ax: Optional[plt.Axes] = None, +) -> list[plt.Axes]: """ Plot submodel's predictions against reference-model's predictions. Parameters ---------- - vi_results: Dictionary + vi_results : Dictionary Dictionary computed with `compute_variable_importance` func : Optional[Callable], by default None. Arbitrary function to apply to the predictions. Defaults to the identity function. + submodels : Optional[Union[list[int], np.ndarray]] + List of the indices of the submodels to plot. Defaults to None, all variables are ploted. + The indices correspond to order computed by `compute_variable_importance`. + For example `submodels=[0,1]` will plot the two most important variables. + `submodels=[1,0]` is equivalent as values are sorted before use. grid : str or tuple How to arrange the subplots. Defaults to "long", one subplot below the other. Other options are "wide", one subplot next to each other or a tuple indicating the number @@ -1092,20 +1106,23 @@ def plot_scatter_submodels( ------- axes: matplotlib axes """ - indices = vi_results["indices"] - preds = vi_results["preds"] + if submodels is None: + submodels = np.sort(vi_results["indices"]) + else: + submodels = np.sort(submodels) + + indices = vi_results["indices"][submodels] + preds = vi_results["preds"][submodels] preds_all = vi_results["preds_all"] - if axes is None: - _, axes = _get_axes(grid, len(indices), True, True, figsize) + if ax is None: + _, ax = _get_axes(grid, len(indices), True, True, figsize) if plot_kwargs is None: plot_kwargs = {} if labels is None: - labels = vi_results["labels"] - - labels = ["+ " + ele if index != 0 else ele for index, ele in enumerate(labels)] + labels = vi_results["labels"][submodels] if func is not None: preds = func(preds) @@ -1114,8 +1131,8 @@ def plot_scatter_submodels( min_ = min(np.min(preds), np.min(preds_all)) max_ = max(np.max(preds), np.max(preds_all)) - for pred, x_label, ax in zip(preds, labels, axes.ravel()): - ax.plot( + for pred, x_label, axi in zip(preds, labels, ax.ravel()): + axi.plot( pred, preds_all, marker=plot_kwargs.get("marker_scatter", "."), @@ -1123,13 +1140,14 @@ def plot_scatter_submodels( color=plot_kwargs.get("color_scatter", "C0"), alpha=plot_kwargs.get("alpha_scatter", 0.1), ) - ax.set_xlabel(x_label) - ax.axline( + axi.set_xlabel(x_label) + axi.axline( [min_, min_], [max_, max_], color=plot_kwargs.get("color_ref", "0.5"), ls=plot_kwargs.get("ls_ref", "--"), ) + return ax def generate_sequences(n_vars, i_var, include):