1 package classification;
2
3 /****
4 * This class implements one best structured Mira
5 * @author Kuzman Ganchev and Georgi Georgiev
6 * <A HREF="mailto:georgiev@ontotext.com>georgi.georgiev@ontotext.com</A>
7 * <A HREF="mailto:ganchev@ontotext.com>kuzman.ganchev@ontotext.com</A>
8 * Date: Thu Feb 26 12:27:56 EET 2009
9 */
10
11 import gnu.trove.TIntDoubleHashMap;
12 import gnu.trove.TIntDoubleIterator;
13
14 import java.util.ArrayList;
15
16 import types.Alphabet;
17 import types.ClassificationInstance;
18 import types.LinearClassifier;
19 import types.SparseVector;
20 import types.StaticUtils;
21
22 public class Mira {
23
24 boolean performAveraging;
25 int numIterations;
26 Alphabet xAlphabet;
27 Alphabet yAlphabet;
28 CompleteFeatureFunction fxy;
29 Loss loss;
30
31 public Mira(boolean performAveraging, int numIterations,
32 Alphabet xAlphabet, Alphabet yAlphabet,
33 CompleteFeatureFunction fxy, Loss loss) {
34 this.performAveraging = performAveraging;
35 this.numIterations = numIterations;
36 this.xAlphabet = xAlphabet;
37 this.yAlphabet = yAlphabet;
38 this.fxy = fxy;
39 this.loss = loss;
40 }
41
42 private double calculateDenom(SparseVector a, SparseVector b) {
43
44 double result = 0;
45
46 TIntDoubleHashMap diff = new TIntDoubleHashMap();
47
48 for (int i = 0; i < a.numEntries(); i++) {
49 int ind = a.getIndexAt(i);
50 double val = a.getValueAt(i);
51 if (!diff.containsKey(ind)) {
52 diff.put(ind, 0);
53 }
54 diff.put(ind, diff.get(ind) + val);
55 }
56
57 for (int i = 0; i < b.numEntries(); i++) {
58 int ind = b.getIndexAt(i);
59 double val = b.getValueAt(i);
60 if (!diff.containsKey(ind)) {
61 diff.put(ind, 0);
62 }
63 diff.put(ind, diff.get(ind) - val);
64 }
65
66 for (TIntDoubleIterator iterator = diff.iterator(); iterator.hasNext();) {
67 iterator.advance();
68 result += Math.pow(iterator.value(), 2);
69 }
70
71 return result;
72
73 }
74
75 public LinearClassifier batchTrain(
76 ArrayList<ClassificationInstance> trainingData) {
77 LinearClassifier w = new LinearClassifier(this.xAlphabet,
78 this.yAlphabet, this.fxy);
79 LinearClassifier theta = null;
80 if (this.performAveraging) {
81 theta = new LinearClassifier(this.xAlphabet, this.yAlphabet,
82 this.fxy);
83 }
84 for (int iter = 0; iter < this.numIterations; iter++) {
85 for (ClassificationInstance inst : trainingData) {
86 int yhat = w.label(inst.x);
87
88 double lloss = this.loss.calculate(inst.y, yhat);
89
90
91 double alpha = lloss
92 + StaticUtils.dotProduct(this.fxy.apply(inst.x, yhat),
93 w.w)
94 - StaticUtils.dotProduct(
95 this.fxy.apply(inst.x, inst.y), w.w);
96 if (alpha <= 0) {
97 continue;
98 }
99 alpha = alpha
100 / this.calculateDenom(this.fxy.apply(inst.x, inst.y),
101 this.fxy.apply(inst.x, yhat));
102
103
104 StaticUtils.plusEquals(w.w, this.fxy.apply(inst.x, inst.y),
105 alpha);
106 StaticUtils.plusEquals(w.w, this.fxy.apply(inst.x, yhat),
107 -alpha);
108 if (this.performAveraging) {
109 StaticUtils.plusEquals(theta.w, w.w, 1);
110 }
111 }
112 }
113 if (this.performAveraging) {
114 return theta;
115 }
116 return w;
117 }
118
119 }