Spaces:
Sleeping
Sleeping
| import streamlit as st | |
| import os | |
| def safe_import(module_name): | |
| try: | |
| return __import__(module_name) | |
| except ImportError: | |
| return None | |
| torch = safe_import('torch') | |
| if torch is None: | |
| st.error("Torch is not installed yet. Please wait a moment for the dependencies to install.") | |
| st.stop() | |
| import torch.nn as nn | |
| # architecture | |
| class AddModel(nn.Module): | |
| def __init__(self): | |
| super(AddModel, self).__init__() | |
| self.fc1 = nn.Linear(2, 32) | |
| self.relu1 = nn.ReLU() | |
| self.fc2 = nn.Linear(32, 64) | |
| self.relu2 = nn.ReLU() | |
| self.fc3 = nn.Linear(64, 1) | |
| def forward(self, x): | |
| x = self.relu1(self.fc1(x)) | |
| x = self.relu2(self.fc2(x)) | |
| x = self.fc3(x) | |
| return x | |
| def load_model(model_path): | |
| model = AddModel() | |
| model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu'))) | |
| model.eval() # evaluation mode | |
| return model | |
| def predict_sum(model, x1, x2): | |
| with torch.no_grad(): | |
| input_tensor = torch.tensor([[x1, x2]], dtype=torch.float32) | |
| prediction = model(input_tensor) | |
| return prediction.item() | |
| # Streamlit app | |
| def main(): | |
| st.title("Sum Predictor using Neural Network (Using version M3)") | |
| model_path = "MA3T.pth" # Update with your model path if necessary | |
| if os.path.exists(model_path): | |
| model = load_model(model_path) | |
| st.success("Model loaded successfully.") | |
| x1 = st.number_input("Enter the first number:", value=0.0) | |
| x2 = st.number_input("Enter the second number:", value=0.0) | |
| if st.button("Predict"): | |
| predicted_sum = predict_sum(model, x1, x2) | |
| correct_sum = x1 + x2 # Calculate the correct answer | |
| st.write(f"The predicted sum of {x1} and {x2} is: {predicted_sum:.2f}") | |
| st.write(f"The correct sum of {x1} and {x2} is: {correct_sum:.2f}") | |
| else: | |
| st.error("Model file not found. Please upload the model.") | |
| if __name__ == "__main__": | |
| main() |