Spaces:
Runtime error
Runtime error
| import transformers | |
| from transformers import AutoConfig, AutoTokenizer, AutoModelForSequenceClassification | |
| import torch | |
| import gradio as gr | |
| model_dir = "./experiments/checkpoint-382" | |
| config = AutoConfig.from_pretrained(model_dir, num_labels=3, finetuning_task="text-classification") | |
| tokenizer = AutoTokenizer.from_pretrained(model_dir) | |
| model = AutoModelForSequenceClassification.from_pretrained(model_dir, config=config) | |
| def inference(input_text): | |
| inputs = tokenizer.batch_encode_plus([input_text], return_tensors="pt", max_length=512, truncation=True, padding="max_length") | |
| with torch.no_grad(): | |
| logits = model(**inputs)["logits"] | |
| predicted_class = torch.argmax(logits, dim=1).item() | |
| output = model.config.id2label[predicted_class] | |
| return output | |
| with gr.Blocks(css=""".message.svelte-w6rprc.svelte-w6rprc.svelte-w6rprc {font-size: 1.5em; padding: 1em; text-align: center;} #component-21 > div.wrap.svelte-w6rprc {height: 600px;}""") as demo: | |
| with gr.Row(): | |
| with gr.Column(): | |
| input_text = gr.Textbox(label="Input Text", scale=2, container=False) | |
| answer = gr.Output(label="Output", lines=0) | |
| generate_btn = gr.Button(text="Generate", type="primary", scale=2) | |
| inputs = [input_text] | |
| outputs = [answer] | |
| generate_btn.click(fn=inference, inputs=inputs, outputs=outputs, show_progress=True) | |
| examples = [ | |
| ["I love this movie!"], | |
| ["I hate this movie!"], | |
| ["I feel neutral about this movie!"] | |
| ] | |
| demo.queue() | |
| demo.launch() |