You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

202 lines
8.1 KiB

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

# -*- coding: utf-8 -*-
import hashlib
import json
import time
from typing import Iterator, Dict
from openai import OpenAI
from openai.types.chat import ChatCompletionChunk
from K2_Neo4jExecutor import *
class KnowledgeGraph:
def __init__(self, content: str):
self.content = content
self.question_id = self._generate_question_id()
self.graph = self._init_graph_connection()
self.existing_knowledge = self._fetch_existing_nodes("KnowledgePoint")
self.existing_ability = self._fetch_existing_nodes("AbilityPoint")
self.client = OpenAI(api_key=MODEL_API_KEY, base_url=MODEL_API_URL)
def _validate_ids(self, line: str) -> bool:
"""修正后的ID验证严格匹配数据库格式"""
# 调整正则表达式匹配6位小写hex格式
found_ids = {
'kp': set(re.findall(r'(kp_[a-f0-9]{6})', line.lower())),
'ab': set(re.findall(r'(ab_[a-f0-9]{6})', line.lower()))
}
# 直接检查小写形式
valid_kp = all(kp in self.existing_knowledge for kp in found_ids['kp'])
valid_ab = all(ab in self.existing_ability for ab in found_ids['ab'])
return valid_kp and valid_ab
def _init_graph_connection(self) -> Graph:
"""初始化并测试数据库连接"""
try:
graph = Graph(NEO4J_URI, auth=NEO4J_AUTH)
graph.run("RETURN 1").data()
print("✅ Neo4j连接成功")
return graph
except Exception as e:
raise ConnectionError(f"❌ 数据库连接失败: {str(e)}")
def _validate_weight(self, line: str) -> bool:
"""验证关系权重是否合法"""
weight_match = re.search(r"weight\s*:\s*([0-9.]+)", line)
if weight_match:
try:
weight = float(weight_match.group(1))
return 0.0 <= weight <= 1.0
except ValueError:
return False
return True # 没有weight属性时视为合法
def _validate_cypher_structure(self, cypher: str) -> bool:
"""验证WITH子句存在性"""
has_merge = re.search(r'\bMERGE\s*\(q:Question\b', cypher, re.IGNORECASE)
has_with = re.search(r'\bWITH\s+q\b', cypher, re.IGNORECASE)
return not has_merge or (has_merge and has_with)
def _generate_question_id(self) -> str:
"""生成题目唯一标识符"""
return hashlib.md5(self.content.encode()).hexdigest()[:8]
def _fetch_existing_nodes(self, label: str) -> Dict[str, str]:
"""从Neo4j获取已有节点"""
try:
cypher = f"MATCH (n:{label}) RETURN n.id as id, n.name as name"
result = self.graph.run(cypher).data()
return {item['id']: item['name'] for item in result}
except Exception as e:
print(f"❌ 节点查询失败: {str(e)}")
return {}
def _generate_stream(self) -> Iterator[ChatCompletionChunk]:
"""生成限制性提示词(添加现有节点示例)"""
# 在提示词中添加现有节点示例
knowledge_samples = '\n'.join([f"KP_{k[3:]}: {v}" for k, v in list(self.existing_knowledge.items())[:5]])
ability_samples = '\n'.join([f"AB_{a[3:]}: {v}" for a, v in list(self.existing_ability.items())[:5]])
system_prompt = f''' 将题目中涉及到的小学数学知识点、能力点进行总结并且按照以下格式生成在neo4j-community-5.26.2上的语句:
请严格使用以下已有节点ID不要创建新ID
现有知识点示例KP_后接6位小写字母/数字):
{knowledge_samples}
现有能力点示例AB_后接6位小写字母/数字):
{ability_samples}
生成格式要求:
MERGE (q:Question {{id: "{self.question_id}"}})
SET q.content = "..."
WITH q
MATCH (kp:KnowledgePoint {{id: "KP_xxxxxx"}}) // 必须使用已有KP_ID
MERGE (q)-[:TESTS_KNOWLEDGE]->(kp)
WITH q
MATCH (ab:AbilityPoint {{id: "AB_xxxxxx"}}) // 必须使用已有AB_ID
MERGE (q)-[:REQUIRES_ABILITY]->(ab)
'''
return self.client.chat.completions.create(
model=MODEL_NAME,
messages=[
{"role": "system", "content": system_prompt},
{"role": "user", "content": self.content}
],
stream=True,
timeout=300
)
def _format_node_list(self, nodes: Dict[str, str]) -> str:
"""格式化节点列表"""
if not nodes:
return " (无相关节点)"
sample = []
for i, (k, v) in enumerate(nodes.items()):
if i >= 5:
sample.append(f" ...(共{len(nodes)}仅显示前5个")
break
sample.append(f" - {k}: {v}")
return '\n'.join(sample)
def _extract_cypher(self, content: str) -> str:
"""安全提取Cypher"""
safe_blocks = []
for block in re.findall(r"```(?:cypher)?\n(.*?)```", content, re.DOTALL):
cleaned = self._sanitize_cypher(block)
if cleaned:
safe_blocks.append(cleaned)
return ';\n\n'.join(safe_blocks) if safe_blocks else ""
def _sanitize_cypher(self, cypher: str) -> str:
valid_lines = []
for line_num, line in enumerate(cypher.split('\n'), 1):
# 添加调试输出
print(f"正在处理第{line_num}行: {line[:50]}...")
# 允许SET子句中的内容
if line.strip().startswith("SET"):
valid_lines.append(line)
continue
# 记录过滤原因
filter_reason = []
if re.search(r'\bCREATE\b', line, re.IGNORECASE):
filter_reason.append("包含CREATE语句")
if not self._validate_ids(line):
filter_reason.append("存在无效ID")
if not self._validate_weight(line):
filter_reason.append("权重值非法")
if filter_reason:
print(f"{line_num}行被过滤: {line[:50]}... | 原因: {', '.join(filter_reason)}")
continue
valid_lines.append(line)
return '\n'.join(valid_lines)
def run(self) -> Tuple[bool, str, str]:
"""执行安全生成流程"""
if not self.existing_knowledge or not self.existing_ability:
print("❌ 知识库或能力点为空,请检查数据库")
return False, "节点数据为空", ""
start_time = time.time()
spinner = ['', '', '', '', '', '', '', '', '', '']
content_buffer = []
try:
print(f"🚀 开始生成(知识点:{len(self.existing_knowledge)}个,能力点:{len(self.existing_ability)}个)")
stream = self._generate_stream()
for idx, chunk in enumerate(stream):
print(f"\r{spinner[idx % 10]} 生成中({int(time.time() - start_time)}秒)", end="")
if chunk.choices and chunk.choices[0].delta.content:
content_chunk = chunk.choices[0].delta.content
content_buffer.append(content_chunk)
if len(content_buffer) == 1:
print("\n\n📝 内容生成开始:")
print(content_chunk, end="", flush=True)
if content_buffer:
full_content = ''.join(content_buffer)
print(f"\n\n✅ 生成完成!耗时 {int(time.time() - start_time)}")
return True, full_content
print("\n⚠️ 生成完成但未获取到有效内容")
return False, "空内容"
# 修改run方法中的异常处理
except Exception as e:
# 修正后的代码
error_msg = str(e) if not isinstance(e, dict) else json.dumps(e)
print(f"\n\n❌ 生成失败:{error_msg}")
return False, error_msg, ""