You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
110 lines
3.8 KiB
110 lines
3.8 KiB
package UnitTest;
|
|
|
|
import org.apache.commons.codec.binary.Base64;
|
|
|
|
import java.security.KeyFactory;
|
|
import java.security.NoSuchAlgorithmException;
|
|
import java.security.interfaces.RSAPublicKey;
|
|
import java.security.spec.InvalidKeySpecException;
|
|
import java.security.spec.X509EncodedKeySpec;
|
|
import java.util.HashMap;
|
|
import java.util.Map;
|
|
|
|
public class NaiveBayesClassifier {
|
|
|
|
private Map<String, Integer> classCounts;
|
|
private Map<String, Map<String, Integer>> wordCounts;
|
|
private Map<String, Integer> wordTotalCounts;
|
|
private int totalDocuments;
|
|
|
|
public NaiveBayesClassifier() {
|
|
classCounts = new HashMap<>();
|
|
wordCounts = new HashMap<>();
|
|
wordTotalCounts = new HashMap<>();
|
|
totalDocuments = 0;
|
|
}
|
|
|
|
public void train(String[] document, String className) {
|
|
// 统计类别计数
|
|
classCounts.put(className, classCounts.getOrDefault(className, 0) + 1);
|
|
|
|
// 统计每个单词在每个类别中的计数
|
|
if (!wordCounts.containsKey(className)) {
|
|
wordCounts.put(className, new HashMap<>());
|
|
}
|
|
for (String word : document) {
|
|
Map<String, Integer> wordClassCounts = wordCounts.get(className);
|
|
wordClassCounts.put(word, wordClassCounts.getOrDefault(word, 0) + 1);
|
|
|
|
// 统计每个单词在所有类别中的计数
|
|
wordTotalCounts.put(word, wordTotalCounts.getOrDefault(word, 0) + 1);
|
|
}
|
|
|
|
totalDocuments++;
|
|
}
|
|
|
|
public String classify(String[] document) {
|
|
String bestClass = null;
|
|
double bestScore = Double.NEGATIVE_INFINITY;
|
|
|
|
for (String className : classCounts.keySet()) {
|
|
double score = calculateClassScore(document, className);
|
|
|
|
if (score > bestScore) {
|
|
bestScore = score;
|
|
bestClass = className;
|
|
}
|
|
}
|
|
|
|
return bestClass;
|
|
}
|
|
|
|
private double calculateClassScore(String[] document, String className) {
|
|
double score = Math.log((double) classCounts.get(className) / totalDocuments);
|
|
|
|
for (String word : document) {
|
|
int wordCount = wordCounts.get(className).getOrDefault(word, 0);
|
|
int totalWordCount = wordTotalCounts.getOrDefault(word, 0);
|
|
double wordProbability = (wordCount + 1.0) / (totalWordCount + 2.0);
|
|
|
|
score += Math.log(wordProbability);
|
|
}
|
|
|
|
return score;
|
|
}
|
|
|
|
public static void main(String[] args) throws NoSuchAlgorithmException, InvalidKeySpecException {
|
|
|
|
byte[] decoded = Base64.decodeBase64("");
|
|
RSAPublicKey pubKey = (RSAPublicKey)
|
|
KeyFactory.getInstance("RSA").generatePublic(new
|
|
X509EncodedKeySpec(decoded));
|
|
NaiveBayesClassifier classifier = new NaiveBayesClassifier();
|
|
|
|
// 训练数据
|
|
String[] doc1 = {"Chinese", "Beijing", "Chinese"};
|
|
classifier.train(doc1, "China");
|
|
|
|
String[] doc2 = {"Chinese", "Chinese", "Shanghai"};
|
|
classifier.train(doc2, "China");
|
|
|
|
String[] doc3 = {"Chinese", "Macao"};
|
|
classifier.train(doc3, "China");
|
|
|
|
String[] doc4 = {"Tokyo", "Japan", "Chinese"};
|
|
classifier.train(doc4, "Japan");
|
|
|
|
String[] doc5 = {"Chinese", "Tokyo", "Japan"};
|
|
classifier.train(doc5, "Japan");
|
|
|
|
String[] doc6 = {"Chinese", "Chinese", "Chinese", "Tokyo"};
|
|
classifier.train(doc6, "Japan");
|
|
|
|
// 测试数据
|
|
String[] testDoc = {"Chinese", "Chinese", "Chinese", "Shanghai", "Shanghai", "Chinese", "Chinese", "Shanghai", "Shanghai", "Japan"};
|
|
|
|
// 分类
|
|
String classifiedClass = classifier.classify(testDoc);
|
|
System.out.println("Classifier predicted class: " + classifiedClass);
|
|
}
|
|
} |