View Javadoc

1   package algo;
2   
3   import types.DifferentiableObjective;
4   import types.StaticUtils;
5   
6   public class GradientAscent {
7   
8   	int maxNumSteps = 2000;
9   	double maxStepSize = 0.1;
10  	double minStepSize = 1e-100;
11  
12  	public boolean maximize(DifferentiableObjective o) {
13  		double[] gradient = new double[o.getNumParameters()];
14  		double[] currParameters = new double[o.getNumParameters()];
15  		double[] newParameters = new double[o.getNumParameters()];
16  		for (int step = 0; step < maxNumSteps; step++) {
17  			double currValue = o.getValue();
18  			o.getParameters(currParameters);
19  			o.getGradient(gradient);
20  			double stepSize = maxStepSize;
21  			while (true) {
22  				StaticUtils.add(newParameters, currParameters, gradient,
23  						stepSize);
24  				o.setParameters(newParameters);
25  				double newValue = o.getValue();
26  				if (newValue > currValue)
27  					break;
28  				if (stepSize < minStepSize) {
29  					System.out.println("Converged in "
30  							+ step
31  							+ " steps. TwoNorm of gradient is "
32  							+ Math.pow(StaticUtils.twoNormSquared(gradient),
33  									0.5));
34  					return true;
35  				}
36  				stepSize /= 2;
37  			}
38  		}
39  		System.out.println("Did not converge in " + maxNumSteps
40  				+ " gradient steps. TwoNorm of gradient is "
41  				+ Math.pow(StaticUtils.twoNormSquared(gradient), 0.5));
42  		return false;
43  	}
44  
45  }