View Javadoc

1   package classification;
2   
3   import java.util.ArrayList;
4   
5   import algo.ConjugateGradient;
6   import algo.GradientAscent;
7   import types.Alphabet;
8   import types.ClassificationInstance;
9   import types.DifferentiableObjective;
10  import types.FeatureFunction;
11  import types.LinearClassifier;
12  import types.StaticUtils;
13  
14  public class MaxEntropy {
15  
16  	double gaussianPriorVariance;
17  	double numObservations;
18  	Alphabet xAlphabet;
19  	Alphabet yAlphabet;
20  	FeatureFunction fxy;
21  
22  	public MaxEntropy(double gaussianPriorVariance, Alphabet xAlphabet,
23  			Alphabet yAlphabet, FeatureFunction fxy) {
24  		this.gaussianPriorVariance = gaussianPriorVariance;
25  		this.xAlphabet = xAlphabet;
26  		this.yAlphabet = yAlphabet;
27  		this.fxy = fxy;
28  	}
29  
30  	public LinearClassifier batchTrain(
31  			ArrayList<ClassificationInstance> trainingData) {
32  		Objective obj = new Objective(trainingData);
33  		// perform gradient descent
34  		@SuppressWarnings("unused")
35  		GradientAscent gaoptimizer;
36  		ConjugateGradient optimizer = new ConjugateGradient(obj
37  				.getNumParameters());
38  		@SuppressWarnings("unused")
39  		boolean success = optimizer.maximize(obj);
40  		System.out.println("valCalls = " + obj.numValueCalls
41  				+ "   gradientCalls=" + obj.numGradientCalls);
42  		return obj.classifier;
43  	}
44  
45  	/***
46  	 * An objective for our max-ent model. That is: max_\lambda sum_i log
47  	 * Pr(y_i|x_i) - 1/var * ||\lambda||^2 where var is the Gaussian prior
48  	 * variance, and p(y|x) = exp(f(x,y)*lambda)/Z(x).
49  	 * 
50  	 * @author kuzman
51  	 * 
52  	 */
53  	class Objective implements DifferentiableObjective {
54  		double[] empiricalExpectations;
55  		LinearClassifier classifier;
56  		ArrayList<ClassificationInstance> trainingData;
57  		int numValueCalls = 0;
58  		int numGradientCalls = 0;
59  
60  		Objective(ArrayList<ClassificationInstance> trainingData) {
61  			this.trainingData = trainingData;
62  			// compute empirical expectations...
63  			empiricalExpectations = new double[fxy.wSize()];
64  			for (ClassificationInstance inst : trainingData) {
65  				StaticUtils.plusEquals(empiricalExpectations, fxy.apply(inst.x,
66  						inst.y));
67  			}
68  			classifier = new LinearClassifier(xAlphabet, yAlphabet, fxy);
69  		}
70  
71  		public double getValue() {
72  			numValueCalls++;
73  			// value = log(prob(data)) - 1/gaussianPriorVariance * ||lambda||^2
74  			double val = 0;
75  			for (ClassificationInstance inst : trainingData) {
76  				double[] scores = classifier.scores(inst.x);
77  				double[] probs = StaticUtils.exp(scores);
78  				double Z = StaticUtils.sum(probs);
79  				val += scores[inst.y] - Math.log(Z);
80  			}
81  			val -= 1 / (2 * gaussianPriorVariance)
82  					* StaticUtils.twoNormSquared(classifier.w);
83  			return val;
84  		}
85  
86  		public void getGradient(double[] gradient) {
87  			numGradientCalls++;
88  			// gradient = empiricalExpectations - modelExpectations
89  			// -2/gaussianPriorVariance * params
90  			double[] modelExpectations = new double[gradient.length];
91  			for (int i = 0; i < gradient.length; i++) {
92  				gradient[i] = empiricalExpectations[i];
93  				modelExpectations[i] = 0;
94  			}
95  			for (ClassificationInstance inst : trainingData) {
96  				double[] scores = classifier.scores(inst.x);
97  				double[] probs = StaticUtils.exp(scores);
98  				double Z = StaticUtils.sum(probs);
99  				for (int y = 0; y < yAlphabet.size(); y++) {
100 					StaticUtils.plusEquals(modelExpectations, fxy.apply(inst.x,
101 							y), probs[y] / Z);
102 				}
103 			}
104 			for (int i = 0; i < gradient.length; i++) {
105 				gradient[i] -= modelExpectations[i];
106 				gradient[i] -= 1 / gaussianPriorVariance * classifier.w[i];
107 			}
108 
109 		}
110 
111 		public void setParameters(double[] newParameters) {
112 			System.arraycopy(newParameters, 0, classifier.w, 0,
113 					newParameters.length);
114 		}
115 
116 		public void getParameters(double[] params) {
117 			System.arraycopy(classifier.w, 0, params, 0, params.length);
118 		}
119 
120 		public int getNumParameters() {
121 			return classifier.w.length;
122 		}
123 
124 	}
125 
126 }