001    /**
002     * 
003     */
004    package org.wdssii.decisiontree;
005    
006    import java.beans.XMLEncoder;
007    import java.io.BufferedReader;
008    import java.io.File;
009    import java.io.FileOutputStream;
010    import java.io.FileReader;
011    import java.io.FileWriter;
012    import java.io.PrintWriter;
013    import java.util.ArrayList;
014    import java.util.Collections;
015    import java.util.List;
016    
017    /**
018     * Command-line program that invokes Quinlan's C4.5 algorithm
019     * @author lakshman
020     *
021     */
022    public class Train {
023            /**
024             * Trains and prints out a decision tree
025             */
026            public static void main(String[] args) throws Exception {
027                    if (args.length < 2) {
028                            System.err
029                                            .println("Usage: java org.jscience.statistics.decisiontree.Train trainingfile.csv outdir [pruningFraction=0.1] [shuffle=0/1] [separator]\n");
030                            System.err
031                                            .println("The training file should have each example on a line, with all the attributes being numeric and separated by commas. The final attribute should be the true classification\n");
032                            System.err
033                                            .println("Note that if shuffle is turned on, line numbers reported by this program in error messages will be incorrect.");
034                            return;
035                    }
036                    String infile = args[0];
037                    String outdir = args[1];
038                    String pruningFraction = (args.length > 2) ? args[2] : "0.1";
039                    boolean shuffle = (args.length > 3)? (args[3].equals("1")) : false;
040                    String separator = (args.length > 4) ? args[4] : ",";
041                    BufferedReader reader = null;
042                    try {
043                            new File(outdir).mkdirs();
044                            // read the file
045                            reader = new BufferedReader(new FileReader(infile));
046                            String line;
047                            List<String[]> trainingData = new ArrayList<String[]>();
048                            while ((line = reader.readLine()) != null) {
049                                    String[] splitValues = line.split(separator);
050                                    trainingData.add(splitValues);
051                            }
052                            
053                            if ( shuffle ){
054                                    Collections.shuffle(trainingData);
055                            }
056                            
057                            // get the data and categories
058                            int numTraining = trainingData.size();
059                            int numAttr = trainingData.get(0).length - 1; // last one is
060                                                                                                                            // category
061                            float[][] data = new float[numTraining][numAttr];
062                            int[] categories = new int[numTraining];
063                            for (int i = 0; i < numTraining; ++i) {
064                                    String[] indata = trainingData.get(i);
065                                    if (indata.length != (numAttr + 1)) {
066                                            throw new IllegalArgumentException("Row number " + (i + 1)
067                                                            + " of file has " + indata.length
068                                                            + " attributes. Expected " + (numAttr + 1));
069                                    }
070                                    for (int j = 0; j < numAttr; ++j) {
071                                            try {
072                                                    data[i][j] = Float.parseFloat(indata[j]);
073                                            } catch (NumberFormatException e) {
074                                                    System.err.println(e.getMessage() + " in attribute no."
075                                                                    + (j + 1) + " of line " + (i + 1)
076                                                                    + " Assuming zero.");
077                                                    data[i][j] = 0;
078                                            }
079                                    }
080                                    categories[i] = Integer.parseInt(indata[numAttr]);
081                            }
082    
083                            
084                            
085                            // Create classifier and learn decision tree
086                            QuinlanC45AxialDecisionTreeCreator classifier = new QuinlanC45AxialDecisionTreeCreator(
087                                            Float.parseFloat(pruningFraction));
088                            DecisionTree tree = classifier.learn(data, categories);
089    
090                            // Java program
091                            String program = tree.toJava();
092                            System.out.println(program);
093                            PrintWriter javaWriter = new PrintWriter(new FileWriter(outdir
094                                            + "/DecisionTree.java"));
095                            javaWriter.println(program);
096                            javaWriter.close();
097    
098                            // skill
099                            int[] algResults = new int[categories.length];
100                            for (int i = 0; i < algResults.length; ++i) {
101                                    algResults[i] = tree.classify(data[i]);
102                            }
103                            MulticategorySkillScore tss = new MulticategorySkillScore(
104                                            classifier.getNumCategories());
105                            tss.update(categories, algResults);
106                            System.out.println(tss);
107                            PrintWriter skillWriter = new PrintWriter(new FileWriter(outdir
108                                            + "/skill.txt"));
109                            skillWriter.println(tss);
110                            skillWriter.close();
111    
112                            // the tree itself
113                            XMLEncoder encoder = new XMLEncoder(new FileOutputStream(outdir
114                                            + "/decisiontree.xml"));
115                            encoder.writeObject(tree);
116                            encoder.close();
117                            System.out.println("decisiontree.xml written out to " + outdir);
118                    } finally {
119                            if (reader != null) {
120                                    reader.close();
121                            }
122                    }
123            }
124    }