Files
dsProject/dsLightRag/Test/Test101.py
2025-08-14 15:45:08 +08:00

90 lines
2.5 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# pip install bitsandbytes
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
from transformers import BitsAndBytesConfig
from api.doubao_client import DoubaoClient
from core.ocr_service import OCRService
import torch
import threading
_doubao_client = DoubaoClient()
_ocr_service = OCRService(_doubao_client)
image_path = r'c:\math1.jpg'
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_compute_dtype=torch.bfloat16,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4"
)
ocr_result = _ocr_service.get_ocr(image_path)
# model_name = "netease-youdao/Confucius3-Math"
model_name = r"D:\Confucius3-Math\netease-youdao\Confucius3-Math"
SYSTEM_PROMPT_TEMPLATE="""
# 数学解题助手规则(必须严格遵守)
1. 【核心要求】必须使用初中生知识范围内可以理解的方法解题,绝对禁止使用向量、微积分、复数等高中及以上数学知识。
2. 如果是动点问题要注意可能有多个解。
3. 使用简单易懂的语言,将复杂问题分解成简单的步骤。
4. 引导学生思考,而不是直接给出答案。
# 角色信息
- language: 中文
- description: 数学解题助手是一个专门为初中生提供数学问题解答的角色。
- expertise: 初中数学解题、数学教育
- target_audience: 初中生
"""
USER_PROMPT_TEMPLATE = """{question}"""
# question = "1+1=?"
question = ocr_result
model = AutoModelForCausalLM.from_pretrained(
model_name,
# torch_dtype="auto",
quantization_config=bnb_config,
device_map="auto"
)
tokenizer = AutoTokenizer.from_pretrained(model_name)
messages = [
{'role': 'system', 'content': SYSTEM_PROMPT_TEMPLATE},
{'role': 'user', 'content': USER_PROMPT_TEMPLATE.format(question=question)},
]
text = tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True
)
model_inputs = tokenizer([text], return_tensors="pt").to(model.device)
# 创建流式输出器
streamer = TextIteratorStreamer(tokenizer, skip_special_tokens=True, timeout=10.0)
# 设置生成参数添加streamer
generation_kwargs = {
**model_inputs,
"streamer": streamer,
"max_new_tokens": 32768
}
# 创建线程来处理流式输出
thread = threading.Thread(target=model.generate, kwargs=generation_kwargs)
thread.start()
# 流式输出结果
print("流式输出开始:")
for chunk in streamer:
if chunk:
print(chunk, end="", flush=True)
thread.join()
print("\n流式输出结束")