You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

64 lines
2.2 KiB

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

from openai import OpenAI
from Config import MODEL_NAME, MODEL_API_KEY
# 初始化 OpenAI 客户端
client = OpenAI(
api_key=MODEL_API_KEY,
base_url="https://dashscope.aliyuncs.com/compatible-mode/v1",
)
# 利用AI获取数据集的X轴和Y轴的列名
def generate_columns_with_ai(data):
"""
利用大模型解析数据,生成 category_columns_str 和 value_column_str
:param data: 数据集(列表字典格式,例如:[{"行政区划名": "二道区", "学校名称": "清华附中", "课程数量": 100}, ...]
:param api_key: OpenAI API 密钥
:return: (category_columns_str, value_column_str)
"""
# 获取所有字段名
columns = list(data[0].keys()) if data else []
# 构造提示词
prompt = f"""
给定以下字段名列表:{columns},请分析并回答以下问题:
1. 哪些字段适合作为分类字段category_columns_str请用逗号分隔。
2. 哪个字段适合作为数值字段value_column_str
3. 以JSON格式返回结果,但不要输出 json``` ```
返回格式:
category_columns_str: <字段1,字段2,...>
value_column_str: <字段>
"""
# 调用大模型
response = client.chat.completions.create(
model=MODEL_NAME,
messages=[
{"role": "system",
"content": "你是一个专业的语义分类助手。"},
{"role": "user", "content": prompt}
],
max_tokens=500
)
# 解析模型返回的结果
result = response.choices[0].message.content
category_columns_str = result.split("category_columns_str: ")[1].split("\n")[0].strip()
value_column_str = result.split("value_column_str: ")[1].strip()
return category_columns_str, value_column_str
# 示例数据
data = [
{"行政区划名": "二道区", "学校名称": "清华附中", "课程数量": 100},
{"行政区划名": "朝阳区", "学校名称": "复旦附中", "课程数量": 120},
]
# 调用函数
category_columns_str, value_column_str = generate_columns_with_ai(data)
print(f"category_columns_str: {category_columns_str}")
print(f"value_column_str: {value_column_str}")