1 package sequence;
2
3 import java.io.IOException;
4 import java.io.ObjectOutputStream;
5 import java.io.Serializable;
6
7 import types.Alphabet;
8 import types.SparseVector;
9 import types.StaticUtils;
10
11 /***
12 * A linear model for sequence classification. It has the form h(x) = arg_max_y
13 * f(x,y).w where x and y are sequences of identical length, and f(x,y)
14 * decomposes over pairs of positions on y.
15 *
16 * @author kuzman
17 *
18 */
19
20 public class LinearTagger implements Serializable {
21
22 private static final long serialVersionUID = 1L;
23 public double[] w;
24 Alphabet yAlphabet;
25 Alphabet xAlphabet;
26 SequenceFeatureFunction fxy;
27
28 public LinearTagger(Alphabet xAlpha, Alphabet yAlpha,
29 SequenceFeatureFunction fxy) {
30 w = new double[fxy.wSize()];
31 yAlphabet = yAlpha;
32 xAlphabet = xAlpha;
33 this.fxy = fxy;
34 }
35
36 /***
37 * at each position 0<=t<x.length, computes the score of each label pair
38 * 'ytm1','yt' as f(x,ytm1,yt) . w where yt is the label at position t and
39 * ytm1 is the label at position t-1.
40 *
41 * @param x
42 * @return result[t][ytm1][yt] = f(x,ytm1,yt).w
43 */
44 public double[][][] scores(SparseVector[] x) {
45 double[][][] res = new double[x.length][yAlphabet.size()][yAlphabet
46 .size()];
47 for (int t = 0; t < x.length; t++) {
48 for (int ytm1 = 0; ytm1 < yAlphabet.size(); ytm1++) {
49 for (int yt = 0; yt < yAlphabet.size(); yt++) {
50 res[t][ytm1][yt] = StaticUtils.dotProduct(fxy.apply(x,
51 ytm1, yt, t), w);
52 }
53 }
54 }
55 return res;
56 }
57
58 /***
59 * use the Viterbi algorithm to find arg_max_y f(x,y) . w
60 *
61 * @param x
62 * @return y that maximizes f(x,y) . w
63 */
64 public int[] label(SparseVector[] x) {
65 double[][][] scores = scores(x);
66 double[][] gamma = new double[x.length][yAlphabet.size()];
67 int[][] back = new int[x.length][yAlphabet.size()];
68 for (int y = 0; y < yAlphabet.size(); y++) {
69 gamma[0][y] = scores[0][0][y];
70 }
71 for (int t = 1; t < x.length; t++) {
72 for (int yt = 0; yt < yAlphabet.size(); yt++) {
73 gamma[t][yt] = Double.NEGATIVE_INFINITY;
74 for (int ytm1 = 0; ytm1 < yAlphabet.size(); ytm1++) {
75 if (gamma[t][yt] < gamma[t - 1][ytm1] + scores[t][ytm1][yt]) {
76 back[t][yt] = ytm1;
77 gamma[t][yt] = gamma[t - 1][ytm1] + scores[t][ytm1][yt];
78 }
79 }
80 }
81 }
82 int[] tags = new int[x.length];
83 for (int y = 0; y < yAlphabet.size(); y++) {
84 if (gamma[x.length - 1][tags[x.length - 1]] < gamma[x.length - 1][y]) {
85 tags[x.length - 1] = y;
86 }
87 }
88 for (int t = x.length - 2; t >= 0; t--) {
89 tags[t] = back[t + 1][tags[t + 1]];
90 }
91 return tags;
92 }
93
94 public void writeObject(ObjectOutputStream out) throws IOException {
95 out.writeLong(serialVersionUID);
96 out.writeInt(w.length);
97 for (double d : w)
98 out.writeDouble(d);
99 out.writeObject(xAlphabet);
100 out.writeObject(yAlphabet);
101 out.writeObject(fxy);
102 }
103
104 @SuppressWarnings("unchecked")
105 public void readObject(java.io.ObjectInputStream in) throws IOException,
106 ClassNotFoundException {
107 long inid = in.readLong();
108 if (inid != serialVersionUID)
109 throw new IOException("Serial version mismatch: expected "
110 + serialVersionUID + " got " + inid);
111 w = new double[in.readInt()];
112 for (int i = 0; i < w.length; i++)
113 w[i] = in.readDouble();
114 xAlphabet = (Alphabet) in.readObject();
115 yAlphabet = (Alphabet) in.readObject();
116 fxy = (SequenceFeatureFunction) in.readObject();
117 }
118
119 }