001    /**
002     * 
003     */
004    package org.wdssii.decisiontree;
005    
006    import java.text.DecimalFormat;
007    
008    /**
009     * 
010     * Computes and returns skill scores for a multi-category forecast. For
011     * definitions of the skill scores used, see:
012     * http://www.bom.gov.au/bmrc/wefor/staff/eee/verif/verif_web_page.html
013     * 
014     * Usage:
015     * <pre>
016       float[][] data = new float[numTesting][numAttr];
017       int[] categories = new int[numTesting]; // true categories
018       MulticategorySkillScore tss = new MulticategorySkillScore(tree.getNumCategories());
019       for (int i=0; i < numTesting; ++i){
020                    int result = tree.classify(data[i]);
021                    tss.update(categories[i], result);
022        }
023        float trueSkillScore = tss.getTSS();
024     * </pre> 
025     * 
026     * @author lakshman
027     * 
028     */
029    public class MulticategorySkillScore {
030            private int[][] stats;
031    
032            public void update(int expectedValue, int alg_stormtype) {
033                    stats[expectedValue][alg_stormtype]++;
034            }
035    
036            public void update(int[] expectedValue, int[] algValues) {
037                    if (expectedValue.length != algValues.length) {
038                            throw new IllegalArgumentException(
039                                            "The two arrays should be of same length");
040                    }
041                    for (int i = 0; i < expectedValue.length; ++i) {
042                            update(expectedValue[i], algValues[i]);
043                    }
044            }
045    
046            private static String align(int x){
047                    StringBuilder in = new StringBuilder("          ").append(x); // 10 spaces + the number
048                    return in.substring(in.length()-10); // last 10 digits
049            }
050            
051            public String toString() {
052                    String newline = System.getProperty("line.separator");
053                    String colsep = "    ";
054                    
055                    DecimalFormat adf = new DecimalFormat("     ###.00");
056                    // Header
057                    StringBuilder sb = new StringBuilder();
058                    sb.append("From tree:           ");
059                    for (int i = 0; i < stats.length; ++i) {
060                            sb.append(align(i)).append(colsep);
061                    }
062                    sb.append("   Accuracy").append(newline);
063                    // Matrix
064                    for (int i = 0; i < stats.length; ++i) {
065                            sb.append("Expected ").append(align(i)).append(": ");
066                            int denom = 0;
067                            for (int j = 0; j < stats[i].length; ++j) {
068                                    sb.append(align(stats[i][j])).append(colsep);
069                                    denom += stats[i][j];
070                            }
071                            float acc = (denom == 0)? 100 : (100.0f * stats[i][i] / denom);
072                            sb.append(adf.format(acc)).append(newline);
073                    }
074    
075                    // Trailer
076                    sb.append("Accuracy             ");
077                    for (int i = 0; i < stats.length; ++i) {
078                            int denom = 0;
079                            for (int j = 0; j < stats[i].length; ++j) {
080                                    denom += stats[j][i];
081                            }
082                            float acc = (denom == 0)? 100 : (100.0f * stats[i][i] / denom);
083                            sb.append(adf.format(acc)).append(colsep);
084                    }
085    
086                    // Final
087                    sb.append(newline);
088                    sb.append("Overall accuracy=").append(getOverallAccuracy()).append(newline);
089                    sb.append("True Skill Score (TSS)=").append(getTSS()).append(newline);
090    
091                    return sb.toString();
092            }
093    
094            public float getOverallAccuracy() {
095                    int denom = 0;
096                    int num = 0;
097                    for (int i=0; i < stats.length; ++i){
098                            num += stats[i][i];
099                            for (int j=0; j < stats.length; ++j){
100                                    denom += stats[i][j];
101                            }
102                    }
103                    float acc = (denom == 0)? 1 : (1.0f * num / denom);
104                    return acc;
105            }
106    
107            /**
108             * Computes True Skill Score (Hanssen/Kuipers/Peirces' skill score)
109             * 
110             */
111            public float getTSS() {
112                    float sum_a = 0;
113                    float sum_b = 0;
114                    float sum_c = 0;
115                    int numCategories = stats.length;
116                    float[] numForecasts = new float[numCategories];
117                    float[] numObservations = new float[numCategories];
118                    for (int i = 0; i < numCategories; ++i) {
119                            for (int j = 0; j < numCategories; ++j) {
120                                    // stats is stats[observed][forecast]
121                                    numForecasts[i] += stats[j][i];
122                                    numObservations[i] += stats[i][j];
123                            }
124                    }
125                    int N = 0;
126                    for (int i = 0; i < numCategories; ++i) {
127                            N += numForecasts[i];
128                    }
129                    for (int j = 0; j < numCategories; ++j) {
130                            sum_a += stats[j][j];
131                            sum_b += numForecasts[j] * numObservations[j];
132                            sum_c += numForecasts[j] * numForecasts[j];
133                    }
134                    float TSS = ((sum_a / N) - (sum_b) / (N * N)) / (1 - sum_c / (N * N));
135                    return TSS;
136            }
137    
138            public MulticategorySkillScore(int numCategories) {
139                    stats = new int[numCategories][numCategories];
140            }
141    
142    }