001 /** 002 * 003 */ 004 package org.wdssii.decisiontree; 005 006 import java.util.ArrayList; 007 import java.util.List; 008 009 /** 010 * C45 learning algorithm to create an axial decision tree. J. R. Quinlan. 011 * Improved use of continuous attributes in c4.5. Journal of Artificial 012 * Intelligence Research, 4:77-90, 1996 013 * 014 * Usage: 015 * <pre> 016 float[][] data = new float[numTraining][numAttr]; 017 int[] categories = new int[numTraining]; 018 // populate arrays 019 ... 020 QuinlanC45AxialDecisionTreeCreator classifier = new QuinlanC45AxialDecisionTreeCreator(0.1); // pruning fraction 021 DecisionTree tree = classifier.learn(data, categories); 022 * </pre> 023 * 024 * @author lakshman 025 * 026 */ 027 public class QuinlanC45AxialDecisionTreeCreator implements DecisionTreeCreator { 028 /** How many members can be in a population before it is split? */ 029 private int populationToConsiderSplitting = 10; 030 031 /** 032 * fraction of the training data set to keep aside so that the learned tree 033 * is not overfit. A value of 0.1f may be a pretty good choice. The pruning 034 * points will be the last few instances of the training data set. Pass in a 035 * randomized sample if this simply won't do. 036 */ 037 private float pruningFraction = 0.1f; 038 039 /** how deep can this tree go? The deeper the tree the less general it is. */ 040 private int maxDepth = 10; 041 042 /** how many classes are there? */ 043 private int numCategories = 0; 044 045 /** By default, InformationGain is used. */ 046 private FitnessFunction fitness = new GainRatioFitnessFunction(); 047 048 @SuppressWarnings("serial") 049 public static class TreeCreationException extends RuntimeException { 050 TreeCreationException(String cause) { 051 super(cause); 052 } 053 } 054 055 public QuinlanC45AxialDecisionTreeCreator(float pruningFraction) { 056 this.pruningFraction = pruningFraction; 057 } 058 059 public QuinlanC45AxialDecisionTreeCreator() { 060 } 061 062 /** 063 * @param inputData 064 * an array where each row corresponds to a single instance (to 065 * be classified) and the columns hold the attributes of that 066 * instance 067 * @param targetClass 068 * an array where each row corresponds to a single instance, 069 * specifically the actual classification of that instance. The 070 * class needs to be a number 0,1,2,...,N-1 where N is the number 071 * of classes. Some of these classes may have no examples. 072 * @return decisiontree 073 */ 074 public AxialDecisionTree learn(float[][] inputData, int[] targetClass) 075 throws TreeCreationException, IllegalArgumentException { 076 if (inputData.length == 0 || inputData[0].length == 0 077 || targetClass.length != inputData.length 078 || pruningFraction < 0.0f) { 079 throw new IllegalArgumentException(); 080 } 081 int numTraining = Math.round(inputData.length * (1 - pruningFraction)); 082 int numTesting = inputData.length - numTraining; 083 084 int[] toConsider = new int[numTraining]; 085 for (int i = 0; i < numTraining; ++i) { 086 toConsider[i] = i; 087 if (targetClass[i] >= numCategories) { 088 numCategories = targetClass[i] + 1; 089 } 090 } 091 092 AxialTreeNode node = buildTree(inputData, targetClass, toConsider, 0); 093 if (node == null) { 094 throw new TreeCreationException( 095 "Can not classify decision tree as there are too few unique inputs"); 096 } 097 098 if (numTesting > 0) { 099 toConsider = new int[numTesting]; 100 for (int i = 0; i < numTesting; ++i) { 101 toConsider[i] = numTraining + i; // the last few 102 } 103 node = pruneTree(node, inputData, targetClass, toConsider); 104 } 105 106 return new AxialDecisionTree(node, -1, inputData[0].length, numCategories); 107 } 108 109 /** 110 * removes nodes that do not perform well on the validation dataset. 111 */ 112 private AxialTreeNode pruneTree(AxialTreeNode node, float[][] inputData, 113 int[] targetClass, int[] toConsider) { 114 // Prune the left/right branches 115 List<Integer> leftPoints = new ArrayList<Integer>(); 116 List<Integer> rightPoints = new ArrayList<Integer>(); 117 for (int i = 0, n = toConsider.length; i < n; ++i) { 118 int row = toConsider[i]; 119 if (node.isHandledByLeftBranch(inputData[row])) { 120 leftPoints.add(i); 121 } else if (node.isHandledByRightBranch(inputData[row])) { 122 rightPoints.add(i); 123 } 124 } 125 if (leftPoints.size() > 0) { 126 int[] toConsiderLeft = new int[leftPoints.size()]; 127 for (int i = 0, n = toConsiderLeft.length; i < n; ++i) { 128 toConsiderLeft[i] = leftPoints.get(i).intValue(); 129 } 130 node.setLeft(pruneTree(node.getLeft(), inputData, targetClass, 131 toConsiderLeft)); 132 } 133 if (rightPoints.size() > 0) { 134 int[] toConsiderRight = new int[rightPoints.size()]; 135 for (int i = 0, n = toConsiderRight.length; i < n; ++i) { 136 toConsiderRight[i] = rightPoints.get(i).intValue(); 137 } 138 node.setRight(pruneTree(node.getRight(), inputData, targetClass, 139 toConsiderRight)); 140 } 141 142 node.normalize(); 143 144 // Decide whether to replace this node by just a stump 145 // The replacement will happen if it will increase the number of correct 146 int numCorrect = 0; 147 int numCorrect_asStump = 0; 148 int defaultCategory = node.getDefaultCategory(); 149 for (int i = 0, n = toConsider.length; i < n; ++i) { 150 int row = toConsider[i]; 151 int trueCategory = targetClass[row]; 152 int estCategory = node.classify(inputData[row]); 153 if (trueCategory == estCategory) { 154 ++numCorrect; 155 } 156 if (trueCategory == defaultCategory) { 157 ++numCorrect_asStump; 158 } 159 } 160 if (numCorrect_asStump > numCorrect) { 161 System.out.println("Pruning " + node); 162 return new AxialTreeNode(defaultCategory); 163 } 164 165 float fractionCorrect = (float) numCorrect / toConsider.length; 166 if (fractionCorrect < 0.4f) { 167 System.out.println("Pruning " + node); 168 return new AxialTreeNode(defaultCategory); 169 } 170 171 return node; 172 } 173 174 /** 175 * Helper method that creates a sub-tree and returns a node 176 * 177 * @return 178 */ 179 private AxialTreeNode buildTree(float[][] inputData, int[] targetClass, 180 int[] toConsider, int depth) { 181 if (toConsider.length < populationToConsiderSplitting 182 || depth == maxDepth) { 183 // find most likely category and return a node that supplies it 184 // always 185 int mostLikelyCategory = getMostLikelyCategory(inputData, 186 targetClass, toConsider); 187 return new AxialTreeNode(mostLikelyCategory); 188 } 189 190 // Is everything of same category? 191 boolean allSame = true; 192 int startCategory = targetClass[toConsider[0]]; 193 for (int i = 0, n = toConsider.length; i < n; ++i) { 194 if (targetClass[toConsider[i]] != startCategory) { 195 allSame = false; 196 break; 197 } 198 } 199 if (allSame) { 200 return new AxialTreeNode(startCategory); 201 } 202 203 // compute best attribute by finding the one that has highest 204 // information gain 205 int numAttributes = inputData[0].length; 206 FitnessFunction.Split[] splits = new FitnessFunction.Split[numAttributes]; 207 for (int i = 0; i < numAttributes; ++i) { 208 splits[i] = fitness.computeSplitAndGain(inputData, targetClass, 209 numCategories, toConsider, i); 210 } 211 int bestAttribute = 0; 212 for (int i = 1; i < numAttributes; ++i) { 213 if (splits[i].score > splits[bestAttribute].score) { 214 bestAttribute = i; 215 } 216 } 217 218 // create node to split on bestAttribute 219 float thresh = splits[bestAttribute].thresh; 220 int[][] toConsiderSplit = split(inputData, toConsider, bestAttribute, 221 thresh); 222 223 // if there are no examples on one side of the branch, simply return the 224 // other side 225 if (toConsiderSplit[0].length == 0) { 226 return buildTree(inputData, targetClass, toConsiderSplit[1], 227 depth + 1); 228 } else if (toConsiderSplit[1].length == 0) { 229 return buildTree(inputData, targetClass, toConsiderSplit[0], 230 depth + 1); 231 } 232 233 int leftCategory = getMostLikelyCategory(inputData, targetClass, 234 toConsiderSplit[0]); 235 int rightCategory = getMostLikelyCategory(inputData, targetClass, 236 toConsiderSplit[1]); 237 AxialTreeNode leftNode = buildTree(inputData, targetClass, 238 toConsiderSplit[0], depth + 1); 239 AxialTreeNode rightNode = buildTree(inputData, targetClass, 240 toConsiderSplit[1], depth + 1); 241 int defaultCategory = (toConsiderSplit[0].length > toConsiderSplit[1].length) ? leftCategory 242 : rightCategory; 243 AxialTreeNode branch = new AxialTreeNode(bestAttribute, thresh, 244 leftNode, rightNode, defaultCategory); 245 branch.normalize(); 246 return branch; 247 } 248 249 private int[][] split(float[][] inputData, int[] toConsider, 250 int bestAttribute, float thresh) { 251 int numLeft = 0; 252 for (int i = 0, n = toConsider.length; i < n; ++i) { 253 if (inputData[toConsider[i]][bestAttribute] < thresh) { 254 ++numLeft; 255 } 256 } 257 int numRight = toConsider.length - numLeft; 258 int[][] result = new int[2][]; 259 result[0] = new int[numLeft]; 260 result[1] = new int[numRight]; 261 int leftIndex = 0; 262 int rightIndex = 0; 263 for (int i = 0, n = toConsider.length; i < n; ++i) { 264 if (inputData[toConsider[i]][bestAttribute] < thresh) { 265 result[0][leftIndex] = toConsider[i]; 266 ++leftIndex; 267 } else { 268 result[1][rightIndex] = toConsider[i]; 269 ++rightIndex; 270 } 271 } 272 return result; 273 } 274 275 private int getMostLikelyCategory(float[][] inputData, int[] targetClass, 276 int[] toConsider) { 277 if (toConsider.length == 0) { 278 throw new IllegalStateException( 279 "should not have empty toConsider array here"); 280 } 281 int[] populationByCategory = new int[numCategories]; 282 for (int i = 0, n = toConsider.length; i < n; ++i) { 283 int category = targetClass[toConsider[i]]; 284 ++populationByCategory[category]; 285 } 286 int bestCategory = 0; 287 for (int i = 1; i < numCategories; ++i) { 288 if (populationByCategory[i] > populationByCategory[bestCategory]) { 289 bestCategory = i; 290 } 291 } 292 return bestCategory; 293 } 294 295 /** Corresponds to previously learnt data set */ 296 public int getNumCategories() { 297 return numCategories; 298 } 299 300 301 }