main
HuangHai 4 months ago
parent d7b4fc92cb
commit 570322904b

@ -1,72 +1,24 @@
import os
import uuid
import asyncio
import logging
import time
import jieba
from fastapi import FastAPI, Form, HTTPException
from openai import AsyncOpenAI # 使用异步客户端
from gensim.models import KeyedVectors
import uuid
from contextlib import asynccontextmanager
from WxMini.Utils.OssUtil import upload_mp3_to_oss, upload_mp3_to_oss_from_memory
from WxMini.Utils.TtsUtil import TTS
from fastapi import FastAPI, Form, HTTPException
from openai import AsyncOpenAI
from WxMini.Milvus.Config.MulvusConfig import *
from WxMini.Milvus.Utils.MilvusCollectionManager import MilvusCollectionManager
from WxMini.Milvus.Utils.MilvusConnectionPool import *
from WxMini.Milvus.Config.MulvusConfig import *
import asyncio # 引入异步支持
import logging # 增加日志记录
import jieba.analyse
from WxMini.Utils.OssUtil import upload_mp3_to_oss_from_memory
from WxMini.Utils.TtsUtil import TTS
from WxMini.Utils.MySQLUtil import init_mysql_pool, save_chat_to_mysql
from WxMini.Utils.EmbeddingUtil import text_to_embedding
# 配置日志
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
logger = logging.getLogger(__name__)
from aiomysql import create_pool
# MySQL 配置
MYSQL_CONFIG = {
"host": MYSQL_HOST,
"port": MYSQL_PORT,
"user": MYSQL_USER,
"password": MYSQL_PASSWORD,
"db": MYSQL_DB_NAME,
"minsize": 1,
"maxsize": 20,
}
# 保存聊天记录到 MySQL
async def save_chat_to_mysql(mysql_pool, session_id, prompt, result):
async with mysql_pool.acquire() as conn:
async with conn.cursor() as cur:
await cur.execute(
"INSERT INTO t_chat_log (session_id, user_input, model_response, create_time) VALUES (%s, %s, %s, NOW())",
(session_id, prompt, result)
)
await conn.commit()
# 初始化 MySQL 连接池
async def init_mysql_pool():
return await create_pool(**MYSQL_CONFIG)
# 提取用户输入的关键词
def extract_keywords(text, topK=3):
"""
提取用户输入的关键词
:param text: 用户输入的文本
:param topK: 返回的关键词数量
:return: 关键词列表
"""
keywords = jieba.analyse.extract_tags(text, topK=topK)
return keywords
# 初始化 Word2Vec 模型
model_path = MS_MODEL_PATH
model = KeyedVectors.load_word2vec_format(model_path, binary=False, limit=MS_MODEL_LIMIT)
logger.info(f"模型加载成功,词向量维度: {model.vector_size}")
# 初始化 Milvus 连接池
milvus_pool = MilvusConnectionPool(host=MS_HOST, port=MS_PORT, max_connections=MS_MAX_CONNECTIONS)
@ -74,22 +26,6 @@ milvus_pool = MilvusConnectionPool(host=MS_HOST, port=MS_PORT, max_connections=M
collection_name = MS_COLLECTION_NAME
collection_manager = MilvusCollectionManager(collection_name)
# 将文本转换为嵌入向量
def text_to_embedding(text):
words = jieba.lcut(text) # 使用 jieba 分词
logger.info(f"文本: {text}, 分词结果: {words}")
embeddings = [model[word] for word in words if word in model]
logger.info(f"有效词向量数量: {len(embeddings)}")
if embeddings:
avg_embedding = sum(embeddings) / len(embeddings)
logger.info(f"生成的平均向量: {avg_embedding[:5]}...") # 打印前 5 维
return avg_embedding
else:
logger.warning("未找到有效词,返回零向量")
return [0.0] * model.vector_size
# 使用 Lifespan Events 处理应用启动和关闭逻辑
@asynccontextmanager
async def lifespan(app: FastAPI):
@ -106,7 +42,6 @@ async def lifespan(app: FastAPI):
await app.state.mysql_pool.wait_closed()
logger.info("Milvus 和 MySQL 连接池已关闭。")
# 初始化 FastAPI 应用
app = FastAPI(lifespan=lifespan)
@ -116,7 +51,6 @@ client = AsyncOpenAI(
base_url="https://dashscope.aliyuncs.com/compatible-mode/v1",
)
@app.post("/reply")
async def reply(session_id: str = Form(...), prompt: str = Form(...)):
"""
@ -220,8 +154,7 @@ async def reply(session_id: str = Form(...), prompt: str = Form(...)):
audio_data = await asyncio.to_thread(t.generate_audio, result) # 假设 TTS 类有一个 generate_audio 方法返回音频数据
# 将音频数据直接上传到 OSS
await asyncio.to_thread(upload_mp3_to_oss_from_memory, tts_file,
audio_data) # 假设 upload_mp3_to_oss_from_memory 支持从内存上传
await asyncio.to_thread(upload_mp3_to_oss_from_memory, tts_file, audio_data)
logger.info(f"TTS 文件已直接上传到 OSS: {tts_file}")
# 完整的 URL
@ -241,9 +174,8 @@ async def reply(session_id: str = Form(...), prompt: str = Form(...)):
# 释放连接
milvus_pool.release_connection(connection)
# 运行 FastAPI 应用
if __name__ == "__main__":
import uvicorn
uvicorn.run("Start:app", host="0.0.0.0", port=5800, workers=1)
uvicorn.run("Start:app", host="0.0.0.0", port=5800, workers=1)

@ -0,0 +1,26 @@
import logging
import jieba
from gensim.models import KeyedVectors
# 配置日志
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
logger = logging.getLogger(__name__)
# 初始化 Word2Vec 模型
model_path = r"D:\Tencent_AILab_ChineseEmbedding\Tencent_AILab_ChineseEmbedding.txt"
model = KeyedVectors.load_word2vec_format(model_path, binary=False, limit=10000)
logger.info(f"模型加载成功,词向量维度: {model.vector_size}")
# 将文本转换为嵌入向量
def text_to_embedding(text):
words = jieba.lcut(text) # 使用 jieba 分词
logger.info(f"文本: {text}, 分词结果: {words}")
embeddings = [model[word] for word in words if word in model]
logger.info(f"有效词向量数量: {len(embeddings)}")
if embeddings:
avg_embedding = sum(embeddings) / len(embeddings)
logger.info(f"生成的平均向量: {avg_embedding[:5]}...") # 打印前 5 维
return avg_embedding
else:
logger.warning("未找到有效词,返回零向量")
return [0.0] * model.vector_size

@ -0,0 +1,36 @@
import asyncio
import logging
from aiomysql import create_pool
from WxMini.Milvus.Config.MulvusConfig import *
# 配置日志
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
logger = logging.getLogger(__name__)
# MySQL 配置
MYSQL_CONFIG = {
"host": MYSQL_HOST,
"port": MYSQL_PORT,
"user": MYSQL_USER,
"password": MYSQL_PASSWORD,
"db": MYSQL_DB_NAME,
"minsize": 1,
"maxsize": 20,
}
# 初始化 MySQL 连接池
async def init_mysql_pool():
return await create_pool(**MYSQL_CONFIG)
# 保存聊天记录到 MySQL
async def save_chat_to_mysql(mysql_pool, session_id, prompt, result):
async with mysql_pool.acquire() as conn:
async with conn.cursor() as cur:
await cur.execute(
"INSERT INTO t_chat_log (session_id, user_input, model_response, create_time) VALUES (%s, %s, %s, NOW())",
(session_id, prompt, result)
)
await conn.commit()
logger.info("用户输入和大模型反馈已记录到 MySQL 数据库。")
Loading…
Cancel
Save