import streamlit as st import torch import pandas as pd import numpy as np import torch import transformers import json from torch.utils.data import Dataset, DataLoader from transformers import RobertaModel, RobertaTokenizer import transformers idx_to_tag = {0: 'cs', 1: 'stat', 2: 'physics', 3: 'math', 4: 'q-bio', 5: 'eess', 6: 'economics, finances', 7: 'gr-qc', 8: 'hep-ex', 9: 'hep-lat'} tag_to_idx = {'cs': 0, 'stat': 1, 'physics': 2, 'math': 3, 'q-bio': 4, 'eess': 5, 'economics, finances': 6, 'gr-qc': 7, 'hep-ex': 8, 'hep-lat': 9} class RobertaClass(torch.nn.Module): def __init__(self): super(RobertaClass, self).__init__() self.l1 = RobertaModel.from_pretrained("roberta-base") self.pre_classifier = torch.nn.Linear(768, 768) self.dropout = torch.nn.Dropout(0.3) self.classifier = torch.nn.Linear(768, 10) def forward(self, input_ids, attention_mask, token_type_ids): output_1 = self.l1(input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids) hidden_state = output_1[0] pooler = hidden_state[:, 0] pooler = self.pre_classifier(pooler) pooler = torch.nn.ReLU()(pooler) pooler = self.dropout(pooler) output = self.classifier(pooler) return output def load_model(): tokenizer = RobertaTokenizer.from_pretrained('roberta-base', truncation=True, do_lower_case=True, vocab_file='model/vocab.json', merges_file='model/merges.txt') model = torch.load('model/pytorch_roberta_sentiment.bin', map_location=torch.device('cpu')) return model, tokenizer model, tokenizer = load_model() st.markdown("### Угадыватель") title = st.text_area("Title здесь") abstract = st.text_area("Abstract здесь") ans = None if st.button('Предположить'): if len(title) == 0 or len(abstract) == 0: st.write("Вы ничего не ввели =(") else: text = title + " : " + abstract inputs = tokenizer.encode_plus( text, None, add_special_tokens=True, max_length=256, pad_to_max_length=True, return_token_type_ids=True ) ids = torch.Tensor(inputs['input_ids']).long() mask = torch.Tensor(inputs['attention_mask']).long() token_type_ids = torch.Tensor(inputs['token_type_ids']).long() ans = model(ids.unsqueeze(0), mask.unsqueeze(0), token_type_ids.unsqueeze(0)) idx = torch.nn.functional.softmax(ans[0], dim=0).argmax().item() print('ANSLEN', ans.shape) st.markdown(f'{idx_to_tag[idx]}') if st.button("Посмотреть топ"): if not ans: print(1) text = title + " : " + abstract inputs = tokenizer.encode_plus( text, None, add_special_tokens=True, max_length=256, pad_to_max_length=True, return_token_type_ids=True ) ids = torch.Tensor(inputs['input_ids']).long() mask = torch.Tensor(inputs['attention_mask']).long() token_type_ids = torch.Tensor(inputs['token_type_ids']).long() ans = model(ids.unsqueeze(0), mask.unsqueeze(0), token_type_ids.unsqueeze(0)) if len(title) == 0 or len(abstract) == 0: st.write("Вы ничего не ввели =(") else: elems = [el.item() for el in ans[0].argsort(descending=True)] probs = ans[0].softmax(dim=0).detach().numpy() str_ans = '' current_prob = 0 current_elems = [] current_probs = [] idx = 0 while current_prob < 0.95 and idx < len(elems): current_elems.append(idx_to_tag[elems[idx]]) current_probs.append(probs[elems[idx]]) current_prob += probs[elems[idx]] idx += 1 st.write(pd.DataFrame({ 'Направление': current_elems, 'Вероятность': current_probs, }))