001    /**
002     * 
003     */
004    package org.wdssii.decisiontree;
005    
006    import java.io.Serializable;
007    
008    /**
009     * A node in an axial decision tree has two branches and a condition to decide
010     * which branch to take.
011     * 
012     * @author lakshman
013     * 
014     */
015    @SuppressWarnings("serial")
016    public class AxialTreeNode implements Serializable{
017            private float thresh = 0;
018    
019            private int attributeNumber = -1;
020    
021            private AxialTreeNode left = null;
022    
023            private AxialTreeNode right = null;
024    
025            private int defaultCategory = -1;
026    
027            public int classify(float[] data) {
028                    if (attributeNumber >= 0 && attributeNumber < data.length) {
029                            if (data[attributeNumber] < thresh) {
030                                    return left.classify(data);
031                            } else {
032                                    return right.classify(data);
033                            }
034                    }
035                    return defaultCategory;
036            }
037    
038            boolean isHandledByLeftBranch(float[] data) {
039                    return (attributeNumber >= 0 && attributeNumber < data.length && data[attributeNumber] < thresh);
040            }
041    
042            boolean isHandledByRightBranch(float[] data) {
043                    return (attributeNumber >= 0 && attributeNumber < data.length && data[attributeNumber] >= thresh);
044            }
045    
046            public AxialTreeNode getLeft() {
047                    return left;
048            }
049    
050            public void setLeft(AxialTreeNode left) {
051                    this.left = left;
052            }
053    
054            public AxialTreeNode getRight() {
055                    return right;
056            }
057    
058            public void setRight(AxialTreeNode right) {
059                    this.right = right;
060            }
061    
062            public AxialTreeNode() {
063                    this(0);
064            }
065    
066            /** node that supplies the same category for all inputs */
067            public AxialTreeNode(int category) {
068                    this.defaultCategory = category;
069            }
070    
071            /** A node in the tree that branches out. */
072            public AxialTreeNode(int attributeNo, float thresh, AxialTreeNode leftNode,
073                            AxialTreeNode rightNode, int defaultCategory) {
074                    this.attributeNumber = attributeNo;
075                    this.thresh = thresh;
076                    this.left = leftNode;
077                    this.right = rightNode;
078                    this.defaultCategory = defaultCategory;
079            }
080    
081            @Override
082            public String toString() {
083                    if (attributeNumber < 0) {
084                            return "Stump-node: " + defaultCategory;
085                    } else {
086                            return "Branch-node on attr=" + attributeNumber + " thresh="
087                                            + thresh;
088                    }
089            }
090    
091            public void appendJava(StringBuilder sb, int depth, String newline) {
092                    // Indent
093                    char[] indent = new char[depth * 2];
094                    for (int i = 0; i < indent.length; ++i) {
095                            indent[i] = ' ';
096                    }
097    
098                    // valid attribute to make decision?
099                    if (attributeNumber >= 0) {
100                            // left
101                            sb.append(indent).append("if ( data[").append(attributeNumber)
102                                            .append("] < ").append(thresh).append(" ){")
103                                            .append(newline);
104                            left.appendJava(sb, depth + 1, newline);
105    
106                            // right
107                            sb.append(indent).append("} else {").append(newline);
108                            right.appendJava(sb, depth + 1, newline);
109    
110                            sb.append(indent).append("}").append(newline);
111                    } else {
112                            // Stump node
113                            sb.append(indent).append("return ").append(defaultCategory).append(
114                                            ";").append(newline);
115    
116                    }
117            }
118    
119            public void normalize() {
120                    if (attributeNumber >= 0 && left.attributeNumber < 0
121                                    && right.attributeNumber < 0
122                                    && left.defaultCategory == right.defaultCategory) {
123                            this.attributeNumber = -1;
124                            this.defaultCategory = left.defaultCategory;
125                            this.left = null;
126                            this.right = null;
127                    }
128            }
129    
130            public int getDefaultCategory() {
131                    return defaultCategory;
132            }
133    
134            public int getAttributeNumber() {
135                    return attributeNumber;
136            }
137    
138            public void setAttributeNumber(int attributeNumber) {
139                    this.attributeNumber = attributeNumber;
140            }
141    
142            public float getThresh() {
143                    return thresh;
144            }
145    
146            public void setThresh(float thresh) {
147                    this.thresh = thresh;
148            }
149    
150            public void setDefaultCategory(int defaultCategory) {
151                    this.defaultCategory = defaultCategory;
152            }
153    }