diff --git a/AI/Text2Sql/Model/__init__.py b/AI/Text2Sql/Model/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/AI/Text2Sql/Model/__pycache__/__init__.cpython-310.pyc b/AI/Text2Sql/Model/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 00000000..3c7c72b0 Binary files /dev/null and b/AI/Text2Sql/Model/__pycache__/__init__.cpython-310.pyc differ diff --git a/AI/Text2Sql/Model/__pycache__/biModel.cpython-310.pyc b/AI/Text2Sql/Model/__pycache__/biModel.cpython-310.pyc new file mode 100644 index 00000000..7d3db5a1 Binary files /dev/null and b/AI/Text2Sql/Model/__pycache__/biModel.cpython-310.pyc differ diff --git a/AI/Text2Sql/Model/biModel.py b/AI/Text2Sql/Model/biModel.py new file mode 100644 index 00000000..ed21772b --- /dev/null +++ b/AI/Text2Sql/Model/biModel.py @@ -0,0 +1,70 @@ +from pydantic import BaseModel +from sqlalchemy import create_engine, Column, Integer, String, SmallInteger, Date +from sqlalchemy.orm import declarative_base # 更新导入路径 +from sqlalchemy.orm import sessionmaker + +# 数据库连接配置 +DATABASE_URL = "postgresql+psycopg2://postgres:DsideaL147258369@10.10.14.71:5432/szjz_db" + +# 创建数据库引擎 +engine = create_engine(DATABASE_URL) + +# 创建 SessionLocal 类 +SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine) + +# 创建 Base 类 +Base = declarative_base() + +# 定义 t_bi_question 表模型 +class TBIQuestion(Base): + __tablename__ = "t_bi_question" + + id = Column(Integer, primary_key=True, index=True) + question = Column(String(255), nullable=False) + state_id = Column(SmallInteger, nullable=False, default=0) + sql = Column(String(2048)) + bar_chart_x_columns = Column(String(255)) + bar_chart_y_columns = Column(String(255)) + pie_chart_category_columns = Column(String(255)) + pie_chart_value_column = Column(String(255)) + session_id = Column(String(255), nullable=False) + excel_file_name = Column(String(255)) + bar_chart_file_name = Column(String(255)) + pie_chart_file_name = Column(String(255)) + report_file_name = Column(String(255)) + create_time = Column(Date, nullable=False, default="CURRENT_TIMESTAMP") + is_system = Column(SmallInteger, nullable=False, default=0) + is_collect = Column(SmallInteger, nullable=False, default=0) + +# 创建 Pydantic 模型 +class TBIQuestionCreate(BaseModel): + question: str + state_id: int = 0 + sql: str = None + bar_chart_x_columns: str = None + bar_chart_y_columns: str = None + pie_chart_category_columns: str = None + pie_chart_value_column: str = None + session_id: str + excel_file_name: str = None + bar_chart_file_name: str = None + pie_chart_file_name: str = None + report_file_name: str = None + is_system: int = 0 + is_collect: int = 0 + +class TBIQuestionUpdate(BaseModel): + question: str = None + state_id: int = None + sql: str = None + bar_chart_x_columns: str = None + bar_chart_y_columns: str = None + pie_chart_category_columns: str = None + pie_chart_value_column: str = None + session_id: str = None + excel_file_name: str = None + bar_chart_file_name: str = None + pie_chart_file_name: str = None + report_file_name: str = None + is_system: int = None + is_collect: int = None \ No newline at end of file diff --git a/AI/Text2Sql/YunXiao.py b/AI/Text2Sql/Train.py similarity index 100% rename from AI/Text2Sql/YunXiao.py rename to AI/Text2Sql/Train.py diff --git a/AI/Text2Sql/Util/SaveToExcel.py b/AI/Text2Sql/Util/SaveToExcel.py index 31c497ad..e9d85010 100644 --- a/AI/Text2Sql/Util/SaveToExcel.py +++ b/AI/Text2Sql/Util/SaveToExcel.py @@ -1,6 +1,9 @@ +import io +import pandas as pd +from openpyxl import Workbook from openpyxl.styles import Font, PatternFill, Alignment, Border, Side from openpyxl.utils import get_column_letter -import pandas as pd + def save_to_excel(data, filename): """ @@ -61,4 +64,77 @@ def save_to_excel(data, filename): column_width = min(max(max_length + 2, 10)*2, 120) # 加 2 是为了留出一些空白 # 设置列宽 column_letter = get_column_letter(idx + 1) - worksheet.column_dimensions[column_letter].width = column_width \ No newline at end of file + worksheet.column_dimensions[column_letter].width = column_width + + +def save_to_excel_stream(data): + """ + 将数据集保存为格式化的Excel文件流 + + 参数: + data - 数据集 (列表字典格式,例如:[{"列1": "值1", "列2": "值2"}, ...]) + + 返回: + BytesIO - 包含Excel文件内容的流 + """ + # 转换数据为DataFrame + df = pd.DataFrame(data) + + # 创建一个 BytesIO 对象作为缓冲区 + output = io.BytesIO() + + # 创建Excel工作簿 + wb = Workbook() + ws = wb.active + ws.title = '统计报表' + + # 写入数据 + for row in pd.DataFrame(data).itertuples(index=False): + ws.append(row) + + # 定义边框样式 + thin_border = Border(left=Side(style='thin'), + right=Side(style='thin'), + top=Side(style='thin'), + bottom=Side(style='thin')) + + # 设置全局行高 + for row in ws.iter_rows(): + ws.row_dimensions[row[0].row].height = 20 + # 为所有单元格添加边框 + for cell in row: + cell.border = thin_border + + # 设置标题样式 + header_font = Font(bold=True, size=14) + header_fill = PatternFill(start_color='ADD8E6', end_color='ADD8E6', fill_type='solid') + + for cell in ws[1]: + cell.font = header_font + cell.fill = header_fill + cell.alignment = Alignment(horizontal='center', vertical='center') + + # 设置数据行样式 + data_font = Font(size=14) + for row in ws.iter_rows(min_row=2): + for cell in row: + cell.font = data_font + cell.alignment = Alignment(vertical='center', wrap_text=True) + + # 动态设置列宽 + for idx, column in enumerate(df.columns): + # 获取列的最大长度 + max_length = max( + df[column].astype(str).map(len).max(), # 数据列的最大长度 + len(str(column)) # 列名的长度 + ) + # 计算列宽,确保在 10 到 120 之间 + column_width = min(max(max_length + 2, 10) * 2, 120) # 加 2 是为了留出一些空白 + # 设置列宽 + column_letter = get_column_letter(idx + 1) + ws.column_dimensions[column_letter].width = column_width + + # 将工作簿保存到 BytesIO 对象 + wb.save(output) + output.seek(0) # 将指针移动到流的开头 + return output \ No newline at end of file diff --git a/AI/Text2Sql/Util/VannaUtil.py b/AI/Text2Sql/Util/VannaUtil.py index 3dfe7df1..d7e0e5a1 100644 --- a/AI/Text2Sql/Util/VannaUtil.py +++ b/AI/Text2Sql/Util/VannaUtil.py @@ -1,61 +1,128 @@ import re from typing import List, Dict, Any import requests +import sqlite3 from vanna.base import VannaBase from Config import * class VannaUtil(VannaBase): - def __init__(self): + def __init__(self, db_type='sqlite', db_uri=None): super().__init__() self.api_key = MODEL_API_KEY self.base_url = MODEL_GENERATION_TEXT_URL # 阿里云专用API地址 self.model = QWEN_MODEL_NAME # 根据实际模型名称调整 self.training_data = [] self.chat_history = [] + self.db_type = db_type + self.db_uri = db_uri or 'vanna.db' # 默认使用 SQLite + self._init_db() + + def _init_db(self): + """初始化数据库连接""" + if self.db_type == 'sqlite': + self.conn = sqlite3.connect(self.db_uri) + self._create_tables() + elif self.db_type == 'postgres': + import psycopg2 + self.conn = psycopg2.connect(self.db_uri) + self._create_tables() + else: + raise ValueError(f"Unsupported database type: {self.db_type}") + + def _create_tables(self): + """创建训练数据表""" + cursor = self.conn.cursor() + if self.db_type == 'sqlite': + cursor.execute(''' + CREATE TABLE IF NOT EXISTS training_data ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + type TEXT, + question TEXT, + sql TEXT, + content TEXT + ) + ''') + elif self.db_type == 'postgres': + cursor.execute(''' + CREATE TABLE IF NOT EXISTS training_data ( + id SERIAL PRIMARY KEY, + type TEXT, + question TEXT, + sql TEXT, + content TEXT + ) + ''') + self.conn.commit() - # ---------- 必须实现的抽象方法 ---------- def add_ddl(self, ddl: str, **kwargs) -> None: - self.training_data.append({"type": "ddl", "content": ddl}) + """添加 DDL""" + cursor = self.conn.cursor() + cursor.execute('INSERT INTO training_data (type, content) VALUES (?, ?)', ('ddl', ddl)) + self.conn.commit() def add_documentation(self, doc: str, **kwargs) -> None: - self.training_data.append({"type": "documentation", "content": doc}) + """添加文档""" + cursor = self.conn.cursor() + cursor.execute('INSERT INTO training_data (type, content) VALUES (?, ?)', ('documentation', doc)) + self.conn.commit() def add_question_sql(self, question: str, sql: str, **kwargs) -> None: - self.training_data.append({"type": "qa", "question": question, "sql": sql}) + """添加问答对""" + cursor = self.conn.cursor() + cursor.execute('INSERT INTO training_data (type, question, sql) VALUES (?, ?, ?)', ('qa', question, sql)) + self.conn.commit() def get_related_ddl(self, question: str, **kwargs) -> str: - return "\n".join([item["content"] for item in self.training_data if item["type"] == "ddl"]) + """获取相关 DDL""" + cursor = self.conn.cursor() + cursor.execute('SELECT content FROM training_data WHERE type = ?', ('ddl',)) + return "\n".join(row[0] for row in cursor.fetchall()) def get_related_documentation(self, question: str, **kwargs) -> str: - return "\n".join([item["content"] for item in self.training_data if item["type"] == "documentation"]) + """获取相关文档""" + cursor = self.conn.cursor() + cursor.execute('SELECT content FROM training_data WHERE type = ?', ('documentation',)) + return "\n".join(row[0] for row in cursor.fetchall()) def get_training_data(self, **kwargs) -> List[Dict[str, Any]]: - return self.training_data + """获取所有训练数据""" + cursor = self.conn.cursor() + cursor.execute('SELECT * FROM training_data') + columns = [column[0] for column in cursor.description] + return [dict(zip(columns, row)) for row in cursor.fetchall()] + + def remove_training_data(self, id: str, **kwargs) -> bool: + """删除训练数据""" + cursor = self.conn.cursor() + cursor.execute('DELETE FROM training_data WHERE id = ?', (id,)) + self.conn.commit() + return cursor.rowcount > 0 + + def generate_embedding(self, text: str, **kwargs) -> List[float]: + """生成嵌入向量""" + return [] + + def get_similar_question_sql(self, question: str, **kwargs) -> List[Dict[str, Any]]: + """获取相似问答对""" + return [] - # ---------- 对话方法 ---------- def system_message(self, message: str) -> None: + """添加系统消息""" self.chat_history = [{"role": "system", "content": message}] def user_message(self, message: str) -> None: + """添加用户消息""" self.chat_history.append({"role": "user", "content": message}) def assistant_message(self, message: str) -> None: + """添加助手消息""" self.chat_history.append({"role": "assistant", "content": message}) def submit_prompt(self, prompt: str, **kwargs) -> str: + """提交提示词""" return self.generate_sql(question=prompt) - # ---------- 其他方法 ---------- - def remove_training_data(self, id: str, **kwargs) -> bool: - return True - - def generate_embedding(self, text: str, **kwargs) -> List[float]: - return [] - - def get_similar_question_sql(self, question: str, **kwargs) -> List[Dict[str, Any]]: - return [] - def _clean_sql_output(self, raw_sql: str) -> str: """增强版清洗逻辑""" # 移除所有非SQL内容 @@ -119,4 +186,4 @@ class VannaUtil(VannaBase): except Exception as e: print(f"\nAPI请求错误: {str(e)}") - return "" + return "" \ No newline at end of file diff --git a/AI/Text2Sql/Util/__pycache__/SaveToExcel.cpython-310.pyc b/AI/Text2Sql/Util/__pycache__/SaveToExcel.cpython-310.pyc index 589d8762..9744971e 100644 Binary files a/AI/Text2Sql/Util/__pycache__/SaveToExcel.cpython-310.pyc and b/AI/Text2Sql/Util/__pycache__/SaveToExcel.cpython-310.pyc differ diff --git a/AI/Text2Sql/Util/__pycache__/VannaUtil.cpython-310.pyc b/AI/Text2Sql/Util/__pycache__/VannaUtil.cpython-310.pyc index 0cc2ffe9..c710a730 100644 Binary files a/AI/Text2Sql/Util/__pycache__/VannaUtil.cpython-310.pyc and b/AI/Text2Sql/Util/__pycache__/VannaUtil.cpython-310.pyc differ diff --git a/AI/Text2Sql/__pycache__/app.cpython-310.pyc b/AI/Text2Sql/__pycache__/app.cpython-310.pyc index df88b321..f4681ea9 100644 Binary files a/AI/Text2Sql/__pycache__/app.cpython-310.pyc and b/AI/Text2Sql/__pycache__/app.cpython-310.pyc differ diff --git a/AI/Text2Sql/app.py b/AI/Text2Sql/app.py index 22e8cfb1..d0df49f7 100644 --- a/AI/Text2Sql/app.py +++ b/AI/Text2Sql/app.py @@ -1,84 +1,13 @@ import re -from typing import io -import pandas as pd -from fastapi import FastAPI, HTTPException, Depends -from pydantic import BaseModel -from sqlalchemy import create_engine, Column, Integer, String, SmallInteger, Date -from sqlalchemy.orm import declarative_base # 更新导入路径 -from sqlalchemy.orm import sessionmaker, Session import uvicorn # 导入 uvicorn +from fastapi import FastAPI, HTTPException, Depends +from sqlalchemy.orm import Session from starlette.responses import StreamingResponse - from Text2Sql.Util.PostgreSQLUtil import PostgreSQLUtil -from Text2Sql.Util.SaveToExcel import save_to_excel +from Text2Sql.Util.SaveToExcel import save_to_excel_stream from Text2Sql.Util.VannaUtil import VannaUtil - -# 数据库连接配置 -DATABASE_URL = "postgresql+psycopg2://postgres:DsideaL147258369@10.10.14.71:5432/szjz_db" - -# 创建数据库引擎 -engine = create_engine(DATABASE_URL) - -# 创建 SessionLocal 类 -SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine) - -# 创建 Base 类 -Base = declarative_base() - -# 定义 t_bi_question 表模型 -class TBIQuestion(Base): - __tablename__ = "t_bi_question" - - id = Column(Integer, primary_key=True, index=True) - question = Column(String(255), nullable=False) - state_id = Column(SmallInteger, nullable=False, default=0) - sql = Column(String(2048)) - bar_chart_x_columns = Column(String(255)) - bar_chart_y_columns = Column(String(255)) - pie_chart_category_columns = Column(String(255)) - pie_chart_value_column = Column(String(255)) - session_id = Column(String(255), nullable=False) - excel_file_name = Column(String(255)) - bar_chart_file_name = Column(String(255)) - pie_chart_file_name = Column(String(255)) - report_file_name = Column(String(255)) - create_time = Column(Date, nullable=False, default="CURRENT_TIMESTAMP") - is_system = Column(SmallInteger, nullable=False, default=0) - is_collect = Column(SmallInteger, nullable=False, default=0) - -# 创建 Pydantic 模型 -class TBIQuestionCreate(BaseModel): - question: str - state_id: int = 0 - sql: str = None - bar_chart_x_columns: str = None - bar_chart_y_columns: str = None - pie_chart_category_columns: str = None - pie_chart_value_column: str = None - session_id: str - excel_file_name: str = None - bar_chart_file_name: str = None - pie_chart_file_name: str = None - report_file_name: str = None - is_system: int = 0 - is_collect: int = 0 - -class TBIQuestionUpdate(BaseModel): - question: str = None - state_id: int = None - sql: str = None - bar_chart_x_columns: str = None - bar_chart_y_columns: str = None - pie_chart_category_columns: str = None - pie_chart_value_column: str = None - session_id: str = None - excel_file_name: str = None - bar_chart_file_name: str = None - pie_chart_file_name: str = None - report_file_name: str = None - is_system: int = None - is_collect: int = None +from Model.biModel import * # 初始化 FastAPI app = FastAPI() @@ -87,6 +16,7 @@ app = FastAPI() def read_root(): return {"message": "Hello, World!"} + # 获取数据库会话 def get_db(): db = SessionLocal() @@ -95,6 +25,7 @@ def get_db(): finally: db.close() + # 创建记录 @app.post("/questions/", response_model=TBIQuestionCreate) def create_question(question: TBIQuestionCreate, db: Session = Depends(get_db)): @@ -104,6 +35,7 @@ def create_question(question: TBIQuestionCreate, db: Session = Depends(get_db)): db.refresh(db_question) return db_question + # 读取记录 @app.get("/questions/{question_id}", response_model=TBIQuestionCreate) def read_question(question_id: int, db: Session = Depends(get_db)): @@ -112,6 +44,7 @@ def read_question(question_id: int, db: Session = Depends(get_db)): raise HTTPException(status_code=404, detail="Question not found") return db_question + # 更新记录 @app.put("/questions/{question_id}", response_model=TBIQuestionCreate) def update_question(question_id: int, question: TBIQuestionUpdate, db: Session = Depends(get_db)): @@ -125,6 +58,7 @@ def update_question(question_id: int, question: TBIQuestionUpdate, db: Session = db.refresh(db_question) return db_question + # 删除记录 @app.delete("/questions/{question_id}") def delete_question(question_id: int, db: Session = Depends(get_db)): @@ -135,16 +69,17 @@ def delete_question(question_id: int, db: Session = Depends(get_db)): db.commit() return {"message": "Question deleted successfully"} + # 通过语义生成Excel -@app.post("/questions/getExcel") -def getExcel(question: str): +@app.post("/questions/get_excel") +def get_excel(question: str): # 指定学段 - question = ''' - 查询: - 1、发布时间是2024年度 - 2、每个学段,每个科目,上传课程数量,按由多到少排序 - 3、字段名: 学段,科目,排名,课程数量 - ''' + # question = ''' + # 查询: + # 1、发布时间是2024年度 + # 2、每个学段,每个科目,上传课程数量,按由多到少排序 + # 3、字段名: 学段,科目,排名,课程数量 + # ''' common_prompt = ''' 返回的信息要求: 1、行政区划为NULL 或者是空字符的不参加统计 @@ -165,45 +100,10 @@ def getExcel(question: str): media_type="application/vnd.openxmlformats-officedocument.spreadsheetml.sheet", headers={"Content-Disposition": "attachment; filename=导出信息.xlsx"} ) -def save_to_excel_stream(data): - # 将数据保存为Excel文件流 - df = pd.DataFrame(data) - output = io.BytesIO() - with pd.ExcelWriter(output, engine='openpyxl') as writer: - df.to_excel(writer, index=False) - output.seek(0) # 将指针移动到流的开头 - return output -# 添加 main 函数 -def main(): - # 开始训练 - print("开始训练...") - # 打开AreaSchoolLesson.sql文件内容 - with open("Sql/AreaSchoolLessonDDL.sql", "r", encoding="utf-8") as file: - ddl = file.read() - # 训练数据 - vn.train( - ddl=ddl - ) - # 添加有关业务术语或定义的文档 - # vn.train(documentation="Sql/AreaSchoolLesson.md") - - # 使用 SQL 进行训练 - with open('Sql/AreaSchoolLessonGenerate.sql', 'r', encoding='utf-8') as file: - sql_content = file.read() - # 使用正则表达式提取注释和 SQL 语句 - sql_pattern = r'/\*(.*?)\*/(.*?);' - sql_snippets = re.findall(sql_pattern, sql_content, re.DOTALL) - - # 打印提取的注释和 SQL 语句 - for i, (comment, sql) in enumerate(sql_snippets, 1): - vn.train(sql=comment.strip() + '\n' + sql.strip() + '\n') - - - uvicorn.run("app:app", host="0.0.0.0", port=8000, reload=True) # 确保直接运行脚本时启动 FastAPI 应用 if __name__ == "__main__": vn = VannaUtil() - main() \ No newline at end of file + uvicorn.run("app:app", host="0.0.0.0", port=8000, reload=True) diff --git a/AI/Text2Sql/vanna.db b/AI/Text2Sql/vanna.db new file mode 100644 index 00000000..11abe5af Binary files /dev/null and b/AI/Text2Sql/vanna.db differ