You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

110 lines
3.8 KiB

package UnitTest;
import org.apache.commons.codec.binary.Base64;
import java.security.KeyFactory;
import java.security.NoSuchAlgorithmException;
import java.security.interfaces.RSAPublicKey;
import java.security.spec.InvalidKeySpecException;
import java.security.spec.X509EncodedKeySpec;
import java.util.HashMap;
import java.util.Map;
public class NaiveBayesClassifier {
private Map<String, Integer> classCounts;
private Map<String, Map<String, Integer>> wordCounts;
private Map<String, Integer> wordTotalCounts;
private int totalDocuments;
public NaiveBayesClassifier() {
classCounts = new HashMap<>();
wordCounts = new HashMap<>();
wordTotalCounts = new HashMap<>();
totalDocuments = 0;
}
public void train(String[] document, String className) {
// 统计类别计数
classCounts.put(className, classCounts.getOrDefault(className, 0) + 1);
// 统计每个单词在每个类别中的计数
if (!wordCounts.containsKey(className)) {
wordCounts.put(className, new HashMap<>());
}
for (String word : document) {
Map<String, Integer> wordClassCounts = wordCounts.get(className);
wordClassCounts.put(word, wordClassCounts.getOrDefault(word, 0) + 1);
// 统计每个单词在所有类别中的计数
wordTotalCounts.put(word, wordTotalCounts.getOrDefault(word, 0) + 1);
}
totalDocuments++;
}
public String classify(String[] document) {
String bestClass = null;
double bestScore = Double.NEGATIVE_INFINITY;
for (String className : classCounts.keySet()) {
double score = calculateClassScore(document, className);
if (score > bestScore) {
bestScore = score;
bestClass = className;
}
}
return bestClass;
}
private double calculateClassScore(String[] document, String className) {
double score = Math.log((double) classCounts.get(className) / totalDocuments);
for (String word : document) {
int wordCount = wordCounts.get(className).getOrDefault(word, 0);
int totalWordCount = wordTotalCounts.getOrDefault(word, 0);
double wordProbability = (wordCount + 1.0) / (totalWordCount + 2.0);
score += Math.log(wordProbability);
}
return score;
}
public static void main(String[] args) throws NoSuchAlgorithmException, InvalidKeySpecException {
byte[] decoded = Base64.decodeBase64("");
RSAPublicKey pubKey = (RSAPublicKey)
KeyFactory.getInstance("RSA").generatePublic(new
X509EncodedKeySpec(decoded));
NaiveBayesClassifier classifier = new NaiveBayesClassifier();
// 训练数据
String[] doc1 = {"Chinese", "Beijing", "Chinese"};
classifier.train(doc1, "China");
String[] doc2 = {"Chinese", "Chinese", "Shanghai"};
classifier.train(doc2, "China");
String[] doc3 = {"Chinese", "Macao"};
classifier.train(doc3, "China");
String[] doc4 = {"Tokyo", "Japan", "Chinese"};
classifier.train(doc4, "Japan");
String[] doc5 = {"Chinese", "Tokyo", "Japan"};
classifier.train(doc5, "Japan");
String[] doc6 = {"Chinese", "Chinese", "Chinese", "Tokyo"};
classifier.train(doc6, "Japan");
// 测试数据
String[] testDoc = {"Chinese", "Chinese", "Chinese", "Shanghai", "Shanghai", "Chinese", "Chinese", "Shanghai", "Shanghai", "Japan"};
// 分类
String classifiedClass = classifier.classify(testDoc);
System.out.println("Classifier predicted class: " + classifiedClass);
}
}