You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

63 lines
2.2 KiB

5 months ago
from openai import OpenAI
# 阿里云中用来调用 deepseek v3 的密钥
MODEL_API_KEY = "sk-01d13a39e09844038322108ecdbd1bbc"
#MODEL_NAME = "qwen-plus"
MODEL_NAME="deepseek-v3"
# 初始化 OpenAI 客户端
client = OpenAI(
api_key=MODEL_API_KEY,
base_url="https://dashscope.aliyuncs.com/compatible-mode/v1",
)
def generate_sql_from_prompt(ddl: str, prompt: str) -> str:
"""
根据 DDL 和自然语言描述生成 SQL 查询
:param ddl: 数据库表结构的 DDL
:param prompt: 自然语言描述
:return: 生成的 SQL 查询
"""
# 构建完整的提示词
full_prompt = (
f"以下是数据库表结构的 DDL\n\n{ddl}\n\n"
f"请根据以下描述生成 SQL 查询:\n\n{prompt}\n\n"
"生成的 SQL 查询:"
)
# 调用大模型
response = client.chat.completions.create(
model=MODEL_NAME,
messages=[
{"role": "system", "content": "你是一个专业的 SQL 生成助手,能够根据数据库表结构和自然语言描述生成正确的 SQL 查询。"},
{"role": "user", "content": full_prompt}
],
max_tokens=500
)
# 提取生成的 SQL
if response.choices and response.choices[0].message.content:
return response.choices[0].message.content.strip()
else:
raise ValueError("未能生成 SQL 查询")
if __name__ == '__main__':
4 months ago
# 读取 Sql/AreaSchoolLessonDDL.sql 文件
with open("../Sql/AreaSchoolLessonDDL.sql", "r", encoding="utf-8") as file:
5 months ago
ddl = file.read()
# 自然语言描述
5 months ago
prompt = "查询 2024 年每个学段下,上传课程数量排名前 10 的学校,显示排名,并按上传课程数量排序。"
5 months ago
common_prompt='''
要求
5 months ago
1只返回可以运行的SQL不要描述信息和```sql 还有```
2对于学校名称和行政区划名称等于NULL 或者为空的不要进行统计
5 months ago
'''
prompt = prompt + common_prompt
# 生成 SQL
try:
sql = generate_sql_from_prompt(ddl, prompt)
print(sql)
except Exception as e:
print(f"生成 SQL 时出错:{e}")