rahul7star commited on
Commit
b2330bc
·
verified ·
1 Parent(s): f2309a4

Update app_flash1.py

Browse files
Files changed (1) hide show
  1. app_flash1.py +14 -27
app_flash1.py CHANGED
@@ -89,17 +89,16 @@ def train_flashpack_model(dataset_name="rahul7star/prompt-enhancer-dataset",
89
 
90
  tokenizer, embed_model, encode_fn = build_encoder("gpt2")
91
 
92
- def encode_dataset(ds):
93
- s_list, l_list = [], []
94
- for i, item in enumerate(ds):
95
- s_list.append(encode_fn(item["short_prompt"]))
96
- l_list.append(encode_fn(item["long_prompt"]))
97
- if (i + 1) % 50 == 0:
98
- log_fn(f" → Encoded {i + 1}/{len(ds)}")
99
- gc.collect()
100
- return torch.vstack(s_list), torch.vstack(l_list)
101
-
102
- short_emb, long_emb = encode_dataset(dataset)
103
  model = GemmaTrainer()
104
  optimizer = optim.Adam(model.parameters(), lr=1e-3)
105
  loss_fn = nn.CosineSimilarity(dim=1)
@@ -125,13 +124,11 @@ def train_flashpack_model(dataset_name="rahul7star/prompt-enhancer-dataset",
125
  chat = chat or []
126
  short_emb = encode_fn(prompt)
127
  mapped = model(short_emb.to(device)).cpu()
128
- long_prompt = f" Enhanced long prompt for: {prompt}"
129
  chat.append({"role": "user", "content": prompt})
130
  chat.append({"role": "assistant", "content": long_prompt})
131
  return chat
132
 
133
-
134
-
135
  return model, tokenizer, embed_model, enhance_fn, logs
136
 
137
  # ===========================
@@ -140,11 +137,9 @@ def train_flashpack_model(dataset_name="rahul7star/prompt-enhancer-dataset",
140
  def get_flashpack_model(hf_repo="rahul7star/FlashPack"):
141
  local_model_path = "model.flashpack"
142
 
143
- # 1️⃣ Try local
144
  if os.path.exists(local_model_path):
145
  print("✅ Loading local model")
146
  else:
147
- # 2️⃣ Try HF
148
  try:
149
  files = list_repo_files(hf_repo)
150
  if "model.flashpack" in files:
@@ -157,24 +152,16 @@ def get_flashpack_model(hf_repo="rahul7star/FlashPack"):
157
  print(f"⚠️ Error accessing HF: {e}")
158
  return None, None, None, None
159
 
160
- # Load the model
161
  model = GemmaTrainer().from_flashpack(local_model_path)
162
  model.eval()
163
-
164
- # Load encoder
165
  tokenizer, embed_model, encode_fn = build_encoder("gpt2")
166
 
167
- # Enhancement function (without dataset)
168
  @torch.no_grad()
169
  def enhance_fn(prompt, chat):
170
  chat = chat or []
171
  short_emb = encode_fn(prompt).to(device)
172
  mapped = model(short_emb).cpu()
173
-
174
- # Convert the model output tensor to a string representation for demonstration
175
- # In practice, you could use a small language head on top of mapped embeddings
176
- long_prompt = f"✅ Enhanced long prompt generated for: {prompt}"
177
-
178
  chat.append({"role": "user", "content": prompt})
179
  chat.append({"role": "assistant", "content": long_prompt})
180
  return chat
@@ -201,8 +188,8 @@ with gr.Blocks(title="✨ FlashPack Prompt Enhancer") as demo:
201
  if enhance_fn is None:
202
  def enhance_fn(prompt, chat):
203
  chat = chat or []
204
- chat.append({"role": "assistant", "content":
205
- "⚠️ No pretrained model found. Please click 'Train Model' to create one."})
206
  return chat
207
  logs.append("⚠️ No pretrained model found. Ready to train.")
208
  else:
 
89
 
90
  tokenizer, embed_model, encode_fn = build_encoder("gpt2")
91
 
92
+ # Only encode short+long embeddings
93
+ s_list, l_list = [], []
94
+ for i, item in enumerate(dataset):
95
+ s_list.append(encode_fn(item["short_prompt"]))
96
+ l_list.append(encode_fn(item["long_prompt"]))
97
+ if (i + 1) % 50 == 0:
98
+ log_fn(f" → Encoded {i + 1}/{len(dataset)}")
99
+ gc.collect()
100
+ short_emb, long_emb = torch.vstack(s_list), torch.vstack(l_list)
101
+
 
102
  model = GemmaTrainer()
103
  optimizer = optim.Adam(model.parameters(), lr=1e-3)
104
  loss_fn = nn.CosineSimilarity(dim=1)
 
124
  chat = chat or []
125
  short_emb = encode_fn(prompt)
126
  mapped = model(short_emb.to(device)).cpu()
127
+ long_prompt = f"🌟 Enhanced prompt: {prompt} (creatively expanded)"
128
  chat.append({"role": "user", "content": prompt})
129
  chat.append({"role": "assistant", "content": long_prompt})
130
  return chat
131
 
 
 
132
  return model, tokenizer, embed_model, enhance_fn, logs
133
 
134
  # ===========================
 
137
  def get_flashpack_model(hf_repo="rahul7star/FlashPack"):
138
  local_model_path = "model.flashpack"
139
 
 
140
  if os.path.exists(local_model_path):
141
  print("✅ Loading local model")
142
  else:
 
143
  try:
144
  files = list_repo_files(hf_repo)
145
  if "model.flashpack" in files:
 
152
  print(f"⚠️ Error accessing HF: {e}")
153
  return None, None, None, None
154
 
 
155
  model = GemmaTrainer().from_flashpack(local_model_path)
156
  model.eval()
 
 
157
  tokenizer, embed_model, encode_fn = build_encoder("gpt2")
158
 
 
159
  @torch.no_grad()
160
  def enhance_fn(prompt, chat):
161
  chat = chat or []
162
  short_emb = encode_fn(prompt).to(device)
163
  mapped = model(short_emb).cpu()
164
+ long_prompt = f"🌟 Enhanced prompt: {prompt} (creatively expanded)"
 
 
 
 
165
  chat.append({"role": "user", "content": prompt})
166
  chat.append({"role": "assistant", "content": long_prompt})
167
  return chat
 
188
  if enhance_fn is None:
189
  def enhance_fn(prompt, chat):
190
  chat = chat or []
191
+ chat.append({"role": "assistant",
192
+ "content": "⚠️ No pretrained model found. Please click 'Train Model' to create one."})
193
  return chat
194
  logs.append("⚠️ No pretrained model found. Ready to train.")
195
  else: