37 lines
1.2 KiB
Python
37 lines
1.2 KiB
Python
from Text2Sql.Util.VannaUtil import *
|
||
|
||
'''
|
||
经验:
|
||
1、尽量使用宽表,少用关联,越少越好
|
||
2、应该有一些固定的组合用法预置出来,给出范例,让用户可以简单修改后就能使用
|
||
3、应该有类似于 保存为用例,查询历史等功能,让用户方便利旧。
|
||
'''
|
||
|
||
if __name__ == "__main__":
|
||
vn = VannaUtil()
|
||
|
||
# 开始训练
|
||
print("开始训练...")
|
||
# 打开AreaSchoolLesson.sql文件内容
|
||
with open("Sql/AreaSchoolLessonDDL.sql", "r", encoding="utf-8") as file:
|
||
ddl = file.read()
|
||
# 训练数据
|
||
vn.train(
|
||
ddl=ddl
|
||
)
|
||
# 添加有关业务术语或定义的文档
|
||
# vn.train(documentation="Sql/AreaSchoolLesson.md")
|
||
|
||
# 使用 SQL 进行训练
|
||
with open('Sql/AreaSchoolLessonGenerate.sql', 'r', encoding='utf-8') as file:
|
||
sql_content = file.read()
|
||
# 使用正则表达式提取注释和 SQL 语句
|
||
sql_pattern = r'/\*(.*?)\*/(.*?);'
|
||
sql_snippets = re.findall(sql_pattern, sql_content, re.DOTALL)
|
||
|
||
# 打印提取的注释和 SQL 语句
|
||
for i, (comment, sql) in enumerate(sql_snippets, 1):
|
||
vn.train(sql=comment.strip() + '\n' + sql.strip() + '\n')
|
||
|
||
|