001 /** 002 * 003 */ 004 package org.wdssii.decisiontree; 005 006 import java.beans.XMLDecoder; 007 import java.io.BufferedReader; 008 import java.io.File; 009 import java.io.FileInputStream; 010 import java.io.FileReader; 011 import java.io.FileWriter; 012 import java.io.PrintWriter; 013 import java.util.ArrayList; 014 import java.util.List; 015 016 /** 017 * 018 * Uses a trained decision tree to classify cases. If labels are provided, also reports skill score. 019 * 020 * @author lakshman 021 * 022 */ 023 public class Classify { 024 /** 025 * Classifies all inputs in a file based on a trained decision tree 026 */ 027 public static void main(String[] args) throws Exception { 028 if (args.length < 2) { 029 System.err 030 .println("Usage: java org.jscience.statistics.decisiontree.Classify decisiontree.xml testinfile.csv outdir [separator]\n"); 031 System.err 032 .println("The data file should have each example on a line, with all the attributes being numeric and separated by commas."); 033 System.err 034 .println("The final attribute can be the true classification if known. In this case, skill will be computed."); 035 return; 036 } 037 038 String decisiontree = args[0]; 039 String infile = args[1]; 040 String outdir = args[2]; 041 String separator = (args.length > 3) ? args[3] : ","; 042 043 XMLDecoder decoder = new XMLDecoder(new FileInputStream(decisiontree)); 044 AxialDecisionTree tree = (AxialDecisionTree) decoder.readObject(); 045 046 BufferedReader reader = null; 047 try { 048 new File(outdir).mkdirs(); 049 // read the file 050 reader = new BufferedReader(new FileReader(infile)); 051 String line; 052 List<String[]> testingData = new ArrayList<String[]>(); 053 while ((line = reader.readLine()) != null) { 054 String[] splitValues = line.split(separator); 055 testingData.add(splitValues); 056 } 057 // get the data and categories 058 int numTesting = testingData.size(); 059 if (numTesting == 0){ 060 System.err.println("Empty file: not processed"); 061 return; 062 } 063 int numAttrInFile = testingData.get(0).length; 064 int numAttr = tree.getNumAttributes(); 065 if (numAttrInFile < numAttr || numAttrInFile > (numAttr+1) ){ 066 // numAttr and numAttr+1 are ok 067 throw new IllegalArgumentException("The file contains only " + numAttrInFile + " columns but the decision tree was trained with " + numAttr); 068 } 069 float[][] data = new float[numTesting][numAttr]; 070 int[] categories = (numAttrInFile > numAttr)? new int[numTesting] : null; 071 for (int i = 0; i < numTesting; ++i) { 072 String[] indata = testingData.get(i); 073 if (indata.length != numAttrInFile) { 074 throw new IllegalArgumentException("Row number " + (i + 1) 075 + " of file has " + indata.length 076 + " attributes. Expected " + numAttrInFile); 077 } 078 for (int j = 0; j < numAttr; ++j) { 079 try { 080 data[i][j] = Float.parseFloat(indata[j]); 081 } catch (NumberFormatException e) { 082 System.err.println(e.getMessage() + " in attribute no." 083 + (j + 1) + " of line " + (i + 1) 084 + " Assuming zero."); 085 data[i][j] = 0; 086 } 087 } 088 if ( categories != null ){ 089 categories[i] = Integer.parseInt(indata[numAttr]); 090 } 091 } 092 093 // classify and compute skill 094 MulticategorySkillScore tss = new MulticategorySkillScore(tree.getNumCategories()); 095 PrintWriter writer = new PrintWriter(new FileWriter(outdir+"/classified.txt")); 096 for (int i=0; i < numTesting; ++i){ 097 int result = tree.classify(data[i]); 098 writer.println(result); 099 if (categories != null){ 100 tss.update(categories[i], result); 101 } 102 } 103 104 if ( categories != null ){ 105 System.out.println(tss); 106 PrintWriter skillWriter = new PrintWriter(new FileWriter(outdir 107 + "/testing_skill.txt")); 108 skillWriter.println(tss); 109 skillWriter.close(); 110 } 111 112 } finally { 113 if (reader != null) { 114 reader.close(); 115 } 116 } 117 } 118 }