Files
ocr/Test101.py
2025-08-14 16:04:59 +08:00

94 lines
2.7 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# pip install bitsandbytes
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
from transformers import BitsAndBytesConfig
from api.doubao_client import DoubaoClient
from core.ocr_service import OCRService
import torch
import threading
_doubao_client = DoubaoClient()
_ocr_service = OCRService(_doubao_client)
image_path = r'c:\math1.jpg'
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_compute_dtype=torch.bfloat16,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4"
)
ocr_result = _ocr_service.get_ocr(image_path)
# model_name = "netease-youdao/Confucius3-Math"
model_name = r"D:\Confucius3-Math\netease-youdao\Confucius3-Math"
SYSTEM_PROMPT_TEMPLATE="""
# 数学解题助手(严格限定初中知识)
你现在是一名严格按照初中数学课程标准教学的老师,必须遵守以下规则:
【铁律1】绝对禁止使用向量、坐标系、微积分、复数等任何超出初中数学范围的知识。
【铁律2】必须使用初中生熟悉的算术方法、几何定理、代数方程等基础知识解题。
【铁律3】如果题目可能涉及高级知识你需要重新表述问题使其符合初中水平。
【铁律4】解题过程中必须明确说明每一步使用的初中数学定理或公式。
【铁律5】如遇无法用初中知识解决的问题直接说明"本题超出初中数学范围"
# 解题步骤
1. 用自然语言复述题目,确保理解正确
2. 分析题目涉及的初中数学知识点
3. 列出解题所需的定理、公式
4. 分步解答,每步标注所用知识点
5. 总结答案
"""
USER_PROMPT_TEMPLATE = """{question}"""
# question = "1+1=?"
question = ocr_result
model = AutoModelForCausalLM.from_pretrained(
model_name,
# torch_dtype="auto",
quantization_config=bnb_config,
device_map="auto"
)
tokenizer = AutoTokenizer.from_pretrained(model_name)
messages = [
{'role': 'system', 'content': SYSTEM_PROMPT_TEMPLATE},
{'role': 'user', 'content': USER_PROMPT_TEMPLATE.format(question=question)},
]
text = tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True
)
model_inputs = tokenizer([text], return_tensors="pt").to(model.device)
# 创建流式输出器
streamer = TextIteratorStreamer(tokenizer, skip_special_tokens=True, timeout=10.0)
# 设置生成参数添加streamer
generation_kwargs = {
**model_inputs,
"streamer": streamer,
"max_new_tokens": 32768
}
# 创建线程来处理流式输出
thread = threading.Thread(target=model.generate, kwargs=generation_kwargs)
thread.start()
# 流式输出结果
print("流式输出开始:")
for chunk in streamer:
if chunk:
print(chunk, end="", flush=True)
thread.join()
print("\n流式输出结束")