Spaces:
Running
on
Zero
Running
on
Zero
| import requests | |
| import numpy as np | |
| # from bio_embeddings.embed import SeqVecEmbedder, ProtTransBertBFDEmbedder, ProtTransT5XLU50Embedder | |
| from transformers import T5Tokenizer, T5EncoderModel | |
| import torch | |
| import re | |
| import concurrent.futures | |
| from tqdm.auto import tqdm | |
| import multiprocessing | |
| from multiprocessing import Pool | |
| import spaces | |
| ENCODERS = { | |
| # "seqvec": SeqVecEmbedder(), | |
| # "prottrans_bert_bfd": ProtTransBertBFDEmbedder(), | |
| # "prottrans_t5_xl_u50": ProtTransT5XLU50Embedder(), | |
| "prot_t5": { | |
| "tokenizer": T5Tokenizer.from_pretrained('Rostlab/prot_t5_xl_half_uniref50-enc', do_lower_case=False), | |
| "model": T5EncoderModel.from_pretrained('Rostlab/prot_t5_xl_half_uniref50-enc') | |
| }, | |
| "prost_t5": { | |
| "tokenizer": T5Tokenizer.from_pretrained("Rostlab/ProstT5", do_lower_case=False), | |
| "model": T5EncoderModel.from_pretrained("Rostlab/ProstT5") | |
| } | |
| } | |
| def drugbank2smiles(drugbank_id): | |
| url = f"https://go.drugbank.com/drugs/{drugbank_id}.smiles" | |
| response = requests.get(url) | |
| if response.status_code == 200: | |
| return response.text | |
| else: | |
| # print(f"Failed to get SMILES for {drugbank_id}") | |
| return None | |
| def uniprot2sequence(uniprot_id): | |
| url = f"https://rest.uniprot.org/uniprotkb/{uniprot_id}.fasta" | |
| response = requests.get(url) | |
| if response.status_code == 200: | |
| # Extract sequence from FASTA format | |
| sequence = "".join(response.text.split("\n")[1:]) | |
| return sequence | |
| else: | |
| # print(f"Failed to get sequence for {uniprot_id}") | |
| return None | |
| def encode_sequences(sequences: list, encoder: str): | |
| if encoder not in ENCODERS.keys(): | |
| raise ValueError(f"Invalid encoder: {encoder}") | |
| model = ENCODERS[encoder]["model"] | |
| tokenizer = ENCODERS[encoder]["tokenizer"] | |
| # Cache for storing encoded sequences | |
| cache = {} | |
| def encode_sequence(sequence: str): | |
| if sequence is None: | |
| return None | |
| if len(sequence) <= 3: | |
| raise ValueError(f"Invalid sequence: {sequence}") | |
| # Check if the sequence is already in the cache | |
| if sequence in cache: | |
| return cache[sequence] | |
| else: | |
| # Encode the sequence and store it in the cache | |
| try: | |
| encoded_sequence = model.embed(sequence) | |
| encoded_sequence = np.mean(encoded_sequence, axis=0) | |
| cache[sequence] = encoded_sequence | |
| return encoded_sequence | |
| except Exception as e: | |
| print(f"Failed to encode sequence: {sequence}") | |
| print(e) | |
| return None | |
| def encode_sequence_device_failover(sequence: str, function, timeout: int = 120): | |
| if sequence is None: | |
| return None | |
| if sequence in cache: | |
| return cache[sequence] | |
| device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") | |
| torch.cuda.empty_cache() | |
| try: | |
| # Try to process using GPU | |
| result = function(sequence, device) | |
| except RuntimeError as e: | |
| print(e) | |
| return None | |
| if "CUDA out of memory." in str(e): | |
| print("Trying on CPU instead.") | |
| device = torch.device("cpu") | |
| with concurrent.futures.ThreadPoolExecutor() as executor: | |
| future = executor.submit(function, sequence, device) | |
| try: | |
| result = future.result(timeout=timeout) | |
| except concurrent.futures.TimeoutError: | |
| print(f"CPU encoding timed out.") | |
| cache[sequence] = None | |
| return None | |
| else: | |
| cache[sequence] = None | |
| raise Exception(e) | |
| except Exception as e: | |
| print(f"Failed to encode sequence: {sequence}") | |
| cache[sequence] = None | |
| return None | |
| cache[sequence] = result | |
| return result | |
| def encode_sequence_hf_3d(sequence, device): | |
| sequence_1d_list = [sequence] | |
| model.full() if device == "cpu" else model.half() | |
| model.to(device) | |
| ids = tokenizer.batch_encode_plus( | |
| sequence_1d_list, | |
| add_special_tokens=True, | |
| padding="longest", | |
| return_tensors="pt" | |
| ).to(device) | |
| with torch.no_grad(): | |
| embedding = model( | |
| ids.input_ids, | |
| attention_mask=ids.attention_mask | |
| ) | |
| # Skip the first token, which is the special token for the entire sequence and mean pool the rest | |
| assert embedding.last_hidden_state.shape[0] == 1 | |
| encoded_sequence = embedding.last_hidden_state[0, 1:-1, :] | |
| encoded_sequence = encoded_sequence.mean(dim=0).cpu().numpy().flatten() | |
| assert encoded_sequence.shape[0] == 1024 | |
| return encoded_sequence | |
| def encode_sequence_hf(sequence, device): | |
| sequence_1d_list = [sequence] | |
| model.full() if device == "cpu" else model.half() | |
| model.to(device) | |
| ids = tokenizer.batch_encode_plus( | |
| sequence_1d_list, | |
| add_special_tokens=True, | |
| padding="longest", | |
| return_tensors="pt" | |
| ).to(device) | |
| with torch.no_grad(): | |
| embedding = model( | |
| ids.input_ids, | |
| attention_mask=ids.attention_mask | |
| ) | |
| assert embedding.last_hidden_state.shape[0] == 1 | |
| encoded_sequence = embedding.last_hidden_state[0, :-1, :] | |
| encoded_sequence = encoded_sequence.mean(dim=0).cpu().numpy().flatten() | |
| assert encoded_sequence.shape[0] == 1024 | |
| return encoded_sequence | |
| # Use list comprehension to encode all sequences, utilizing the cache | |
| if encoder == "seqvec": | |
| raise NotImplementedError("SeqVec is not supported") | |
| seq = encoder_function.embed(list(sequences)) | |
| seq = np.sum(seq, axis=0) | |
| if encoder == "prost_t5": | |
| sequences = [" ".join(list(re.sub(r"[UZOB]", "X", sequence))) for sequence in sequences] | |
| # The direction of the translation is indicated by two special tokens: | |
| # if you go from AAs to 3Di (or if you want to embed AAs), you need to prepend "<AA2fold>" | |
| # if you go from 3Di to AAs (or if you want to embed 3Di), you need to prepend "<fold2AA>" | |
| sequences = ["<AA2fold>" + " " + s if s.isupper() else "<fold2AA>" + " " + s for s in sequences] | |
| seq = [encode_sequence_device_failover(sequence, encode_sequence_hf_3d) for sequence in tqdm(sequences, desc="Encoding sequences")] | |
| elif encoder == "prot_t5": | |
| sequences = [" ".join(list(re.sub(r"[UZOB]", "X", sequence))) for sequence in sequences] | |
| seq = [encode_sequence_device_failover(sequence, encode_sequence_hf) for sequence in tqdm(sequences, desc="Encoding sequences")] | |
| else: | |
| raise NotImplementedError("SeqVec is not supported") | |
| seq = [encode_sequence(sequence) for sequence in sequences] | |
| return np.array(seq) | |
| class SequenceEncoder: | |
| def __init__(self, encoder: str): | |
| if encoder not in ENCODERS: | |
| raise ValueError(f"Invalid encoder: {encoder}") | |
| self.encoder = encoder | |
| self.model = ENCODERS[encoder]["model"] | |
| self.tokenizer = ENCODERS[encoder]["tokenizer"] | |
| self.cache = {} | |
| def encode_sequence(self, sequence: str): | |
| if sequence is None: | |
| return None | |
| if len(sequence) <= 3: | |
| raise ValueError(f"Invalid sequence: {sequence}") | |
| if sequence in self.cache: | |
| return self.cache[sequence] | |
| try: | |
| encoded_sequence = self.model.embed(sequence) | |
| encoded_sequence = np.mean(encoded_sequence, axis=0) | |
| self.cache[sequence] = encoded_sequence | |
| return encoded_sequence | |
| except Exception as e: | |
| print(f"Failed to encode sequence: {sequence}") | |
| print(e) | |
| return None | |
| def encode_sequence_device_failover(self, sequence: str, function, timeout: int = 5): | |
| if sequence is None: | |
| return None | |
| if sequence in self.cache: | |
| return self.cache[sequence] | |
| device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") | |
| torch.cuda.empty_cache() | |
| try: | |
| result = function(sequence, device) | |
| except RuntimeError as e: | |
| return None | |
| print(e) | |
| if "CUDA out of memory." in str(e): | |
| print("Trying on CPU instead.") | |
| device = torch.device("cpu") | |
| with concurrent.futures.ThreadPoolExecutor(max_workers=1) as executor: | |
| future = executor.submit(function, sequence, device) | |
| try: | |
| result = future.result(timeout=timeout) | |
| except: | |
| print(f"CPU encoding timed out.") | |
| self.cache[sequence] = None | |
| return None | |
| finally: | |
| executor.shutdown(wait=False) | |
| else: | |
| self.cache[sequence] = None | |
| return None | |
| except Exception as e: | |
| print(f"Failed to encode sequence: {sequence}") | |
| self.cache[sequence] = None | |
| return None | |
| self.cache[sequence] = result | |
| return result | |
| def encode_sequence_hf_3d(self, sequence, device): | |
| sequence_1d_list = [sequence] | |
| self.model.full() if device == "cpu" else self.model.half() | |
| self.model.to(device) | |
| ids = self.tokenizer.batch_encode_plus( | |
| sequence_1d_list, | |
| add_special_tokens=True, | |
| padding="longest", | |
| return_tensors="pt" | |
| ).to(device) | |
| with torch.no_grad(): | |
| embedding = self.model( | |
| ids.input_ids, | |
| attention_mask=ids.attention_mask | |
| ) | |
| assert embedding.last_hidden_state.shape[0] == 1 | |
| encoded_sequence = embedding.last_hidden_state[0, 1:-1, :] | |
| encoded_sequence = encoded_sequence.mean(dim=0).cpu().numpy().flatten() | |
| assert encoded_sequence.shape[0] == 1024 | |
| return encoded_sequence | |
| def encode_sequence_hf(self, sequence, device): | |
| sequence_1d_list = [sequence] | |
| self.model.full() if device == "cpu" else self.model.half() | |
| self.model.to(device) | |
| ids = self.tokenizer.batch_encode_plus( | |
| sequence_1d_list, | |
| add_special_tokens=True, | |
| padding="longest", | |
| return_tensors="pt" | |
| ).to(device) | |
| with torch.no_grad(): | |
| embedding = self.model( | |
| ids.input_ids, | |
| attention_mask=ids.attention_mask | |
| ) | |
| assert embedding.last_hidden_state.shape[0] == 1 | |
| encoded_sequence = embedding.last_hidden_state[0, :-1, :] | |
| encoded_sequence = encoded_sequence.mean(dim=0).cpu().numpy().flatten() | |
| assert encoded_sequence.shape[0] == 1024 | |
| return encoded_sequence | |
| def encode_sequences(self, sequences: list): | |
| if self.encoder == "seqvec": | |
| raise NotImplementedError("SeqVec is not supported") | |
| seq = self.encoder_function.embed(list(sequences)) | |
| seq = np.sum(seq, axis=0) | |
| elif self.encoder == "prost_t5": | |
| sequences = [" ".join(list(re.sub(r"[UZOB]", "X", sequence))) for sequence in sequences] | |
| sequences = ["<AA2fold>" + " " + s if s.isupper() else "<fold2AA>" + " " + s for s in sequences] | |
| seq = [self.encode_sequence_device_failover(sequence, self.encode_sequence_hf_3d) for sequence in tqdm(sequences, desc="Encoding sequences")] | |
| elif self.encoder == "prot_t5": | |
| sequences = [" ".join(list(re.sub(r"[UZOB]", "X", sequence))) for sequence in sequences] | |
| seq = [self.encode_sequence_device_failover(sequence, self.encode_sequence_hf) for sequence in tqdm(sequences, desc="Encoding sequences")] | |
| else: | |
| raise NotImplementedError("SeqVec is not supported") | |
| seq = [self.encode_sequence(sequence) for sequence in sequences] | |
| if any([x is None for x in seq]): | |
| return seq | |
| else: | |
| return np.array(seq) | |