ilushado commited on
Commit
e569e59
·
1 Parent(s): 897433d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +2 -19
app.py CHANGED
@@ -17,7 +17,8 @@ idx_to_tag = {0: 'cs',
17
  3: 'math',
18
  4: 'q-bio',
19
  5: 'eess',
20
- 6: 'economics, finances'}
 
21
 
22
 
23
  tag_to_idx = {'cs': 0,
@@ -29,24 +30,6 @@ tag_to_idx = {'cs': 0,
29
  'economics, finances': 6
30
  }
31
 
32
- class RobertaClass(torch.nn.Module):
33
- def __init__(self):
34
- super(RobertaClass, self).__init__()
35
- self.l1 = RobertaModel.from_pretrained("roberta-base")
36
- self.pre_classifier = torch.nn.Linear(768, 768)
37
- self.dropout = torch.nn.Dropout(0.3)
38
- self.classifier = torch.nn.Linear(768, 10)
39
-
40
- def forward(self, input_ids, attention_mask, token_type_ids):
41
- output_1 = self.l1(input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids)
42
- hidden_state = output_1[0]
43
- pooler = hidden_state[:, 0]
44
- pooler = self.pre_classifier(pooler)
45
- pooler = torch.nn.ReLU()(pooler)
46
- pooler = self.dropout(pooler)
47
- output = self.classifier(pooler)
48
- return output
49
-
50
  tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-uncased')
51
 
52
  model = AutoModel.from_pretrained('./model')
 
17
  3: 'math',
18
  4: 'q-bio',
19
  5: 'eess',
20
+ 6: 'economics, finances'
21
+ }
22
 
23
 
24
  tag_to_idx = {'cs': 0,
 
30
  'economics, finances': 6
31
  }
32
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
  tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-uncased')
34
 
35
  model = AutoModel.from_pretrained('./model')