|
|
|
@ -1,15 +1,13 @@
|
|
|
|
|
# -*- coding: utf-8 -*-
|
|
|
|
|
import time
|
|
|
|
|
from pathlib import Path
|
|
|
|
|
from typing import Iterator, Optional
|
|
|
|
|
|
|
|
|
|
from dashscope import Generation
|
|
|
|
|
from dashscope.api_entities.dashscope_response import DashScopeAPIResponse
|
|
|
|
|
|
|
|
|
|
from openai import OpenAI, APIError, APITimeoutError
|
|
|
|
|
from Config import *
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class MarkdownGenerator:
|
|
|
|
|
"""Markdown教学大纲生成器"""
|
|
|
|
|
"""Markdown教学大纲生成器(OpenAI版本)"""
|
|
|
|
|
|
|
|
|
|
def __init__(
|
|
|
|
|
self,
|
|
|
|
@ -27,11 +25,12 @@ class MarkdownGenerator:
|
|
|
|
|
self.course_name = course_name
|
|
|
|
|
self.template_path = template_path or DEFAULT_TEMPLATE
|
|
|
|
|
self.output_path = output_path or DEFAULT_OUTPUT_DIR / f"{course_name}.md"
|
|
|
|
|
self.client = OpenAI(api_key=API_KEY, base_url=MODEL_URL) # 初始化OpenAI客户端
|
|
|
|
|
|
|
|
|
|
self._validate_paths()
|
|
|
|
|
|
|
|
|
|
def _validate_paths(self):
|
|
|
|
|
"""路径验证"""
|
|
|
|
|
"""路径验证(保持不变)"""
|
|
|
|
|
if not self.template_path.exists():
|
|
|
|
|
raise FileNotFoundError(f"模板文件不存在: {self.template_path}")
|
|
|
|
|
|
|
|
|
@ -40,11 +39,11 @@ class MarkdownGenerator:
|
|
|
|
|
raise NotADirectoryError(f"无效的输出目录: {self.output_path.parent}")
|
|
|
|
|
|
|
|
|
|
def _load_template(self) -> str:
|
|
|
|
|
"""加载模板内容"""
|
|
|
|
|
"""加载模板内容(保持不变)"""
|
|
|
|
|
return self.template_path.read_text(encoding='utf-8')
|
|
|
|
|
|
|
|
|
|
def _generate_stream(self) -> Iterator[DashScopeAPIResponse]:
|
|
|
|
|
"""流式生成内容"""
|
|
|
|
|
def _generate_stream(self) -> Iterator:
|
|
|
|
|
"""流式生成内容(修改为OpenAI格式)"""
|
|
|
|
|
system_prompt = (
|
|
|
|
|
"请严格按照以下Markdown格式生成教学大纲:\n"
|
|
|
|
|
f"{self._load_template()}\n"
|
|
|
|
@ -55,49 +54,55 @@ class MarkdownGenerator:
|
|
|
|
|
"4. 内容中不要包含图片\n"
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
return Generation.call(
|
|
|
|
|
return self.client.chat.completions.create(
|
|
|
|
|
model=MODEL_R1,
|
|
|
|
|
api_key=API_KEY,
|
|
|
|
|
messages=[
|
|
|
|
|
{"role": "system", "content": system_prompt},
|
|
|
|
|
{"role": "user", "content": f"请生成《{self.course_name}》教学大纲"}
|
|
|
|
|
],
|
|
|
|
|
result_format='message',
|
|
|
|
|
stream=True,
|
|
|
|
|
incremental_output=True
|
|
|
|
|
timeout=30 # 设置超时时间
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
def run(self) -> bool:
|
|
|
|
|
"""执行生成流程"""
|
|
|
|
|
"""执行生成流程(调整流式处理逻辑)"""
|
|
|
|
|
start_time = time.time()
|
|
|
|
|
spinner = ['⠋', '⠙', '⠹', '⠸', '⠼', '⠴', '⠦', '⠧', '⠇', '⠏']
|
|
|
|
|
content_buffer = []
|
|
|
|
|
|
|
|
|
|
try:
|
|
|
|
|
print(f"🚀 开始生成【{self.course_name}】教学大纲")
|
|
|
|
|
responses = self._generate_stream()
|
|
|
|
|
stream = self._generate_stream()
|
|
|
|
|
|
|
|
|
|
for idx, response in enumerate(responses):
|
|
|
|
|
# 显示进度
|
|
|
|
|
for idx, chunk in enumerate(stream):
|
|
|
|
|
print(f"\r{spinner[idx % 10]} 生成中({int(time.time() - start_time)}秒)", end="")
|
|
|
|
|
|
|
|
|
|
if response.status_code == 200 and response.output:
|
|
|
|
|
if chunk := response.output.choices[0].message.content:
|
|
|
|
|
content_buffer.append(chunk)
|
|
|
|
|
if len(content_buffer) == 1:
|
|
|
|
|
print("\n\n📝 内容生成开始:")
|
|
|
|
|
print(chunk, end="", flush=True)
|
|
|
|
|
# 处理OpenAI流式响应
|
|
|
|
|
if chunk.choices and chunk.choices[0].delta.content:
|
|
|
|
|
content_chunk = chunk.choices[0].delta.content
|
|
|
|
|
content_buffer.append(content_chunk)
|
|
|
|
|
|
|
|
|
|
if len(content_buffer) == 1:
|
|
|
|
|
print("\n\n📝 内容生成开始:")
|
|
|
|
|
print(content_chunk, end="", flush=True)
|
|
|
|
|
|
|
|
|
|
# 保存结果
|
|
|
|
|
if content_buffer:
|
|
|
|
|
self.output_path.write_text(''.join(content_buffer), encoding='utf-8')
|
|
|
|
|
full_content = ''.join(content_buffer)
|
|
|
|
|
self.output_path.write_text(full_content, encoding='utf-8')
|
|
|
|
|
print(f"\n\n✅ 生成成功!耗时 {int(time.time() - start_time)}秒")
|
|
|
|
|
print(f"📂 文件已保存至:{self.output_path}")
|
|
|
|
|
return True
|
|
|
|
|
return False
|
|
|
|
|
|
|
|
|
|
except APITimeoutError as e:
|
|
|
|
|
print(f"\n\n❌ 请求超时:{str(e)}")
|
|
|
|
|
return False
|
|
|
|
|
except APIError as e:
|
|
|
|
|
print(f"\n\n❌ API错误:{str(e)}")
|
|
|
|
|
return False
|
|
|
|
|
except Exception as e:
|
|
|
|
|
print(f"\n\n❌ 生成失败:{str(e)}")
|
|
|
|
|
print(f"\n\n❌ 未处理异常:{str(e)}")
|
|
|
|
|
return False
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@ -107,11 +112,7 @@ def generate_document(
|
|
|
|
|
template_path: Optional[str] = None
|
|
|
|
|
):
|
|
|
|
|
"""
|
|
|
|
|
生成文档入口函数
|
|
|
|
|
|
|
|
|
|
:param course_name: 课程名称
|
|
|
|
|
:param output_path: 输出路径(可选)
|
|
|
|
|
:param template_path: 模板路径(可选)
|
|
|
|
|
生成文档入口函数(保持不变)
|
|
|
|
|
"""
|
|
|
|
|
try:
|
|
|
|
|
generator = MarkdownGenerator(
|
|
|
|
@ -122,9 +123,4 @@ def generate_document(
|
|
|
|
|
return generator.run()
|
|
|
|
|
except Exception as e:
|
|
|
|
|
print(f"❌ 初始化失败:{str(e)}")
|
|
|
|
|
return False
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return False
|