diff --git a/AI/WxMini/Start.py b/AI/WxMini/Start.py index 2c67c501..610bbdee 100644 --- a/AI/WxMini/Start.py +++ b/AI/WxMini/Start.py @@ -1,72 +1,24 @@ -import os -import uuid +import asyncio +import logging import time -import jieba -from fastapi import FastAPI, Form, HTTPException -from openai import AsyncOpenAI # 使用异步客户端 -from gensim.models import KeyedVectors +import uuid from contextlib import asynccontextmanager -from WxMini.Utils.OssUtil import upload_mp3_to_oss, upload_mp3_to_oss_from_memory -from WxMini.Utils.TtsUtil import TTS + +from fastapi import FastAPI, Form, HTTPException +from openai import AsyncOpenAI + +from WxMini.Milvus.Config.MulvusConfig import * from WxMini.Milvus.Utils.MilvusCollectionManager import MilvusCollectionManager from WxMini.Milvus.Utils.MilvusConnectionPool import * -from WxMini.Milvus.Config.MulvusConfig import * -import asyncio # 引入异步支持 -import logging # 增加日志记录 - -import jieba.analyse +from WxMini.Utils.OssUtil import upload_mp3_to_oss_from_memory +from WxMini.Utils.TtsUtil import TTS +from WxMini.Utils.MySQLUtil import init_mysql_pool, save_chat_to_mysql +from WxMini.Utils.EmbeddingUtil import text_to_embedding # 配置日志 logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s") logger = logging.getLogger(__name__) -from aiomysql import create_pool - -# MySQL 配置 -MYSQL_CONFIG = { - "host": MYSQL_HOST, - "port": MYSQL_PORT, - "user": MYSQL_USER, - "password": MYSQL_PASSWORD, - "db": MYSQL_DB_NAME, - "minsize": 1, - "maxsize": 20, -} - - -# 保存聊天记录到 MySQL -async def save_chat_to_mysql(mysql_pool, session_id, prompt, result): - async with mysql_pool.acquire() as conn: - async with conn.cursor() as cur: - await cur.execute( - "INSERT INTO t_chat_log (session_id, user_input, model_response, create_time) VALUES (%s, %s, %s, NOW())", - (session_id, prompt, result) - ) - await conn.commit() - - -# 初始化 MySQL 连接池 -async def init_mysql_pool(): - return await create_pool(**MYSQL_CONFIG) - - -# 提取用户输入的关键词 -def extract_keywords(text, topK=3): - """ - 提取用户输入的关键词 - :param text: 用户输入的文本 - :param topK: 返回的关键词数量 - :return: 关键词列表 - """ - keywords = jieba.analyse.extract_tags(text, topK=topK) - return keywords - - -# 初始化 Word2Vec 模型 -model_path = MS_MODEL_PATH -model = KeyedVectors.load_word2vec_format(model_path, binary=False, limit=MS_MODEL_LIMIT) -logger.info(f"模型加载成功,词向量维度: {model.vector_size}") - # 初始化 Milvus 连接池 milvus_pool = MilvusConnectionPool(host=MS_HOST, port=MS_PORT, max_connections=MS_MAX_CONNECTIONS) @@ -74,22 +26,6 @@ milvus_pool = MilvusConnectionPool(host=MS_HOST, port=MS_PORT, max_connections=M collection_name = MS_COLLECTION_NAME collection_manager = MilvusCollectionManager(collection_name) - -# 将文本转换为嵌入向量 -def text_to_embedding(text): - words = jieba.lcut(text) # 使用 jieba 分词 - logger.info(f"文本: {text}, 分词结果: {words}") - embeddings = [model[word] for word in words if word in model] - logger.info(f"有效词向量数量: {len(embeddings)}") - if embeddings: - avg_embedding = sum(embeddings) / len(embeddings) - logger.info(f"生成的平均向量: {avg_embedding[:5]}...") # 打印前 5 维 - return avg_embedding - else: - logger.warning("未找到有效词,返回零向量") - return [0.0] * model.vector_size - - # 使用 Lifespan Events 处理应用启动和关闭逻辑 @asynccontextmanager async def lifespan(app: FastAPI): @@ -106,7 +42,6 @@ async def lifespan(app: FastAPI): await app.state.mysql_pool.wait_closed() logger.info("Milvus 和 MySQL 连接池已关闭。") - # 初始化 FastAPI 应用 app = FastAPI(lifespan=lifespan) @@ -116,7 +51,6 @@ client = AsyncOpenAI( base_url="https://dashscope.aliyuncs.com/compatible-mode/v1", ) - @app.post("/reply") async def reply(session_id: str = Form(...), prompt: str = Form(...)): """ @@ -220,8 +154,7 @@ async def reply(session_id: str = Form(...), prompt: str = Form(...)): audio_data = await asyncio.to_thread(t.generate_audio, result) # 假设 TTS 类有一个 generate_audio 方法返回音频数据 # 将音频数据直接上传到 OSS - await asyncio.to_thread(upload_mp3_to_oss_from_memory, tts_file, - audio_data) # 假设 upload_mp3_to_oss_from_memory 支持从内存上传 + await asyncio.to_thread(upload_mp3_to_oss_from_memory, tts_file, audio_data) logger.info(f"TTS 文件已直接上传到 OSS: {tts_file}") # 完整的 URL @@ -241,9 +174,8 @@ async def reply(session_id: str = Form(...), prompt: str = Form(...)): # 释放连接 milvus_pool.release_connection(connection) - # 运行 FastAPI 应用 if __name__ == "__main__": import uvicorn - uvicorn.run("Start:app", host="0.0.0.0", port=5800, workers=1) + uvicorn.run("Start:app", host="0.0.0.0", port=5800, workers=1) \ No newline at end of file diff --git a/AI/WxMini/Utils/EmbeddingUtil.py b/AI/WxMini/Utils/EmbeddingUtil.py new file mode 100644 index 00000000..fde367c0 --- /dev/null +++ b/AI/WxMini/Utils/EmbeddingUtil.py @@ -0,0 +1,26 @@ +import logging +import jieba +from gensim.models import KeyedVectors + +# 配置日志 +logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s") +logger = logging.getLogger(__name__) + +# 初始化 Word2Vec 模型 +model_path = r"D:\Tencent_AILab_ChineseEmbedding\Tencent_AILab_ChineseEmbedding.txt" +model = KeyedVectors.load_word2vec_format(model_path, binary=False, limit=10000) +logger.info(f"模型加载成功,词向量维度: {model.vector_size}") + +# 将文本转换为嵌入向量 +def text_to_embedding(text): + words = jieba.lcut(text) # 使用 jieba 分词 + logger.info(f"文本: {text}, 分词结果: {words}") + embeddings = [model[word] for word in words if word in model] + logger.info(f"有效词向量数量: {len(embeddings)}") + if embeddings: + avg_embedding = sum(embeddings) / len(embeddings) + logger.info(f"生成的平均向量: {avg_embedding[:5]}...") # 打印前 5 维 + return avg_embedding + else: + logger.warning("未找到有效词,返回零向量") + return [0.0] * model.vector_size \ No newline at end of file diff --git a/AI/WxMini/Utils/MySQLUtil.py b/AI/WxMini/Utils/MySQLUtil.py new file mode 100644 index 00000000..147b8342 --- /dev/null +++ b/AI/WxMini/Utils/MySQLUtil.py @@ -0,0 +1,36 @@ +import asyncio +import logging +from aiomysql import create_pool +from WxMini.Milvus.Config.MulvusConfig import * + +# 配置日志 +logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s") +logger = logging.getLogger(__name__) + +# MySQL 配置 +MYSQL_CONFIG = { + "host": MYSQL_HOST, + "port": MYSQL_PORT, + "user": MYSQL_USER, + "password": MYSQL_PASSWORD, + "db": MYSQL_DB_NAME, + "minsize": 1, + "maxsize": 20, +} + + +# 初始化 MySQL 连接池 +async def init_mysql_pool(): + return await create_pool(**MYSQL_CONFIG) + + +# 保存聊天记录到 MySQL +async def save_chat_to_mysql(mysql_pool, session_id, prompt, result): + async with mysql_pool.acquire() as conn: + async with conn.cursor() as cur: + await cur.execute( + "INSERT INTO t_chat_log (session_id, user_input, model_response, create_time) VALUES (%s, %s, %s, NOW())", + (session_id, prompt, result) + ) + await conn.commit() + logger.info("用户输入和大模型反馈已记录到 MySQL 数据库。") diff --git a/AI/WxMini/Utils/__pycache__/EmbeddingUtil.cpython-310.pyc b/AI/WxMini/Utils/__pycache__/EmbeddingUtil.cpython-310.pyc new file mode 100644 index 00000000..454f8a58 Binary files /dev/null and b/AI/WxMini/Utils/__pycache__/EmbeddingUtil.cpython-310.pyc differ diff --git a/AI/WxMini/Utils/__pycache__/MySQLUtil.cpython-310.pyc b/AI/WxMini/Utils/__pycache__/MySQLUtil.cpython-310.pyc new file mode 100644 index 00000000..710d61e1 Binary files /dev/null and b/AI/WxMini/Utils/__pycache__/MySQLUtil.cpython-310.pyc differ diff --git a/AI/WxMini/__pycache__/Start.cpython-310.pyc b/AI/WxMini/__pycache__/Start.cpython-310.pyc index 3c11859d..fd3591d3 100644 Binary files a/AI/WxMini/__pycache__/Start.cpython-310.pyc and b/AI/WxMini/__pycache__/Start.cpython-310.pyc differ