Skip to content

Commit 88e7c60

Browse files
Merge pull request dmlc#664 from pommedeterresautee/master
Support GLM in importance plot + increase tests #Rstat
2 parents 5575257 + 1678a6f commit 88e7c60

File tree

3 files changed

+39
-26
lines changed

3 files changed

+39
-26
lines changed

R-package/R/xgb.plot.importance.R

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
#' Plot feature importance bar graph
22
#'
3-
#' Read a data.table containing feature importance details and plot it.
3+
#' Read a data.table containing feature importance details and plot it (for both GLM and Trees).
44
#'
55
#' @importFrom magrittr %>%
66
#' @param importance_matrix a \code{data.table} returned by the \code{xgb.importance} function.
@@ -10,7 +10,7 @@
1010
#'
1111
#' @details
1212
#' The purpose of this function is to easily represent the importance of each feature of a model.
13-
#' The function return a ggplot graph, therefore each of its characteristic can be overriden (to customize it).
13+
#' The function returns a ggplot graph, therefore each of its characteristic can be overriden (to customize it).
1414
#' In particular you may want to override the title of the graph. To do so, add \code{+ ggtitle("A GRAPH NAME")} next to the value returned by this function.
1515
#'
1616
#' @examples
@@ -40,21 +40,29 @@ xgb.plot.importance <-
4040
stop("Ckmeans.1d.dp package is required for plotting the importance", call. = FALSE)
4141
}
4242

43+
if(isTRUE(all.equal(colnames(importance_matrix), c("Feature", "Gain", "Cover", "Frequency")))){
44+
y.axe.name <- "Gain"
45+
} else if(isTRUE(all.equal(colnames(importance_matrix), c("Feature", "Weight")))){
46+
y.axe.name <- "Weight"
47+
} else {
48+
stop("Importance matrix is not correct (column names issue)")
49+
}
50+
4351
# To avoid issues in clustering when co-occurences are used
4452
importance_matrix <-
45-
importance_matrix[, .(Gain = sum(Gain)), by = Feature]
53+
importance_matrix[, .(Gain.or.Weight = sum(get(y.axe.name))), by = Feature]
4654

4755
clusters <-
48-
suppressWarnings(Ckmeans.1d.dp::Ckmeans.1d.dp(importance_matrix[,Gain], numberOfClusters))
56+
suppressWarnings(Ckmeans.1d.dp::Ckmeans.1d.dp(importance_matrix[,Gain.or.Weight], numberOfClusters))
4957
importance_matrix[,"Cluster":= clusters$cluster %>% as.character]
5058

5159
plot <-
5260
ggplot2::ggplot(
5361
importance_matrix, ggplot2::aes(
54-
x = stats::reorder(Feature, Gain), y = Gain, width = 0.05
62+
x = stats::reorder(Feature, Gain.or.Weight), y = Gain.or.Weight, width = 0.05
5563
), environment = environment()
5664
) + ggplot2::geom_bar(ggplot2::aes(fill = Cluster), stat = "identity", position =
57-
"identity") + ggplot2::coord_flip() + ggplot2::xlab("Features") + ggplot2::ylab("Gain") + ggplot2::ggtitle("Feature importance") + ggplot2::theme(
65+
"identity") + ggplot2::coord_flip() + ggplot2::xlab("Features") + ggplot2::ylab(y.axe.name) + ggplot2::ggtitle("Feature importance") + ggplot2::theme(
5866
plot.title = ggplot2::element_text(lineheight = .9, face = "bold"), panel.grid.major.y = ggplot2::element_blank()
5967
)
6068

@@ -66,6 +74,6 @@ xgb.plot.importance <-
6674
# They are mainly column names inferred by Data.table...
6775
globalVariables(
6876
c(
69-
"Feature", "Gain", "Cluster", "ggplot", "aes", "geom_bar", "coord_flip", "xlab", "ylab", "ggtitle", "theme", "element_blank", "element_text"
77+
"Feature", "Gain.or.Weight", "Cluster", "ggplot", "aes", "geom_bar", "coord_flip", "xlab", "ylab", "ggtitle", "theme", "element_blank", "element_text", "Gain.or.Weight"
7078
)
7179
)

R-package/man/xgb.plot.importance.Rd

Lines changed: 2 additions & 2 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

R-package/tests/testthat/test_helpers.R

Lines changed: 22 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -14,50 +14,55 @@ df[,AgeCat := as.factor(ifelse(Age > 30, "Old", "Young"))]
1414
df[,ID := NULL]
1515
sparse_matrix <- sparse.model.matrix(Improved~.-1, data = df)
1616
output_vector <- df[,Y := 0][Improved == "Marked",Y := 1][,Y]
17-
bst <- xgboost(data = sparse_matrix, label = output_vector, max.depth = 9,
18-
eta = 1, nthread = 2, nround = 10, objective = "binary:logistic")
17+
bst.Tree <- xgboost(data = sparse_matrix, label = output_vector, max.depth = 9,
18+
eta = 1, nthread = 2, nround = 10, objective = "binary:logistic", booster = "gbtree")
19+
20+
bst.GLM <- xgboost(data = sparse_matrix, label = output_vector,
21+
eta = 1, nthread = 2, nround = 10, objective = "binary:logistic", booster = "gblinear")
1922

2023
feature.names <- agaricus.train$data@Dimnames[[2]]
2124

2225
test_that("xgb.dump works", {
23-
capture.output(print(xgb.dump(bst)))
24-
expect_true(xgb.dump(bst, 'xgb.model.dump', with.stats = T))
26+
capture.output(print(xgb.dump(bst.Tree)))
27+
capture.output(print(xgb.dump(bst.GLM)))
28+
expect_true(xgb.dump(bst.Tree, 'xgb.model.dump', with.stats = T))
2529
})
2630

2731
test_that("xgb.model.dt.tree works with and without feature names", {
2832
names.dt.trees <- c("ID", "Feature", "Split", "Yes", "No", "Missing", "Quality", "Cover",
2933
"Tree", "Yes.Feature", "Yes.Cover", "Yes.Quality", "No.Feature", "No.Cover", "No.Quality")
30-
dt.tree <- xgb.model.dt.tree(feature_names = feature.names, model = bst)
34+
dt.tree <- xgb.model.dt.tree(feature_names = feature.names, model = bst.Tree)
3135
expect_equal(names.dt.trees, names(dt.tree))
3236
expect_equal(dim(dt.tree), c(162, 15))
33-
xgb.model.dt.tree(model = bst)
37+
xgb.model.dt.tree(model = bst.Tree)
3438
})
3539

3640
test_that("xgb.importance works with and without feature names", {
37-
importance <- xgb.importance(feature_names = sparse_matrix@Dimnames[[2]], model = bst)
38-
expect_equal(dim(importance), c(7, 4))
39-
expect_equal(colnames(importance), c("Feature", "Gain", "Cover", "Frequency"))
40-
xgb.importance(model = bst)
41+
importance.Tree <- xgb.importance(feature_names = sparse_matrix@Dimnames[[2]], model = bst.Tree)
42+
expect_equal(dim(importance.Tree), c(7, 4))
43+
expect_equal(colnames(importance.Tree), c("Feature", "Gain", "Cover", "Frequency"))
44+
xgb.importance(model = bst.Tree)
45+
xgb.plot.importance(importance_matrix = importance.Tree)
4146
})
4247

4348
test_that("xgb.importance works with GLM model", {
44-
bst.GLM <- xgboost(data = sparse_matrix, label = output_vector,
45-
eta = 1, nthread = 2, nround = 10, objective = "binary:logistic", booster = "gblinear")
4649
importance.GLM <- xgb.importance(feature_names = sparse_matrix@Dimnames[[2]], model = bst.GLM)
4750
expect_equal(dim(importance.GLM), c(10, 2))
4851
expect_equal(colnames(importance.GLM), c("Feature", "Weight"))
4952
xgb.importance(model = bst.GLM)
53+
xgb.plot.importance(importance.GLM)
5054
})
5155

5256
test_that("xgb.plot.tree works with and without feature names", {
53-
xgb.plot.tree(feature_names = feature.names, model = bst)
54-
xgb.plot.tree(model = bst)
57+
xgb.plot.tree(feature_names = feature.names, model = bst.Tree)
58+
xgb.plot.tree(model = bst.Tree)
5559
})
5660

5761
test_that("xgb.plot.multi.trees works with and without feature names", {
58-
xgb.plot.multi.trees(model = bst, feature_names = feature.names, features.keep = 3)
59-
xgb.plot.multi.trees(model = bst, features.keep = 3)
62+
xgb.plot.multi.trees(model = bst.Tree, feature_names = feature.names, features.keep = 3)
63+
xgb.plot.multi.trees(model = bst.Tree, features.keep = 3)
6064
})
65+
6166
test_that("xgb.plot.deepness works", {
62-
xgb.plot.deepness(model = bst)
67+
xgb.plot.deepness(model = bst.Tree)
6368
})

0 commit comments

Comments
 (0)