You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

100 lines
4.0 KiB

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

# -*- coding: utf-8 -*-
import re
import time
from typing import Iterator, Tuple
from openai import OpenAI
from openai.types.chat import ChatCompletionChunk
from Config import *
class KnowledgeGraph:
def __init__(self, shiti_content: str):
self.shiti_content = shiti_content
self.client = OpenAI(api_key=MODEL_API_KEY, base_url=MODEL_API_URL)
def _generate_stream(self) -> Iterator[ChatCompletionChunk]:
"""流式生成内容"""
system_prompt = '''回答以下内容:
1. 这道题目有哪些知识点,哪些能力点
2. 生成Neo4j 5.26.2的插入语句'''
return self.client.chat.completions.create(
model=MODEL_NAME,
messages=[
{"role": "system", "content": system_prompt},
{"role": "user", "content": self.shiti_content}
],
stream=True,
timeout=300
)
def _extract_cypher(self, content: str) -> str:
"""从内容中提取Cypher语句修正版"""
# 匹配包含cypher的代码块支持可选语言声明
pattern = r"```(?:cypher)?\n(.*?)```"
matches = re.findall(pattern, content, re.DOTALL)
processed = []
for block in matches:
# 清理每行:移除注释和首尾空格
cleaned_lines = []
for line in block.split('\n'):
line = line.split('//')[0].strip() # 移除行尾注释
if line: # 保留非空行
cleaned_lines.append(line)
if cleaned_lines:
processed.append('\n'.join(cleaned_lines))
return ';\n\n'.join(processed) if processed else ""
def run(self) -> Tuple[bool, str, str]:
"""执行生成流程返回状态、完整内容、Cypher语句"""
start_time = time.time()
spinner = ['', '', '', '', '', '', '', '', '', '']
content_buffer = []
cypher_script = ""
try:
print(f"🚀 开始生成知识点和能力点的总结和插入语句")
stream = self._generate_stream()
for idx, chunk in enumerate(stream):
print(f"\r{spinner[idx % 10]} 生成中({int(time.time() - start_time)}秒)", end="")
if chunk.choices and chunk.choices[0].delta.content:
content_chunk = chunk.choices[0].delta.content
content_buffer.append(content_chunk)
if len(content_buffer) == 1:
print("\n\n📝 内容生成开始:")
print(content_chunk, end="", flush=True)
if content_buffer:
full_content = ''.join(content_buffer)
cypher_script = self._extract_cypher(full_content)
print(f"\n\n✅ 生成成功!耗时 {int(time.time() - start_time)}")
print("\n================ 完整结果 ================")
print(full_content)
print("\n================ Cypher语句 ===============")
print(cypher_script if cypher_script else "未检测到Cypher语句")
print("==========================================")
return True, full_content, cypher_script
return False, "", ""
except Exception as e:
print(f"\n\n❌ 生成失败:{str(e)}")
return False, str(e), ""
if __name__ == '__main__':
shiti_content = '''
下面是一道小学三年级的数学题目,巧求周长:
把7个完全相同的小长方形拼成如图的样子已知每个小长方形的长是10厘米则拼成的大长方形的周长是多少厘米
'''
kg = KnowledgeGraph(shiti_content)
success, result, cypher = kg.run()
if success and cypher:
with open("knowledge_graph.cypher", "w", encoding="utf-8") as f:
f.write(cypher)
print("\nCypher语句已保存至 knowledge_graph.cypher")