1 package classification;
2
3 import java.util.ArrayList;
4
5 import algo.ConjugateGradient;
6 import algo.GradientAscent;
7 import types.Alphabet;
8 import types.ClassificationInstance;
9 import types.DifferentiableObjective;
10 import types.FeatureFunction;
11 import types.LinearClassifier;
12 import types.StaticUtils;
13
14 public class MaxEntropy {
15
16 double gaussianPriorVariance;
17 double numObservations;
18 Alphabet xAlphabet;
19 Alphabet yAlphabet;
20 FeatureFunction fxy;
21
22 public MaxEntropy(double gaussianPriorVariance, Alphabet xAlphabet,
23 Alphabet yAlphabet, FeatureFunction fxy) {
24 this.gaussianPriorVariance = gaussianPriorVariance;
25 this.xAlphabet = xAlphabet;
26 this.yAlphabet = yAlphabet;
27 this.fxy = fxy;
28 }
29
30 public LinearClassifier batchTrain(
31 ArrayList<ClassificationInstance> trainingData) {
32 Objective obj = new Objective(trainingData);
33
34 @SuppressWarnings("unused")
35 GradientAscent gaoptimizer;
36 ConjugateGradient optimizer = new ConjugateGradient(obj
37 .getNumParameters());
38 @SuppressWarnings("unused")
39 boolean success = optimizer.maximize(obj);
40 System.out.println("valCalls = " + obj.numValueCalls
41 + " gradientCalls=" + obj.numGradientCalls);
42 return obj.classifier;
43 }
44
45 /***
46 * An objective for our max-ent model. That is: max_\lambda sum_i log
47 * Pr(y_i|x_i) - 1/var * ||\lambda||^2 where var is the Gaussian prior
48 * variance, and p(y|x) = exp(f(x,y)*lambda)/Z(x).
49 *
50 * @author kuzman
51 *
52 */
53 class Objective implements DifferentiableObjective {
54 double[] empiricalExpectations;
55 LinearClassifier classifier;
56 ArrayList<ClassificationInstance> trainingData;
57 int numValueCalls = 0;
58 int numGradientCalls = 0;
59
60 Objective(ArrayList<ClassificationInstance> trainingData) {
61 this.trainingData = trainingData;
62
63 empiricalExpectations = new double[fxy.wSize()];
64 for (ClassificationInstance inst : trainingData) {
65 StaticUtils.plusEquals(empiricalExpectations, fxy.apply(inst.x,
66 inst.y));
67 }
68 classifier = new LinearClassifier(xAlphabet, yAlphabet, fxy);
69 }
70
71 public double getValue() {
72 numValueCalls++;
73
74 double val = 0;
75 for (ClassificationInstance inst : trainingData) {
76 double[] scores = classifier.scores(inst.x);
77 double[] probs = StaticUtils.exp(scores);
78 double Z = StaticUtils.sum(probs);
79 val += scores[inst.y] - Math.log(Z);
80 }
81 val -= 1 / (2 * gaussianPriorVariance)
82 * StaticUtils.twoNormSquared(classifier.w);
83 return val;
84 }
85
86 public void getGradient(double[] gradient) {
87 numGradientCalls++;
88
89
90 double[] modelExpectations = new double[gradient.length];
91 for (int i = 0; i < gradient.length; i++) {
92 gradient[i] = empiricalExpectations[i];
93 modelExpectations[i] = 0;
94 }
95 for (ClassificationInstance inst : trainingData) {
96 double[] scores = classifier.scores(inst.x);
97 double[] probs = StaticUtils.exp(scores);
98 double Z = StaticUtils.sum(probs);
99 for (int y = 0; y < yAlphabet.size(); y++) {
100 StaticUtils.plusEquals(modelExpectations, fxy.apply(inst.x,
101 y), probs[y] / Z);
102 }
103 }
104 for (int i = 0; i < gradient.length; i++) {
105 gradient[i] -= modelExpectations[i];
106 gradient[i] -= 1 / gaussianPriorVariance * classifier.w[i];
107 }
108
109 }
110
111 public void setParameters(double[] newParameters) {
112 System.arraycopy(newParameters, 0, classifier.w, 0,
113 newParameters.length);
114 }
115
116 public void getParameters(double[] params) {
117 System.arraycopy(classifier.w, 0, params, 0, params.length);
118 }
119
120 public int getNumParameters() {
121 return classifier.w.length;
122 }
123
124 }
125
126 }