1 package classification;
2
3 import java.util.ArrayList;
4
5 import types.Alphabet;
6 import types.ClassificationInstance;
7 import types.FeatureFunction;
8 import types.LinearClassifier;
9 import types.StaticUtils;
10
11 public class Perceptron {
12
13 boolean performAveraging;
14 int numIterations;
15 Alphabet xAlphabet;
16 Alphabet yAlphabet;
17 FeatureFunction fxy;
18
19 public Perceptron(boolean performAveraging, int numIterations,
20 Alphabet xAlphabet, Alphabet yAlphabet, FeatureFunction fxy) {
21 this.performAveraging = performAveraging;
22 this.numIterations = numIterations;
23 this.xAlphabet = xAlphabet;
24 this.yAlphabet = yAlphabet;
25 this.fxy = fxy;
26 }
27
28 public LinearClassifier batchTrain(
29 ArrayList<ClassificationInstance> trainingData) {
30 LinearClassifier w = new LinearClassifier(xAlphabet, yAlphabet, fxy);
31 LinearClassifier theta = null;
32 if (performAveraging)
33 theta = new LinearClassifier(xAlphabet, yAlphabet, fxy);
34 for (int iter = 0; iter < numIterations; iter++) {
35 for (ClassificationInstance inst : trainingData) {
36 int yhat = w.label(inst.x);
37 if (yhat != inst.y) {
38 StaticUtils.plusEquals(w.w, fxy.apply(inst.x, inst.y));
39 StaticUtils.plusEquals(w.w, fxy.apply(inst.x, yhat), -1);
40 }
41 if (performAveraging)
42 StaticUtils.plusEquals(theta.w, w.w, 1);
43 }
44 }
45 if (performAveraging)
46 return theta;
47 return w;
48 }
49
50 }