001    package org.wdssii.aicompetition;
002    
003    import java.io.BufferedReader;
004    import java.io.FileReader;
005    import java.io.FileWriter;
006    import java.io.PrintWriter;
007    import java.text.DecimalFormat;
008    
009    /**
010     * @author lakshman
011     * 
012     */
013    public class StormType {
014    
015            /**
016             * @param args
017             */
018            public static void main(String[] args) throws Exception {
019                    if (args.length == 0) {
020                            System.err.println("Usage: java StormType filename.csv output.csv");
021                            return;
022                    }
023                    BufferedReader reader = null;
024                    PrintWriter writer = null;
025                    try {
026                            reader = new BufferedReader(new FileReader(args[0]));
027                            writer = new PrintWriter(new FileWriter(args[1]));
028                            Stat stat = new Stat();
029    
030                            String line = null;
031                            float[] inputs = new float[23];
032                            boolean statAvailable = false;
033                            while ((line = reader.readLine()) != null) {
034                                    String[] columns = line.split(",");
035                                    for (int i = 0; i < columns.length; ++i) {
036                                            inputs[i] = Float.parseFloat(columns[i]);
037                                    }
038                                    int alg_stormtype = computeStormType(inputs);
039                                    writer.println(alg_stormtype); // write out value for this
040                                                                                                    // column
041    
042                                    // Is the true value available?
043                                    if (columns.length == inputs.length) {
044                                            int expectedValue = Integer
045                                                            .parseInt(columns[columns.length - 1]);
046                                            stat.update(expectedValue, alg_stormtype);
047                                            statAvailable = true;
048                                    }
049                            }
050    
051                            if (statAvailable) {
052                                    System.out.println(stat);
053                            }
054    
055                    } finally {
056                            if (reader != null) {
057                                    reader.close();
058                            }
059                            if (writer != null) {
060                                    writer.close();
061                            }
062                    }
063            }
064    
065            /**
066             * Computes and maintains statistics
067             */
068            public static class Stat {
069                    private int[][] stats = new int[5][5];
070    
071                    public void update(int expectedValue, int alg_stormtype) {
072                            stats[expectedValue][alg_stormtype]++;
073                    }
074    
075                    public String toString() {
076                            String newline = System.getProperty("line.separator");
077                            DecimalFormat idf = new DecimalFormat("  ");
078                            DecimalFormat adf = new DecimalFormat("##.0");
079                            // Header
080                            StringBuilder sb = new StringBuilder();
081                            sb.append("       Got->  ");
082                            for (int i = 0; i < stats.length; ++i) {
083                                    sb.append(idf.format(i)).append("\t");
084                            }
085                            sb.append("Accuracy").append(newline);
086                            // Matrix
087                            for (int i = 0; i < stats.length; ++i) {
088                                    sb.append("Expected ").append(idf.format(i)).append(": ");
089                                    float denom = 0;
090                                    for (int j = 0; j < stats[i].length; ++j) {
091                                            sb.append(idf.format(stats[i][j])).append("\t");
092                                            denom += stats[i][j];
093                                    }
094                                    float acc = 100 * stats[i][i] / denom;
095                                    sb.append(adf.format(acc)).append(newline);
096                            }
097    
098                            // Trailer
099                            sb.append("Accuracy        ");
100                            for (int i = 0; i < stats.length; ++i) {
101                                    float denom = 0;
102                                    for (int j = 0; j < stats[i].length; ++j) {
103                                            denom += stats[j][i];
104                                    }
105                                    float acc = 100 * stats[i][i] / denom;
106                                    sb.append(adf.format(acc)).append("\t");
107                            }
108    
109                            // Final
110                            sb.append(newline).append("True Skill Score (TSS)=").append(
111                                            getTSS()).append(newline);
112    
113                            return sb.toString();
114                    }
115    
116                    /** Computes True Skill Score (Hanssen/Kuipers/Peirces' skill score) */
117                    private float getTSS() {
118                            float sum_a = 0;
119                            float sum_b = 0;
120                            float sum_c = 0;
121                            int numCategories = stats.length;
122                            float[] numForecasts = new float[numCategories];
123                            float[] numObservations = new float[numCategories];
124                            for (int i = 0; i < numCategories; ++i) {
125                                    for (int j = 0; j < numCategories; ++j) {
126                                            // stats is stats[observed][forecast]
127                                            numForecasts[i] += stats[j][i];
128                                            numObservations[i] += stats[i][j];
129                                    }
130                            }
131                            int N = 0;
132                            for (int i = 0; i < numCategories; ++i) {
133                                    N += numForecasts[i];
134                            }
135                            for (int j = 0; j < numCategories; ++j) {
136                                    sum_a += stats[j][j];
137                                    sum_b += numForecasts[j] * numObservations[j];
138                                    sum_c += numForecasts[j] * numForecasts[j];
139                            }
140                            float TSS = ((sum_a / N) - (sum_b) / (N * N))
141                                            / (1 - sum_c / (N * N));
142                            return TSS;
143                    }
144            }
145    
146            /**
147             * Result of Quinlan's decision tree algorithm
148             * 
149             * @param input
150             * @return stormtype
151             */
152            private static int computeStormType(float[] input) {
153    
154                    if ((input[11]) < 11.4306) {
155                            if ((input[11]) < 3.51456) {
156                                    return 0;
157                            } else {
158                                    if ((input[12]) < 38.3377) {
159                                            if ((input[21]) < 14.7142) {
160                                                    if ((input[12]) < 29.0601) {
161                                                            return 0;
162                                                    } else {
163                                                            if ((input[0]) < 2.78166) {
164                                                                    return 0;
165                                                            } else {
166                                                                    return 4;
167                                                            }
168                                                    }
169                                            } else {
170                                                    if ((input[11]) < 6.52937) {
171                                                            if ((input[20]) < 227.065) {
172                                                                    if ((input[12]) < 30.4995) {
173                                                                            return 4;
174                                                                    } else {
175                                                                            if ((input[12]) < 36.1114) {
176                                                                                    return 0;
177                                                                            } else {
178                                                                                    return 4;
179                                                                            }
180                                                                    }
181                                                            } else {
182                                                                    return 0;
183                                                            }
184                                                    } else {
185                                                            return 4;
186                                                    }
187                                            }
188                                    } else {
189                                            if ((input[11]) < 8.25122) {
190                                                    return 0;
191                                            } else {
192                                                    if ((input[8]) < 0.003965) {
193                                                            if ((input[8]) < 0.00185) {
194                                                                    return 2;
195                                                            } else {
196                                                                    return 4;
197                                                            }
198                                                    } else {
199                                                            return 0;
200                                                    }
201                                            }
202                                    }
203                            }
204                    } else {
205                            if ((input[12]) < 44.9008) {
206                                    if ((input[8]) < 0.00446) {
207                                            if ((input[12]) < 38.4237) {
208                                                    return 4;
209                                            } else {
210                                                    if ((input[20]) < 323.767) {
211                                                            return 4;
212                                                    } else {
213                                                            if ((input[11]) < 43.7864) {
214                                                                    if ((input[8]) < 0.004249) {
215                                                                            if ((input[0]) < 3.41001) {
216                                                                                    return 4;
217                                                                            } else {
218                                                                                    return 2;
219                                                                            }
220                                                                    } else {
221                                                                            return 0;
222                                                                    }
223                                                            } else {
224                                                                    return 1;
225                                                            }
226                                                    }
227                                            }
228                                    } else {
229                                            if ((input[11]) < 25.8763) {
230                                                    if ((input[12]) < 38.0125) {
231                                                            return 4;
232                                                    } else {
233                                                            if ((input[11]) < 21.5562) {
234                                                                    return 2;
235                                                            } else {
236                                                                    return 4;
237                                                            }
238                                                    }
239                                            } else {
240                                                    if ((input[0]) < 4.55339) {
241                                                            return 1;
242                                                    } else {
243                                                            return 2;
244                                                    }
245                                            }
246                                    }
247                            } else {
248                                    if ((input[0]) < 4.70992) {
249                                            if ((input[11]) < 48.6246) {
250                                                    if ((input[21]) < 21.9221) {
251                                                            if ((input[8]) < 0.006211) {
252                                                                    if ((input[0]) < 2.91861) {
253                                                                            if ((input[11]) < 21.3905) {
254                                                                                    return 4;
255                                                                            } else {
256                                                                                    return 1;
257                                                                            }
258                                                                    } else {
259                                                                            if ((input[12]) < 47.8763) {
260                                                                                    if ((input[20]) < 476.565) {
261                                                                                            if ((input[21]) < 13.0049) {
262                                                                                                    return 4;
263                                                                                            } else {
264                                                                                                    if ((input[8]) < 0.002292) {
265                                                                                                            return 2;
266                                                                                                    } else {
267                                                                                                            return 1;
268                                                                                                    }
269                                                                                            }
270                                                                                    } else {
271                                                                                            return 2;
272                                                                                    }
273                                                                            } else {
274                                                                                    if ((input[0]) < 3.6226) {
275                                                                                            return 2;
276                                                                                    } else {
277                                                                                            return 1;
278                                                                                    }
279                                                                            }
280                                                                    }
281                                                            } else {
282                                                                    return 1;
283                                                            }
284                                                    } else {
285                                                            return 2;
286                                                    }
287                                            } else {
288                                                    return 1;
289                                            }
290                                    } else {
291                                            return 2;
292                                    }
293                            }
294                    }
295    
296            }
297    
298    }