001 /** 002 * 003 */ 004 package org.wdssii.decisiontree; 005 006 import java.beans.XMLEncoder; 007 import java.io.BufferedReader; 008 import java.io.File; 009 import java.io.FileOutputStream; 010 import java.io.FileReader; 011 import java.io.FileWriter; 012 import java.io.PrintWriter; 013 import java.util.ArrayList; 014 import java.util.Collections; 015 import java.util.List; 016 017 /** 018 * Command-line program that invokes Quinlan's C4.5 algorithm 019 * @author lakshman 020 * 021 */ 022 public class Train { 023 /** 024 * Trains and prints out a decision tree 025 */ 026 public static void main(String[] args) throws Exception { 027 if (args.length < 2) { 028 System.err 029 .println("Usage: java org.jscience.statistics.decisiontree.Train trainingfile.csv outdir [pruningFraction=0.1] [shuffle=0/1] [separator]\n"); 030 System.err 031 .println("The training file should have each example on a line, with all the attributes being numeric and separated by commas. The final attribute should be the true classification\n"); 032 System.err 033 .println("Note that if shuffle is turned on, line numbers reported by this program in error messages will be incorrect."); 034 return; 035 } 036 String infile = args[0]; 037 String outdir = args[1]; 038 String pruningFraction = (args.length > 2) ? args[2] : "0.1"; 039 boolean shuffle = (args.length > 3)? (args[3].equals("1")) : false; 040 String separator = (args.length > 4) ? args[4] : ","; 041 BufferedReader reader = null; 042 try { 043 new File(outdir).mkdirs(); 044 // read the file 045 reader = new BufferedReader(new FileReader(infile)); 046 String line; 047 List<String[]> trainingData = new ArrayList<String[]>(); 048 while ((line = reader.readLine()) != null) { 049 String[] splitValues = line.split(separator); 050 trainingData.add(splitValues); 051 } 052 053 if ( shuffle ){ 054 Collections.shuffle(trainingData); 055 } 056 057 // get the data and categories 058 int numTraining = trainingData.size(); 059 int numAttr = trainingData.get(0).length - 1; // last one is 060 // category 061 float[][] data = new float[numTraining][numAttr]; 062 int[] categories = new int[numTraining]; 063 for (int i = 0; i < numTraining; ++i) { 064 String[] indata = trainingData.get(i); 065 if (indata.length != (numAttr + 1)) { 066 throw new IllegalArgumentException("Row number " + (i + 1) 067 + " of file has " + indata.length 068 + " attributes. Expected " + (numAttr + 1)); 069 } 070 for (int j = 0; j < numAttr; ++j) { 071 try { 072 data[i][j] = Float.parseFloat(indata[j]); 073 } catch (NumberFormatException e) { 074 System.err.println(e.getMessage() + " in attribute no." 075 + (j + 1) + " of line " + (i + 1) 076 + " Assuming zero."); 077 data[i][j] = 0; 078 } 079 } 080 categories[i] = Integer.parseInt(indata[numAttr]); 081 } 082 083 084 085 // Create classifier and learn decision tree 086 QuinlanC45AxialDecisionTreeCreator classifier = new QuinlanC45AxialDecisionTreeCreator( 087 Float.parseFloat(pruningFraction)); 088 DecisionTree tree = classifier.learn(data, categories); 089 090 // Java program 091 String program = tree.toJava(); 092 System.out.println(program); 093 PrintWriter javaWriter = new PrintWriter(new FileWriter(outdir 094 + "/DecisionTree.java")); 095 javaWriter.println(program); 096 javaWriter.close(); 097 098 // skill 099 int[] algResults = new int[categories.length]; 100 for (int i = 0; i < algResults.length; ++i) { 101 algResults[i] = tree.classify(data[i]); 102 } 103 MulticategorySkillScore tss = new MulticategorySkillScore( 104 classifier.getNumCategories()); 105 tss.update(categories, algResults); 106 System.out.println(tss); 107 PrintWriter skillWriter = new PrintWriter(new FileWriter(outdir 108 + "/skill.txt")); 109 skillWriter.println(tss); 110 skillWriter.close(); 111 112 // the tree itself 113 XMLEncoder encoder = new XMLEncoder(new FileOutputStream(outdir 114 + "/decisiontree.xml")); 115 encoder.writeObject(tree); 116 encoder.close(); 117 System.out.println("decisiontree.xml written out to " + outdir); 118 } finally { 119 if (reader != null) { 120 reader.close(); 121 } 122 } 123 } 124 }