main
HuangHai 4 months ago
parent c9b8601646
commit 4ce91f1f5b

@ -1,134 +1,87 @@
from Text2Sql.Util.PostgreSQLUtil import PostgreSQLUtil, postgresql_pool
import asyncpg
# 删除数据
def delete_question(question_id: str):
with PostgreSQLUtil(postgresql_pool.getconn()) as db:
# 删除 t_bi_question 表中的数据
delete_sql = """
DELETE FROM t_bi_question WHERE id = %s
"""
db.execute_query(delete_sql, (question_id,))
async def delete_question(db: asyncpg.Connection, question_id: str):
delete_sql = """
DELETE FROM t_bi_question WHERE id = $1
"""
await db.execute(delete_sql, question_id)
# 插入数据
def insert_question(question_id: str, question: str):
# 向 t_bi_question 表插入数据
with PostgreSQLUtil(postgresql_pool.getconn()) as db:
insert_sql = """
INSERT INTO t_bi_question (id,question, state_id, is_system, is_collect)
VALUES (%s,%s, %s, %s, %s)
"""
db.execute_query(insert_sql, (question_id, question, 0, 0, 0))
async def insert_question(db: asyncpg.Connection, question_id: str, question: str):
insert_sql = """
INSERT INTO t_bi_question (id, question, state_id, is_system, is_collect)
VALUES ($1, $2, $3, $4, $5)
"""
await db.execute(insert_sql, question_id, question, 0, 0, 0)
# 修改数据
'''
示例:
# 更新 question 和 state_id 字段
update_question_by_id(db, question_id=1, question="新的问题描述", state_id=1)
# 只更新 excel_file_name 字段
update_question_by_id(db, question_id=1, excel_file_name="new_excel.xlsx")
# 只更新 is_collect 字段
update_question_by_id(db, question_id=1, is_collect=1)
# 不更新任何字段(因为所有参数都是 None
update_question_by_id(db, question_id=1, question=None, state_id=None)
'''
def update_question_by_id(question_id: str, **kwargs):
"""
根据主键更新 t_bi_question 只更新非 None 的字段
:param db: PostgreSQLUtil 实例
:param question_id: 主键 id
:param kwargs: 需要更新的字段和值
:return: 更新是否成功
"""
# 过滤掉值为 None 的字段
async def update_question_by_id(db: asyncpg.Connection, question_id: str, **kwargs):
update_fields = {k: v for k, v in kwargs.items() if v is not None}
if not update_fields:
return False # 没有需要更新的字段
# 动态构建 SET 子句
set_clause = ", ".join([f"{field} = %s" for field in update_fields.keys()])
with PostgreSQLUtil(postgresql_pool.getconn()) as db:
# 构建完整 SQL
sql = f"""
UPDATE t_bi_question
SET {set_clause}
WHERE id = %s
"""
# 参数列表
params = list(update_fields.values()) + [question_id]
# 执行更新
try:
db.execute_query(sql, params)
return True
except Exception as e:
print(f"更新失败: {e}")
return False
# 根据 问题id 查询 sql
def get_question_by_id(question_id: str):
with PostgreSQLUtil(postgresql_pool.getconn()) as db:
select_sql = """
select * from t_bi_question where id=%s
"""
_data = db.execute_query(select_sql, (question_id,))
return _data
return False
set_clause = ", ".join([f"{field} = ${i+1}" for i, field in enumerate(update_fields.keys())])
sql = f"""
UPDATE t_bi_question
SET {set_clause}
WHERE id = ${len(update_fields) + 1}
"""
params = list(update_fields.values()) + [question_id]
try:
await db.execute(sql, *params)
return True
except Exception as e:
print(f"更新失败: {e}")
return False
# 根据问题 ID 查询 SQL
async def get_question_by_id(db: asyncpg.Connection, question_id: str):
select_sql = """
SELECT * FROM t_bi_question WHERE id = $1
"""
_data = await db.fetch(select_sql, question_id)
return _data
# 保存系统推荐
def set_system_recommend_questions(question_id: str, flag: str):
with PostgreSQLUtil(postgresql_pool.getconn()) as db:
sql = f"""
UPDATE t_bi_question
SET is_system =%s WHERE id = %s
"""
# 执行更新
try:
db.execute_query(sql, int(flag), question_id)
return True
except Exception as e:
print(f"更新失败: {e}")
return False
async def set_system_recommend_questions(db: asyncpg.Connection, question_id: str, flag: str):
sql = """
UPDATE t_bi_question
SET is_system = $1 WHERE id = $2
"""
try:
await db.execute(sql, int(flag), question_id)
return True
except Exception as e:
print(f"更新失败: {e}")
return False
# 设置用户收藏
def set_user_collect_questions(question_id: str, flag: str):
sql = f"""
UPDATE t_bi_question
SET is_collect =%s WHERE id = %s
"""
with PostgreSQLUtil(postgresql_pool.getconn()) as db:
# 执行更新
try:
db.execute_query(sql, int(flag), question_id)
return True
except Exception as e:
print(f"更新失败: {e}")
return False
# 查询有哪些系统推荐问题
def get_system_recommend_questions():
async def set_user_collect_questions(db: asyncpg.Connection, question_id: str, flag: str):
sql = """
SELECT * FROM t_bi_question WHERE is_system = 1
UPDATE t_bi_question
SET is_collect = $1 WHERE id = $2
"""
with PostgreSQLUtil(postgresql_pool.getconn()) as db:
_data = db.execute_query(sql)
return _data
try:
await db.execute(sql, int(flag), question_id)
return True
except Exception as e:
print(f"更新失败: {e}")
return False
# 查询系统推荐问题
async def get_system_recommend_questions(db: asyncpg.Connection):
sql = """
SELECT * FROM t_bi_question WHERE is_system = 1
"""
_data = await db.fetch(sql)
return _data
# 查询有哪些用户收藏问题
def get_user_collect_questions():
# 从t_bi_question表中获取所有is_collect=1的数据
sql="""
SELECT * FROM t_bi_question WHERE is_collect = 1
# 查询用户收藏问题
async def get_user_collect_questions(db: asyncpg.Connection):
sql = """
SELECT * FROM t_bi_question WHERE is_collect = 1
"""
with PostgreSQLUtil(postgresql_pool.getconn()) as db:
_data = db.execute_query(sql)
return _data
_data = await db.fetch(sql)
return _data

@ -1,86 +0,0 @@
import json
from datetime import date, datetime
import psycopg2
from psycopg2 import pool
from psycopg2.extras import RealDictCursor
from Config import *
# 创建连接池
postgresql_pool = psycopg2.pool.SimpleConnectionPool(
minconn=1,
maxconn=10,
host=PG_HOST,
port=PG_PORT,
dbname=PG_DATABASE,
user=PG_USER,
password=PG_PASSWORD
)
class PostgreSQLUtil:
def __init__(self, connection):
self.connection = connection
def __enter__(self):
return self
def __exit__(self, exc_type, exc_val, exc_tb):
self.connection.commit()
postgresql_pool.putconn(self.connection)
def execute_query(self, sql, params=None, return_dict=True):
"""执行查询并返回结果"""
try:
with self.connection.cursor(
cursor_factory=RealDictCursor if return_dict else None
) as cursor:
cursor.execute(sql, params)
if cursor.description:
columns = [desc[0] for desc in cursor.description]
results = cursor.fetchall()
# 转换字典格式
if return_dict:
return results
else:
return [dict(zip(columns, row)) for row in results]
else:
return {"rowcount": cursor.rowcount}
except Exception as e:
print(f"执行SQL出错: {e}")
self.connection.rollback()
raise
def get_db():
connection = postgresql_pool.getconn()
try:
yield PostgreSQLUtil(connection)
finally:
postgresql_pool.putconn(connection)
# 使用示例
if __name__ == "__main__":
'''
db_gen = get_db()调用生成器函数返回生成器对象
db = next(db_gen)从生成器中获取 PostgreSQLUtil 实例
生成器函数确保连接在使用后正确归还到连接池
'''
# 从生成器中获取数据库实例
db_gen = get_db()
db = next(db_gen)
try:
# 示例查询
result = db.execute_query("SELECT version()")
print("数据库版本:", result)
# 返回JSON
json_data = db.query_to_json("SELECT * FROM t_base_class LIMIT 2")
print("JSON结果:", json_data)
finally:
# 手动关闭生成器
db_gen.close()

@ -1,13 +1,17 @@
import json
import uuid
import uvicorn # 导入 uvicorn
from datetime import date, datetime
from asyncpg.pool import Pool
from fastapi import FastAPI, Form, Query
from openai import OpenAI
from fastapi.staticfiles import StaticFiles
from contextlib import asynccontextmanager
from fastapi import FastAPI, Depends
import asyncpg
import uvicorn
from openai import AsyncOpenAI
from starlette.responses import StreamingResponse
from starlette.staticfiles import StaticFiles
from Config import MODEL_API_KEY, MODEL_API_URL, QWEN_MODEL_NAME
from Config import *
from Model.biModel import *
from Text2Sql.Util.MarkdownToDocxUtil import markdown_to_docx
from Text2Sql.Util.SaveToExcel import save_to_excel
@ -15,22 +19,66 @@ from Text2Sql.Util.VannaUtil import VannaUtil
# 初始化 FastAPI
app = FastAPI()
# 配置静态文件目录
app.mount("/static", StaticFiles(directory="static"), name="static")
# 初始化一次vanna的类
vn = VannaUtil()
# 初始化 FastAPI 应用
@asynccontextmanager
async def lifespan(app: FastAPI):
# 启动时初始化连接池
app.state.pool = await asyncpg.create_pool(
host=PG_HOST,
port=PG_PORT,
database=PG_DATABASE,
user=PG_USER,
password=PG_PASSWORD,
min_size=1,
max_size=10
)
yield
# 关闭时释放连接池
await app.state.pool.close()
app = FastAPI(lifespan=lifespan)
# 依赖注入连接池
async def get_db():
async with app.state.pool.acquire() as connection:
yield connection
class PostgreSQLUtil:
def __init__(self, pool: Pool):
self.pool = pool
async def execute_query(self, sql, params=None):
async with self.pool.acquire() as connection:
result = await connection.fetch(sql, params)
return result
async def query_to_json(self, sql, params=None):
data = await self.execute_query(sql, params)
return json.dumps(data, default=self.json_serializer)
@staticmethod
def json_serializer(obj):
"""处理JSON无法序列化的类型"""
if isinstance(obj, (date, datetime)):
return obj.isoformat()
raise TypeError(f"Type {type(obj)} not serializable")
async def create_pool():
return await asyncpg.create_pool(
host=PG_HOST,
port=PG_PORT,
database=PG_DATABASE,
user=PG_USER,
password=PG_PASSWORD,
min_size=1,
max_size=10
)
@app.get("/")
def read_root():
return {"message": "Welcome to AI SQL World!"}
# 通过语义生成Excel
# http://10.10.21.20:8000/questions/get_excel
@app.post("/questions/get_excel")
def get_excel(question_id: str = Form(...), question_str: str = Form(...)):
async def get_excel(question_id: str = Form(...), question_str: str = Form(...), db: asyncpg.Connection = Depends(get_db)):
# 只接受guid号
if len(question_id) != 36:
return {"success": False, "message": "question_id格式错误"}
@ -43,34 +91,40 @@ def get_excel(question_id: str = Form(...), question_str: str = Form(...)):
question = question_str + common_prompt
# 先删除后插入,防止重复插入
delete_question(question_id)
insert_question(question_id, question)
await delete_question(db, question_id)
await insert_question(db, question_id, question)
# 获取完整 SQL
sql = vn.generate_sql(question)
print("生成的查询 SQL:\n", sql)
# 更新question_id
update_question_by_id(question_id=question_id, sql=sql, state_id=1)
await update_question_by_id(db, question_id=question_id, sql=sql, state_id=1)
# 执行SQL查询
with PostgreSQLUtil(postgresql_pool.getconn()) as db:
_data = db.execute_query(sql)
_data = await db.fetch(sql)
# 在static目录下生成一个guid号的临时文件
uuid_str = str(uuid.uuid4())
filename = f"static/{uuid_str}.xlsx"
save_to_excel(_data, filename)
# 更新EXCEL文件名称
update_question_by_id(question_id, excel_file_name=filename)
await update_question_by_id(db, question_id, excel_file_name=filename)
# 返回静态文件URL
return {"success": True, "message": "Excel文件生成成功", "download_url": f"/static/{uuid_str}.xlsx"}
# http://10.10.21.20:8000/questions/get_docx_stream?question_id_get=af15d834-e7f5-46b4-a0f6-15f1f888f443
# 初始化 OpenAI 客户端
client = AsyncOpenAI(
api_key=MODEL_API_KEY,
base_url=MODEL_API_URL,
)
@app.api_route("/questions/get_docx_stream", methods=["POST", "GET"])
async def get_docx_stream(
question_id: str = Form(None, description="问题IDPOST请求"), # POST 请求参数
question_id_get: str = Query(None, description="问题IDGET请求") # GET 请求参数
question_id_get: str = Query(None, description="问题IDGET请求"), # GET 请求参数
db: asyncpg.Connection = Depends(get_db)
):
# 根据请求方式获取 question_id
if question_id is not None: # POST 请求
@ -81,8 +135,9 @@ async def get_docx_stream(
return {"success": False, "message": "缺少问题ID参数"}
# 根据问题ID获取查询sql
sql = get_question_by_id(question_id)[0]['sql']
# 4、生成word报告
sql = (await db.fetch("SELECT * FROM t_bi_question WHERE id = $1", question_id))[0]['sql']
# 生成word报告
prompt = '''
请根据以下 JSON 数据整理出2000字左右的话描述当前数据情况要求
1以Markdown格式返回我将直接通过markdown格式生成Word
@ -91,19 +146,15 @@ async def get_docx_stream(
4尽量以条目列出这样更清晰
5数据
'''
with PostgreSQLUtil(postgresql_pool.getconn()) as db:
_data = db.execute_query(sql)
prompt = prompt + json.dumps(_data, ensure_ascii=False)
# 初始化 OpenAI 客户端
client = OpenAI(
api_key=MODEL_API_KEY,
base_url=MODEL_API_URL,
)
_data = await db.fetch(sql)
#print(_data)
# 将 asyncpg.Record 转换为 JSON 格式
json_data = json.dumps([dict(record) for record in _data], ensure_ascii=False)
print(json_data) # 打印 JSON 数据
prompt = prompt + json.dumps(json_data, ensure_ascii=False)
# 调用 OpenAI API 生成总结(流式输出)
response = client.chat.completions.create(
#model=MODEL_NAME,
response = await client.chat.completions.create(
model=QWEN_MODEL_NAME,
messages=[
{"role": "system", "content": "你是一个数据分析助手,擅长从 JSON 数据中提取关键信息并生成详细的总结。"},
@ -122,7 +173,7 @@ async def get_docx_stream(
async def generate_stream():
summary = ""
try:
for chunk in response:
async for chunk in response: # 使用 async for 处理流式响应
if chunk.choices[0].delta.content: # 检查是否有内容
chunk_content = chunk.choices[0].delta.content
# 逐字拆分并返回
@ -135,7 +186,7 @@ async def get_docx_stream(
markdown_to_docx(summary, output_file=filename)
# 记录到数据库
update_question_by_id(question_id, docx_file_name=filename)
await db.execute("UPDATE t_bi_question SET docx_file_name = $1 WHERE id = $2", filename, question_id)
except Exception as e:
# 如果发生异常,返回错误信息
@ -149,7 +200,7 @@ async def get_docx_stream(
finally:
# 确保资源释放
if "response" in locals():
response.close()
await response.aclose()
# 使用 StreamingResponse 返回流式结果
return StreamingResponse(
@ -164,60 +215,6 @@ async def get_docx_stream(
}
)
# 返回生成的Word文件下载地址
# http://10.10.21.20:8000/questions/get_docx_file?question_id_get=af15d834-e7f5-46b4-a0f6-15f1f888f443
@app.api_route("/questions/get_docx_file", methods=["POST", "GET"])
async def get_docx_file(
question_id: str = Form(None, description="问题IDPOST请求"), # POST 请求参数
question_id_get: str = Query(None, description="问题IDGET请求"), # GET 请求参数
):
# 根据请求方式获取 question_id
if question_id is not None: # POST 请求
question_id = question_id
elif question_id_get is not None: # GET 请求
question_id = question_id_get
else:
return {"success": False, "message": "缺少问题ID参数"}
# 根据问题ID获取查询docx_file_name
docx_file_name = get_question_by_id(question_id)[0]['docx_file_name']
# 返回成功和静态文件的URL
return {"success": True, "message": "Word文件生成成功", "download_url": f"{docx_file_name}"}
# 设置问题为系统推荐问题 ,0:取消1设置
@app.post("/questions/set_system_recommend")
def set_system_recommend(question_id: str = Form(...), flag: str = Form(...)):
set_system_recommend_questions(question_id, flag)
# 提示保存成功
return {"success": True, "message": "保存成功"}
# 设置问题为用户收藏问题 ,0:取消1设置
@app.post("/questions/set_user_collect")
def set_user_collect(question_id: str = Form(...), flag: str = Form(...)):
set_user_collect_questions(question_id, flag)
# 提示保存成功
return {"success": True, "message": "保存成功"}
# 查询有哪些系统推荐问题
@app.get("/questions/get_system_recommend")
def get_system_recommend():
# 查询所有系统推荐问题
system_recommend_questions = get_system_recommend_questions()
# 返回查询结果
return {"success": True, "data": system_recommend_questions}
# 查询有哪些用户收藏问题
@app.get("/questions/get_user_collect")
def get_user_collect():
# 查询所有用户收藏问题
user_collect_questions = get_user_collect_questions()
# 返回查询结果
return {"success": True, "data": user_collect_questions}
# 确保直接运行脚本时启动 FastAPI 应用
# 启动 FastAPI
if __name__ == "__main__":
uvicorn.run("app:app", host="0.0.0.0", port=8000, reload=True)
uvicorn.run("app:app", host="0.0.0.0", port=8000, reload=True)
Loading…
Cancel
Save