Spaces:
Runtime error
Runtime error
| import streamlit as st | |
| import torch | |
| import pandas as pd | |
| import numpy as np | |
| import torch | |
| import transformers | |
| import json | |
| from torch.utils.data import Dataset, DataLoader | |
| from transformers import RobertaModel, RobertaTokenizer | |
| import transformers | |
| idx_to_tag = {0: 'cs', | |
| 1: 'stat', | |
| 2: 'physics', | |
| 3: 'math', | |
| 4: 'cond-mat', | |
| 5: 'q-bio', | |
| 6: 'eess', | |
| 7: 'quant-ph', | |
| 8: 'astro-ph', | |
| 9: 'nlin', | |
| 10: 'q-fin', | |
| 11: 'gr-qc', | |
| 12: 'hep-th', | |
| 13: 'hep-ex', | |
| 14: 'econ', | |
| 15: 'hep-ph', | |
| 16: 'nucl-th', | |
| 17: 'hep-lat', | |
| 18: 'math-ph', | |
| 19: 'nucl-ex'} | |
| tag_to_idx = {'cs': 0, | |
| 'stat': 1, | |
| 'physics': 2, | |
| 'math': 3, | |
| 'cond-mat': 4, | |
| 'q-bio': 5, | |
| 'eess': 6, | |
| 'quant-ph': 7, | |
| 'astro-ph': 8, | |
| 'nlin': 9, | |
| 'q-fin': 10, | |
| 'gr-qc': 11, | |
| 'hep-th': 12, | |
| 'hep-ex': 13, | |
| 'econ': 14, | |
| 'hep-ph': 15, | |
| 'nucl-th': 16, | |
| 'hep-lat': 17, | |
| 'math-ph': 18, | |
| 'nucl-ex': 19} | |
| class RobertaClass(torch.nn.Module): | |
| def __init__(self): | |
| super(RobertaClass, self).__init__() | |
| self.l1 = RobertaModel.from_pretrained("roberta-base") | |
| self.pre_classifier = torch.nn.Linear(768, 768) | |
| self.dropout = torch.nn.Dropout(0.3) | |
| self.classifier = torch.nn.Linear(768, 5) | |
| def forward(self, input_ids, attention_mask, token_type_ids): | |
| output_1 = self.l1(input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids) | |
| hidden_state = output_1[0] | |
| pooler = hidden_state[:, 0] | |
| pooler = self.pre_classifier(pooler) | |
| pooler = torch.nn.ReLU()(pooler) | |
| pooler = self.dropout(pooler) | |
| output = self.classifier(pooler) | |
| return output | |
| tokenizer = RobertaTokenizer.from_pretrained('roberta-base', truncation=True, do_lower_case=True, | |
| vocab_file='model/vocab.json', | |
| merges_file='model/merges.txt') | |
| model = torch.load('model/pytorch_roberta_sentiment.bin', map_location=torch.device('cpu')) | |
| st.markdown("### Hello, world!") | |
| # st.markdown("<img width=200px src='https://rozetked.me/images/uploads/dwoilp3BVjlE.jpg'>", unsafe_allow_html=True) | |
| # ^-- можно показывать пользователю текст, картинки, ограниченное подмножество html - всё как в jupyter | |
| title = st.text_area("Title HERE") | |
| abstract = st.text_area("Abstract HERE") | |
| ans = None | |
| if st.button('Предположить'): | |
| text = title + " : " + abstract | |
| inputs = tokenizer.encode_plus( | |
| text, | |
| None, | |
| add_special_tokens=True, | |
| max_length=256, | |
| pad_to_max_length=True, | |
| return_token_type_ids=True | |
| ) | |
| ids = torch.Tensor(inputs['input_ids']).long() | |
| mask = torch.Tensor(inputs['attention_mask']).long() | |
| token_type_ids = torch.Tensor(inputs['token_type_ids']).long() | |
| ans = model(ids.unsqueeze(0), mask.unsqueeze(0), token_type_ids.unsqueeze(0)) | |
| idx = torch.nn.functional.softmax(ans[0], dim=0).argmax().item() | |
| st.markdown(f'{idx_to_tag[idx]}') | |
| if st.button("Посмотреть топ"): | |
| if not ans: | |
| print(1) | |
| text = title + " : " + abstract | |
| inputs = tokenizer.encode_plus( | |
| text, | |
| None, | |
| add_special_tokens=True, | |
| max_length=256, | |
| pad_to_max_length=True, | |
| return_token_type_ids=True | |
| ) | |
| ids = torch.Tensor(inputs['input_ids']).long() | |
| mask = torch.Tensor(inputs['attention_mask']).long() | |
| token_type_ids = torch.Tensor(inputs['token_type_ids']).long() | |
| ans = model(ids.unsqueeze(0), mask.unsqueeze(0), token_type_ids.unsqueeze(0)) | |
| elems = [el.item() for el in ans[0].argsort(descending=True)[:10]] | |
| probs = ans[0].softmax(dim=0) | |
| str_ans = '' | |
| for el in elems: | |
| str_ans += str(el) + " : " + str(probs[el]) + "\n" | |
| st.markdown(str_ans) | |
| from transformers import pipeline | |
| # тут уже знакомый вам код с huggingface.transformers -- его можно заменить на что угодно от fairseq до catboost | |
| # выводим результаты модели в текстовое поле, на потеху пользователю |