001 /** 002 * 003 */ 004 package org.wdssii.decisiontree; 005 006 import java.text.DecimalFormat; 007 008 /** 009 * 010 * Computes and returns skill scores for a multi-category forecast. For 011 * definitions of the skill scores used, see: 012 * http://www.bom.gov.au/bmrc/wefor/staff/eee/verif/verif_web_page.html 013 * 014 * Usage: 015 * <pre> 016 float[][] data = new float[numTesting][numAttr]; 017 int[] categories = new int[numTesting]; // true categories 018 MulticategorySkillScore tss = new MulticategorySkillScore(tree.getNumCategories()); 019 for (int i=0; i < numTesting; ++i){ 020 int result = tree.classify(data[i]); 021 tss.update(categories[i], result); 022 } 023 float trueSkillScore = tss.getTSS(); 024 * </pre> 025 * 026 * @author lakshman 027 * 028 */ 029 public class MulticategorySkillScore { 030 private int[][] stats; 031 032 public void update(int expectedValue, int alg_stormtype) { 033 stats[expectedValue][alg_stormtype]++; 034 } 035 036 public void update(int[] expectedValue, int[] algValues) { 037 if (expectedValue.length != algValues.length) { 038 throw new IllegalArgumentException( 039 "The two arrays should be of same length"); 040 } 041 for (int i = 0; i < expectedValue.length; ++i) { 042 update(expectedValue[i], algValues[i]); 043 } 044 } 045 046 private static String align(int x){ 047 StringBuilder in = new StringBuilder(" ").append(x); // 10 spaces + the number 048 return in.substring(in.length()-10); // last 10 digits 049 } 050 051 public String toString() { 052 String newline = System.getProperty("line.separator"); 053 String colsep = " "; 054 055 DecimalFormat adf = new DecimalFormat(" ###.00"); 056 // Header 057 StringBuilder sb = new StringBuilder(); 058 sb.append("From tree: "); 059 for (int i = 0; i < stats.length; ++i) { 060 sb.append(align(i)).append(colsep); 061 } 062 sb.append(" Accuracy").append(newline); 063 // Matrix 064 for (int i = 0; i < stats.length; ++i) { 065 sb.append("Expected ").append(align(i)).append(": "); 066 int denom = 0; 067 for (int j = 0; j < stats[i].length; ++j) { 068 sb.append(align(stats[i][j])).append(colsep); 069 denom += stats[i][j]; 070 } 071 float acc = (denom == 0)? 100 : (100.0f * stats[i][i] / denom); 072 sb.append(adf.format(acc)).append(newline); 073 } 074 075 // Trailer 076 sb.append("Accuracy "); 077 for (int i = 0; i < stats.length; ++i) { 078 int denom = 0; 079 for (int j = 0; j < stats[i].length; ++j) { 080 denom += stats[j][i]; 081 } 082 float acc = (denom == 0)? 100 : (100.0f * stats[i][i] / denom); 083 sb.append(adf.format(acc)).append(colsep); 084 } 085 086 // Final 087 sb.append(newline); 088 sb.append("Overall accuracy=").append(getOverallAccuracy()).append(newline); 089 sb.append("True Skill Score (TSS)=").append(getTSS()).append(newline); 090 091 return sb.toString(); 092 } 093 094 public float getOverallAccuracy() { 095 int denom = 0; 096 int num = 0; 097 for (int i=0; i < stats.length; ++i){ 098 num += stats[i][i]; 099 for (int j=0; j < stats.length; ++j){ 100 denom += stats[i][j]; 101 } 102 } 103 float acc = (denom == 0)? 1 : (1.0f * num / denom); 104 return acc; 105 } 106 107 /** 108 * Computes True Skill Score (Hanssen/Kuipers/Peirces' skill score) 109 * 110 */ 111 public float getTSS() { 112 float sum_a = 0; 113 float sum_b = 0; 114 float sum_c = 0; 115 int numCategories = stats.length; 116 float[] numForecasts = new float[numCategories]; 117 float[] numObservations = new float[numCategories]; 118 for (int i = 0; i < numCategories; ++i) { 119 for (int j = 0; j < numCategories; ++j) { 120 // stats is stats[observed][forecast] 121 numForecasts[i] += stats[j][i]; 122 numObservations[i] += stats[i][j]; 123 } 124 } 125 int N = 0; 126 for (int i = 0; i < numCategories; ++i) { 127 N += numForecasts[i]; 128 } 129 for (int j = 0; j < numCategories; ++j) { 130 sum_a += stats[j][j]; 131 sum_b += numForecasts[j] * numObservations[j]; 132 sum_c += numForecasts[j] * numForecasts[j]; 133 } 134 float TSS = ((sum_a / N) - (sum_b) / (N * N)) / (1 - sum_c / (N * N)); 135 return TSS; 136 } 137 138 public MulticategorySkillScore(int numCategories) { 139 stats = new int[numCategories][numCategories]; 140 } 141 142 }