001 package org.wdssii.aicompetition; 002 003 import java.io.BufferedReader; 004 import java.io.FileReader; 005 import java.text.DecimalFormat; 006 import java.util.ArrayList; 007 import java.util.List; 008 import java.util.Random; 009 010 /** 011 * @author lakshman 012 * 013 */ 014 public class Evaluate { 015 016 /** 017 * @param args 018 */ 019 public static void main(String[] args) throws Exception { 020 if (args.length != 2) { 021 System.err.println("Usage: java Evaluate correct_results.txt candidate_results.txt"); 022 return; 023 } 024 int[] correct_results = readResults(args[0]); 025 int[] candidate_results = readResults(args[1]); 026 027 028 // do N random tests 029 computeTssOnRandomPatterns(correct_results, candidate_results, 10, 0.5); 030 computeTssOnRandomRuns(correct_results, candidate_results, 10, 0.5); 031 032 // overall 033 Stat overall = new Stat(); 034 for (int i=0; i < correct_results.length; ++i){ 035 overall.update(correct_results[i], candidate_results[i]); 036 } 037 System.out.println("Overall: " + overall.toString()); 038 } 039 040 /** 041 * Draws random samples from the entire set and computes TSS. 042 * @param correct_results 043 * @param candidate_results 044 * @param NUM_ITER 045 * @param selectionProbability 046 */ 047 private static void computeTssOnRandomPatterns(int[] correct_results, 048 int[] candidate_results, final int NUM_ITER, 049 final double selectionProbability) { 050 Random rand = new Random(); 051 float sumx = 0; 052 float sumx2 = 0; 053 054 for (int iter=0; iter < NUM_ITER; ++iter){ 055 int N = 0; 056 Stat s = new Stat(); 057 for (int i=0; i < correct_results.length; ++i){ 058 if (rand.nextDouble() < selectionProbability){ 059 s.update(correct_results[i], candidate_results[i]); 060 ++N; 061 } 062 } 063 // System.out.println(iter + " (" + N + "): " + s.toString()); 064 float tss = s.getTSS(); 065 sumx += tss; 066 sumx2 += tss*tss; 067 } 068 float tss_mean = sumx / NUM_ITER; 069 float tss_var = (sumx2 - (sumx*sumx)/NUM_ITER)/(NUM_ITER-1); 070 float tss_stddev = (tss_var > 0.0001f)? (float) Math.sqrt(tss_var) : 0; 071 System.out.println("Random rows: TSS= " + tss_mean + " +/- " + tss_stddev + " estimated over " + NUM_ITER + " iterations with selectionProb=" + selectionProbability); 072 } 073 074 /** 075 * Chooses random runs of the input sets i.e. chosen patterns are consecutive 076 * @param correct_results 077 * @param candidate_results 078 * @param NUM_ITER 079 * @param selectionProbability 080 */ 081 private static void computeTssOnRandomRuns(int[] correct_results, 082 int[] candidate_results, final int NUM_ITER, 083 final double selectionProbability) { 084 final int N = (int) Math.round(selectionProbability * correct_results.length); 085 Random rand = new Random(); 086 float sumx = 0; 087 float sumx2 = 0; 088 089 for (int iter=0; iter < NUM_ITER; ++iter){ 090 Stat s = new Stat(); 091 final int start = rand.nextInt(correct_results.length); 092 final int end = start + N; 093 for (int j=start; j < end; ++j){ 094 int i = j % correct_results.length; 095 s.update(correct_results[i], candidate_results[i]); 096 } 097 // System.out.println(iter + " (" + N + "): " + s.toString()); 098 float tss = s.getTSS(); 099 sumx += tss; 100 sumx2 += tss*tss; 101 } 102 float tss_mean = sumx / NUM_ITER; 103 float tss_var = (sumx2 - (sumx*sumx)/NUM_ITER)/(NUM_ITER-1); 104 float tss_stddev = (tss_var > 0.0001f)? (float) Math.sqrt(tss_var) : 0; 105 System.out.println("Random runs: TSS= " + tss_mean + " +/- " + tss_stddev + " estimated over " + NUM_ITER + " iterations with N=" + N); 106 } 107 108 public static int[] readResults(String fileName) throws Exception { 109 BufferedReader reader = new BufferedReader(new FileReader(fileName)); 110 List<Integer> results = new ArrayList<Integer>(); 111 String line; 112 while ((line = reader.readLine()) != null){ 113 results.add(Integer.parseInt(line)); 114 } 115 int[] v = new int[results.size()]; 116 for (int i=0; i < v.length; ++i){ 117 v[i] = results.get(i); 118 } 119 return v; 120 } 121 122 /** 123 * Computes and maintains statistics 124 */ 125 public static class Stat { 126 private int[][] stats = new int[5][5]; 127 128 public void update(int expectedValue, int alg_stormtype) { 129 stats[expectedValue][alg_stormtype]++; 130 } 131 132 public String toString() { 133 String newline = System.getProperty("line.separator"); 134 DecimalFormat idf = new DecimalFormat(" "); 135 DecimalFormat adf = new DecimalFormat("##.0"); 136 // Header 137 StringBuilder sb = new StringBuilder(); 138 sb.append(" Got-> "); 139 for (int i = 0; i < stats.length; ++i) { 140 sb.append(idf.format(i)).append("\t"); 141 } 142 sb.append("Accuracy").append(newline); 143 // Matrix 144 for (int i = 0; i < stats.length; ++i) { 145 sb.append("Expected ").append(idf.format(i)).append(": "); 146 float denom = 0; 147 for (int j = 0; j < stats[i].length; ++j) { 148 sb.append(idf.format(stats[i][j])).append("\t"); 149 denom += stats[i][j]; 150 } 151 float acc = 100 * stats[i][i] / denom; 152 sb.append(adf.format(acc)).append(newline); 153 } 154 155 // Trailer 156 sb.append("Accuracy "); 157 for (int i = 0; i < stats.length; ++i) { 158 float denom = 0; 159 for (int j = 0; j < stats[i].length; ++j) { 160 denom += stats[j][i]; 161 } 162 float acc = 100 * stats[i][i] / denom; 163 sb.append(adf.format(acc)).append("\t"); 164 } 165 166 // Final 167 sb.append(newline).append("True Skill Score (TSS)=").append( 168 getTSS()).append(newline); 169 170 return sb.toString(); 171 } 172 173 /** Computes True Skill Score (Hanssen/Kuipers/Peirces' skill score) */ 174 private float getTSS() { 175 float sum_a = 0; 176 float sum_b = 0; 177 float sum_c = 0; 178 int numCategories = stats.length; 179 float[] numForecasts = new float[numCategories]; 180 float[] numObservations = new float[numCategories]; 181 for (int i = 0; i < numCategories; ++i) { 182 for (int j = 0; j < numCategories; ++j) { 183 // stats is stats[observed][forecast] 184 numForecasts[i] += stats[j][i]; 185 numObservations[i] += stats[i][j]; 186 } 187 } 188 int N = 0; 189 for (int i = 0; i < numCategories; ++i) { 190 N += numForecasts[i]; 191 } 192 for (int j = 0; j < numCategories; ++j) { 193 sum_a += stats[j][j]; 194 sum_b += numForecasts[j] * numObservations[j]; 195 sum_c += numForecasts[j] * numForecasts[j]; 196 } 197 float TSS = ((sum_a / N) - (sum_b) / (N * N)) 198 / (1 - sum_c / (N * N)); 199 return TSS; 200 } 201 } 202 }