Update chatNT.py
Browse files
chatNT.py
CHANGED
|
@@ -417,10 +417,6 @@ class TorchBioBrainDecoder(nn.Module):
|
|
| 417 |
|
| 418 |
# Insert the bio embeddings at the SEQ token positions
|
| 419 |
processed_tokens_ids = english_token_ids.clone()
|
| 420 |
-
print("(debug) Before call tokens embeddings shape : ", tokens_embeddings.shape)
|
| 421 |
-
print("(debug) Before call Processed tokens ids shape : ", processed_tokens_ids.shape)
|
| 422 |
-
print("(debug) Before call Projected bio embeddings shape : ", projected_bio_embeddings.shape)
|
| 423 |
-
print("num bio sequences : ", num_bio_sequences)
|
| 424 |
for bio_seq_num in range(num_bio_sequences):
|
| 425 |
tokens_embeddings, processed_tokens_ids = self.insert_embeddings(
|
| 426 |
processed_tokens_ids,
|
|
@@ -428,7 +424,6 @@ class TorchBioBrainDecoder(nn.Module):
|
|
| 428 |
projected_bio_embeddings[:, bio_seq_num, :, :],
|
| 429 |
bio_seq_num=bio_seq_num,
|
| 430 |
)
|
| 431 |
-
print("After call : ", tokens_embeddings.shape)
|
| 432 |
|
| 433 |
# Regular GPT pass through
|
| 434 |
embeddings = self.gpt_model.apply_transformer_layers(tokens_embeddings)
|
|
@@ -471,8 +466,6 @@ class TorchBioBrainDecoder(nn.Module):
|
|
| 471 |
- input_embeddings with resampled_embeddings inserted at the SEQ token
|
| 472 |
- tokens with the SEQ token set to -1
|
| 473 |
"""
|
| 474 |
-
print("Tokens : ", list(tokens))
|
| 475 |
-
print("seq_token_id : ", self.seq_token_id)
|
| 476 |
|
| 477 |
def _insert(
|
| 478 |
tokens_1d: torch.Tensor,
|
|
@@ -488,7 +481,6 @@ class TorchBioBrainDecoder(nn.Module):
|
|
| 488 |
"""
|
| 489 |
indices = torch.where(tokens_1d == self.seq_token_id)[0]
|
| 490 |
if indices.numel() > 0:
|
| 491 |
-
print("going in if")
|
| 492 |
idx = indices[0].item()
|
| 493 |
insertion_pos = idx + resampled_embeddings_1d.shape[-2] * bio_seq_num
|
| 494 |
x = torch.cat(
|
|
@@ -505,7 +497,6 @@ class TorchBioBrainDecoder(nn.Module):
|
|
| 505 |
tokens_1d[idx] = -1
|
| 506 |
return x, tokens_1d
|
| 507 |
else:
|
| 508 |
-
print("going in else")
|
| 509 |
return (
|
| 510 |
input_embeddings,
|
| 511 |
tokens_1d,
|
|
@@ -680,6 +671,11 @@ class TorchMultiOmicsModel(PreTrainedModel):
|
|
| 680 |
Shape (batch_size, num_bio_sequences, ?, embed_dim)
|
| 681 |
"""
|
| 682 |
english_token_ids, bio_token_ids = multi_omics_tokens_ids
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 683 |
|
| 684 |
# Replace config.vocab_size value in english tokens
|
| 685 |
# We do this because the default vocab size (32000) doesn't match with the
|
|
@@ -698,8 +694,6 @@ class TorchMultiOmicsModel(PreTrainedModel):
|
|
| 698 |
vocab_size - 1
|
| 699 |
)
|
| 700 |
|
| 701 |
-
print("seq token id : ", self.seq_token_id)
|
| 702 |
-
print("Tokens at step 1 in multiomics : ", list(english_token_ids))
|
| 703 |
if bio_token_ids is None:
|
| 704 |
projected_bio_embeddings = None
|
| 705 |
else:
|
|
@@ -724,9 +718,7 @@ class TorchMultiOmicsModel(PreTrainedModel):
|
|
| 724 |
]
|
| 725 |
projected_bio_embeddings = torch.stack(projected_bio_embeddings, dim=1)
|
| 726 |
|
| 727 |
-
# decode
|
| 728 |
-
print("Tokens at step 2 in multiomics : ", list(english_token_ids))
|
| 729 |
-
|
| 730 |
logits = self.biobrain_decoder(
|
| 731 |
english_token_ids=english_token_ids,
|
| 732 |
projected_bio_embeddings=projected_bio_embeddings,
|
|
|
|
| 417 |
|
| 418 |
# Insert the bio embeddings at the SEQ token positions
|
| 419 |
processed_tokens_ids = english_token_ids.clone()
|
|
|
|
|
|
|
|
|
|
|
|
|
| 420 |
for bio_seq_num in range(num_bio_sequences):
|
| 421 |
tokens_embeddings, processed_tokens_ids = self.insert_embeddings(
|
| 422 |
processed_tokens_ids,
|
|
|
|
| 424 |
projected_bio_embeddings[:, bio_seq_num, :, :],
|
| 425 |
bio_seq_num=bio_seq_num,
|
| 426 |
)
|
|
|
|
| 427 |
|
| 428 |
# Regular GPT pass through
|
| 429 |
embeddings = self.gpt_model.apply_transformer_layers(tokens_embeddings)
|
|
|
|
| 466 |
- input_embeddings with resampled_embeddings inserted at the SEQ token
|
| 467 |
- tokens with the SEQ token set to -1
|
| 468 |
"""
|
|
|
|
|
|
|
| 469 |
|
| 470 |
def _insert(
|
| 471 |
tokens_1d: torch.Tensor,
|
|
|
|
| 481 |
"""
|
| 482 |
indices = torch.where(tokens_1d == self.seq_token_id)[0]
|
| 483 |
if indices.numel() > 0:
|
|
|
|
| 484 |
idx = indices[0].item()
|
| 485 |
insertion_pos = idx + resampled_embeddings_1d.shape[-2] * bio_seq_num
|
| 486 |
x = torch.cat(
|
|
|
|
| 497 |
tokens_1d[idx] = -1
|
| 498 |
return x, tokens_1d
|
| 499 |
else:
|
|
|
|
| 500 |
return (
|
| 501 |
input_embeddings,
|
| 502 |
tokens_1d,
|
|
|
|
| 671 |
Shape (batch_size, num_bio_sequences, ?, embed_dim)
|
| 672 |
"""
|
| 673 |
english_token_ids, bio_token_ids = multi_omics_tokens_ids
|
| 674 |
+
english_token_ids = english_token_ids.clone()
|
| 675 |
+
bio_token_ids = bio_token_ids.clone()
|
| 676 |
+
projection_english_tokens_ids = projection_english_tokens_ids.clone()
|
| 677 |
+
if projected_bio_embeddings is not None:
|
| 678 |
+
projected_bio_embeddings = projected_bio_embeddings.clone()
|
| 679 |
|
| 680 |
# Replace config.vocab_size value in english tokens
|
| 681 |
# We do this because the default vocab size (32000) doesn't match with the
|
|
|
|
| 694 |
vocab_size - 1
|
| 695 |
)
|
| 696 |
|
|
|
|
|
|
|
| 697 |
if bio_token_ids is None:
|
| 698 |
projected_bio_embeddings = None
|
| 699 |
else:
|
|
|
|
| 718 |
]
|
| 719 |
projected_bio_embeddings = torch.stack(projected_bio_embeddings, dim=1)
|
| 720 |
|
| 721 |
+
# decode
|
|
|
|
|
|
|
| 722 |
logits = self.biobrain_decoder(
|
| 723 |
english_token_ids=english_token_ids,
|
| 724 |
projected_bio_embeddings=projected_bio_embeddings,
|