AnnyNguyen commited on
Commit
04be1da
·
verified ·
1 Parent(s): 92254a1

Upload models.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. models.py +813 -0
models.py ADDED
@@ -0,0 +1,813 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Model architectures cho Aspect-Based Sentiment Analysis
3
+ Hỗ trợ nhiều architectures: Transformer-based, CNN, LSTM, và hybrid models
4
+ """
5
+ import torch
6
+ import os
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+ from transformers import (
10
+ RobertaPreTrainedModel, RobertaModel,
11
+ BertPreTrainedModel, BertModel,
12
+ XLMRobertaPreTrainedModel, XLMRobertaModel,
13
+ BartPreTrainedModel, BartModel, BartForSequenceClassification,
14
+ T5PreTrainedModel, T5EncoderModel,
15
+ AutoConfig, AutoModel, AutoTokenizer,
16
+ PreTrainedModel
17
+ )
18
+ from transformers.modeling_outputs import SequenceClassifierOutput
19
+ from typing import Optional
20
+
21
+
22
+ class BaseABSA(PreTrainedModel):
23
+ """Base class cho tất cả ABSA models"""
24
+ def __init__(self, config):
25
+ super().__init__(config)
26
+ self.num_aspects = config.num_aspects
27
+ self.num_sentiments = config.num_sentiments
28
+
29
+ def forward(self, input_ids=None, attention_mask=None, labels=None, return_dict=None):
30
+ raise NotImplementedError
31
+
32
+ def get_sentiment_classifiers(self, hidden_size):
33
+ """Create sentiment classifiers cho từng aspect"""
34
+ return nn.ModuleList([
35
+ nn.Linear(hidden_size, self.num_sentiments + 1) # +1 cho "none"
36
+ for _ in range(self.num_aspects)
37
+ ])
38
+
39
+
40
+ # ========== Transformer-based Models ==========
41
+
42
+ class TransformerForABSA(RobertaPreTrainedModel):
43
+ """RoBERTa-based model (cho PhoBERT, ViSoBERT, RoBERTa-GRU)"""
44
+ base_model_prefix = "roberta"
45
+
46
+ def __init__(self, config):
47
+ super().__init__(config)
48
+ self.roberta = RobertaModel(config)
49
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
50
+ self.sentiment_classifiers = nn.ModuleList([
51
+ nn.Linear(config.hidden_size, config.num_sentiments + 1)
52
+ for _ in range(config.num_aspects)
53
+ ])
54
+ self.init_weights()
55
+
56
+ def forward(self, input_ids=None, attention_mask=None, labels=None, return_dict=None, **kwargs):
57
+ # RoBERTa-based models don't use token_type_ids, so we ignore it
58
+ kwargs.pop('token_type_ids', None)
59
+ # Filter kwargs to only include valid arguments for RobertaModel
60
+ model_kwargs = {
61
+ k: v for k, v in kwargs.items()
62
+ if k in ['position_ids', 'head_mask', 'inputs_embeds',
63
+ 'output_attentions', 'output_hidden_states']
64
+ }
65
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
66
+ outputs = self.roberta(input_ids, attention_mask=attention_mask, return_dict=return_dict, **model_kwargs)
67
+ pooled = self.dropout(outputs.pooler_output)
68
+ all_logits = torch.stack([cls(pooled) for cls in self.sentiment_classifiers], dim=1)
69
+
70
+ loss = None
71
+ if labels is not None:
72
+ B, A, _ = all_logits.size()
73
+ logits_flat = all_logits.view(-1, all_logits.size(-1))
74
+ targets_flat = labels.view(-1)
75
+ loss_fct = nn.CrossEntropyLoss()
76
+ loss = loss_fct(logits_flat, targets_flat)
77
+
78
+ if not return_dict:
79
+ return ((loss, all_logits) + outputs[2:]) if loss is not None else (all_logits,) + outputs[2:]
80
+
81
+ # T5 returns BaseModelOutput, which has hidden_states
82
+ # But we need to handle it properly
83
+ hidden_states = getattr(outputs, 'hidden_states', None)
84
+ attentions = getattr(outputs, 'attentions', None)
85
+
86
+ return SequenceClassifierOutput(
87
+ loss=loss, logits=all_logits,
88
+ hidden_states=hidden_states,
89
+ attentions=attentions,
90
+ )
91
+
92
+ def save_pretrained(self, save_directory: str, **kwargs):
93
+ # Ensure directory exists
94
+ os.makedirs(save_directory, exist_ok=True)
95
+
96
+ # Save backbone
97
+ self.roberta.save_pretrained(save_directory, **kwargs)
98
+
99
+ # Update and save config with custom attributes
100
+ config = self.roberta.config
101
+ config.num_aspects = len(self.sentiment_classifiers)
102
+ config.num_sentiments = self.sentiment_classifiers[0].out_features - 1 # -1 vì không tính lớp "none"
103
+ # Auto map để AutoModel tự động load đúng class
104
+ # models.py sẽ được upload vào root của repo
105
+ config.auto_map = {
106
+ "AutoModel": "models.TransformerForABSA",
107
+ "AutoModelForSequenceClassification": "models.TransformerForABSA"
108
+ }
109
+ # Lưu thêm thông tin vào config để dễ dàng load lại
110
+ if not hasattr(config, 'custom_model_type'):
111
+ config.custom_model_type = 'TransformerForABSA'
112
+ config.save_pretrained(save_directory, **kwargs)
113
+
114
+ # Save full state_dict (bao gồm cả sentiment_classifiers)
115
+ sd = kwargs.get("state_dict", None) or self.state_dict()
116
+ torch.save(sd, os.path.join(save_directory, "pytorch_model.bin"))
117
+
118
+ @classmethod
119
+ def from_pretrained(cls, pretrained_model_name_or_path: str, num_aspects: int = None, num_sentiments: int = None, **kwargs):
120
+ config = AutoConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
121
+
122
+ # Nếu num_aspects và num_sentiments không được truyền vào, đọc từ config
123
+ if num_aspects is None:
124
+ num_aspects = getattr(config, 'num_aspects', None)
125
+ if num_aspects is None:
126
+ raise ValueError("num_aspects must be provided or present in config")
127
+
128
+ if num_sentiments is None:
129
+ num_sentiments = getattr(config, 'num_sentiments', None)
130
+ if num_sentiments is None:
131
+ raise ValueError("num_sentiments must be provided or present in config")
132
+
133
+ config.num_aspects = num_aspects
134
+ config.num_sentiments = num_sentiments
135
+
136
+ model = cls(config)
137
+
138
+ # Load backbone weights
139
+ model.roberta = RobertaModel.from_pretrained(
140
+ pretrained_model_name_or_path, config=config,
141
+ **{k: v for k, v in kwargs.items() if k not in ("config", "state_dict")},
142
+ )
143
+
144
+ # Load full state_dict nếu có (bao gồm sentiment_classifiers)
145
+ try:
146
+ state_dict_path = os.path.join(pretrained_model_name_or_path, "pytorch_model.bin")
147
+ if os.path.exists(state_dict_path):
148
+ state_dict = torch.load(state_dict_path, map_location="cpu")
149
+ model.load_state_dict(state_dict, strict=False)
150
+ elif "state_dict" in kwargs:
151
+ model.load_state_dict(kwargs["state_dict"], strict=False)
152
+ except Exception as e:
153
+ print(f"⚠ Warning: Could not load full state_dict: {e}")
154
+
155
+ return model
156
+
157
+
158
+ class BERTForABSA(BertPreTrainedModel):
159
+ """BERT-based model (cho mBERT)"""
160
+ def __init__(self, config):
161
+ super().__init__(config)
162
+ self.bert = BertModel(config)
163
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
164
+ self.sentiment_classifiers = nn.ModuleList([
165
+ nn.Linear(config.hidden_size, config.num_sentiments + 1)
166
+ for _ in range(config.num_aspects)
167
+ ])
168
+ self.init_weights()
169
+
170
+ def forward(self, input_ids=None, attention_mask=None, labels=None, return_dict=None, token_type_ids=None, **kwargs):
171
+ # BERT models can use token_type_ids, but for single sentence tasks, it's usually all zeros
172
+ # Filter kwargs to only include valid arguments for BertModel
173
+ model_kwargs = {
174
+ k: v for k, v in kwargs.items()
175
+ if k in ['position_ids', 'head_mask', 'inputs_embeds',
176
+ 'output_attentions', 'output_hidden_states']
177
+ }
178
+ # BERT expects token_type_ids, but if not provided, it will default to all zeros
179
+ if token_type_ids is not None:
180
+ model_kwargs['token_type_ids'] = token_type_ids
181
+
182
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
183
+ outputs = self.bert(input_ids, attention_mask=attention_mask, return_dict=return_dict, **model_kwargs)
184
+ pooled = self.dropout(outputs.pooler_output)
185
+ all_logits = torch.stack([cls(pooled) for cls in self.sentiment_classifiers], dim=1)
186
+
187
+ loss = None
188
+ if labels is not None:
189
+ logits_flat = all_logits.view(-1, all_logits.size(-1))
190
+ targets_flat = labels.view(-1)
191
+ loss = nn.CrossEntropyLoss()(logits_flat, targets_flat)
192
+
193
+ if not return_dict:
194
+ return ((loss, all_logits) + outputs[2:]) if loss is not None else (all_logits,) + outputs[2:]
195
+
196
+ # T5 returns BaseModelOutput, which has hidden_states
197
+ # But we need to handle it properly
198
+ hidden_states = getattr(outputs, 'hidden_states', None)
199
+ attentions = getattr(outputs, 'attentions', None)
200
+
201
+ return SequenceClassifierOutput(
202
+ loss=loss, logits=all_logits,
203
+ hidden_states=hidden_states,
204
+ attentions=attentions,
205
+ )
206
+
207
+ def save_pretrained(self, save_directory: str, **kwargs):
208
+ """Save model with custom attributes"""
209
+ os.makedirs(save_directory, exist_ok=True)
210
+ self.bert.save_pretrained(save_directory, **kwargs)
211
+ config = self.bert.config
212
+ config.num_aspects = len(self.sentiment_classifiers)
213
+ config.num_sentiments = self.sentiment_classifiers[0].out_features - 1
214
+ config.auto_map = {
215
+ "AutoModel": "models.BERTForABSA",
216
+ "AutoModelForSequenceClassification": "models.BERTForABSA"
217
+ }
218
+ if not hasattr(config, 'custom_model_type'):
219
+ config.custom_model_type = 'BERTForABSA'
220
+ config.save_pretrained(save_directory, **kwargs)
221
+ sd = kwargs.get("state_dict", None) or self.state_dict()
222
+ torch.save(sd, os.path.join(save_directory, "pytorch_model.bin"))
223
+
224
+ @classmethod
225
+ def from_pretrained(cls, pretrained_model_name_or_path: str, num_aspects: int = None, num_sentiments: int = None, **kwargs):
226
+ config = AutoConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
227
+
228
+ # Nếu num_aspects và num_sentiments không được truyền vào, đọc từ config
229
+ if num_aspects is None:
230
+ num_aspects = getattr(config, 'num_aspects', None)
231
+ if num_aspects is None:
232
+ raise ValueError("num_aspects must be provided or present in config")
233
+
234
+ if num_sentiments is None:
235
+ num_sentiments = getattr(config, 'num_sentiments', None)
236
+ if num_sentiments is None:
237
+ raise ValueError("num_sentiments must be provided or present in config")
238
+
239
+ config.num_aspects = num_aspects
240
+ config.num_sentiments = num_sentiments
241
+ model = cls(config)
242
+ model.bert = BertModel.from_pretrained(
243
+ pretrained_model_name_or_path, config=config,
244
+ **{k: v for k, v in kwargs.items() if k not in ("config", "state_dict")},
245
+ )
246
+
247
+ # Load full state_dict if available
248
+ try:
249
+ state_dict_path = os.path.join(pretrained_model_name_or_path, "pytorch_model.bin")
250
+ if os.path.exists(state_dict_path):
251
+ state_dict = torch.load(state_dict_path, map_location="cpu")
252
+ model.load_state_dict(state_dict, strict=False)
253
+ elif "state_dict" in kwargs:
254
+ model.load_state_dict(kwargs["state_dict"], strict=False)
255
+ except Exception as e:
256
+ print(f"⚠ Warning: Could not load full state_dict: {e}")
257
+
258
+ return model
259
+
260
+
261
+ class XLMRobertaForABSA(XLMRobertaPreTrainedModel):
262
+ """XLM-RoBERTa-based model"""
263
+ def __init__(self, config):
264
+ super().__init__(config)
265
+ self.roberta = XLMRobertaModel(config)
266
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
267
+ self.sentiment_classifiers = nn.ModuleList([
268
+ nn.Linear(config.hidden_size, config.num_sentiments + 1)
269
+ for _ in range(config.num_aspects)
270
+ ])
271
+ self.init_weights()
272
+
273
+ def forward(self, input_ids=None, attention_mask=None, labels=None, return_dict=None):
274
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
275
+ outputs = self.roberta(input_ids, attention_mask=attention_mask, return_dict=return_dict)
276
+ pooled = self.dropout(outputs.pooler_output)
277
+ all_logits = torch.stack([cls(pooled) for cls in self.sentiment_classifiers], dim=1)
278
+
279
+ loss = None
280
+ if labels is not None:
281
+ logits_flat = all_logits.view(-1, all_logits.size(-1))
282
+ targets_flat = labels.view(-1)
283
+ loss = nn.CrossEntropyLoss()(logits_flat, targets_flat)
284
+
285
+ if not return_dict:
286
+ return ((loss, all_logits) + outputs[2:]) if loss is not None else (all_logits,) + outputs[2:]
287
+
288
+ # T5 returns BaseModelOutput, which has hidden_states
289
+ # But we need to handle it properly
290
+ hidden_states = getattr(outputs, 'hidden_states', None)
291
+ attentions = getattr(outputs, 'attentions', None)
292
+
293
+ return SequenceClassifierOutput(
294
+ loss=loss, logits=all_logits,
295
+ hidden_states=hidden_states,
296
+ attentions=attentions,
297
+ )
298
+
299
+ @classmethod
300
+ def from_pretrained(cls, pretrained_model_name_or_path: str, num_aspects: int, num_sentiments: int, **kwargs):
301
+ config = AutoConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
302
+ config.num_aspects = num_aspects
303
+ config.num_sentiments = num_sentiments
304
+ model = cls(config)
305
+ model.roberta = XLMRobertaModel.from_pretrained(
306
+ pretrained_model_name_or_path, config=config,
307
+ **{k: v for k, v in kwargs.items() if k not in ("config",)},
308
+ )
309
+ return model
310
+
311
+
312
+ class RoBERTaGRUForABSA(RobertaPreTrainedModel):
313
+ """RoBERTa + GRU hybrid model"""
314
+ base_model_prefix = "roberta"
315
+
316
+ def __init__(self, config):
317
+ super().__init__(config)
318
+ self.roberta = RobertaModel(config)
319
+ self.gru = nn.GRU(
320
+ config.hidden_size, config.hidden_size,
321
+ num_layers=2, batch_first=True, bidirectional=True, dropout=0.2
322
+ )
323
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
324
+ self.sentiment_classifiers = nn.ModuleList([
325
+ nn.Linear(config.hidden_size * 2, config.num_sentiments + 1) # *2 vì bidirectional
326
+ for _ in range(config.num_aspects)
327
+ ])
328
+ self.init_weights()
329
+
330
+ def forward(self, input_ids=None, attention_mask=None, labels=None, return_dict=None):
331
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
332
+ outputs = self.roberta(input_ids, attention_mask=attention_mask, return_dict=return_dict)
333
+
334
+ # Use last_hidden_state thay vì pooler_output
335
+ sequence_output = outputs.last_hidden_state # [B, L, H]
336
+
337
+ # GRU layer
338
+ gru_out, _ = self.gru(sequence_output) # [B, L, H*2]
339
+ # Take last timestep
340
+ pooled = self.dropout(gru_out[:, -1, :]) # [B, H*2]
341
+
342
+ all_logits = torch.stack([cls(pooled) for cls in self.sentiment_classifiers], dim=1)
343
+
344
+ loss = None
345
+ if labels is not None:
346
+ logits_flat = all_logits.view(-1, all_logits.size(-1))
347
+ targets_flat = labels.view(-1)
348
+ loss = nn.CrossEntropyLoss()(logits_flat, targets_flat)
349
+
350
+ if not return_dict:
351
+ return ((loss, all_logits) + outputs[2:]) if loss is not None else (all_logits,) + outputs[2:]
352
+
353
+ # T5 returns BaseModelOutput, which has hidden_states
354
+ # But we need to handle it properly
355
+ hidden_states = getattr(outputs, 'hidden_states', None)
356
+ attentions = getattr(outputs, 'attentions', None)
357
+
358
+ return SequenceClassifierOutput(
359
+ loss=loss, logits=all_logits,
360
+ hidden_states=hidden_states,
361
+ attentions=attentions,
362
+ )
363
+
364
+ @classmethod
365
+ def from_pretrained(cls, pretrained_model_name_or_path: str, num_aspects: int, num_sentiments: int, **kwargs):
366
+ config = AutoConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
367
+ config.num_aspects = num_aspects
368
+ config.num_sentiments = num_sentiments
369
+ model = cls(config)
370
+ model.roberta = RobertaModel.from_pretrained(
371
+ pretrained_model_name_or_path, config=config,
372
+ **{k: v for k, v in kwargs.items() if k not in ("config",)},
373
+ )
374
+ return model
375
+
376
+
377
+ class BartForABSA(BartPreTrainedModel):
378
+ """BART-based model (cho BartPho)"""
379
+ def __init__(self, config):
380
+ super().__init__(config)
381
+ self.model = BartModel(config)
382
+ self.dropout = nn.Dropout(config.dropout)
383
+ self.sentiment_classifiers = nn.ModuleList([
384
+ nn.Linear(config.d_model, config.num_sentiments + 1)
385
+ for _ in range(config.num_aspects)
386
+ ])
387
+ self.init_weights()
388
+
389
+ def forward(self, input_ids=None, attention_mask=None, labels=None, return_dict=None, **kwargs):
390
+ # BART models don't use token_type_ids, so we ignore it
391
+ kwargs.pop('token_type_ids', None)
392
+ # Filter kwargs to only include valid arguments for BartModel
393
+ # Remove training-specific arguments that BartModel doesn't accept
394
+ model_kwargs = {
395
+ k: v for k, v in kwargs.items()
396
+ if k in ['position_ids', 'head_mask', 'inputs_embeds',
397
+ 'output_attentions', 'output_hidden_states']
398
+ }
399
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
400
+
401
+ # IMPORTANT: For BART, we need to access encoder output directly
402
+ # BartModel.forward() returns decoder output in last_hidden_state
403
+ # We need to call encoder separately to get encoder hidden states
404
+ # Only call encoder once (don't call full model.forward() to avoid double computation)
405
+ encoder_outputs = self.model.get_encoder()(
406
+ input_ids,
407
+ attention_mask=attention_mask,
408
+ return_dict=True,
409
+ **{k: v for k, v in model_kwargs.items()}
410
+ )
411
+ sequence_output = encoder_outputs.last_hidden_state # [B, L, H] - encoder output
412
+
413
+ # Mean pooling with attention mask (weighted mean to avoid padding tokens)
414
+ if attention_mask is not None:
415
+ # Expand attention mask to match sequence_output dimensions
416
+ attention_mask_expanded = attention_mask.unsqueeze(-1).expand(sequence_output.size()).float()
417
+ # Sum over sequence length, divide by number of non-padding tokens
418
+ sum_embeddings = torch.sum(sequence_output * attention_mask_expanded, dim=1)
419
+ sum_mask = torch.clamp(attention_mask_expanded.sum(dim=1), min=1e-9)
420
+ pooled = sum_embeddings / sum_mask # [B, H]
421
+ else:
422
+ pooled = sequence_output.mean(dim=1) # [B, H]
423
+
424
+ pooled = self.dropout(pooled)
425
+ all_logits = torch.stack([cls(pooled) for cls in self.sentiment_classifiers], dim=1)
426
+
427
+ loss = None
428
+ if labels is not None:
429
+ logits_flat = all_logits.view(-1, all_logits.size(-1))
430
+ targets_flat = labels.view(-1)
431
+ loss = nn.CrossEntropyLoss()(logits_flat, targets_flat)
432
+
433
+ if not return_dict:
434
+ return ((loss, all_logits) + (encoder_outputs.hidden_states, encoder_outputs.attentions)) if loss is not None else (all_logits,)
435
+
436
+ # Use encoder outputs for hidden_states and attentions
437
+ hidden_states = getattr(encoder_outputs, 'hidden_states', None)
438
+ attentions = getattr(encoder_outputs, 'attentions', None)
439
+
440
+ return SequenceClassifierOutput(
441
+ loss=loss, logits=all_logits,
442
+ hidden_states=hidden_states,
443
+ attentions=attentions,
444
+ )
445
+
446
+ def save_pretrained(self, save_directory: str, **kwargs):
447
+ """Save model with custom attributes"""
448
+ os.makedirs(save_directory, exist_ok=True)
449
+ self.model.save_pretrained(save_directory, **kwargs)
450
+ config = self.model.config
451
+ config.num_aspects = len(self.sentiment_classifiers)
452
+ config.num_sentiments = self.sentiment_classifiers[0].out_features - 1
453
+ config.auto_map = {
454
+ "AutoModel": "models.BartForABSA",
455
+ "AutoModelForSequenceClassification": "models.BartForABSA"
456
+ }
457
+ if not hasattr(config, 'custom_model_type'):
458
+ config.custom_model_type = 'BartForABSA'
459
+ config.save_pretrained(save_directory, **kwargs)
460
+ sd = kwargs.get("state_dict", None) or self.state_dict()
461
+ torch.save(sd, os.path.join(save_directory, "pytorch_model.bin"))
462
+
463
+ @classmethod
464
+ def from_pretrained(cls, pretrained_model_name_or_path: str, num_aspects: int = None, num_sentiments: int = None, **kwargs):
465
+ config = AutoConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
466
+
467
+ # Nếu num_aspects và num_sentiments không được truyền vào, đọc từ config
468
+ if num_aspects is None:
469
+ num_aspects = getattr(config, 'num_aspects', None)
470
+ if num_aspects is None:
471
+ raise ValueError("num_aspects must be provided or present in config")
472
+
473
+ if num_sentiments is None:
474
+ num_sentiments = getattr(config, 'num_sentiments', None)
475
+ if num_sentiments is None:
476
+ raise ValueError("num_sentiments must be provided or present in config")
477
+
478
+ config.num_aspects = num_aspects
479
+ config.num_sentiments = num_sentiments
480
+ model = cls(config)
481
+ model.model = BartModel.from_pretrained(
482
+ pretrained_model_name_or_path, config=config,
483
+ **{k: v for k, v in kwargs.items() if k not in ("config", "state_dict")},
484
+ )
485
+
486
+ # Load full state_dict if available
487
+ try:
488
+ state_dict_path = os.path.join(pretrained_model_name_or_path, "pytorch_model.bin")
489
+ if os.path.exists(state_dict_path):
490
+ state_dict = torch.load(state_dict_path, map_location="cpu")
491
+ model.load_state_dict(state_dict, strict=False)
492
+ elif "state_dict" in kwargs:
493
+ model.load_state_dict(kwargs["state_dict"], strict=False)
494
+ except Exception as e:
495
+ print(f"⚠ Warning: Could not load full state_dict: {e}")
496
+
497
+ return model
498
+
499
+
500
+ class T5ForABSA(T5PreTrainedModel):
501
+ """T5-based model (cho ViT5) - sử dụng encoder only"""
502
+ def __init__(self, config):
503
+ super().__init__(config)
504
+ self.encoder = T5EncoderModel(config)
505
+ self.dropout = nn.Dropout(config.dropout_rate)
506
+ self.sentiment_classifiers = nn.ModuleList([
507
+ nn.Linear(config.d_model, config.num_sentiments + 1)
508
+ for _ in range(config.num_aspects)
509
+ ])
510
+ self.init_weights()
511
+
512
+ def forward(self, input_ids=None, attention_mask=None, labels=None, return_dict=None, **kwargs):
513
+ # T5 models don't use token_type_ids, so we ignore it
514
+ kwargs.pop('token_type_ids', None)
515
+ # Filter kwargs to only include valid arguments for T5EncoderModel
516
+ model_kwargs = {
517
+ k: v for k, v in kwargs.items()
518
+ if k in ['position_ids', 'head_mask', 'inputs_embeds',
519
+ 'output_attentions', 'output_hidden_states']
520
+ }
521
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
522
+ outputs = self.encoder(input_ids, attention_mask=attention_mask, return_dict=return_dict, **model_kwargs)
523
+
524
+ # Mean pooling with attention mask (weighted mean to avoid padding tokens)
525
+ sequence_output = outputs.last_hidden_state # [B, L, H]
526
+ if attention_mask is not None:
527
+ # Expand attention mask to match sequence_output dimensions
528
+ attention_mask_expanded = attention_mask.unsqueeze(-1).expand(sequence_output.size()).float()
529
+ # Sum over sequence length, divide by number of non-padding tokens
530
+ sum_embeddings = torch.sum(sequence_output * attention_mask_expanded, dim=1)
531
+ sum_mask = torch.clamp(attention_mask_expanded.sum(dim=1), min=1e-9)
532
+ pooled = sum_embeddings / sum_mask # [B, H]
533
+ else:
534
+ pooled = sequence_output.mean(dim=1) # [B, H]
535
+
536
+ pooled = self.dropout(pooled)
537
+ all_logits = torch.stack([cls(pooled) for cls in self.sentiment_classifiers], dim=1)
538
+
539
+ loss = None
540
+ if labels is not None:
541
+ logits_flat = all_logits.view(-1, all_logits.size(-1))
542
+ targets_flat = labels.view(-1)
543
+ loss = nn.CrossEntropyLoss()(logits_flat, targets_flat)
544
+
545
+ if not return_dict:
546
+ return ((loss, all_logits) + outputs[2:]) if loss is not None else (all_logits,) + outputs[2:]
547
+
548
+ # T5 returns BaseModelOutput, which has hidden_states
549
+ # But we need to handle it properly
550
+ hidden_states = getattr(outputs, 'hidden_states', None)
551
+ attentions = getattr(outputs, 'attentions', None)
552
+
553
+ return SequenceClassifierOutput(
554
+ loss=loss, logits=all_logits,
555
+ hidden_states=hidden_states,
556
+ attentions=attentions,
557
+ )
558
+
559
+ def save_pretrained(self, save_directory: str, **kwargs):
560
+ """Save model with custom attributes"""
561
+ os.makedirs(save_directory, exist_ok=True)
562
+ self.encoder.save_pretrained(save_directory, **kwargs)
563
+ config = self.encoder.config
564
+ config.num_aspects = len(self.sentiment_classifiers)
565
+ config.num_sentiments = self.sentiment_classifiers[0].out_features - 1
566
+ config.auto_map = {
567
+ "AutoModel": "models.T5ForABSA",
568
+ "AutoModelForSequenceClassification": "models.T5ForABSA"
569
+ }
570
+ if not hasattr(config, 'custom_model_type'):
571
+ config.custom_model_type = 'T5ForABSA'
572
+ config.save_pretrained(save_directory, **kwargs)
573
+ sd = kwargs.get("state_dict", None) or self.state_dict()
574
+ torch.save(sd, os.path.join(save_directory, "pytorch_model.bin"))
575
+
576
+ @classmethod
577
+ def from_pretrained(cls, pretrained_model_name_or_path: str, num_aspects: int = None, num_sentiments: int = None, **kwargs):
578
+ config = AutoConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
579
+
580
+ # Nếu num_aspects và num_sentiments không được truyền vào, đọc từ config
581
+ if num_aspects is None:
582
+ num_aspects = getattr(config, 'num_aspects', None)
583
+ if num_aspects is None:
584
+ raise ValueError("num_aspects must be provided or present in config")
585
+
586
+ if num_sentiments is None:
587
+ num_sentiments = getattr(config, 'num_sentiments', None)
588
+ if num_sentiments is None:
589
+ raise ValueError("num_sentiments must be provided or present in config")
590
+
591
+ config.num_aspects = num_aspects
592
+ config.num_sentiments = num_sentiments
593
+ model = cls(config)
594
+ model.encoder = T5EncoderModel.from_pretrained(
595
+ pretrained_model_name_or_path, config=config,
596
+ **{k: v for k, v in kwargs.items() if k not in ("config", "state_dict")},
597
+ )
598
+
599
+ # Load full state_dict if available
600
+ try:
601
+ state_dict_path = os.path.join(pretrained_model_name_or_path, "pytorch_model.bin")
602
+ if os.path.exists(state_dict_path):
603
+ state_dict = torch.load(state_dict_path, map_location="cpu")
604
+ model.load_state_dict(state_dict, strict=False)
605
+ elif "state_dict" in kwargs:
606
+ model.load_state_dict(kwargs["state_dict"], strict=False)
607
+ except Exception as e:
608
+ print(f"⚠ Warning: Could not load full state_dict: {e}")
609
+
610
+ return model
611
+
612
+
613
+ # ========== Non-Transformer Models ==========
614
+
615
+ class TextCNNForABSA(nn.Module):
616
+ """TextCNN model - không dùng transformers"""
617
+ def __init__(self, vocab_size, embed_dim, num_filters, filter_sizes, num_aspects, num_sentiments, max_length=256):
618
+ super().__init__()
619
+ self.embedding = nn.Embedding(vocab_size, embed_dim)
620
+ self.convs = nn.ModuleList([
621
+ nn.Conv1d(embed_dim, num_filters, kernel_size=fs)
622
+ for fs in filter_sizes
623
+ ])
624
+ self.dropout = nn.Dropout(0.5)
625
+ self.sentiment_classifiers = nn.ModuleList([
626
+ nn.Linear(len(filter_sizes) * num_filters, num_sentiments + 1)
627
+ for _ in range(num_aspects)
628
+ ])
629
+
630
+ def forward(self, input_ids, attention_mask=None, labels=None, return_dict=True):
631
+ # input_ids: [B, L]
632
+ x = self.embedding(input_ids) # [B, L, E]
633
+ x = x.permute(0, 2, 1) # [B, E, L]
634
+
635
+ conv_outputs = []
636
+ for conv in self.convs:
637
+ conv_out = F.relu(conv(x)) # [B, F, L']
638
+ pooled = F.max_pool1d(conv_out, kernel_size=conv_out.size(2)) # [B, F, 1]
639
+ conv_outputs.append(pooled.squeeze(2)) # [B, F]
640
+
641
+ x = torch.cat(conv_outputs, dim=1) # [B, F*len(filter_sizes)]
642
+ x = self.dropout(x)
643
+
644
+ all_logits = torch.stack([cls(x) for cls in self.sentiment_classifiers], dim=1)
645
+
646
+ loss = None
647
+ if labels is not None:
648
+ logits_flat = all_logits.view(-1, all_logits.size(-1))
649
+ targets_flat = labels.view(-1)
650
+ loss = nn.CrossEntropyLoss()(logits_flat, targets_flat)
651
+
652
+ if return_dict:
653
+ return SequenceClassifierOutput(
654
+ loss=loss, logits=all_logits,
655
+ hidden_states=None, attentions=None
656
+ )
657
+ return (loss, all_logits) if loss is not None else (all_logits,)
658
+
659
+
660
+ class BiLSTMForABSA(nn.Module):
661
+ """BiLSTM model - không dùng transformers"""
662
+ def __init__(self, vocab_size, embed_dim, hidden_dim, num_layers, num_aspects, num_sentiments, dropout=0.3):
663
+ super().__init__()
664
+ self.embedding = nn.Embedding(vocab_size, embed_dim)
665
+ self.lstm = nn.LSTM(
666
+ embed_dim, hidden_dim, num_layers,
667
+ batch_first=True, bidirectional=True, dropout=dropout
668
+ )
669
+ self.dropout = nn.Dropout(dropout)
670
+ self.sentiment_classifiers = nn.ModuleList([
671
+ nn.Linear(hidden_dim * 2, num_sentiments + 1) # *2 vì bidirectional
672
+ for _ in range(num_aspects)
673
+ ])
674
+
675
+ def forward(self, input_ids, attention_mask=None, labels=None, return_dict=True):
676
+ x = self.embedding(input_ids) # [B, L, E]
677
+ lstm_out, (h_n, c_n) = self.lstm(x) # [B, L, H*2]
678
+
679
+ # Use last non-padding hidden state instead of always using last timestep
680
+ # This is important because padding tokens can be at the end
681
+ if attention_mask is not None:
682
+ # Find the last non-padding token for each sequence
683
+ # attention_mask: [B, L] where 1 = real token, 0 = padding
684
+ seq_lengths = attention_mask.sum(dim=1) - 1 # -1 for 0-indexing
685
+ # Ensure seq_lengths are within valid range
686
+ seq_lengths = torch.clamp(seq_lengths, min=0, max=lstm_out.size(1) - 1)
687
+ # Get last hidden state for each sequence: [B, H*2]
688
+ batch_size = lstm_out.size(0)
689
+ pooled = lstm_out[torch.arange(batch_size, device=lstm_out.device), seq_lengths, :]
690
+ else:
691
+ # Fallback: use last timestep if no attention mask
692
+ pooled = lstm_out[:, -1, :] # [B, H*2]
693
+
694
+ pooled = self.dropout(pooled)
695
+ all_logits = torch.stack([cls(pooled) for cls in self.sentiment_classifiers], dim=1)
696
+
697
+ loss = None
698
+ if labels is not None:
699
+ logits_flat = all_logits.view(-1, all_logits.size(-1))
700
+ targets_flat = labels.view(-1)
701
+ loss = nn.CrossEntropyLoss()(logits_flat, targets_flat)
702
+
703
+ if return_dict:
704
+ return SequenceClassifierOutput(
705
+ loss=loss, logits=all_logits,
706
+ hidden_states=None, attentions=None
707
+ )
708
+ return (loss, all_logits) if loss is not None else (all_logits,)
709
+
710
+
711
+ # ========== Model Factory ==========
712
+
713
+ def get_model_class(model_name: str):
714
+ """Factory function để lấy model class dựa trên model name"""
715
+ model_name_lower = model_name.lower()
716
+
717
+ # RoBERTa-GRU (check first to avoid confusion)
718
+ if 'roberta' in model_name_lower and ('gru' in model_name_lower or 'roberta-base-gru' in model_name_lower):
719
+ return RoBERTaGRUForABSA
720
+
721
+ # Roberta-based (PhoBERT v1/v2, ViSoBERT)
722
+ if 'phobert' in model_name_lower or 'visobert' in model_name_lower or 'roberta' in model_name_lower:
723
+ return TransformerForABSA
724
+
725
+ # XLM-RoBERTa
726
+ elif 'xlm-roberta' in model_name_lower or 'xlm_roberta' in model_name_lower:
727
+ return XLMRobertaForABSA
728
+
729
+ # BERT
730
+ elif 'bert' in model_name_lower and 'roberta' not in model_name_lower:
731
+ return BERTForABSA
732
+
733
+ # BART
734
+ elif 'bart' in model_name_lower:
735
+ return BartForABSA
736
+
737
+ # T5
738
+ elif 't5' in model_name_lower or 'vit5' in model_name_lower:
739
+ return T5ForABSA
740
+
741
+ # TextCNN
742
+ elif 'textcnn' in model_name_lower or 'cnn' in model_name_lower:
743
+ return TextCNNForABSA
744
+
745
+ # BiLSTM
746
+ elif 'bilstm' in model_name_lower or 'lstm' in model_name_lower:
747
+ return BiLSTMForABSA
748
+
749
+ # Default: try Roberta
750
+ else:
751
+ return TransformerForABSA
752
+
753
+
754
+ def create_model(model_name: str, num_aspects: int, num_sentiments: int, vocab_size=None, **kwargs):
755
+ """
756
+ Create model instance dựa trên model name
757
+
758
+ Args:
759
+ model_name: Tên model hoặc model ID từ Hugging Face
760
+ num_aspects: Số lượng aspects
761
+ num_sentiments: Số lượng sentiment classes
762
+ vocab_size: Vocabulary size (chỉ cần cho TextCNN/BiLSTM)
763
+ **kwargs: Additional arguments
764
+ """
765
+ model_class = get_model_class(model_name)
766
+
767
+ # RoBERTa-GRU cần base model riêng
768
+ if model_class == RoBERTaGRUForABSA:
769
+ # Use roberta-base as base model for RoBERTa-GRU
770
+ base_model_name = 'roberta-base'
771
+ return model_class.from_pretrained(
772
+ base_model_name,
773
+ num_aspects=num_aspects,
774
+ num_sentiments=num_sentiments,
775
+ trust_remote_code=True,
776
+ **kwargs
777
+ )
778
+
779
+ # Non-transformer models
780
+ if model_class in [TextCNNForABSA, BiLSTMForABSA]:
781
+ if vocab_size is None:
782
+ raise ValueError(f"vocab_size is required for {model_class.__name__}")
783
+
784
+ if model_class == TextCNNForABSA:
785
+ return TextCNNForABSA(
786
+ vocab_size=vocab_size,
787
+ embed_dim=kwargs.get('embed_dim', 300),
788
+ num_filters=kwargs.get('num_filters', 100),
789
+ filter_sizes=kwargs.get('filter_sizes', [3, 4, 5]),
790
+ num_aspects=num_aspects,
791
+ num_sentiments=num_sentiments,
792
+ max_length=kwargs.get('max_length', 256)
793
+ )
794
+ elif model_class == BiLSTMForABSA:
795
+ return BiLSTMForABSA(
796
+ vocab_size=vocab_size,
797
+ embed_dim=kwargs.get('embed_dim', 300),
798
+ hidden_dim=kwargs.get('hidden_dim', 256),
799
+ num_layers=kwargs.get('num_layers', 2),
800
+ num_aspects=num_aspects,
801
+ num_sentiments=num_sentiments,
802
+ dropout=kwargs.get('dropout', 0.3)
803
+ )
804
+
805
+ # Transformer models
806
+ else:
807
+ return model_class.from_pretrained(
808
+ model_name,
809
+ num_aspects=num_aspects,
810
+ num_sentiments=num_sentiments,
811
+ trust_remote_code=True,
812
+ **kwargs
813
+ )