View Javadoc

1   package sequence;
2   
3   import java.util.ArrayList;
4   
5   import algo.ConjugateGradient;
6   import algo.GradientAscent;
7   import types.Alphabet;
8   import types.DifferentiableObjective;
9   import types.StaticUtils;
10  
11  public class CRF {
12  
13  	double gaussianPriorVariance;
14  	double numObservations;
15  	Alphabet xAlphabet;
16  	Alphabet yAlphabet;
17  	SequenceFeatureFunction fxy;
18  
19  	public CRF(double gaussianPriorVariance, Alphabet xAlphabet,
20  			Alphabet yAlphabet, SequenceFeatureFunction fxy) {
21  		this.gaussianPriorVariance = gaussianPriorVariance;
22  		this.xAlphabet = xAlphabet;
23  		this.yAlphabet = yAlphabet;
24  		this.fxy = fxy;
25  	}
26  
27  	public LinearTagger batchTrain(ArrayList<SequenceInstance> trainingData) {
28  		Objective obj = new Objective(trainingData);
29  		// perform gradient descent
30  		@SuppressWarnings("unused")
31  		GradientAscent gaoptimizer = new GradientAscent();
32  		@SuppressWarnings("unused")
33  		ConjugateGradient optimizer = new ConjugateGradient(obj
34  				.getNumParameters());
35  		@SuppressWarnings("unused")
36  		boolean success = optimizer.maximize(obj);
37  		System.out.println("valCalls = " + obj.numValueCalls
38  				+ "   gradientCalls=" + obj.numGradientCalls);
39  		return obj.tagger;
40  	}
41  
42  	/***
43  	 * An objective for our max-ent model. That is: max_\lambda sum_i log
44  	 * Pr(y_i|x_i) - 1/var * ||\lambda||^2 where var is the Gaussian prior
45  	 * variance, and p(y|x) = exp(f(x,y)*lambda)/Z(x).
46  	 * 
47  	 * @author kuzman
48  	 * 
49  	 */
50  	class Objective implements DifferentiableObjective {
51  		double[] empiricalExpectations;
52  		LinearTagger tagger;
53  		ArrayList<SequenceInstance> trainingData;
54  		int numValueCalls = 0;
55  		int numGradientCalls = 0;
56  
57  		Objective(ArrayList<SequenceInstance> trainingData) {
58  			this.trainingData = trainingData;
59  			// compute empirical expectations...
60  			empiricalExpectations = new double[fxy.wSize()];
61  			for (SequenceInstance inst : trainingData) {
62  				StaticUtils.plusEquals(empiricalExpectations, fxy.apply(inst.x,
63  						inst.y));
64  			}
65  			tagger = new LinearTagger(xAlphabet, yAlphabet, fxy);
66  		}
67  
68  		private double[][] forward(double[][][] expS) {
69  			double[][] res = new double[expS.length][yAlphabet.size()];
70  			for (int y = 0; y < yAlphabet.size(); y++) {
71  				res[0][y] = expS[0][0][y];
72  			}
73  			for (int t = 1; t < expS.length; t++) {
74  				for (int yt = 0; yt < yAlphabet.size(); yt++) {
75  					for (int ytm1 = 0; ytm1 < yAlphabet.size(); ytm1++) {
76  						res[t][yt] += res[t - 1][ytm1] * expS[t][ytm1][yt];
77  					}
78  				}
79  			}
80  			return res;
81  		}
82  
83  		private double[][] backward(double[][][] expS) {
84  			double[][] res = new double[expS.length][yAlphabet.size()];
85  			for (int y = 0; y < yAlphabet.size(); y++) {
86  				res[expS.length - 1][y] = 1;
87  			}
88  			for (int t = expS.length - 1; t > 0; t--) {
89  				for (int yt = 0; yt < yAlphabet.size(); yt++) {
90  					for (int ytm1 = 0; ytm1 < yAlphabet.size(); ytm1++) {
91  						res[t - 1][ytm1] += res[t][yt] * expS[t][ytm1][yt];
92  					}
93  				}
94  			}
95  			return res;
96  		}
97  
98  		private void normalizeScores(double[][][] scores) {
99  			for (int t = 0; t < scores.length; t++) {
100 				double max = 0;
101 				for (int ytm1 = 0; ytm1 < yAlphabet.size(); ytm1++) {
102 					for (int yt = 0; yt < yAlphabet.size(); yt++) {
103 						max = Math.max(max, scores[t][ytm1][yt]);
104 					}
105 				}
106 				// max = max/yAlphabet.size();
107 				for (int ytm1 = 0; ytm1 < yAlphabet.size(); ytm1++) {
108 					for (int yt = 0; yt < yAlphabet.size(); yt++) {
109 						scores[t][ytm1][yt] -= max;
110 					}
111 				}
112 			}
113 		}
114 
115 		public double getValue() {
116 			numValueCalls++;
117 			// value = log(prob(data)) - 1/gaussianPriorVariance * ||lambda||^2
118 			double val = 0;
119 			int numUnnormalizedInstances = 0;
120 			for (SequenceInstance inst : trainingData) {
121 				double[][][] scores = tagger.scores(inst.x);
122 				normalizeScores(scores);
123 				double[][][] expScores = StaticUtils.exp(scores);
124 				double[][] alpha = forward(expScores);
125 				// just need likelihood.. so no beta
126 				// double[][] beta = backward(expScores);
127 				double Z = StaticUtils.sum(alpha[inst.x.length - 1]);
128 				if (Z == 0 || Double.isNaN(Z) || Double.isInfinite(Z)) {
129 					// throw new RuntimeException("can't normalize instance.
130 					// Z="+Z);
131 					if (numUnnormalizedInstances < 3) {
132 						System.err.println("Could not normalize instance (" + Z
133 								+ "), skipping");
134 					} else if (numUnnormalizedInstances == 3) {
135 						System.err.println("    ...");
136 					}
137 					numUnnormalizedInstances++;
138 					continue;
139 				}
140 				val += Math.log(expScores[0][0][inst.y[0]]);
141 				for (int t = 1; t < inst.y.length; t++) {
142 					val += Math.log(expScores[t][inst.y[t - 1]][inst.y[t]]);
143 				}
144 				val -= Math.log(Z);
145 			}
146 			if (numUnnormalizedInstances != 0)
147 				System.err.println("Could not normalize "
148 						+ numUnnormalizedInstances + " instances");
149 			val -= 1 / (2 * gaussianPriorVariance)
150 					* StaticUtils.twoNormSquared(tagger.w);
151 			return val;
152 		}
153 
154 		public void getGradient(double[] gradient) {
155 			numGradientCalls++;
156 			// gradient = empiricalExpectations - modelExpectations
157 			// -2/gaussianPriorVariance * params
158 			double[] modelExpectations = new double[gradient.length];
159 			for (int i = 0; i < gradient.length; i++) {
160 				gradient[i] = empiricalExpectations[i];
161 				modelExpectations[i] = 0;
162 			}
163 			int numUnnormalizedInstances = 0;
164 			for (SequenceInstance inst : trainingData) {
165 				double[][][] scores = tagger.scores(inst.x);
166 				normalizeScores(scores);
167 				double[][][] expScores = StaticUtils.exp(scores);
168 				double[][] alpha = forward(expScores);
169 				// just need likelihood.. so no beta
170 				double[][] beta = backward(expScores);
171 				double Z = StaticUtils.sum(alpha[inst.x.length - 1]);
172 				if (Z == 0 || Double.isNaN(Z) || Double.isInfinite(Z)) {
173 					if (numUnnormalizedInstances < 3) {
174 						System.err.println("Could not normalize instance (" + Z
175 								+ "), skipping");
176 					} else if (numUnnormalizedInstances == 3) {
177 						System.err.println("    ...");
178 					}
179 					numUnnormalizedInstances++;
180 					continue;
181 					// throw new RuntimeException("can't normalize instance.
182 					// Z="+Z);
183 				}
184 				for (int yt = 0; yt < yAlphabet.size(); yt++) {
185 					StaticUtils.plusEquals(modelExpectations, fxy.apply(inst.x,
186 							0, yt, 0), alpha[0][yt] * beta[0][yt]
187 							* expScores[0][0][yt] / Z);
188 				}
189 				for (int t = 1; t < inst.x.length; t++) {
190 					for (int ytm1 = 0; ytm1 < yAlphabet.size(); ytm1++) {
191 						for (int yt = 0; yt < yAlphabet.size(); yt++) {
192 							StaticUtils.plusEquals(modelExpectations, fxy
193 									.apply(inst.x, ytm1, yt, t),
194 									alpha[t - 1][ytm1] * beta[t][yt]
195 											* expScores[t][ytm1][yt] / Z);
196 						}
197 					}
198 				}
199 			}
200 
201 			for (int i = 0; i < gradient.length; i++) {
202 				gradient[i] -= modelExpectations[i];
203 				gradient[i] -= 1 / gaussianPriorVariance * tagger.w[i];
204 
205 			}
206 
207 		}
208 
209 		public void setParameters(double[] newParameters) {
210 			System.arraycopy(newParameters, 0, tagger.w, 0,
211 					newParameters.length);
212 		}
213 
214 		public void getParameters(double[] params) {
215 			System.arraycopy(tagger.w, 0, params, 0, params.length);
216 		}
217 
218 		public int getNumParameters() {
219 			return tagger.w.length;
220 		}
221 
222 	}
223 
224 }