diff --git a/AI/Text2Sql/YunXiao.py b/AI/Text2Sql/YunXiao.py index 25f27c1c..6e251308 100644 --- a/AI/Text2Sql/YunXiao.py +++ b/AI/Text2Sql/YunXiao.py @@ -16,47 +16,6 @@ from Util.EchartsUtil import * 3、应该有类似于 保存为用例,查询历史等功能,让用户方便利旧。 ''' - -def infer_axes_fields(field_names, sample_data): - """ - 使用 AI 大模型推断 X 轴和 Y 轴的字段 - :param field_names: 数据字段名列表 - :param sample_data: 数据示例(前几行) - :return: X 轴字段, Y 轴字段 - """ - client = OpenAI(api_key=MODEL_API_KEY, base_url=MODEL_API_URL) # 初始化OpenAI客户端 - - # 构造提示词 - prompt = f""" - 以下是数据的字段名和示例数据: - 字段名: {field_names} - 示例数据: {sample_data} - - 请根据这些数据,推荐适合用于柱状图的 X 轴和 Y 轴字段。 - X 轴应为分类字段(如学段、科目、行政区划等),规则如下: - 1. 如果学段、科目同时存在,返回学段+科目。 - 2. 如果学段、科目都不存在,返回学段。 - 3. 如果只有学段或科目存在,返回存在的字段。 - Y 轴应为数值字段(如课程数量、数量等)。 - 请直接返回 X 轴和 Y 轴的字段名,格式为:X_轴字段, Y_轴字段 - """ - - # 调用 AI 大模型 - response = client.chat.completions.create( - model=MODEL_NAME, - messages=[ - {"role": "system", "content": "你是一个数据分析助手,帮助用户选择合适的字段生成图表。"}, - {"role": "user", "content": prompt} - ], - max_tokens=50 - ) - - # 解析 AI 返回的结果 - result = response.choices[0].message.content.strip() - x_column, y_column = result.split(", ") - return x_column, y_column - - if __name__ == "__main__": vn = VannaUtil() @@ -70,7 +29,7 @@ if __name__ == "__main__": ddl=ddl ) # 添加有关业务术语或定义的文档 - vn.train(documentation="Sql/AreaSchoolLesson.md") + # vn.train(documentation="Sql/AreaSchoolLesson.md") # 使用 SQL 进行训练 with open('Sql/AreaSchoolLessonGenerate.sql', 'r', encoding='utf-8') as file: @@ -127,26 +86,20 @@ if __name__ == "__main__": field_names = list(_data[0].keys()) if _data else [] sample_data = _data[:3] # 取前 3 行作为示例数据 - # 推断 X 轴和 Y 轴字段 - x_column, y_column = infer_axes_fields(field_names, sample_data) - - x_columns = x_column.split('+') - y_columns = y_column.split('+') - # 1、生成柱状图 generate_bar_chart( _data=_data, title="学段+科目课程数量柱状图", - x_columns=x_columns, # 动态指定 X 轴列 - y_columns=y_columns, # 动态指定 Y 轴列 + x_columns=['学段', '科目'], # 动态指定 X 轴列 + y_columns=['课程数量'], # 动态指定 Y 轴列 output_file="d:/lesson_bar_chart.html" ) # 2、生成饼状图 generate_pie_chart( _data=_data, title="学段+科目分布", - category_columns=x_columns, # 多列组合参数 - value_column=y_columns[0], + category_columns=['学段', '科目'], # 多列组合参数 + value_column='课程数量', output_file="d:/lesson_pie_chart.html" ) diff --git a/AI/Text2Sql/__pycache__/app.cpython-310.pyc b/AI/Text2Sql/__pycache__/app.cpython-310.pyc new file mode 100644 index 00000000..df88b321 Binary files /dev/null and b/AI/Text2Sql/__pycache__/app.cpython-310.pyc differ diff --git a/AI/Text2Sql/app.py b/AI/Text2Sql/app.py new file mode 100644 index 00000000..28e50597 --- /dev/null +++ b/AI/Text2Sql/app.py @@ -0,0 +1,188 @@ +import re + +from fastapi import FastAPI, HTTPException, Depends +from pydantic import BaseModel +from sqlalchemy import create_engine, Column, Integer, String, SmallInteger, Date +from sqlalchemy.orm import declarative_base # 更新导入路径 +from sqlalchemy.orm import sessionmaker, Session +import uvicorn # 导入 uvicorn +from starlette.staticfiles import StaticFiles + +from Text2Sql.Util.VannaUtil import VannaUtil + +# 数据库连接配置 +DATABASE_URL = "postgresql+psycopg2://postgres:DsideaL147258369@10.10.14.71:5432/szjz_db" + +# 创建数据库引擎 +engine = create_engine(DATABASE_URL) + +# 创建 SessionLocal 类 +SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine) + +# 创建 Base 类 +Base = declarative_base() + +# 定义 t_bi_question 表模型 +class TBIQuestion(Base): + __tablename__ = "t_bi_question" + + id = Column(Integer, primary_key=True, index=True) + question = Column(String(255), nullable=False) + state_id = Column(SmallInteger, nullable=False, default=0) + sql = Column(String(2048)) + bar_chart_x_columns = Column(String(255)) + bar_chart_y_columns = Column(String(255)) + pie_chart_category_columns = Column(String(255)) + pie_chart_value_column = Column(String(255)) + session_id = Column(String(255), nullable=False) + excel_file_name = Column(String(255)) + bar_chart_file_name = Column(String(255)) + pie_chart_file_name = Column(String(255)) + report_file_name = Column(String(255)) + create_time = Column(Date, nullable=False, default="CURRENT_TIMESTAMP") + is_system = Column(SmallInteger, nullable=False, default=0) + is_collect = Column(SmallInteger, nullable=False, default=0) + +# 创建 Pydantic 模型 +class TBIQuestionCreate(BaseModel): + question: str + state_id: int = 0 + sql: str = None + bar_chart_x_columns: str = None + bar_chart_y_columns: str = None + pie_chart_category_columns: str = None + pie_chart_value_column: str = None + session_id: str + excel_file_name: str = None + bar_chart_file_name: str = None + pie_chart_file_name: str = None + report_file_name: str = None + is_system: int = 0 + is_collect: int = 0 + +class TBIQuestionUpdate(BaseModel): + question: str = None + state_id: int = None + sql: str = None + bar_chart_x_columns: str = None + bar_chart_y_columns: str = None + pie_chart_category_columns: str = None + pie_chart_value_column: str = None + session_id: str = None + excel_file_name: str = None + bar_chart_file_name: str = None + pie_chart_file_name: str = None + report_file_name: str = None + is_system: int = None + is_collect: int = None + +# 初始化 FastAPI +app = FastAPI() + +@app.get("/") +def read_root(): + return {"message": "Hello, World!"} + +# 获取数据库会话 +def get_db(): + db = SessionLocal() + try: + yield db + finally: + db.close() + +# 创建记录 +@app.post("/questions/", response_model=TBIQuestionCreate) +def create_question(question: TBIQuestionCreate, db: Session = Depends(get_db)): + db_question = TBIQuestion(**question.dict()) + db.add(db_question) + db.commit() + db.refresh(db_question) + return db_question + +# 读取记录 +@app.get("/questions/{question_id}", response_model=TBIQuestionCreate) +def read_question(question_id: int, db: Session = Depends(get_db)): + db_question = db.query(TBIQuestion).filter(TBIQuestion.id == question_id).first() + if db_question is None: + raise HTTPException(status_code=404, detail="Question not found") + return db_question + +# 更新记录 +@app.put("/questions/{question_id}", response_model=TBIQuestionCreate) +def update_question(question_id: int, question: TBIQuestionUpdate, db: Session = Depends(get_db)): + db_question = db.query(TBIQuestion).filter(TBIQuestion.id == question_id).first() + if db_question is None: + raise HTTPException(status_code=404, detail="Question not found") + for key, value in question.dict().items(): + if value is not None: + setattr(db_question, key, value) + db.commit() + db.refresh(db_question) + return db_question + +# 删除记录 +@app.delete("/questions/{question_id}") +def delete_question(question_id: int, db: Session = Depends(get_db)): + db_question = db.query(TBIQuestion).filter(TBIQuestion.id == question_id).first() + if db_question is None: + raise HTTPException(status_code=404, detail="Question not found") + db.delete(db_question) + db.commit() + return {"message": "Question deleted successfully"} + +@app.post("/questions/generate_sql") +def generate_sql(question: str): + # 指定学段 + question = ''' + 查询: + 1、发布时间是2024年度 + 2、每个学段,每个科目,上传课程数量,按由多到少排序 + 3、字段名: 学段,科目,排名,课程数量 + ''' + common_prompt = ''' + 返回的信息要求: + 1、行政区划为NULL 或者是空字符的不参加统计 + 2、目标数据库是Postgresql 16 + ''' + question = question + common_prompt + # 开始查询 + print("开始查询...") + # 获取完整 SQL + sql = vn.generate_sql(question) + print("生成的查询 SQL:\n", sql) + +# 添加 main 函数 +def main(): + # 开始训练 + print("开始训练...") + # 打开AreaSchoolLesson.sql文件内容 + with open("Sql/AreaSchoolLessonDDL.sql", "r", encoding="utf-8") as file: + ddl = file.read() + # 训练数据 + vn.train( + ddl=ddl + ) + # 添加有关业务术语或定义的文档 + # vn.train(documentation="Sql/AreaSchoolLesson.md") + + # 使用 SQL 进行训练 + with open('Sql/AreaSchoolLessonGenerate.sql', 'r', encoding='utf-8') as file: + sql_content = file.read() + # 使用正则表达式提取注释和 SQL 语句 + sql_pattern = r'/\*(.*?)\*/(.*?);' + sql_snippets = re.findall(sql_pattern, sql_content, re.DOTALL) + + # 打印提取的注释和 SQL 语句 + for i, (comment, sql) in enumerate(sql_snippets, 1): + vn.train(sql=comment.strip() + '\n' + sql.strip() + '\n') + + + uvicorn.run("app:app", host="0.0.0.0", port=8000, reload=True) + + + +# 确保直接运行脚本时启动 FastAPI 应用 +if __name__ == "__main__": + vn = VannaUtil() + main() \ No newline at end of file