You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
|
|
|
|
import os
|
|
|
|
|
|
|
|
|
|
from Text2Sql.Util.PostgreSQLUtil import PostgreSQLUtil
|
|
|
|
|
from Text2Sql.Util.SaveToExcel import save_to_excel
|
|
|
|
|
from Text2Sql.Util.Text2SqlUtil import *
|
|
|
|
|
'''
|
|
|
|
|
经验:
|
|
|
|
|
1、尽量使用宽表,少用关联,越少越好
|
|
|
|
|
2、应该有一些固定的组合用法预置出来,给出范例,让用户可以简单修改后就能使用
|
|
|
|
|
3、应该有类似于 保存为用例,查询历史等功能,让用户方便利旧。
|
|
|
|
|
'''
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
|
vn = DeepSeekVanna()
|
|
|
|
|
|
|
|
|
|
# 开始训练
|
|
|
|
|
print("开始训练...")
|
|
|
|
|
# 打开CreateTable.sql文件内容
|
|
|
|
|
with open("Sql/CreateTable.sql", "r", encoding="utf-8") as file:
|
|
|
|
|
ddl = file.read()
|
|
|
|
|
# 训练数据
|
|
|
|
|
vn.train(
|
|
|
|
|
ddl=ddl
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# 自然语言提问
|
|
|
|
|
# '''
|
|
|
|
|
question = '''
|
|
|
|
|
查询发布时间是2024年度,每个行政区划每个学校都上传了多少课程数量,
|
|
|
|
|
返回: 行政区划名,学段,排名,学校名称,课程数量
|
|
|
|
|
'''
|
|
|
|
|
common_prompt = '''
|
|
|
|
|
要求:
|
|
|
|
|
1、行政区划为NULL 或者是空字符的不参加统计工作,
|
|
|
|
|
'''
|
|
|
|
|
question = question + common_prompt
|
|
|
|
|
# 开始查询
|
|
|
|
|
print("开始查询...")
|
|
|
|
|
# 获取完整 SQL
|
|
|
|
|
sql = vn.generate_sql(question)
|
|
|
|
|
print("生成的查询 SQL:\n", sql)
|
|
|
|
|
|
|
|
|
|
# 执行SQL查询
|
|
|
|
|
with PostgreSQLUtil() as db:
|
|
|
|
|
sample_data = db.execute_query(sql)
|
|
|
|
|
filename = "d:/导出信息.xlsx"
|
|
|
|
|
save_to_excel(sample_data, filename)
|
|
|
|
|
os.startfile(filename)
|