diff --git a/AI/Text2Sql/YunXiao.py b/AI/Text2Sql/YunXiao.py index e239722a..25f27c1c 100644 --- a/AI/Text2Sql/YunXiao.py +++ b/AI/Text2Sql/YunXiao.py @@ -15,6 +15,48 @@ from Util.EchartsUtil import * 2、应该有一些固定的组合用法预置出来,给出范例,让用户可以简单修改后就能使用 3、应该有类似于 保存为用例,查询历史等功能,让用户方便利旧。 ''' + + +def infer_axes_fields(field_names, sample_data): + """ + 使用 AI 大模型推断 X 轴和 Y 轴的字段 + :param field_names: 数据字段名列表 + :param sample_data: 数据示例(前几行) + :return: X 轴字段, Y 轴字段 + """ + client = OpenAI(api_key=MODEL_API_KEY, base_url=MODEL_API_URL) # 初始化OpenAI客户端 + + # 构造提示词 + prompt = f""" + 以下是数据的字段名和示例数据: + 字段名: {field_names} + 示例数据: {sample_data} + + 请根据这些数据,推荐适合用于柱状图的 X 轴和 Y 轴字段。 + X 轴应为分类字段(如学段、科目、行政区划等),规则如下: + 1. 如果学段、科目同时存在,返回学段+科目。 + 2. 如果学段、科目都不存在,返回学段。 + 3. 如果只有学段或科目存在,返回存在的字段。 + Y 轴应为数值字段(如课程数量、数量等)。 + 请直接返回 X 轴和 Y 轴的字段名,格式为:X_轴字段, Y_轴字段 + """ + + # 调用 AI 大模型 + response = client.chat.completions.create( + model=MODEL_NAME, + messages=[ + {"role": "system", "content": "你是一个数据分析助手,帮助用户选择合适的字段生成图表。"}, + {"role": "user", "content": prompt} + ], + max_tokens=50 + ) + + # 解析 AI 返回的结果 + result = response.choices[0].message.content.strip() + x_column, y_column = result.split(", ") + return x_column, y_column + + if __name__ == "__main__": vn = VannaUtil() @@ -31,7 +73,6 @@ if __name__ == "__main__": vn.train(documentation="Sql/AreaSchoolLesson.md") # 使用 SQL 进行训练 - # 读取 SQL 文件 with open('Sql/AreaSchoolLessonGenerate.sql', 'r', encoding='utf-8') as file: sql_content = file.read() # 使用正则表达式提取注释和 SQL 语句 @@ -40,7 +81,7 @@ if __name__ == "__main__": # 打印提取的注释和 SQL 语句 for i, (comment, sql) in enumerate(sql_snippets, 1): - vn.train(sql=comment.strip() + '\n' + sql.strip()+'\n') + vn.train(sql=comment.strip() + '\n' + sql.strip() + '\n') # 自然语言提问 # 整体情况 @@ -82,20 +123,30 @@ if __name__ == "__main__": with PostgreSQLUtil() as db: _data = db.execute_query(sql) + # 获取字段名和数据示例 + field_names = list(_data[0].keys()) if _data else [] + sample_data = _data[:3] # 取前 3 行作为示例数据 + + # 推断 X 轴和 Y 轴字段 + x_column, y_column = infer_axes_fields(field_names, sample_data) + + x_columns = x_column.split('+') + y_columns = y_column.split('+') + # 1、生成柱状图 generate_bar_chart( _data=_data, title="学段+科目课程数量柱状图", - x_columns=["学段", "科目"], # 动态指定 X 轴列 - y_columns=["课程数量"], # 动态指定 Y 轴列 + x_columns=x_columns, # 动态指定 X 轴列 + y_columns=y_columns, # 动态指定 Y 轴列 output_file="d:/lesson_bar_chart.html" ) # 2、生成饼状图 generate_pie_chart( _data=_data, title="学段+科目分布", - category_columns=["学段", "科目"], # 多列组合参数 - value_column="课程数量", + category_columns=x_columns, # 多列组合参数 + value_column=y_columns[0], output_file="d:/lesson_pie_chart.html" )