001    package org.wdssii.aicompetition;
002    
003    import java.io.BufferedReader;
004    import java.io.FileReader;
005    import java.text.DecimalFormat;
006    import java.util.ArrayList;
007    import java.util.List;
008    import java.util.Random;
009    
010    /**
011     * @author lakshman
012     * 
013     */
014    public class Evaluate {
015    
016            /**
017             * @param args
018             */
019            public static void main(String[] args) throws Exception {
020                    if (args.length != 2) {
021                            System.err.println("Usage: java Evaluate correct_results.txt  candidate_results.txt");
022                            return;
023                    }
024                    int[] correct_results = readResults(args[0]);
025                    int[] candidate_results = readResults(args[1]);
026                    
027    
028                    // do N random tests
029                    computeTssOnRandomPatterns(correct_results, candidate_results, 10, 0.5);
030                    computeTssOnRandomRuns(correct_results, candidate_results, 10, 0.5);
031                    
032                    // overall
033                    Stat overall = new Stat();
034                    for (int i=0; i < correct_results.length; ++i){
035                            overall.update(correct_results[i], candidate_results[i]);
036                    }
037                    System.out.println("Overall: " + overall.toString());
038            }
039    
040            /**
041             * Draws random samples from the entire set and computes TSS.
042             * @param correct_results
043             * @param candidate_results
044             * @param NUM_ITER
045             * @param selectionProbability
046             */
047            private static void computeTssOnRandomPatterns(int[] correct_results,
048                            int[] candidate_results, final int NUM_ITER,
049                            final double selectionProbability) {
050                    Random rand = new Random();
051                    float sumx = 0;
052                    float sumx2 = 0;
053    
054                    for (int iter=0; iter < NUM_ITER; ++iter){
055                            int N = 0;
056                            Stat s = new Stat();
057                            for (int i=0; i < correct_results.length; ++i){
058                                    if (rand.nextDouble() < selectionProbability){
059                                            s.update(correct_results[i], candidate_results[i]);
060                                            ++N;
061                                    }
062                            }
063                            // System.out.println(iter + " (" + N + "): " + s.toString());
064                            float tss = s.getTSS();
065                            sumx += tss;
066                            sumx2 += tss*tss;
067                    }
068                    float tss_mean = sumx / NUM_ITER;
069                    float tss_var  = (sumx2 - (sumx*sumx)/NUM_ITER)/(NUM_ITER-1);
070                    float tss_stddev = (tss_var > 0.0001f)? (float) Math.sqrt(tss_var) : 0;
071                    System.out.println("Random rows: TSS= " + tss_mean + " +/- " + tss_stddev + " estimated over " + NUM_ITER + " iterations with selectionProb=" + selectionProbability);
072            }
073    
074            /**
075             * Chooses random runs of the input sets i.e. chosen patterns are consecutive 
076             * @param correct_results
077             * @param candidate_results
078             * @param NUM_ITER
079             * @param selectionProbability
080             */
081            private static void computeTssOnRandomRuns(int[] correct_results,
082                            int[] candidate_results, final int NUM_ITER,
083                            final double selectionProbability) {
084                    final int N = (int) Math.round(selectionProbability * correct_results.length);
085                    Random rand = new Random();
086                    float sumx = 0;
087                    float sumx2 = 0;
088    
089                    for (int iter=0; iter < NUM_ITER; ++iter){                   
090                            Stat s = new Stat();
091                            final int start = rand.nextInt(correct_results.length);
092                            final int end   = start + N;
093                            for (int j=start; j < end; ++j){
094                                    int i = j % correct_results.length;
095                                    s.update(correct_results[i], candidate_results[i]);
096                            }
097                            // System.out.println(iter + " (" + N + "): " + s.toString());
098                            float tss = s.getTSS();
099                            sumx += tss;
100                            sumx2 += tss*tss;
101                    }
102                    float tss_mean = sumx / NUM_ITER;
103                    float tss_var  = (sumx2 - (sumx*sumx)/NUM_ITER)/(NUM_ITER-1);
104                    float tss_stddev = (tss_var > 0.0001f)? (float) Math.sqrt(tss_var) : 0;
105                    System.out.println("Random runs: TSS= " + tss_mean + " +/- " + tss_stddev + " estimated over " + NUM_ITER + " iterations with N=" + N);
106            }
107            
108            public static int[] readResults(String fileName) throws Exception {
109                    BufferedReader reader = new BufferedReader(new FileReader(fileName));
110                    List<Integer> results = new ArrayList<Integer>();
111                    String line;
112                    while ((line = reader.readLine()) != null){
113                            results.add(Integer.parseInt(line));
114                    }
115                    int[] v = new int[results.size()];
116                    for (int i=0; i < v.length; ++i){
117                            v[i] = results.get(i);
118                    }
119                    return v;
120            }
121            
122            /**
123             * Computes and maintains statistics
124             */
125            public static class Stat {
126                    private int[][] stats = new int[5][5];
127    
128                    public void update(int expectedValue, int alg_stormtype) {
129                            stats[expectedValue][alg_stormtype]++;
130                    }
131    
132                    public String toString() {
133                            String newline = System.getProperty("line.separator");
134                            DecimalFormat idf = new DecimalFormat("  ");
135                            DecimalFormat adf = new DecimalFormat("##.0");
136                            // Header
137                            StringBuilder sb = new StringBuilder();
138                            sb.append("       Got->  ");
139                            for (int i = 0; i < stats.length; ++i) {
140                                    sb.append(idf.format(i)).append("\t");
141                            }
142                            sb.append("Accuracy").append(newline);
143                            // Matrix
144                            for (int i = 0; i < stats.length; ++i) {
145                                    sb.append("Expected ").append(idf.format(i)).append(": ");
146                                    float denom = 0;
147                                    for (int j = 0; j < stats[i].length; ++j) {
148                                            sb.append(idf.format(stats[i][j])).append("\t");
149                                            denom += stats[i][j];
150                                    }
151                                    float acc = 100 * stats[i][i] / denom;
152                                    sb.append(adf.format(acc)).append(newline);
153                            }
154    
155                            // Trailer
156                            sb.append("Accuracy        ");
157                            for (int i = 0; i < stats.length; ++i) {
158                                    float denom = 0;
159                                    for (int j = 0; j < stats[i].length; ++j) {
160                                            denom += stats[j][i];
161                                    }
162                                    float acc = 100 * stats[i][i] / denom;
163                                    sb.append(adf.format(acc)).append("\t");
164                            }
165    
166                            // Final
167                            sb.append(newline).append("True Skill Score (TSS)=").append(
168                                            getTSS()).append(newline);
169    
170                            return sb.toString();
171                    }
172    
173                    /** Computes True Skill Score (Hanssen/Kuipers/Peirces' skill score) */
174                    private float getTSS() {
175                            float sum_a = 0;
176                            float sum_b = 0;
177                            float sum_c = 0;
178                            int numCategories = stats.length;
179                            float[] numForecasts = new float[numCategories];
180                            float[] numObservations = new float[numCategories];
181                            for (int i = 0; i < numCategories; ++i) {
182                                    for (int j = 0; j < numCategories; ++j) {
183                                            // stats is stats[observed][forecast]
184                                            numForecasts[i] += stats[j][i];
185                                            numObservations[i] += stats[i][j];
186                                    }
187                            }
188                            int N = 0;
189                            for (int i = 0; i < numCategories; ++i) {
190                                    N += numForecasts[i];
191                            }
192                            for (int j = 0; j < numCategories; ++j) {
193                                    sum_a += stats[j][j];
194                                    sum_b += numForecasts[j] * numObservations[j];
195                                    sum_c += numForecasts[j] * numForecasts[j];
196                            }
197                            float TSS = ((sum_a / N) - (sum_b) / (N * N))
198                                            / (1 - sum_c / (N * N));
199                            return TSS;
200                    }
201            }
202    }