1 package types; 2 3 import java.io.IOException; 4 import java.io.ObjectOutputStream; 5 6 import classification.CompleteFeatureFunction; 7 8 /*** 9 * A linear model for classification. It has the form h(x) = arg_max_y f(x,y).w 10 * 11 * @author kuzman 12 * 13 */ 14 15 public class LinearClassifier { 16 private static final long serialVersionUID = 1L; 17 public double[] w; 18 public Alphabet yAlphabet; 19 public Alphabet xAlphabet; 20 FeatureFunction fxy; 21 22 public LinearClassifier(Alphabet xAlpha, Alphabet yAlpha, 23 FeatureFunction fxy) { 24 w = new double[fxy.wSize()]; 25 yAlphabet = yAlpha; 26 xAlphabet = xAlpha; 27 this.fxy = fxy; 28 } 29 30 /*** 31 * computes the score of each label 'y' defined as f(x,y) . w 32 * 33 * @param x 34 * @return [f(x,0).w, f(x,1).w, ...] 35 */ 36 public double[] scores(SparseVector x) { 37 double[] res = new double[yAlphabet.size()]; 38 for (int y = 0; y < yAlphabet.size(); y++) { 39 res[y] = StaticUtils.dotProduct(fxy.apply(x, y), (w)); 40 } 41 return res; 42 } 43 44 /*** 45 * computes the classification according to this linear classifier. 46 * arg_max_y f(x,y) . w 47 * 48 * @param x 49 * @return y that maximizes f(x,y) . w 50 */ 51 public int label(SparseVector x) { 52 double[] scores = scores(x); 53 int max = 0; 54 for (int y = 0; y < yAlphabet.size(); y++) { 55 if (scores[max] < scores[y]) 56 max = y; 57 } 58 return max; 59 } 60 61 public void writeObject(ObjectOutputStream out) throws IOException { 62 out.writeLong(serialVersionUID); 63 out.writeInt(w.length); 64 for (double d : w) 65 out.writeDouble(d); 66 out.writeObject(xAlphabet); 67 out.writeObject(yAlphabet); 68 out.writeObject(fxy); 69 } 70 71 @SuppressWarnings("unchecked") 72 public void readObject(java.io.ObjectInputStream in) throws IOException, 73 ClassNotFoundException { 74 long inid = in.readLong(); 75 if (inid != serialVersionUID) 76 throw new IOException("Serial version mismatch: expected " 77 + serialVersionUID + " got " + inid); 78 w = new double[in.readInt()]; 79 for (int i = 0; i < w.length; i++) 80 w[i] = in.readDouble(); 81 xAlphabet = (Alphabet) in.readObject(); 82 yAlphabet = (Alphabet) in.readObject(); 83 fxy = (CompleteFeatureFunction) in.readObject(); 84 } 85 86 }