Kiria-Nozan commited on
Commit
80ad4cd
·
1 Parent(s): c6e45fc

solve same embedding bug

Browse files
Files changed (48) hide show
  1. DLM_emb_model.py +3 -3
  2. __pycache__/DLM_emb_model.cpython-39.pyc +0 -0
  3. __pycache__/noise_schedule.cpython-39.pyc +0 -0
  4. compare_source_vs_hf.py +108 -0
  5. configs/callbacks/checkpoint_every_n_steps.yaml +8 -0
  6. configs/callbacks/checkpoint_monitor.yaml +10 -0
  7. configs/callbacks/learning_rate_monitor.yaml +3 -0
  8. configs/config.yaml +102 -0
  9. configs/data/ag_news.yaml +6 -0
  10. configs/data/lambada.yaml +6 -0
  11. configs/data/lm1b-gpt2.yaml +6 -0
  12. configs/data/lm1b-streaming.yaml +6 -0
  13. configs/data/lm1b.yaml +6 -0
  14. configs/data/openwebtext-split.yaml +6 -0
  15. configs/data/openwebtext-streaming.yaml +6 -0
  16. configs/data/openwebtext.yaml +6 -0
  17. configs/data/ptb.yaml +6 -0
  18. configs/data/scientific_papers_arxiv.yaml +6 -0
  19. configs/data/scientific_papers_pubmed.yaml +6 -0
  20. configs/data/text8-crop.yaml +7 -0
  21. configs/data/text8.yaml +7 -0
  22. configs/data/wikitext103.yaml +6 -0
  23. configs/data/wikitext2.yaml +6 -0
  24. configs/lr_scheduler/constant_warmup.yaml +2 -0
  25. configs/lr_scheduler/cosine_decay_warmup.yaml +7 -0
  26. configs/model/medium.yaml +10 -0
  27. configs/model/small-ar.yaml +11 -0
  28. configs/model/small.yaml +10 -0
  29. configs/model/tiny-ar.yaml +11 -0
  30. configs/model/tiny-dimamba.yaml +11 -0
  31. configs/model/tiny.yaml +10 -0
  32. configs/noise/ar.yaml +2 -0
  33. configs/noise/linear.yaml +3 -0
  34. configs/noise/loglinear.yaml +3 -0
  35. configs/noise/polynomial.yaml +5 -0
  36. configs/strategy/ddp.yaml +2 -0
  37. configs/strategy/fsdp.yaml +3 -0
  38. models/__pycache__/__init__.cpython-39.pyc +0 -0
  39. models/__pycache__/autoregressive.cpython-39.pyc +0 -0
  40. models/__pycache__/dimamba.cpython-39.pyc +0 -0
  41. models/__pycache__/dit.cpython-39.pyc +0 -0
  42. models/__pycache__/ema.cpython-39.pyc +0 -0
  43. models/dit.py +1 -1
  44. reproduce_issue.py +71 -0
  45. temp_data/monomer_embeddings.npy +0 -0
  46. temp_data/polymers_lit_scraped.csv +57 -0
  47. temp_fangping.py +127 -0
  48. verify_selfies.py +83 -0
DLM_emb_model.py CHANGED
@@ -31,10 +31,10 @@ import ast
31
  from omegaconf import OmegaConf, DictConfig, ListConfig
32
  from huggingface_hub import PyTorchModelHubMixin
33
 
34
- # current_directory = Path(__file__).parent
35
- current_directory = Path('/data2/tianang/projects/Synergy')
36
 
37
- with initialize_config_dir(config_dir="/data2/tianang/projects/mdlm/configs"):
38
  config = compose(config_name="config")
39
 
40
  class mol_emb_mdlm(nn.Module):
 
31
  from omegaconf import OmegaConf, DictConfig, ListConfig
32
  from huggingface_hub import PyTorchModelHubMixin
33
 
34
+ current_directory = Path(__file__).parent
35
+ # current_directory = Path('/data2/tianang/projects/Synergy')
36
 
37
+ with initialize_config_dir(config_dir=str(current_directory/"configs")):
38
  config = compose(config_name="config")
39
 
40
  class mol_emb_mdlm(nn.Module):
__pycache__/DLM_emb_model.cpython-39.pyc CHANGED
Binary files a/__pycache__/DLM_emb_model.cpython-39.pyc and b/__pycache__/DLM_emb_model.cpython-39.pyc differ
 
__pycache__/noise_schedule.cpython-39.pyc ADDED
Binary file (6.17 kB). View file
 
compare_source_vs_hf.py ADDED
@@ -0,0 +1,108 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import torch
3
+ from transformers import AutoTokenizer
4
+ import sys
5
+ import os
6
+ from hydra import compose, initialize_config_dir
7
+ from pathlib import Path
8
+ import numpy as np
9
+
10
+ # Add current dir to path
11
+ sys.path.append(os.getcwd())
12
+
13
+ try:
14
+ from DLM_emb_model import MolEmbDLM
15
+ except ImportError:
16
+ print("Could not import MolEmbDLM. Make sure you are running from ApexOracle directory.")
17
+ exit(1)
18
+
19
+ def load_source_model():
20
+ print("Loading Source Model...")
21
+ current_directory = Path(os.getcwd())
22
+ # Replicating logic from DLM_emb_model.py
23
+ with initialize_config_dir(config_dir=str(current_directory/"configs"), version_base=None):
24
+ config = compose(config_name="config")
25
+
26
+ model_name = "ibm-research/materials.selfies-ted"
27
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
28
+
29
+ DIT_ckpt_path = '/data2/tianang/projects/mdlm/Checkpoints_fangping/1-255000-fine-tune.ckpt'
30
+ model = MolEmbDLM(config, len(tokenizer.get_vocab()), DIT_ckpt_path, tokenizer.mask_token_id)
31
+ model.eval()
32
+ return model, tokenizer
33
+
34
+ def load_hf_model():
35
+ print("Loading HF Model...")
36
+ model_path = "/data2/tianang/projects/mdlm/huggingface/huggingface_model"
37
+ # We use the same class but loaded via from_pretrained
38
+ try:
39
+ tokenizer = AutoTokenizer.from_pretrained(model_path)
40
+ model = MolEmbDLM.from_pretrained(model_path)
41
+ except Exception as e:
42
+ print(f"Failed to load HF model: {e}")
43
+ # Fallback to local if needed, though path is absolute
44
+ model = MolEmbDLM.from_pretrained(".")
45
+ model.eval()
46
+ return model, tokenizer
47
+
48
+ def main():
49
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
50
+
51
+ # Load Source Model
52
+ source_model, source_tokenizer = load_source_model()
53
+ source_model.to(device)
54
+
55
+ # Load HF Model
56
+ hf_model, hf_tokenizer = load_hf_model()
57
+ hf_model.to(device)
58
+
59
+ # Test Input (SELFIES)
60
+ selfies = "[C][C][=O][O]" # Ethanol "[C][C][=O][O]"
61
+ processed_selfies = selfies.replace('][', '] [')
62
+
63
+ print(f"Testing with SELFIES: {processed_selfies}")
64
+
65
+ # Tokenize (using source tokenizer for both to ensure identical input ids if tokenizers are same)
66
+ # Note: HF model folder has its own tokenizer files, source uses "ibm-research/materials.selfies-ted".
67
+ # They should be the same, but let's verify input_ids match too.
68
+
69
+ inputs_source = source_tokenizer(processed_selfies, return_tensors="pt", padding=False, truncation=False)
70
+ inputs_hf = hf_tokenizer(processed_selfies, return_tensors="pt", padding=False, truncation=False)
71
+
72
+ print(f"Source Input IDs: {inputs_source['input_ids']}")
73
+ print(f"HF Input IDs: {inputs_hf['input_ids']}")
74
+
75
+ if not torch.equal(inputs_source['input_ids'], inputs_hf['input_ids']):
76
+ print("WARNING: Tokenizers produced different input IDs!")
77
+
78
+ # Run Source Model
79
+ inputs_s = {k: v.to(device) for k, v in inputs_source.items() if k in ["input_ids", "attention_mask"]}
80
+ with torch.no_grad():
81
+ emb_source = source_model(**inputs_s)
82
+
83
+ # Run HF Model
84
+ inputs_h = {k: v.to(device) for k, v in inputs_hf.items() if k in ["input_ids", "attention_mask"]}
85
+ with torch.no_grad():
86
+ emb_hf = hf_model(**inputs_h)
87
+
88
+ print(f'Huggingface Embeddings: {emb_hf[0][0]}')
89
+
90
+ print(f"Source Emb Shape: {emb_source.shape}")
91
+ print(f"HF Emb Shape: {emb_hf.shape}")
92
+
93
+ # Compare
94
+ diff = torch.abs(emb_source - emb_hf).sum().item()
95
+ max_diff = torch.abs(emb_source - emb_hf).max().item()
96
+
97
+ print(f"Sum of Absolute Differences: {diff}")
98
+ print(f"Max Absolute Difference: {max_diff}")
99
+
100
+ if diff < 1e-5: # Allow small floating point differences
101
+ print("SUCCESS: Embeddings are identical (or extremely close).")
102
+ else:
103
+ print("FAILURE: Embeddings differ significantly.")
104
+ print(f"Source Mean: {emb_source.mean().item()}")
105
+ print(f"HF Mean: {emb_hf.mean().item()}")
106
+
107
+ if __name__ == "__main__":
108
+ main()
configs/callbacks/checkpoint_every_n_steps.yaml ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ checkpoint_every_n_steps:
2
+ _target_: lightning.pytorch.callbacks.ModelCheckpoint
3
+ save_top_k: -1 # Do not save any "best" models; this callback is being used to save every n train steps
4
+ save_last: True # save model as ${save_dir}/checkpoints/last.ckpt
5
+ dirpath: ${checkpointing.save_dir}/checkpoints
6
+ verbose: True
7
+ auto_insert_metric_name: False
8
+ every_n_train_steps: 500
configs/callbacks/checkpoint_monitor.yaml ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ checkpoint_monitor:
2
+ _target_: lightning.pytorch.callbacks.ModelCheckpoint
3
+ monitor: val/nll # name of the logged metric which determines when model is improving
4
+ mode: min # can be "max" or "min"
5
+ save_top_k: 1 # save k best models (determined by above metric)
6
+ save_last: False # True = additionally always save model from last epoch
7
+ dirpath: ${checkpointing.save_dir}/checkpoints
8
+ filename: best
9
+ auto_insert_metric_name: False
10
+ verbose: True
configs/callbacks/learning_rate_monitor.yaml ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ learning_rate_monitor:
2
+ _target_: lightning.pytorch.callbacks.LearningRateMonitor
3
+ logging_interval: step
configs/config.yaml ADDED
@@ -0,0 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ defaults:
2
+ - _self_
3
+ - /callbacks: [checkpoint_every_n_steps, checkpoint_monitor, learning_rate_monitor]
4
+ - /data: openwebtext
5
+ - /model: small # small / medium
6
+ - /strategy: ddp
7
+ - /noise: loglinear
8
+ - /lr_scheduler: constant_warmup
9
+
10
+ mode: sample_eval # train / ppl_eval / sample_eval
11
+ diffusion: absorbing_state
12
+ backbone: dit # dit / dimamba / ar
13
+ parameterization: subs # subs / d3pm / sedd
14
+ time_conditioning: False
15
+ T: 0 # 0 (continuous time) / 1000
16
+ subs_masking: False
17
+
18
+ seed: 1
19
+
20
+ loader:
21
+ global_batch_size: 512
22
+ eval_global_batch_size: ${.global_batch_size}
23
+ # Note: batch_size and eval_batch_size are **per machine**
24
+ batch_size: ${div_up:${.global_batch_size}, ${eval:${trainer.devices} * ${trainer.num_nodes}}}
25
+ eval_batch_size: ${div_up:${.eval_global_batch_size}, ${eval:${trainer.devices} * ${trainer.num_nodes}}}
26
+ num_workers: ${eval:"len(__import__('os').sched_getaffinity(0))"}
27
+ pin_memory: True
28
+
29
+ sampling:
30
+ predictor: ddpm_cache # analytic, ddpm, ddpm_cache
31
+ steps: 128
32
+ noise_removal: True
33
+ # TODO(yair): @subham, why aren't these params under `eval`?
34
+ num_sample_batches: 2 # Total samples: `num_gpus` * `loader.eval_batch_size` * num_sample_batches
35
+ num_sample_log: 2
36
+ semi_ar: False
37
+ stride_length: 1
38
+ num_strides: 1
39
+
40
+
41
+ training:
42
+ ema: 0.9999
43
+ antithetic_sampling: True
44
+ importance_sampling: False
45
+ sampling_eps: 1e-3
46
+ change_of_variables: False
47
+
48
+ eval:
49
+ checkpoint_path: '/data2/tianang/projects/mdlm/Checkpoints_fangping/1-255000-fine-tune.ckpt' # Used to evaluate a checkpoint after training.
50
+ disable_ema: False
51
+ compute_generative_perplexity: False
52
+ perplexity_batch_size: 8
53
+ compute_perplexity_on_sanity: False
54
+ gen_ppl_eval_model_name_or_path: gpt2-large # gpt2-large, meta-llama/Llama-2-7b-hf
55
+ generate_samples: True
56
+
57
+ optim:
58
+ weight_decay: 0
59
+ lr: 3e-4
60
+ beta1: 0.9
61
+ beta2: 0.999
62
+ eps: 1e-8
63
+
64
+ trainer:
65
+ _target_: lightning.Trainer
66
+ accelerator: cuda
67
+ num_nodes: 1
68
+ devices: ${device_count:}
69
+ accumulate_grad_batches: ${div_up:${loader.global_batch_size}, ${eval:${trainer.devices} * ${loader.batch_size} * ${trainer.num_nodes}}}
70
+ gradient_clip_val: 1.0
71
+ precision: 'bf16'
72
+ num_sanity_val_steps: 2
73
+ max_steps: 1_000_000
74
+ log_every_n_steps: 10
75
+ limit_train_batches: 1.0 # train on full dataset, can be used to toggle quick run
76
+ limit_val_batches: 1.0 # validate on full dataset, can be used to toggle quick run
77
+ val_check_interval: 10000
78
+
79
+ wandb:
80
+ project: text-diffusion
81
+ notes: Mulan for text
82
+ group: null
83
+ job_type: null
84
+ name: null
85
+ id: ${.name}_${seed}
86
+ tags:
87
+ - ${noise.type}
88
+ - ${data.train}
89
+ - ${data.valid}
90
+
91
+ hydra:
92
+ run:
93
+ dir: ./outputs/${data.train}/${now:%Y.%m.%d}/${now:%H%M%S}
94
+ job:
95
+ chdir: true
96
+
97
+ checkpointing:
98
+ # Use custom `save_dir` if, e.g., saving to S3 bucket, otherwise leave this parameter as is
99
+ save_dir: ${cwd:}
100
+ # Note: `checkpoints` path should correspond to `checkpoint_every_n_steps.dirpath`
101
+ resume_from_ckpt: true
102
+ resume_ckpt_path: ${.save_dir}/checkpoints/last.ckpt
configs/data/ag_news.yaml ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ train: ag_news
2
+ valid: ag_news
3
+ tokenizer_name_or_path: gpt2
4
+ cache_dir: /share/kuleshov/ssahoo/textdiffusion/data
5
+ wrap: True
6
+ streaming: False
configs/data/lambada.yaml ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ train: lambada
2
+ valid: lambada
3
+ tokenizer_name_or_path: gpt2
4
+ cache_dir: /share/kuleshov/ssahoo/textdiffusion/data
5
+ wrap: True
6
+ streaming: False
configs/data/lm1b-gpt2.yaml ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ train: lm1b
2
+ valid: lm1b
3
+ tokenizer_name_or_path: gpt2
4
+ cache_dir: /share/kuleshov/ssahoo/textdiffusion/data
5
+ wrap: True
6
+ streaming: False
configs/data/lm1b-streaming.yaml ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ train: lm1b
2
+ valid: lm1b
3
+ tokenizer_name_or_path: bert-base-uncased
4
+ cache_dir: /share/kuleshov/ssahoo/textdiffusion/data
5
+ wrap: False
6
+ streaming: True
configs/data/lm1b.yaml ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ train: lm1b
2
+ valid: lm1b
3
+ tokenizer_name_or_path: bert-base-uncased
4
+ cache_dir: /share/kuleshov/ssahoo/textdiffusion/data
5
+ wrap: False
6
+ streaming: False
configs/data/openwebtext-split.yaml ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ train: openwebtext-train
2
+ valid: openwebtext-valid
3
+ tokenizer_name_or_path: gpt2
4
+ cache_dir: /share/kuleshov/ssahoo/textdiffusion/data
5
+ wrap: True
6
+ streaming: False
configs/data/openwebtext-streaming.yaml ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ train: openwebtext
2
+ valid: wikitext103
3
+ tokenizer_name_or_path: gpt2
4
+ cache_dir: /tmp/data
5
+ wrap: True
6
+ streaming: True
configs/data/openwebtext.yaml ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ train: openwebtext
2
+ valid: wikitext103
3
+ tokenizer_name_or_path: ibm-research/materials.selfies-ted
4
+ cache_dir: /share/kuleshov/ssahoo/textdiffusion/data
5
+ wrap: True
6
+ streaming: False
configs/data/ptb.yaml ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ train: ptb
2
+ valid: ptb
3
+ tokenizer_name_or_path: gpt2
4
+ cache_dir: /share/kuleshov/ssahoo/textdiffusion/data
5
+ wrap: True
6
+ streaming: False
configs/data/scientific_papers_arxiv.yaml ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ train: scientific_papers_arxiv
2
+ valid: scientific_papers_arxiv
3
+ tokenizer_name_or_path: gpt2
4
+ cache_dir: /share/kuleshov/ssahoo/textdiffusion/data
5
+ wrap: True
6
+ streaming: False
configs/data/scientific_papers_pubmed.yaml ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ train: scientific_papers_pubmed
2
+ valid: scientific_papers_pubmed
3
+ tokenizer_name_or_path: gpt2
4
+ cache_dir: /share/kuleshov/ssahoo/textdiffusion/data
5
+ wrap: True
6
+ streaming: False
configs/data/text8-crop.yaml ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ # TODO: When using this dataset, set model.length = 256 to match D3PM setup
2
+ train: text8-crop
3
+ valid: text8
4
+ tokenizer_name_or_path: text8
5
+ cache_dir: /share/kuleshov/ssahoo/textdiffusion/data
6
+ wrap: True
7
+ streaming: False
configs/data/text8.yaml ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ # TODO: When using this dataset, set model.length = 256 to match D3PM setup
2
+ train: text8
3
+ valid: text8
4
+ tokenizer_name_or_path: text8
5
+ cache_dir: /share/kuleshov/ssahoo/textdiffusion/data
6
+ wrap: True
7
+ streaming: False
configs/data/wikitext103.yaml ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ train: wikitext103
2
+ valid: wikitext103
3
+ tokenizer_name_or_path: gpt2
4
+ cache_dir: /share/kuleshov/ssahoo/textdiffusion/data
5
+ wrap: True
6
+ streaming: False
configs/data/wikitext2.yaml ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ train: wikitext2
2
+ valid: wikitext2
3
+ tokenizer_name_or_path: gpt2
4
+ cache_dir: /share/kuleshov/ssahoo/textdiffusion/data
5
+ wrap: True
6
+ streaming: False
configs/lr_scheduler/constant_warmup.yaml ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ _target_: transformers.get_constant_schedule_with_warmup
2
+ num_warmup_steps: 2500
configs/lr_scheduler/cosine_decay_warmup.yaml ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ _target_: utils.CosineDecayWarmupLRScheduler
2
+ t_in_epochs: False
3
+ t_initial: ${eval:${trainer.max_steps}-${.warmup_t}}
4
+ warmup_prefix: True
5
+ warmup_lr_init: 1e-6
6
+ warmup_t: ${eval:0.1*${trainer.max_steps}}
7
+ lr_min: 1e-6
configs/model/medium.yaml ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ name: medium
2
+ type: ddit
3
+ hidden_size: 1024
4
+ cond_dim: 128
5
+ length: 1024
6
+ n_blocks: 24
7
+ n_heads: 16
8
+ scale_by_sigma: True
9
+ dropout: 0.1
10
+ tie_word_embeddings: False
configs/model/small-ar.yaml ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: small
2
+ type: ddit
3
+ hidden_size: 768
4
+ cond_dim: 128
5
+ length: 1024
6
+ n_blocks: 12
7
+ n_heads: 12
8
+ scale_by_sigma: True
9
+ dropout: 0.1
10
+ causal: True
11
+ tie_word_embeddings: False
configs/model/small.yaml ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ name: small
2
+ type: ddit
3
+ hidden_size: 768
4
+ cond_dim: 128
5
+ length: 1024
6
+ n_blocks: 12
7
+ n_heads: 12
8
+ scale_by_sigma: True
9
+ dropout: 0.1
10
+ tie_word_embeddings: False
configs/model/tiny-ar.yaml ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: tiny
2
+ type: ddit
3
+ hidden_size: 512
4
+ cond_dim: 128
5
+ length: 1024
6
+ n_blocks: 8
7
+ n_heads: 8
8
+ scale_by_sigma: True
9
+ dropout: 0.1
10
+ causal: True
11
+ tie_word_embeddings: False
configs/model/tiny-dimamba.yaml ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: tiny
2
+ type: dimamba
3
+ hidden_size: 512
4
+ cond_dim: 128
5
+ length: 1024
6
+ n_blocks: 14
7
+ n_heads: 8
8
+ scale_by_sigma: True
9
+ dropout: 0.1
10
+ temb_strategy: adaln
11
+ tie_word_embeddings: False
configs/model/tiny.yaml ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ name: tiny
2
+ type: ddit
3
+ hidden_size: 512
4
+ cond_dim: 128
5
+ length: 1024
6
+ n_blocks: 8
7
+ n_heads: 8
8
+ scale_by_sigma: True
9
+ dropout: 0.1
10
+ tie_word_embeddings: False
configs/noise/ar.yaml ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ type: ar
2
+ scale: 6.0
configs/noise/linear.yaml ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ type: linear
2
+ sigma_min: 1e-3
3
+ sigma_max: 7.0
configs/noise/loglinear.yaml ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ type: loglinear
2
+ sigma_min: 1e-4
3
+ sigma_max: 20
configs/noise/polynomial.yaml ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ type: polynomial
2
+ a: -3
3
+ b: 5
4
+ c: -4
5
+ eps: 1e-3
configs/strategy/ddp.yaml ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ _target_: lightning.pytorch.strategies.DDPStrategy
2
+ find_unused_parameters: false # TODO(yair): this seems hacky, I think if things are correct we shouldn't need this
configs/strategy/fsdp.yaml ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ # TODO(yair): Currenly not compatible with grad clipping
2
+ _target_: lightning.pytorch.strategies.FSDPStrategy
3
+ sharding_strategy: SHARD_GRAD_OP
models/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (265 Bytes). View file
 
models/__pycache__/autoregressive.cpython-39.pyc ADDED
Binary file (9.53 kB). View file
 
models/__pycache__/dimamba.cpython-39.pyc ADDED
Binary file (24.7 kB). View file
 
models/__pycache__/dit.cpython-39.pyc ADDED
Binary file (14.8 kB). View file
 
models/__pycache__/ema.cpython-39.pyc ADDED
Binary file (4.64 kB). View file
 
models/dit.py CHANGED
@@ -339,7 +339,7 @@ class DDiTBlock_non_pad(nn.Module):
339
  qkv = rearrange(qkv, 'b s ... -> (b s) ...')
340
 
341
  # --------------------------------
342
- mask_flat = attnmask.reshape(-1)
343
  qkv = qkv[mask_flat]
344
  seqlens = attnmask.sum(dim=1)
345
  pad_seq_len = torch.zeros(len(seqlens)+1, dtype=torch.int32, device=qkv.device)
 
339
  qkv = rearrange(qkv, 'b s ... -> (b s) ...')
340
 
341
  # --------------------------------
342
+ mask_flat = attnmask.reshape(-1).bool()
343
  qkv = qkv[mask_flat]
344
  seqlens = attnmask.sum(dim=1)
345
  pad_seq_len = torch.zeros(len(seqlens)+1, dtype=torch.int32, device=qkv.device)
reproduce_issue.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import torch
3
+ from transformers import AutoTokenizer
4
+ import sys
5
+ import os
6
+
7
+ # Add current dir to path
8
+ sys.path.append(os.getcwd())
9
+
10
+ try:
11
+ from DLM_emb_model import MolEmbDLM
12
+ except ImportError:
13
+ print("Could not import MolEmbDLM. Make sure you are running from ApexOracle directory.")
14
+ exit(1)
15
+
16
+ model_path = "/data2/tianang/projects/mdlm/huggingface/huggingface_model"
17
+
18
+ print(f"Loading model from {model_path}...")
19
+ try:
20
+ tokenizer = AutoTokenizer.from_pretrained(model_path)
21
+ model = MolEmbDLM.from_pretrained(model_path)
22
+ except Exception as e:
23
+ print(f"Failed to load model: {e}")
24
+ # Try loading from local directory if remote fails (though path is local)
25
+ try:
26
+ print("Trying to load from local directory...")
27
+ model = MolEmbDLM.from_pretrained(".")
28
+ except Exception as e2:
29
+ print(f"Failed to load from local: {e2}")
30
+ exit(1)
31
+
32
+ model.eval()
33
+ model.to("cuda" if torch.cuda.is_available() else "cpu")
34
+
35
+ # Two different molecules
36
+ smiles_list = [
37
+ "CC(=O)OC1=CC=CC=C1C(=O)O", # Aspirin
38
+ "CN1C=NC2=C1C(=O)N(C(=O)N2C)C" # Caffeine
39
+ ]
40
+
41
+ print("Tokenizing inputs...")
42
+ inputs = tokenizer(smiles_list, return_tensors="pt", padding=True, truncation=True)
43
+ inputs = {k: v.to(next(model.parameters()).device) for k, v in inputs.items() if k in ["input_ids", "attention_mask"]}
44
+ # Force different inputs to verify model logic, bypassing tokenizer issue
45
+ inputs['input_ids'][1] = inputs['input_ids'][1] + 1
46
+ print(f"Input IDs: {inputs['input_ids']}")
47
+ print(f"Attention Mask: {inputs['attention_mask']}")
48
+
49
+ print("Running model...")
50
+ with torch.no_grad():
51
+ outputs = model(**inputs)
52
+
53
+ # outputs is the tensor returned by forward
54
+ embeddings = outputs
55
+
56
+ print(f"Embeddings shape: {embeddings.shape}")
57
+
58
+ emb1 = embeddings[0].cpu().numpy()
59
+ emb2 = embeddings[1].cpu().numpy()
60
+
61
+ # Compare
62
+ diff = torch.abs(embeddings[0] - embeddings[1]).sum().item()
63
+ print(f"Difference between embeddings: {diff}")
64
+
65
+ if diff < 1e-6:
66
+ print("ISSUE REPRODUCED: Embeddings are identical.")
67
+ else:
68
+ print("Embeddings are different.")
69
+
70
+ print(f"Emb1 mean: {emb1.mean()}")
71
+ print(f"Emb2 mean: {emb2.mean()}")
temp_data/monomer_embeddings.npy ADDED
Binary file (38 kB). View file
 
temp_data/polymers_lit_scraped.csv ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Notebook reference,Polymer name,monomer A,mol fraction A,monomer B,fraction B,monomer C,fraction C,monomer D,fraction D,monomer E,fraction E,monomer F,fraction F,Distribution,Architecture,Target DP,MIC (E. coli),MIC (S. aureus),MIC (K. pneumoniae),MIC (E. faecium),HC50
2
+ SW1.84.1,L-Ni31Mo10,C=CC(=O)NCC[N+](C)(C)C.[Cl-],0.45,C=CC(=O)NC(C)C,0.43,C=CC(=O)N1CCOCC1,0.12,,,,,,,statistical,linear,70,>512,>512,,,>2000
3
+ SW1.84.2,L-Ni31Mep10,C=CC(=O)NCC[N+](C)(C)C.[Cl-],0.45,C=CC(=O)NC(C)C,0.43,C=CC(=O)NCCCOC,0.11,,,,,,,statistical,linear,70,>512,>512,,,>2000
4
+ SW1.84.3,L-Phe31Mo10,C=CC(=O)NCC[N+](C)(C)C.[Cl-],0.50,C=CC(=O)Nc1ccccc1,0.37,C=CC(=O)N1CCOCC1,0.13,,,,,,,statistical,linear,70,256,>512,,,>2000
5
+ SW1.89.1,L-Phe31Mep10,C=CC(=O)NCC[N+](C)(C)C.[Cl-],0.51,C=CC(=O)Nc1ccccc1,0.37,C=CC(=O)NCCCOC,0.13,,,,,,,statistical,linear,70,256,>512,,,>2000
6
+ SW1.89.2,L-Do31Mo10,C=CC(=O)NCC[N+](C)(C)C.[Cl-],0.59,C=CC(=O)NCCCCCCCCCCCC,0.26,C=CC(=O)N1CCOCC1,0.15,,,,,,,statistical,linear,70,128,32-64,256,512,>2000
7
+ SW1.89.3,L-Do31Mep10,C=CC(=O)NCC[N+](C)(C)C.[Cl-],0.59,C=CC(=O)NCCCCCCCCCCCC,0.26,C=CC(=O)NCCCOC,0.15,,,,,,,statistical,linear,70,128,32,512,512,>2000
8
+ SW1.110.1,L-Ni13Mo4,C=CC(=O)NCC[N+](C)(C)C.[Cl-],0.73,C=CC(=O)NC(C)C,0.21,C=CC(=O)N1CCOCC1,0.06,,,,,,,statistical,linear,70,>512,32,,,>2000
9
+ SW1.110.2,L-Ni13Mep4,C=CC(=O)NCC[N+](C)(C)C.[Cl-],0.73,C=CC(=O)NC(C)C,0.21,C=CC(=O)NCCCOC,0.06,,,,,,,statistical,linear,70,>512,64-128,,,>2000
10
+ SW1.110.3,L-Phe13Mo4,C=CC(=O)NCC[N+](C)(C)C.[Cl-],0.77,C=CC(=O)Nc1ccccc1,0.17,C=CC(=O)N1CCOCC1,0.06,,,,,,,statistical,linear,70,>512,32,,,>2000
11
+ SW1.115.1,L-Phe13Mep4,C=CC(=O)NCC[N+](C)(C)C.[Cl-],0.77,C=CC(=O)Nc1ccccc1,0.17,C=CC(=O)NCCCOC,0.06,,,,,,,statistical,linear,70,>512,32,,,>2000
12
+ SW1.115.2,L-Do13Mo4,C=CC(=O)NCC[N+](C)(C)C.[Cl-],0.83,C=CC(=O)NCCCCCCCCCCCC,0.11,C=CC(=O)N1CCOCC1,0.06,,,,,,,statistical,linear,70,256-512,32,,,<50
13
+ SW1.115.3,L-Do13Mep4,C=CC(=O)NCC[N+](C)(C)C.[Cl-],0.83,C=CC(=O)NCCCCCCCCCCCC,0.11,C=CC(=O)NCCCOC,0.06,,,,,,,statistical,linear,70,256,32,256,256,>2000
14
+ SW1.119.1,H-Ni31Mo10,C=CC(=O)NCC[N+](C)(C)C.[Cl-],0.45,C=CC(=O)NC(C)C,0.43,C=CC(=O)N1CCOCC1,0.12,,,,,,,statistical,linear,115,>512,128,,,>8000
15
+ SW1.119.2,H-Ni31Mep10,C=CC(=O)NCC[N+](C)(C)C.[Cl-],0.45,C=CC(=O)NC(C)C,0.43,C=CC(=O)NCCCOC,0.11,,,,,,,statistical,linear,115,>512,>512,,,>8000
16
+ SW1.119.3,H-Phe31Mo10,C=CC(=O)NCC[N+](C)(C)C.[Cl-],0.50,C=CC(=O)Nc1ccccc1,0.37,C=CC(=O)N1CCOCC1,0.13,,,,,,,statistical,linear,115,256-512,128-256,64,>512,>8000
17
+ SW1.125.1,H-Phe31Mep10,C=CC(=O)NCC[N+](C)(C)C.[Cl-],0.51,C=CC(=O)Nc1ccccc1,0.37,C=CC(=O)NCCCOC,0.13,,,,,,,statistical,linear,115,256,>512,nd,,>8000
18
+ SW1.119.5,H-Do31Mo10,C=CC(=O)NCC[N+](C)(C)C.[Cl-],0.59,C=CC(=O)NCCCCCCCCCCCC,0.26,C=CC(=O)N1CCOCC1,0.15,,,,,,,statistical,linear,115,128,32,128-256,256,>8000
19
+ SW1.119.6,H-Do31Mep10,C=CC(=O)NCC[N+](C)(C)C.[Cl-],0.59,C=CC(=O)NCCCCCCCCCCCC,0.26,C=CC(=O)NCCCOC,0.15,,,,,,,statistical,linear,115,128,32,256,>512,6300
20
+ SW2.3.1,L-Bam31Mep10,C=CC(=O)NCC[N+](C)(C)C.[Cl-],0.48,C=CC(=O)NCCCC,0.40,C=CC(=O)NCCCOC,0.12,,,,,,,statistical,linear,70,>512,>512,,,>8000
21
+ SW2.3.2,L-Bmam31Mep10,C=CC(=O)NCC[N+](C)(C)C.[Cl-],0.52,C=CC(=O)NCOCCCC,0.35,C=CC(=O)NCCCOC,0.13,,,,,,,statistical,linear,70,256,>512,,,6200
22
+ SW2.3.3,L-Tmb31Mep10,C=CC(=O)NCC[N+](C)(C)C.[Cl-],0.54,C=CC(=O)NC(C)(C)CC(C)(C)C,0.32,C=CC(=O)NCCCOC,0.14,,,,,,,statistical,linear,70,64,64,,,<62.5
23
+ SW2.3.4,L-Oct31Mep10,C=CC(=O)NCC[N+](C)(C)C.[Cl-],0.54,C=CC(=O)NCCCCCCCC,0.32,C=CC(=O)NCCCOC,0.14,,,,,,,statistical,linear,70,256-128,64,256,>512,4700
24
+ SW2.3.5,L-Olam31Mep10,C=CC(=O)NCC[N+](C)(C)C.[Cl-],0.63,C=CC(=O)NCCCCCCCC/C=C\CCCCCCCC,0.21,C=CC(=O)NCCCOC,0.16,,,,,,,statistical,linear,70,128,64-32,>512,>512,>8000
25
+ SW3.56.1,L-Do30Mep5,C=CC(=O)NCC[N+](C)(C)C.[Cl-],0.66,C=CC(=O)NCCCCCCCCCCCC,0.26,C=CC(=O)NCCCOC,0.07,,,,,,,statistical,linear,70,512,128,,,3400
26
+ SW3.56.2,L-Tmb5Mo90,C=CC(=O)NCC[N+](C)(C)C.[Cl-],0.04,C=CC(=O)NC(C)(C)CC(C)(C)C,0.04,C=CC(=O)N1CCOCC1,0.93,,,,,,,statistical,linear,70,>512,>512,,,>4000
27
+ SW3.56.3,L-Oct5Mep5,C=CC(=O)NCC[N+](C)(C)C.[Cl-],0.87,C=CC(=O)NCCCCCCCC,0.05,C=CC(=O)NCCCOC,0.07,,,,,,,statistical,linear,70,>512,>512,,,>4000
28
+ SW3.56.4,L-Phe15Mo30,C=CC(=O)NCC[N+](C)(C)C.[Cl-],0.46,C=CC(=O)Nc1ccccc1,0.18,C=CC(=O)N1CCOCC1,0.37,,,,,,,statistical,linear,70,>512,16,,,>4000
29
+ SW4.14.2,L-Aeg5Phe25Mo50Mep20,C=CC(=O)NCCNC(N)=[NH2+].[Cl-],0.038,C=CC(=O)Nc1ccccc1,0.246,C=CC(=O)N1CCOCC1,0.514,C=CC(=O)NCCCOC,0.203,,,,,statistical,linear,70,>512,>512,,,2200
30
+ SW4.29.1,L-Do5Mo40Mep5,C=CC(=O)NCC[N+](C)(C)C.[Cl-],0.416,C=CC(=O)NCCCCCCCCCCCC,0.036,C=CC(=O)N1CCOCC1,0.488,C=CC(=O)NCCCOC,0.060,,,,,statistical,linear,70,>512,>512,,,>4000
31
+ SW4.29.2,L-Phe20Olam5Mep5,C=CC(=O)NCC[N+](C)(C)C.[Cl-],0.645,C=CC(=O)Nc1ccccc1,0.259,C=CC(=O)NCCCCCCCC/C=C\CCCCCCCC,0.030,C=CC(=O)NCCCOC,0.067,,,,,statistical,linear,70,128,32,,,>4000
32
+ SW5.20.1,L-Do25,C=CC(=O)NCC[N+](C)(C)C.[Cl-],0.777,C=CC(=O)NCCCCCCCCCCCC,0.223,,,,,,,,,statistical,linear,70,64,,,,>4000
33
+ SW5.20.2,L-Aeg10Olam30Mo60,C=CC(=O)NCCNC(N)=[NH2+].[Cl-],0.091,C=CC(=O)NCCCCCCCC/C=C\CCCCCCCC,0.164,C=CC(=O)N1CCOCC1,0.745,,,,,,,statistical,linear,70,>512,,,,>4000
34
+ SW5.20.3,L-Ni25Phe20,C=CC(=O)NCC[N+](C)(C)C.[Cl-],0.427,C=CC(=O)NC(C)C,0.355,C=CC(=O)Nc1ccccc1,0.218,,,,,,,statistical,linear,70,>512,,,,>4000
35
+ SW5.20.4,L-Bam40Oct5,C=CC(=O)NCC[N+](C)(C)C.[Cl-],0.438,C=CC(=O)NCCCC,0.517,C=CC(=O)NCCCCCCCC,0.045,,,,,,,statistical,linear,70,32,,,,<500
36
+ SW5.20.5,L-Phe23Oct5Mo55,C=CC(=O)NCC[N+](C)(C)C.[Cl-],0.126,C=CC(=O)Nc1ccccc1,0.239,C=CC(=O)N1CCOCC1,0.038,C=CC(=O)N1CCOCC1,0.597,,,,,statistical,linear,70,>512,,,,>4000
37
+ SW5.24.1,L-Aeg10Phe20Olam25,C=CC(=O)NCC[N+](C)(C)C.[Cl-],0.450,C=CC(=O)NCCNC(N)=[NH2+].[Cl-],0.107,C=CC(=O)Nc1ccccc1,0.281,C=CC(=O)NCCCCCCCC/C=C\CCCCCCCC,0.161,,,,,statistical,linear,70,128,,,,1500
38
+ SW5.24.2,L-Aeg20Ni35Tmb10,C=CC(=O)NCC[N+](C)(C)C.[Cl-],0.266,C=CC(=O)NCCNC(N)=[NH2+].[Cl-],0.163,C=CC(=O)NC(C)C,0.486,C=CC(=O)NC(C)(C)CC(C)(C)C,0.086,,,,,statistical,linear,70,64,,,,<500
39
+ SW5.24.3,L-Phe35Olam10Mo20,C=CC(=O)NCC[N+](C)(C)C.[Cl-],0.292,C=CC(=O)Nc1ccccc1,0.410,C=CC(=O)NCCCCCCCC/C=C\CCCCCCCC,0.054,C=CC(=O)N1CCOCC1,0.244,,,,,statistical,linear,70,128,,,,>4000
40
+ SW5.24.4,L-Aeg17Tmb8Mo37,C=CC(=O)NCC[N+](C)(C)C.[Cl-],0.319,C=CC(=O)NCCNC(N)=[NH2+].[Cl-],0.148,C=CC(=O)NC(C)(C)CC(C)(C)C,0.078,C=CC(=O)N1CCOCC1,0.455,,,,,statistical,linear,70,256,,,,<500
41
+ SW5.24.5,L-Aeg20Ni20Olam25Mo5,C=CC(=O)NCC[N+](C)(C)C.[Cl-],0.269,C=CC(=O)NCCNC(N)=[NH2+].[Cl-],0.193,C=CC(=O)NC(C)C,0.328,C=CC(=O)NCCCCCCCC/C=C\CCCCCCCC,0.144,C=CC(=O)N1CCOCC1,0.066,,,statistical,linear,70,256,,,,>4000
42
+ SW5.41.1,L-Do10,C=CC(=O)NCC[N+](C)(C)C.[Cl-],0.912,C=CC(=O)NCCCCCCCCCCCC,0.088,,,,,,,,,statistical,linear,70,256,,,,>4000
43
+ SW5.41.2,L-Phe15Do5,C=CC(=O)NCC[N+](C)(C)C.[Cl-],0.759,C=CC(=O)Nc1ccccc1,0.200,C=CC(=O)NCCCCCCCCCCCC,0.041,,,,,,,statistical,linear,70,256,,,,>4000
44
+ SW5.41.3,L-Aeg5Phe5Olam5,C=CC(=O)NCC[N+](C)(C)C.[Cl-],0.845,C=CC(=O)NCCNC(N)=[NH2+].[Cl-],0.053,C=CC(=O)Nc1ccccc1,0.070,C=CC(=O)NCCCCCCCC/C=C\CCCCCCCC,0.032,,,,,statistical,linear,70,128,,,,>4000
45
+ SW5.41.4,L-Ni20Do5Mep5,C=CC(=O)NCC[N+](C)(C)C.[Cl-],0.593,C=CC(=O)NC(C)C,0.309,C=CC(=O)NCCCCCCCCCCCC,0.037,C=CC(=O)NCCCOC,0.061,,,,,statistical,linear,70,256,,,,>4000
46
+ SW5.41.5,L-Phe20Olam5Mo15,C=CC(=O)NCC[N+](C)(C)C.[Cl-],0.530,C=CC(=O)Nc1ccccc1,0.248,C=CC(=O)NCCCCCCCC/C=C\CCCCCCCC,0.028,C=CC(=O)N1CCOCC1,0.194,,,,,statistical,linear,70,128,,,,>4000
47
+ SW5.42.1,L-Phe5Do5Mo50,C=CC(=O)NCC[N+](C)(C)C.[Cl-],0.321,C=CC(=O)Nc1ccccc1,0.056,C=CC(=O)NCCCCCCCCCCCC,0.035,C=CC(=O)N1CCOCC1,0.588,,,,,statistical,linear,70,>512,,,,>4000
48
+ SW5.42.2,L-Aeg10Oct15Tmb5,C=CC(=O)NCC[N+](C)(C)C.[Cl-],0.678,C=CC(=O)NCCNC(N)=[NH2+].[Cl-],0.104,C=CC(=O)NCCCCCCCC,0.164,C=CC(=O)NC(C)(C)CC(C)(C)C,0.055,,,,,statistical,linear,70,128-256,,,,<500
49
+ SW5.42.3,L-Do5Bam5Mo20Mep5,C=CC(=O)NCC[N+](C)(C)C.[Cl-],0.570,C=CC(=O)NCCCCCCCCCCCC,0.038,C=CC(=O)NCCCC,0.071,C=CC(=O)N1CCOCC1,0.257,C=CC(=O)NCCCOC,0.063,,,statistical,linear,70,256,,,,>4000
50
+ SW5.42.4,L-Aeg5Phe15Bam30Mo25,C=CC(=O)NCC[N+](C)(C)C.[Cl-],0.183,C=CC(=O)NCCNC(N)=[NH2+].[Cl-],0.039,C=CC(=O)Nc1ccccc1,0.154,C=CC(=O)NCCCC,0.356,C=CC(=O)N1CCOCC1,0.268,,,statistical,linear,70,512,,,,>4000
51
+ SW5.42.5,L-Phe5Olam10Bmam10Mep5,C=CC(=O)NCC[N+](C)(C)C.[Cl-],0.674,C=CC(=O)Nc1ccccc1,0.068,C=CC(=O)NCCCCCCCC/C=C\CCCCCCCC,0.062,C=CC(=O)NCOCCCC,0.127,C=CC(=O)NCCCOC,0.070,,,statistical,linear,70,64,,,,>4000
52
+ SW5.65.1,L-Aeg5Ni10Phe5Do30Mep15,C=CC(=O)NCC[N+](C)(C)C.[Cl-],0.309,C=CC(=O)NCCNC(N)=[NH2+].[Cl-],0.047,C=CC(=O)NC(C)C,0.161,C=CC(=O)Nc1ccccc1,0.062,C=CC(=O)NCCCCCCCCCCCC,0.229,C=CC(=O)NCCCOC,0.191,statistical,linear,70,64,,,,3300
53
+ SW5.65.5,L-Aeg10Ni15Bam10Olam20Mep20,C=CC(=O)NCC[N+](C)(C)C.[Cl-],0.206,C=CC(=O)NCCNC(N)=[NH2+].[Cl-],0.089,C=CC(=O)NC(C)C,0.226,C=CC(=O)NCCCC,0.134,C=CC(=O)NCCCCCCCC/C=C\CCCCCCCC,0.106,C=CC(=O)NCCCOC,0.238,statistical,linear,70,128,,,,1400
54
+ SW5.65.7,L-Do15Bam15Oct10Mo30,C=CC(=O)NCC[N+](C)(C)C.[Cl-],0.245,C=CC(=O)NCCCCCCCCCCCC,0.106,C=CC(=O)NCCCC,0.199,C=CC(=O)NCCCCCCCC,0.092,C=CC(=O)N1CCOCC1,0.358,,,statistical,linear,70,128,,,,>4000
55
+ SW5.65.8,L-Aeg10Ni5Do25Tmb10Mep35,C=CC(=O)NCC[N+](C)(C)C.[Cl-],0.122,C=CC(=O)NCCNC(N)=[NH2+].[Cl-],0.088,C=CC(=O)NC(C)C,0.075,C=CC(=O)NCCCCCCCCCCCC,0.211,C=CC(=O)NC(C)(C)CC(C)(C)C,0.092,C=CC(=O)NCCCOC,0.412,statistical,linear,70,>512,,,,<500
56
+ SW5.65.9,L-Ni10Do5Mo60,C=CC(=O)NCC[N+](C)(C)C.[Cl-],0.185,C=CC(=O)NC(C)C,0.135,C=CC(=O)NCCCCCCCCCCCC,0.032,C=CC(=O)N1CCOCC1,0.649,,,,,statistical,linear,70,>512,,,,>4000
57
+ SW5.65.10,L-Aeg15Ni10Do10Olam10Mep35,C=CC(=O)NCC[N+](C)(C)C.[Cl-],0.167,C=CC(=O)NCCNC(N)=[NH2+].[Cl-],0.134,C=CC(=O)NC(C)C,0.152,C=CC(=O)NCCCCCCCCCCCC,0.072,C=CC(=O)NCCCCCCCC/C=C\CCCCCCCC,0.054,C=CC(=O)NCCCOC,0.421,statistical,linear,70,>512,,,,2500
temp_fangping.py ADDED
@@ -0,0 +1,127 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pandas as pd
2
+ import numpy as np
3
+ from DLM_emb_model import MolEmbDLM
4
+ from transformers import AutoTokenizer
5
+ import torch
6
+ import selfies as sf
7
+
8
+ MODEL_DIR = "Kiria-Nozan/ApexOracle"
9
+
10
+ # Load model and tokenizer
11
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_DIR)
12
+ model = MolEmbDLM.from_pretrained(MODEL_DIR)
13
+ model.eval()
14
+
15
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
16
+ model = model.to(device)
17
+
18
+ # Load CSV data
19
+ df = pd.read_csv("temp_data/polymers_lit_scraped.csv")
20
+
21
+ # Extract all unique monomer SMILES
22
+ monomer_columns = ["monomer A", "monomer B", "monomer C", "monomer D", "monomer E", "monomer F"]
23
+ all_monomers = set()
24
+
25
+ for col in monomer_columns:
26
+ if col in df.columns:
27
+ monomers = df[col].dropna().unique()
28
+ all_monomers.update(monomers)
29
+
30
+ print(f"Total unique monomers: {len(all_monomers)}")
31
+
32
+ # Convert SMILES to SELFIES and prepare for embedding
33
+ monomer_selfies = {}
34
+ valid_monomers = []
35
+
36
+ for smiles in all_monomers:
37
+ try:
38
+ selfies = sf.encoder(smiles)
39
+ monomer_selfies[smiles] = selfies
40
+ valid_monomers.append((smiles, selfies))
41
+ except Exception as e:
42
+ print(f"Error converting {smiles} to SELFIES: {e}")
43
+
44
+ print(f"Valid monomers for embedding: {len(valid_monomers)}")
45
+
46
+ # Generate embeddings for all monomers
47
+ monomer_embeddings = {}
48
+
49
+ for smiles, selfies in valid_monomers:
50
+ # Prepare input similar to example.py
51
+ batch = tokenizer(
52
+ selfies.replace('][', '] ['),
53
+ padding="max_length",
54
+ max_length=1024,
55
+ truncation=True,
56
+ return_tensors="pt",
57
+ )
58
+
59
+ batch = {k: v.to(device) for k, v in batch.items()}
60
+
61
+ with torch.no_grad():
62
+ embeddings = model(
63
+ input_ids=batch["input_ids"],
64
+ attention_mask=batch["attention_mask"]+1-batch["attention_mask"],
65
+ )
66
+
67
+ # Store the embedding (average pooling over sequence length)
68
+ monomer_embeddings[smiles] = embeddings[0][0].cpu().numpy()
69
+
70
+ print(f"Generated embeddings for {len(monomer_embeddings)} monomers")
71
+ print(f"Embedding shape: {list(monomer_embeddings.values())[0].shape}")
72
+
73
+ # Check for identical embeddings
74
+ print("\nChecking for identical embeddings...")
75
+ embedding_list = list(monomer_embeddings.items())
76
+ identical_pairs = []
77
+
78
+ for i in range(len(embedding_list)):
79
+ for j in range(i + 1, len(embedding_list)):
80
+ smiles1, emb1 = embedding_list[i]
81
+ smiles2, emb2 = embedding_list[j]
82
+
83
+ # Check if embeddings are identical (with small tolerance for floating point precision)
84
+ if np.allclose(emb1, emb2, rtol=1e-09, atol=1e-09):
85
+ identical_pairs.append((smiles1, smiles2))
86
+
87
+ if identical_pairs:
88
+ print(f"Found {len(identical_pairs)} pairs of identical embeddings:")
89
+ for smiles1, smiles2 in identical_pairs:
90
+ print(f" {smiles1} <-> {smiles2}")
91
+
92
+ # Analyze the identical groups
93
+ print("\nAnalyzing identical embedding groups...")
94
+
95
+ # Create groups of molecules with identical embeddings
96
+ identical_groups = {}
97
+ processed = set()
98
+
99
+ for smiles1, smiles2 in identical_pairs:
100
+ if smiles1 not in processed and smiles2 not in processed:
101
+ # Find all molecules identical to smiles1
102
+ group = {smiles1, smiles2}
103
+ for other_smiles1, other_smiles2 in identical_pairs:
104
+ if other_smiles1 in group:
105
+ group.add(other_smiles2)
106
+ elif other_smiles2 in group:
107
+ group.add(other_smiles1)
108
+
109
+ group_key = frozenset(group)
110
+ if group_key not in identical_groups:
111
+ identical_groups[group_key] = group
112
+ processed.update(group)
113
+
114
+ print(f"Found {len(identical_groups)} groups of molecules with identical embeddings:")
115
+ for i, group in enumerate(identical_groups.values(), 1):
116
+ print(f"\nGroup {i} ({len(group)} molecules):")
117
+ for smiles in sorted(group):
118
+ selfies_repr = monomer_selfies.get(smiles, "N/A")
119
+ print(f" SMILES: {smiles}")
120
+ print(f" SELFIES: {selfies_repr}")
121
+ print()
122
+ else:
123
+ print("No identical embeddings found.")
124
+
125
+ # Save results
126
+ np.save("temp_data/monomer_embeddings.npy", monomer_embeddings)
127
+ print("Embeddings saved to monomer_embeddings.npy")
verify_selfies.py ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import torch
3
+ from transformers import AutoTokenizer
4
+ import sys
5
+ import os
6
+
7
+ # Add current dir to path to find DLM_emb_model
8
+ sys.path.append(os.getcwd())
9
+
10
+ try:
11
+ from DLM_emb_model import MolEmbDLM
12
+ except ImportError:
13
+ print("Could not import MolEmbDLM. Make sure you are running from ApexOracle directory.")
14
+ exit(1)
15
+
16
+ # Use local model path where we applied the fix
17
+ model_path = "/data2/tianang/projects/mdlm/huggingface/huggingface_model"
18
+
19
+ print(f"Loading model from {model_path}...")
20
+ try:
21
+ tokenizer = AutoTokenizer.from_pretrained(model_path)
22
+ model = MolEmbDLM.from_pretrained(model_path)
23
+ except Exception as e:
24
+ print(f"Failed to load model: {e}")
25
+ # Try loading from local directory if remote fails
26
+ try:
27
+ model = MolEmbDLM.from_pretrained(".")
28
+ except Exception as e2:
29
+ print(f"Failed to load from local: {e2}")
30
+ exit(1)
31
+
32
+ model.eval()
33
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
34
+ model = model.to(device)
35
+
36
+ # Two different SELFIES
37
+ selfies_list = [
38
+ "[C][C][O]", # Ethanol
39
+ "[C][C][=O][O]" # Acetic Acid
40
+ ]
41
+
42
+ # Preprocessing from example.py: seq.replace('][', '] [')
43
+ processed_selfies = [s.replace('][', '] [') for s in selfies_list]
44
+
45
+ print(f"Processed SELFIES: {processed_selfies}")
46
+
47
+ # Tokenize with padding=True to create a batch (essential to test the bug fix)
48
+ # example.py had padding=False because it was single sequence.
49
+ print("Tokenizing inputs...")
50
+ inputs = tokenizer(
51
+ processed_selfies,
52
+ padding=True,
53
+ truncation=True,
54
+ return_tensors="pt"
55
+ )
56
+
57
+ print(f"Input IDs:\n{inputs['input_ids']}")
58
+ print(f"Attention Mask:\n{inputs['attention_mask']}")
59
+
60
+ inputs = {k: v.to(device) for k, v in inputs.items() if k in ["input_ids", "attention_mask"]}
61
+
62
+ print("Running model...")
63
+ with torch.no_grad():
64
+ embeddings = model(**inputs)
65
+
66
+ print(f"Embeddings shape: {embeddings.shape}")
67
+
68
+ # Compare embeddings of the two molecules
69
+ # We compare the mean embedding or the first token embedding
70
+ emb1 = embeddings[0]
71
+ emb2 = embeddings[1]
72
+
73
+ # Calculate difference
74
+ diff = torch.abs(emb1 - emb2).sum().item()
75
+ print(f"Difference between embeddings (sum of abs diff): {diff}")
76
+
77
+ if diff < 1e-6:
78
+ print("ISSUE: Embeddings are identical.")
79
+ else:
80
+ print("SUCCESS: Embeddings are different.")
81
+
82
+ print(f"Emb1 mean: {emb1.mean().item()}")
83
+ print(f"Emb2 mean: {emb2.mean().item()}")