rahul7star's picture
Update app_flash.py
503e5c1 verified
raw
history blame
9.56 kB
import gc
import os
import torch
import torch.nn as nn
import torch.optim as optim
import tempfile
import gradio as gr
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModel
from flashpack import FlashPackMixin
from huggingface_hub import Repository
from typing import Tuple
# ============================================================
# 🖥 Device setup (CPU-only)
# ============================================================
device = torch.device("cpu")
torch.set_num_threads(4)
print(f"🔧 Using device: {device} (CPU-only)")
# ============================================================
# 1️⃣ FlashPack model with better hidden layers
# ============================================================
class GemmaTrainer(nn.Module, FlashPackMixin):
def __init__(self, input_dim: int, hidden_dim: int = 1024, output_dim: int = 1536):
super().__init__()
self.fc1 = nn.Linear(input_dim, hidden_dim)
self.relu = nn.ReLU()
self.fc2 = nn.Linear(hidden_dim, hidden_dim)
self.fc3 = nn.Linear(hidden_dim, output_dim)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.fc1(x)
x = self.relu(x)
x = self.fc2(x)
x = self.relu(x)
x = self.fc3(x)
return x
# ============================================================
# 2️⃣ Encoder using mean+max pooling (for richer embeddings)
# ============================================================
def build_encoder(model_name="gpt2", max_length: int = 128):
tokenizer = AutoTokenizer.from_pretrained(model_name)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
embed_model = AutoModel.from_pretrained(model_name).to(device)
embed_model.eval()
@torch.no_grad()
def encode(prompt: str) -> torch.Tensor:
inputs = tokenizer(prompt, return_tensors="pt", truncation=True,
padding="max_length", max_length=max_length).to(device)
last_hidden = embed_model(**inputs).last_hidden_state
mean_pool = last_hidden.mean(dim=1)
max_pool, _ = last_hidden.max(dim=1)
return torch.cat([mean_pool, max_pool], dim=1).cpu() # doubled embedding
return tokenizer, embed_model, encode
# ============================================================
# 3️⃣ Push FlashPack model to Hugging Face
# ============================================================
def push_flashpack_model_to_hf(model, hf_repo: str):
logs = []
with tempfile.TemporaryDirectory() as tmp_dir:
logs.append(f"📂 Temporary directory: {tmp_dir}")
repo = Repository(local_dir=tmp_dir, clone_from=hf_repo, use_auth_token=True)
pack_path = os.path.join(tmp_dir, "model.flashpack")
model.save_flashpack(pack_path, target_dtype=torch.float32)
readme_path = os.path.join(tmp_dir, "README.md")
with open(readme_path, "w") as f:
f.write("# FlashPack Model\nThis repo contains a FlashPack model.")
repo.push_to_hub()
logs.append(f"✅ Model pushed to HF: {hf_repo}")
return logs
# ============================================================
# 4️⃣ Train FlashPack model
# ============================================================
def train_flashpack_model(
dataset_name: str = "rahul7star/prompt-enhancer-dataset",
max_encode: int = 1000,
hidden_dim: int = 1024,
push_to_hub: bool = True,
hf_repo: str = "rahul7star/FlashPack"
) -> Tuple[GemmaTrainer, object, object, object, torch.Tensor]:
print("📦 Loading dataset...")
dataset = load_dataset(dataset_name, split="train")
limit = min(max_encode, len(dataset))
dataset = dataset.select(range(limit))
print(f"⚡ Using {len(dataset)} prompts for training")
tokenizer, embed_model, encode_fn = build_encoder("gpt2", max_length=128)
# Encode prompts
short_list, long_list = [], []
for i, item in enumerate(dataset):
short_list.append(encode_fn(item["short_prompt"]))
long_list.append(encode_fn(item["long_prompt"]))
if (i+1) % 50 == 0 or (i+1) == len(dataset):
print(f" → Encoded {i+1}/{limit} prompts")
gc.collect()
short_embeddings = torch.vstack(short_list)
long_embeddings = torch.vstack(long_list)
print(f"✅ Encoded embeddings shape: short {short_embeddings.shape}, long {long_embeddings.shape}")
input_dim = short_embeddings.shape[1] # should match concatenated mean+max
output_dim = long_embeddings.shape[1]
model = GemmaTrainer(input_dim=input_dim, hidden_dim=hidden_dim, output_dim=output_dim).to(device)
criterion = nn.CosineSimilarity(dim=1)
optimizer = optim.Adam(model.parameters(), lr=1e-3)
max_epochs = 50
batch_size = 32
n = short_embeddings.shape[0]
print("🚀 Training model...")
for epoch in range(max_epochs):
model.train()
epoch_loss = 0.0
perm = torch.randperm(n)
for start in range(0, n, batch_size):
idx = perm[start:start+batch_size]
inputs = short_embeddings[idx].to(device)
targets = long_embeddings[idx].to(device)
optimizer.zero_grad()
outputs = model(inputs)
loss = 1 - criterion(outputs, targets).mean()
loss.backward()
optimizer.step()
epoch_loss += loss.item() * inputs.size(0)
epoch_loss /= n
if epoch % 5 == 0 or epoch == max_epochs-1:
print(f"Epoch {epoch+1}/{max_epochs}, Loss={epoch_loss:.6f}")
print("✅ Training finished!")
if push_to_hub:
logs = push_flashpack_model_to_hf(model, hf_repo)
for log in logs:
print(log)
return model, dataset, embed_model, tokenizer, long_embeddings
# ============================================================
# 5️⃣ Load or train model
# ============================================================
def get_flashpack_model(hf_repo="rahul7star/FlashPack"):
try:
print(f"🔁 Attempting to load FlashPack model from {hf_repo}")
model = GemmaTrainer.from_flashpack(hf_repo)
model.eval()
tokenizer, embed_model, encode_fn = build_encoder("gpt2", max_length=128)
return model, tokenizer, embed_model
except Exception as e:
print(f"⚠️ Load failed: {e}")
print("⏬ Training a new FlashPack model locally...")
model, dataset, embed_model, tokenizer, long_embeddings = train_flashpack_model()
push_flashpack_model_to_hf(model, hf_repo)
return model, tokenizer, embed_model, dataset, long_embeddings
# ============================================================
# 6️⃣ Load or train
# ============================================================
model, tokenizer, embed_model, dataset, long_embeddings = get_flashpack_model()
# ============================================================
# 7️⃣ Inference helpers
# ============================================================
@torch.no_grad()
def encode_for_inference(prompt: str) -> torch.Tensor:
inputs = tokenizer(prompt, return_tensors="pt", truncation=True,
padding="max_length", max_length=128).to(device)
last_hidden = embed_model(**inputs).last_hidden_state
mean_pool = last_hidden.mean(dim=1)
max_pool, _ = last_hidden.max(dim=1)
return torch.cat([mean_pool, max_pool], dim=1).cpu()
def enhance_prompt(user_prompt: str, temperature: float, max_tokens: int, chat_history):
chat_history = chat_history or []
short_emb = encode_for_inference(user_prompt)
mapped = model(short_emb.to(device)).cpu()
sims = (long_embeddings @ mapped.t()).squeeze(1)
long_norms = long_embeddings.norm(dim=1)
mapped_norm = mapped.norm()
sims = sims / (long_norms * (mapped_norm + 1e-12))
best_idx = int(sims.argmax().item())
enhanced_prompt = dataset[best_idx]["long_prompt"]
chat_history.append({"role": "user", "content": user_prompt})
chat_history.append({"role": "assistant", "content": enhanced_prompt})
return chat_history
# ============================================================
# 8️⃣ Gradio UI
# ============================================================
with gr.Blocks(title="Prompt Enhancer – FlashPack (CPU)", theme=gr.themes.Soft()) as demo:
gr.Markdown(
"""
# ✨ Prompt Enhancer (FlashPack mapper)
Enter a short prompt, and the model will **expand it with details and creative context**.
(CPU-only mode.)
"""
)
with gr.Row():
chatbot = gr.Chatbot(height=400, label="Enhanced Prompts", type="messages")
with gr.Column(scale=1):
user_prompt = gr.Textbox(placeholder="Enter a short prompt...", label="Your Prompt", lines=3)
temperature = gr.Slider(0.0, 1.0, value=0.7, step=0.05, label="Temperature")
max_tokens = gr.Slider(32, 256, value=128, step=16, label="Max Tokens")
send_btn = gr.Button("🚀 Enhance Prompt", variant="primary")
clear_btn = gr.Button("🧹 Clear Chat")
send_btn.click(enhance_prompt, [user_prompt, temperature, max_tokens, chatbot], chatbot)
user_prompt.submit(enhance_prompt, [user_prompt, temperature, max_tokens, chatbot], chatbot)
clear_btn.click(lambda: [], None, chatbot)
# ============================================================
# 9️⃣ Launch
# ============================================================
if __name__ == "__main__":
demo.launch(show_error=True)