View Javadoc

1   package algo;
2   
3   import types.DifferentiableObjective;
4   import types.StaticUtils;
5   
6   public class ConjugateGradient {
7   
8   	int maxNumSteps = 2000;
9   	double defaultStepSize = 1e-3;
10  	double minStepSize = 1e-100;
11  
12  	public ConjugateGradient(int numParams) {
13  		evaluateAtStorage = new double[numParams];
14  	}
15  
16  	public boolean maximize(DifferentiableObjective o) {
17  		double[] p = new double[o.getNumParameters()];
18  		o.getParameters(p);
19  		double[] g = new double[o.getNumParameters()];
20  		double[] h = new double[o.getNumParameters()];
21  		double[] xi = new double[o.getNumParameters()];
22  		double oldScore = o.getValue();
23  		@SuppressWarnings("unused")
24  		long time = System.currentTimeMillis();
25  		System.err.println("	Score: " + oldScore);
26  		o.getGradient(g);
27  		System.arraycopy(g, 0, h, 0, g.length);
28  		System.arraycopy(g, 0, xi, 0, g.length);
29  		for (int iteration = 0; iteration < maxNumSteps; iteration++) {
30  			double newScore = lineSearch(o, p, xi);
31  			// System.err.println(" Score:
32  			// "+newScore+"\t("+((System.currentTimeMillis()-time)/1000.0)+"
33  			// seconds)");
34  			time = System.currentTimeMillis();
35  			if (newScore - oldScore < 1e-30)
36  				return true;
37  			oldScore = newScore;
38  			// check if we've converged.
39  			o.getGradient(xi);
40  			double numerator = StaticUtils.dotProduct(xi, xi)
41  					- StaticUtils.dotProduct(g, xi);
42  			double denom = StaticUtils.dotProduct(g, g);
43  			if (denom < minStepSize)
44  				return true;
45  			double gamma = numerator / denom;
46  			System.arraycopy(xi, 0, g, 0, g.length);
47  			StaticUtils.add(xi, g, h, gamma);
48  			System.arraycopy(xi, 0, h, 0, g.length);
49  		}
50  		return false;
51  	}
52  
53  	/***
54  	 * finds the maximizer of o(parameters + lambda*direction)
55  	 * 
56  	 * @param o
57  	 * @param parameters
58  	 * @param direction
59  	 * @return the score at the new parameters
60  	 */
61  	public double lineSearch(DifferentiableObjective o, double[] parameters,
62  			double[] direction) {
63  		double min = 0;
64  		double minVal = evalueateAt(o, parameters, direction, min);
65  		if (Double.isNaN(minVal))
66  			throw new RuntimeException("Invalid function value: " + minVal);
67  		double max = defaultStepSize;
68  		double maxVal = evalueateAt(o, parameters, direction, max);
69  		if (Double.isNaN(maxVal))
70  			throw new RuntimeException("Invalid function value: " + maxVal);
71  		// make sure maxVal < minVal, that way we know that if direction is a
72  		// descent direction, there is a max somewhere in the middle.
73  		while (maxVal > minVal) {
74  			max = 2 * max;
75  			maxVal = evalueateAt(o, parameters, direction, max);
76  			if (Double.isNaN(maxVal))
77  				throw new RuntimeException("Invalid function value: " + maxVal);
78  		}
79  		// binary search between the two values
80  		while (max - min > max * 0.05) {
81  			double mid = (max + min) / 2;
82  			double midVal = evalueateAt(o, parameters, direction, mid);
83  			if (Double.isNaN(midVal))
84  				throw new RuntimeException("Invalid function value: " + midVal);
85  			if (minVal > maxVal) {
86  				// min is better than max. discard max
87  				max = mid;
88  				maxVal = midVal;
89  			} else {
90  				min = mid;
91  				minVal = midVal;
92  			}
93  		}
94  		StaticUtils.add(parameters, parameters, direction, min);
95  		defaultStepSize = min * 2;
96  		return minVal;
97  	}
98  
99  	private double[] evaluateAtStorage;
100 
101 	private double evalueateAt(DifferentiableObjective o, double[] params,
102 			double[] direction, double step) {
103 		StaticUtils.add(evaluateAtStorage, params, direction, step);
104 		o.setParameters(evaluateAtStorage);
105 		return o.getValue();
106 	}
107 
108 }