import os import gc import torch import torch.nn as nn import torch.optim as optim import tempfile import gradio as gr from datasets import load_dataset from transformers import AutoTokenizer, AutoModel from flashpack import FlashPackMixin from huggingface_hub import Repository, list_repo_files, hf_hub_download import torch.nn.functional as F # =========================== # Device # =========================== device = torch.device("cpu") torch.set_num_threads(4) print(f"๐Ÿ”ง Using device: {device} (CPU-only mode)") # =========================== # Model Definition # =========================== class GemmaTrainer(nn.Module, FlashPackMixin): def __init__(self): super().__init__() input_dim = 1536 hidden_dim = 1024 output_dim = 1536 self.fc1 = nn.Linear(input_dim, hidden_dim) self.relu = nn.ReLU() self.fc2 = nn.Linear(hidden_dim, hidden_dim) self.fc3 = nn.Linear(hidden_dim, output_dim) def forward(self, x: torch.Tensor): x = self.fc1(x) x = self.relu(x) x = self.fc2(x) x = self.relu(x) x = self.fc3(x) return x # =========================== # Encoder # =========================== def build_encoder(model_name="gpt2", max_length: int = 128): tokenizer = AutoTokenizer.from_pretrained(model_name) if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token embed_model = AutoModel.from_pretrained(model_name).to(device) embed_model.eval() @torch.no_grad() def encode(prompt: str) -> torch.Tensor: inputs = tokenizer(prompt, return_tensors="pt", truncation=True, padding="max_length", max_length=max_length).to(device) last_hidden = embed_model(**inputs).last_hidden_state mean_pool = last_hidden.mean(dim=1) max_pool, _ = last_hidden.max(dim=1) return torch.cat([mean_pool, max_pool], dim=1).cpu() # doubled embedding return tokenizer, embed_model, encode # =========================== # Push model to HF # =========================== def push_flashpack_model_to_hf(model, hf_repo, log_fn): with tempfile.TemporaryDirectory() as tmp_dir: log_fn(f"๐Ÿ“ฆ Preparing repository {hf_repo}...") repo = Repository(local_dir=tmp_dir, clone_from=hf_repo, use_auth_token=True) model.save_flashpack(os.path.join(tmp_dir, "model.flashpack"), target_dtype=torch.float32) with open(os.path.join(tmp_dir, "README.md"), "w") as f: f.write("# FlashPack Model\nTrained locally and pushed to HF.") log_fn("โณ Pushing model to Hugging Face...") repo.push_to_hub() log_fn(f"โœ… Model pushed to {hf_repo}") # =========================== # Training # =========================== def train_flashpack_model(dataset_name="rahul7star/prompt-enhancer-dataset", hf_repo="rahul7star/FlashPack", max_encode=1000): logs = [] def log_fn(msg): logs.append(msg) print(msg) log_fn("๐Ÿ“ฆ Loading dataset...") dataset = load_dataset(dataset_name, split="train").select(range(max_encode)) log_fn(f"โœ… Loaded {len(dataset)} samples") tokenizer, embed_model, encode_fn = build_encoder("gpt2") # Encode embeddings s_list, l_list = [], [] for i, item in enumerate(dataset): s_list.append(encode_fn(item["short_prompt"])) l_list.append(encode_fn(item["long_prompt"])) if (i + 1) % 50 == 0: log_fn(f" โ†’ Encoded {i + 1}/{len(dataset)}") gc.collect() short_emb, long_emb = torch.vstack(s_list), torch.vstack(l_list) # Save embeddings & prompts for nearest-neighbor retrieval train_prompts = [item["long_prompt"] for item in dataset] model = GemmaTrainer() optimizer = optim.Adam(model.parameters(), lr=1e-3) loss_fn = nn.CosineSimilarity(dim=1) log_fn("๐Ÿš€ Training model...") for epoch in range(20): model.train() optimizer.zero_grad() preds = model(short_emb) loss = 1 - loss_fn(preds, long_emb).mean() loss.backward() optimizer.step() log_fn(f"Epoch {epoch+1}/20 | Loss: {loss.item():.5f}") if loss.item() < 0.01: log_fn("๐ŸŽฏ Early stopping.") break push_flashpack_model_to_hf(model, hf_repo, log_fn) @torch.no_grad() def enhance_fn(prompt, chat): chat = chat or [] short_emb_input = encode_fn(prompt) mapped_emb = model(short_emb_input).cpu() # Nearest neighbor sims = F.cosine_similarity(mapped_emb, long_emb) best_idx = sims.argmax().item() long_prompt = train_prompts[best_idx] chat.append({"role": "user", "content": prompt}) chat.append({"role": "assistant", "content": f"๐ŸŒŸ Enhanced prompt: {long_prompt}"}) return chat return model, tokenizer, embed_model, enhance_fn, logs # =========================== # Lazy Load / Get Model # =========================== def get_flashpack_model(hf_repo="rahul7star/FlashPack"): local_model_path = "model.flashpack" if os.path.exists(local_model_path): print("โœ… Loading local model") else: try: files = list_repo_files(hf_repo) if "model.flashpack" in files: print("โœ… Downloading model from HF") local_model_path = hf_hub_download(repo_id=hf_repo, filename="model.flashpack") else: print("๐Ÿšซ No pretrained model found") return None, None, None, None except Exception as e: print(f"โš ๏ธ Error accessing HF: {e}") return None, None, None, None model = GemmaTrainer().from_flashpack(local_model_path) model.eval() tokenizer, embed_model, encode_fn = build_encoder("gpt2") # Dummy placeholders for nearest neighbor retrieval (replace with actual dataset if available) long_emb = torch.randn(10, 1536) # placeholder embeddings train_prompts = [f"Example long prompt {i}" for i in range(10)] @torch.no_grad() def enhance_fn(prompt, chat): chat = chat or [] short_emb_input = encode_fn(prompt) mapped_emb = model(short_emb_input).cpu() sims = F.cosine_similarity(mapped_emb, long_emb) best_idx = sims.argmax().item() long_prompt = train_prompts[best_idx] chat.append({"role": "user", "content": prompt}) chat.append({"role": "assistant", "content": f"๐ŸŒŸ Enhanced prompt: {long_prompt}"}) return chat return model, tokenizer, embed_model, enhance_fn # =========================== # Gradio UI # =========================== with gr.Blocks(title="โœจ FlashPack Prompt Enhancer") as demo: gr.Markdown("## ๐Ÿง  FlashPack Prompt Enhancer (CPU)\nShort โ†’ Long prompt expander") chatbot = gr.Chatbot(height=400, type="messages") user_input = gr.Textbox(label="Your prompt") send_btn = gr.Button("๐Ÿš€ Enhance Prompt", variant="primary") clear_btn = gr.Button("๐Ÿงน Clear") train_btn = gr.Button("๐Ÿงฉ Train Model", variant="secondary") log_output = gr.Textbox(label="Logs", lines=15) # Lazy load model model, tokenizer, embed_model, enhance_fn = get_flashpack_model() logs = [] if enhance_fn is None: def enhance_fn(prompt, chat): chat = chat or [] chat.append({"role": "assistant", "content": "โš ๏ธ No pretrained model found. Please click 'Train Model' to create one."}) return chat logs.append("โš ๏ธ No pretrained model found. Ready to train.") else: logs.append("โœ… Model loaded โ€” ready to enhance.") # Button callbacks send_btn.click(enhance_fn, [user_input, chatbot], chatbot) user_input.submit(enhance_fn, [user_input, chatbot], chatbot) clear_btn.click(lambda: [], None, chatbot) def retrain(): global model, tokenizer, embed_model, enhance_fn, logs logs = ["๐Ÿš€ Training model, please wait..."] model, tokenizer, embed_model, enhance_fn, train_logs = train_flashpack_model() logs.extend(train_logs) return gr.Textbox.update(value="\n".join(logs)) train_btn.click(retrain, None, log_output) if __name__ == "__main__": demo.launch(show_error=True)