main
HuangHai 5 months ago
parent 669ebf65e5
commit cbdd0c99bf

@ -101,7 +101,7 @@ if __name__ == "__main__":
# 开始训练
print("开始训练...")
# 打开CreateTable.sql文件内容
with open("Sql/CreateTable.sql", "r", encoding="utf-8") as file:
with open("Sql/AreaSchoolLesson.sql", "r", encoding="utf-8") as file:
ddl = file.read()
# 训练数据
vn.train(

@ -42,8 +42,8 @@ def generate_sql_from_prompt(ddl: str, prompt: str) -> str:
raise ValueError("未能生成 SQL 查询")
if __name__ == '__main__':
# 读取 Sql/CreateTable.sql 文件
with open("../Sql/CreateTable.sql", "r", encoding="utf-8") as file:
# 读取 Sql/AreaSchoolLesson.sql 文件
with open("../Sql/AreaSchoolLesson.sql", "r", encoding="utf-8") as file:
ddl = file.read()
# 自然语言描述

@ -15,7 +15,7 @@ if __name__ == "__main__":
# 开始训练
print("开始训练...")
# 打开CreateTable.sql文件内容
with open("Sql/CreateTable.sql", "r", encoding="utf-8") as file:
with open("Sql/AreaSchoolLesson.sql", "r", encoding="utf-8") as file:
ddl = file.read()
# 训练数据
vn.train(
@ -23,16 +23,35 @@ if __name__ == "__main__":
)
# 自然语言提问
# '''
# 整体情况
question = '''
查询:
1发布时间是2024年度
2每个行政区每个学校都上传了多少课程数量
3格式: 行政区划名,学段,排名,学校名称,课程数量
'''
# 指定行政区域
# question = '''
# 查询:
# 1、发布时间是2024年度
# 2、二道区每个学校都上传了多少课程数量
# 3、格式: 行政区划名,学段,排名,学校名称,发布年份,课程数量
# '''
# 指定学段
# question = '''
# 查询:
# 1、发布时间是2024年度
# 2、每个学段每个科目上传课程数量按由多到少排序
# 3、字段名: 学段,科目,排名,课程数量
# '''
common_prompt = '''
要求
返回的信息要求
1行政区划为NULL 或者是空字符的不参加统计
2目标数据库是Postgresql 16,注意字段名称不要有二义性问题
3使用stage_name描述学段,一定不要使用stage_id
4使用subject_name描述科目,一定不要使用subject_id
'''
question = question + common_prompt
# 开始查询

Loading…
Cancel
Save