1 package sequence; 2 3 import java.io.IOException; 4 import java.io.ObjectOutputStream; 5 import java.io.Serializable; 6 7 import types.Alphabet; 8 import types.SparseVector; 9 import types.StaticUtils; 10 11 /*** 12 * A linear model for sequence classification. It has the form h(x) = arg_max_y 13 * f(x,y).w where x and y are sequences of identical length, and f(x,y) 14 * decomposes over pairs of positions on y. 15 * 16 * @author kuzman 17 * 18 */ 19 20 public class LinearTagger implements Serializable { 21 22 private static final long serialVersionUID = 1L; 23 public double[] w; 24 Alphabet yAlphabet; 25 Alphabet xAlphabet; 26 SequenceFeatureFunction fxy; 27 28 public LinearTagger(Alphabet xAlpha, Alphabet yAlpha, 29 SequenceFeatureFunction fxy) { 30 w = new double[fxy.wSize()]; 31 yAlphabet = yAlpha; 32 xAlphabet = xAlpha; 33 this.fxy = fxy; 34 } 35 36 /*** 37 * at each position 0<=t<x.length, computes the score of each label pair 38 * 'ytm1','yt' as f(x,ytm1,yt) . w where yt is the label at position t and 39 * ytm1 is the label at position t-1. 40 * 41 * @param x 42 * @return result[t][ytm1][yt] = f(x,ytm1,yt).w 43 */ 44 public double[][][] scores(SparseVector[] x) { 45 double[][][] res = new double[x.length][yAlphabet.size()][yAlphabet 46 .size()]; 47 for (int t = 0; t < x.length; t++) { 48 for (int ytm1 = 0; ytm1 < yAlphabet.size(); ytm1++) { 49 for (int yt = 0; yt < yAlphabet.size(); yt++) { 50 res[t][ytm1][yt] = StaticUtils.dotProduct(fxy.apply(x, 51 ytm1, yt, t), w); 52 } 53 } 54 } 55 return res; 56 } 57 58 /*** 59 * use the Viterbi algorithm to find arg_max_y f(x,y) . w 60 * 61 * @param x 62 * @return y that maximizes f(x,y) . w 63 */ 64 public int[] label(SparseVector[] x) { 65 double[][][] scores = scores(x); 66 double[][] gamma = new double[x.length][yAlphabet.size()]; 67 int[][] back = new int[x.length][yAlphabet.size()]; 68 for (int y = 0; y < yAlphabet.size(); y++) { 69 gamma[0][y] = scores[0][0][y]; 70 } 71 for (int t = 1; t < x.length; t++) { 72 for (int yt = 0; yt < yAlphabet.size(); yt++) { 73 gamma[t][yt] = Double.NEGATIVE_INFINITY; 74 for (int ytm1 = 0; ytm1 < yAlphabet.size(); ytm1++) { 75 if (gamma[t][yt] < gamma[t - 1][ytm1] + scores[t][ytm1][yt]) { 76 back[t][yt] = ytm1; 77 gamma[t][yt] = gamma[t - 1][ytm1] + scores[t][ytm1][yt]; 78 } 79 } 80 } 81 } 82 int[] tags = new int[x.length]; 83 for (int y = 0; y < yAlphabet.size(); y++) { 84 if (gamma[x.length - 1][tags[x.length - 1]] < gamma[x.length - 1][y]) { 85 tags[x.length - 1] = y; 86 } 87 } 88 for (int t = x.length - 2; t >= 0; t--) { 89 tags[t] = back[t + 1][tags[t + 1]]; 90 } 91 return tags; 92 } 93 94 public void writeObject(ObjectOutputStream out) throws IOException { 95 out.writeLong(serialVersionUID); 96 out.writeInt(w.length); 97 for (double d : w) 98 out.writeDouble(d); 99 out.writeObject(xAlphabet); 100 out.writeObject(yAlphabet); 101 out.writeObject(fxy); 102 } 103 104 @SuppressWarnings("unchecked") 105 public void readObject(java.io.ObjectInputStream in) throws IOException, 106 ClassNotFoundException { 107 long inid = in.readLong(); 108 if (inid != serialVersionUID) 109 throw new IOException("Serial version mismatch: expected " 110 + serialVersionUID + " got " + inid); 111 w = new double[in.readInt()]; 112 for (int i = 0; i < w.length; i++) 113 w[i] = in.readDouble(); 114 xAlphabet = (Alphabet) in.readObject(); 115 yAlphabet = (Alphabet) in.readObject(); 116 fxy = (SequenceFeatureFunction) in.readObject(); 117 } 118 119 }