1 package algo;
2
3 import types.DifferentiableObjective;
4 import types.StaticUtils;
5
6 public class ConjugateGradient {
7
8 int maxNumSteps = 2000;
9 double defaultStepSize = 1e-3;
10 double minStepSize = 1e-100;
11
12 public ConjugateGradient(int numParams) {
13 evaluateAtStorage = new double[numParams];
14 }
15
16 public boolean maximize(DifferentiableObjective o) {
17 double[] p = new double[o.getNumParameters()];
18 o.getParameters(p);
19 double[] g = new double[o.getNumParameters()];
20 double[] h = new double[o.getNumParameters()];
21 double[] xi = new double[o.getNumParameters()];
22 double oldScore = o.getValue();
23 @SuppressWarnings("unused")
24 long time = System.currentTimeMillis();
25 System.err.println(" Score: " + oldScore);
26 o.getGradient(g);
27 System.arraycopy(g, 0, h, 0, g.length);
28 System.arraycopy(g, 0, xi, 0, g.length);
29 for (int iteration = 0; iteration < maxNumSteps; iteration++) {
30 double newScore = lineSearch(o, p, xi);
31
32
33
34 time = System.currentTimeMillis();
35 if (newScore - oldScore < 1e-30)
36 return true;
37 oldScore = newScore;
38
39 o.getGradient(xi);
40 double numerator = StaticUtils.dotProduct(xi, xi)
41 - StaticUtils.dotProduct(g, xi);
42 double denom = StaticUtils.dotProduct(g, g);
43 if (denom < minStepSize)
44 return true;
45 double gamma = numerator / denom;
46 System.arraycopy(xi, 0, g, 0, g.length);
47 StaticUtils.add(xi, g, h, gamma);
48 System.arraycopy(xi, 0, h, 0, g.length);
49 }
50 return false;
51 }
52
53 /***
54 * finds the maximizer of o(parameters + lambda*direction)
55 *
56 * @param o
57 * @param parameters
58 * @param direction
59 * @return the score at the new parameters
60 */
61 public double lineSearch(DifferentiableObjective o, double[] parameters,
62 double[] direction) {
63 double min = 0;
64 double minVal = evalueateAt(o, parameters, direction, min);
65 if (Double.isNaN(minVal))
66 throw new RuntimeException("Invalid function value: " + minVal);
67 double max = defaultStepSize;
68 double maxVal = evalueateAt(o, parameters, direction, max);
69 if (Double.isNaN(maxVal))
70 throw new RuntimeException("Invalid function value: " + maxVal);
71
72
73 while (maxVal > minVal) {
74 max = 2 * max;
75 maxVal = evalueateAt(o, parameters, direction, max);
76 if (Double.isNaN(maxVal))
77 throw new RuntimeException("Invalid function value: " + maxVal);
78 }
79
80 while (max - min > max * 0.05) {
81 double mid = (max + min) / 2;
82 double midVal = evalueateAt(o, parameters, direction, mid);
83 if (Double.isNaN(midVal))
84 throw new RuntimeException("Invalid function value: " + midVal);
85 if (minVal > maxVal) {
86
87 max = mid;
88 maxVal = midVal;
89 } else {
90 min = mid;
91 minVal = midVal;
92 }
93 }
94 StaticUtils.add(parameters, parameters, direction, min);
95 defaultStepSize = min * 2;
96 return minVal;
97 }
98
99 private double[] evaluateAtStorage;
100
101 private double evalueateAt(DifferentiableObjective o, double[] params,
102 double[] direction, double step) {
103 StaticUtils.add(evaluateAtStorage, params, direction, step);
104 o.setParameters(evaluateAtStorage);
105 return o.getValue();
106 }
107
108 }