90 lines
2.5 KiB
Python
90 lines
2.5 KiB
Python
# pip install bitsandbytes
|
||
|
||
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
|
||
from transformers import BitsAndBytesConfig
|
||
from api.doubao_client import DoubaoClient
|
||
from core.ocr_service import OCRService
|
||
import torch
|
||
import threading
|
||
|
||
_doubao_client = DoubaoClient()
|
||
_ocr_service = OCRService(_doubao_client)
|
||
|
||
image_path = r'c:\math1.jpg'
|
||
|
||
bnb_config = BitsAndBytesConfig(
|
||
load_in_4bit=True,
|
||
bnb_4bit_compute_dtype=torch.bfloat16,
|
||
bnb_4bit_use_double_quant=True,
|
||
bnb_4bit_quant_type="nf4"
|
||
)
|
||
|
||
ocr_result = _ocr_service.get_ocr(image_path)
|
||
|
||
# model_name = "netease-youdao/Confucius3-Math"
|
||
model_name = r"D:\Confucius3-Math\netease-youdao\Confucius3-Math"
|
||
|
||
|
||
SYSTEM_PROMPT_TEMPLATE="""
|
||
# 数学解题助手规则(必须严格遵守)
|
||
|
||
1. 【核心要求】必须使用初中生知识范围内可以理解的方法解题,绝对禁止使用向量、微积分、复数等高中及以上数学知识。
|
||
2. 如果是动点问题要注意可能有多个解。
|
||
3. 使用简单易懂的语言,将复杂问题分解成简单的步骤。
|
||
4. 引导学生思考,而不是直接给出答案。
|
||
|
||
# 角色信息
|
||
- language: 中文
|
||
- description: 数学解题助手是一个专门为初中生提供数学问题解答的角色。
|
||
- expertise: 初中数学解题、数学教育
|
||
- target_audience: 初中生
|
||
"""
|
||
|
||
USER_PROMPT_TEMPLATE = """{question}"""
|
||
|
||
# question = "1+1=?"
|
||
|
||
question = ocr_result
|
||
|
||
model = AutoModelForCausalLM.from_pretrained(
|
||
model_name,
|
||
# torch_dtype="auto",
|
||
quantization_config=bnb_config,
|
||
device_map="auto"
|
||
)
|
||
|
||
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
||
messages = [
|
||
{'role': 'system', 'content': SYSTEM_PROMPT_TEMPLATE},
|
||
{'role': 'user', 'content': USER_PROMPT_TEMPLATE.format(question=question)},
|
||
]
|
||
|
||
text = tokenizer.apply_chat_template(
|
||
messages,
|
||
tokenize=False,
|
||
add_generation_prompt=True
|
||
)
|
||
model_inputs = tokenizer([text], return_tensors="pt").to(model.device)
|
||
|
||
# 创建流式输出器
|
||
streamer = TextIteratorStreamer(tokenizer, skip_special_tokens=True, timeout=10.0)
|
||
|
||
# 设置生成参数,添加streamer
|
||
generation_kwargs = {
|
||
**model_inputs,
|
||
"streamer": streamer,
|
||
"max_new_tokens": 32768
|
||
}
|
||
|
||
# 创建线程来处理流式输出
|
||
thread = threading.Thread(target=model.generate, kwargs=generation_kwargs)
|
||
thread.start()
|
||
|
||
# 流式输出结果
|
||
print("流式输出开始:")
|
||
for chunk in streamer:
|
||
if chunk:
|
||
print(chunk, end="", flush=True)
|
||
|
||
thread.join()
|
||
print("\n流式输出结束") |