View Javadoc

1   package types;
2   
3   import java.io.IOException;
4   import java.io.ObjectOutputStream;
5   
6   import classification.CompleteFeatureFunction;
7   
8   /***
9    * A linear model for classification. It has the form h(x) = arg_max_y f(x,y).w
10   * 
11   * @author kuzman
12   * 
13   */
14  
15  public class LinearClassifier {
16  	private static final long serialVersionUID = 1L;
17  	public double[] w;
18  	public Alphabet yAlphabet;
19  	public Alphabet xAlphabet;
20  	FeatureFunction fxy;
21  
22  	public LinearClassifier(Alphabet xAlpha, Alphabet yAlpha,
23  			FeatureFunction fxy) {
24  		w = new double[fxy.wSize()];
25  		yAlphabet = yAlpha;
26  		xAlphabet = xAlpha;
27  		this.fxy = fxy;
28  	}
29  
30  	/***
31  	 * computes the score of each label 'y' defined as f(x,y) . w
32  	 * 
33  	 * @param x
34  	 * @return [f(x,0).w, f(x,1).w, ...]
35  	 */
36  	public double[] scores(SparseVector x) {
37  		double[] res = new double[yAlphabet.size()];
38  		for (int y = 0; y < yAlphabet.size(); y++) {
39  			res[y] = StaticUtils.dotProduct(fxy.apply(x, y), (w));
40  		}
41  		return res;
42  	}
43  
44  	/***
45  	 * computes the classification according to this linear classifier.
46  	 * arg_max_y f(x,y) . w
47  	 * 
48  	 * @param x
49  	 * @return y that maximizes f(x,y) . w
50  	 */
51  	public int label(SparseVector x) {
52  		double[] scores = scores(x);
53  		int max = 0;
54  		for (int y = 0; y < yAlphabet.size(); y++) {
55  			if (scores[max] < scores[y])
56  				max = y;
57  		}
58  		return max;
59  	}
60  
61  	public void writeObject(ObjectOutputStream out) throws IOException {
62  		out.writeLong(serialVersionUID);
63  		out.writeInt(w.length);
64  		for (double d : w)
65  			out.writeDouble(d);
66  		out.writeObject(xAlphabet);
67  		out.writeObject(yAlphabet);
68  		out.writeObject(fxy);
69  	}
70  
71  	@SuppressWarnings("unchecked")
72  	public void readObject(java.io.ObjectInputStream in) throws IOException,
73  			ClassNotFoundException {
74  		long inid = in.readLong();
75  		if (inid != serialVersionUID)
76  			throw new IOException("Serial version mismatch: expected "
77  					+ serialVersionUID + " got " + inid);
78  		w = new double[in.readInt()];
79  		for (int i = 0; i < w.length; i++)
80  			w[i] = in.readDouble();
81  		xAlphabet = (Alphabet) in.readObject();
82  		yAlphabet = (Alphabet) in.readObject();
83  		fxy = (CompleteFeatureFunction) in.readObject();
84  	}
85  
86  }