Spaces:
Runtime error
Runtime error
File size: 4,320 Bytes
42b2835 d4ef182 9e69dbf d4ef182 ad84f1e 9e69dbf d4ef182 ad84f1e 9e69dbf cd4ae85 9e69dbf d4ef182 ccda969 9e69dbf ccda969 d4ef182 9e69dbf ccda969 d4ef182 42b2835 bee7c66 42b2835 ee59b50 85583fc d4ef182 1b02242 ee59b50 d4ef182 85583fc 9e69dbf 85583fc 9e69dbf ee59b50 782f163 ac0ccdd ee59b50 94303e5 85583fc d4ef182 42b2835 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 |
import streamlit as st
import torch
import pandas as pd
import numpy as np
import torch
import transformers
import json
from torch.utils.data import Dataset, DataLoader
from transformers import RobertaModel, RobertaTokenizer
import transformers
idx_to_tag = {0: 'cs',
1: 'stat',
2: 'physics',
3: 'math',
4: 'q-bio',
5: 'eess',
6: 'economics, finances',
7: 'gr-qc',
8: 'hep-ex',
9: 'hep-lat'}
tag_to_idx = {'cs': 0,
'stat': 1,
'physics': 2,
'math': 3,
'q-bio': 4,
'eess': 5,
'economics, finances': 6,
'gr-qc': 7,
'hep-ex': 8,
'hep-lat': 9}
class RobertaClass(torch.nn.Module):
def __init__(self):
super(RobertaClass, self).__init__()
self.l1 = RobertaModel.from_pretrained("roberta-base")
self.pre_classifier = torch.nn.Linear(768, 768)
self.dropout = torch.nn.Dropout(0.3)
self.classifier = torch.nn.Linear(768, 10)
def forward(self, input_ids, attention_mask, token_type_ids):
output_1 = self.l1(input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids)
hidden_state = output_1[0]
pooler = hidden_state[:, 0]
pooler = self.pre_classifier(pooler)
pooler = torch.nn.ReLU()(pooler)
pooler = self.dropout(pooler)
output = self.classifier(pooler)
return output
def load_model():
tokenizer = RobertaTokenizer.from_pretrained('roberta-base', truncation=True, do_lower_case=True,
vocab_file='model/vocab.json',
merges_file='model/merges.txt')
model = torch.load('model/pytorch_roberta_sentiment.bin', map_location=torch.device('cpu'))
return model, tokenizer
model, tokenizer = load_model()
st.markdown("### Угадыватель")
title = st.text_area("Title здесь")
abstract = st.text_area("Abstract здесь")
ans = None
if st.button('Предположить'):
if len(title) == 0 or len(abstract) == 0:
st.write("Вы ничего не ввели =(")
else:
text = title + " : " + abstract
inputs = tokenizer.encode_plus(
text,
None,
add_special_tokens=True,
max_length=256,
pad_to_max_length=True,
return_token_type_ids=True
)
ids = torch.Tensor(inputs['input_ids']).long()
mask = torch.Tensor(inputs['attention_mask']).long()
token_type_ids = torch.Tensor(inputs['token_type_ids']).long()
ans = model(ids.unsqueeze(0), mask.unsqueeze(0), token_type_ids.unsqueeze(0))
idx = torch.nn.functional.softmax(ans[0], dim=0).argmax().item()
print('ANSLEN', ans.shape)
st.markdown(f'{idx_to_tag[idx]}')
if st.button("Посмотреть топ"):
if not ans:
print(1)
text = title + " : " + abstract
inputs = tokenizer.encode_plus(
text,
None,
add_special_tokens=True,
max_length=256,
pad_to_max_length=True,
return_token_type_ids=True
)
ids = torch.Tensor(inputs['input_ids']).long()
mask = torch.Tensor(inputs['attention_mask']).long()
token_type_ids = torch.Tensor(inputs['token_type_ids']).long()
ans = model(ids.unsqueeze(0), mask.unsqueeze(0), token_type_ids.unsqueeze(0))
if len(title) == 0 or len(abstract) == 0:
st.write("Вы ничего не ввели =(")
else:
elems = [el.item() for el in ans[0].argsort(descending=True)]
probs = ans[0].softmax(dim=0).detach().numpy()
str_ans = ''
current_prob = 0
current_elems = []
current_probs = []
idx = 0
while current_prob < 0.95 and idx < len(elems):
current_elems.append(idx_to_tag[elems[idx]])
current_probs.append(probs[elems[idx]])
current_prob += probs[elems[idx]]
idx += 1
st.write(pd.DataFrame({
'Направление': current_elems,
'Вероятность': current_probs,
}))
|