You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.
from openai import OpenAI
# 阿里云中用来调用 deepseek v3 的密钥
MODEL_API_KEY = " sk-01d13a39e09844038322108ecdbd1bbc "
#MODEL_NAME = "qwen-plus"
MODEL_NAME = " deepseek-v3 "
# 初始化 OpenAI 客户端
client = OpenAI (
api_key = MODEL_API_KEY ,
base_url = " https://dashscope.aliyuncs.com/compatible-mode/v1 " ,
)
def generate_sql_from_prompt ( ddl : str , prompt : str ) - > str :
"""
根据 DDL 和自然语言描述生成 SQL 查询
:param ddl: 数据库表结构的 DDL
:param prompt: 自然语言描述
:return: 生成的 SQL 查询
"""
# 构建完整的提示词
full_prompt = (
f " 以下是数据库表结构的 DDL: \n \n { ddl } \n \n "
f " 请根据以下描述生成 SQL 查询: \n \n { prompt } \n \n "
" 生成的 SQL 查询: "
)
# 调用大模型
response = client . chat . completions . create (
model = MODEL_NAME ,
messages = [
{ " role " : " system " , " content " : " 你是一个专业的 SQL 生成助手,能够根据数据库表结构和自然语言描述生成正确的 SQL 查询。 " } ,
{ " role " : " user " , " content " : full_prompt }
] ,
max_tokens = 500
)
# 提取生成的 SQL
if response . choices and response . choices [ 0 ] . message . content :
return response . choices [ 0 ] . message . content . strip ( )
else :
raise ValueError ( " 未能生成 SQL 查询 " )
if __name__ == ' __main__ ' :
# 读取 Sql/AreaSchoolLessonDDL.sql 文件
with open ( " ../Sql/AreaSchoolLessonDDL.sql " , " r " , encoding = " utf-8 " ) as file :
ddl = file . read ( )
# 自然语言描述
prompt = " 查询 2024 年每个学段下,上传课程数量排名前 10 的学校,显示排名,并按上传课程数量排序。 "
common_prompt = '''
要求:
1、只返回可以运行的SQL, 不要描述信息和```sql 还有```
2、对于学校名称和行政区划名称等于NULL 或者为空的不要进行统计
'''
prompt = prompt + common_prompt
# 生成 SQL
try :
sql = generate_sql_from_prompt ( ddl , prompt )
print ( sql )
except Exception as e :
print ( f " 生成 SQL 时出错: { e } " )