Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import torch | |
| import random | |
| from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, AutoModelWithLMHead | |
| from sentence_splitter import SentenceSplitter, split_text_into_sentences | |
| splitter = SentenceSplitter(language='en') | |
| if torch.cuda.is_available(): | |
| torch_device="cuda:0" | |
| else: | |
| torch_device="cpu" | |
| ptokenizer = AutoTokenizer.from_pretrained("tuner007/pegasus_paraphrase") | |
| pmodel = AutoModelForSeq2SeqLM.from_pretrained("tuner007/pegasus_paraphrase").to(torch_device) | |
| def get_answer(input_text,num_return_sequences,num_beams): | |
| batch = ptokenizer([input_text],truncation=True,padding='longest',max_length=60, return_tensors="pt").to(torch_device) | |
| translated = pmodel.generate(**batch,max_length=60,num_beams=num_beams, num_return_sequences=num_return_sequences, temperature=1.5) | |
| tgt_text = ptokenizer.batch_decode(translated, skip_special_tokens=True) | |
| return tgt_text | |
| qtokenizer = AutoTokenizer.from_pretrained("mrm8488/t5-base-finetuned-question-generation-ap") | |
| qmodel = AutoModelWithLMHead.from_pretrained("mrm8488/t5-base-finetuned-question-generation-ap").to(torch_device) | |
| def get_question(answer, context, max_length=64): | |
| input_text = "answer: %s context: %s </s>" % (answer, context) | |
| features = qtokenizer([input_text], return_tensors='pt').to(torch_device) | |
| output = qmodel.generate(input_ids=features['input_ids'], | |
| attention_mask=features['attention_mask'], | |
| max_length=max_length) | |
| return qtokenizer.decode(output[0]) | |
| def getqna(input): | |
| input=split_text_into_sentences(text=input, language='en') | |
| if len(input)==0: | |
| answer= get_answer(input,10,10)[random.randint(0, 9)] | |
| else: | |
| sentences=[get_answer(sentence,10,10)[random.randint(0, 9)] for sentence in input] | |
| answer= " ".join(sentences) | |
| answer= get_answer(answer,10,10)[random.randint(0, 9)] | |
| question= get_question(answer, input).replace("<pad>","").replace("</s>","") | |
| return "%s \n answer:%s" % (question, answer) | |
| app = gr.Interface(fn=getqna, inputs="text", outputs="text") | |
| app.launch() | |