1 package classification;
2
3 import java.util.ArrayList;
4 import types.Alphabet;
5 import types.ClassificationInstance;
6 import types.LinearClassifier;
7
8 public class NaiveBayes {
9
10 double[] counts;
11 Alphabet xAlphabet;
12 Alphabet yAlphabet;
13 CompleteFeatureFunction fxy;
14
15 public NaiveBayes(double smoothTrue, double smoothFalse,
16 Alphabet xAlphabet, Alphabet yAlphabet) {
17 this.xAlphabet = xAlphabet;
18 this.yAlphabet = yAlphabet;
19 fxy = new CompleteFeatureFunction(xAlphabet, yAlphabet);
20 counts = new double[fxy.wSize()];
21 int defaultFeatureIndex = fxy.defalutFeatureIndex;
22 for (int y = 0; y < yAlphabet.size(); y++) {
23 counts[indexOf(y, defaultFeatureIndex)] = smoothTrue + smoothFalse;
24 for (int f = 0; f < xAlphabet.size(); f++) {
25 counts[indexOf(y, f)] = smoothTrue;
26 }
27 }
28 }
29
30 private int indexOf(int y, int feat) {
31 return y * (fxy.defalutFeatureIndex + 1) + feat;
32 }
33
34 public LinearClassifier batchTrain(
35 ArrayList<ClassificationInstance> trainingData) {
36 LinearClassifier res = new LinearClassifier(xAlphabet, yAlphabet, fxy);
37 int defaultFeatureIndex = fxy.defalutFeatureIndex;
38
39
40 for (ClassificationInstance inst : trainingData) {
41 counts[indexOf(inst.y, defaultFeatureIndex)] += 1;
42 for (int i = 0; i < inst.x.numEntries(); i++) {
43 counts[indexOf(inst.y, inst.x.getIndexAt(i))] += 1;
44 }
45 }
46
47 double sumYCounts = 0;
48 for (int y = 0; y < yAlphabet.size(); y++) {
49 sumYCounts += counts[indexOf(y, defaultFeatureIndex)];
50 }
51
52
53 for (int y = 0; y < yAlphabet.size(); y++) {
54 double countOfY = counts[indexOf(y, defaultFeatureIndex)];
55 double prY = countOfY / sumYCounts;
56 double weightY = Math.log(prY);
57 if (Double.isNaN(weightY))
58 throw new AssertionError();
59 for (int f = 0; f < defaultFeatureIndex; f++) {
60 double prXfgivenY = counts[indexOf(y, f)] / countOfY;
61 double prNotXfgivenY = 1 - prXfgivenY;
62 weightY += Math.log(prNotXfgivenY);
63 if (Double.isNaN(weightY))
64 throw new AssertionError();
65 res.w[indexOf(y, f)] -= Math.log(prNotXfgivenY);
66 res.w[indexOf(y, f)] += Math.log(prXfgivenY);
67 }
68 res.w[indexOf(y, defaultFeatureIndex)] = weightY;
69 }
70 return res;
71 }
72
73 }