1 package sequence;
2
3 /****
4 * This class implements one best structured Mira
5 * @author Kuzman Ganchev and Georgi Georgiev
6 * <A HREF="mailto:georgiev@ontotext.com>georgi.georgiev@ontotext.com</A>
7 * <A HREF="mailto:ganchev@ontotext.com>kuzman.ganchev@ontotext.com</A>
8 * Date: Thu Feb 26 12:27:56 EET 2009
9 */
10
11 import gnu.trove.TIntDoubleHashMap;
12 import gnu.trove.TIntDoubleIterator;
13
14 import java.util.ArrayList;
15
16 import types.Alphabet;
17 import types.SparseVector;
18 import types.StaticUtils;
19
20 public class Mira {
21
22 boolean performAveraging;
23 int numIterations;
24 Alphabet xAlphabet;
25 Alphabet yAlphabet;
26 SequenceFeatureFunction fxy;
27 Loss loss;
28
29 public Mira(boolean performAveraging, int numIterations,
30 Alphabet xAlphabet, Alphabet yAlphabet,
31 SequenceFeatureFunction fxy, Loss loss) {
32 this.performAveraging = performAveraging;
33 this.numIterations = numIterations;
34 this.xAlphabet = xAlphabet;
35 this.yAlphabet = yAlphabet;
36 this.fxy = fxy;
37 this.loss = loss;
38 }
39
40 private double calculateDenom(SparseVector a, SparseVector b) {
41
42 double result = 0;
43
44 TIntDoubleHashMap diff = new TIntDoubleHashMap();
45
46 for (int i = 0; i < a.numEntries(); i++) {
47 int ind = a.getIndexAt(i);
48 double val = a.getValueAt(i);
49 if (!diff.containsKey(ind)) {
50 diff.put(ind, 0);
51 }
52 diff.put(ind, diff.get(ind) + val);
53 }
54
55 for (int i = 0; i < b.numEntries(); i++) {
56 int ind = b.getIndexAt(i);
57 double val = b.getValueAt(i);
58 if (!diff.containsKey(ind)) {
59 diff.put(ind, 0);
60 }
61 diff.put(ind, diff.get(ind) - val);
62 }
63
64 for (TIntDoubleIterator iterator = diff.iterator(); iterator.hasNext();) {
65 iterator.advance();
66 result += Math.pow(iterator.value(), 2);
67 }
68
69 return result;
70
71 }
72
73 public LinearTagger batchTrain(ArrayList<SequenceInstance> trainingData) {
74 LinearTagger w = new LinearTagger(this.xAlphabet, this.yAlphabet,
75 this.fxy);
76 LinearTagger theta = null;
77 if (this.performAveraging) {
78 theta = new LinearTagger(this.xAlphabet, this.yAlphabet, this.fxy);
79 }
80 for (int iter = 0; iter < this.numIterations; iter++) {
81 for (SequenceInstance inst : trainingData) {
82 int[] yhat = w.label(inst.x);
83
84 double lloss = this.loss.calculate(inst.y, yhat);
85
86
87 double alpha = lloss
88 + StaticUtils.dotProduct(this.fxy.apply(inst.x, yhat),
89 w.w)
90 - StaticUtils.dotProduct(
91 this.fxy.apply(inst.x, inst.y), w.w);
92 if (alpha <= 0) {
93 continue;
94 }
95 alpha = alpha
96 / this.calculateDenom(this.fxy.apply(inst.x, inst.y),
97 this.fxy.apply(inst.x, yhat));
98
99
100 StaticUtils.plusEquals(w.w, this.fxy.apply(inst.x, inst.y),
101 alpha);
102 StaticUtils.plusEquals(w.w, this.fxy.apply(inst.x, yhat),
103 -alpha);
104 if (this.performAveraging) {
105 StaticUtils.plusEquals(theta.w, w.w, 1);
106 }
107 }
108 }
109 if (this.performAveraging) {
110 return theta;
111 }
112 return w;
113 }
114
115 }