001 /** 002 * 003 */ 004 package org.wdssii.decisiontree; 005 006 import java.io.Serializable; 007 008 /** 009 * A node in an axial decision tree has two branches and a condition to decide 010 * which branch to take. 011 * 012 * @author lakshman 013 * 014 */ 015 @SuppressWarnings("serial") 016 public class AxialTreeNode implements Serializable{ 017 private float thresh = 0; 018 019 private int attributeNumber = -1; 020 021 private AxialTreeNode left = null; 022 023 private AxialTreeNode right = null; 024 025 private int defaultCategory = -1; 026 027 public int classify(float[] data) { 028 if (attributeNumber >= 0 && attributeNumber < data.length) { 029 if (data[attributeNumber] < thresh) { 030 return left.classify(data); 031 } else { 032 return right.classify(data); 033 } 034 } 035 return defaultCategory; 036 } 037 038 boolean isHandledByLeftBranch(float[] data) { 039 return (attributeNumber >= 0 && attributeNumber < data.length && data[attributeNumber] < thresh); 040 } 041 042 boolean isHandledByRightBranch(float[] data) { 043 return (attributeNumber >= 0 && attributeNumber < data.length && data[attributeNumber] >= thresh); 044 } 045 046 public AxialTreeNode getLeft() { 047 return left; 048 } 049 050 public void setLeft(AxialTreeNode left) { 051 this.left = left; 052 } 053 054 public AxialTreeNode getRight() { 055 return right; 056 } 057 058 public void setRight(AxialTreeNode right) { 059 this.right = right; 060 } 061 062 public AxialTreeNode() { 063 this(0); 064 } 065 066 /** node that supplies the same category for all inputs */ 067 public AxialTreeNode(int category) { 068 this.defaultCategory = category; 069 } 070 071 /** A node in the tree that branches out. */ 072 public AxialTreeNode(int attributeNo, float thresh, AxialTreeNode leftNode, 073 AxialTreeNode rightNode, int defaultCategory) { 074 this.attributeNumber = attributeNo; 075 this.thresh = thresh; 076 this.left = leftNode; 077 this.right = rightNode; 078 this.defaultCategory = defaultCategory; 079 } 080 081 @Override 082 public String toString() { 083 if (attributeNumber < 0) { 084 return "Stump-node: " + defaultCategory; 085 } else { 086 return "Branch-node on attr=" + attributeNumber + " thresh=" 087 + thresh; 088 } 089 } 090 091 public void appendJava(StringBuilder sb, int depth, String newline) { 092 // Indent 093 char[] indent = new char[depth * 2]; 094 for (int i = 0; i < indent.length; ++i) { 095 indent[i] = ' '; 096 } 097 098 // valid attribute to make decision? 099 if (attributeNumber >= 0) { 100 // left 101 sb.append(indent).append("if ( data[").append(attributeNumber) 102 .append("] < ").append(thresh).append(" ){") 103 .append(newline); 104 left.appendJava(sb, depth + 1, newline); 105 106 // right 107 sb.append(indent).append("} else {").append(newline); 108 right.appendJava(sb, depth + 1, newline); 109 110 sb.append(indent).append("}").append(newline); 111 } else { 112 // Stump node 113 sb.append(indent).append("return ").append(defaultCategory).append( 114 ";").append(newline); 115 116 } 117 } 118 119 public void normalize() { 120 if (attributeNumber >= 0 && left.attributeNumber < 0 121 && right.attributeNumber < 0 122 && left.defaultCategory == right.defaultCategory) { 123 this.attributeNumber = -1; 124 this.defaultCategory = left.defaultCategory; 125 this.left = null; 126 this.right = null; 127 } 128 } 129 130 public int getDefaultCategory() { 131 return defaultCategory; 132 } 133 134 public int getAttributeNumber() { 135 return attributeNumber; 136 } 137 138 public void setAttributeNumber(int attributeNumber) { 139 this.attributeNumber = attributeNumber; 140 } 141 142 public float getThresh() { 143 return thresh; 144 } 145 146 public void setThresh(float thresh) { 147 this.thresh = thresh; 148 } 149 150 public void setDefaultCategory(int defaultCategory) { 151 this.defaultCategory = defaultCategory; 152 } 153 }