resilience commited on
Commit
b3b6acd
·
verified ·
1 Parent(s): cf2fadc

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +41 -0
app.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import transformers
2
+ from transformers import AutoConfig, AutoTokenizer, AutoModelForSequenceClassification
3
+ import torch
4
+ import gradio as gr
5
+
6
+ model_dir = "./experiments/checkpoint-382"
7
+
8
+ config = AutoConfig.from_pretrained(model_dir, num_labels=3, finetuning_task="text-classification")
9
+ tokenizer = AutoTokenizer.from_pretrained(model_dir)
10
+ model = AutoModelForSequenceClassification.from_pretrained(model_dir, config=config)
11
+
12
+ def inference(input_text):
13
+ inputs = tokenizer.batch_encode_plus([input_text], return_tensors="pt", padding=True, max_length=512, truncation=True, padding="max_length")
14
+
15
+ with torch.no_grad():
16
+ logits = model(**inputs)["logits"]
17
+
18
+ predicted_class = torch.argmax(logits, dim=1).item()
19
+ output = model.config.id2label[predicted_class]
20
+ return output
21
+
22
+ with gr.Blocks(css=""".message.svelte-w6rprc.svelte-w6rprc.svelte-w6rprc {font-size: 1.5em; padding: 1em; text-align: center;} #component-21 > div.wrap.svelte-w6rprc {height: 600px;}""") as demo:
23
+ with gr.Row():
24
+ with gr.Column():
25
+ input_text = gr.Textbox(label="Input Text", scale=2, container=False)
26
+ answer = gr.Output(label="Output", lines=0)
27
+ generate_btn = gr.Button(text="Generate", type="primary", scale=2)
28
+
29
+ inputs = [input_text]
30
+ outputs = [answer]
31
+
32
+ generate_btn.click(fn=inference, inputs=inputs, outputs=outputs, show_progress=True)
33
+
34
+ examples = [
35
+ ["I love this movie!"],
36
+ ["I hate this movie!"],
37
+ ["I feel neutral about this movie!"]
38
+ ]
39
+
40
+ demo.queue()
41
+ demo.launch()