We saw in Part 1 the basic structure of a decision tree and we created in Part 2 a class to handle the samples and labels of a data set. We are going to see now how to compute the prediction values of the leaves to fit a data set.
Computing the leaves’ values from a data set
We suppose that the structure of the decision tree is known, as well as the test functions, so that only the decision values of the leaves are missing. We then compute the values of the leaves the following way:
- for each sample in the data set, we check to which leave it is associated to by following the path given by the true values of the test functions,
- for each leaf, we combine the target values of the associated samples.
Combining the target values depend on the type of problem studied. If the targets are real valued, we may just want to take the mean of the targets of the associated samples:
def combineDouble(targets: Vector[Double]): Double =
targets.sum / targets.length
val targets = Vector(0.1, 0.4, 0.5)
combineDouble(targets)
0.3333333333333333
If the targets are not numerical, we may want to choose as prediction value the target with the highest occurence.
For example, let’s assume that the targets of the associated samples of a leaf are: “Taiwan”, “Belgium”, “Canada”, “Belgium”, “France”, “Belgium”. It then seems natural to choose “Belgium” as prediction value:
def combineString(targets: Vector[String]): String = {
val groups = targets groupBy identity
groups.maxBy(_._2.length)._1
}
val targets = Vector("Taiwan", "Belgium", "Canada", "Belgium", "France", "Belgium")
combineString(targets)
Belgium
We create a LabelCombiner class to better handle the combination function:
case class LabelCombiner[B](combine: Vector[B] => B) {
/** Combine two elements rather than the elements of a vector */
def combine(left: B, right: B): B = combine(Vector(left, right))
}
defined class LabelCombiner
We saw in this Part how to compute the prediction values of the leaves. In Part 4, we will combine the previous results in order to build a decision tree predictor.