main
HuangHai 4 months ago
parent 8336aa2b7d
commit 97d8e70b54

@ -34,3 +34,10 @@ DIFY_URL='http://10.10.14.207/v1'
# Vanna Postgresql
VANNA_POSTGRESQL_URI = "postgresql://postgres:DsideaL147258369@10.10.14.71:5432/szjz_db"
# Postgresql配置信息
PG_HOST="10.10.14.71"
PG_PORT=5432
PG_DATABASE="szjz_db"
PG_USER="postgres"
PG_PASSWORD="DsideaL147258369"

@ -1,78 +1,70 @@
from pydantic import BaseModel
from sqlalchemy import create_engine, Column, Integer, String, SmallInteger, Date
from sqlalchemy.orm import declarative_base # 更新导入路径
from sqlalchemy.orm import sessionmaker
from Text2Sql.Util.PostgreSQLUtil import PostgreSQLUtil
# 数据库连接配置
DATABASE_URL = "postgresql+psycopg2://postgres:DsideaL147258369@10.10.14.71:5432/szjz_db"
# 创建数据库引擎
engine = create_engine(DATABASE_URL)
# 删除数据
def delete_question(db, question_id: str):
# 删除 t_bi_question 表中的数据
delete_sql = """
DELETE FROM t_bi_question WHERE id = %s
"""
db.execute_query(delete_sql, (question_id,))
# 创建 SessionLocal 类
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
# 创建 Base 类
Base = declarative_base()
# 插入数据
def insert_question(db, question_id: str, question: str):
# 向 t_bi_question 表插入数据
insert_sql = """
INSERT INTO t_bi_question (id,question, state_id, is_system, is_collect)
VALUES (%s,%s, %s, %s, %s)
"""
db.execute_query(insert_sql, (question_id, question, 0, 0, 0))
# 获取数据库会话
def get_db():
db = SessionLocal()
try:
yield db
finally:
db.close()
# 定义 t_bi_question 表模型
class TBIQuestion(Base):
__tablename__ = "t_bi_question"
# 修改数据
'''
示例:
# 更新 question 和 state_id 字段
update_question_by_id(db, question_id=1, question="新的问题描述", state_id=1)
# 只更新 excel_file_name 字段
update_question_by_id(db, question_id=1, excel_file_name="new_excel.xlsx")
# 只更新 is_collect 字段
update_question_by_id(db, question_id=1, is_collect=1)
id = Column(Integer, primary_key=True, index=True)
question = Column(String(255), nullable=False)
state_id = Column(SmallInteger, nullable=False, default=0)
sql = Column(String(2048))
bar_chart_x_columns = Column(String(255))
bar_chart_y_columns = Column(String(255))
pie_chart_category_columns = Column(String(255))
pie_chart_value_column = Column(String(255))
session_id = Column(String(255), nullable=False)
excel_file_name = Column(String(255))
bar_chart_file_name = Column(String(255))
pie_chart_file_name = Column(String(255))
report_file_name = Column(String(255))
create_time = Column(Date, nullable=False, default="CURRENT_TIMESTAMP")
is_system = Column(SmallInteger, nullable=False, default=0)
is_collect = Column(SmallInteger, nullable=False, default=0)
# 不更新任何字段(因为所有参数都是 None
update_question_by_id(db, question_id=1, question=None, state_id=None)
'''
# 创建 Pydantic 模型
class TBIQuestionCreate(BaseModel):
question: str
state_id: int = 0
sql: str = None
bar_chart_x_columns: str = None
bar_chart_y_columns: str = None
pie_chart_category_columns: str = None
pie_chart_value_column: str = None
session_id: str
excel_file_name: str = None
bar_chart_file_name: str = None
pie_chart_file_name: str = None
report_file_name: str = None
is_system: int = 0
is_collect: int = 0
class TBIQuestionUpdate(BaseModel):
question: str = None
state_id: int = None
sql: str = None
bar_chart_x_columns: str = None
bar_chart_y_columns: str = None
pie_chart_category_columns: str = None
pie_chart_value_column: str = None
session_id: str = None
excel_file_name: str = None
bar_chart_file_name: str = None
pie_chart_file_name: str = None
report_file_name: str = None
is_system: int = None
is_collect: int = None
def update_question_by_id(db: PostgreSQLUtil, question_id: str, **kwargs):
"""
根据主键更新 t_bi_question 只更新非 None 的字段
:param db: PostgreSQLUtil 实例
:param question_id: 主键 id
:param kwargs: 需要更新的字段和值
:return: 更新是否成功
"""
# 过滤掉值为 None 的字段
update_fields = {k: v for k, v in kwargs.items() if v is not None}
if not update_fields:
return False # 没有需要更新的字段
# 动态构建 SET 子句
set_clause = ", ".join([f"{field} = %s" for field in update_fields.keys()])
# 构建完整 SQL
sql = f"""
UPDATE t_bi_question
SET {set_clause}
WHERE id = %s
"""
# 参数列表
params = list(update_fields.values()) + [question_id]
# 执行更新
try:
db.execute_query(sql, params)
return True
except Exception as e:
print(f"更新失败: {e}")
return False

@ -1,41 +1,26 @@
import psycopg2
from psycopg2 import OperationalError
from psycopg2.extras import RealDictCursor
import json
from datetime import date, datetime
import psycopg2
from psycopg2 import pool
from psycopg2.extras import RealDictCursor
from Config import *
class PostgreSQLUtil:
def __init__(self, host="10.10.14.71", port=5432,
dbname="szjz_db", user="postgres", password="DsideaL147258369"):
self.conn_params = {
"host": host,
"port": port,
"dbname": dbname,
"user": user,
"password": password
}
self.connection = None
def __enter__(self):
self.connect()
return self
def __exit__(self, exc_type, exc_val, exc_tb):
self.close()
# 创建连接池
postgresql_pool = psycopg2.pool.SimpleConnectionPool(
minconn=1,
maxconn=10,
host=PG_HOST,
port=PG_PORT,
dbname=PG_DATABASE,
user=PG_USER,
password=PG_PASSWORD
)
def connect(self):
try:
self.connection = psycopg2.connect(**self.conn_params)
print("成功连接到PostgreSQL数据库")
except OperationalError as e:
print(f"连接错误: {e}")
raise
def close(self):
if self.connection:
self.connection.close()
print("数据库连接已关闭")
class PostgreSQLUtil:
def __init__(self, connection):
self.connection = connection
def execute_query(self, sql, params=None, return_dict=True):
"""执行查询并返回结果"""
@ -77,13 +62,32 @@ class PostgreSQLUtil:
raise TypeError(f"Type {type(obj)} not serializable")
def get_db():
connection = postgresql_pool.getconn()
try:
yield PostgreSQLUtil(connection)
finally:
postgresql_pool.putconn(connection)
# 使用示例
if __name__ == "__main__":
with PostgreSQLUtil() as db:
'''
db_gen = get_db()调用生成器函数返回生成器对象
db = next(db_gen)从生成器中获取 PostgreSQLUtil 实例
生成器函数确保连接在使用后正确归还到连接池
'''
# 从生成器中获取数据库实例
db_gen = get_db()
db = next(db_gen)
try:
# 示例查询
result = db.execute_query("SELECT version()")
print("数据库版本:", result)
# 返回JSON
json_data = db.query_to_json("SELECT * FROM t_base_class LIMIT 2")
print("JSON结果:", json_data)
print("JSON结果:", json_data)
finally:
# 手动关闭生成器
db_gen.close()

@ -1,13 +1,15 @@
import json
import uuid
import uvicorn # 导入 uvicorn
from fastapi import FastAPI, HTTPException, Depends, Form
from sqlalchemy.orm import Session
from starlette.responses import JSONResponse
from fastapi import FastAPI, Depends, Form
from openai import OpenAI
from starlette.staticfiles import StaticFiles
from Config import MODEL_API_KEY, MODEL_API_URL, MODEL_NAME
from Model.biModel import *
from Text2Sql.Util.PostgreSQLUtil import PostgreSQLUtil
from Text2Sql.Util.MarkdownToDocxUtil import markdown_to_docx
from Text2Sql.Util.PostgreSQLUtil import get_db
from Text2Sql.Util.SaveToExcel import save_to_excel
from Text2Sql.Util.VannaUtil import VannaUtil
@ -15,81 +17,113 @@ from Text2Sql.Util.VannaUtil import VannaUtil
app = FastAPI()
# 配置静态文件目录
app.mount("/static", StaticFiles(directory="static"), name="static")
# 初始化一次vanna的类
vn = VannaUtil()
@app.get("/")
def read_root():
return {"message": "Welcome to Vanna AI SQL !"}
# 创建记录
@app.post("/questions/", response_model=TBIQuestionCreate)
def create_question(question: TBIQuestionCreate, db: Session = Depends(get_db)):
db_question = TBIQuestion(**question.dict())
db.add(db_question)
db.commit()
db.refresh(db_question)
return db_question
# 读取记录
@app.get("/questions/{question_id}", response_model=TBIQuestionCreate)
def read_question(question_id: int, db: Session = Depends(get_db)):
db_question = db.query(TBIQuestion).filter(TBIQuestion.id == question_id).first()
if db_question is None:
raise HTTPException(status_code=404, detail="Question not found")
return db_question
# 更新记录
@app.put("/questions/{question_id}", response_model=TBIQuestionCreate)
def update_question(question_id: int, question: TBIQuestionUpdate, db: Session = Depends(get_db)):
db_question = db.query(TBIQuestion).filter(TBIQuestion.id == question_id).first()
if db_question is None:
raise HTTPException(status_code=404, detail="Question not found")
for key, value in question.dict().items():
if value is not None:
setattr(db_question, key, value)
db.commit()
db.refresh(db_question)
return db_question
# 删除记录
@app.delete("/questions/{question_id}")
def delete_question(question_id: int, db: Session = Depends(get_db)):
db_question = db.query(TBIQuestion).filter(TBIQuestion.id == question_id).first()
if db_question is None:
raise HTTPException(status_code=404, detail="Question not found")
db.delete(db_question)
db.commit()
return {"message": "Question deleted successfully"}
return {"message": "Welcome to AI SQL World!"}
# 通过语义生成Excel
# http://10.10.21.20:8000/questions/get_excel
# 参数question
@app.post("/questions/get_excel")
def get_excel(question: str = Form(...)):
def get_excel(question_id: str = Form(...), question_str: str = Form(...), db: PostgreSQLUtil = Depends(get_db)):
# 只接受guid号
if len(question_id) != 36:
return {"success": False, "message": "question_id格式错误"}
common_prompt = '''
返回的信息要求
1行政区划为NULL 或者是空字符的不参加统计
2目标数据库是Postgresql 16
'''
question = question + common_prompt
question = question_str + common_prompt
# 先删除后插入,防止重复插入
delete_question(db, question_id)
insert_question(db, question_id, question)
# 获取完整 SQL
sql = vn.generate_sql(question)
print("生成的查询 SQL:\n", sql)
# 执行SQL查询
with PostgreSQLUtil() as db:
_data = db.execute_query(sql)
# 在static目录下生成一个guid号的临时文件
uuidStr = str(uuid.uuid4())
filename = f"static/{uuidStr}.xlsx"
save_to_excel(_data, filename)
# 返回静态文件URL
return {"success": True, "message": "Excel文件生成成功", "download_url": f"/static/{uuidStr}.xlsx"}
# 更新question_id
update_question_by_id(db, question_id=question_id, sql=sql, state_id=1)
# 执行SQL查询
_data = db.execute_query(sql)
# 在static目录下生成一个guid号的临时文件
uuid_str = str(uuid.uuid4())
filename = f"static/{uuid_str}.xlsx"
save_to_excel(_data, filename)
# 更新EXCEL文件名称
update_question_by_id(db, question_id, excel_file_name=filename)
# 返回静态文件URL
return {"success": True, "message": "Excel文件生成成功", "download_url": f"/static/{uuid_str}.xlsx"}
# 获取docx
# http://10.10.21.20:8000/questions/get_docx
@app.post("/questions/get_docx")
def get_docx(question_id: str = Form(...), db: PostgreSQLUtil = Depends(get_db)):
select_sql = """
select * from t_bi_question where id=%s
"""
_data = db.execute_query(select_sql, (question_id,))
sql = _data[0]['sql']
# 4、生成word报告
prompt = '''
请根据以下 JSON 数据整理出2000字左右的话描述当前数据情况要求
1以Markdown格式返回我将直接通过markdown格式生成Word
2标题统一为长春云校数据分析报告
3内容中不要提到JSON数据统一称数据
4尽量以条目列出这样更清晰
5数据
'''
_data = db.execute_query(sql)
prompt = prompt + json.dumps(_data, ensure_ascii=False)
# 初始化 OpenAI 客户端
client = OpenAI(
api_key=MODEL_API_KEY,
base_url=MODEL_API_URL,
)
# 调用 OpenAI API 生成总结(流式输出)
response = client.chat.completions.create(
model=MODEL_NAME,
messages=[
{"role": "system", "content": "你是一个数据分析助手,擅长从 JSON 数据中提取关键信息并生成详细的总结。"},
{"role": "user", "content": prompt}
],
max_tokens=3000, # 控制生成内容的长度
temperature=0.7, # 控制生成内容的创造性
stream=True # 启用流式输出
)
# 初始化变量用于存储流式输出的内容
summary = ""
# 处理流式输出
for chunk in response:
if chunk.choices[0].delta.content: # 检查是否有内容
chunk_content = chunk.choices[0].delta.content
print(chunk_content, end="", flush=True) # 实时打印到控制台
summary += chunk_content # 将内容拼接到 summary 中
# 最终 summary 为完整的 Markdown 内容
print("\n\n流式输出完成summary 已拼接为完整字符串。")
# 生成 Word 文档
uuid_str = str(uuid.uuid4())
filename = f"static/{uuid_str}.docx"
markdown_to_docx(summary, output_file=filename)
# 返回静态文件URL
return {"success": True, "message": "Word文件生成成功", "download_url": f"/static/{uuid_str}.docx"}
# 确保直接运行脚本时启动 FastAPI 应用

Loading…
Cancel
Save