001    /**
002     * 
003     */
004    package org.wdssii.decisiontree;
005    
006    import java.beans.XMLDecoder;
007    import java.io.BufferedReader;
008    import java.io.File;
009    import java.io.FileInputStream;
010    import java.io.FileReader;
011    import java.io.FileWriter;
012    import java.io.PrintWriter;
013    import java.util.ArrayList;
014    import java.util.List;
015    
016    /**
017     * 
018     * Uses a  trained decision tree to classify cases. If labels are provided, also reports skill score.
019     * 
020     * @author lakshman
021     *
022     */
023    public class Classify {
024            /**
025             * Classifies all inputs in a file based on a trained decision tree
026             */
027            public static void main(String[] args) throws Exception {
028                    if (args.length < 2) {
029                            System.err
030                                            .println("Usage: java org.jscience.statistics.decisiontree.Classify decisiontree.xml testinfile.csv outdir [separator]\n");
031                            System.err
032                                            .println("The data file should have each example on a line, with all the attributes being numeric and separated by commas.");
033                            System.err
034                                            .println("The final attribute can be the true classification if known. In this case, skill will be computed.");
035                            return;
036                    }
037    
038                    String decisiontree = args[0];
039                    String infile = args[1];
040                    String outdir = args[2];
041                    String separator = (args.length > 3) ? args[3] : ",";
042    
043                    XMLDecoder decoder = new XMLDecoder(new FileInputStream(decisiontree));
044                    AxialDecisionTree tree = (AxialDecisionTree) decoder.readObject();
045                    
046                    BufferedReader reader = null;
047                    try {
048                            new File(outdir).mkdirs();
049                            // read the file
050                            reader = new BufferedReader(new FileReader(infile));
051                            String line;
052                            List<String[]> testingData = new ArrayList<String[]>();
053                            while ((line = reader.readLine()) != null) {
054                                    String[] splitValues = line.split(separator);
055                                    testingData.add(splitValues);
056                            }
057                            // get the data and categories
058                            int numTesting = testingData.size();
059                            if (numTesting == 0){
060                                    System.err.println("Empty file: not processed");
061                                    return;
062                            }
063                            int numAttrInFile = testingData.get(0).length;
064                            int numAttr = tree.getNumAttributes();
065                            if (numAttrInFile < numAttr || numAttrInFile > (numAttr+1) ){
066                                    // numAttr and numAttr+1 are ok
067                                    throw new IllegalArgumentException("The file contains only " + numAttrInFile + " columns but the decision tree was trained with " + numAttr);
068                            }
069                            float[][] data = new float[numTesting][numAttr];
070                            int[] categories = (numAttrInFile > numAttr)? new int[numTesting] : null;
071                            for (int i = 0; i < numTesting; ++i) {
072                                    String[] indata = testingData.get(i);
073                                    if (indata.length != numAttrInFile) {
074                                            throw new IllegalArgumentException("Row number " + (i + 1)
075                                                            + " of file has " + indata.length
076                                                            + " attributes. Expected " + numAttrInFile);
077                                    }
078                                    for (int j = 0; j < numAttr; ++j) {
079                                            try {
080                                                    data[i][j] = Float.parseFloat(indata[j]);
081                                            } catch (NumberFormatException e) {
082                                                    System.err.println(e.getMessage() + " in attribute no."
083                                                                    + (j + 1) + " of line " + (i + 1)
084                                                                    + " Assuming zero.");
085                                                    data[i][j] = 0;
086                                            }
087                                    }
088                                    if ( categories != null ){
089                                            categories[i] = Integer.parseInt(indata[numAttr]);
090                                    }
091                            }
092    
093                            // classify and compute skill
094                            MulticategorySkillScore tss = new MulticategorySkillScore(tree.getNumCategories());
095                            PrintWriter writer = new PrintWriter(new FileWriter(outdir+"/classified.txt"));
096                            for (int i=0; i < numTesting; ++i){
097                                    int result = tree.classify(data[i]);
098                                    writer.println(result);
099                                    if (categories != null){
100                                            tss.update(categories[i], result);
101                                    }
102                            }
103    
104                            if ( categories != null ){
105                                    System.out.println(tss);
106                                    PrintWriter skillWriter = new PrintWriter(new FileWriter(outdir
107                                                    + "/testing_skill.txt"));
108                                    skillWriter.println(tss);
109                                    skillWriter.close();
110                            }
111                            
112                    } finally {
113                            if (reader != null) {
114                                    reader.close();
115                            }
116                    }
117            }
118    }