94 lines
2.7 KiB
Python
94 lines
2.7 KiB
Python
|
# pip install bitsandbytes
|
|||
|
|
|||
|
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
|
|||
|
from transformers import BitsAndBytesConfig
|
|||
|
from api.doubao_client import DoubaoClient
|
|||
|
from core.ocr_service import OCRService
|
|||
|
import torch
|
|||
|
import threading
|
|||
|
|
|||
|
_doubao_client = DoubaoClient()
|
|||
|
_ocr_service = OCRService(_doubao_client)
|
|||
|
|
|||
|
image_path = r'c:\math1.jpg'
|
|||
|
|
|||
|
bnb_config = BitsAndBytesConfig(
|
|||
|
load_in_4bit=True,
|
|||
|
bnb_4bit_compute_dtype=torch.bfloat16,
|
|||
|
bnb_4bit_use_double_quant=True,
|
|||
|
bnb_4bit_quant_type="nf4"
|
|||
|
)
|
|||
|
|
|||
|
ocr_result = _ocr_service.get_ocr(image_path)
|
|||
|
|
|||
|
# model_name = "netease-youdao/Confucius3-Math"
|
|||
|
model_name = r"D:\Confucius3-Math\netease-youdao\Confucius3-Math"
|
|||
|
|
|||
|
|
|||
|
SYSTEM_PROMPT_TEMPLATE="""
|
|||
|
# 数学解题助手(严格限定初中知识)
|
|||
|
|
|||
|
你现在是一名严格按照初中数学课程标准教学的老师,必须遵守以下规则:
|
|||
|
|
|||
|
【铁律1】绝对禁止使用向量、坐标系、微积分、复数等任何超出初中数学范围的知识。
|
|||
|
【铁律2】必须使用初中生熟悉的算术方法、几何定理、代数方程等基础知识解题。
|
|||
|
【铁律3】如果题目可能涉及高级知识,你需要重新表述问题,使其符合初中水平。
|
|||
|
【铁律4】解题过程中必须明确说明每一步使用的初中数学定理或公式。
|
|||
|
【铁律5】如遇无法用初中知识解决的问题,直接说明"本题超出初中数学范围"。
|
|||
|
|
|||
|
# 解题步骤
|
|||
|
1. 用自然语言复述题目,确保理解正确
|
|||
|
2. 分析题目涉及的初中数学知识点
|
|||
|
3. 列出解题所需的定理、公式
|
|||
|
4. 分步解答,每步标注所用知识点
|
|||
|
5. 总结答案
|
|||
|
"""
|
|||
|
|
|||
|
USER_PROMPT_TEMPLATE = """{question}"""
|
|||
|
|
|||
|
# question = "1+1=?"
|
|||
|
|
|||
|
question = ocr_result
|
|||
|
|
|||
|
model = AutoModelForCausalLM.from_pretrained(
|
|||
|
model_name,
|
|||
|
# torch_dtype="auto",
|
|||
|
quantization_config=bnb_config,
|
|||
|
device_map="auto"
|
|||
|
)
|
|||
|
|
|||
|
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
|||
|
messages = [
|
|||
|
{'role': 'system', 'content': SYSTEM_PROMPT_TEMPLATE},
|
|||
|
{'role': 'user', 'content': USER_PROMPT_TEMPLATE.format(question=question)},
|
|||
|
]
|
|||
|
|
|||
|
text = tokenizer.apply_chat_template(
|
|||
|
messages,
|
|||
|
tokenize=False,
|
|||
|
add_generation_prompt=True
|
|||
|
)
|
|||
|
model_inputs = tokenizer([text], return_tensors="pt").to(model.device)
|
|||
|
|
|||
|
# 创建流式输出器
|
|||
|
streamer = TextIteratorStreamer(tokenizer, skip_special_tokens=True, timeout=10.0)
|
|||
|
|
|||
|
# 设置生成参数,添加streamer
|
|||
|
generation_kwargs = {
|
|||
|
**model_inputs,
|
|||
|
"streamer": streamer,
|
|||
|
"max_new_tokens": 32768
|
|||
|
}
|
|||
|
|
|||
|
# 创建线程来处理流式输出
|
|||
|
thread = threading.Thread(target=model.generate, kwargs=generation_kwargs)
|
|||
|
thread.start()
|
|||
|
|
|||
|
# 流式输出结果
|
|||
|
print("流式输出开始:")
|
|||
|
for chunk in streamer:
|
|||
|
if chunk:
|
|||
|
print(chunk, end="", flush=True)
|
|||
|
|
|||
|
thread.join()
|
|||
|
print("\n流式输出结束")
|