package UnitTest; import org.apache.commons.codec.binary.Base64; import java.security.KeyFactory; import java.security.NoSuchAlgorithmException; import java.security.interfaces.RSAPublicKey; import java.security.spec.InvalidKeySpecException; import java.security.spec.X509EncodedKeySpec; import java.util.HashMap; import java.util.Map; public class NaiveBayesClassifier { private Map classCounts; private Map> wordCounts; private Map wordTotalCounts; private int totalDocuments; public NaiveBayesClassifier() { classCounts = new HashMap<>(); wordCounts = new HashMap<>(); wordTotalCounts = new HashMap<>(); totalDocuments = 0; } public void train(String[] document, String className) { // 统计类别计数 classCounts.put(className, classCounts.getOrDefault(className, 0) + 1); // 统计每个单词在每个类别中的计数 if (!wordCounts.containsKey(className)) { wordCounts.put(className, new HashMap<>()); } for (String word : document) { Map wordClassCounts = wordCounts.get(className); wordClassCounts.put(word, wordClassCounts.getOrDefault(word, 0) + 1); // 统计每个单词在所有类别中的计数 wordTotalCounts.put(word, wordTotalCounts.getOrDefault(word, 0) + 1); } totalDocuments++; } public String classify(String[] document) { String bestClass = null; double bestScore = Double.NEGATIVE_INFINITY; for (String className : classCounts.keySet()) { double score = calculateClassScore(document, className); if (score > bestScore) { bestScore = score; bestClass = className; } } return bestClass; } private double calculateClassScore(String[] document, String className) { double score = Math.log((double) classCounts.get(className) / totalDocuments); for (String word : document) { int wordCount = wordCounts.get(className).getOrDefault(word, 0); int totalWordCount = wordTotalCounts.getOrDefault(word, 0); double wordProbability = (wordCount + 1.0) / (totalWordCount + 2.0); score += Math.log(wordProbability); } return score; } public static void main(String[] args) throws NoSuchAlgorithmException, InvalidKeySpecException { byte[] decoded = Base64.decodeBase64(""); RSAPublicKey pubKey = (RSAPublicKey) KeyFactory.getInstance("RSA").generatePublic(new X509EncodedKeySpec(decoded)); NaiveBayesClassifier classifier = new NaiveBayesClassifier(); // 训练数据 String[] doc1 = {"Chinese", "Beijing", "Chinese"}; classifier.train(doc1, "China"); String[] doc2 = {"Chinese", "Chinese", "Shanghai"}; classifier.train(doc2, "China"); String[] doc3 = {"Chinese", "Macao"}; classifier.train(doc3, "China"); String[] doc4 = {"Tokyo", "Japan", "Chinese"}; classifier.train(doc4, "Japan"); String[] doc5 = {"Chinese", "Tokyo", "Japan"}; classifier.train(doc5, "Japan"); String[] doc6 = {"Chinese", "Chinese", "Chinese", "Tokyo"}; classifier.train(doc6, "Japan"); // 测试数据 String[] testDoc = {"Chinese", "Chinese", "Chinese", "Shanghai", "Shanghai", "Chinese", "Chinese", "Shanghai", "Shanghai", "Japan"}; // 分类 String classifiedClass = classifier.classify(testDoc); System.out.println("Classifier predicted class: " + classifiedClass); } }