main
HuangHai 4 months ago
parent 6fc1c0449c
commit fd08e3b904

@ -1,23 +1,25 @@
from Text2Sql.Util.PostgreSQLUtil import PostgreSQLUtil from Text2Sql.Util.PostgreSQLUtil import PostgreSQLUtil, postgresql_pool
# 删除数据 # 删除数据
def delete_question(db, question_id: str): def delete_question(question_id: str):
# 删除 t_bi_question 表中的数据 with PostgreSQLUtil(postgresql_pool.getconn()) as db:
delete_sql = """ # 删除 t_bi_question 表中的数据
DELETE FROM t_bi_question WHERE id = %s delete_sql = """
""" DELETE FROM t_bi_question WHERE id = %s
db.execute_query(delete_sql, (question_id,)) """
db.execute_query(delete_sql, (question_id,))
# 插入数据 # 插入数据
def insert_question(db, question_id: str, question: str): def insert_question(question_id: str, question: str):
# 向 t_bi_question 表插入数据 # 向 t_bi_question 表插入数据
insert_sql = """ with PostgreSQLUtil(postgresql_pool.getconn()) as db:
INSERT INTO t_bi_question (id,question, state_id, is_system, is_collect) insert_sql = """
VALUES (%s,%s, %s, %s, %s) INSERT INTO t_bi_question (id,question, state_id, is_system, is_collect)
""" VALUES (%s,%s, %s, %s, %s)
db.execute_query(insert_sql, (question_id, question, 0, 0, 0)) """
db.execute_query(insert_sql, (question_id, question, 0, 0, 0))
# 修改数据 # 修改数据
@ -37,7 +39,7 @@ update_question_by_id(db, question_id=1, question=None, state_id=None)
''' '''
def update_question_by_id(db: PostgreSQLUtil, question_id: str, **kwargs): def update_question_by_id(question_id: str, **kwargs):
""" """
根据主键更新 t_bi_question 只更新非 None 的字段 根据主键更新 t_bi_question 只更新非 None 的字段
:param db: PostgreSQLUtil 实例 :param db: PostgreSQLUtil 实例
@ -52,74 +54,81 @@ def update_question_by_id(db: PostgreSQLUtil, question_id: str, **kwargs):
# 动态构建 SET 子句 # 动态构建 SET 子句
set_clause = ", ".join([f"{field} = %s" for field in update_fields.keys()]) set_clause = ", ".join([f"{field} = %s" for field in update_fields.keys()])
# 构建完整 SQL
sql = f"""
UPDATE t_bi_question
SET {set_clause}
WHERE id = %s
"""
# 参数列表
params = list(update_fields.values()) + [question_id]
# 执行更新 with PostgreSQLUtil(postgresql_pool.getconn()) as db:
try: # 构建完整 SQL
db.execute_query(sql, params) sql = f"""
return True UPDATE t_bi_question
except Exception as e: SET {set_clause}
print(f"更新失败: {e}") WHERE id = %s
return False """
# 参数列表
params = list(update_fields.values()) + [question_id]
# 执行更新
try:
db.execute_query(sql, params)
return True
except Exception as e:
print(f"更新失败: {e}")
return False
# 根据 问题id 查询 sql # 根据 问题id 查询 sql
def get_question_by_id(db, question_id: str): def get_question_by_id(question_id: str):
select_sql = """ with PostgreSQLUtil(postgresql_pool.getconn()) as db:
select * from t_bi_question where id=%s select_sql = """
""" select * from t_bi_question where id=%s
_data = db.execute_query(select_sql, (question_id,)) """
return _data _data = db.execute_query(select_sql, (question_id,))
return _data
# 保存系统推荐 # 保存系统推荐
def set_system_recommend_questions(db, question_id: str, flag: str): def set_system_recommend_questions(question_id: str, flag: str):
sql = f""" with PostgreSQLUtil(postgresql_pool.getconn()) as db:
UPDATE t_bi_question sql = f"""
SET is_system =%s WHERE id = %s UPDATE t_bi_question
""" SET is_system =%s WHERE id = %s
# 执行更新 """
try: # 执行更新
db.execute_query(sql, int(flag), question_id) try:
return True db.execute_query(sql, int(flag), question_id)
except Exception as e: return True
print(f"更新失败: {e}") except Exception as e:
return False print(f"更新失败: {e}")
return False
# 设置用户收藏 # 设置用户收藏
def set_user_collect_questions(db, question_id: str, flag: str): def set_user_collect_questions(question_id: str, flag: str):
sql = f""" sql = f"""
UPDATE t_bi_question UPDATE t_bi_question
SET is_collect =%s WHERE id = %s SET is_collect =%s WHERE id = %s
""" """
# 执行更新 with PostgreSQLUtil(postgresql_pool.getconn()) as db:
try: # 执行更新
db.execute_query(sql, int(flag), question_id) try:
return True db.execute_query(sql, int(flag), question_id)
except Exception as e: return True
print(f"更新失败: {e}") except Exception as e:
return False print(f"更新失败: {e}")
return False
# 查询有哪些系统推荐问题 # 查询有哪些系统推荐问题
def get_system_recommend_questions(db): def get_system_recommend_questions():
sql = """ sql = """
SELECT * FROM t_bi_question WHERE is_system = 1 SELECT * FROM t_bi_question WHERE is_system = 1
""" """
_data = db.execute_query(sql) with PostgreSQLUtil(postgresql_pool.getconn()) as db:
return _data _data = db.execute_query(sql)
return _data
# 查询有哪些用户收藏问题 # 查询有哪些用户收藏问题
def get_user_collect_questions(db): def get_user_collect_questions():
# 从t_bi_question表中获取所有is_collect=1的数据 # 从t_bi_question表中获取所有is_collect=1的数据
sql=""" sql="""
SELECT * FROM t_bi_question WHERE is_collect = 1 SELECT * FROM t_bi_question WHERE is_collect = 1
""" """
_data = db.execute_query(sql) with PostgreSQLUtil(postgresql_pool.getconn()) as db:
return _data _data = db.execute_query(sql)
return _data

@ -22,6 +22,13 @@ class PostgreSQLUtil:
def __init__(self, connection): def __init__(self, connection):
self.connection = connection self.connection = connection
def __enter__(self):
return self
def __exit__(self, exc_type, exc_val, exc_tb):
self.connection.commit()
postgresql_pool.putconn(self.connection)
def execute_query(self, sql, params=None, return_dict=True): def execute_query(self, sql, params=None, return_dict=True):
"""执行查询并返回结果""" """执行查询并返回结果"""
try: try:
@ -46,20 +53,6 @@ class PostgreSQLUtil:
print(f"执行SQL出错: {e}") print(f"执行SQL出错: {e}")
self.connection.rollback() self.connection.rollback()
raise raise
finally:
self.connection.commit()
def query_to_json(self, sql, params=None):
"""返回JSON格式结果"""
data = self.execute_query(sql, params)
return json.dumps(data, default=self.json_serializer)
@staticmethod
def json_serializer(obj):
"""处理JSON无法序列化的类型"""
if isinstance(obj, (date, datetime)):
return obj.isoformat()
raise TypeError(f"Type {type(obj)} not serializable")
def get_db(): def get_db():

@ -2,15 +2,14 @@ import json
import uuid import uuid
import uvicorn # 导入 uvicorn import uvicorn # 导入 uvicorn
from fastapi import FastAPI, Depends, Form, Query from fastapi import FastAPI, Form, Query
from openai import OpenAI from openai import OpenAI
from starlette.responses import StreamingResponse from starlette.responses import StreamingResponse
from starlette.staticfiles import StaticFiles from starlette.staticfiles import StaticFiles
from Config import MODEL_API_KEY, MODEL_API_URL, MODEL_NAME from Config import MODEL_API_KEY, MODEL_API_URL, QWEN_MODEL_NAME
from Model.biModel import * from Model.biModel import *
from Text2Sql.Util.MarkdownToDocxUtil import markdown_to_docx from Text2Sql.Util.MarkdownToDocxUtil import markdown_to_docx
from Text2Sql.Util.PostgreSQLUtil import get_db
from Text2Sql.Util.SaveToExcel import save_to_excel from Text2Sql.Util.SaveToExcel import save_to_excel
from Text2Sql.Util.VannaUtil import VannaUtil from Text2Sql.Util.VannaUtil import VannaUtil
@ -31,7 +30,7 @@ def read_root():
# 通过语义生成Excel # 通过语义生成Excel
# http://10.10.21.20:8000/questions/get_excel # http://10.10.21.20:8000/questions/get_excel
@app.post("/questions/get_excel") @app.post("/questions/get_excel")
def get_excel(question_id: str = Form(...), question_str: str = Form(...), db: PostgreSQLUtil = Depends(get_db)): def get_excel(question_id: str = Form(...), question_str: str = Form(...)):
# 只接受guid号 # 只接受guid号
if len(question_id) != 36: if len(question_id) != 36:
return {"success": False, "message": "question_id格式错误"} return {"success": False, "message": "question_id格式错误"}
@ -44,24 +43,25 @@ def get_excel(question_id: str = Form(...), question_str: str = Form(...), db: P
question = question_str + common_prompt question = question_str + common_prompt
# 先删除后插入,防止重复插入 # 先删除后插入,防止重复插入
delete_question(db, question_id) delete_question(question_id)
insert_question(db, question_id, question) insert_question(question_id, question)
# 获取完整 SQL # 获取完整 SQL
sql = vn.generate_sql(question) sql = vn.generate_sql(question)
print("生成的查询 SQL:\n", sql) print("生成的查询 SQL:\n", sql)
# 更新question_id # 更新question_id
update_question_by_id(db, question_id=question_id, sql=sql, state_id=1) update_question_by_id(question_id=question_id, sql=sql, state_id=1)
# 执行SQL查询 # 执行SQL查询
_data = db.execute_query(sql) with PostgreSQLUtil(postgresql_pool.getconn()) as db:
_data = db.execute_query(sql)
# 在static目录下生成一个guid号的临时文件 # 在static目录下生成一个guid号的临时文件
uuid_str = str(uuid.uuid4()) uuid_str = str(uuid.uuid4())
filename = f"static/{uuid_str}.xlsx" filename = f"static/{uuid_str}.xlsx"
save_to_excel(_data, filename) save_to_excel(_data, filename)
# 更新EXCEL文件名称 # 更新EXCEL文件名称
update_question_by_id(db, question_id, excel_file_name=filename) update_question_by_id(question_id, excel_file_name=filename)
# 返回静态文件URL # 返回静态文件URL
return {"success": True, "message": "Excel文件生成成功", "download_url": f"/static/{uuid_str}.xlsx"} return {"success": True, "message": "Excel文件生成成功", "download_url": f"/static/{uuid_str}.xlsx"}
@ -70,8 +70,7 @@ def get_excel(question_id: str = Form(...), question_str: str = Form(...), db: P
@app.api_route("/questions/get_docx_stream", methods=["POST", "GET"]) @app.api_route("/questions/get_docx_stream", methods=["POST", "GET"])
async def get_docx_stream( async def get_docx_stream(
question_id: str = Form(None, description="问题IDPOST请求"), # POST 请求参数 question_id: str = Form(None, description="问题IDPOST请求"), # POST 请求参数
question_id_get: str = Query(None, description="问题IDGET请求"), # GET 请求参数 question_id_get: str = Query(None, description="问题IDGET请求") # GET 请求参数
db: PostgreSQLUtil = Depends(get_db)
): ):
# 根据请求方式获取 question_id # 根据请求方式获取 question_id
if question_id is not None: # POST 请求 if question_id is not None: # POST 请求
@ -82,7 +81,7 @@ async def get_docx_stream(
return {"success": False, "message": "缺少问题ID参数"} return {"success": False, "message": "缺少问题ID参数"}
# 根据问题ID获取查询sql # 根据问题ID获取查询sql
sql = get_question_by_id(db, question_id)[0]['sql'] sql = get_question_by_id(question_id)[0]['sql']
# 4、生成word报告 # 4、生成word报告
prompt = ''' prompt = '''
请根据以下 JSON 数据整理出2000字左右的话描述当前数据情况要求 请根据以下 JSON 数据整理出2000字左右的话描述当前数据情况要求
@ -92,7 +91,8 @@ async def get_docx_stream(
4尽量以条目列出这样更清晰 4尽量以条目列出这样更清晰
5数据 5数据
''' '''
_data = db.execute_query(sql) with PostgreSQLUtil(postgresql_pool.getconn()) as db:
_data = db.execute_query(sql)
prompt = prompt + json.dumps(_data, ensure_ascii=False) prompt = prompt + json.dumps(_data, ensure_ascii=False)
# 初始化 OpenAI 客户端 # 初始化 OpenAI 客户端
@ -103,7 +103,8 @@ async def get_docx_stream(
# 调用 OpenAI API 生成总结(流式输出) # 调用 OpenAI API 生成总结(流式输出)
response = client.chat.completions.create( response = client.chat.completions.create(
model=MODEL_NAME, #model=MODEL_NAME,
model=QWEN_MODEL_NAME,
messages=[ messages=[
{"role": "system", "content": "你是一个数据分析助手,擅长从 JSON 数据中提取关键信息并生成详细的总结。"}, {"role": "system", "content": "你是一个数据分析助手,擅长从 JSON 数据中提取关键信息并生成详细的总结。"},
{"role": "user", "content": prompt} {"role": "user", "content": prompt}
@ -126,6 +127,7 @@ async def get_docx_stream(
chunk_content = chunk.choices[0].delta.content chunk_content = chunk.choices[0].delta.content
# 逐字拆分并返回 # 逐字拆分并返回
for char in chunk_content: for char in chunk_content:
print(char, end="", flush=True) # 逐字输出到控制台
yield char.encode("utf-8") # 将字符编码为 UTF-8 字节 yield char.encode("utf-8") # 将字符编码为 UTF-8 字节
summary += char # 将内容拼接到 summary 中 summary += char # 将内容拼接到 summary 中
@ -133,7 +135,7 @@ async def get_docx_stream(
markdown_to_docx(summary, output_file=filename) markdown_to_docx(summary, output_file=filename)
# 记录到数据库 # 记录到数据库
update_question_by_id(db, question_id, docx_file_name=filename) update_question_by_id(question_id, docx_file_name=filename)
except Exception as e: except Exception as e:
# 如果发生异常,返回错误信息 # 如果发生异常,返回错误信息
@ -141,6 +143,7 @@ async def get_docx_stream(
"success": False, "success": False,
"message": f"生成Word文件失败: {str(e)}" "message": f"生成Word文件失败: {str(e)}"
}) })
print(error_response) # 输出错误信息到控制台
yield error_response.encode("utf-8") # 将错误信息编码为 UTF-8 字节 yield error_response.encode("utf-8") # 将错误信息编码为 UTF-8 字节
finally: finally:
@ -168,7 +171,6 @@ async def get_docx_stream(
async def get_docx_file( async def get_docx_file(
question_id: str = Form(None, description="问题IDPOST请求"), # POST 请求参数 question_id: str = Form(None, description="问题IDPOST请求"), # POST 请求参数
question_id_get: str = Query(None, description="问题IDGET请求"), # GET 请求参数 question_id_get: str = Query(None, description="问题IDGET请求"), # GET 请求参数
db: PostgreSQLUtil = Depends(get_db)
): ):
# 根据请求方式获取 question_id # 根据请求方式获取 question_id
if question_id is not None: # POST 请求 if question_id is not None: # POST 请求
@ -179,7 +181,7 @@ async def get_docx_file(
return {"success": False, "message": "缺少问题ID参数"} return {"success": False, "message": "缺少问题ID参数"}
# 根据问题ID获取查询docx_file_name # 根据问题ID获取查询docx_file_name
docx_file_name = get_question_by_id(db, question_id)[0]['docx_file_name'] docx_file_name = get_question_by_id(question_id)[0]['docx_file_name']
# 返回成功和静态文件的URL # 返回成功和静态文件的URL
return {"success": True, "message": "Word文件生成成功", "download_url": f"{docx_file_name}"} return {"success": True, "message": "Word文件生成成功", "download_url": f"{docx_file_name}"}
@ -187,31 +189,31 @@ async def get_docx_file(
# 设置问题为系统推荐问题 ,0:取消1设置 # 设置问题为系统推荐问题 ,0:取消1设置
@app.post("/questions/set_system_recommend") @app.post("/questions/set_system_recommend")
def set_system_recommend(question_id: str = Form(...), flag: str = Form(...), db: PostgreSQLUtil = Depends(get_db)): def set_system_recommend(question_id: str = Form(...), flag: str = Form(...)):
set_system_recommend_questions(db, question_id, flag) set_system_recommend_questions(question_id, flag)
# 提示保存成功 # 提示保存成功
return {"success": True, "message": "保存成功"} return {"success": True, "message": "保存成功"}
# 设置问题为用户收藏问题 ,0:取消1设置 # 设置问题为用户收藏问题 ,0:取消1设置
@app.post("/questions/set_user_collect") @app.post("/questions/set_user_collect")
def set_user_collect(question_id: str = Form(...), flag: str = Form(...), db: PostgreSQLUtil = Depends(get_db)): def set_user_collect(question_id: str = Form(...), flag: str = Form(...)):
set_user_collect_questions(db, question_id, flag) set_user_collect_questions(question_id, flag)
# 提示保存成功 # 提示保存成功
return {"success": True, "message": "保存成功"} return {"success": True, "message": "保存成功"}
# 查询有哪些系统推荐问题 # 查询有哪些系统推荐问题
@app.get("/questions/get_system_recommend") @app.get("/questions/get_system_recommend")
def get_system_recommend(db: PostgreSQLUtil = Depends(get_db)): def get_system_recommend():
# 查询所有系统推荐问题 # 查询所有系统推荐问题
system_recommend_questions = get_system_recommend_questions(db) system_recommend_questions = get_system_recommend_questions()
# 返回查询结果 # 返回查询结果
return {"success": True, "data": system_recommend_questions} return {"success": True, "data": system_recommend_questions}
# 查询有哪些用户收藏问题 # 查询有哪些用户收藏问题
@app.get("/questions/get_user_collect") @app.get("/questions/get_user_collect")
def get_user_collect(db: PostgreSQLUtil = Depends(get_db)): def get_user_collect():
# 查询所有用户收藏问题 # 查询所有用户收藏问题
user_collect_questions = get_user_collect_questions(db) user_collect_questions = get_user_collect_questions()
# 返回查询结果 # 返回查询结果
return {"success": True, "data": user_collect_questions} return {"success": True, "data": user_collect_questions}

Loading…
Cancel
Save