View Javadoc

1   package sequence;
2   
3   import java.util.ArrayList;
4   
5   import types.Alphabet;
6   import types.StaticUtils;
7   
8   public class Perceptron {
9   
10  	boolean performAveraging;
11  	int numIterations;
12  	Alphabet xAlphabet;
13  	Alphabet yAlphabet;
14  	SequenceFeatureFunction fxy;
15  
16  	public Perceptron(boolean performAveraging, int numIterations,
17  			Alphabet xAlphabet, Alphabet yAlphabet, SequenceFeatureFunction fxy) {
18  		this.performAveraging = performAveraging;
19  		this.numIterations = numIterations;
20  		this.xAlphabet = xAlphabet;
21  		this.yAlphabet = yAlphabet;
22  		this.fxy = fxy;
23  	}
24  
25  	public LinearTagger batchTrain(ArrayList<SequenceInstance> trainingData) {
26  		LinearTagger w = new LinearTagger(xAlphabet, yAlphabet, fxy);
27  		LinearTagger theta = null;
28  		if (performAveraging)
29  			theta = new LinearTagger(xAlphabet, yAlphabet, fxy);
30  		for (int iter = 0; iter < numIterations; iter++) {
31  			for (SequenceInstance inst : trainingData) {
32  				int[] yhat = w.label(inst.x);
33  				// if y = yhat then this update will not change w.
34  				StaticUtils.plusEquals(w.w, fxy.apply(inst.x, inst.y));
35  				StaticUtils.plusEquals(w.w, fxy.apply(inst.x, yhat), -1);
36  				if (performAveraging)
37  					StaticUtils.plusEquals(theta.w, w.w, 1);
38  			}
39  		}
40  		if (performAveraging)
41  			return theta;
42  		return w;
43  	}
44  
45  }