Statistics Toolbox | ![]() ![]() |
Syntax
cost = treetest(T,'resubstitution') cost = treetest(T,'test',X,y) cost = treetest(T,'crossvalidate',X,y) [cost,secost,ntnodes,bestsize] = treetest(...) [...] = treetest(...,'param1',val1,'param2',val2,...)
Description
cost = treetext(T,'resubstitution')
computes the cost of the tree T
using a resubstitution method. T
is a decision tree as created by the treefit
function. The cost of the tree is the sum over all terminal nodes of the estimated probability of that node times the node's cost. If T
is a classification tree, the cost of a node is the sum of of the misclassification costs of the observations in that node. If T
is a regression tree, the cost of a node is the average squared error over the observations in that node. cost
is a vector of cost values for each subtree in the optimal pruning sequence for T
. The resubstitution cost is based on the same sample that was used to create the original tree, so it underestimates the likely cost of applying the tree to new data.
cost = treetest(T,'test',X,y)
uses the predictor matrix X
and response y
as a test sample, applies the decision tree T
to that sample, and returns a vector cost
of cost values computed for the test sample. X
and y
should not be the same as the learning sample, which is the sample that was used to fit the tree T
.
cost = treetest(T,'crossvalidate',X,y)
uses 10-fold cross-validation to compute the cost vector. X
and y
should be the learning sample, which is the sample that was used to fit the tree T
. The function partitions the sample into 10 subsamples, chosen randomly but with roughly equal size. For classification trees, the subsamples also have roughly the same class proportions. For each subsample, treetest
fits a tree to the remaining data and uses it to predict the subsample. It pools the information from all subsamples to compute the cost for the whole sample.
[cost,secost,ntnodes,bestlevel] = treetest(...)
also returns the vector secost
containing the standard error of each cost
value, the vector ntnodes
containing number of terminal nodes for each subtree, and the scalar bestlevel
containing the estimated best level of pruning. bestlevel =
0 means no pruning, i.e., the full unpruned tree. The best level is the one that produces the smallest tree that is within one standard error of the minimum-cost subtree.
[...] = treetest(...,'param1',val1,'param2',val2,...)
specifies optional parameter name-value pairs chosen from the following:
Examples
Find the best tree for Fisher's iris data using cross-validation. The solid line shows the estimated cost for each tree size, the dashed line marks 1 standard error above the minimum, and the square marks the smallest tree under the dashed line.
% Start with a large tree. load fisheriris; t = treefit(meas,species','splitmin',5); % Find the minimum-cost tree. [c,s,n,best] = treetest(t,'cross',meas,species); tmin = treeprune(t,'level',best); % Plot smallest tree within 1 std. error of minimum cost tree. [mincost,minloc] = min(c); plot(n,c,'b-o', n,c+s,'r:', n(best+1),c(best+1),'bs',... n,(mincost+s(minloc))*ones(size(n)),'k--'); xlabel('Tree size (number of terminal nodes)') ylabel('Cost')
See Also
![]() | treeprune | treeval | ![]() |