001    /**
002     * 
003     */
004    package org.wdssii.decisiontree;
005    
006    import java.util.ArrayList;
007    import java.util.List;
008    
009    /**
010     * C45 learning algorithm to create an axial decision tree. J. R. Quinlan.
011     * Improved use of continuous attributes in c4.5. Journal of Artificial
012     * Intelligence Research, 4:77-90, 1996
013     * 
014     * Usage:
015     * <pre>
016       float[][] data = new float[numTraining][numAttr];
017       int[] categories = new int[numTraining];
018       // populate arrays
019       ...
020       QuinlanC45AxialDecisionTreeCreator classifier = new QuinlanC45AxialDecisionTreeCreator(0.1); // pruning fraction
021       DecisionTree tree = classifier.learn(data, categories);
022     * </pre>
023     * 
024     * @author lakshman
025     * 
026     */
027    public class QuinlanC45AxialDecisionTreeCreator implements DecisionTreeCreator {
028            /** How many members can be in a population before it is split? */
029            private int populationToConsiderSplitting = 10;
030    
031            /**
032             * fraction of the training data set to keep aside so that the learned tree
033             * is not overfit. A value of 0.1f may be a pretty good choice. The pruning
034             * points will be the last few instances of the training data set. Pass in a
035             * randomized sample if this simply won't do.
036             */
037            private float pruningFraction = 0.1f;
038    
039            /** how deep can this tree go? The deeper the tree the less general it is. */
040            private int maxDepth = 10;
041    
042            /** how many classes are there? */
043            private int numCategories = 0;
044    
045            /** By default, InformationGain is used. */
046            private FitnessFunction fitness = new GainRatioFitnessFunction();
047    
048            @SuppressWarnings("serial")
049            public static class TreeCreationException extends RuntimeException {
050                    TreeCreationException(String cause) {
051                            super(cause);
052                    }
053            }
054    
055            public QuinlanC45AxialDecisionTreeCreator(float pruningFraction) {
056                    this.pruningFraction = pruningFraction;
057            }
058    
059            public QuinlanC45AxialDecisionTreeCreator() {
060            }
061    
062            /**
063             * @param inputData
064             *            an array where each row corresponds to a single instance (to
065             *            be classified) and the columns hold the attributes of that
066             *            instance
067             * @param targetClass
068             *            an array where each row corresponds to a single instance,
069             *            specifically the actual classification of that instance. The
070             *            class needs to be a number 0,1,2,...,N-1 where N is the number
071             *            of classes. Some of these classes may have no examples.
072             * @return decisiontree
073             */
074            public AxialDecisionTree learn(float[][] inputData, int[] targetClass)
075                            throws TreeCreationException, IllegalArgumentException {
076                    if (inputData.length == 0 || inputData[0].length == 0
077                                    || targetClass.length != inputData.length
078                                    || pruningFraction < 0.0f) {
079                            throw new IllegalArgumentException();
080                    }
081                    int numTraining = Math.round(inputData.length * (1 - pruningFraction));
082                    int numTesting = inputData.length - numTraining;
083    
084                    int[] toConsider = new int[numTraining];
085                    for (int i = 0; i < numTraining; ++i) {
086                            toConsider[i] = i;
087                            if (targetClass[i] >= numCategories) {
088                                    numCategories = targetClass[i] + 1;
089                            }
090                    }
091    
092                    AxialTreeNode node = buildTree(inputData, targetClass, toConsider, 0);
093                    if (node == null) {
094                            throw new TreeCreationException(
095                                            "Can not classify decision tree as there are too few unique inputs");
096                    }
097    
098                    if (numTesting > 0) {
099                            toConsider = new int[numTesting];
100                            for (int i = 0; i < numTesting; ++i) {
101                                    toConsider[i] = numTraining + i; // the last few
102                            }
103                            node = pruneTree(node, inputData, targetClass, toConsider);
104                    }
105    
106                    return new AxialDecisionTree(node, -1, inputData[0].length, numCategories);
107            }
108    
109            /**
110             * removes nodes that do not perform well on the validation dataset.
111             */
112            private AxialTreeNode pruneTree(AxialTreeNode node, float[][] inputData,
113                            int[] targetClass, int[] toConsider) {
114                    // Prune the left/right branches
115                    List<Integer> leftPoints = new ArrayList<Integer>();
116                    List<Integer> rightPoints = new ArrayList<Integer>();
117                    for (int i = 0, n = toConsider.length; i < n; ++i) {
118                            int row = toConsider[i];
119                            if (node.isHandledByLeftBranch(inputData[row])) {
120                                    leftPoints.add(i);
121                            } else if (node.isHandledByRightBranch(inputData[row])) {
122                                    rightPoints.add(i);
123                            }
124                    }
125                    if (leftPoints.size() > 0) {
126                            int[] toConsiderLeft = new int[leftPoints.size()];
127                            for (int i = 0, n = toConsiderLeft.length; i < n; ++i) {
128                                    toConsiderLeft[i] = leftPoints.get(i).intValue();
129                            }
130                            node.setLeft(pruneTree(node.getLeft(), inputData, targetClass,
131                                            toConsiderLeft));
132                    }
133                    if (rightPoints.size() > 0) {
134                            int[] toConsiderRight = new int[rightPoints.size()];
135                            for (int i = 0, n = toConsiderRight.length; i < n; ++i) {
136                                    toConsiderRight[i] = rightPoints.get(i).intValue();
137                            }
138                            node.setRight(pruneTree(node.getRight(), inputData, targetClass,
139                                            toConsiderRight));
140                    }
141    
142                    node.normalize();
143    
144                    // Decide whether to replace this node by just a stump
145                    // The replacement will happen if it will increase the number of correct
146                    int numCorrect = 0;
147                    int numCorrect_asStump = 0;
148                    int defaultCategory = node.getDefaultCategory();
149                    for (int i = 0, n = toConsider.length; i < n; ++i) {
150                            int row = toConsider[i];
151                            int trueCategory = targetClass[row];
152                            int estCategory = node.classify(inputData[row]);
153                            if (trueCategory == estCategory) {
154                                    ++numCorrect;
155                            }
156                            if (trueCategory == defaultCategory) {
157                                    ++numCorrect_asStump;
158                            }
159                    }
160                    if (numCorrect_asStump > numCorrect) {
161                            System.out.println("Pruning " + node);
162                            return new AxialTreeNode(defaultCategory);
163                    }
164    
165                    float fractionCorrect = (float) numCorrect / toConsider.length;
166                    if (fractionCorrect < 0.4f) {
167                            System.out.println("Pruning " + node);
168                            return new AxialTreeNode(defaultCategory);
169                    }
170    
171                    return node;
172            }
173    
174            /**
175             * Helper method that creates a sub-tree and returns a node
176             * 
177             * @return
178             */
179            private AxialTreeNode buildTree(float[][] inputData, int[] targetClass,
180                            int[] toConsider, int depth) {
181                    if (toConsider.length < populationToConsiderSplitting
182                                    || depth == maxDepth) {
183                            // find most likely category and return a node that supplies it
184                            // always
185                            int mostLikelyCategory = getMostLikelyCategory(inputData,
186                                            targetClass, toConsider);
187                            return new AxialTreeNode(mostLikelyCategory);
188                    }
189    
190                    // Is everything of same category?
191                    boolean allSame = true;
192                    int startCategory = targetClass[toConsider[0]];
193                    for (int i = 0, n = toConsider.length; i < n; ++i) {
194                            if (targetClass[toConsider[i]] != startCategory) {
195                                    allSame = false;
196                                    break;
197                            }
198                    }
199                    if (allSame) {
200                            return new AxialTreeNode(startCategory);
201                    }
202    
203                    // compute best attribute by finding the one that has highest
204                    // information gain
205                    int numAttributes = inputData[0].length;
206                    FitnessFunction.Split[] splits = new FitnessFunction.Split[numAttributes];
207                    for (int i = 0; i < numAttributes; ++i) {
208                            splits[i] = fitness.computeSplitAndGain(inputData, targetClass,
209                                            numCategories, toConsider, i);
210                    }
211                    int bestAttribute = 0;
212                    for (int i = 1; i < numAttributes; ++i) {
213                            if (splits[i].score > splits[bestAttribute].score) {
214                                    bestAttribute = i;
215                            }
216                    }
217    
218                    // create node to split on bestAttribute
219                    float thresh = splits[bestAttribute].thresh;
220                    int[][] toConsiderSplit = split(inputData, toConsider, bestAttribute,
221                                    thresh);
222    
223                    // if there are no examples on one side of the branch, simply return the
224                    // other side
225                    if (toConsiderSplit[0].length == 0) {
226                            return buildTree(inputData, targetClass, toConsiderSplit[1],
227                                            depth + 1);
228                    } else if (toConsiderSplit[1].length == 0) {
229                            return buildTree(inputData, targetClass, toConsiderSplit[0],
230                                            depth + 1);
231                    }
232    
233                    int leftCategory = getMostLikelyCategory(inputData, targetClass,
234                                    toConsiderSplit[0]);
235                    int rightCategory = getMostLikelyCategory(inputData, targetClass,
236                                    toConsiderSplit[1]);
237                    AxialTreeNode leftNode = buildTree(inputData, targetClass,
238                                    toConsiderSplit[0], depth + 1);
239                    AxialTreeNode rightNode = buildTree(inputData, targetClass,
240                                    toConsiderSplit[1], depth + 1);
241                    int defaultCategory = (toConsiderSplit[0].length > toConsiderSplit[1].length) ? leftCategory
242                                    : rightCategory;
243                    AxialTreeNode branch = new AxialTreeNode(bestAttribute, thresh,
244                                    leftNode, rightNode, defaultCategory);
245                    branch.normalize();
246                    return branch;
247            }
248    
249            private int[][] split(float[][] inputData, int[] toConsider,
250                            int bestAttribute, float thresh) {
251                    int numLeft = 0;
252                    for (int i = 0, n = toConsider.length; i < n; ++i) {
253                            if (inputData[toConsider[i]][bestAttribute] < thresh) {
254                                    ++numLeft;
255                            }
256                    }
257                    int numRight = toConsider.length - numLeft;
258                    int[][] result = new int[2][];
259                    result[0] = new int[numLeft];
260                    result[1] = new int[numRight];
261                    int leftIndex = 0;
262                    int rightIndex = 0;
263                    for (int i = 0, n = toConsider.length; i < n; ++i) {
264                            if (inputData[toConsider[i]][bestAttribute] < thresh) {
265                                    result[0][leftIndex] = toConsider[i];
266                                    ++leftIndex;
267                            } else {
268                                    result[1][rightIndex] = toConsider[i];
269                                    ++rightIndex;
270                            }
271                    }
272                    return result;
273            }
274    
275            private int getMostLikelyCategory(float[][] inputData, int[] targetClass,
276                            int[] toConsider) {
277                    if (toConsider.length == 0) {
278                            throw new IllegalStateException(
279                                            "should not have empty toConsider array here");
280                    }
281                    int[] populationByCategory = new int[numCategories];
282                    for (int i = 0, n = toConsider.length; i < n; ++i) {
283                            int category = targetClass[toConsider[i]];
284                            ++populationByCategory[category];
285                    }
286                    int bestCategory = 0;
287                    for (int i = 1; i < numCategories; ++i) {
288                            if (populationByCategory[i] > populationByCategory[bestCategory]) {
289                                    bestCategory = i;
290                            }
291                    }
292                    return bestCategory;
293            }
294    
295            /** Corresponds to previously learnt data set */
296            public int getNumCategories() {
297                    return numCategories;
298            }
299    
300    
301    }