ilushado's picture
Update app.py
6e42b95
import streamlit as st
import torch
import pandas as pd
import numpy as np
import torch
import transformers
import json
from torch.utils.data import Dataset, DataLoader
from transformers import RobertaModel, RobertaTokenizer
import transformers
idx_to_tag = {0: 'cs',
1: 'stat',
2: 'physics',
3: 'math',
4: 'q-bio',
5: 'eess',
6: 'economics, finances',
7: 'gr-qc',
8: 'hep-ex',
9: 'hep-lat'}
tag_to_idx = {'cs': 0,
'stat': 1,
'physics': 2,
'math': 3,
'q-bio': 4,
'eess': 5,
'economics, finances': 6,
'gr-qc': 7,
'hep-ex': 8,
'hep-lat': 9}
class RobertaClass(torch.nn.Module):
def __init__(self):
super(RobertaClass, self).__init__()
self.l1 = RobertaModel.from_pretrained("roberta-base")
self.pre_classifier = torch.nn.Linear(768, 768)
self.dropout = torch.nn.Dropout(0.3)
self.classifier = torch.nn.Linear(768, 10)
def forward(self, input_ids, attention_mask, token_type_ids):
output_1 = self.l1(input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids)
hidden_state = output_1[0]
pooler = hidden_state[:, 0]
pooler = self.pre_classifier(pooler)
pooler = torch.nn.ReLU()(pooler)
pooler = self.dropout(pooler)
output = self.classifier(pooler)
return output
def load_model():
tokenizer = RobertaTokenizer.from_pretrained('roberta-base', truncation=True, do_lower_case=True,
vocab_file='model/vocab.json',
merges_file='model/merges.txt')
model = torch.load('model/pytorch_roberta_sentiment.bin', map_location=torch.device('cpu'))
return model, tokenizer
model, tokenizer = load_model()
st.markdown("### Угадыватель")
title = st.text_area("Title здесь")
abstract = st.text_area("Abstract здесь")
ans = None
if st.button('Предположить'):
if len(title) == 0 or len(abstract) == 0:
st.write("Вы ничего не ввели =(")
else:
text = title + " : " + abstract
inputs = tokenizer.encode_plus(
text,
None,
add_special_tokens=True,
max_length=256,
pad_to_max_length=True,
return_token_type_ids=True
)
ids = torch.Tensor(inputs['input_ids']).long()
mask = torch.Tensor(inputs['attention_mask']).long()
token_type_ids = torch.Tensor(inputs['token_type_ids']).long()
ans = model(ids.unsqueeze(0), mask.unsqueeze(0), token_type_ids.unsqueeze(0))
idx = torch.nn.functional.softmax(ans[0], dim=0).argmax().item()
print('ANSLEN', ans.shape)
st.markdown(f'{idx_to_tag[idx]}')
if st.button("Посмотреть топ"):
if not ans:
print(1)
text = title + " : " + abstract
inputs = tokenizer.encode_plus(
text,
None,
add_special_tokens=True,
max_length=256,
pad_to_max_length=True,
return_token_type_ids=True
)
ids = torch.Tensor(inputs['input_ids']).long()
mask = torch.Tensor(inputs['attention_mask']).long()
token_type_ids = torch.Tensor(inputs['token_type_ids']).long()
ans = model(ids.unsqueeze(0), mask.unsqueeze(0), token_type_ids.unsqueeze(0))
if len(title) == 0 or len(abstract) == 0:
st.write("Вы ничего не ввели =(")
else:
elems = [el.item() for el in ans[0].argsort(descending=True)]
probs = ans[0].softmax(dim=0).detach().numpy()
str_ans = ''
current_prob = 0
current_elems = []
current_probs = []
idx = 0
while current_prob < 0.95 and idx < len(elems):
current_elems.append(idx_to_tag[elems[idx]])
current_probs.append(probs[elems[idx]])
current_prob += probs[elems[idx]]
idx += 1
st.write(pd.DataFrame({
'Направление': current_elems,
'Вероятность': current_probs,
}))