001    /**
002     * 
003     */
004    package org.wdssii.decisiontree;
005    
006    import java.util.Arrays;
007    
008    /**
009     * 
010     * Picks a threshold to maximize the gain but uses gain-ratio as the fitness of
011     * this attribute
012     * 
013     * @author lakshman
014     * 
015     */
016    public class GainRatioFitnessFunction implements FitnessFunction {
017    
018            private static class ValueAndCategory implements Comparable<ValueAndCategory> {
019                    float value;
020    
021                    int category;
022    
023                    public ValueAndCategory(float f, int i) {
024                            this.value = f;
025                            this.category = i;
026                    }
027    
028                    @Override
029                    public int compareTo(ValueAndCategory other) {
030                            return Float.compare(value, other.value);
031                    }
032            }
033    
034            /*
035             * (non-Javadoc)
036             * 
037             * @see org.jscience.statistics.classifiers.decisiontree.FitnessFunction#computeSplitAndGain(float[][],
038             *      int[], int[], int)
039             */
040            public Split computeSplitAndGain(float[][] inputData, int[] targetClass,
041                            int numCategories, int[] toConsider, int attribute) {
042    
043                    if (toConsider.length < 2) {
044                            throw new IllegalStateException(
045                                            "should not be called with < 2 samples");
046                    }
047    
048                    // sort all the possible values
049                    ValueAndCategory[] all = new ValueAndCategory[toConsider.length];
050                    for (int i = 0, n = toConsider.length; i < n; ++i) {
051                            int row = toConsider[i];
052                            all[i] = new ValueAndCategory(inputData[row][attribute],
053                                            targetClass[row]);
054                    }
055                    Arrays.sort(all);
056    
057                    // compute the gain for each possible value of the threshold
058                    int bestSplit = -1;
059                    float bestGain = Float.MIN_VALUE;
060                    for (int splitBefore = 1; splitBefore < all.length; ++splitBefore) {
061                            // split will not be maximal if on both sides of the split, the same
062                            // category exists
063                            if (all[splitBefore].category != all[splitBefore - 1].category) {
064                                    float gain = computeGain(all, numCategories, splitBefore);
065                                    if (gain > bestGain) {
066                                            bestGain = gain;
067                                            bestSplit = splitBefore;
068                                    }
069                            }
070                    }
071    
072                    if (bestSplit < 0) {
073                            bestSplit = all.length / 2; // in-half
074                    }
075    
076                    // compute information gain-ratio
077                    Split result = new Split();
078                    result.score = computeGainRatio(all, numCategories, bestSplit);
079                    result.thresh = (all[bestSplit].value + all[bestSplit].value) / 2;
080    
081                    System.out.println("att=" + attribute + " split=" + bestSplit + "/"
082                                    + all.length + " thresh=" + result.thresh + " gain=" + bestGain
083                                    + " gainratio=" + result.score);
084    
085                    return result;
086            }
087    
088            private float computeGain(ValueAndCategory[] all, int numCategories,
089                            int split) {
090                    float info_d = computeInfo(all, numCategories, 0, all.length);
091                    float info_d0 = computeInfo(all, numCategories, 0, split);
092                    float info_d1 = computeInfo(all, numCategories, split, all.length);
093                    float mld_corr = log2(all.length - 1) / all.length;
094                    float gain = info_d - (info_d0 * split) / all.length
095                                    - (info_d1 * (all.length - split)) / all.length - mld_corr;
096                    return gain;
097            }
098    
099            private float computeInfo(ValueAndCategory[] all, int numCategories,
100                            int start, int end) {
101                    int[] freq = new int[numCategories];
102                    for (int i = start; i < end; ++i) {
103                            int category = all[i].category;
104                            ++freq[category];
105                    }
106                    int total = end - start;
107    
108                    float sum = 0;
109                    for (int j = 0; j < numCategories; ++j) {
110                            if (freq[j] != 0) {
111                                    float prob_dj = (float) freq[j] / total;
112                                    sum += prob_dj * log2(prob_dj);
113                            }
114                    }
115                    return (-sum);
116            }
117    
118            private static double LOG2 = Math.log(2);
119    
120            private static float log2(float p) {
121                    float result = (float) (Math.log(p) / LOG2);
122                    return result;
123            }
124    
125            private float computeGainRatio(ValueAndCategory[] all, int numCategories,
126                            int split) {
127                    float gain = computeGain(all, numCategories, split);
128                    float frac = (float) split / all.length;
129                    float cost_0 = (split == 0) ? 0 : (frac * log2(frac));
130                    float cost_1 = (split == all.length) ? 0
131                                    : ((1 - frac) * log2(1 - frac));
132                    return -gain / (cost_0 + cost_1);
133            }
134    
135    }