main
黄海 5 months ago
parent 6567442144
commit 3ff11d93cf

@ -1,15 +1,13 @@
# -*- coding: utf-8 -*-
import time
from pathlib import Path
from typing import Iterator, Optional
from dashscope import Generation
from dashscope.api_entities.dashscope_response import DashScopeAPIResponse
from openai import OpenAI, APIError, APITimeoutError
from Config import *
class MarkdownGenerator:
"""Markdown教学大纲生成器"""
"""Markdown教学大纲生成器OpenAI版本"""
def __init__(
self,
@ -27,11 +25,12 @@ class MarkdownGenerator:
self.course_name = course_name
self.template_path = template_path or DEFAULT_TEMPLATE
self.output_path = output_path or DEFAULT_OUTPUT_DIR / f"{course_name}.md"
self.client = OpenAI(api_key=API_KEY, base_url=MODEL_URL) # 初始化OpenAI客户端
self._validate_paths()
def _validate_paths(self):
"""路径验证"""
"""路径验证(保持不变)"""
if not self.template_path.exists():
raise FileNotFoundError(f"模板文件不存在: {self.template_path}")
@ -40,11 +39,11 @@ class MarkdownGenerator:
raise NotADirectoryError(f"无效的输出目录: {self.output_path.parent}")
def _load_template(self) -> str:
"""加载模板内容"""
"""加载模板内容(保持不变)"""
return self.template_path.read_text(encoding='utf-8')
def _generate_stream(self) -> Iterator[DashScopeAPIResponse]:
"""流式生成内容"""
def _generate_stream(self) -> Iterator:
"""流式生成内容修改为OpenAI格式"""
system_prompt = (
"请严格按照以下Markdown格式生成教学大纲\n"
f"{self._load_template()}\n"
@ -55,49 +54,55 @@ class MarkdownGenerator:
"4. 内容中不要包含图片\n"
)
return Generation.call(
return self.client.chat.completions.create(
model=MODEL_R1,
api_key=API_KEY,
messages=[
{"role": "system", "content": system_prompt},
{"role": "user", "content": f"请生成《{self.course_name}》教学大纲"}
],
result_format='message',
stream=True,
incremental_output=True
timeout=30 # 设置超时时间
)
def run(self) -> bool:
"""执行生成流程"""
"""执行生成流程(调整流式处理逻辑)"""
start_time = time.time()
spinner = ['', '', '', '', '', '', '', '', '', '']
content_buffer = []
try:
print(f"🚀 开始生成【{self.course_name}】教学大纲")
responses = self._generate_stream()
stream = self._generate_stream()
for idx, response in enumerate(responses):
# 显示进度
for idx, chunk in enumerate(stream):
print(f"\r{spinner[idx % 10]} 生成中({int(time.time() - start_time)}秒)", end="")
if response.status_code == 200 and response.output:
if chunk := response.output.choices[0].message.content:
content_buffer.append(chunk)
if len(content_buffer) == 1:
print("\n\n📝 内容生成开始:")
print(chunk, end="", flush=True)
# 处理OpenAI流式响应
if chunk.choices and chunk.choices[0].delta.content:
content_chunk = chunk.choices[0].delta.content
content_buffer.append(content_chunk)
if len(content_buffer) == 1:
print("\n\n📝 内容生成开始:")
print(content_chunk, end="", flush=True)
# 保存结果
if content_buffer:
self.output_path.write_text(''.join(content_buffer), encoding='utf-8')
full_content = ''.join(content_buffer)
self.output_path.write_text(full_content, encoding='utf-8')
print(f"\n\n✅ 生成成功!耗时 {int(time.time() - start_time)}")
print(f"📂 文件已保存至:{self.output_path}")
return True
return False
except APITimeoutError as e:
print(f"\n\n❌ 请求超时:{str(e)}")
return False
except APIError as e:
print(f"\n\n❌ API错误{str(e)}")
return False
except Exception as e:
print(f"\n\n❌ 生成失败:{str(e)}")
print(f"\n\n未处理异常{str(e)}")
return False
@ -107,11 +112,7 @@ def generate_document(
template_path: Optional[str] = None
):
"""
生成文档入口函数
:param course_name: 课程名称
:param output_path: 输出路径可选
:param template_path: 模板路径可选
生成文档入口函数保持不变
"""
try:
generator = MarkdownGenerator(
@ -122,9 +123,4 @@ def generate_document(
return generator.run()
except Exception as e:
print(f"❌ 初始化失败:{str(e)}")
return False
return False

@ -1,30 +1,20 @@
# -*- coding: utf-8 -*-
import time
from typing import Iterator
from typing import Iterator, Tuple
from openai import OpenAI
from openai.types.chat import ChatCompletionChunk
from Config import API_KEY, MODEL_R1, MODEL_URL # 确保导入配置
from Config import API_KEY, MODEL_R1, MODEL_URL
class KnowledgeGraph:
def __init__(self, shiti_content: str):
"""
初始化生成器
"""
self.shiti_content = shiti_content
self.client = OpenAI(
api_key=API_KEY,
base_url=MODEL_URL
)
self.client = OpenAI(api_key=API_KEY, base_url=MODEL_URL)
def _generate_stream(self) -> Iterator[ChatCompletionChunk]:
"""流式生成内容"""
system_prompt = '''回答以下内容:
1. 这道题目有哪些知识点哪些能力点
2. 生成Neo4j 5.26.2的插入语句'''
return self.client.chat.completions.create(
model=MODEL_R1,
messages=[
@ -35,8 +25,8 @@ class KnowledgeGraph:
timeout=300
)
def run(self) -> bool:
"""执行生成流程"""
def run(self) -> Tuple[bool, str]:
"""执行生成流程(返回状态和完整内容)"""
start_time = time.time()
spinner = ['', '', '', '', '', '', '', '', '', '']
content_buffer = []
@ -57,13 +47,17 @@ class KnowledgeGraph:
print(content_chunk, end="", flush=True)
if content_buffer:
full_content = ''.join(content_buffer)
print(f"\n\n✅ 生成成功!耗时 {int(time.time() - start_time)}")
return True
return False
print("\n================ 完整结果 ================")
print(full_content)
print("========================================")
return True, full_content
return False, ""
except Exception as e:
print(f"\n\n❌ 生成失败:{str(e)}")
return False
return False, str(e)
if __name__ == '__main__':
@ -72,4 +66,9 @@ if __name__ == '__main__':
把7个完全相同的小长方形拼成如图的样子已知每个小长方形的长是10厘米则拼成的大长方形的周长是多少厘米
'''
kg = KnowledgeGraph(shiti_content)
kg.run()
success, result = kg.run() # 获取返回结果
# 如果需要进一步处理结果
if success:
with open("result.txt", "w", encoding="utf-8") as f:
f.write(result)
Loading…
Cancel
Save