View Javadoc

1   package classification;
2   
3   import java.util.ArrayList;
4   
5   import types.Alphabet;
6   import types.ClassificationInstance;
7   import types.FeatureFunction;
8   import types.LinearClassifier;
9   import types.StaticUtils;
10  
11  public class Perceptron {
12  
13  	boolean performAveraging;
14  	int numIterations;
15  	Alphabet xAlphabet;
16  	Alphabet yAlphabet;
17  	FeatureFunction fxy;
18  
19  	public Perceptron(boolean performAveraging, int numIterations,
20  			Alphabet xAlphabet, Alphabet yAlphabet, FeatureFunction fxy) {
21  		this.performAveraging = performAveraging;
22  		this.numIterations = numIterations;
23  		this.xAlphabet = xAlphabet;
24  		this.yAlphabet = yAlphabet;
25  		this.fxy = fxy;
26  	}
27  
28  	public LinearClassifier batchTrain(
29  			ArrayList<ClassificationInstance> trainingData) {
30  		LinearClassifier w = new LinearClassifier(xAlphabet, yAlphabet, fxy);
31  		LinearClassifier theta = null;
32  		if (performAveraging)
33  			theta = new LinearClassifier(xAlphabet, yAlphabet, fxy);
34  		for (int iter = 0; iter < numIterations; iter++) {
35  			for (ClassificationInstance inst : trainingData) {
36  				int yhat = w.label(inst.x);
37  				if (yhat != inst.y) {
38  					StaticUtils.plusEquals(w.w, fxy.apply(inst.x, inst.y));
39  					StaticUtils.plusEquals(w.w, fxy.apply(inst.x, yhat), -1);
40  				}
41  				if (performAveraging)
42  					StaticUtils.plusEquals(theta.w, w.w, 1);
43  			}
44  		}
45  		if (performAveraging)
46  			return theta;
47  		return w;
48  	}
49  
50  }