You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

221 lines
8.7 KiB

5 months ago
# -*- coding: utf-8 -*-
5 months ago
import hashlib
5 months ago
import json
import time
from typing import Iterator, Dict
5 months ago
from openai import OpenAI
from openai.types.chat import ChatCompletionChunk
5 months ago
from K2_Neo4jExecutor import *
5 months ago
5 months ago
class KnowledgeGraph:
5 months ago
def __init__(self, content: str):
self.content = content
5 months ago
self.question_id = self._generate_question_id()
5 months ago
self.graph = self._init_graph_connection()
self.existing_knowledge = self._fetch_existing_nodes("KnowledgePoint")
self.existing_ability = self._fetch_existing_nodes("AbilityPoint")
self.client = OpenAI(api_key=MODEL_API_KEY, base_url=MODEL_API_URL)
5 months ago
def _validate_ids(self, line: str) -> bool:
5 months ago
"""修正后的ID验证严格匹配数据库格式"""
# 调整正则表达式匹配6位小写hex格式
5 months ago
found_ids = {
5 months ago
'kp': set(re.findall(r'(kp_[a-f0-9]{6})', line.lower())),
'ab': set(re.findall(r'(ab_[a-f0-9]{6})', line.lower()))
5 months ago
}
5 months ago
# 直接检查小写形式
valid_kp = all(kp in self.existing_knowledge for kp in found_ids['kp'])
valid_ab = all(ab in self.existing_ability for ab in found_ids['ab'])
5 months ago
return valid_kp and valid_ab
5 months ago
5 months ago
def _init_graph_connection(self) -> Graph:
"""初始化并测试数据库连接"""
try:
graph = Graph(NEO4J_URI, auth=NEO4J_AUTH)
graph.run("RETURN 1").data()
print("✅ Neo4j连接成功")
return graph
except Exception as e:
raise ConnectionError(f"❌ 数据库连接失败: {str(e)}")
5 months ago
5 months ago
def _validate_weight(self, line: str) -> bool:
"""验证关系权重是否合法"""
weight_match = re.search(r"weight\s*:\s*([0-9.]+)", line)
if weight_match:
try:
weight = float(weight_match.group(1))
return 0.0 <= weight <= 1.0
except ValueError:
return False
return True # 没有weight属性时视为合法
def _validate_cypher_structure(self, cypher: str) -> bool:
"""验证WITH子句存在性"""
has_merge = re.search(r'\bMERGE\s*\(q:Question\b', cypher, re.IGNORECASE)
has_with = re.search(r'\bWITH\s+q\b', cypher, re.IGNORECASE)
return not has_merge or (has_merge and has_with)
5 months ago
5 months ago
def _generate_question_id(self) -> str:
"""生成题目唯一标识符"""
5 months ago
return hashlib.md5(self.content.encode()).hexdigest()[:8]
5 months ago
5 months ago
def _fetch_existing_nodes(self, label: str) -> Dict[str, str]:
"""从Neo4j获取已有节点"""
try:
cypher = f"MATCH (n:{label}) RETURN n.id as id, n.name as name"
result = self.graph.run(cypher).data()
return {item['id']: item['name'] for item in result}
except Exception as e:
print(f"❌ 节点查询失败: {str(e)}")
return {}
5 months ago
def _generate_stream(self) -> Iterator[ChatCompletionChunk]:
5 months ago
"""生成限制性提示词(添加现有节点示例)"""
# 在提示词中添加现有节点示例
knowledge_samples = '\n'.join([f"KP_{k[3:]}: {v}" for k, v in list(self.existing_knowledge.items())[:5]])
ability_samples = '\n'.join([f"AB_{a[3:]}: {v}" for a, v in list(self.existing_ability.items())[:5]])
system_prompt = f''' 将题目中涉及到的小学数学知识点、能力点进行总结并且按照以下格式生成在neo4j-community-5.26.2上的语句:
请严格使用以下已有节点ID不要创建新ID
现有知识点示例KP_后接6位小写字母/数字
{knowledge_samples}
现有能力点示例AB_后接6位小写字母/数字
{ability_samples}
生成格式要求
5 months ago
MERGE (q:Question {{id: "{self.question_id}"}})
5 months ago
SET q.content = "..."
5 months ago
WITH q
MATCH (kp:KnowledgePoint {{id: "KP_xxxxxx"}}) // 必须使用已有KP_ID
5 months ago
MERGE (q)-[:TESTS_KNOWLEDGE]->(kp)
5 months ago
WITH q
MATCH (ab:AbilityPoint {{id: "AB_xxxxxx"}}) // 必须使用已有AB_ID
5 months ago
MERGE (q)-[:REQUIRES_ABILITY]->(ab)
5 months ago
'''
5 months ago
5 months ago
return self.client.chat.completions.create(
5 months ago
model=MODEL_NAME,
5 months ago
messages=[
{"role": "system", "content": system_prompt},
5 months ago
{"role": "user", "content": self.content}
5 months ago
],
stream=True,
5 months ago
timeout=300
5 months ago
)
5 months ago
def _format_node_list(self, nodes: Dict[str, str]) -> str:
"""格式化节点列表"""
if not nodes:
return " (无相关节点)"
sample = []
for i, (k, v) in enumerate(nodes.items()):
if i >= 5:
sample.append(f" ...(共{len(nodes)}仅显示前5个")
break
sample.append(f" - {k}: {v}")
return '\n'.join(sample)
5 months ago
def _extract_cypher(self, content: str) -> str:
5 months ago
"""安全提取Cypher"""
safe_blocks = []
for block in re.findall(r"```(?:cypher)?\n(.*?)```", content, re.DOTALL):
cleaned = self._sanitize_cypher(block)
5 months ago
if cleaned:
5 months ago
safe_blocks.append(cleaned)
return ';\n\n'.join(safe_blocks) if safe_blocks else ""
5 months ago
def _sanitize_cypher(self, cypher: str) -> str:
valid_lines = []
for line_num, line in enumerate(cypher.split('\n'), 1):
# 添加调试输出
print(f"正在处理第{line_num}行: {line[:50]}...")
# 允许SET子句中的内容
if line.strip().startswith("SET"):
valid_lines.append(line)
continue
# 记录过滤原因
filter_reason = []
if re.search(r'\bCREATE\b', line, re.IGNORECASE):
filter_reason.append("包含CREATE语句")
if not self._validate_ids(line):
filter_reason.append("存在无效ID")
if not self._validate_weight(line):
filter_reason.append("权重值非法")
if filter_reason:
5 months ago
print(f"{line_num}行被过滤: {line[:50]}... | 原因: {', '.join(filter_reason)}")
5 months ago
continue
valid_lines.append(line)
return '\n'.join(valid_lines)
5 months ago
5 months ago
def run(self) -> Tuple[bool, str, str]:
5 months ago
"""执行安全生成流程"""
if not self.existing_knowledge or not self.existing_ability:
print("❌ 知识库或能力点为空,请检查数据库")
return False, "节点数据为空", ""
5 months ago
start_time = time.time()
spinner = ['', '', '', '', '', '', '', '', '', '']
content_buffer = []
try:
5 months ago
print(f"🚀 开始生成(知识点:{len(self.existing_knowledge)}个,能力点:{len(self.existing_ability)}个)")
5 months ago
stream = self._generate_stream()
5 months ago
5 months ago
for idx, chunk in enumerate(stream):
5 months ago
print(f"\r{spinner[idx % 10]} 生成中({int(time.time() - start_time)}秒)", end="")
5 months ago
if chunk.choices and chunk.choices[0].delta.content:
content_chunk = chunk.choices[0].delta.content
content_buffer.append(content_chunk)
if len(content_buffer) == 1:
print("\n\n📝 内容生成开始:")
print(content_chunk, end="", flush=True)
5 months ago
if content_buffer:
5 months ago
full_content = ''.join(content_buffer)
5 months ago
print(f"\n\n✅ 生成完成!耗时 {int(time.time() - start_time)}")
5 months ago
return True, full_content
5 months ago
print("\n⚠️ 生成完成但未获取到有效内容")
5 months ago
return False, "空内容"
5 months ago
5 months ago
# 修改run方法中的异常处理
5 months ago
except Exception as e:
5 months ago
# 修正后的代码
error_msg = str(e) if not isinstance(e, dict) else json.dumps(e)
print(f"\n\n❌ 生成失败:{error_msg}")
return False, error_msg, ""
5 months ago
if __name__ == '__main__':
5 months ago
# 准备执行
executor = K2_Neo4jExecutor(
uri=NEO4J_URI,
auth=NEO4J_AUTH
)
5 months ago
# 测试用例
test_content = '''
题目一个长方体的长是8厘米宽是5厘米高是3厘米求它的表面积是多少平方厘米
5 months ago
'''
5 months ago
try:
kg = KnowledgeGraph(test_content)
5 months ago
success, cypher = kg.run()
5 months ago
res = executor.execute_cypher_text(cypher)
print("恭喜,执行数据插入完成!")
5 months ago
except Exception as e:
5 months ago
print(f"程序初始化失败: {str(e)}")