View Javadoc

1   package classification;
2   
3   /****
4    * This class implements one best structured Mira
5    * @author Kuzman Ganchev and Georgi Georgiev
6    * <A HREF="mailto:georgiev@ontotext.com>georgi.georgiev@ontotext.com</A>
7    * <A HREF="mailto:ganchev@ontotext.com>kuzman.ganchev@ontotext.com</A>
8    * Date: Thu Feb 26 12:27:56 EET 2009
9    */
10  
11  import gnu.trove.TIntDoubleHashMap;
12  import gnu.trove.TIntDoubleIterator;
13  
14  import java.util.ArrayList;
15  
16  import types.Alphabet;
17  import types.ClassificationInstance;
18  import types.LinearClassifier;
19  import types.SparseVector;
20  import types.StaticUtils;
21  
22  public class Mira {
23  
24  	boolean performAveraging;
25  	int numIterations;
26  	Alphabet xAlphabet;
27  	Alphabet yAlphabet;
28  	CompleteFeatureFunction fxy;
29  	Loss loss;
30  
31  	public Mira(boolean performAveraging, int numIterations,
32  			Alphabet xAlphabet, Alphabet yAlphabet,
33  			CompleteFeatureFunction fxy, Loss loss) {
34  		this.performAveraging = performAveraging;
35  		this.numIterations = numIterations;
36  		this.xAlphabet = xAlphabet;
37  		this.yAlphabet = yAlphabet;
38  		this.fxy = fxy;
39  		this.loss = loss;
40  	}
41  
42  	private double calculateDenom(SparseVector a, SparseVector b) {
43  
44  		double result = 0;
45  
46  		TIntDoubleHashMap diff = new TIntDoubleHashMap();
47  
48  		for (int i = 0; i < a.numEntries(); i++) {
49  			int ind = a.getIndexAt(i);
50  			double val = a.getValueAt(i);
51  			if (!diff.containsKey(ind)) {
52  				diff.put(ind, 0);
53  			}
54  			diff.put(ind, diff.get(ind) + val);
55  		}
56  
57  		for (int i = 0; i < b.numEntries(); i++) {
58  			int ind = b.getIndexAt(i);
59  			double val = b.getValueAt(i);
60  			if (!diff.containsKey(ind)) {
61  				diff.put(ind, 0);
62  			}
63  			diff.put(ind, diff.get(ind) - val);
64  		}
65  
66  		for (TIntDoubleIterator iterator = diff.iterator(); iterator.hasNext();) {
67  			iterator.advance();
68  			result += Math.pow(iterator.value(), 2);
69  		}
70  
71  		return result;
72  
73  	}
74  
75  	public LinearClassifier batchTrain(
76  			ArrayList<ClassificationInstance> trainingData) {
77  		LinearClassifier w = new LinearClassifier(this.xAlphabet,
78  				this.yAlphabet, this.fxy);
79  		LinearClassifier theta = null;
80  		if (this.performAveraging) {
81  			theta = new LinearClassifier(this.xAlphabet, this.yAlphabet,
82  					this.fxy);
83  		}
84  		for (int iter = 0; iter < this.numIterations; iter++) {
85  			for (ClassificationInstance inst : trainingData) {
86  				int yhat = w.label(inst.x);
87  				// calculate loss
88  				double lloss = this.loss.calculate(inst.y, yhat);
89  				// calculate alpha
90  
91  				double alpha = lloss
92  						+ StaticUtils.dotProduct(this.fxy.apply(inst.x, yhat),
93  								w.w)
94  						- StaticUtils.dotProduct(
95  								this.fxy.apply(inst.x, inst.y), w.w);
96  				if (alpha <= 0) {
97  					continue;
98  				}
99  				alpha = alpha
100 						/ this.calculateDenom(this.fxy.apply(inst.x, inst.y),
101 								this.fxy.apply(inst.x, yhat));
102 
103 				// if y = yhat then this update will not change w.
104 				StaticUtils.plusEquals(w.w, this.fxy.apply(inst.x, inst.y),
105 						alpha);
106 				StaticUtils.plusEquals(w.w, this.fxy.apply(inst.x, yhat),
107 						-alpha);
108 				if (this.performAveraging) {
109 					StaticUtils.plusEquals(theta.w, w.w, 1);
110 				}
111 			}
112 		}
113 		if (this.performAveraging) {
114 			return theta;
115 		}
116 		return w;
117 	}
118 
119 }