Spaces:
Runtime error
Runtime error
| #!/usr/bin/env python | |
| # coding: utf-8 | |
| import gradio as gr | |
| import numpy as np | |
| import requests | |
| from transformers import AutoModelForSequenceClassification, AutoTokenizer, TextClassificationPipeline, pipeline | |
| from langdetect import detect | |
| from matplotlib import pyplot as plt | |
| import imageio | |
| # Load the model | |
| model = AutoModelForSequenceClassification.from_pretrained("saved_model") | |
| tokenizer = AutoTokenizer.from_pretrained("saved_model") | |
| pipe = TextClassificationPipeline(model=model, tokenizer=tokenizer) | |
| # Function called by the UI | |
| def attribution(text): | |
| # Clean the plot | |
| plt.clf() | |
| # Detect the language | |
| language = detect(text) | |
| # Translate the input in german if necessary | |
| if language == 'fr': | |
| translator = pipeline("translation", model="Helsinki-NLP/opus-mt-fr-de") | |
| translatedText = translator(text[0:1000]) | |
| text = translatedText[0]["translation_text"] | |
| elif language != 'de': | |
| return "The language is not recognized, it must be either in German or in French.", None | |
| # Set the bars of the bar chart | |
| bars = "" | |
| if language == 'fr': | |
| bars = ("DDPS", "DFI", "AS-MPC", "DFJP", "DEFR", "DETEC", "DFAE", "Parl", "ChF", "DFF", "AF", "TF") | |
| else: | |
| bars = ("VBS", "EDI", "AB-BA", "EJPD", "WBF", "UVEK", "EDA", "Parl", "BK", "EFD", "BV", "BGer") | |
| # Make the prediction with the 1000 first characters | |
| results = pipe(text[0:1000], return_all_scores=True) | |
| rates = [row["score"] for row in results[0]] | |
| # Bar chart | |
| y_pos = np.arange(len(bars)) | |
| plt.barh(y_pos, rates) | |
| plt.yticks(y_pos, bars) | |
| # Set the output text | |
| name = "" | |
| maxRate = np.max(rates) | |
| maxIndex = np.argmax(rates) | |
| # ML model not sure if highest probability < 60% | |
| if maxRate < 0.6: | |
| # de / fr | |
| if language == 'de': | |
| name = "Das ML-Modell ist nicht sicher. Das Departement könnte sein : \n\n" | |
| else: | |
| name = "Le modèle ML n'est pas sûr. Le département pourrait être : \n\n" | |
| i = 0 | |
| # Show each department that has a probability > 10% | |
| while i == 0: | |
| if rates[maxIndex] >= 0.1: | |
| name = name + "\t" + str(rates[maxIndex])[2:4] + "%" + "\t\t\t\t\t" + bars[maxIndex] + "\n" | |
| rates[maxIndex] = 0 | |
| maxIndex = np.argmax(rates) | |
| else: | |
| i = 1 | |
| # ML model pretty sure, show only one department | |
| else: | |
| name = str(maxRate)[2:4] + "%" + "\t\t\t\t\t\t" + bars[maxIndex] | |
| # Save the bar chart as png and load it (enables better display) | |
| plt.savefig('rates.png') | |
| im = imageio.imread('rates.png') | |
| return name, im | |
| # display the UI | |
| interface = gr.Interface(fn=attribution, | |
| inputs=[gr.inputs.Textbox(lines=20, placeholder="Geben Sie bitte den Titel und den Sumbmitted Text des Vorstoss ein.\nVeuillez entrer le titre et le Submitted Text de la requête.")], | |
| outputs=['text', 'image']) | |
| interface.launch() |