main
HuangHai 3 weeks ago
parent fa674aa9d3
commit fa7ce6fada

@ -1,6 +1,21 @@
import logging
import os
import subprocess
import tempfile
import urllib.parse
import uuid
import warnings
from io import BytesIO
from logging.handlers import RotatingFileHandler
import fastapi
import uvicorn
from fastapi import FastAPI, HTTPException
from starlette.staticfiles import StaticFiles
from Util.ALiYunUtil import ALiYunUtil
from Util.SearchUtil import *
# 初始化日志
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
@ -24,7 +39,6 @@ logger.addHandler(file_handler)
logger.addHandler(console_handler)
async def lifespan(app: FastAPI):
# 初始化阿里云大模型工具
app.state.aliyun_util = ALiYunUtil()
@ -42,7 +56,7 @@ app.mount("/static", StaticFiles(directory="Static"), name="static")
@app.post("/api/save-word")
async def save_to_word(request: Request):
async def save_to_word(request: fastapi.Request):
output_file = None
try:
# Parse request data
@ -91,18 +105,17 @@ async def save_to_word(request: Request):
logger.warning(f"Failed to clean up temp files: {str(e)}")
@app.post("/api/rag")
async def rag(request: Request):
@app.post("/api/rag", response_model=None)
async def rag(request: fastapi.Request):
data = await request.json()
query = data.get('query', '')
query_tags = data.get('tags', [])
# 调用es进行混合搜索
search_results = queryByEs(query, query_tags)
search_results = queryByEs(query, query_tags, logger)
# 调用大模型
markdown_content = callLLM(request, query, search_results)
markdown_content = callLLM(request, query, search_results, logger)
# 如果有正确的结果
if markdown_content:

@ -0,0 +1,4 @@
from typing import List
from pydantic import BaseModel, Field

@ -1,29 +1,10 @@
import logging
import os
import subprocess
import tempfile
import urllib.parse
import uuid
from io import BytesIO
from logging.handlers import RotatingFileHandler
from typing import List
import uvicorn
from fastapi import FastAPI, Request, HTTPException
from fastapi.staticfiles import StaticFiles
from pydantic import BaseModel, Field
from starlette.responses import StreamingResponse
from Config.Config import ES_CONFIG
import warnings
from Util.ALiYunUtil import ALiYunUtil
from Util.EsSearchUtil import EsSearchUtil
# 初始化日志
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
def queryByEs(query, query_tags):
def queryByEs(query, query_tags,logger):
# 获取EsSearchUtil实例
es_search_util = EsSearchUtil(ES_CONFIG)
@ -144,7 +125,7 @@ def queryByEs(query, query_tags):
es_search_util.es_pool.release_connection(es_conn)
def callLLM(request, query, search_results):
def callLLM(request, query, search_results, logger,streamBack=False):
# 调用阿里云大模型整合结果
aliyun_util = request.app.state.aliyun_util
@ -183,7 +164,16 @@ def callLLM(request, query, search_results):
if len(context) > 0:
# 调用大模型生成回答
logger.info("正在调用阿里云大模型生成回答...")
markdown_content = aliyun_util.chat(prompt)
logger.info(f"调用阿里云大模型生成回答成功完成!")
return markdown_content
if streamBack:
# SSE流式返回
async def generate():
async for chunk in aliyun_util.chat_stream(prompt):
yield f"data: {chunk}\n\n"
return StreamingResponse(generate(), media_type="text/event-stream")
else:
# 一次性返回
markdown_content = aliyun_util.chat(prompt)
logger.info(f"调用阿里云大模型生成回答成功完成!")
return markdown_content
return None

Loading…
Cancel
Save