View Javadoc

1   package classification;
2   
3   import java.util.ArrayList;
4   import types.Alphabet;
5   import types.ClassificationInstance;
6   import types.LinearClassifier;
7   
8   public class NaiveBayes {
9   
10  	double[] counts;
11  	Alphabet xAlphabet;
12  	Alphabet yAlphabet;
13  	CompleteFeatureFunction fxy;
14  
15  	public NaiveBayes(double smoothTrue, double smoothFalse,
16  			Alphabet xAlphabet, Alphabet yAlphabet) {
17  		this.xAlphabet = xAlphabet;
18  		this.yAlphabet = yAlphabet;
19  		fxy = new CompleteFeatureFunction(xAlphabet, yAlphabet);
20  		counts = new double[fxy.wSize()];
21  		int defaultFeatureIndex = fxy.defalutFeatureIndex;
22  		for (int y = 0; y < yAlphabet.size(); y++) {
23  			counts[indexOf(y, defaultFeatureIndex)] = smoothTrue + smoothFalse;
24  			for (int f = 0; f < xAlphabet.size(); f++) {
25  				counts[indexOf(y, f)] = smoothTrue;
26  			}
27  		}
28  	}
29  
30  	private int indexOf(int y, int feat) {
31  		return y * (fxy.defalutFeatureIndex + 1) + feat;
32  	}
33  
34  	public LinearClassifier batchTrain(
35  			ArrayList<ClassificationInstance> trainingData) {
36  		LinearClassifier res = new LinearClassifier(xAlphabet, yAlphabet, fxy);
37  		int defaultFeatureIndex = fxy.defalutFeatureIndex;
38  
39  		// update the counts that we've seen so far
40  		for (ClassificationInstance inst : trainingData) {
41  			counts[indexOf(inst.y, defaultFeatureIndex)] += 1;
42  			for (int i = 0; i < inst.x.numEntries(); i++) {
43  				counts[indexOf(inst.y, inst.x.getIndexAt(i))] += 1;
44  			}
45  		}
46  
47  		double sumYCounts = 0;
48  		for (int y = 0; y < yAlphabet.size(); y++) {
49  			sumYCounts += counts[indexOf(y, defaultFeatureIndex)];
50  		}
51  
52  		// compute the probabilities given the current counts
53  		for (int y = 0; y < yAlphabet.size(); y++) {
54  			double countOfY = counts[indexOf(y, defaultFeatureIndex)];
55  			double prY = countOfY / sumYCounts;
56  			double weightY = Math.log(prY);
57  			if (Double.isNaN(weightY))
58  				throw new AssertionError();
59  			for (int f = 0; f < defaultFeatureIndex; f++) {
60  				double prXfgivenY = counts[indexOf(y, f)] / countOfY;
61  				double prNotXfgivenY = 1 - prXfgivenY;
62  				weightY += Math.log(prNotXfgivenY);
63  				if (Double.isNaN(weightY))
64  					throw new AssertionError();
65  				res.w[indexOf(y, f)] -= Math.log(prNotXfgivenY);
66  				res.w[indexOf(y, f)] += Math.log(prXfgivenY);
67  			}
68  			res.w[indexOf(y, defaultFeatureIndex)] = weightY;
69  		}
70  		return res;
71  	}
72  
73  }