Spaces:
Sleeping
Sleeping
| from datasets import load_dataset | |
| from torch_geometric.transforms import ToUndirected | |
| import torch | |
| from torch.nn import Linear | |
| from torch_geometric.nn import HGTConv, MLP | |
| import pandas as pd | |
| class ProtHGT(torch.nn.Module): | |
| def __init__(self, data,hidden_channels, num_heads, num_layers, mlp_hidden_layers, mlp_dropout): | |
| super().__init__() | |
| self.lin_dict = torch.nn.ModuleDict({ | |
| node_type: Linear(-1, hidden_channels) | |
| for node_type in data.node_types | |
| }) | |
| self.convs = torch.nn.ModuleList() | |
| for _ in range(num_layers): | |
| conv = HGTConv(hidden_channels, hidden_channels, data.metadata(), num_heads, group='sum') | |
| self.convs.append(conv) | |
| # self.left_linear = Linear(hidden_channels, hidden_channels) | |
| # self.right_linear = Linear(hidden_channels, hidden_channels) | |
| # self.sqrt_hd = hidden_channels**1/2 | |
| # self.mlp =MLP([2*hidden_channels, 128, 1], dropout=0.5, norm=None) | |
| self.mlp = MLP(mlp_hidden_layers , dropout=mlp_dropout, norm=None) | |
| def generate_embeddings(self, x_dict, edge_index_dict): | |
| # Generate updated embeddings through the GNN layers | |
| x_dict = { | |
| node_type: self.lin_dict[node_type](x).relu_() | |
| for node_type, x in x_dict.items() | |
| } | |
| for conv in self.convs: | |
| x_dict = conv(x_dict, edge_index_dict) | |
| return x_dict | |
| def forward(self, x_dict, edge_index_dict, tr_edge_label_index, target_type, test=False): | |
| # Get updated embeddings | |
| x_dict = self.generate_embeddings(x_dict, edge_index_dict) | |
| # Make predictions | |
| row, col = tr_edge_label_index | |
| z = torch.cat([x_dict["Protein"][row], x_dict[target_type][col]], dim=-1) | |
| return self.mlp(z).view(-1), x_dict | |
| def _load_data(protein_id, go_category=None, heterodata_path=''): | |
| heterodata = load_dataset(heterodata_path) | |
| # Remove unnecessary edge types in one go | |
| edge_types_to_remove = [ | |
| ('Protein', 'protein_function', 'GO_term_F'), | |
| ('Protein', 'protein_function', 'GO_term_P'), | |
| ('Protein', 'protein_function', 'GO_term_C'), | |
| ('GO_term_F', 'rev_protein_function', 'Protein'), | |
| ('GO_term_P', 'rev_protein_function', 'Protein'), | |
| ('GO_term_C', 'rev_protein_function', 'Protein') | |
| ] | |
| for edge_type in edge_types_to_remove: | |
| if edge_type in heterodata: | |
| del heterodata[edge_type] | |
| # Remove reverse edges | |
| heterodata = {k: v for k, v in heterodata.items() if not isinstance(k, tuple) or 'rev' not in k[1]} | |
| protein_index = heterodata['Protein']['id_mapping'][protein_id] | |
| # Create edge indices more efficiently | |
| categories = [go_category] if go_category else ['GO_term_F', 'GO_term_P', 'GO_term_C'] | |
| for category in categories: | |
| pairs = [(protein_index, i) for i in range(len(heterodata[category]))] | |
| heterodata['Protein', 'protein_function', category] = {'edge_index': pairs} | |
| return ToUndirected(merge=False)(heterodata) | |
| def get_available_proteins(protein_list_file='data/available_proteins.txt'): | |
| with open(protein_list_file, 'r') as file: | |
| return [line.strip() for line in file.readlines()] | |
| def _generate_predictions(heterodata, model_path, model_config, target_type): | |
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
| model = ProtHGT(heterodata, model_config['hidden_channels'], model_config['num_heads'], model_config['num_layers'], model_config['mlp_hidden_layers'], model_config['mlp_dropout']) | |
| print('Loading model from', model_path) | |
| model.load_state_dict(torch.load(model_path, map_location=device)) | |
| model.to(device) | |
| model.eval() | |
| heterodata.to(device) | |
| with torch.no_grad(): | |
| predictions, _ = model(heterodata.x_dict, heterodata.edge_index_dict, heterodata[("Protein", "protein_function", target_type)].edge_label_index, target_type) | |
| return predictions | |
| def _create_prediction_df(predictions, heterodata, protein_id, go_category): | |
| prediction_df = pd.DataFrame({ | |
| 'Protein': protein_id, | |
| 'GO_category': go_category, | |
| 'GO_term': heterodata[go_category]['id_mapping'].keys(), | |
| 'Probability': predictions.tolist() | |
| }) | |
| prediction_df.sort_values(by='Probability', ascending=False, inplace=True) | |
| prediction_df.reset_index(drop=True, inplace=True) | |
| return prediction_df | |
| def generate_prediction_df(protein_id, heterodata_path, model_path, model_config, go_category=None): | |
| heterodata = _load_data(protein_id, go_category, heterodata_path) | |
| if go_category: | |
| predictions = _generate_predictions(heterodata, model_path, model_config, go_category) | |
| prediction_df = _create_prediction_df(predictions, heterodata, protein_id, go_category) | |
| return prediction_df | |
| else: | |
| all_predictions = [] | |
| for go_category in ['GO_term_F', 'GO_term_P', 'GO_term_C']: | |
| predictions = _generate_predictions(heterodata, model_path, model_config, go_category) | |
| category_df = _create_prediction_df(predictions, heterodata, protein_id, go_category) | |
| all_predictions.append(category_df) | |
| return pd.concat(all_predictions, ignore_index=True) | |