main
HuangHai 4 weeks ago
parent a1b5e5b049
commit 174d8b825b

@ -5,6 +5,16 @@ MYSQL_USER = "root"
MYSQL_PASSWORD = "DsideaL147258369"
MYSQL_DB_NAME = "base_db"
# Elasticsearch配置
ES_CONFIG = {
"hosts": "https://10.10.14.206:9200",
"basic_auth": ("elastic", "jv9h8uwRrRxmDi1dq6u8"),
"verify_certs": False,
"ssl_show_warn": False,
"default_index": "knowledge_base"
}
# Milvus 服务器的主机地址
MS_HOST = "10.10.14.207"
# Milvus 服务器的端口号
@ -29,10 +39,7 @@ DEEPSEEK_URL = 'https://api.deepseek.com'
# 阿里云中用来调用 deepseek v3 的密钥【驿来特】
MODEL_API_KEY = "sk-f6da0c787eff4b0389e4ad03a35a911f"
MODEL_NAME = "qwen-plus"
#MODEL_NAME = "deepseek-v3"
# MODEL_NAME = "deepseek-v3"
# Jieba分词自定义词典配置
JIEBA_CUSTOM_WORDS = [
'文言虚词',
# 可在此添加更多自定义词语
]
JIEBA_CUSTOM_WORDS = ['文言虚词', '花呗']

@ -0,0 +1,20 @@
# 初始化ES连接
import urllib3
from elasticsearch import Elasticsearch
from Config import Config
urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning)
# 初始化ES连接时添加verify_certs=False
es = Elasticsearch(
hosts=Config.ES_CONFIG['hosts'],
basic_auth=Config.ES_CONFIG['basic_auth'],
verify_certs=False # 禁用证书验证
)
# 查询所有index
indices = es.cat.indices(format='json')
print(f"当前ES集群中共有 {len(indices)} 个index:")
for idx in indices:
print(f"- {idx['index']}")

@ -0,0 +1,265 @@
import os
import subprocess
import tempfile
import urllib.parse
import uuid
from contextlib import asynccontextmanager
from io import BytesIO
from logging.handlers import RotatingFileHandler
from typing import List
import jieba # 导入 jieba 分词库
import uvicorn
from fastapi import FastAPI, Request, HTTPException
from fastapi.staticfiles import StaticFiles
from gensim.models import KeyedVectors
from pydantic import BaseModel, Field, ValidationError
from starlette.responses import StreamingResponse
from Config.Config import MS_MODEL_PATH, MS_MODEL_LIMIT, MS_HOST, MS_PORT, MS_MAX_CONNECTIONS, MS_NPROBE, \
MS_COLLECTION_NAME
from Milvus.Utils.MilvusCollectionManager import MilvusCollectionManager
from Milvus.Utils.MilvusConnectionPool import *
from Milvus.Utils.MilvusConnectionPool import MilvusConnectionPool
from Util.ALiYunUtil import ALiYunUtil
# 初始化日志
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
handler = RotatingFileHandler('Logs/start.log', maxBytes=1024 * 1024, backupCount=5)
handler.setFormatter(logging.Formatter('%(asctime)s - %(levelname)s - %(message)s'))
logger.addHandler(handler)
# 1. 加载预训练的 Word2Vec 模型
model = KeyedVectors.load_word2vec_format(MS_MODEL_PATH, binary=False, limit=MS_MODEL_LIMIT)
logger.info(f"模型加载成功,词向量维度: {model.vector_size}")
# 将HTML文件转换为Word文件
def html_to_word_pandoc(html_file, output_file):
subprocess.run(['pandoc', html_file, '-o', output_file])
@asynccontextmanager
async def lifespan(app: FastAPI):
# 初始化Milvus连接池
app.state.milvus_pool = MilvusConnectionPool(host=MS_HOST, port=MS_PORT, max_connections=MS_MAX_CONNECTIONS)
# 初始化集合管理器
app.state.collection_manager = MilvusCollectionManager(MS_COLLECTION_NAME)
app.state.collection_manager.load_collection()
# 初始化阿里云大模型工具
app.state.aliyun_util = ALiYunUtil()
yield
# 关闭Milvus连接池
app.state.milvus_pool.close()
app = FastAPI(lifespan=lifespan)
# 挂载静态文件目录
app.mount("/static", StaticFiles(directory="Static"), name="static")
# 将文本转换为嵌入向量
def text_to_embedding(text):
words = jieba.lcut(text) # 使用 jieba 分词
print(f"文本: {text}, 分词结果: {words}")
embeddings = [model[word] for word in words if word in model]
logger.info(f"有效词向量数量: {len(embeddings)}")
if embeddings:
avg_embedding = sum(embeddings) / len(embeddings)
logger.info(f"生成的平均向量: {avg_embedding[:5]}...") # 打印前 5 维
return avg_embedding
else:
logger.warning("未找到有效词,返回零向量")
return [0.0] * model.vector_size
async def generate_stream(client, milvus_pool, collection_manager, query, documents):
# 从连接池获取连接
connection = milvus_pool.get_connection()
try:
# 1. 将查询文本转换为向量
current_embedding = text_to_embedding(query)
# 2. 搜索相关数据
search_params = {
"metric_type": "L2", # 使用 L2 距离度量方式
"params": {"nprobe": MS_NPROBE} # 设置 IVF_FLAT 的 nprobe 参数
}
# 动态生成expr表达式
if documents:
conditions = [f"array_contains(tags['tags'], '{doc}')" for doc in documents]
expr = " OR ".join(conditions)
else:
expr = "" # 如果没有选择文档,返回空字符串
# 7. 将文本转换为嵌入向量
results = collection_manager.search(current_embedding,
search_params,
expr=expr, # 使用in操作符
limit=5) # 返回 5 条结果
# 3. 处理搜索结果
logger.info("最相关的知识库内容:")
context = ""
if results:
for hits in results:
for hit in hits:
try:
# 查询非向量字段
record = collection_manager.query_by_id(hit.id)
if hit.distance < 0.88: # 设置距离阈值
logger.info(f"ID: {hit.id}")
logger.info(f"标签: {record['tags']}")
logger.info(f"用户问题: {record['user_input']}")
logger.info(f"时间: {record['timestamp']}")
logger.info(f"距离: {hit.distance}")
logger.info("-" * 40) # 分隔线
# 获取完整内容
full_content = record['tags'].get('full_content', record['user_input'])
context = context + full_content
else:
logger.warning(f"距离太远,忽略此结果: {hit.id}")
logger.info(f"标签: {record['tags']}")
logger.info(f"用户问题: {record['user_input']}")
logger.info(f"时间: {record['timestamp']}")
logger.info(f"距离: {hit.distance}")
continue
except Exception as e:
logger.error(f"查询失败: {e}")
else:
logger.warning("未找到相关历史对话,请检查查询参数或数据。")
prompt = f"""
信息检索与回答助手
根据以下关于'{query}'的相关信息
基本信息
- 语言: 中文
- 描述: 根据提供的材料检索信息并回答问题
- 特点: 快速准确提取关键信息清晰简洁地回答
相关信息
{context}
回答要求
1. 依托给定的资料快速准确地回答问题可以添加一些额外的信息但请勿重复内容
2. 如果未提供相关信息请不要回答
3. 如果发现相关信息与原来的问题契合度低也不要回答
4. 使用HTML格式返回包含适当的段落列表和标题标签
5. 确保内容结构清晰便于前端展示
"""
# 调用阿里云大模型
if len(context) > 0:
html_content = client.chat(prompt)
yield {"data": html_content}
else:
yield {"data": "没有在知识库中找到相关的信息,无法回答此问题。"}
except Exception as e:
yield {"data": f"生成报告时出错: {str(e)}"}
finally:
# 释放连接
milvus_pool.release_connection(connection)
"""
http://10.10.21.22:8000/static/ai.html
知识库中有的内容
小学数学中有哪些模型
帮我写一下 如何理解点线的教学设计
知识库中没有的内容
你知道黄海是谁吗
"""
class QueryRequest(BaseModel):
query: str = Field(..., description="用户查询的问题")
documents: List[str] = Field(..., description="用户上传的文档")
class SaveWordRequest(BaseModel):
html: str = Field(..., description="要保存为Word的HTML内容")
@app.post("/api/save-word")
async def save_to_word(request: Request):
temp_html = None
output_file = None
try:
# Parse request data
try:
data = await request.json()
html_content = data.get('html_content', '')
if not html_content:
raise ValueError("Empty HTML content")
except Exception as e:
logger.error(f"Request parsing failed: {str(e)}")
raise HTTPException(status_code=400, detail=f"Invalid request: {str(e)}")
# 创建临时HTML文件
temp_html = os.path.join(tempfile.gettempdir(), uuid.uuid4().hex + ".html")
with open(temp_html, "w", encoding="utf-8") as f:
f.write(html_content)
# 使用pandoc转换
output_file = os.path.join(tempfile.gettempdir(), "小学数学问答.docx")
subprocess.run(['pandoc', temp_html, '-o', output_file], check=True)
# 读取生成的Word文件
with open(output_file, "rb") as f:
stream = BytesIO(f.read())
# 返回响应
encoded_filename = urllib.parse.quote("小学数学问答.docx")
return StreamingResponse(
stream,
media_type="application/vnd.openxmlformats-officedocument.wordprocessingml.document",
headers={"Content-Disposition": f"attachment; filename*=UTF-8''{encoded_filename}"})
except HTTPException:
raise
except Exception as e:
logger.error(f"Unexpected error: {str(e)}")
raise HTTPException(status_code=500, detail="Internal server error")
finally:
# 清理临时文件
try:
if temp_html and os.path.exists(temp_html):
os.remove(temp_html)
if output_file and os.path.exists(output_file):
os.remove(output_file)
except Exception as e:
logger.warning(f"Failed to clean up temp files: {str(e)}")
@app.post("/api/rag")
async def rag_stream(request: Request):
try:
data = await request.json()
query_request = QueryRequest(**data)
except ValidationError as e:
logger.error(f"请求体验证失败: {e.errors()}")
raise HTTPException(status_code=422, detail=e.errors())
except Exception as e:
logger.error(f"请求解析失败: {str(e)}")
raise HTTPException(status_code=400, detail="无效的请求格式")
"""RAG+ALiYun接口"""
async for chunk in generate_stream(
request.app.state.aliyun_util,
request.app.state.milvus_pool,
request.app.state.collection_manager,
query_request.query,
query_request.documents
):
return chunk
if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=8000)
Loading…
Cancel
Save