rahul7star commited on
Commit
5daa352
·
verified ·
1 Parent(s): 9c7141f

Update app_flash1.py

Browse files
Files changed (1) hide show
  1. app_flash1.py +13 -10
app_flash1.py CHANGED
@@ -21,9 +21,17 @@ print(f"🔧 Using device: {device} (CPU-only mode)")
21
  # ============================================================
22
  # 1️⃣ Model Definition
23
  # ============================================================
 
 
 
24
  class GemmaTrainer(nn.Module, FlashPackMixin):
25
- def __init__(self, input_dim: int, hidden_dim: int = 1024, output_dim: int = 1536):
26
  super().__init__()
 
 
 
 
 
27
  self.fc1 = nn.Linear(input_dim, hidden_dim)
28
  self.relu = nn.ReLU()
29
  self.fc2 = nn.Linear(hidden_dim, hidden_dim)
@@ -38,6 +46,7 @@ class GemmaTrainer(nn.Module, FlashPackMixin):
38
  return x
39
 
40
 
 
41
  # ============================================================
42
  # 2️⃣ Encoder Setup
43
  # ============================================================
@@ -131,10 +140,9 @@ def get_flashpack_model(hf_repo="rahul7star/FlashPack"):
131
  # 1️⃣ Try local first
132
  if os.path.exists(local_path):
133
  print("✅ Found local model.flashpack — loading it directly.")
134
- model = GemmaTrainer().from_flashpack(local_path)
135
  model.eval()
136
  tokenizer, embed_model, _ = build_encoder("gpt2")
137
- print("✅ FlashPack model loaded successfully (local).")
138
  return model, tokenizer, embed_model, None, None
139
 
140
  # 2️⃣ Otherwise, check HF repo
@@ -143,19 +151,13 @@ def get_flashpack_model(hf_repo="rahul7star/FlashPack"):
143
  if "model.flashpack" in files:
144
  print("✅ Found model.flashpack in repo — downloading and loading it.")
145
  local_path = hf_hub_download(repo_id=hf_repo, filename="model.flashpack")
146
-
147
- # Correct loading: no extra args passed
148
- model = GemmaTrainer().from_flashpack(local_path)
149
  model.eval()
150
-
151
  tokenizer, embed_model, _ = build_encoder("gpt2")
152
- print("✅ FlashPack model loaded successfully (from repo).")
153
  return model, tokenizer, embed_model, None, None
154
-
155
  else:
156
  print("🚫 model.flashpack not found — starting training.")
157
  return train_flashpack_model(hf_repo=hf_repo)
158
-
159
  except Exception as e:
160
  print(f"⚠️ Error checking repo: {e}")
161
  print("⏬ Training new model instead.")
@@ -164,6 +166,7 @@ def get_flashpack_model(hf_repo="rahul7star/FlashPack"):
164
 
165
 
166
 
 
167
  # ============================================================
168
  # 6️⃣ Encode & Enhance Functions
169
  # ============================================================
 
21
  # ============================================================
22
  # 1️⃣ Model Definition
23
  # ============================================================
24
+ # ============================================================
25
+ # 1️⃣ Fixed Model Definition (FlashPack-compatible)
26
+ # ============================================================
27
  class GemmaTrainer(nn.Module, FlashPackMixin):
28
+ def __init__(self):
29
  super().__init__()
30
+ # Move input/output dimensions inside
31
+ input_dim = 1536
32
+ hidden_dim = 1024
33
+ output_dim = 1536
34
+
35
  self.fc1 = nn.Linear(input_dim, hidden_dim)
36
  self.relu = nn.ReLU()
37
  self.fc2 = nn.Linear(hidden_dim, hidden_dim)
 
46
  return x
47
 
48
 
49
+
50
  # ============================================================
51
  # 2️⃣ Encoder Setup
52
  # ============================================================
 
140
  # 1️⃣ Try local first
141
  if os.path.exists(local_path):
142
  print("✅ Found local model.flashpack — loading it directly.")
143
+ model = GemmaTrainer().from_flashpack(local_path) # no args needed
144
  model.eval()
145
  tokenizer, embed_model, _ = build_encoder("gpt2")
 
146
  return model, tokenizer, embed_model, None, None
147
 
148
  # 2️⃣ Otherwise, check HF repo
 
151
  if "model.flashpack" in files:
152
  print("✅ Found model.flashpack in repo — downloading and loading it.")
153
  local_path = hf_hub_download(repo_id=hf_repo, filename="model.flashpack")
154
+ model = GemmaTrainer().from_flashpack(local_path) # no args
 
 
155
  model.eval()
 
156
  tokenizer, embed_model, _ = build_encoder("gpt2")
 
157
  return model, tokenizer, embed_model, None, None
 
158
  else:
159
  print("🚫 model.flashpack not found — starting training.")
160
  return train_flashpack_model(hf_repo=hf_repo)
 
161
  except Exception as e:
162
  print(f"⚠️ Error checking repo: {e}")
163
  print("⏬ Training new model instead.")
 
166
 
167
 
168
 
169
+
170
  # ============================================================
171
  # 6️⃣ Encode & Enhance Functions
172
  # ============================================================