thai-sentiment / infer.py
Dusit-P's picture
Upload 13 files
48e0979 verified
# infer.py
import os, sys, json, torch
import torch.nn.functional as F
from transformers import AutoTokenizer
from safetensors.torch import load_file
# ใช้สถาปัตยกรรมร่วม
sys.path.append(os.path.join(os.path.dirname(__file__), "common"))
from models import create_model_by_name
def load_model(model_dir: str):
cfg_path = os.path.join(model_dir, "config.json")
w_path = os.path.join(model_dir, "model.safetensors")
if not (os.path.exists(cfg_path) and os.path.exists(w_path)):
raise FileNotFoundError("config.json หรือ model.safetensors ไม่ครบ")
with open(cfg_path, "r", encoding="utf-8") as f:
cfg = json.load(f)
tok = AutoTokenizer.from_pretrained(cfg["base_model"])
model = create_model_by_name(cfg["arch"])
state = load_file(w_path)
model.load_state_dict(state)
model.eval()
return model, tok, cfg
def predict(texts, model, tok, cfg):
enc = tok(texts, padding=True, truncation=True, max_length=cfg["max_len"], return_tensors="pt")
with torch.no_grad():
logits = model(enc["input_ids"], enc["attention_mask"])
prob = F.softmax(logits, dim=1).cpu().numpy()
pred = prob.argmax(1)
return pred, prob
if __name__ == "__main__":
# เลือกโฟลเดอร์โมเดล: "baseline" หรือ "cnn_bilstm"
MODEL_DIR = sys.argv[1] if len(sys.argv) > 1 else "cnn_bilstm"
model, tok, cfg = load_model(MODEL_DIR)
xs = ["อาหารอร่อยมาก บริการดี", "ไม่ประทับใจเลย ช้ามาก"]
y, p = predict(xs, model, tok, cfg)
labels = ["negative", "positive"]
for t, yy, pp in zip(xs, y, p):
print(f"{t} => {labels[yy]} | prob={pp}")