Files
QingLong/AI/Text2Sql/Train.py
2025-08-15 09:13:13 +08:00

37 lines
1.2 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from Text2Sql.Util.VannaUtil import *
'''
经验:
1、尽量使用宽表少用关联越少越好
2、应该有一些固定的组合用法预置出来给出范例让用户可以简单修改后就能使用
3、应该有类似于 保存为用例,查询历史等功能,让用户方便利旧。
'''
if __name__ == "__main__":
vn = VannaUtil()
# 开始训练
print("开始训练...")
# 打开AreaSchoolLesson.sql文件内容
with open("Sql/AreaSchoolLessonDDL.sql", "r", encoding="utf-8") as file:
ddl = file.read()
# 训练数据
vn.train(
ddl=ddl
)
# 添加有关业务术语或定义的文档
# vn.train(documentation="Sql/AreaSchoolLesson.md")
# 使用 SQL 进行训练
with open('Sql/AreaSchoolLessonGenerate.sql', 'r', encoding='utf-8') as file:
sql_content = file.read()
# 使用正则表达式提取注释和 SQL 语句
sql_pattern = r'/\*(.*?)\*/(.*?);'
sql_snippets = re.findall(sql_pattern, sql_content, re.DOTALL)
# 打印提取的注释和 SQL 语句
for i, (comment, sql) in enumerate(sql_snippets, 1):
vn.train(sql=comment.strip() + '\n' + sql.strip() + '\n')