1 package sequence;
2
3 import java.util.ArrayList;
4
5 import algo.ConjugateGradient;
6 import algo.GradientAscent;
7 import types.Alphabet;
8 import types.DifferentiableObjective;
9 import types.StaticUtils;
10
11 public class CRF {
12
13 double gaussianPriorVariance;
14 double numObservations;
15 Alphabet xAlphabet;
16 Alphabet yAlphabet;
17 SequenceFeatureFunction fxy;
18
19 public CRF(double gaussianPriorVariance, Alphabet xAlphabet,
20 Alphabet yAlphabet, SequenceFeatureFunction fxy) {
21 this.gaussianPriorVariance = gaussianPriorVariance;
22 this.xAlphabet = xAlphabet;
23 this.yAlphabet = yAlphabet;
24 this.fxy = fxy;
25 }
26
27 public LinearTagger batchTrain(ArrayList<SequenceInstance> trainingData) {
28 Objective obj = new Objective(trainingData);
29
30 @SuppressWarnings("unused")
31 GradientAscent gaoptimizer = new GradientAscent();
32 @SuppressWarnings("unused")
33 ConjugateGradient optimizer = new ConjugateGradient(obj
34 .getNumParameters());
35 @SuppressWarnings("unused")
36 boolean success = optimizer.maximize(obj);
37 System.out.println("valCalls = " + obj.numValueCalls
38 + " gradientCalls=" + obj.numGradientCalls);
39 return obj.tagger;
40 }
41
42 /***
43 * An objective for our max-ent model. That is: max_\lambda sum_i log
44 * Pr(y_i|x_i) - 1/var * ||\lambda||^2 where var is the Gaussian prior
45 * variance, and p(y|x) = exp(f(x,y)*lambda)/Z(x).
46 *
47 * @author kuzman
48 *
49 */
50 class Objective implements DifferentiableObjective {
51 double[] empiricalExpectations;
52 LinearTagger tagger;
53 ArrayList<SequenceInstance> trainingData;
54 int numValueCalls = 0;
55 int numGradientCalls = 0;
56
57 Objective(ArrayList<SequenceInstance> trainingData) {
58 this.trainingData = trainingData;
59
60 empiricalExpectations = new double[fxy.wSize()];
61 for (SequenceInstance inst : trainingData) {
62 StaticUtils.plusEquals(empiricalExpectations, fxy.apply(inst.x,
63 inst.y));
64 }
65 tagger = new LinearTagger(xAlphabet, yAlphabet, fxy);
66 }
67
68 private double[][] forward(double[][][] expS) {
69 double[][] res = new double[expS.length][yAlphabet.size()];
70 for (int y = 0; y < yAlphabet.size(); y++) {
71 res[0][y] = expS[0][0][y];
72 }
73 for (int t = 1; t < expS.length; t++) {
74 for (int yt = 0; yt < yAlphabet.size(); yt++) {
75 for (int ytm1 = 0; ytm1 < yAlphabet.size(); ytm1++) {
76 res[t][yt] += res[t - 1][ytm1] * expS[t][ytm1][yt];
77 }
78 }
79 }
80 return res;
81 }
82
83 private double[][] backward(double[][][] expS) {
84 double[][] res = new double[expS.length][yAlphabet.size()];
85 for (int y = 0; y < yAlphabet.size(); y++) {
86 res[expS.length - 1][y] = 1;
87 }
88 for (int t = expS.length - 1; t > 0; t--) {
89 for (int yt = 0; yt < yAlphabet.size(); yt++) {
90 for (int ytm1 = 0; ytm1 < yAlphabet.size(); ytm1++) {
91 res[t - 1][ytm1] += res[t][yt] * expS[t][ytm1][yt];
92 }
93 }
94 }
95 return res;
96 }
97
98 private void normalizeScores(double[][][] scores) {
99 for (int t = 0; t < scores.length; t++) {
100 double max = 0;
101 for (int ytm1 = 0; ytm1 < yAlphabet.size(); ytm1++) {
102 for (int yt = 0; yt < yAlphabet.size(); yt++) {
103 max = Math.max(max, scores[t][ytm1][yt]);
104 }
105 }
106
107 for (int ytm1 = 0; ytm1 < yAlphabet.size(); ytm1++) {
108 for (int yt = 0; yt < yAlphabet.size(); yt++) {
109 scores[t][ytm1][yt] -= max;
110 }
111 }
112 }
113 }
114
115 public double getValue() {
116 numValueCalls++;
117
118 double val = 0;
119 int numUnnormalizedInstances = 0;
120 for (SequenceInstance inst : trainingData) {
121 double[][][] scores = tagger.scores(inst.x);
122 normalizeScores(scores);
123 double[][][] expScores = StaticUtils.exp(scores);
124 double[][] alpha = forward(expScores);
125
126
127 double Z = StaticUtils.sum(alpha[inst.x.length - 1]);
128 if (Z == 0 || Double.isNaN(Z) || Double.isInfinite(Z)) {
129
130
131 if (numUnnormalizedInstances < 3) {
132 System.err.println("Could not normalize instance (" + Z
133 + "), skipping");
134 } else if (numUnnormalizedInstances == 3) {
135 System.err.println(" ...");
136 }
137 numUnnormalizedInstances++;
138 continue;
139 }
140 val += Math.log(expScores[0][0][inst.y[0]]);
141 for (int t = 1; t < inst.y.length; t++) {
142 val += Math.log(expScores[t][inst.y[t - 1]][inst.y[t]]);
143 }
144 val -= Math.log(Z);
145 }
146 if (numUnnormalizedInstances != 0)
147 System.err.println("Could not normalize "
148 + numUnnormalizedInstances + " instances");
149 val -= 1 / (2 * gaussianPriorVariance)
150 * StaticUtils.twoNormSquared(tagger.w);
151 return val;
152 }
153
154 public void getGradient(double[] gradient) {
155 numGradientCalls++;
156
157
158 double[] modelExpectations = new double[gradient.length];
159 for (int i = 0; i < gradient.length; i++) {
160 gradient[i] = empiricalExpectations[i];
161 modelExpectations[i] = 0;
162 }
163 int numUnnormalizedInstances = 0;
164 for (SequenceInstance inst : trainingData) {
165 double[][][] scores = tagger.scores(inst.x);
166 normalizeScores(scores);
167 double[][][] expScores = StaticUtils.exp(scores);
168 double[][] alpha = forward(expScores);
169
170 double[][] beta = backward(expScores);
171 double Z = StaticUtils.sum(alpha[inst.x.length - 1]);
172 if (Z == 0 || Double.isNaN(Z) || Double.isInfinite(Z)) {
173 if (numUnnormalizedInstances < 3) {
174 System.err.println("Could not normalize instance (" + Z
175 + "), skipping");
176 } else if (numUnnormalizedInstances == 3) {
177 System.err.println(" ...");
178 }
179 numUnnormalizedInstances++;
180 continue;
181
182
183 }
184 for (int yt = 0; yt < yAlphabet.size(); yt++) {
185 StaticUtils.plusEquals(modelExpectations, fxy.apply(inst.x,
186 0, yt, 0), alpha[0][yt] * beta[0][yt]
187 * expScores[0][0][yt] / Z);
188 }
189 for (int t = 1; t < inst.x.length; t++) {
190 for (int ytm1 = 0; ytm1 < yAlphabet.size(); ytm1++) {
191 for (int yt = 0; yt < yAlphabet.size(); yt++) {
192 StaticUtils.plusEquals(modelExpectations, fxy
193 .apply(inst.x, ytm1, yt, t),
194 alpha[t - 1][ytm1] * beta[t][yt]
195 * expScores[t][ytm1][yt] / Z);
196 }
197 }
198 }
199 }
200
201 for (int i = 0; i < gradient.length; i++) {
202 gradient[i] -= modelExpectations[i];
203 gradient[i] -= 1 / gaussianPriorVariance * tagger.w[i];
204
205 }
206
207 }
208
209 public void setParameters(double[] newParameters) {
210 System.arraycopy(newParameters, 0, tagger.w, 0,
211 newParameters.length);
212 }
213
214 public void getParameters(double[] params) {
215 System.arraycopy(tagger.w, 0, params, 0, params.length);
216 }
217
218 public int getNumParameters() {
219 return tagger.w.length;
220 }
221
222 }
223
224 }