1 package sequence;
2
3 import java.util.ArrayList;
4
5 import types.Alphabet;
6 import types.StaticUtils;
7
8 public class Perceptron {
9
10 boolean performAveraging;
11 int numIterations;
12 Alphabet xAlphabet;
13 Alphabet yAlphabet;
14 SequenceFeatureFunction fxy;
15
16 public Perceptron(boolean performAveraging, int numIterations,
17 Alphabet xAlphabet, Alphabet yAlphabet, SequenceFeatureFunction fxy) {
18 this.performAveraging = performAveraging;
19 this.numIterations = numIterations;
20 this.xAlphabet = xAlphabet;
21 this.yAlphabet = yAlphabet;
22 this.fxy = fxy;
23 }
24
25 public LinearTagger batchTrain(ArrayList<SequenceInstance> trainingData) {
26 LinearTagger w = new LinearTagger(xAlphabet, yAlphabet, fxy);
27 LinearTagger theta = null;
28 if (performAveraging)
29 theta = new LinearTagger(xAlphabet, yAlphabet, fxy);
30 for (int iter = 0; iter < numIterations; iter++) {
31 for (SequenceInstance inst : trainingData) {
32 int[] yhat = w.label(inst.x);
33
34 StaticUtils.plusEquals(w.w, fxy.apply(inst.x, inst.y));
35 StaticUtils.plusEquals(w.w, fxy.apply(inst.x, yhat), -1);
36 if (performAveraging)
37 StaticUtils.plusEquals(theta.w, w.w, 1);
38 }
39 }
40 if (performAveraging)
41 return theta;
42 return w;
43 }
44
45 }