View Javadoc

1   package sequence;
2   
3   /****
4    * This class implements one best structured Mira
5    * @author Kuzman Ganchev and Georgi Georgiev
6    * <A HREF="mailto:georgiev@ontotext.com>georgi.georgiev@ontotext.com</A>
7    * <A HREF="mailto:ganchev@ontotext.com>kuzman.ganchev@ontotext.com</A>
8    * Date: Thu Feb 26 12:27:56 EET 2009
9    */
10  
11  import gnu.trove.TIntDoubleHashMap;
12  import gnu.trove.TIntDoubleIterator;
13  
14  import java.util.ArrayList;
15  
16  import types.Alphabet;
17  import types.SparseVector;
18  import types.StaticUtils;
19  
20  public class Mira {
21  
22  	boolean performAveraging;
23  	int numIterations;
24  	Alphabet xAlphabet;
25  	Alphabet yAlphabet;
26  	SequenceFeatureFunction fxy;
27  	Loss loss;
28  
29  	public Mira(boolean performAveraging, int numIterations,
30  			Alphabet xAlphabet, Alphabet yAlphabet,
31  			SequenceFeatureFunction fxy, Loss loss) {
32  		this.performAveraging = performAveraging;
33  		this.numIterations = numIterations;
34  		this.xAlphabet = xAlphabet;
35  		this.yAlphabet = yAlphabet;
36  		this.fxy = fxy;
37  		this.loss = loss;
38  	}
39  
40  	private double calculateDenom(SparseVector a, SparseVector b) {
41  
42  		double result = 0;
43  
44  		TIntDoubleHashMap diff = new TIntDoubleHashMap();
45  
46  		for (int i = 0; i < a.numEntries(); i++) {
47  			int ind = a.getIndexAt(i);
48  			double val = a.getValueAt(i);
49  			if (!diff.containsKey(ind)) {
50  				diff.put(ind, 0);
51  			}
52  			diff.put(ind, diff.get(ind) + val);
53  		}
54  
55  		for (int i = 0; i < b.numEntries(); i++) {
56  			int ind = b.getIndexAt(i);
57  			double val = b.getValueAt(i);
58  			if (!diff.containsKey(ind)) {
59  				diff.put(ind, 0);
60  			}
61  			diff.put(ind, diff.get(ind) - val);
62  		}
63  
64  		for (TIntDoubleIterator iterator = diff.iterator(); iterator.hasNext();) {
65  			iterator.advance();
66  			result += Math.pow(iterator.value(), 2);
67  		}
68  
69  		return result;
70  
71  	}
72  
73  	public LinearTagger batchTrain(ArrayList<SequenceInstance> trainingData) {
74  		LinearTagger w = new LinearTagger(this.xAlphabet, this.yAlphabet,
75  				this.fxy);
76  		LinearTagger theta = null;
77  		if (this.performAveraging) {
78  			theta = new LinearTagger(this.xAlphabet, this.yAlphabet, this.fxy);
79  		}
80  		for (int iter = 0; iter < this.numIterations; iter++) {
81  			for (SequenceInstance inst : trainingData) {
82  				int[] yhat = w.label(inst.x);
83  				// calculate loss
84  				double lloss = this.loss.calculate(inst.y, yhat);
85  				// calculate alpha
86  
87  				double alpha = lloss
88  						+ StaticUtils.dotProduct(this.fxy.apply(inst.x, yhat),
89  								w.w)
90  						- StaticUtils.dotProduct(
91  								this.fxy.apply(inst.x, inst.y), w.w);
92  				if (alpha <= 0) {
93  					continue;
94  				}
95  				alpha = alpha
96  						/ this.calculateDenom(this.fxy.apply(inst.x, inst.y),
97  								this.fxy.apply(inst.x, yhat));
98  
99  				// if y = yhat then this update will not change w.
100 				StaticUtils.plusEquals(w.w, this.fxy.apply(inst.x, inst.y),
101 						alpha);
102 				StaticUtils.plusEquals(w.w, this.fxy.apply(inst.x, yhat),
103 						-alpha);
104 				if (this.performAveraging) {
105 					StaticUtils.plusEquals(theta.w, w.w, 1);
106 				}
107 			}
108 		}
109 		if (this.performAveraging) {
110 			return theta;
111 		}
112 		return w;
113 	}
114 
115 }