Este tutorial é baseado nos exemplos do livro An Introduction to Statistical Learning, de Gareth James, Daniela Witten, Trevor Hastie e Rob Tibshirani (https://www.statlearning.com/).
Neste tutorial, a biblioteca tree é usada para construir
árvores de classificação e regressão:
library(tree)
Primeiro usamos árvores de classificação para analisar o conjunto de
dados Carseats.
Nesse conjunto de dados, Sales é uma variável contínua,
e por isso começamos recodificando-a como uma variável binária (para
podermos usá-la em nosso exemplo sobre árvores de classifição).
Utilizamos a função ifelse() para criar uma nova variável,
chamada High, que recebe o valor Yes se a
variável Sales for maior que \(8\), e No caso contrário.
library(ISLR2)
attach(Carseats)
High <- factor(ifelse(Sales <= 8, "No", "Yes"))
Por fim, usamos a função data.frame() para combinar a
variável High com o restante dos dados de
Carseats:
Carseats <- data.frame(Carseats, High)
Note que poderiamos ter feito mesmo com:
Carseats$High <- factor(ifelse(Sales <= 8, "No", "Yes"))
Agora usamos a função tree() para ajustar uma árvore de
classificação com o objetivo de prever High utilizando
todas as variáveis, exceto Sales. A sintaxe da função
tree() é bastante semelhante à da função
lm():
tree.carseats <- tree(High ~ . - Sales, Carseats)
A função summary() lista as variáveis que são utilizadas
como nós (nodes) internos na árvore, o número de nós terminais
(leaf nodes) e a taxa de erro (de treinamento):
summary(tree.carseats)
Classification tree:
tree(formula = High ~ . - Sales, data = Carseats)
Variables actually used in tree construction:
[1] "ShelveLoc" "Price" "Income" "CompPrice" "Population"
[6] "Advertising" "Age" "US"
Number of terminal nodes: 27
Residual mean deviance: 0.4575 = 170.7 / 373
Misclassification error rate: 0.09 = 36 / 400
Observamos que a taxa de erro de treinamento é de \(9\%\). Para árvores de classificação
criadas com a função tree(), a deviance relatada
na saída da função summary() é dada por: \[
-2 \sum_m \sum_k n_{mk} \log \hat{p}_{mk}
\] onde \(n_{mk}\) é o número de
observações no \(m\)-ésimo nó terminal
que pertencem à \(k\)-ésima classe. A
deviance está relacionada à entropia e mede a falta de
homogeneidade das classes nos nós da árvore. Quanto menor a deviance,
mais puros são os nós. Lembre-se que a impureza de Gini também
é frequentemente usada para medir a impureza dos nós, mas utiliza uma
métrica diferente baseada na probabilidade ao quadrado das classes
(\(k\)): \[
1-\overset{k}{\sum_{i=1}} (p_i)^2
\] Ambas, deviance e impureza de Gini, são
utilizadas como critérios para dividir os nós em árvores de
classificação, com o objetivo de criar nós mais homogêneos em termos de
classe. A escolha entre elas pode, em alguns casos, levar a árvores
ligeiramente diferentes, mas geralmente produzem resultados
similares.
Uma deviance pequena indica que a árvore fornece um bom ajuste aos dados de treinamento. A deviance média residual relatada é simplesmente a deviance dividida por \(n - |{T}_0|\), que neste caso é \(400 - 27 = 373\) (i.e., o número total de observações dividido pelo número de nós terminais).
Uma das propriedades mais atrativas das árvores é que elas podem ser
exibidas graficamente. Utilizamos a função plot() para
mostrar a estrutura da árvore, e a função text() para
exibir os rótulos dos nós. O argumento pretty = 0 instrui o
R a incluir os nomes das categorias para quaisquer preditores
qualitativos, ao invés de simplesmente mostrar uma letra para cada
categoria:
plot(tree.carseats)
text(tree.carseats, pretty = 0)
O indicador mais importante de Sales parece ser a
localização da prateleira, uma vez que o primeiro ramo diferencia locais
Good de locais Bad e Medium.
Se simplesmente digitarmos o nome do objeto da árvore
(tree.carseats, no nosso caso), o R imprime
uma saída correspondente a cada ramo da árvore. O R exibe o
critério de divisão (por exemplo, Price < 92.5), o
número de observações naquele ramo, a deviance, a previsão
geral para o ramo (Yes ou No) e a fração de
observações naquele ramo que assumem os valores Yes e
No. Ramos que levam a nós terminais são indicados com
asteriscos:
tree.carseats
node), split, n, deviance, yval, (yprob)
* denotes terminal node
1) root 400 541.500 No ( 0.59000 0.41000 )
2) ShelveLoc: Bad,Medium 315 390.600 No ( 0.68889 0.31111 )
4) Price < 92.5 46 56.530 Yes ( 0.30435 0.69565 )
8) Income < 57 10 12.220 No ( 0.70000 0.30000 )
16) CompPrice < 110.5 5 0.000 No ( 1.00000 0.00000 ) *
17) CompPrice > 110.5 5 6.730 Yes ( 0.40000 0.60000 ) *
9) Income > 57 36 35.470 Yes ( 0.19444 0.80556 )
18) Population < 207.5 16 21.170 Yes ( 0.37500 0.62500 ) *
19) Population > 207.5 20 7.941 Yes ( 0.05000 0.95000 ) *
5) Price > 92.5 269 299.800 No ( 0.75465 0.24535 )
10) Advertising < 13.5 224 213.200 No ( 0.81696 0.18304 )
20) CompPrice < 124.5 96 44.890 No ( 0.93750 0.06250 )
40) Price < 106.5 38 33.150 No ( 0.84211 0.15789 )
80) Population < 177 12 16.300 No ( 0.58333 0.41667 )
160) Income < 60.5 6 0.000 No ( 1.00000 0.00000 ) *
161) Income > 60.5 6 5.407 Yes ( 0.16667 0.83333 ) *
81) Population > 177 26 8.477 No ( 0.96154 0.03846 ) *
41) Price > 106.5 58 0.000 No ( 1.00000 0.00000 ) *
21) CompPrice > 124.5 128 150.200 No ( 0.72656 0.27344 )
42) Price < 122.5 51 70.680 Yes ( 0.49020 0.50980 )
84) ShelveLoc: Bad 11 6.702 No ( 0.90909 0.09091 ) *
85) ShelveLoc: Medium 40 52.930 Yes ( 0.37500 0.62500 )
170) Price < 109.5 16 7.481 Yes ( 0.06250 0.93750 ) *
171) Price > 109.5 24 32.600 No ( 0.58333 0.41667 )
342) Age < 49.5 13 16.050 Yes ( 0.30769 0.69231 ) *
343) Age > 49.5 11 6.702 No ( 0.90909 0.09091 ) *
43) Price > 122.5 77 55.540 No ( 0.88312 0.11688 )
86) CompPrice < 147.5 58 17.400 No ( 0.96552 0.03448 ) *
87) CompPrice > 147.5 19 25.010 No ( 0.63158 0.36842 )
174) Price < 147 12 16.300 Yes ( 0.41667 0.58333 )
348) CompPrice < 152.5 7 5.742 Yes ( 0.14286 0.85714 ) *
349) CompPrice > 152.5 5 5.004 No ( 0.80000 0.20000 ) *
175) Price > 147 7 0.000 No ( 1.00000 0.00000 ) *
11) Advertising > 13.5 45 61.830 Yes ( 0.44444 0.55556 )
22) Age < 54.5 25 25.020 Yes ( 0.20000 0.80000 )
44) CompPrice < 130.5 14 18.250 Yes ( 0.35714 0.64286 )
88) Income < 100 9 12.370 No ( 0.55556 0.44444 ) *
89) Income > 100 5 0.000 Yes ( 0.00000 1.00000 ) *
45) CompPrice > 130.5 11 0.000 Yes ( 0.00000 1.00000 ) *
23) Age > 54.5 20 22.490 No ( 0.75000 0.25000 )
46) CompPrice < 122.5 10 0.000 No ( 1.00000 0.00000 ) *
47) CompPrice > 122.5 10 13.860 No ( 0.50000 0.50000 )
94) Price < 125 5 0.000 Yes ( 0.00000 1.00000 ) *
95) Price > 125 5 0.000 No ( 1.00000 0.00000 ) *
3) ShelveLoc: Good 85 90.330 Yes ( 0.22353 0.77647 )
6) Price < 135 68 49.260 Yes ( 0.11765 0.88235 )
12) US: No 17 22.070 Yes ( 0.35294 0.64706 )
24) Price < 109 8 0.000 Yes ( 0.00000 1.00000 ) *
25) Price > 109 9 11.460 No ( 0.66667 0.33333 ) *
13) US: Yes 51 16.880 Yes ( 0.03922 0.96078 ) *
7) Price > 135 17 22.070 No ( 0.64706 0.35294 )
14) Income < 46 6 0.000 No ( 1.00000 0.00000 ) *
15) Income > 46 11 15.160 Yes ( 0.45455 0.54545 ) *
Para avaliar adequadamente o desempenho de uma árvore de
classificação nesses dados, devemos estimar o erro de teste em vez de
simplesmente calcular o erro de treinamento. Dividimos as observações em
um conjunto de treinamento e um conjunto de teste, construímos a árvore
usando o conjunto de treinamento e avaliamos seu desempenho nos dados de
teste. A função predict() pode ser usada para esse
propósito. No caso de uma árvore de classificação, o argumento
type = "class" instrui o R a retornar a
previsão de classe propriamente dita.
Essa abordagem leva a previsões corretas para cerca de \(77\%\) das localizações no conjunto de teste seguindo o script abaixo:
set.seed(2)
train <- sample(1:nrow(Carseats), 200)
Carseats.test <- Carseats[-train, ]
High.test <- High[-train]
tree.carseats <- tree(High ~ . - Sales, Carseats,
subset = train)
tree.pred <- predict(tree.carseats, Carseats.test,
type = "class")
table(tree.pred, High.test)
High.test
tree.pred No Yes
No 104 33
Yes 13 50
(104 + 50) / 200
[1] 0.77
Em raros casos, se você executar novamente a função
predict(), poderá obter resultados ligeiramente diferentes
devido a “empates”. Por exemplo, isso pode ocorrer quando as observações
de treinamento correspondentes a um nó terminal estão igualmente
divididas entre os valores de resposta Yes e
No.
Em seguida, consideramos se a “poda” (prune) da árvore pode
levar a melhores resultados. A função cv.tree() realiza
validação cruzada para determinar o nível ideal de complexidade da
árvore. A poda por complexidade de custo é usada para selecionar uma
sequência de árvores a serem consideradas.
Usamos o argumento FUN = prune.misclass para indicar que
queremos que a taxa de erro de classificação guie o processo de
validação cruzada e de poda, em vez do padrão da função
cv.tree(), que é a deviance. A função
cv.tree() (onde cv significa
cross-validation) informa o número de nós terminais de cada
árvore considerada (size), bem como a taxa de erro
correspondente e o valor do parâmetro de complexidade de custo utilizado
(k):
set.seed(7)
cv.carseats <- cv.tree(tree.carseats, FUN = prune.misclass)
names(cv.carseats)
[1] "size" "dev" "k" "method"
cv.carseats
$size
[1] 21 19 14 9 8 5 3 2 1
$dev
[1] 75 75 75 74 82 83 83 85 82
$k
[1] -Inf 0.0 1.0 1.4 2.0 3.0 4.0 9.0 18.0
$method
[1] "misclass"
attr(,"class")
[1] "prune" "tree.sequence"
Apesar do nome, dev corresponde ao número de erros de
validação cruzada. O output acima indica que a árvore com \(9\) nós terminais resulta em apenas \(74\) erros de validação cruzada. Isso pode
ser visualizado quando plotamos a taxa de erro como uma função tanto de
size quanto de k da seguinte forma:
par(mfrow = c(1, 2))
plot(cv.carseats$size, cv.carseats$dev, type = "b")
plot(cv.carseats$k, cv.carseats$dev, type = "b")
Agora aplicamos a função prune.misclass() para podar a
árvore e obter a árvore com \(9\) nós
terminais:
prune.carseats <- prune.misclass(tree.carseats, best = 9)
plot(prune.carseats)
text(prune.carseats, pretty = 0)
Como a árvore podada se comporta no conjunto de dados de teste? Mais
uma vez, aplicamos a função predict():
tree.pred <- predict(prune.carseats, Carseats.test,
type = "class")
table(tree.pred, High.test)
High.test
tree.pred No Yes
No 97 25
Yes 20 58
(97 + 58) / 200
[1] 0.775
Agora, \(77,5\%\) das observações de teste são corretamente classificadas. Então, não só o processo de poda resultou em uma árvore mais interpretável, mas também melhorou ligeiramente a precisão da classificação.
Se aumentarmos o valor de best, obtemos uma árvore
podada maior com menor precisão de classificação (\(77\%\), no caso):
prune.carseats <- prune.misclass(tree.carseats, best = 14)
plot(prune.carseats)
text(prune.carseats, pretty = 0)
tree.pred <- predict(prune.carseats, Carseats.test,
type = "class")
table(tree.pred, High.test)
High.test
tree.pred No Yes
No 102 31
Yes 15 52
(102 + 52) / 200
[1] 0.77
Aqui, ajustamos uma árvore de regressão ao conjunto de dados
Boston, já utilizado em tutoriais anteriores. Primeiro,
criamos um conjunto de treinamento e ajustamos a árvore aos dados de
treinamento usando novamente a função tree():
set.seed(1)
train <- sample(1:nrow(Boston), nrow(Boston) / 2)
tree.boston <- tree(medv ~ ., Boston, subset = train)
summary(tree.boston)
Regression tree:
tree(formula = medv ~ ., data = Boston, subset = train)
Variables actually used in tree construction:
[1] "rm" "lstat" "crim" "age"
Number of terminal nodes: 7
Residual mean deviance: 10.38 = 2555 / 246
Distribution of residuals:
Min. 1st Qu. Median Mean 3rd Qu. Max.
-10.1800 -1.7770 -0.1775 0.0000 1.9230 16.5800
Note que a saída do summary() indica que apenas quatro
das variáveis foram usadas na construção da árvore. Lembre-se que no
contexto de uma árvore de regressão, a deviance é simplesmente
a soma dos erros quadrados para a árvore. Agora, vamos plotar a
árvore:
plot(tree.boston)
text(tree.boston, pretty = 0)
A variável lstat mede a porcentagem de domicílios de
baixo status socioeconômico, enquanto a variável rm
corresponde ao número médio de cômodos. A árvore indica que valores
maiores de rm ou menores de lstat correspondem
a casas mais caras. Por exemplo, a árvore prevê um preço mediano de casa
de \(45.38\) (mil dólares) para casas
em setores censitários nos quais rm >= 7.553.
Agora, usamos a função cv.tree() (onde cv
significa cross-validation) para verificar se o poda da árvore
pode melhorar o desempenho:
cv.boston <- cv.tree(tree.boston)
plot(cv.boston$size, cv.boston$dev, type = "b")
Neste caso, a árvore mais complexa sob consideração é selecionada por
meio de validação cruzada. No entanto, se desejarmos podar a
árvore, podemos fazer isso da seguinte forma, utilizando a função
prune.tree():
prune.boston <- prune.tree(tree.boston, best = 5)
plot(prune.boston)
text(prune.boston, pretty = 0)
Seguindo os resultados da validação cruzada, utilizamos a árvore não podada para fazer previsões no conjunto de teste:
yhat <- predict(tree.boston, newdata = Boston[-train, ])
boston.test <- Boston[-train, "medv"]
plot(yhat, boston.test)
abline(0, 1)
mean((yhat - boston.test)^2)
[1] 35.28688
Vemos que o erro quadrático médio no conjunto de teste associado à árvore de regressão é \(35{,}29\). A raiz quadrada do erro é, portanto, aproximadamente \(5{,}94\), indicando que esse modelo gera previsões no conjunto de teste que estão (em média) a cerca de \(5{,}94\) do valor mediano real das casas no setor censitário.
Vamos introduzir um novo conceito/tipo de modelo aqui chamado de bagging (puro), o qual é semelhante, embora distinto, de uma floresta aleatória:
Bagging (puro): Utiliza um subconjunto aleatório de amostras dos dados para treinar cada árvore (através de bootstrapping), mas considera todos os preditores (variáveis independentes) disponíveis em cada nó para determinar a melhor divisão.
Floresta aleatória: Além de utilizar o bootstrapping das amostras (como no bagging puro), a floresta aleatória também introduz aleatoriedade na seleção de preditores. Em cada nó de cada árvore, apenas um subconjunto aleatório dos preditores é considerado para encontrar a melhor divisão.
Aqui aplicamos bagging (puro) e florestas aleatória
aos dados Boston, utilizando o pacote
randomForest do R. Note que os resultados
apresentados a seguir podem variar um pouco dependendo das versões do
R e do pacote randomForest instaladas em seu
computador.
Já que o bagging (puro) é simplesmente um caso especial de
uma floresta aleatória, a função randomForest()
pode ser usada tanto para criar florestas aleatórias quanto
para realizar bagging (puro). Executamos o bagging
(puro) da seguinte forma:
library(randomForest)
randomForest 4.7-1.2
Type rfNews() to see new features/changes/bug fixes.
set.seed(1)
bag.boston <- randomForest(medv ~ ., data = Boston,
subset = train, mtry = 12, importance = TRUE)
bag.boston
Call:
randomForest(formula = medv ~ ., data = Boston, mtry = 12, importance = TRUE, subset = train)
Type of random forest: regression
Number of trees: 500
No. of variables tried at each split: 12
Mean of squared residuals: 11.40162
% Var explained: 85.17
O argumento mtry = 12 indica que todos os \(12\) preditores devem ser considerados para
cada divisão da árvore—ou seja, que o método de bagging (puro)
será utilizado.
Quão bem esse modelo com bagging (puro) se sai no conjunto de teste?
yhat.bag <- predict(bag.boston, newdata = Boston[-train, ])
plot(yhat.bag, boston.test)
abline(0, 1)
mean((yhat.bag - boston.test)^2)
[1] 23.41916
O erro quadrático médio no conjunto de teste associado à árvore de regressão com bagging (puro) é \(23{,}42\), que é cerca de dois terços daquele obtido com uma única árvore podada de forma ótima.
Podemos alterar o número de árvores geradas pela função
randomForest() usando o argumento ntree:
bag.boston <- randomForest(medv ~ ., data = Boston,
subset = train, mtry = 12, ntree = 25)
yhat.bag <- predict(bag.boston, newdata = Boston[-train, ])
mean((yhat.bag - boston.test)^2)
[1] 25.75055
A forma de se criar um modelo de florestas aleatória procede
exatamente da mesma forma, exceto pelo fato de usarmos um valor menor
para o argumento mtry.
Por padrão, a função randomForest() utiliza \(x/3\) variáveis ao construir uma
florestas aleatória de regressão, e \(\sqrt{x}\) variáveis ao construir uma
florestas aleatória de classificação.
Aqui, usamos mtry = 6:
set.seed(1)
rf.boston <- randomForest(medv ~ ., data = Boston,
subset = train, mtry = 6, importance = TRUE)
yhat.rf <- predict(rf.boston, newdata = Boston[-train, ])
mean((yhat.rf - boston.test)^2)
[1] 20.06644
O erro quadrático médio no conjunto de teste é \(20{,}07\). Ou seja, o modelo de
florestas aleatória (utilizando mtry = 6) se saiu
melhor que o bagging (puro) (utilizando mtry = 12)
neste caso—parece até contraintutitivo, não acha? Não é incomum que uma
floresta aleatória supere o bagging puro em termos de
desempenho no conjunto de teste. A aleatoriedade na seleção de
preditores em cada divisão é uma característica poderosa das
florestas aleatórias que frequentemente leva a modelos mais
robustos e com melhor capacidade de generalização.
Usando a função importance(), podemos visualizar a
importância de cada variável:
importance(rf.boston)
%IncMSE IncNodePurity
crim 19.435587 1070.42307
zn 3.091630 82.19257
indus 6.140529 590.09536
chas 1.370310 36.70356
nox 13.263466 859.97091
rm 35.094741 8270.33906
age 15.144821 634.31220
dis 9.163776 684.87953
rad 4.793720 83.18719
tax 4.410714 292.20949
ptratio 8.612780 902.20190
lstat 28.725343 5813.04833
Duas medidas de importância das variáveis são reportadas acima. A primeira baseia-se na redução média da acurácia nas previsões sobre as amostras out of bag quando uma determinada variável é permutada. A segunda é uma medida da redução total da impureza dos nós resultante das divisões feitas com base naquela variável, calculada como a média ao longo de todas as árvores.
No caso de árvores de regressão criadas pela função
randomForest(), a impureza dos nós é medida pelo erro
quadrático residual (RSS, em inglês) do treino, e para
árvores de classificação, pela deviance. Gráficos dessas
medidas de importância podem ser gerados com a função
varImpPlot(). Note que o gráfico abaixo não usa o termo
RSS, mas sim MSE (Mean Squared Error ou Erro
Quadrático Médio). Embora o processo de construção de árvores em
árvores de regressão frequentemente se concentre em reduzir o
RSS a cada divisão, o MSE é uma métrica mais
padronizada e interpretável para avaliar o desempenho geral de um modelo
de regressão. O MSE é simplesmente a média das diferenças
quadráticas entre os valores previstos e os valores reais:
\[ MSE = \frac{1}{n} \sum_{i=1}^{n} (y_i - \hat{y}_i)^2 = \frac{RSS}{n} \]
varImpPlot(rf.boston)
Os resultados indicam que, considerando-se todas as árvores da
floresta aleatória, lstat e rm são as duas
variáveis mais importantes em nosso modelo.
Aqui utilizamos o pacote gbm3, e dentro dele a função
gbm(), para ajustar árvores de regressão com
boosting ao conjunto de dados
Boston. (Note que podemos usar boosting com
árvores de classificação também).
Executamos gbm() com a opção
distribution = "gaussian" porque queremos um modelo de
regressão (se quiséssemos um modelo de classificação binária, usaríamos
distribution = "bernoulli"). O argumento
n.trees = 5000 indica que queremos \(5000\) árvores, e a opção
interaction.depth = 4 limita a profundidade de cada
árvore:
library(gbm)
Loaded gbm 2.2.2
This version of gbm is no longer under development. Consider transitioning to gbm3, https://github.com/gbm-developers/gbm3
set.seed(1)
boost.boston <- gbm(medv ~ ., data = Boston[train, ],
distribution = "gaussian", n.trees = 5000,
interaction.depth = 4)
Note a mensagem que alerta que o pacote gbm não está
mais em desenvolvimento e os autores sugerem o uso do pacote
gbm3 de agora em diante (https://github.com/gbm-developers/gbm3). No entanto,
como esse pacote ainda não está disponível no repositório central do
R (CRAN), vamos continuar utilizando o pacote
gbm original nesse tutorial (apenas por conveniência).
A função summary() produz um gráfico de influência
relativa e também exibe as estatísticas de influência relativa:
summary(boost.boston)
var rel.inf
rm rm 44.48249588
lstat lstat 32.70281223
crim crim 4.85109954
dis dis 4.48693083
nox nox 3.75222394
age age 3.19769210
ptratio ptratio 2.81354826
tax tax 1.54417603
indus indus 1.03384666
rad rad 0.87625748
zn zn 0.16220479
chas chas 0.09671228
Vemos que, novamente, lstat e rm são, de
longe, as variáveis mais importantes em nosso modelo.
Também podemos produzir gráficos de dependência parcial para
essas duas variáveis. Esses gráficos ilustram o efeito marginal das
variáveis selecionadas sobre a resposta, após integrar o efeito
das outras variáveis. Neste caso, como era de se esperar, os preços
medianos das casas (y) aumentam com rm e
diminuem com lstat:
plot(boost.boston, i = "rm")
plot(boost.boston, i = "lstat")
Agora usamos o modelo com boosting para prever
medv no conjunto de teste:
yhat.boost <- predict(boost.boston,
newdata = Boston[-train, ], n.trees = 5000)
mean((yhat.boost - boston.test)^2)
[1] 18.39057
O erro quadrático médio (MSE, em inglês) no conjunto de teste obtido é \(18{,}39\). Este valor é inferior ao MSE dos modelos random forests e bagging anteriores, indicando melhor desempenho do modelo de árvores de regressão com boosting.
Se quisermos, podemos realizar o boosting com um valor diferente para o hiperparâmetro de redução (shrinkage) \(\lambda\). Esse é um hiperparâmetro crucial e controla a contribuição de cada nova árvore adicionada ao modelo (ensemble) final.
Em essência, o shrinkage funciona como um fator de moderação ou regularização no processo de aprendizado sequencial do boosting. Em vez de adicionar a previsão completa de cada nova árvore ao ensemble, o shrinkage multiplica essa previsão por um valor menor (tipicamente entre 0 e 1). Ou seja, ele controla a magnitude (ou “peso”) da contribuição de cada árvore individual para a construção do modelo final.
A configuração padrão da função gbm() adota
shrinkage = 0.001, mas isso pode ser facilmente modificado.
Aqui, escolhemos shrinkage = 0.2:
boost.boston <- gbm(medv ~ ., data = Boston[train, ],
distribution = "gaussian", n.trees = 5000,
interaction.depth = 4, shrinkage = 0.2, verbose = F)
yhat.boost <- predict(boost.boston,
newdata = Boston[-train, ], n.trees = 5000)
mean((yhat.boost - boston.test)^2)
[1] 16.54778
Neste caso, utilizar \(\lambda = 0{,}2\) resulta em um MSE no conjunto de teste menor do que ao usar \(\lambda = 0{,}001\).
Agora é sua vez de ajustar modelos de florestas aleatóras e árvores
de regressão com boosting a um novo conjunto de dados. Você pode fazer
isso utilizando seus próprios dados, dados disponíveis na internet ou em
pacotes do R. Outra opção é utilizar os dados sobre
desmatamento municipal em 2004, disponíveis em https://thaleswest.wixsite.com/home/tutorials (note que
o site contém a descrição das variáveis deste conjunto de dados).
Siga os seguintes passos:
randomForest().