View Javadoc

1   package classification;
2   
3   import java.util.ArrayList;
4   import java.util.Random;
5   
6   import types.Alphabet;
7   import types.ClassificationInstance;
8   import types.FeatureFunction;
9   import types.LinearClassifier;
10  import types.SparseVector;
11  import types.StaticUtils;
12  
13  public class AdaBoost {
14  
15  	int numIterations;
16  	Alphabet xAlphabet;
17  	Alphabet yAlphabet;
18  	FeatureFunction fxy;
19  	double smooth = 0.01;
20  
21  	public AdaBoost(int numIterations, Alphabet xAlphabet, Alphabet yAlphabet,
22  			FeatureFunction fxy) {
23  		this.numIterations = numIterations;
24  		this.xAlphabet = xAlphabet;
25  		this.yAlphabet = yAlphabet;
26  		this.fxy = fxy;
27  	}
28  
29  	public void printArray(double[] a) {
30  		for (int i = 0; i < a.length; i++) {
31  			System.out.print(a[i] + " ");
32  		}
33  		System.out.println();
34  	}
35  
36  	public LinearClassifier batchTrain(
37  			ArrayList<ClassificationInstance> trainingData) {
38  		LinearClassifier result = new LinearClassifier(xAlphabet, yAlphabet,
39  				fxy);
40  		double[] w = new double[trainingData.size()];
41  		for (int i = 0; i < w.length; i++)
42  			w[i] = 1.0 / trainingData.size();
43  		// choose $t$ weights.
44  		double[] correct = new double[fxy.wSize()];
45  		double[] wrong = new double[fxy.wSize()];
46  
47  		@SuppressWarnings("unused")
48  		int oldBest = 0;
49  		for (int iter = 0; iter < numIterations; iter++) {
50  			computeAccuracies(correct, wrong, trainingData, w);
51  			// System.out.println();
52  			// System.out.print(" correct = ");
53  			// printArray(correct);
54  			// System.out.print(" wrong = ");
55  			// printArray(wrong);
56  			int bestFeature = chooseBest(correct, wrong);
57  			double alpha = Math
58  					.log((correct[bestFeature]) / wrong[bestFeature]) / 2;
59  			result.w[bestFeature] += alpha;
60  			updateW(w, bestFeature, alpha, trainingData);
61  			// System.out.print(" w = ");
62  			// printArray(w);
63  		}
64  		return result;
65  	}
66  
67  	private int chooseBest(double[] correct, double[] wrong) {
68  		int res = 0;
69  		double bestval = Double.MIN_VALUE;
70  		for (int i = 0; i < correct.length; i++) {
71  			double val = correct[i] - wrong[i];
72  			if (val > bestval) {
73  				res = i;
74  				bestval = val;
75  			}
76  		}
77  		return res;
78  	}
79  
80  	private void updateW(double[] w, int bestFeature, double alpha,
81  			ArrayList<ClassificationInstance> trainingData) {
82  		double wrongUpdate = Math.exp(alpha);
83  		double correctUpdate = Math.exp(-alpha);
84  		for (int instInd = 0; instInd < trainingData.size(); instInd++) {
85  			ClassificationInstance inst = trainingData.get(instInd);
86  			for (int y = 0; y < yAlphabet.size(); y++) {
87  				SparseVector fv = fxy.apply(inst.x, y);
88  				for (int i = 0; i < fv.numEntries(); i++) {
89  					if (fv.getIndexAt(i) == bestFeature) {
90  						if (y == inst.y)
91  							w[instInd] *= correctUpdate;
92  						else
93  							w[instInd] *= wrongUpdate;
94  					}
95  				}
96  			}
97  		}
98  		double sum = StaticUtils.sum(w);
99  		for (int i = 0; i < w.length; i++) {
100 			w[i] /= sum;
101 		}
102 
103 	}
104 
105 	private void computeAccuracies(double[] correct, double[] wrongs,
106 			ArrayList<ClassificationInstance> trainingData, double[] w) {
107 		double total = 2 * smooth;
108 		for (int i = 0; i < correct.length; i++) {
109 			correct[i] = smooth;
110 			wrongs[i] = smooth;
111 		}
112 		for (int instInd = 0; instInd < trainingData.size(); instInd++) {
113 			ClassificationInstance inst = trainingData.get(instInd);
114 			total += w[instInd];
115 			for (int y = 0; y < yAlphabet.size(); y++) {
116 				SparseVector fv = fxy.apply(inst.x, y);
117 				if (y == inst.y) {
118 					for (int i = 0; i < fv.numEntries(); i++) {
119 						correct[fv.getIndexAt(i)] += w[instInd];
120 					}
121 				} else {
122 					for (int i = 0; i < fv.numEntries(); i++) {
123 						wrongs[fv.getIndexAt(i)] += w[instInd];
124 					}
125 				}
126 			}
127 		}
128 		for (int i = 0; i < correct.length; i++) {
129 			correct[i] /= total;
130 			wrongs[i] /= total;
131 		}
132 	}
133 
134 	public static void main(String[] args) {
135 		ArrayList<ClassificationInstance> train = new ArrayList<ClassificationInstance>();
136 		Alphabet xAlphabet = new Alphabet();
137 		Alphabet yAlphabet = new Alphabet();
138 		String[] classes = new String[] { "a", "b" };
139 		Random r = new Random(10);
140 		int numFeats = 5;
141 		double randomFrac = 0.5;
142 		double missingFrac = 0.5;
143 		for (int instInd = 0; instInd < 10; instInd++) {
144 			String label = classes[r.nextInt(classes.length)];
145 			SparseVector sv = new SparseVector();
146 			for (int fInd = 0; fInd < numFeats; fInd++) {
147 				if (r.nextDouble() < missingFrac)
148 					continue;
149 				String tmpLab = label;
150 				if (r.nextDouble() < randomFrac)
151 					tmpLab = classes[r.nextInt(classes.length)];
152 				sv.add(xAlphabet.lookupObject(tmpLab + fInd), 1);
153 			}
154 			train.add(new ClassificationInstance(xAlphabet, yAlphabet, sv,
155 					label));
156 		}
157 		AdaBoost boost = new AdaBoost(10, xAlphabet, yAlphabet,
158 				new CompleteFeatureFunction(xAlphabet, yAlphabet));
159 		LinearClassifier h = boost.batchTrain(train);
160 		System.out.println(StaticUtils.computeAccuracy(h, train));
161 	}
162 
163 }