1 package classification;
2
3 import java.util.ArrayList;
4 import java.util.Random;
5
6 import types.Alphabet;
7 import types.ClassificationInstance;
8 import types.FeatureFunction;
9 import types.LinearClassifier;
10 import types.SparseVector;
11 import types.StaticUtils;
12
13 public class AdaBoost {
14
15 int numIterations;
16 Alphabet xAlphabet;
17 Alphabet yAlphabet;
18 FeatureFunction fxy;
19 double smooth = 0.01;
20
21 public AdaBoost(int numIterations, Alphabet xAlphabet, Alphabet yAlphabet,
22 FeatureFunction fxy) {
23 this.numIterations = numIterations;
24 this.xAlphabet = xAlphabet;
25 this.yAlphabet = yAlphabet;
26 this.fxy = fxy;
27 }
28
29 public void printArray(double[] a) {
30 for (int i = 0; i < a.length; i++) {
31 System.out.print(a[i] + " ");
32 }
33 System.out.println();
34 }
35
36 public LinearClassifier batchTrain(
37 ArrayList<ClassificationInstance> trainingData) {
38 LinearClassifier result = new LinearClassifier(xAlphabet, yAlphabet,
39 fxy);
40 double[] w = new double[trainingData.size()];
41 for (int i = 0; i < w.length; i++)
42 w[i] = 1.0 / trainingData.size();
43
44 double[] correct = new double[fxy.wSize()];
45 double[] wrong = new double[fxy.wSize()];
46
47 @SuppressWarnings("unused")
48 int oldBest = 0;
49 for (int iter = 0; iter < numIterations; iter++) {
50 computeAccuracies(correct, wrong, trainingData, w);
51
52
53
54
55
56 int bestFeature = chooseBest(correct, wrong);
57 double alpha = Math
58 .log((correct[bestFeature]) / wrong[bestFeature]) / 2;
59 result.w[bestFeature] += alpha;
60 updateW(w, bestFeature, alpha, trainingData);
61
62
63 }
64 return result;
65 }
66
67 private int chooseBest(double[] correct, double[] wrong) {
68 int res = 0;
69 double bestval = Double.MIN_VALUE;
70 for (int i = 0; i < correct.length; i++) {
71 double val = correct[i] - wrong[i];
72 if (val > bestval) {
73 res = i;
74 bestval = val;
75 }
76 }
77 return res;
78 }
79
80 private void updateW(double[] w, int bestFeature, double alpha,
81 ArrayList<ClassificationInstance> trainingData) {
82 double wrongUpdate = Math.exp(alpha);
83 double correctUpdate = Math.exp(-alpha);
84 for (int instInd = 0; instInd < trainingData.size(); instInd++) {
85 ClassificationInstance inst = trainingData.get(instInd);
86 for (int y = 0; y < yAlphabet.size(); y++) {
87 SparseVector fv = fxy.apply(inst.x, y);
88 for (int i = 0; i < fv.numEntries(); i++) {
89 if (fv.getIndexAt(i) == bestFeature) {
90 if (y == inst.y)
91 w[instInd] *= correctUpdate;
92 else
93 w[instInd] *= wrongUpdate;
94 }
95 }
96 }
97 }
98 double sum = StaticUtils.sum(w);
99 for (int i = 0; i < w.length; i++) {
100 w[i] /= sum;
101 }
102
103 }
104
105 private void computeAccuracies(double[] correct, double[] wrongs,
106 ArrayList<ClassificationInstance> trainingData, double[] w) {
107 double total = 2 * smooth;
108 for (int i = 0; i < correct.length; i++) {
109 correct[i] = smooth;
110 wrongs[i] = smooth;
111 }
112 for (int instInd = 0; instInd < trainingData.size(); instInd++) {
113 ClassificationInstance inst = trainingData.get(instInd);
114 total += w[instInd];
115 for (int y = 0; y < yAlphabet.size(); y++) {
116 SparseVector fv = fxy.apply(inst.x, y);
117 if (y == inst.y) {
118 for (int i = 0; i < fv.numEntries(); i++) {
119 correct[fv.getIndexAt(i)] += w[instInd];
120 }
121 } else {
122 for (int i = 0; i < fv.numEntries(); i++) {
123 wrongs[fv.getIndexAt(i)] += w[instInd];
124 }
125 }
126 }
127 }
128 for (int i = 0; i < correct.length; i++) {
129 correct[i] /= total;
130 wrongs[i] /= total;
131 }
132 }
133
134 public static void main(String[] args) {
135 ArrayList<ClassificationInstance> train = new ArrayList<ClassificationInstance>();
136 Alphabet xAlphabet = new Alphabet();
137 Alphabet yAlphabet = new Alphabet();
138 String[] classes = new String[] { "a", "b" };
139 Random r = new Random(10);
140 int numFeats = 5;
141 double randomFrac = 0.5;
142 double missingFrac = 0.5;
143 for (int instInd = 0; instInd < 10; instInd++) {
144 String label = classes[r.nextInt(classes.length)];
145 SparseVector sv = new SparseVector();
146 for (int fInd = 0; fInd < numFeats; fInd++) {
147 if (r.nextDouble() < missingFrac)
148 continue;
149 String tmpLab = label;
150 if (r.nextDouble() < randomFrac)
151 tmpLab = classes[r.nextInt(classes.length)];
152 sv.add(xAlphabet.lookupObject(tmpLab + fInd), 1);
153 }
154 train.add(new ClassificationInstance(xAlphabet, yAlphabet, sv,
155 label));
156 }
157 AdaBoost boost = new AdaBoost(10, xAlphabet, yAlphabet,
158 new CompleteFeatureFunction(xAlphabet, yAlphabet));
159 LinearClassifier h = boost.batchTrain(train);
160 System.out.println(StaticUtils.computeAccuracy(h, train));
161 }
162
163 }