View Javadoc

1   package sequence;
2   
3   import java.io.IOException;
4   import java.io.ObjectOutputStream;
5   import java.io.Serializable;
6   
7   import types.Alphabet;
8   import types.SparseVector;
9   import types.StaticUtils;
10  
11  /***
12   * A linear model for sequence classification. It has the form h(x) = arg_max_y
13   * f(x,y).w where x and y are sequences of identical length, and f(x,y)
14   * decomposes over pairs of positions on y.
15   * 
16   * @author kuzman
17   * 
18   */
19  
20  public class LinearTagger implements Serializable {
21  
22  	private static final long serialVersionUID = 1L;
23  	public double[] w;
24  	Alphabet yAlphabet;
25  	Alphabet xAlphabet;
26  	SequenceFeatureFunction fxy;
27  
28  	public LinearTagger(Alphabet xAlpha, Alphabet yAlpha,
29  			SequenceFeatureFunction fxy) {
30  		w = new double[fxy.wSize()];
31  		yAlphabet = yAlpha;
32  		xAlphabet = xAlpha;
33  		this.fxy = fxy;
34  	}
35  
36  	/***
37  	 * at each position 0<=t<x.length, computes the score of each label pair
38  	 * 'ytm1','yt' as f(x,ytm1,yt) . w where yt is the label at position t and
39  	 * ytm1 is the label at position t-1.
40  	 * 
41  	 * @param x
42  	 * @return result[t][ytm1][yt] = f(x,ytm1,yt).w
43  	 */
44  	public double[][][] scores(SparseVector[] x) {
45  		double[][][] res = new double[x.length][yAlphabet.size()][yAlphabet
46  				.size()];
47  		for (int t = 0; t < x.length; t++) {
48  			for (int ytm1 = 0; ytm1 < yAlphabet.size(); ytm1++) {
49  				for (int yt = 0; yt < yAlphabet.size(); yt++) {
50  					res[t][ytm1][yt] = StaticUtils.dotProduct(fxy.apply(x,
51  							ytm1, yt, t), w);
52  				}
53  			}
54  		}
55  		return res;
56  	}
57  
58  	/***
59  	 * use the Viterbi algorithm to find arg_max_y f(x,y) . w
60  	 * 
61  	 * @param x
62  	 * @return y that maximizes f(x,y) . w
63  	 */
64  	public int[] label(SparseVector[] x) {
65  		double[][][] scores = scores(x);
66  		double[][] gamma = new double[x.length][yAlphabet.size()];
67  		int[][] back = new int[x.length][yAlphabet.size()];
68  		for (int y = 0; y < yAlphabet.size(); y++) {
69  			gamma[0][y] = scores[0][0][y];
70  		}
71  		for (int t = 1; t < x.length; t++) {
72  			for (int yt = 0; yt < yAlphabet.size(); yt++) {
73  				gamma[t][yt] = Double.NEGATIVE_INFINITY;
74  				for (int ytm1 = 0; ytm1 < yAlphabet.size(); ytm1++) {
75  					if (gamma[t][yt] < gamma[t - 1][ytm1] + scores[t][ytm1][yt]) {
76  						back[t][yt] = ytm1;
77  						gamma[t][yt] = gamma[t - 1][ytm1] + scores[t][ytm1][yt];
78  					}
79  				}
80  			}
81  		}
82  		int[] tags = new int[x.length];
83  		for (int y = 0; y < yAlphabet.size(); y++) {
84  			if (gamma[x.length - 1][tags[x.length - 1]] < gamma[x.length - 1][y]) {
85  				tags[x.length - 1] = y;
86  			}
87  		}
88  		for (int t = x.length - 2; t >= 0; t--) {
89  			tags[t] = back[t + 1][tags[t + 1]];
90  		}
91  		return tags;
92  	}
93  
94  	public void writeObject(ObjectOutputStream out) throws IOException {
95  		out.writeLong(serialVersionUID);
96  		out.writeInt(w.length);
97  		for (double d : w)
98  			out.writeDouble(d);
99  		out.writeObject(xAlphabet);
100 		out.writeObject(yAlphabet);
101 		out.writeObject(fxy);
102 	}
103 
104 	@SuppressWarnings("unchecked")
105 	public void readObject(java.io.ObjectInputStream in) throws IOException,
106 			ClassNotFoundException {
107 		long inid = in.readLong();
108 		if (inid != serialVersionUID)
109 			throw new IOException("Serial version mismatch: expected "
110 					+ serialVersionUID + " got " + inid);
111 		w = new double[in.readInt()];
112 		for (int i = 0; i < w.length; i++)
113 			w[i] = in.readDouble();
114 		xAlphabet = (Alphabet) in.readObject();
115 		yAlphabet = (Alphabet) in.readObject();
116 		fxy = (SequenceFeatureFunction) in.readObject();
117 	}
118 
119 }