001 /** 002 * 003 */ 004 package org.wdssii.decisiontree; 005 006 import java.util.Arrays; 007 008 /** 009 * 010 * Picks a threshold to maximize the gain but uses gain-ratio as the fitness of 011 * this attribute 012 * 013 * @author lakshman 014 * 015 */ 016 public class GainRatioFitnessFunction implements FitnessFunction { 017 018 private static class ValueAndCategory implements Comparable<ValueAndCategory> { 019 float value; 020 021 int category; 022 023 public ValueAndCategory(float f, int i) { 024 this.value = f; 025 this.category = i; 026 } 027 028 @Override 029 public int compareTo(ValueAndCategory other) { 030 return Float.compare(value, other.value); 031 } 032 } 033 034 /* 035 * (non-Javadoc) 036 * 037 * @see org.jscience.statistics.classifiers.decisiontree.FitnessFunction#computeSplitAndGain(float[][], 038 * int[], int[], int) 039 */ 040 public Split computeSplitAndGain(float[][] inputData, int[] targetClass, 041 int numCategories, int[] toConsider, int attribute) { 042 043 if (toConsider.length < 2) { 044 throw new IllegalStateException( 045 "should not be called with < 2 samples"); 046 } 047 048 // sort all the possible values 049 ValueAndCategory[] all = new ValueAndCategory[toConsider.length]; 050 for (int i = 0, n = toConsider.length; i < n; ++i) { 051 int row = toConsider[i]; 052 all[i] = new ValueAndCategory(inputData[row][attribute], 053 targetClass[row]); 054 } 055 Arrays.sort(all); 056 057 // compute the gain for each possible value of the threshold 058 int bestSplit = -1; 059 float bestGain = Float.MIN_VALUE; 060 for (int splitBefore = 1; splitBefore < all.length; ++splitBefore) { 061 // split will not be maximal if on both sides of the split, the same 062 // category exists 063 if (all[splitBefore].category != all[splitBefore - 1].category) { 064 float gain = computeGain(all, numCategories, splitBefore); 065 if (gain > bestGain) { 066 bestGain = gain; 067 bestSplit = splitBefore; 068 } 069 } 070 } 071 072 if (bestSplit < 0) { 073 bestSplit = all.length / 2; // in-half 074 } 075 076 // compute information gain-ratio 077 Split result = new Split(); 078 result.score = computeGainRatio(all, numCategories, bestSplit); 079 result.thresh = (all[bestSplit].value + all[bestSplit].value) / 2; 080 081 System.out.println("att=" + attribute + " split=" + bestSplit + "/" 082 + all.length + " thresh=" + result.thresh + " gain=" + bestGain 083 + " gainratio=" + result.score); 084 085 return result; 086 } 087 088 private float computeGain(ValueAndCategory[] all, int numCategories, 089 int split) { 090 float info_d = computeInfo(all, numCategories, 0, all.length); 091 float info_d0 = computeInfo(all, numCategories, 0, split); 092 float info_d1 = computeInfo(all, numCategories, split, all.length); 093 float mld_corr = log2(all.length - 1) / all.length; 094 float gain = info_d - (info_d0 * split) / all.length 095 - (info_d1 * (all.length - split)) / all.length - mld_corr; 096 return gain; 097 } 098 099 private float computeInfo(ValueAndCategory[] all, int numCategories, 100 int start, int end) { 101 int[] freq = new int[numCategories]; 102 for (int i = start; i < end; ++i) { 103 int category = all[i].category; 104 ++freq[category]; 105 } 106 int total = end - start; 107 108 float sum = 0; 109 for (int j = 0; j < numCategories; ++j) { 110 if (freq[j] != 0) { 111 float prob_dj = (float) freq[j] / total; 112 sum += prob_dj * log2(prob_dj); 113 } 114 } 115 return (-sum); 116 } 117 118 private static double LOG2 = Math.log(2); 119 120 private static float log2(float p) { 121 float result = (float) (Math.log(p) / LOG2); 122 return result; 123 } 124 125 private float computeGainRatio(ValueAndCategory[] all, int numCategories, 126 int split) { 127 float gain = computeGain(all, numCategories, split); 128 float frac = (float) split / all.length; 129 float cost_0 = (split == 0) ? 0 : (frac * log2(frac)); 130 float cost_1 = (split == all.length) ? 0 131 : ((1 - frac) * log2(1 - frac)); 132 return -gain / (cost_0 + cost_1); 133 } 134 135 }