From 9ac006e60437ee56d5b1710f31452bd32a2825f8 Mon Sep 17 00:00:00 2001 From: HuangHai <10402852@qq.com> Date: Fri, 14 Mar 2025 16:13:12 +0800 Subject: [PATCH] 'commit' --- AI/Text2Sql/YunXiao.py | 57 +----- AI/Text2Sql/__pycache__/app.cpython-310.pyc | Bin 0 -> 5712 bytes AI/Text2Sql/app.py | 188 ++++++++++++++++++++ 3 files changed, 193 insertions(+), 52 deletions(-) create mode 100644 AI/Text2Sql/__pycache__/app.cpython-310.pyc create mode 100644 AI/Text2Sql/app.py diff --git a/AI/Text2Sql/YunXiao.py b/AI/Text2Sql/YunXiao.py index 25f27c1c..6e251308 100644 --- a/AI/Text2Sql/YunXiao.py +++ b/AI/Text2Sql/YunXiao.py @@ -16,47 +16,6 @@ from Util.EchartsUtil import * 3、应该有类似于 保存为用例,查询历史等功能,让用户方便利旧。 ''' - -def infer_axes_fields(field_names, sample_data): - """ - 使用 AI 大模型推断 X 轴和 Y 轴的字段 - :param field_names: 数据字段名列表 - :param sample_data: 数据示例(前几行) - :return: X 轴字段, Y 轴字段 - """ - client = OpenAI(api_key=MODEL_API_KEY, base_url=MODEL_API_URL) # 初始化OpenAI客户端 - - # 构造提示词 - prompt = f""" - 以下是数据的字段名和示例数据: - 字段名: {field_names} - 示例数据: {sample_data} - - 请根据这些数据,推荐适合用于柱状图的 X 轴和 Y 轴字段。 - X 轴应为分类字段(如学段、科目、行政区划等),规则如下: - 1. 如果学段、科目同时存在,返回学段+科目。 - 2. 如果学段、科目都不存在,返回学段。 - 3. 如果只有学段或科目存在,返回存在的字段。 - Y 轴应为数值字段(如课程数量、数量等)。 - 请直接返回 X 轴和 Y 轴的字段名,格式为:X_轴字段, Y_轴字段 - """ - - # 调用 AI 大模型 - response = client.chat.completions.create( - model=MODEL_NAME, - messages=[ - {"role": "system", "content": "你是一个数据分析助手,帮助用户选择合适的字段生成图表。"}, - {"role": "user", "content": prompt} - ], - max_tokens=50 - ) - - # 解析 AI 返回的结果 - result = response.choices[0].message.content.strip() - x_column, y_column = result.split(", ") - return x_column, y_column - - if __name__ == "__main__": vn = VannaUtil() @@ -70,7 +29,7 @@ if __name__ == "__main__": ddl=ddl ) # 添加有关业务术语或定义的文档 - vn.train(documentation="Sql/AreaSchoolLesson.md") + # vn.train(documentation="Sql/AreaSchoolLesson.md") # 使用 SQL 进行训练 with open('Sql/AreaSchoolLessonGenerate.sql', 'r', encoding='utf-8') as file: @@ -127,26 +86,20 @@ if __name__ == "__main__": field_names = list(_data[0].keys()) if _data else [] sample_data = _data[:3] # 取前 3 行作为示例数据 - # 推断 X 轴和 Y 轴字段 - x_column, y_column = infer_axes_fields(field_names, sample_data) - - x_columns = x_column.split('+') - y_columns = y_column.split('+') - # 1、生成柱状图 generate_bar_chart( _data=_data, title="学段+科目课程数量柱状图", - x_columns=x_columns, # 动态指定 X 轴列 - y_columns=y_columns, # 动态指定 Y 轴列 + x_columns=['学段', '科目'], # 动态指定 X 轴列 + y_columns=['课程数量'], # 动态指定 Y 轴列 output_file="d:/lesson_bar_chart.html" ) # 2、生成饼状图 generate_pie_chart( _data=_data, title="学段+科目分布", - category_columns=x_columns, # 多列组合参数 - value_column=y_columns[0], + category_columns=['学段', '科目'], # 多列组合参数 + value_column='课程数量', output_file="d:/lesson_pie_chart.html" ) diff --git a/AI/Text2Sql/__pycache__/app.cpython-310.pyc b/AI/Text2Sql/__pycache__/app.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..df88b32180d36107b5f7db0f72a5459104603fa8 GIT binary patch literal 5712 zcmbtY>vI&x5#Q(C-b2z!fCQL_z{chPOCaz|CNTo>suD6FVd5>4t<%j&9PB-q-NQ&Z zRbh+7#74$;NSv}$v6OOfNX1}RIWB^!_+Lm>zU@bwlf)-~_lG#Yo;#fm5U!-M(#-YD z^z`&}_jLD6zM&zi;JbV0UnBchD9Sg~sQ)Z9cA{5HF?6cfu*;cz1aKvuo7H^wS?d3edqAbdnv)H6=ui#A+ioJ5b!s4vq zg2Ed3$_d?m1noxFgmx1Tp}h+21ZzgSS+-Ym4YStpHB)O@%e2atu%#Ci&~9hTSnCDF zUdPv*TF=_}Dz{oIP>D06&}A zMwSHsTlkY~(`il2@K)@zgKwEoAVYOPmgMHUr0f2q9)$O$Jiq7g!Jy&5;NaoC=hD33 zXY#oq+{+6*$GoH(M0Te}Dwj(g^E2+)fkNK%M+EoIxSI>!XgXgQ>FVz6obT=K^)ig7`nK(Ort7I^w?F;d zu5DX8=zmAYGuyhK+OfT>(;GWA=CGlCkU%t5^z-R_HkzMugjSeA9n55*NzJyGi zM~d8&R?*-)LmB5xwK)jw#R?*mO^H$GG#?F2h>V}R`#of{4+^w$Mb}LYxm>`9C`nDg z(I`Qf@!?d_^)m{5OwhdN*wLeV`v;xDgNODG4E7v4+@D+#SdNoRWx3-7QOC*VSEPLO8Cw> zCoRp`tM44Gb*?OAcoj1Z3L|-JcYZ?a!r0lATLfO^Kt!unP0KZL94pruh9S#s)vr?5 zxj<>5Ga+~(4-RWRB1T(@RZpRx$?_nQ@f>f|^LaLi0t5ptPy5#tu@$JY`ffvS_X*}5 z&x_M1j^K3m<#Qt^dJdi#6hNl5j}E5ebJS9F?#o z;h2OYJ1g{}vM(#z@L!(u(4e(@-^Nk(D=RgY8k1XUZUN>a6ns?Ahwp{j$bPE^Sjv5Mf;RINb;WAsJP z7N+$^CC}DDfaE=ltCulO#v}%(RG#NAv zn4BQomW0U(nvArqLX!lOwkr8a`HFIMV$U6_a)! zF7N+Bxl8X;ZiVc2BD;esxxiEZ=PHY;_Gu7(0i~+ig(G`E1l1DmcWUPY+SSm%hiH%+ zWh&D$Brm-`2xk$LQzLw=;Q)8t{FZij5|^#JpAqSA$k2ia&JA-!KJOnTdPLaL#B=z$ zXHg0Zv=()GB;HpQUzt$Gm9c0^c?k*2aiye8sne9YC<8V46(y0OG{Sv^`{T3@#RBboNY!BRt(y1|nq#rf`B33?f2!~BjMa3(k1qD8UgsA{O9 z7ky*P>o)!K8g_<_2eG==16wxgw~Xy5zB-L<`syjYq@2cPr;KR@Ti2LUR79I^k|iN6 z(q2{1*aH z({K)rlJ=18>Lq=V?W$9uX=N%rt@J8CZ9Jik8{=ll1fD83`&LOW86|TEuYyiJGB1RGF9>@OnC}bGLnFIUB|{aV z$PiYOOrXnjO=c|jOR1?SU1H6?3ZUZ2%s|R%sSFg_`Kl!y!|20JZ}AB=#av zt))#u`i`awLQNjy!13a!Dt1F>%2EZnma(?F6E(tf92Xt#AlFQjqnrOoX;k*0c(fK% z$B~nat0jo44bBg4M4y1m@X2{7({Ll`fzz1KDsC`o_7@|d7;RtVoB8lhb2t9jz3AAs zZ_dA3zWh%4=ETf=-S$~qx|>J%eSsBI-?6SW^P=*bMr6dPp-~f|J%2> zFIAeeAHOsE!F4q6+_3zzPs(c|FOQl_L+;nojd;v2xa!# zE$}q^*Q-S2=4AQuYvnhlX21M!?)r3%2uTJKn3;MVyPlc6UcU8f%sgDXm$q+vx`tb9 z0)OS(=gS|zSy4|%M@MlP`ev_uICJrBtWxQ2A2`z29mSQ@#HA|d2inXRbtZj<7@)ejX{k@Fh>!*y8ihHj#VcdBiq@X=O;rV(B z%o^8y6MjH1X})Ha(vj<^fAbm2Qcl+Q)^10sLz7D!D zs$$G6`om8?OIKT-LxjRr3b4+YJJ!C~%>7j#K-0&XJ5Oxd*stQ%!4ytPCAWp=GqltCDTaT10V7<3g3LV~;b6idd%PiVT4$2(16P-qh` zQ-|q`R3;Z_f(PMYT*&bZ6NGwS9PH`q3nDyM%u2Tun24I0LJ%&V&7|`p7wDpxOIm?J zH%;*~A`}DHh{Jf;fi22L0kMSxH%=p7E>kFQ-wV`C5GI2JS|xeD*AxRI3x0kNiZVB~ zv;^{yRo@%BriQ*xge!1DZPnMK2%lEQ^`GWAPO?*sP)B$euT-f*Mo|7$344@=1Xf{` zrE+*A5)`I`DELab>61J=Dkyd@JPODM0r?D2xj;8AV0Pq1R**jz4OBH!<)R9fVM&45 z;PVbjN-}8@jVI_Z2W|6L=MH+1>ZmxJru>*p&!hapQkH6QwH+ByLPqwuX6TmQsE71~o>-PhgqqOV literal 0 HcmV?d00001 diff --git a/AI/Text2Sql/app.py b/AI/Text2Sql/app.py new file mode 100644 index 00000000..28e50597 --- /dev/null +++ b/AI/Text2Sql/app.py @@ -0,0 +1,188 @@ +import re + +from fastapi import FastAPI, HTTPException, Depends +from pydantic import BaseModel +from sqlalchemy import create_engine, Column, Integer, String, SmallInteger, Date +from sqlalchemy.orm import declarative_base # 更新导入路径 +from sqlalchemy.orm import sessionmaker, Session +import uvicorn # 导入 uvicorn +from starlette.staticfiles import StaticFiles + +from Text2Sql.Util.VannaUtil import VannaUtil + +# 数据库连接配置 +DATABASE_URL = "postgresql+psycopg2://postgres:DsideaL147258369@10.10.14.71:5432/szjz_db" + +# 创建数据库引擎 +engine = create_engine(DATABASE_URL) + +# 创建 SessionLocal 类 +SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine) + +# 创建 Base 类 +Base = declarative_base() + +# 定义 t_bi_question 表模型 +class TBIQuestion(Base): + __tablename__ = "t_bi_question" + + id = Column(Integer, primary_key=True, index=True) + question = Column(String(255), nullable=False) + state_id = Column(SmallInteger, nullable=False, default=0) + sql = Column(String(2048)) + bar_chart_x_columns = Column(String(255)) + bar_chart_y_columns = Column(String(255)) + pie_chart_category_columns = Column(String(255)) + pie_chart_value_column = Column(String(255)) + session_id = Column(String(255), nullable=False) + excel_file_name = Column(String(255)) + bar_chart_file_name = Column(String(255)) + pie_chart_file_name = Column(String(255)) + report_file_name = Column(String(255)) + create_time = Column(Date, nullable=False, default="CURRENT_TIMESTAMP") + is_system = Column(SmallInteger, nullable=False, default=0) + is_collect = Column(SmallInteger, nullable=False, default=0) + +# 创建 Pydantic 模型 +class TBIQuestionCreate(BaseModel): + question: str + state_id: int = 0 + sql: str = None + bar_chart_x_columns: str = None + bar_chart_y_columns: str = None + pie_chart_category_columns: str = None + pie_chart_value_column: str = None + session_id: str + excel_file_name: str = None + bar_chart_file_name: str = None + pie_chart_file_name: str = None + report_file_name: str = None + is_system: int = 0 + is_collect: int = 0 + +class TBIQuestionUpdate(BaseModel): + question: str = None + state_id: int = None + sql: str = None + bar_chart_x_columns: str = None + bar_chart_y_columns: str = None + pie_chart_category_columns: str = None + pie_chart_value_column: str = None + session_id: str = None + excel_file_name: str = None + bar_chart_file_name: str = None + pie_chart_file_name: str = None + report_file_name: str = None + is_system: int = None + is_collect: int = None + +# 初始化 FastAPI +app = FastAPI() + +@app.get("/") +def read_root(): + return {"message": "Hello, World!"} + +# 获取数据库会话 +def get_db(): + db = SessionLocal() + try: + yield db + finally: + db.close() + +# 创建记录 +@app.post("/questions/", response_model=TBIQuestionCreate) +def create_question(question: TBIQuestionCreate, db: Session = Depends(get_db)): + db_question = TBIQuestion(**question.dict()) + db.add(db_question) + db.commit() + db.refresh(db_question) + return db_question + +# 读取记录 +@app.get("/questions/{question_id}", response_model=TBIQuestionCreate) +def read_question(question_id: int, db: Session = Depends(get_db)): + db_question = db.query(TBIQuestion).filter(TBIQuestion.id == question_id).first() + if db_question is None: + raise HTTPException(status_code=404, detail="Question not found") + return db_question + +# 更新记录 +@app.put("/questions/{question_id}", response_model=TBIQuestionCreate) +def update_question(question_id: int, question: TBIQuestionUpdate, db: Session = Depends(get_db)): + db_question = db.query(TBIQuestion).filter(TBIQuestion.id == question_id).first() + if db_question is None: + raise HTTPException(status_code=404, detail="Question not found") + for key, value in question.dict().items(): + if value is not None: + setattr(db_question, key, value) + db.commit() + db.refresh(db_question) + return db_question + +# 删除记录 +@app.delete("/questions/{question_id}") +def delete_question(question_id: int, db: Session = Depends(get_db)): + db_question = db.query(TBIQuestion).filter(TBIQuestion.id == question_id).first() + if db_question is None: + raise HTTPException(status_code=404, detail="Question not found") + db.delete(db_question) + db.commit() + return {"message": "Question deleted successfully"} + +@app.post("/questions/generate_sql") +def generate_sql(question: str): + # 指定学段 + question = ''' + 查询: + 1、发布时间是2024年度 + 2、每个学段,每个科目,上传课程数量,按由多到少排序 + 3、字段名: 学段,科目,排名,课程数量 + ''' + common_prompt = ''' + 返回的信息要求: + 1、行政区划为NULL 或者是空字符的不参加统计 + 2、目标数据库是Postgresql 16 + ''' + question = question + common_prompt + # 开始查询 + print("开始查询...") + # 获取完整 SQL + sql = vn.generate_sql(question) + print("生成的查询 SQL:\n", sql) + +# 添加 main 函数 +def main(): + # 开始训练 + print("开始训练...") + # 打开AreaSchoolLesson.sql文件内容 + with open("Sql/AreaSchoolLessonDDL.sql", "r", encoding="utf-8") as file: + ddl = file.read() + # 训练数据 + vn.train( + ddl=ddl + ) + # 添加有关业务术语或定义的文档 + # vn.train(documentation="Sql/AreaSchoolLesson.md") + + # 使用 SQL 进行训练 + with open('Sql/AreaSchoolLessonGenerate.sql', 'r', encoding='utf-8') as file: + sql_content = file.read() + # 使用正则表达式提取注释和 SQL 语句 + sql_pattern = r'/\*(.*?)\*/(.*?);' + sql_snippets = re.findall(sql_pattern, sql_content, re.DOTALL) + + # 打印提取的注释和 SQL 语句 + for i, (comment, sql) in enumerate(sql_snippets, 1): + vn.train(sql=comment.strip() + '\n' + sql.strip() + '\n') + + + uvicorn.run("app:app", host="0.0.0.0", port=8000, reload=True) + + + +# 确保直接运行脚本时启动 FastAPI 应用 +if __name__ == "__main__": + vn = VannaUtil() + main() \ No newline at end of file