You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

272 lines
10 KiB

1 month ago
import logging
1 month ago
import os
import subprocess
import tempfile
import urllib.parse
import uuid
from io import BytesIO
from logging.handlers import RotatingFileHandler
from typing import List
import uvicorn
from fastapi import FastAPI, Request, HTTPException
from fastapi.staticfiles import StaticFiles
1 month ago
from pydantic import BaseModel, Field
1 month ago
from starlette.responses import StreamingResponse
1 month ago
from Config.Config import ES_CONFIG
1 month ago
from Util.ALiYunUtil import ALiYunUtil
1 month ago
from Util.EsSearchUtil import EsSearchUtil
1 month ago
# 初始化日志
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
1 month ago
# 配置日志处理器
log_file = os.path.join(os.path.dirname(__file__), 'Logs', 'app.log')
os.makedirs(os.path.dirname(log_file), exist_ok=True)
# 文件处理器
file_handler = RotatingFileHandler(
log_file, maxBytes=1024*1024, backupCount=5, encoding='utf-8')
file_handler.setFormatter(logging.Formatter(
'%(asctime)s - %(name)s - %(levelname)s - %(message)s'))
# 控制台处理器
console_handler = logging.StreamHandler()
console_handler.setFormatter(logging.Formatter(
'%(asctime)s - %(name)s - %(levelname)s - %(message)s'))
logger.addHandler(file_handler)
logger.addHandler(console_handler)
1 month ago
# 将HTML文件转换为Word文件
def html_to_word_pandoc(html_file, output_file):
subprocess.run(['pandoc', html_file, '-o', output_file])
async def lifespan(app: FastAPI):
# 初始化阿里云大模型工具
app.state.aliyun_util = ALiYunUtil()
yield
1 month ago
1 month ago
app = FastAPI(lifespan=lifespan)
# 挂载静态文件目录
app.mount("/static", StaticFiles(directory="Static"), name="static")
class QueryRequest(BaseModel):
query: str = Field(..., description="用户查询的问题")
documents: List[str] = Field(..., description="用户上传的文档")
class SaveWordRequest(BaseModel):
html: str = Field(..., description="要保存为Word的HTML内容")
@app.post("/api/save-word")
async def save_to_word(request: Request):
temp_html = None
output_file = None
try:
# Parse request data
try:
data = await request.json()
html_content = data.get('html_content', '')
if not html_content:
raise ValueError("Empty HTML content")
except Exception as e:
logger.error(f"Request parsing failed: {str(e)}")
raise HTTPException(status_code=400, detail=f"Invalid request: {str(e)}")
# 创建临时HTML文件
temp_html = os.path.join(tempfile.gettempdir(), uuid.uuid4().hex + ".html")
with open(temp_html, "w", encoding="utf-8") as f:
f.write(html_content)
# 使用pandoc转换
1 month ago
output_file = os.path.join(tempfile.gettempdir(), "【理想大模型】问答.docx")
1 month ago
subprocess.run(['pandoc', temp_html, '-o', output_file], check=True)
# 读取生成的Word文件
with open(output_file, "rb") as f:
stream = BytesIO(f.read())
# 返回响应
1 month ago
encoded_filename = urllib.parse.quote("【理想大模型】问答.docx")
1 month ago
return StreamingResponse(
stream,
media_type="application/vnd.openxmlformats-officedocument.wordprocessingml.document",
headers={"Content-Disposition": f"attachment; filename*=UTF-8''{encoded_filename}"})
except HTTPException:
raise
except Exception as e:
logger.error(f"Unexpected error: {str(e)}")
raise HTTPException(status_code=500, detail="Internal server error")
finally:
# 清理临时文件
try:
if temp_html and os.path.exists(temp_html):
os.remove(temp_html)
if output_file and os.path.exists(output_file):
os.remove(output_file)
except Exception as e:
logger.warning(f"Failed to clean up temp files: {str(e)}")
@app.post("/api/rag")
async def rag_stream(request: Request):
1 month ago
try:
data = await request.json()
query = data.get('query', '')
query_tags = data.get('tags', [])
1 month ago
1 month ago
# 获取EsSearchUtil实例
es_search_util = EsSearchUtil(ES_CONFIG)
1 month ago
1 month ago
# 执行混合搜索
es_conn = es_search_util.es_pool.get_connection()
try:
# 向量搜索
1 month ago
logger.info(f"\n=== 开始执行查询 ===")
logger.info(f"原始查询文本: {query}")
logger.info(f"查询标签: {query_tags}")
1 month ago
1 month ago
logger.info("\n=== 向量搜索阶段 ===")
logger.info("1. 文本分词和向量化处理中...")
1 month ago
query_embedding = es_search_util.text_to_embedding(query)
1 month ago
logger.info(f"2. 生成的查询向量维度: {len(query_embedding)}")
logger.info(f"3. 前3维向量值: {query_embedding[:3]}")
1 month ago
1 month ago
logger.info("4. 正在执行Elasticsearch向量搜索...")
1 month ago
vector_results = es_conn.search(
index=ES_CONFIG['index_name'],
body={
"query": {
"script_score": {
"query": {
"bool": {
"should": [
{
"terms": {
"tags.tags": query_tags
}
}
],
"minimum_should_match": 1
}
},
"script": {
"source": "double score = cosineSimilarity(params.query_vector, 'embedding'); return score >= 0 ? score : 0",
"params": {"query_vector": query_embedding}
}
}
},
"size": 3
}
)
1 month ago
logger.info(f"5. 向量搜索结果数量: {len(vector_results['hits']['hits'])}")
1 month ago
1 month ago
# 文本精确搜索
1 month ago
logger.info("\n=== 文本精确搜索阶段 ===")
logger.info("1. 正在执行Elasticsearch文本精确搜索...")
1 month ago
text_results = es_conn.search(
index=ES_CONFIG['index_name'],
body={
"query": {
"bool": {
"must": [
{
"match": {
"user_input": query
}
},
{
"terms": {
"tags.tags": query_tags
}
}
]
}
},
"size": 3
}
)
1 month ago
logger.info(f"2. 文本搜索结果数量: {len(text_results['hits']['hits'])}")
1 month ago
1 month ago
# 合并结果
1 month ago
logger.info("\n=== 最终搜索结果 ===")
logger.info(f"向量搜索结果: {len(vector_results['hits']['hits'])}")
for i, hit in enumerate(vector_results['hits']['hits'], 1):
logger.info(f" {i}. 文档ID: {hit['_id']}, 相似度分数: {hit['_score']:.2f}")
logger.info(f" 内容: {hit['_source']['user_input']}")
1 month ago
1 month ago
logger.info("文本精确搜索结果:")
for i, hit in enumerate(text_results['hits']['hits']):
1 month ago
logger.info(f" {i + 1}. 文档ID: {hit['_id']}, 匹配分数: {hit['_score']:.2f}")
1 month ago
logger.info(f" 内容: {hit['_source']['user_input']}")
1 month ago
1 month ago
search_results = {
1 month ago
"vector_results": [hit['_source'] for hit in vector_results['hits']['hits']],
"text_results": [hit['_source'] for hit in text_results['hits']['hits']]
}
1 month ago
1 month ago
# 调用阿里云大模型整合结果
aliyun_util = request.app.state.aliyun_util
1 month ago
1 month ago
# 构建提示词
context = "\n".join([
1 month ago
f"结果{i + 1}: {res['tags']['full_content']}"
1 month ago
for i, res in enumerate(search_results['vector_results'] + search_results['text_results'])
])
prompt = f"""
信息检索与回答助手
根据以下关于'{query}'的相关信息
基本信息
- 语言: 中文
- 描述: 根据提供的材料检索信息并回答问题
- 特点: 快速准确提取关键信息清晰简洁地回答
相关信息
{context}
回答要求
1. 依托给定的资料快速准确地回答问题可以添加一些额外的信息但请勿重复内容
2. 如果未提供相关信息请不要回答
3. 如果发现相关信息与原来的问题契合度低也不要回答
1 month ago
4. 使用HTML格式返回包含适当的段落列表和标题标签,一定不要使用 ```html 或者 ```!
1 month ago
5. 确保内容结构清晰便于前端展示
"""
# 调用阿里云大模型
if len(context) > 0:
# 调用大模型生成回答
1 month ago
logger.info("正在调用阿里云大模型生成回答...")
1 month ago
html_content = aliyun_util.chat(prompt)
1 month ago
logger.info(f"调用阿里云大模型生成回答成功完成!")
1 month ago
return {"data": html_content}
else:
logger.warning(f"未找到查询'{query}'的相关数据tags: {query_tags}")
1 month ago
return {"data": "没有在知识库中找到相关的信息,无法回答此问题。",
"debug": {"query": query, "tags": query_tags}}
1 month ago
except Exception as e:
return {"data": f"生成报告时出错: {str(e)}"}
1 month ago
1 month ago
1 month ago
finally:
es_search_util.es_pool.release_connection(es_conn)
1 month ago
1 month ago
except Exception as e:
logger.error(f"RAG search error: {str(e)}")
raise HTTPException(status_code=500, detail=str(e))
1 month ago
if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=8000)