main
HuangHai 4 weeks ago
parent 3abdc92bc1
commit 4598a3132b

@ -4,12 +4,14 @@ from logging.handlers import RotatingFileHandler
import jieba # 导入 jieba 分词库
import uvicorn
from fastapi import FastAPI, Request, Body
from fastapi import FastAPI, Request, HTTPException
from pydantic import BaseModel, Field, ValidationError
from fastapi.staticfiles import StaticFiles
from openai import OpenAI
from sse_starlette.sse import EventSourceResponse
from gensim.models import KeyedVectors
from Config import Config
from Config.Config import *
from Config.Config import MS_MODEL_PATH, MS_MODEL_LIMIT, MS_HOST, MS_PORT, MS_MAX_CONNECTIONS, MS_NPROBE, DEEPSEEK_API_KEY, DEEPSEEK_URL, MS_COLLECTION_NAME
from Milvus.Utils.MilvusCollectionManager import MilvusCollectionManager
from Milvus.Utils.MilvusConnectionPool import *
from Milvus.Utils.MilvusConnectionPool import MilvusConnectionPool
@ -24,7 +26,7 @@ logger.addHandler(handler)
# 1. 加载预训练的 Word2Vec 模型
model_path = MS_MODEL_PATH # 替换为你的 Word2Vec 模型路径
model = KeyedVectors.load_word2vec_format(model_path, binary=False, limit=MS_MODEL_LIMIT)
print(f"模型加载成功,词向量维度: {model.vector_size}")
logger.info(f"模型加载成功,词向量维度: {model.vector_size}")
@asynccontextmanager
@ -49,19 +51,22 @@ async def lifespan(app: FastAPI):
app = FastAPI(lifespan=lifespan)
# 挂载静态文件目录
app.mount("/static", StaticFiles(directory="Static"), name="static")
# 将文本转换为嵌入向量
def text_to_embedding(text):
words = jieba.lcut(text) # 使用 jieba 分词
print(f"文本: {text}, 分词结果: {words}")
embeddings = [model[word] for word in words if word in model]
print(f"有效词向量数量: {len(embeddings)}")
logger.info(f"有效词向量数量: {len(embeddings)}")
if embeddings:
avg_embedding = sum(embeddings) / len(embeddings)
print(f"生成的平均向量: {avg_embedding[:5]}...") # 打印前 5 维
logger.info(f"生成的平均向量: {avg_embedding[:5]}...") # 打印前 5 维
return avg_embedding
else:
print("未找到有效词,返回零向量")
logger.warning("未找到有效词,返回零向量")
return [0.0] * model.vector_size
@ -83,7 +88,7 @@ async def generate_stream(client, milvus_pool, collection_manager, query):
results = collection_manager.search(current_embedding, search_params, limit=5) # 返回 2 条结果
# 3. 处理搜索结果
print("最相关的历史对话:")
logger.info("最相关的历史对话:")
context = ""
if results:
for hits in results:
@ -91,17 +96,17 @@ async def generate_stream(client, milvus_pool, collection_manager, query):
try:
# 查询非向量字段
record = collection_manager.query_by_id(hit.id)
print(f"ID: {hit.id}")
print(f"标签: {record['tags']}")
print(f"用户问题: {record['user_input']}")
logger.info(f"ID: {hit.id}")
logger.info(f"标签: {record['tags']}")
logger.info(f"用户问题: {record['user_input']}")
context = context + record['user_input']
print(f"时间: {record['timestamp']}")
print(f"距离: {hit.distance}")
print("-" * 40) # 分隔线
logger.info(f"时间: {record['timestamp']}")
logger.info(f"距离: {hit.distance}")
logger.info("-" * 40) # 分隔线
except Exception as e:
print(f"查询失败: {e}")
logger.error(f"查询失败: {e}")
else:
print("未找到相关历史对话,请检查查询参数或数据。")
logger.warning("未找到相关历史对话,请检查查询参数或数据。")
prompt = f"""根据以下关于'{query}'的相关信息,# Role: 信息检索与回答助手
@ -178,19 +183,32 @@ async def generate_stream(client, milvus_pool, collection_manager, query):
"""
http://10.10.21.22:8000/api/rag?query=小学数学中有哪些模型
http://10.10.21.22:8000/static/chat.html
小学数学中有哪些模型
"""
class QueryRequest(BaseModel):
query: str = Field(..., description="用户查询的问题")
@app.post("/api/rag")
async def rag_stream(request: Request, query: str = Body(...)):
async def rag_stream(request: Request):
try:
data = await request.json()
query_request = QueryRequest(**data)
except ValidationError as e:
logger.error(f"请求体验证失败: {e.errors()}")
raise HTTPException(status_code=422, detail=e.errors())
except Exception as e:
logger.error(f"请求解析失败: {str(e)}")
raise HTTPException(status_code=400, detail="无效的请求格式")
"""RAG+DeepSeek流式接口"""
return EventSourceResponse(
generate_stream(
request.app.state.deepseek_client,
request.app.state.milvus_pool,
request.app.state.collection_manager,
query
query_request.query
)
)

@ -1,10 +0,0 @@
<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8">
<title>Title</title>
</head>
<body>
</body>
</html>

@ -0,0 +1,206 @@
<!DOCTYPE html>
<html lang="zh-CN">
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>教育知识问答</title>
<style>
body {
font-family: Arial, sans-serif;
max-width: 800px;
margin: 0 auto;
padding: 20px;
background-color: #f5f5f5;
}
.container {
background: white;
padding: 30px;
border-radius: 8px;
box-shadow: 0 2px 10px rgba(0,0,0,0.1);
}
h1 {
color: #333;
text-align: center;
margin-bottom: 30px;
}
.data-area {
border: 1px solid #ddd;
border-radius: 5px;
padding: 15px;
min-height: 300px;
max-height: 400px;
overflow-y: auto;
background-color: #f8f9fa;
font-family: 'Courier New', monospace;
font-size: 14px;
line-height: 1.6;
white-space: pre-line;
word-wrap: break-word;
overflow-wrap: break-word;
margin-bottom: 20px;
}
.input-area {
display: flex;
gap: 10px;
}
#questionInput {
flex: 1;
padding: 10px;
border: 1px solid #ddd;
border-radius: 5px;
font-size: 16px;
}
#submitBtn {
background-color: #007bff;
color: white;
border: none;
padding: 10px 20px;
font-size: 16px;
border-radius: 5px;
cursor: pointer;
transition: background-color 0.3s;
}
#submitBtn:hover {
background-color: #0056b3;
}
.status {
text-align: center;
margin-bottom: 20px;
font-weight: bold;
}
.status.connecting {
color: #ffc107;
}
.status.connected {
color: #28a745;
}
.status.error {
color: #dc3545;
}
.status.completed {
color: #17a2b8;
}
</style>
</head>
<body>
<div class="container">
<h1>教育知识问答</h1>
<div id="status" class="status">准备就绪</div>
<div class="data-area" id="dataArea">等待问题...</div>
<div class="input-area">
<input type="text" id="questionInput" placeholder="请输入您的问题,例如:小学数学的学习方法">
<button id="submitBtn" onclick="submitQuestion()">提问</button>
</div>
</div>
<script>
let eventSource = null;
let textBuffer = '';
function submitQuestion() {
const question = document.getElementById('questionInput').value.trim();
if (!question) {
alert('请输入问题!');
return;
}
const statusDiv = document.getElementById('status');
const dataArea = document.getElementById('dataArea');
const submitBtn = document.getElementById('submitBtn');
// 清空之前的数据
dataArea.textContent = '';
textBuffer = '';
// 禁用按钮
submitBtn.disabled = true;
submitBtn.textContent = '处理中...';
// 更新状态
statusDiv.textContent = '正在连接...';
statusDiv.className = 'status connecting';
// 使用fetch发送POST请求并处理SSE流
fetch('/api/rag', {
method: 'POST',
headers: {
'Content-Type': 'application/json'
},
body: JSON.stringify({ query: question })
})
.then(response => {
if (!response.ok) {
throw new Error(`HTTP error! status: ${response.status}`);
}
const reader = response.body.getReader();
const decoder = new TextDecoder();
statusDiv.textContent = '连接成功,等待回答...';
statusDiv.className = 'status connected';
function readStream() {
reader.read().then(({ done, value }) => {
if (done) {
statusDiv.textContent = '回答完成';
statusDiv.className = 'status completed';
submitBtn.disabled = false;
submitBtn.textContent = '提问';
return;
}
const chunk = decoder.decode(value, { stream: true });
// SSE数据通常以 'data: ' 开头,并以 '\n\n' 结束
const lines = chunk.split('\n').filter(line => line.trim() !== '');
for (const line of lines) {
if (line.startsWith('data: ')) {
let data = line.substring(6);
if (data.trim() === '[DONE]') {
statusDiv.textContent = '回答完成';
statusDiv.className = 'status completed';
submitBtn.disabled = false;
submitBtn.textContent = '提问';
reader.cancel(); // 停止读取流
return;
}
textBuffer += data;
dataArea.textContent = textBuffer;
dataArea.scrollTop = dataArea.scrollHeight;
}
}
readStream(); // 继续读取下一块
}).catch(error => {
console.error('SSE连接错误:', error);
statusDiv.textContent = `连接错误或已断开: ${error.message}`;
statusDiv.className = 'status error';
submitBtn.disabled = false;
submitBtn.textContent = '提问';
});
}
readStream();
})
.catch(error => {
console.error('Fetch错误:', error);
statusDiv.textContent = `请求发送失败: ${error.message}`;
statusDiv.className = 'status error';
submitBtn.disabled = false;
submitBtn.textContent = '提问';
});
}
</script>
</body>
</html>
Loading…
Cancel
Save