You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

126 lines
4.6 KiB

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

# -*- coding: utf-8 -*-
import time
from pathlib import Path
from typing import Iterator, Optional
from openai import OpenAI, APIError, APITimeoutError
from Config import *
class MarkdownGenerator:
"""Markdown教学大纲生成器OpenAI版本"""
def __init__(
self,
course_name: str,
output_path: Optional[Path] = None,
template_path: Optional[Path] = None
):
"""
初始化生成器
:param course_name: 课程名称(如:小学数学内角和)
:param output_path: 输出文件路径(可选)
:param template_path: 模板文件路径(可选)
"""
self.course_name = course_name
self.template_path = template_path or DEFAULT_TEMPLATE
self.output_path = output_path or DEFAULT_OUTPUT_DIR / f"{course_name}.md"
self.client = OpenAI(api_key=MODEL_API_KEY, base_url=MODEL_API_URL) # 初始化OpenAI客户端
self._validate_paths()
def _validate_paths(self):
"""路径验证(保持不变)"""
if not self.template_path.exists():
raise FileNotFoundError(f"模板文件不存在: {self.template_path}")
self.output_path.parent.mkdir(parents=True, exist_ok=True)
if not self.output_path.parent.is_dir():
raise NotADirectoryError(f"无效的输出目录: {self.output_path.parent}")
def _load_template(self) -> str:
"""加载模板内容(保持不变)"""
return self.template_path.read_text(encoding='utf-8')
def _generate_stream(self) -> Iterator:
"""流式生成内容修改为OpenAI格式"""
system_prompt = (
"请严格按照以下Markdown格式生成教学大纲\n"
f"{self._load_template()}\n"
"注意:\n"
"1. 使用#/##标记标题\n"
"2. 每页内容以>开头\n"
"3. 列表项使用-标记\n"
"4. 内容中不要包含图片\n"
)
return self.client.chat.completions.create(
model=MODEL_NAME,
messages=[
{"role": "system", "content": system_prompt},
{"role": "user", "content": f"请生成《{self.course_name}》教学大纲"}
],
stream=True,
timeout=100 # 设置超时时间
)
def run(self) -> bool:
"""执行生成流程(调整流式处理逻辑)"""
start_time = time.time()
spinner = ['', '', '', '', '', '', '', '', '', '']
content_buffer = []
try:
print(f"🚀 开始生成【{self.course_name}】教学大纲")
stream = self._generate_stream()
for idx, chunk in enumerate(stream):
print(f"\r{spinner[idx % 10]} 生成中({int(time.time() - start_time)}秒)", end="")
# 处理OpenAI流式响应
if chunk.choices and chunk.choices[0].delta.content:
content_chunk = chunk.choices[0].delta.content
content_buffer.append(content_chunk)
if len(content_buffer) == 1:
print("\n\n📝 内容生成开始:")
print(content_chunk, end="", flush=True)
# 保存结果
if content_buffer:
full_content = ''.join(content_buffer)
self.output_path.write_text(full_content, encoding='utf-8')
print(f"\n\n✅ 生成成功!耗时 {int(time.time() - start_time)}")
print(f"📂 文件已保存至:{self.output_path}")
return True
return False
except APITimeoutError as e:
print(f"\n\n❌ 请求超时:{str(e)}")
return False
except APIError as e:
print(f"\n\n❌ API错误{str(e)}")
return False
except Exception as e:
print(f"\n\n❌ 未处理异常:{str(e)}")
return False
def generate_document(
course_name: str,
output_path: Optional[str] = None,
template_path: Optional[str] = None
):
"""
生成文档入口函数(保持不变)
"""
try:
generator = MarkdownGenerator(
course_name=course_name,
output_path=Path(output_path) if output_path else None,
template_path=Path(template_path) if template_path else None
)
return generator.run()
except Exception as e:
print(f"❌ 初始化失败:{str(e)}")
return False