-
Notifications
You must be signed in to change notification settings - Fork 4
/
grf.R
66 lines (66 loc) · 4.26 KB
/
grf.R
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
c_grf <- list(label = "Generalized Random Forest",
library = c("grf", "dplyr"),
loop = NULL,
type = c( "Regression"),
parameters = data.frame(parameter = c("sample.fraction","mtry", "num.trees", "min.node.size",
"alpha","honesty", "imbalance.penalty"),
class = c("numeric","numeric","numeric","numeric",
"numeric","logical", "numeric"),
label = c("Fraction of data used to build each tree",
"Number of variables tried for each split",
"Number of trees grown in the forest",
"Target for the minimum number of observations in each tree leaf",
"Tuning parameter that controls the maximum imbalance of a split",
"Whether or not honest splitting",
"Controls how harshly imbalance is penalized"
)),
grid = function(x, y, len = NULL, search = "grid") {
if(search == "grid") {
param_grf <- grf::tune_regression_forest(X = x, Y = y, num.fit.trees = 50)
param_grf <- as.data.frame(t(param_grf$params))
out <- expand.grid(
sample.fraction = 0.5,
mtry = param_grf$mtry,
num.trees = 2000,
min.node.size = param_grf$min.node.size,
alpha = param_grf$alpha,
honesty = TRUE,
imbalance.penalty = param_grf$imbalance.penalty
)
} else {
param_grf <- grf::tune_regression_forest(X = x, Y = y, num.fit.trees = 50)
param_grf <- as.data.frame(t(param_grf$params))
out <- data.frame(
sample.fraction = runif(len,0.2,0.5),
mtry = sample(1:ncol(x), size = len, replace = TRUE),
num.trees = floor(runif(len,1500,2500)),
min.node.size = sample(c(1:4,param_grf$min.node.size), size = len, replace = TRUE),
alpha = runif(len,param_grf$alpha*0.9,param_grf$alpha*1.1),
honesty = sample( c(TRUE), len, replace = TRUE),
imbalance.penalty = runif(len,param_grf$imbalance.penalty*0.9,param_grf$imbalance.penalty*1.1)
)
}
out
},
fit = function(x, y, wts, param, lev, last, classProbs, ...)
grf::regression_forest(X = x, Y = y, sample.fraction = param$sample.fraction, mtry = param$mtry, num.trees = param$num.trees, min.node.size = param$min.node.size,
alpha = param$alpha, honesty = param$honesty, imbalance.penalty = param$imbalance.penalty, ...),
predict = function(modelFit, newdata, submodels = NULL)
if(!is.null(newdata)) predict(modelFit, newdata, estimate.variance = TRUE)$predictions else predict(modelFit, estimate.variance = TRUE)$predictions,
prob = NULL,
predictors = function(x, ...) {
unique(as.vector(variable.names(x)))
},
varImp = function(object, ...){
varImp <- grf::variable_importance(object, ...)
if(object$problemType == "Regression"){
rownames(varImp) <- (grf_object$xNames)
varImp <- data.frame(Overall = varImp)
}
else {
varImp <- "not implemented"
}
},
levels = NULL,
tags = c("Generalized Random Forest", "Ensemble Model", "Bagging", "Implicit Feature Selection")
)