Update app_flash.py
Browse files- app_flash.py +3 -54
app_flash.py
CHANGED
|
@@ -9,7 +9,6 @@ from datasets import load_dataset
|
|
| 9 |
from transformers import AutoTokenizer, AutoModel
|
| 10 |
from flashpack import FlashPackMixin
|
| 11 |
from huggingface_hub import Repository
|
| 12 |
-
from huggingface_hub import Repository, list_repo_files, hf_hub_download
|
| 13 |
from typing import Tuple
|
| 14 |
|
| 15 |
# ============================================================
|
|
@@ -23,17 +22,14 @@ print(f"🔧 Using device: {device} (CPU-only)")
|
|
| 23 |
# 1️⃣ FlashPack model with better hidden layers
|
| 24 |
# ============================================================
|
| 25 |
class GemmaTrainer(nn.Module, FlashPackMixin):
|
| 26 |
-
def __init__(self):
|
| 27 |
super().__init__()
|
| 28 |
-
input_dim = 1536
|
| 29 |
-
hidden_dim = 1024
|
| 30 |
-
output_dim = 1536
|
| 31 |
self.fc1 = nn.Linear(input_dim, hidden_dim)
|
| 32 |
self.relu = nn.ReLU()
|
| 33 |
self.fc2 = nn.Linear(hidden_dim, hidden_dim)
|
| 34 |
self.fc3 = nn.Linear(hidden_dim, output_dim)
|
| 35 |
|
| 36 |
-
def forward(self, x: torch.Tensor):
|
| 37 |
x = self.fc1(x)
|
| 38 |
x = self.relu(x)
|
| 39 |
x = self.fc2(x)
|
|
@@ -157,53 +153,6 @@ def train_flashpack_model(
|
|
| 157 |
# 5️⃣ Load or train model
|
| 158 |
# ============================================================
|
| 159 |
def get_flashpack_model(hf_repo="rahul7star/FlashPack"):
|
| 160 |
-
input_dim = 1536 # must match the input_dim used during training
|
| 161 |
-
try:
|
| 162 |
-
print(f"🔁 Attempting to load FlashPack model from {hf_repo}")
|
| 163 |
-
|
| 164 |
-
# 1️⃣ Try local model first
|
| 165 |
-
local_model_path = "model.flashpack"
|
| 166 |
-
if os.path.exists(local_model_path):
|
| 167 |
-
print("✅ Loading local model")
|
| 168 |
-
else:
|
| 169 |
-
# 2️⃣ Try Hugging Face
|
| 170 |
-
files = list_repo_files(hf_repo)
|
| 171 |
-
if "model.flashpack" in files:
|
| 172 |
-
print("✅ Downloading model from HF")
|
| 173 |
-
from huggingface_hub import hf_hub_download
|
| 174 |
-
local_model_path = hf_hub_download(repo_id=hf_repo, filename="model.flashpack")
|
| 175 |
-
else:
|
| 176 |
-
print("🚫 No pretrained model found")
|
| 177 |
-
return None, None, None, None
|
| 178 |
-
|
| 179 |
-
# 3️⃣ Load model with correct input_dim
|
| 180 |
-
model = GemmaTrainer(input_dim=input_dim).from_flashpack(local_model_path)
|
| 181 |
-
model.eval()
|
| 182 |
-
|
| 183 |
-
# 4️⃣ Build encoder
|
| 184 |
-
tokenizer, embed_model, encode_fn = build_encoder("gpt2", max_length=128)
|
| 185 |
-
|
| 186 |
-
# 5️⃣ Enhancement function
|
| 187 |
-
@torch.no_grad()
|
| 188 |
-
def enhance_fn(prompt, chat):
|
| 189 |
-
chat = chat or []
|
| 190 |
-
short_emb = encode_fn(prompt).to(device)
|
| 191 |
-
mapped = model(short_emb).cpu()
|
| 192 |
-
long_prompt = f"🌟 Enhanced prompt: {prompt} (creatively expanded)"
|
| 193 |
-
chat.append({"role": "user", "content": prompt})
|
| 194 |
-
chat.append({"role": "assistant", "content": long_prompt})
|
| 195 |
-
return chat
|
| 196 |
-
|
| 197 |
-
return model, tokenizer, embed_model, enhance_fn
|
| 198 |
-
|
| 199 |
-
except Exception as e:
|
| 200 |
-
print(f"⚠️ Load failed: {e}")
|
| 201 |
-
print("⏬ Training a new FlashPack model locally...")
|
| 202 |
-
model, dataset, embed_model, tokenizer, long_embeddings = train_flashpack_model()
|
| 203 |
-
push_flashpack_model_to_hf(model, hf_repo, log_fn=print)
|
| 204 |
-
return model, tokenizer, embed_model, None
|
| 205 |
-
|
| 206 |
-
def get_flashpack_model1(hf_repo="rahul7star/FlashPack"):
|
| 207 |
try:
|
| 208 |
print(f"🔁 Attempting to load FlashPack model from {hf_repo}")
|
| 209 |
model = GemmaTrainer.from_flashpack(hf_repo)
|
|
@@ -280,4 +229,4 @@ with gr.Blocks(title="Prompt Enhancer – FlashPack (CPU)", theme=gr.themes.Soft
|
|
| 280 |
# 9️⃣ Launch
|
| 281 |
# ============================================================
|
| 282 |
if __name__ == "__main__":
|
| 283 |
-
demo.launch(show_error=True)
|
|
|
|
| 9 |
from transformers import AutoTokenizer, AutoModel
|
| 10 |
from flashpack import FlashPackMixin
|
| 11 |
from huggingface_hub import Repository
|
|
|
|
| 12 |
from typing import Tuple
|
| 13 |
|
| 14 |
# ============================================================
|
|
|
|
| 22 |
# 1️⃣ FlashPack model with better hidden layers
|
| 23 |
# ============================================================
|
| 24 |
class GemmaTrainer(nn.Module, FlashPackMixin):
|
| 25 |
+
def __init__(self, input_dim: int, hidden_dim: int = 1024, output_dim: int = 1536):
|
| 26 |
super().__init__()
|
|
|
|
|
|
|
|
|
|
| 27 |
self.fc1 = nn.Linear(input_dim, hidden_dim)
|
| 28 |
self.relu = nn.ReLU()
|
| 29 |
self.fc2 = nn.Linear(hidden_dim, hidden_dim)
|
| 30 |
self.fc3 = nn.Linear(hidden_dim, output_dim)
|
| 31 |
|
| 32 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 33 |
x = self.fc1(x)
|
| 34 |
x = self.relu(x)
|
| 35 |
x = self.fc2(x)
|
|
|
|
| 153 |
# 5️⃣ Load or train model
|
| 154 |
# ============================================================
|
| 155 |
def get_flashpack_model(hf_repo="rahul7star/FlashPack"):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 156 |
try:
|
| 157 |
print(f"🔁 Attempting to load FlashPack model from {hf_repo}")
|
| 158 |
model = GemmaTrainer.from_flashpack(hf_repo)
|
|
|
|
| 229 |
# 9️⃣ Launch
|
| 230 |
# ============================================================
|
| 231 |
if __name__ == "__main__":
|
| 232 |
+
demo.launch(show_error=True)
|