1 package classification; 2 3 import java.util.ArrayList; 4 5 import types.Alphabet; 6 import types.ClassificationInstance; 7 import types.FeatureFunction; 8 import types.LinearClassifier; 9 import types.StaticUtils; 10 11 public class Perceptron { 12 13 boolean performAveraging; 14 int numIterations; 15 Alphabet xAlphabet; 16 Alphabet yAlphabet; 17 FeatureFunction fxy; 18 19 public Perceptron(boolean performAveraging, int numIterations, 20 Alphabet xAlphabet, Alphabet yAlphabet, FeatureFunction fxy) { 21 this.performAveraging = performAveraging; 22 this.numIterations = numIterations; 23 this.xAlphabet = xAlphabet; 24 this.yAlphabet = yAlphabet; 25 this.fxy = fxy; 26 } 27 28 public LinearClassifier batchTrain( 29 ArrayList<ClassificationInstance> trainingData) { 30 LinearClassifier w = new LinearClassifier(xAlphabet, yAlphabet, fxy); 31 LinearClassifier theta = null; 32 if (performAveraging) 33 theta = new LinearClassifier(xAlphabet, yAlphabet, fxy); 34 for (int iter = 0; iter < numIterations; iter++) { 35 for (ClassificationInstance inst : trainingData) { 36 int yhat = w.label(inst.x); 37 if (yhat != inst.y) { 38 StaticUtils.plusEquals(w.w, fxy.apply(inst.x, inst.y)); 39 StaticUtils.plusEquals(w.w, fxy.apply(inst.x, yhat), -1); 40 } 41 if (performAveraging) 42 StaticUtils.plusEquals(theta.w, w.w, 1); 43 } 44 } 45 if (performAveraging) 46 return theta; 47 return w; 48 } 49 50 }