You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

64 lines
2.2 KiB

4 months ago
from openai import OpenAI
from Config import MODEL_NAME, MODEL_API_KEY
# 初始化 OpenAI 客户端
client = OpenAI(
api_key=MODEL_API_KEY,
base_url="https://dashscope.aliyuncs.com/compatible-mode/v1",
)
# 利用AI获取数据集的X轴和Y轴的列名
def generate_columns_with_ai(data):
"""
利用大模型解析数据生成 category_columns_str value_column_str
:param data: 数据集列表字典格式例如[{"行政区划名": "二道区", "学校名称": "清华附中", "课程数量": 100}, ...]
:param api_key: OpenAI API 密钥
:return: (category_columns_str, value_column_str)
"""
# 获取所有字段名
columns = list(data[0].keys()) if data else []
# 构造提示词
prompt = f"""
给定以下字段名列表{columns}请分析并回答以下问题
1. 哪些字段适合作为分类字段category_columns_str请用逗号分隔
2. 哪个字段适合作为数值字段value_column_str
3. 以JSON格式返回结果,但不要输出 json``` ```
返回格式
category_columns_str: <字段1,字段2,...>
value_column_str: <字段>
"""
# 调用大模型
response = client.chat.completions.create(
model=MODEL_NAME,
messages=[
{"role": "system",
"content": "你是一个专业的语义分类助手。"},
{"role": "user", "content": prompt}
],
max_tokens=500
)
# 解析模型返回的结果
result = response.choices[0].message.content
category_columns_str = result.split("category_columns_str: ")[1].split("\n")[0].strip()
value_column_str = result.split("value_column_str: ")[1].strip()
return category_columns_str, value_column_str
# 示例数据
data = [
{"行政区划名": "二道区", "学校名称": "清华附中", "课程数量": 100},
{"行政区划名": "朝阳区", "学校名称": "复旦附中", "课程数量": 120},
]
# 调用函数
category_columns_str, value_column_str = generate_columns_with_ai(data)
print(f"category_columns_str: {category_columns_str}")
print(f"value_column_str: {value_column_str}")