1 package types;
2
3 import java.io.IOException;
4 import java.io.ObjectOutputStream;
5
6 import classification.CompleteFeatureFunction;
7
8 /***
9 * A linear model for classification. It has the form h(x) = arg_max_y f(x,y).w
10 *
11 * @author kuzman
12 *
13 */
14
15 public class LinearClassifier {
16 private static final long serialVersionUID = 1L;
17 public double[] w;
18 public Alphabet yAlphabet;
19 public Alphabet xAlphabet;
20 FeatureFunction fxy;
21
22 public LinearClassifier(Alphabet xAlpha, Alphabet yAlpha,
23 FeatureFunction fxy) {
24 w = new double[fxy.wSize()];
25 yAlphabet = yAlpha;
26 xAlphabet = xAlpha;
27 this.fxy = fxy;
28 }
29
30 /***
31 * computes the score of each label 'y' defined as f(x,y) . w
32 *
33 * @param x
34 * @return [f(x,0).w, f(x,1).w, ...]
35 */
36 public double[] scores(SparseVector x) {
37 double[] res = new double[yAlphabet.size()];
38 for (int y = 0; y < yAlphabet.size(); y++) {
39 res[y] = StaticUtils.dotProduct(fxy.apply(x, y), (w));
40 }
41 return res;
42 }
43
44 /***
45 * computes the classification according to this linear classifier.
46 * arg_max_y f(x,y) . w
47 *
48 * @param x
49 * @return y that maximizes f(x,y) . w
50 */
51 public int label(SparseVector x) {
52 double[] scores = scores(x);
53 int max = 0;
54 for (int y = 0; y < yAlphabet.size(); y++) {
55 if (scores[max] < scores[y])
56 max = y;
57 }
58 return max;
59 }
60
61 public void writeObject(ObjectOutputStream out) throws IOException {
62 out.writeLong(serialVersionUID);
63 out.writeInt(w.length);
64 for (double d : w)
65 out.writeDouble(d);
66 out.writeObject(xAlphabet);
67 out.writeObject(yAlphabet);
68 out.writeObject(fxy);
69 }
70
71 @SuppressWarnings("unchecked")
72 public void readObject(java.io.ObjectInputStream in) throws IOException,
73 ClassNotFoundException {
74 long inid = in.readLong();
75 if (inid != serialVersionUID)
76 throw new IOException("Serial version mismatch: expected "
77 + serialVersionUID + " got " + inid);
78 w = new double[in.readInt()];
79 for (int i = 0; i < w.length; i++)
80 w[i] = in.readDouble();
81 xAlphabet = (Alphabet) in.readObject();
82 yAlphabet = (Alphabet) in.readObject();
83 fxy = (CompleteFeatureFunction) in.readObject();
84 }
85
86 }