|
|
|
|
|
|
|
|
import torch |
|
|
from transformers import AutoTokenizer |
|
|
from mCLM.model.qwen_based.model import Qwen2ForCausalLM |
|
|
from mCLM.tokenizer.molecule_tokenizer import MoleculeTokenizer |
|
|
|
|
|
|
|
|
model = Qwen2ForCausalLM.from_pretrained( |
|
|
"YOUR_REPO_ID", |
|
|
torch_dtype=torch.bfloat16, |
|
|
device_map="auto" |
|
|
) |
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained("YOUR_REPO_ID") |
|
|
tokenizer.pad_token = tokenizer.eos_token |
|
|
|
|
|
|
|
|
torch.serialization.add_safe_globals([MoleculeTokenizer]) |
|
|
molecule_tokenizer = torch.load("molecule_tokenizer.pth", weights_only=False) |
|
|
|
|
|
|
|
|
user_input = "What is aspirin used for?" |
|
|
messages = [{"role": "user", "content": user_input}] |
|
|
inputs = tokenizer.apply_chat_template(messages, tokenize=True, add_generation_prompt=True, return_tensors="pt") |
|
|
|
|
|
outputs = model.generate(input_ids=inputs, max_new_tokens=256) |
|
|
response = tokenizer.decode(outputs[0], skip_special_tokens=True) |
|
|
print(response) |
|
|
|