Update modeling_reasonir_8b.py
Browse files- modeling_reasonir_8b.py +166 -5
modeling_reasonir_8b.py
CHANGED
|
@@ -51,6 +51,10 @@ from transformers.utils import (
|
|
| 51 |
replace_return_docstrings,
|
| 52 |
)
|
| 53 |
from transformers.models.llama.configuration_llama import LlamaConfig
|
|
|
|
|
|
|
|
|
|
|
|
|
| 54 |
|
| 55 |
if is_flash_attn_2_available():
|
| 56 |
from flash_attn import flash_attn_func, flash_attn_varlen_func
|
|
@@ -428,7 +432,7 @@ class LlamaFlashAttention2(LlamaAttention):
|
|
| 428 |
dropout=0.0,
|
| 429 |
softmax_scale=None,
|
| 430 |
use_sliding_windows=False,
|
| 431 |
-
is_causal=
|
| 432 |
):
|
| 433 |
"""
|
| 434 |
Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token
|
|
@@ -529,7 +533,7 @@ class LlamaFlashAttention2(LlamaAttention):
|
|
| 529 |
use_cache: bool = False,
|
| 530 |
cache_position: Optional[torch.LongTensor] = None,
|
| 531 |
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
|
| 532 |
-
is_causal: bool =
|
| 533 |
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
| 534 |
if isinstance(past_key_value, StaticCache):
|
| 535 |
raise ValueError(
|
|
@@ -656,7 +660,7 @@ class LlamaSdpaAttention(LlamaAttention):
|
|
| 656 |
use_cache: bool = False,
|
| 657 |
cache_position: Optional[torch.LongTensor] = None,
|
| 658 |
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
|
| 659 |
-
is_causal: bool =
|
| 660 |
**kwargs,
|
| 661 |
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
| 662 |
if output_attentions:
|
|
@@ -763,7 +767,7 @@ class LlamaDecoderLayer(nn.Module):
|
|
| 763 |
use_cache: Optional[bool] = False,
|
| 764 |
cache_position: Optional[torch.LongTensor] = None,
|
| 765 |
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
|
| 766 |
-
is_causal: bool =
|
| 767 |
**kwargs,
|
| 768 |
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
| 769 |
"""
|
|
@@ -948,6 +952,8 @@ LLAMA_INPUTS_DOCSTRING = r"""
|
|
| 948 |
"The bare LLaMA Model outputting raw hidden-states without any specific head on top.",
|
| 949 |
LLAMA_START_DOCSTRING,
|
| 950 |
)
|
|
|
|
|
|
|
| 951 |
class LlamaModel(LlamaPreTrainedModel):
|
| 952 |
"""
|
| 953 |
Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`LlamaDecoderLayer`]
|
|
@@ -991,7 +997,7 @@ class LlamaModel(LlamaPreTrainedModel):
|
|
| 991 |
output_hidden_states: Optional[bool] = None,
|
| 992 |
return_dict: Optional[bool] = None,
|
| 993 |
cache_position: Optional[torch.LongTensor] = None,
|
| 994 |
-
is_causal: Optional[bool] =
|
| 995 |
) -> Union[Tuple, BaseModelOutputWithPast]:
|
| 996 |
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
| 997 |
output_hidden_states = (
|
|
@@ -1663,3 +1669,158 @@ class LlamaForTokenClassification(LlamaPreTrainedModel):
|
|
| 1663 |
hidden_states=outputs.hidden_states,
|
| 1664 |
attentions=outputs.attentions,
|
| 1665 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 51 |
replace_return_docstrings,
|
| 52 |
)
|
| 53 |
from transformers.models.llama.configuration_llama import LlamaConfig
|
| 54 |
+
from typing import Dict, List, Union, cast
|
| 55 |
+
import numpy as np
|
| 56 |
+
from tqdm import tqdm
|
| 57 |
+
from transformers import AutoTokenizer
|
| 58 |
|
| 59 |
if is_flash_attn_2_available():
|
| 60 |
from flash_attn import flash_attn_func, flash_attn_varlen_func
|
|
|
|
| 432 |
dropout=0.0,
|
| 433 |
softmax_scale=None,
|
| 434 |
use_sliding_windows=False,
|
| 435 |
+
is_causal=True,
|
| 436 |
):
|
| 437 |
"""
|
| 438 |
Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token
|
|
|
|
| 533 |
use_cache: bool = False,
|
| 534 |
cache_position: Optional[torch.LongTensor] = None,
|
| 535 |
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
|
| 536 |
+
is_causal: bool = True,
|
| 537 |
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
| 538 |
if isinstance(past_key_value, StaticCache):
|
| 539 |
raise ValueError(
|
|
|
|
| 660 |
use_cache: bool = False,
|
| 661 |
cache_position: Optional[torch.LongTensor] = None,
|
| 662 |
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
|
| 663 |
+
is_causal: bool = True,
|
| 664 |
**kwargs,
|
| 665 |
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
| 666 |
if output_attentions:
|
|
|
|
| 767 |
use_cache: Optional[bool] = False,
|
| 768 |
cache_position: Optional[torch.LongTensor] = None,
|
| 769 |
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
|
| 770 |
+
is_causal: bool = True,
|
| 771 |
**kwargs,
|
| 772 |
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
| 773 |
"""
|
|
|
|
| 952 |
"The bare LLaMA Model outputting raw hidden-states without any specific head on top.",
|
| 953 |
LLAMA_START_DOCSTRING,
|
| 954 |
)
|
| 955 |
+
|
| 956 |
+
|
| 957 |
class LlamaModel(LlamaPreTrainedModel):
|
| 958 |
"""
|
| 959 |
Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`LlamaDecoderLayer`]
|
|
|
|
| 997 |
output_hidden_states: Optional[bool] = None,
|
| 998 |
return_dict: Optional[bool] = None,
|
| 999 |
cache_position: Optional[torch.LongTensor] = None,
|
| 1000 |
+
is_causal: Optional[bool] = True,
|
| 1001 |
) -> Union[Tuple, BaseModelOutputWithPast]:
|
| 1002 |
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
| 1003 |
output_hidden_states = (
|
|
|
|
| 1669 |
hidden_states=outputs.hidden_states,
|
| 1670 |
attentions=outputs.attentions,
|
| 1671 |
)
|
| 1672 |
+
|
| 1673 |
+
|
| 1674 |
+
class ReasonIRModel(LLamaModel):
|
| 1675 |
+
"""
|
| 1676 |
+
ReasonIRModel is a wrapper around LlamaModel with bi-directional attention for retrieval tasks
|
| 1677 |
+
"""
|
| 1678 |
+
|
| 1679 |
+
def __init__(self, config: LlamaConfig):
|
| 1680 |
+
"""
|
| 1681 |
+
Initializes the ReasonIRModel with the given configuration.
|
| 1682 |
+
"""
|
| 1683 |
+
super().__init__(config)
|
| 1684 |
+
self.pooling_method = "mean"
|
| 1685 |
+
self.normalized = True
|
| 1686 |
+
self.embed_eos = ""
|
| 1687 |
+
self.reasonir_config = config
|
| 1688 |
+
self.tokenizer = AutoTokenizer.from_pretrained('reasonir/ReasonIR-8B')
|
| 1689 |
+
|
| 1690 |
+
def encode_queries(self, queries: Union[List[str], str], **kwargs) -> np.ndarray:
|
| 1691 |
+
"""Used for encoding the queries of retrieval or reranking tasks"""
|
| 1692 |
+
return self.encode(queries, **kwargs)
|
| 1693 |
+
|
| 1694 |
+
def encode_corpus(self, corpus: Union[List[str], str, List[Dict[str, str]]], **kwargs) -> np.ndarray:
|
| 1695 |
+
"""Used for encoding the corpus of retrieval tasks"""
|
| 1696 |
+
if isinstance(corpus, dict):
|
| 1697 |
+
corpus = [corpus]
|
| 1698 |
+
if isinstance(corpus, list) and isinstance(corpus[0], dict):
|
| 1699 |
+
corpus = [
|
| 1700 |
+
doc["title"] + " " + doc["text"] if "title" in doc
|
| 1701 |
+
else doc["text"] for doc in corpus
|
| 1702 |
+
]
|
| 1703 |
+
return self.encode(corpus, **kwargs)
|
| 1704 |
+
|
| 1705 |
+
@torch.inference_mode()
|
| 1706 |
+
def encode(
|
| 1707 |
+
self,
|
| 1708 |
+
sentences: Union[List[str], str],
|
| 1709 |
+
batch_size: int = 256,
|
| 1710 |
+
max_length: int = 512,
|
| 1711 |
+
instruction: str = "",
|
| 1712 |
+
embed_instruction: bool = False,
|
| 1713 |
+
get_cache: bool = False,
|
| 1714 |
+
convert_to_tensor: bool = False,
|
| 1715 |
+
recast: bool = False,
|
| 1716 |
+
add_special_tokens: bool = True,
|
| 1717 |
+
**kwargs,
|
| 1718 |
+
) -> np.ndarray:
|
| 1719 |
+
|
| 1720 |
+
# get number of gpus
|
| 1721 |
+
num_gpus = torch.cuda.device_count()
|
| 1722 |
+
if num_gpus > 0:
|
| 1723 |
+
batch_size *= num_gpus
|
| 1724 |
+
|
| 1725 |
+
input_was_string = False
|
| 1726 |
+
if isinstance(sentences, str):
|
| 1727 |
+
sentences = [sentences]
|
| 1728 |
+
input_was_string = True
|
| 1729 |
+
|
| 1730 |
+
all_embeddings, all_kv_caches = [], []
|
| 1731 |
+
for start_index in tqdm(range(0, len(sentences), batch_size), desc="Batches", disable=len(sentences)<256):
|
| 1732 |
+
sentences_batch = [
|
| 1733 |
+
instruction + s + self.embed_eos for s in sentences[start_index:start_index + batch_size]
|
| 1734 |
+
]
|
| 1735 |
+
# This will prepend the bos token if the tokenizer has `add_bos_token=True`
|
| 1736 |
+
inputs = self.tokenizer(
|
| 1737 |
+
sentences_batch,
|
| 1738 |
+
padding=True,
|
| 1739 |
+
truncation=True,
|
| 1740 |
+
return_tensors='pt',
|
| 1741 |
+
max_length=max_length,
|
| 1742 |
+
add_special_tokens=add_special_tokens,
|
| 1743 |
+
).to(self.device)
|
| 1744 |
+
|
| 1745 |
+
inputs["is_causal"] = False
|
| 1746 |
+
if get_cache:
|
| 1747 |
+
inputs['use_cache'] = True
|
| 1748 |
+
outputs = self(**inputs)
|
| 1749 |
+
last_hidden_state = outputs[0]
|
| 1750 |
+
if get_cache:
|
| 1751 |
+
# Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`
|
| 1752 |
+
assert len(all_kv_caches) == 0, "Can only get cache for one batch at a time"
|
| 1753 |
+
all_kv_caches = outputs[1]
|
| 1754 |
+
|
| 1755 |
+
if (instruction) and (embed_instruction is False) and ("mean" in self.pooling_method):
|
| 1756 |
+
# Remove instruction tokens from the embeddings by masking them
|
| 1757 |
+
instruction_tokens = self.tokenizer(
|
| 1758 |
+
instruction,
|
| 1759 |
+
padding=False,
|
| 1760 |
+
truncation=True,
|
| 1761 |
+
max_length=max_length,
|
| 1762 |
+
add_special_tokens=add_special_tokens,
|
| 1763 |
+
)["input_ids"]
|
| 1764 |
+
inputs['attention_mask'][:, :len(instruction_tokens)] = 0
|
| 1765 |
+
embeddings = self.pooling(last_hidden_state, inputs['attention_mask'], recast=recast)
|
| 1766 |
+
# Normalize can change the dtype (https://discuss.pytorch.org/t/tensor-in-float16-is-transformed-into-float32-after-torch-norm/110891)
|
| 1767 |
+
if self.normalized:
|
| 1768 |
+
in_dtype = embeddings.dtype
|
| 1769 |
+
embeddings = torch.nn.functional.normalize(embeddings, dim=-1).to(in_dtype)
|
| 1770 |
+
embeddings = cast(torch.Tensor, embeddings)
|
| 1771 |
+
if convert_to_tensor:
|
| 1772 |
+
all_embeddings.append(embeddings)
|
| 1773 |
+
else:
|
| 1774 |
+
# NumPy does not support bfloat16
|
| 1775 |
+
all_embeddings.append(embeddings.cpu().to(torch.float32).numpy())
|
| 1776 |
+
|
| 1777 |
+
all_embeddings = (
|
| 1778 |
+
torch.cat(all_embeddings, dim=0) if convert_to_tensor else np.concatenate(all_embeddings, axis=0)
|
| 1779 |
+
)
|
| 1780 |
+
if input_was_string:
|
| 1781 |
+
all_embeddings = all_embeddings[0]
|
| 1782 |
+
if get_cache:
|
| 1783 |
+
return all_embeddings, all_kv_caches
|
| 1784 |
+
return all_embeddings
|
| 1785 |
+
|
| 1786 |
+
def pooling(
|
| 1787 |
+
self, hidden_state: torch.Tensor, attention_mask: torch.Tensor = None, recast: bool = False
|
| 1788 |
+
) -> torch.Tensor:
|
| 1789 |
+
"""
|
| 1790 |
+
Args:
|
| 1791 |
+
hidden_state: [b, n, d]
|
| 1792 |
+
attention_mask: [b, n]
|
| 1793 |
+
"""
|
| 1794 |
+
# In case the model is distributed across multiple devices; hidden_state may end up on diff device
|
| 1795 |
+
hidden_state = hidden_state.to(attention_mask.device)
|
| 1796 |
+
if self.pooling_method == 'cls':
|
| 1797 |
+
embedding = hidden_state[:, 0]
|
| 1798 |
+
elif self.pooling_method == 'lasttoken':
|
| 1799 |
+
b, n, d = hidden_state.size()
|
| 1800 |
+
# Get the last `1` in the attention mask of each item
|
| 1801 |
+
# Often it is just `gather_indices = torch.argmin(attention_mask, 1, keepdim=False) - 1`
|
| 1802 |
+
# except when 1) There's all 1's 2) There's 0's before the 1's
|
| 1803 |
+
reversed_mask = torch.flip(attention_mask, dims=(1,))
|
| 1804 |
+
argmax_reverse = torch.argmax(reversed_mask, dim=1, keepdim=False)
|
| 1805 |
+
gather_indices = attention_mask.size(1) - argmax_reverse - 1
|
| 1806 |
+
# If there are empty sequences, where the index would become -1 it will crash so set them to 0
|
| 1807 |
+
gather_indices = torch.clamp(gather_indices, min=0)
|
| 1808 |
+
# Turn indices from shape [b] -> [b, 1, d]
|
| 1809 |
+
gather_indices = gather_indices.unsqueeze(-1).repeat(1, d)
|
| 1810 |
+
gather_indices = gather_indices.unsqueeze(1)
|
| 1811 |
+
assert gather_indices.shape == (b, 1, d)
|
| 1812 |
+
# Gather along the seq len: [b, n, d] -> [b, d]
|
| 1813 |
+
# Actually no need for the attention mask as we gather the last token where attn_mask=1 but
|
| 1814 |
+
# as some indices (which shouldn't be attended to) may be 0 due to clamp, use mask to ignore them again
|
| 1815 |
+
input_mask_expanded = attention_mask.unsqueeze(-1).expand((b, n, d)).float()
|
| 1816 |
+
embedding = torch.gather(hidden_state * input_mask_expanded, 1, gather_indices).squeeze(dim=1)
|
| 1817 |
+
elif self.pooling_method in ['mean', 'weightedmean']:
|
| 1818 |
+
if self.pooling_method == 'weightedmean':
|
| 1819 |
+
attention_mask *= attention_mask.cumsum(dim=1) # [0,1,1,1,0,0] -> [0,1,2,3,0,0]
|
| 1820 |
+
s = torch.sum(hidden_state * attention_mask.unsqueeze(-1).float(), dim=1)
|
| 1821 |
+
d = attention_mask.sum(dim=1, keepdim=True).float()
|
| 1822 |
+
embedding = s / d
|
| 1823 |
+
else: raise NotImplementedError(f"Unknown pooling method: {self.pooling_method}")
|
| 1824 |
+
# Recasting performs slightly worse but saves 50% space
|
| 1825 |
+
if recast: return embedding.to(hidden_state.dtype)
|
| 1826 |
+
return embedding
|