|
|
|
|
|
import os, sys, json, torch
|
|
|
import torch.nn.functional as F
|
|
|
from transformers import AutoTokenizer
|
|
|
from safetensors.torch import load_file
|
|
|
|
|
|
|
|
|
sys.path.append(os.path.join(os.path.dirname(__file__), "common"))
|
|
|
from models import create_model_by_name
|
|
|
|
|
|
def load_model(model_dir: str):
|
|
|
cfg_path = os.path.join(model_dir, "config.json")
|
|
|
w_path = os.path.join(model_dir, "model.safetensors")
|
|
|
if not (os.path.exists(cfg_path) and os.path.exists(w_path)):
|
|
|
raise FileNotFoundError("config.json หรือ model.safetensors ไม่ครบ")
|
|
|
|
|
|
with open(cfg_path, "r", encoding="utf-8") as f:
|
|
|
cfg = json.load(f)
|
|
|
|
|
|
tok = AutoTokenizer.from_pretrained(cfg["base_model"])
|
|
|
model = create_model_by_name(cfg["arch"])
|
|
|
state = load_file(w_path)
|
|
|
model.load_state_dict(state)
|
|
|
model.eval()
|
|
|
return model, tok, cfg
|
|
|
|
|
|
def predict(texts, model, tok, cfg):
|
|
|
enc = tok(texts, padding=True, truncation=True, max_length=cfg["max_len"], return_tensors="pt")
|
|
|
with torch.no_grad():
|
|
|
logits = model(enc["input_ids"], enc["attention_mask"])
|
|
|
prob = F.softmax(logits, dim=1).cpu().numpy()
|
|
|
pred = prob.argmax(1)
|
|
|
return pred, prob
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
|
|
MODEL_DIR = sys.argv[1] if len(sys.argv) > 1 else "cnn_bilstm"
|
|
|
|
|
|
model, tok, cfg = load_model(MODEL_DIR)
|
|
|
xs = ["อาหารอร่อยมาก บริการดี", "ไม่ประทับใจเลย ช้ามาก"]
|
|
|
y, p = predict(xs, model, tok, cfg)
|
|
|
labels = ["negative", "positive"]
|
|
|
for t, yy, pp in zip(xs, y, p):
|
|
|
print(f"{t} => {labels[yy]} | prob={pp}")
|
|
|
|