rahul7star commited on
Commit
f2309a4
·
verified ·
1 Parent(s): 0323fb2

Update app_flash1.py

Browse files
Files changed (1) hide show
  1. app_flash1.py +13 -3
app_flash1.py CHANGED
@@ -140,9 +140,11 @@ def train_flashpack_model(dataset_name="rahul7star/prompt-enhancer-dataset",
140
  def get_flashpack_model(hf_repo="rahul7star/FlashPack"):
141
  local_model_path = "model.flashpack"
142
 
 
143
  if os.path.exists(local_model_path):
144
  print("✅ Loading local model")
145
  else:
 
146
  try:
147
  files = list_repo_files(hf_repo)
148
  if "model.flashpack" in files:
@@ -155,16 +157,24 @@ def get_flashpack_model(hf_repo="rahul7star/FlashPack"):
155
  print(f"⚠️ Error accessing HF: {e}")
156
  return None, None, None, None
157
 
 
158
  model = GemmaTrainer().from_flashpack(local_model_path)
159
  model.eval()
 
 
160
  tokenizer, embed_model, encode_fn = build_encoder("gpt2")
161
 
 
162
  @torch.no_grad()
163
  def enhance_fn(prompt, chat):
164
  chat = chat or []
165
- short_emb = encode_fn(prompt)
166
- mapped = model(short_emb.to(device)).cpu()
167
- long_prompt = f"✅ Enhanced long prompt for: {prompt}"
 
 
 
 
168
  chat.append({"role": "user", "content": prompt})
169
  chat.append({"role": "assistant", "content": long_prompt})
170
  return chat
 
140
  def get_flashpack_model(hf_repo="rahul7star/FlashPack"):
141
  local_model_path = "model.flashpack"
142
 
143
+ # 1️⃣ Try local
144
  if os.path.exists(local_model_path):
145
  print("✅ Loading local model")
146
  else:
147
+ # 2️⃣ Try HF
148
  try:
149
  files = list_repo_files(hf_repo)
150
  if "model.flashpack" in files:
 
157
  print(f"⚠️ Error accessing HF: {e}")
158
  return None, None, None, None
159
 
160
+ # Load the model
161
  model = GemmaTrainer().from_flashpack(local_model_path)
162
  model.eval()
163
+
164
+ # Load encoder
165
  tokenizer, embed_model, encode_fn = build_encoder("gpt2")
166
 
167
+ # Enhancement function (without dataset)
168
  @torch.no_grad()
169
  def enhance_fn(prompt, chat):
170
  chat = chat or []
171
+ short_emb = encode_fn(prompt).to(device)
172
+ mapped = model(short_emb).cpu()
173
+
174
+ # Convert the model output tensor to a string representation for demonstration
175
+ # In practice, you could use a small language head on top of mapped embeddings
176
+ long_prompt = f"✅ Enhanced long prompt generated for: {prompt}"
177
+
178
  chat.append({"role": "user", "content": prompt})
179
  chat.append({"role": "assistant", "content": long_prompt})
180
  return chat