Skip to content

Submodels #200

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Nov 29, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 47 additions & 29 deletions pymc_bart/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -705,7 +705,7 @@ def plot_variable_inclusion(idata, X, labels=None, figsize=None, plot_kwargs=Non

Parameters
----------
idata: InferenceData
idata : InferenceData
InferenceData containing a collection of BART_trees in sample_stats group
X : npt.NDArray[np.float64]
The covariate matrix.
Expand Down Expand Up @@ -784,7 +784,7 @@ def compute_variable_importance( # noqa: PLR0915 PLR0912

Parameters
----------
idata: InferenceData
idata : InferenceData
InferenceData containing a collection of BART_trees in sample_stats group
bartrv : BART Random Variable
BART variable once the model that include it has been fitted.
Expand Down Expand Up @@ -949,8 +949,10 @@ def compute_variable_importance( # noqa: PLR0915 PLR0912

indices = least_important_vars[::-1]

labels = np.array(["+ " + ele if index != 0 else ele for index, ele in enumerate(labels)])

vi_results = {
"indices": indices,
"indices": np.asarray(indices),
"labels": labels[indices],
"r2_mean": r2_mean,
"r2_hdi": r2_hdi,
Expand All @@ -962,8 +964,9 @@ def compute_variable_importance( # noqa: PLR0915 PLR0912

def plot_variable_importance(
vi_results: dict,
labels=None,
figsize=None,
submodels: Optional[Union[list[int], np.ndarray, tuple[int, ...]]] = None,
labels: Optional[list[str]] = None,
figsize: Optional[tuple[float, float]] = None,
plot_kwargs: Optional[dict[str, Any]] = None,
ax: Optional[plt.Axes] = None,
):
Expand All @@ -974,8 +977,11 @@ def plot_variable_importance(
----------
vi_results: Dictionary
Dictionary computed with `compute_variable_importance`
X : npt.NDArray[np.float64]
The covariate matrix.
submodels : Optional[Union[list[int], np.ndarray]]
List of the indices of the submodels to plot. Defaults to None, all variables are ploted.
The indices correspond to order computed by `compute_variable_importance`.
For example `submodels=[0,1]` will plot the two most important variables.
`submodels=[1,0]` is equivalent as values are sorted before use.
labels : Optional[list[str]]
List of the names of the covariates. If X is a DataFrame the names of the covariables will
be taken from it and this argument will be ignored.
Expand All @@ -995,11 +1001,15 @@ def plot_variable_importance(
-------
axes: matplotlib axes
"""
if submodels is None:
submodels = np.sort(vi_results["indices"])
else:
submodels = np.sort(submodels)

indices = vi_results["indices"]
r2_mean = vi_results["r2_mean"]
r2_hdi = vi_results["r2_hdi"]
preds = vi_results["preds"]
indices = vi_results["indices"][submodels]
r2_mean = vi_results["r2_mean"][submodels]
r2_hdi = vi_results["r2_hdi"][submodels]
preds = vi_results["preds"][submodels]
preds_all = vi_results["preds_all"]
samples = preds.shape[1]

Expand All @@ -1016,9 +1026,7 @@ def plot_variable_importance(
_, ax = plt.subplots(1, 1, figsize=figsize)

if labels is None:
labels = vi_results["labels"]

labels = ["+ " + ele if index != 0 else ele for index, ele in enumerate(labels)]
labels = vi_results["labels"][submodels]

r_2_ref = np.array([pearsonr2(preds_all[j], preds_all[j + 1]) for j in range(samples - 1)])

Expand Down Expand Up @@ -1059,21 +1067,27 @@ def plot_variable_importance(
def plot_scatter_submodels(
vi_results: dict,
func: Optional[Callable] = None,
submodels: Optional[Union[list[int], np.ndarray]] = None,
grid: str = "long",
labels=None,
labels: Optional[list[str]] = None,
figsize: Optional[tuple[float, float]] = None,
plot_kwargs: Optional[dict[str, Any]] = None,
axes: Optional[plt.Axes] = None,
):
ax: Optional[plt.Axes] = None,
) -> list[plt.Axes]:
"""
Plot submodel's predictions against reference-model's predictions.

Parameters
----------
vi_results: Dictionary
vi_results : Dictionary
Dictionary computed with `compute_variable_importance`
func : Optional[Callable], by default None.
Arbitrary function to apply to the predictions. Defaults to the identity function.
submodels : Optional[Union[list[int], np.ndarray]]
List of the indices of the submodels to plot. Defaults to None, all variables are ploted.
The indices correspond to order computed by `compute_variable_importance`.
For example `submodels=[0,1]` will plot the two most important variables.
`submodels=[1,0]` is equivalent as values are sorted before use.
grid : str or tuple
How to arrange the subplots. Defaults to "long", one subplot below the other.
Other options are "wide", one subplot next to each other or a tuple indicating the number
Expand All @@ -1092,20 +1106,23 @@ def plot_scatter_submodels(
-------
axes: matplotlib axes
"""
indices = vi_results["indices"]
preds = vi_results["preds"]
if submodels is None:
submodels = np.sort(vi_results["indices"])
else:
submodels = np.sort(submodels)

indices = vi_results["indices"][submodels]
preds = vi_results["preds"][submodels]
preds_all = vi_results["preds_all"]

if axes is None:
_, axes = _get_axes(grid, len(indices), True, True, figsize)
if ax is None:
_, ax = _get_axes(grid, len(indices), True, True, figsize)

if plot_kwargs is None:
plot_kwargs = {}

if labels is None:
labels = vi_results["labels"]

labels = ["+ " + ele if index != 0 else ele for index, ele in enumerate(labels)]
labels = vi_results["labels"][submodels]

if func is not None:
preds = func(preds)
Expand All @@ -1114,22 +1131,23 @@ def plot_scatter_submodels(
min_ = min(np.min(preds), np.min(preds_all))
max_ = max(np.max(preds), np.max(preds_all))

for pred, x_label, ax in zip(preds, labels, axes.ravel()):
ax.plot(
for pred, x_label, axi in zip(preds, labels, ax.ravel()):
axi.plot(
pred,
preds_all,
marker=plot_kwargs.get("marker_scatter", "."),
ls="",
color=plot_kwargs.get("color_scatter", "C0"),
alpha=plot_kwargs.get("alpha_scatter", 0.1),
)
ax.set_xlabel(x_label)
ax.axline(
axi.set_xlabel(x_label)
axi.axline(
[min_, min_],
[max_, max_],
color=plot_kwargs.get("color_ref", "0.5"),
ls=plot_kwargs.get("ls_ref", "--"),
)
return ax


def generate_sequences(n_vars, i_var, include):
Expand Down