diff --git a/AI/Text2Sql/Model/__pycache__/biModel.cpython-310.pyc b/AI/Text2Sql/Model/__pycache__/biModel.cpython-310.pyc index eeaa4a4c..1cf7b09e 100644 Binary files a/AI/Text2Sql/Model/__pycache__/biModel.cpython-310.pyc and b/AI/Text2Sql/Model/__pycache__/biModel.cpython-310.pyc differ diff --git a/AI/Text2Sql/Model/biModel.py b/AI/Text2Sql/Model/biModel.py index 2f085611..d162583f 100644 --- a/AI/Text2Sql/Model/biModel.py +++ b/AI/Text2Sql/Model/biModel.py @@ -1,134 +1,87 @@ -from Text2Sql.Util.PostgreSQLUtil import PostgreSQLUtil, postgresql_pool - +import asyncpg # 删除数据 -def delete_question(question_id: str): - with PostgreSQLUtil(postgresql_pool.getconn()) as db: - # 删除 t_bi_question 表中的数据 - delete_sql = """ - DELETE FROM t_bi_question WHERE id = %s - """ - db.execute_query(delete_sql, (question_id,)) - +async def delete_question(db: asyncpg.Connection, question_id: str): + delete_sql = """ + DELETE FROM t_bi_question WHERE id = $1 + """ + await db.execute(delete_sql, question_id) # 插入数据 -def insert_question(question_id: str, question: str): - # 向 t_bi_question 表插入数据 - with PostgreSQLUtil(postgresql_pool.getconn()) as db: - insert_sql = """ - INSERT INTO t_bi_question (id,question, state_id, is_system, is_collect) - VALUES (%s,%s, %s, %s, %s) - """ - db.execute_query(insert_sql, (question_id, question, 0, 0, 0)) - +async def insert_question(db: asyncpg.Connection, question_id: str, question: str): + insert_sql = """ + INSERT INTO t_bi_question (id, question, state_id, is_system, is_collect) + VALUES ($1, $2, $3, $4, $5) + """ + await db.execute(insert_sql, question_id, question, 0, 0, 0) # 修改数据 -''' -示例: -# 更新 question 和 state_id 字段 -update_question_by_id(db, question_id=1, question="新的问题描述", state_id=1) - -# 只更新 excel_file_name 字段 -update_question_by_id(db, question_id=1, excel_file_name="new_excel.xlsx") - -# 只更新 is_collect 字段 -update_question_by_id(db, question_id=1, is_collect=1) - -# 不更新任何字段(因为所有参数都是 None) -update_question_by_id(db, question_id=1, question=None, state_id=None) -''' - - -def update_question_by_id(question_id: str, **kwargs): - """ - 根据主键更新 t_bi_question 表,只更新非 None 的字段 - :param db: PostgreSQLUtil 实例 - :param question_id: 主键 id - :param kwargs: 需要更新的字段和值 - :return: 更新是否成功 - """ - # 过滤掉值为 None 的字段 +async def update_question_by_id(db: asyncpg.Connection, question_id: str, **kwargs): update_fields = {k: v for k, v in kwargs.items() if v is not None} if not update_fields: - return False # 没有需要更新的字段 - - # 动态构建 SET 子句 - set_clause = ", ".join([f"{field} = %s" for field in update_fields.keys()]) - - with PostgreSQLUtil(postgresql_pool.getconn()) as db: - # 构建完整 SQL - sql = f""" - UPDATE t_bi_question - SET {set_clause} - WHERE id = %s - """ - # 参数列表 - params = list(update_fields.values()) + [question_id] - - # 执行更新 - try: - db.execute_query(sql, params) - return True - except Exception as e: - print(f"更新失败: {e}") - return False - - -# 根据 问题id 查询 sql -def get_question_by_id(question_id: str): - with PostgreSQLUtil(postgresql_pool.getconn()) as db: - select_sql = """ - select * from t_bi_question where id=%s - """ - _data = db.execute_query(select_sql, (question_id,)) - return _data + return False + set_clause = ", ".join([f"{field} = ${i+1}" for i, field in enumerate(update_fields.keys())]) + sql = f""" + UPDATE t_bi_question + SET {set_clause} + WHERE id = ${len(update_fields) + 1} + """ + params = list(update_fields.values()) + [question_id] + + try: + await db.execute(sql, *params) + return True + except Exception as e: + print(f"更新失败: {e}") + return False + +# 根据问题 ID 查询 SQL +async def get_question_by_id(db: asyncpg.Connection, question_id: str): + select_sql = """ + SELECT * FROM t_bi_question WHERE id = $1 + """ + _data = await db.fetch(select_sql, question_id) + return _data # 保存系统推荐 -def set_system_recommend_questions(question_id: str, flag: str): - with PostgreSQLUtil(postgresql_pool.getconn()) as db: - sql = f""" - UPDATE t_bi_question - SET is_system =%s WHERE id = %s - """ - # 执行更新 - try: - db.execute_query(sql, int(flag), question_id) - return True - except Exception as e: - print(f"更新失败: {e}") - return False +async def set_system_recommend_questions(db: asyncpg.Connection, question_id: str, flag: str): + sql = """ + UPDATE t_bi_question + SET is_system = $1 WHERE id = $2 + """ + try: + await db.execute(sql, int(flag), question_id) + return True + except Exception as e: + print(f"更新失败: {e}") + return False # 设置用户收藏 -def set_user_collect_questions(question_id: str, flag: str): - sql = f""" - UPDATE t_bi_question - SET is_collect =%s WHERE id = %s - """ - with PostgreSQLUtil(postgresql_pool.getconn()) as db: - # 执行更新 - try: - db.execute_query(sql, int(flag), question_id) - return True - except Exception as e: - print(f"更新失败: {e}") - return False - -# 查询有哪些系统推荐问题 -def get_system_recommend_questions(): +async def set_user_collect_questions(db: asyncpg.Connection, question_id: str, flag: str): sql = """ - SELECT * FROM t_bi_question WHERE is_system = 1 + UPDATE t_bi_question + SET is_collect = $1 WHERE id = $2 """ - with PostgreSQLUtil(postgresql_pool.getconn()) as db: - _data = db.execute_query(sql) - return _data + try: + await db.execute(sql, int(flag), question_id) + return True + except Exception as e: + print(f"更新失败: {e}") + return False + +# 查询系统推荐问题 +async def get_system_recommend_questions(db: asyncpg.Connection): + sql = """ + SELECT * FROM t_bi_question WHERE is_system = 1 + """ + _data = await db.fetch(sql) + return _data -# 查询有哪些用户收藏问题 -def get_user_collect_questions(): - # 从t_bi_question表中获取所有is_collect=1的数据 - sql=""" - SELECT * FROM t_bi_question WHERE is_collect = 1 +# 查询用户收藏问题 +async def get_user_collect_questions(db: asyncpg.Connection): + sql = """ + SELECT * FROM t_bi_question WHERE is_collect = 1 """ - with PostgreSQLUtil(postgresql_pool.getconn()) as db: - _data = db.execute_query(sql) - return _data \ No newline at end of file + _data = await db.fetch(sql) + return _data \ No newline at end of file diff --git a/AI/Text2Sql/Util/PostgreSQLUtil.py b/AI/Text2Sql/Util/PostgreSQLUtil.py deleted file mode 100644 index b45ac52e..00000000 --- a/AI/Text2Sql/Util/PostgreSQLUtil.py +++ /dev/null @@ -1,86 +0,0 @@ -import json -from datetime import date, datetime - -import psycopg2 -from psycopg2 import pool -from psycopg2.extras import RealDictCursor -from Config import * - -# 创建连接池 -postgresql_pool = psycopg2.pool.SimpleConnectionPool( - minconn=1, - maxconn=10, - host=PG_HOST, - port=PG_PORT, - dbname=PG_DATABASE, - user=PG_USER, - password=PG_PASSWORD -) - - -class PostgreSQLUtil: - def __init__(self, connection): - self.connection = connection - - def __enter__(self): - return self - - def __exit__(self, exc_type, exc_val, exc_tb): - self.connection.commit() - postgresql_pool.putconn(self.connection) - - def execute_query(self, sql, params=None, return_dict=True): - """执行查询并返回结果""" - try: - with self.connection.cursor( - cursor_factory=RealDictCursor if return_dict else None - ) as cursor: - cursor.execute(sql, params) - - if cursor.description: - columns = [desc[0] for desc in cursor.description] - results = cursor.fetchall() - - # 转换字典格式 - if return_dict: - return results - else: - return [dict(zip(columns, row)) for row in results] - else: - return {"rowcount": cursor.rowcount} - - except Exception as e: - print(f"执行SQL出错: {e}") - self.connection.rollback() - raise - - -def get_db(): - connection = postgresql_pool.getconn() - try: - yield PostgreSQLUtil(connection) - finally: - postgresql_pool.putconn(connection) - - -# 使用示例 -if __name__ == "__main__": - ''' - db_gen = get_db():调用生成器函数,返回生成器对象。 - db = next(db_gen):从生成器中获取 PostgreSQLUtil 实例。 - 生成器函数确保连接在使用后正确归还到连接池。 - ''' - # 从生成器中获取数据库实例 - db_gen = get_db() - db = next(db_gen) - try: - # 示例查询 - result = db.execute_query("SELECT version()") - print("数据库版本:", result) - - # 返回JSON - json_data = db.query_to_json("SELECT * FROM t_base_class LIMIT 2") - print("JSON结果:", json_data) - finally: - # 手动关闭生成器 - db_gen.close() diff --git a/AI/Text2Sql/Util/__pycache__/PostgreSQLUtil.cpython-310.pyc b/AI/Text2Sql/Util/__pycache__/PostgreSQLUtil.cpython-310.pyc deleted file mode 100644 index 75c3c6a6..00000000 Binary files a/AI/Text2Sql/Util/__pycache__/PostgreSQLUtil.cpython-310.pyc and /dev/null differ diff --git a/AI/Text2Sql/__pycache__/app.cpython-310.pyc b/AI/Text2Sql/__pycache__/app.cpython-310.pyc index 291a307a..dd63c21a 100644 Binary files a/AI/Text2Sql/__pycache__/app.cpython-310.pyc and b/AI/Text2Sql/__pycache__/app.cpython-310.pyc differ diff --git a/AI/Text2Sql/app.py b/AI/Text2Sql/app.py index c5b9841c..568b9a3a 100644 --- a/AI/Text2Sql/app.py +++ b/AI/Text2Sql/app.py @@ -1,13 +1,17 @@ import json import uuid - -import uvicorn # 导入 uvicorn +from datetime import date, datetime +from asyncpg.pool import Pool from fastapi import FastAPI, Form, Query -from openai import OpenAI +from fastapi.staticfiles import StaticFiles +from contextlib import asynccontextmanager +from fastapi import FastAPI, Depends +import asyncpg +import uvicorn +from openai import AsyncOpenAI from starlette.responses import StreamingResponse -from starlette.staticfiles import StaticFiles -from Config import MODEL_API_KEY, MODEL_API_URL, QWEN_MODEL_NAME +from Config import * from Model.biModel import * from Text2Sql.Util.MarkdownToDocxUtil import markdown_to_docx from Text2Sql.Util.SaveToExcel import save_to_excel @@ -15,22 +19,66 @@ from Text2Sql.Util.VannaUtil import VannaUtil # 初始化 FastAPI app = FastAPI() -# 配置静态文件目录 app.mount("/static", StaticFiles(directory="static"), name="static") -# 初始化一次vanna的类 vn = VannaUtil() +# 初始化 FastAPI 应用 +@asynccontextmanager +async def lifespan(app: FastAPI): + # 启动时初始化连接池 + app.state.pool = await asyncpg.create_pool( + host=PG_HOST, + port=PG_PORT, + database=PG_DATABASE, + user=PG_USER, + password=PG_PASSWORD, + min_size=1, + max_size=10 + ) + yield + # 关闭时释放连接池 + await app.state.pool.close() + +app = FastAPI(lifespan=lifespan) + +# 依赖注入连接池 +async def get_db(): + async with app.state.pool.acquire() as connection: + yield connection + +class PostgreSQLUtil: + def __init__(self, pool: Pool): + self.pool = pool + + async def execute_query(self, sql, params=None): + async with self.pool.acquire() as connection: + result = await connection.fetch(sql, params) + return result + + async def query_to_json(self, sql, params=None): + data = await self.execute_query(sql, params) + return json.dumps(data, default=self.json_serializer) + + @staticmethod + def json_serializer(obj): + """处理JSON无法序列化的类型""" + if isinstance(obj, (date, datetime)): + return obj.isoformat() + raise TypeError(f"Type {type(obj)} not serializable") + +async def create_pool(): + return await asyncpg.create_pool( + host=PG_HOST, + port=PG_PORT, + database=PG_DATABASE, + user=PG_USER, + password=PG_PASSWORD, + min_size=1, + max_size=10 + ) - -@app.get("/") -def read_root(): - return {"message": "Welcome to AI SQL World!"} - - -# 通过语义生成Excel -# http://10.10.21.20:8000/questions/get_excel @app.post("/questions/get_excel") -def get_excel(question_id: str = Form(...), question_str: str = Form(...)): +async def get_excel(question_id: str = Form(...), question_str: str = Form(...), db: asyncpg.Connection = Depends(get_db)): # 只接受guid号 if len(question_id) != 36: return {"success": False, "message": "question_id格式错误"} @@ -43,34 +91,40 @@ def get_excel(question_id: str = Form(...), question_str: str = Form(...)): question = question_str + common_prompt # 先删除后插入,防止重复插入 - delete_question(question_id) - insert_question(question_id, question) + await delete_question(db, question_id) + await insert_question(db, question_id, question) # 获取完整 SQL sql = vn.generate_sql(question) print("生成的查询 SQL:\n", sql) # 更新question_id - update_question_by_id(question_id=question_id, sql=sql, state_id=1) + await update_question_by_id(db, question_id=question_id, sql=sql, state_id=1) # 执行SQL查询 - with PostgreSQLUtil(postgresql_pool.getconn()) as db: - _data = db.execute_query(sql) + _data = await db.fetch(sql) # 在static目录下,生成一个guid号的临时文件 uuid_str = str(uuid.uuid4()) filename = f"static/{uuid_str}.xlsx" save_to_excel(_data, filename) # 更新EXCEL文件名称 - update_question_by_id(question_id, excel_file_name=filename) + await update_question_by_id(db, question_id, excel_file_name=filename) # 返回静态文件URL return {"success": True, "message": "Excel文件生成成功", "download_url": f"/static/{uuid_str}.xlsx"} # http://10.10.21.20:8000/questions/get_docx_stream?question_id_get=af15d834-e7f5-46b4-a0f6-15f1f888f443 +# 初始化 OpenAI 客户端 +client = AsyncOpenAI( + api_key=MODEL_API_KEY, + base_url=MODEL_API_URL, +) + @app.api_route("/questions/get_docx_stream", methods=["POST", "GET"]) async def get_docx_stream( question_id: str = Form(None, description="问题ID(POST请求)"), # POST 请求参数 - question_id_get: str = Query(None, description="问题ID(GET请求)") # GET 请求参数 + question_id_get: str = Query(None, description="问题ID(GET请求)"), # GET 请求参数 + db: asyncpg.Connection = Depends(get_db) ): # 根据请求方式获取 question_id if question_id is not None: # POST 请求 @@ -81,8 +135,9 @@ async def get_docx_stream( return {"success": False, "message": "缺少问题ID参数"} # 根据问题ID获取查询sql - sql = get_question_by_id(question_id)[0]['sql'] - # 4、生成word报告 + sql = (await db.fetch("SELECT * FROM t_bi_question WHERE id = $1", question_id))[0]['sql'] + + # 生成word报告 prompt = ''' 请根据以下 JSON 数据,整理出2000字左右的话描述当前数据情况。要求: 1、以Markdown格式返回,我将直接通过markdown格式生成Word。 @@ -91,19 +146,15 @@ async def get_docx_stream( 4、尽量以条目列出,这样更清晰 5、数据: ''' - with PostgreSQLUtil(postgresql_pool.getconn()) as db: - _data = db.execute_query(sql) - prompt = prompt + json.dumps(_data, ensure_ascii=False) - - # 初始化 OpenAI 客户端 - client = OpenAI( - api_key=MODEL_API_KEY, - base_url=MODEL_API_URL, - ) + _data = await db.fetch(sql) + #print(_data) + # 将 asyncpg.Record 转换为 JSON 格式 + json_data = json.dumps([dict(record) for record in _data], ensure_ascii=False) + print(json_data) # 打印 JSON 数据 + prompt = prompt + json.dumps(json_data, ensure_ascii=False) # 调用 OpenAI API 生成总结(流式输出) - response = client.chat.completions.create( - #model=MODEL_NAME, + response = await client.chat.completions.create( model=QWEN_MODEL_NAME, messages=[ {"role": "system", "content": "你是一个数据分析助手,擅长从 JSON 数据中提取关键信息并生成详细的总结。"}, @@ -122,7 +173,7 @@ async def get_docx_stream( async def generate_stream(): summary = "" try: - for chunk in response: + async for chunk in response: # 使用 async for 处理流式响应 if chunk.choices[0].delta.content: # 检查是否有内容 chunk_content = chunk.choices[0].delta.content # 逐字拆分并返回 @@ -135,7 +186,7 @@ async def get_docx_stream( markdown_to_docx(summary, output_file=filename) # 记录到数据库 - update_question_by_id(question_id, docx_file_name=filename) + await db.execute("UPDATE t_bi_question SET docx_file_name = $1 WHERE id = $2", filename, question_id) except Exception as e: # 如果发生异常,返回错误信息 @@ -149,7 +200,7 @@ async def get_docx_stream( finally: # 确保资源释放 if "response" in locals(): - response.close() + await response.aclose() # 使用 StreamingResponse 返回流式结果 return StreamingResponse( @@ -164,60 +215,6 @@ async def get_docx_stream( } ) - -# 返回生成的Word文件下载地址 -# http://10.10.21.20:8000/questions/get_docx_file?question_id_get=af15d834-e7f5-46b4-a0f6-15f1f888f443 -@app.api_route("/questions/get_docx_file", methods=["POST", "GET"]) -async def get_docx_file( - question_id: str = Form(None, description="问题ID(POST请求)"), # POST 请求参数 - question_id_get: str = Query(None, description="问题ID(GET请求)"), # GET 请求参数 -): - # 根据请求方式获取 question_id - if question_id is not None: # POST 请求 - question_id = question_id - elif question_id_get is not None: # GET 请求 - question_id = question_id_get - else: - return {"success": False, "message": "缺少问题ID参数"} - - # 根据问题ID获取查询docx_file_name - docx_file_name = get_question_by_id(question_id)[0]['docx_file_name'] - - # 返回成功和静态文件的URL - return {"success": True, "message": "Word文件生成成功", "download_url": f"{docx_file_name}"} - - -# 设置问题为系统推荐问题 ,0:取消,1:设置 -@app.post("/questions/set_system_recommend") -def set_system_recommend(question_id: str = Form(...), flag: str = Form(...)): - set_system_recommend_questions(question_id, flag) - # 提示保存成功 - return {"success": True, "message": "保存成功"} - -# 设置问题为用户收藏问题 ,0:取消,1:设置 -@app.post("/questions/set_user_collect") -def set_user_collect(question_id: str = Form(...), flag: str = Form(...)): - set_user_collect_questions(question_id, flag) - # 提示保存成功 - return {"success": True, "message": "保存成功"} - -# 查询有哪些系统推荐问题 -@app.get("/questions/get_system_recommend") -def get_system_recommend(): - # 查询所有系统推荐问题 - system_recommend_questions = get_system_recommend_questions() - # 返回查询结果 - return {"success": True, "data": system_recommend_questions} - -# 查询有哪些用户收藏问题 -@app.get("/questions/get_user_collect") -def get_user_collect(): - # 查询所有用户收藏问题 - user_collect_questions = get_user_collect_questions() - # 返回查询结果 - return {"success": True, "data": user_collect_questions} - - -# 确保直接运行脚本时启动 FastAPI 应用 +# 启动 FastAPI if __name__ == "__main__": - uvicorn.run("app:app", host="0.0.0.0", port=8000, reload=True) + uvicorn.run("app:app", host="0.0.0.0", port=8000, reload=True) \ No newline at end of file diff --git a/AI/Text2Sql/static/1a38e532-5349-4df2-a61b-8e40f7a7261d.xlsx b/AI/Text2Sql/static/1a38e532-5349-4df2-a61b-8e40f7a7261d.xlsx new file mode 100644 index 00000000..45f5e16b Binary files /dev/null and b/AI/Text2Sql/static/1a38e532-5349-4df2-a61b-8e40f7a7261d.xlsx differ