File size: 1,975 Bytes
533ff41
b8a9630
87f6fea
 
533ff41
b8a9630
 
 
 
533ff41
b8a9630
 
 
 
 
533ff41
b8a9630
87f6fea
b8a9630
 
533ff41
87f6fea
 
 
b8a9630
 
 
 
 
 
 
 
 
87f6fea
 
 
 
b8a9630
 
 
 
 
 
 
 
 
87f6fea
b8a9630
 
 
 
 
 
 
87f6fea
b8a9630
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
import gradio as gr
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from peft import PeftModel

# 1) Load tokenizer and base model on CPU (or GPU if available)
tokenizer = AutoTokenizer.from_pretrained("finnish-nlp/ahma-3b")
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

base_model = AutoModelForCausalLM.from_pretrained(
    "finnish-nlp/ahma-3b",
    torch_dtype=torch.float32,
    device_map={"": "cpu"}
)

# 2) Apply your fine-tuned LoRA adapter
model = PeftModel.from_pretrained(
    base_model,
    "testi123456789/elektromart"
)
model.to("cpu")
model.eval()

# 3) Instruction you fine-tuned on
INSTRUCTION = "Vastaa asiakkaan kyselyyn ystävällisesti ElektroMartin asiakaspalveluna."

def chat_fn(user_question: str, max_new_tokens: int = 100,
            temperature: float = 0.7, repetition_penalty: float = 1.25) -> str:
    # 4) Build the prompt exactly as during training
    prompt = f"[INST] {INSTRUCTION}\n{user_question} [/INST]\n"

    # 5) Tokenize & clean up
    inputs = tokenizer(prompt, return_tensors="pt")
    inputs.pop("token_type_ids", None)
    inputs = {k: v.to("cpu") for k, v in inputs.items()}

    # 6) Generate
    with torch.no_grad():
        outputs = model.generate(
            **inputs,
            max_new_tokens=max_new_tokens,
            pad_token_id=tokenizer.eos_token_id,
            do_sample=True,
            repetition_penalty=repetition_penalty
        )

    # 7) Decode only the newly generated part
    generated = outputs[0][ inputs["input_ids"].shape[-1] : ]
    answer = tokenizer.decode(generated, skip_special_tokens=True)
    return answer.strip()

# 8) Expose Gradio interface
iface = gr.Interface(
    fn=chat_fn,
    inputs=[
        gr.Textbox(label="Kysy jotain…", placeholder="Kirjoita kysymyksesi tähän"),
    ],
    outputs=gr.Textbox(label="Vastaus"),
    title="ElektroMartin Chatbotti"
)

if __name__ == "__main__":
    iface.launch()