63 lines
2.2 KiB
Python
63 lines
2.2 KiB
Python
from openai import OpenAI
|
||
|
||
# 阿里云中用来调用 deepseek v3 的密钥
|
||
MODEL_API_KEY = "sk-01d13a39e09844038322108ecdbd1bbc"
|
||
#MODEL_NAME = "qwen-plus"
|
||
MODEL_NAME="deepseek-v3"
|
||
|
||
# 初始化 OpenAI 客户端
|
||
client = OpenAI(
|
||
api_key=MODEL_API_KEY,
|
||
base_url="https://dashscope.aliyuncs.com/compatible-mode/v1",
|
||
)
|
||
|
||
def generate_sql_from_prompt(ddl: str, prompt: str) -> str:
|
||
"""
|
||
根据 DDL 和自然语言描述生成 SQL 查询
|
||
:param ddl: 数据库表结构的 DDL
|
||
:param prompt: 自然语言描述
|
||
:return: 生成的 SQL 查询
|
||
"""
|
||
# 构建完整的提示词
|
||
full_prompt = (
|
||
f"以下是数据库表结构的 DDL:\n\n{ddl}\n\n"
|
||
f"请根据以下描述生成 SQL 查询:\n\n{prompt}\n\n"
|
||
"生成的 SQL 查询:"
|
||
)
|
||
|
||
# 调用大模型
|
||
response = client.chat.completions.create(
|
||
model=MODEL_NAME,
|
||
messages=[
|
||
{"role": "system", "content": "你是一个专业的 SQL 生成助手,能够根据数据库表结构和自然语言描述生成正确的 SQL 查询。"},
|
||
{"role": "user", "content": full_prompt}
|
||
],
|
||
max_tokens=500
|
||
)
|
||
|
||
# 提取生成的 SQL
|
||
if response.choices and response.choices[0].message.content:
|
||
return response.choices[0].message.content.strip()
|
||
else:
|
||
raise ValueError("未能生成 SQL 查询")
|
||
|
||
if __name__ == '__main__':
|
||
# 读取 Sql/AreaSchoolLessonDDL.sql 文件
|
||
with open("../Sql/AreaSchoolLessonDDL.sql", "r", encoding="utf-8") as file:
|
||
ddl = file.read()
|
||
|
||
# 自然语言描述
|
||
prompt = "查询 2024 年每个学段下,上传课程数量排名前 10 的学校,显示排名,并按上传课程数量排序。"
|
||
|
||
common_prompt='''
|
||
要求:
|
||
1、只返回可以运行的SQL,不要描述信息和```sql 还有```
|
||
2、对于学校名称和行政区划名称等于NULL 或者为空的不要进行统计
|
||
'''
|
||
prompt = prompt + common_prompt
|
||
# 生成 SQL
|
||
try:
|
||
sql = generate_sql_from_prompt(ddl, prompt)
|
||
print(sql)
|
||
except Exception as e:
|
||
print(f"生成 SQL 时出错:{e}") |