import transformers from transformers import AutoConfig, AutoTokenizer, AutoModelForSequenceClassification import torch import gradio as gr model_dir = "./experiments/checkpoint-382" config = AutoConfig.from_pretrained(model_dir, num_labels=3, finetuning_task="text-classification") tokenizer = AutoTokenizer.from_pretrained(model_dir) model = AutoModelForSequenceClassification.from_pretrained(model_dir, config=config) def inference(input_text): inputs = tokenizer.batch_encode_plus([input_text], return_tensors="pt", max_length=512, truncation=True, padding="max_length") with torch.no_grad(): logits = model(**inputs)["logits"] predicted_class = torch.argmax(logits, dim=1).item() output = model.config.id2label[predicted_class] return output with gr.Blocks(css=""".message.svelte-w6rprc.svelte-w6rprc.svelte-w6rprc {font-size: 1.5em; padding: 1em; text-align: center;} #component-21 > div.wrap.svelte-w6rprc {height: 600px;}""") as demo: with gr.Row(): with gr.Column(): input_text = gr.Textbox(label="Input Text", scale=2, container=False) answer = gr.Output(label="Output", lines=0) generate_btn = gr.Button(text="Generate", type="primary", scale=2) inputs = [input_text] outputs = [answer] generate_btn.click(fn=inference, inputs=inputs, outputs=outputs, show_progress=True) examples = [ ["I love this movie!"], ["I hate this movie!"], ["I feel neutral about this movie!"] ] demo.queue() demo.launch()