File size: 7,720 Bytes
25b932e f3268bd 31737db 25b932e d9e93e9 25b932e f3268bd 25b932e d9e93e9 5ee9a29 25b932e 5daa352 d9e93e9 5daa352 d9e93e9 25b932e 31737db f3268bd 25b932e d9e93e9 5ee9a29 f3268bd 9f53409 25b932e 9f53409 d9e93e9 9f53409 25b932e f3268bd 5ee9a29 9f53409 d9e93e9 5ee9a29 9490127 25b932e 9490127 25b932e a69d346 f3268bd 9490127 25b932e 9490127 f3268bd 5ee9a29 f3268bd 9490127 5ee9a29 9490127 f3268bd 9f53409 b2330bc 6c0c98e d9e93e9 f3268bd d9e93e9 9490127 f3268bd d9e93e9 f3268bd 5ee9a29 f3268bd 9490127 f3268bd 9490127 25b932e 9490127 a69d346 b2330bc a69d346 9490127 d9e93e9 5ee9a29 a69d346 5ee9a29 f3268bd 6c0c98e 8f4e2a0 8143e5c a348f79 5ee9a29 a69d346 5ee9a29 a69d346 8143e5c 6c0c98e f3268bd 25b932e 6c0c98e f3268bd f2309a4 b2330bc f3268bd 8143e5c f3268bd 6c0c98e 9f53409 5ee9a29 f3268bd d9e93e9 49fa7d4 f3268bd 25b932e f3268bd 9490127 25b932e a69d346 6c0c98e 9490127 f3268bd a69d346 b2330bc a69d346 9490127 a69d346 9490127 a69d346 6c0c98e f3268bd 5ee9a29 9490127 5ee9a29 9490127 25b932e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 |
import os
import gc
import torch
import torch.nn as nn
import torch.optim as optim
import tempfile
import gradio as gr
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModel
from flashpack import FlashPackMixin
from huggingface_hub import Repository, list_repo_files, hf_hub_download
device = torch.device("cpu")
torch.set_num_threads(4)
print(f"π§ Using device: {device} (CPU-only mode)")
# ===========================
# Model Definition
# ===========================
class GemmaTrainer(nn.Module, FlashPackMixin):
def __init__(self):
super().__init__()
input_dim = 1536
hidden_dim = 1024
output_dim = 1536
self.fc1 = nn.Linear(input_dim, hidden_dim)
self.relu = nn.ReLU()
self.fc2 = nn.Linear(hidden_dim, hidden_dim)
self.fc3 = nn.Linear(hidden_dim, output_dim)
def forward(self, x: torch.Tensor):
x = self.fc1(x)
x = self.relu(x)
x = self.fc2(x)
x = self.relu(x)
x = self.fc3(x)
return x
# ===========================
# Encoder
# ===========================
def build_encoder(model_name="gpt2", max_length=128):
tokenizer = AutoTokenizer.from_pretrained(model_name)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
embed_model = AutoModel.from_pretrained(model_name).to(device)
embed_model.eval()
@torch.no_grad()
def encode(prompt: str) -> torch.Tensor:
inputs = tokenizer(prompt, return_tensors="pt", truncation=True,
padding="max_length", max_length=max_length).to(device)
hidden = embed_model(**inputs).last_hidden_state
mean_pool = hidden.mean(dim=1)
max_pool, _ = hidden.max(dim=1)
return torch.cat([mean_pool, max_pool], dim=1).cpu()
return tokenizer, embed_model, encode
# ===========================
# Push model to HF
# ===========================
def push_flashpack_model_to_hf(model, hf_repo, log_fn):
with tempfile.TemporaryDirectory() as tmp_dir:
log_fn(f"π¦ Preparing repository {hf_repo}...")
repo = Repository(local_dir=tmp_dir, clone_from=hf_repo, use_auth_token=True)
model.save_flashpack(os.path.join(tmp_dir, "model.flashpack"), target_dtype=torch.float32)
with open(os.path.join(tmp_dir, "README.md"), "w") as f:
f.write("# FlashPack Model\nTrained locally and pushed to HF.")
log_fn("β³ Pushing model to Hugging Face...")
repo.push_to_hub()
log_fn(f"β
Model pushed to {hf_repo}")
# ===========================
# Training
# ===========================
def train_flashpack_model(dataset_name="rahul7star/prompt-enhancer-dataset",
hf_repo="rahul7star/FlashPack",
max_encode=1000):
logs = []
def log_fn(msg):
logs.append(msg)
print(msg)
log_fn("π¦ Loading dataset...")
dataset = load_dataset(dataset_name, split="train").select(range(max_encode))
log_fn(f"β
Loaded {len(dataset)} samples")
tokenizer, embed_model, encode_fn = build_encoder("gpt2")
# Only encode short+long embeddings
s_list, l_list = [], []
for i, item in enumerate(dataset):
s_list.append(encode_fn(item["short_prompt"]))
l_list.append(encode_fn(item["long_prompt"]))
if (i + 1) % 50 == 0:
log_fn(f" β Encoded {i + 1}/{len(dataset)}")
gc.collect()
short_emb, long_emb = torch.vstack(s_list), torch.vstack(l_list)
model = GemmaTrainer()
optimizer = optim.Adam(model.parameters(), lr=1e-3)
loss_fn = nn.CosineSimilarity(dim=1)
log_fn("π Training model...")
for epoch in range(20):
model.train()
optimizer.zero_grad()
preds = model(short_emb)
loss = 1 - loss_fn(preds, long_emb).mean()
loss.backward()
optimizer.step()
log_fn(f"Epoch {epoch+1}/20 | Loss: {loss.item():.5f}")
if loss.item() < 0.01:
log_fn("π― Early stopping.")
break
push_flashpack_model_to_hf(model, hf_repo, log_fn)
tokenizer, embed_model, encode_fn = build_encoder("gpt2")
@torch.no_grad()
def enhance_fn(prompt, chat):
chat = chat or []
short_emb = encode_fn(prompt)
mapped = model(short_emb.to(device)).cpu()
long_prompt = f"π Enhanced prompt: {prompt} (creatively expanded)"
chat.append({"role": "user", "content": prompt})
chat.append({"role": "assistant", "content": long_prompt})
return chat
return model, tokenizer, embed_model, enhance_fn, logs
# ===========================
# Lazy Load / Get Model
# ===========================
def get_flashpack_model(hf_repo="rahul7star/FlashPack"):
local_model_path = "model.flashpack"
if os.path.exists(local_model_path):
print("β
Loading local model")
else:
try:
files = list_repo_files(hf_repo)
if "model.flashpack" in files:
print("β
Downloading model from HF")
local_model_path = hf_hub_download(repo_id=hf_repo, filename="model.flashpack")
else:
print("π« No pretrained model found")
return None, None, None, None
except Exception as e:
print(f"β οΈ Error accessing HF: {e}")
return None, None, None, None
model = GemmaTrainer().from_flashpack(local_model_path)
model.eval()
tokenizer, embed_model, encode_fn = build_encoder("gpt2")
@torch.no_grad()
def enhance_fn(prompt, chat):
chat = chat or []
short_emb = encode_fn(prompt).to(device)
mapped = model(short_emb).cpu()
long_prompt = f"π Enhanced prompt: {prompt} (creatively expanded)"
chat.append({"role": "user", "content": prompt})
chat.append({"role": "assistant", "content": long_prompt})
return chat
return model, tokenizer, embed_model, enhance_fn
# ===========================
# Gradio UI
# ===========================
with gr.Blocks(title="β¨ FlashPack Prompt Enhancer") as demo:
gr.Markdown("## π§ FlashPack Prompt Enhancer (CPU)\nShort β Long prompt expander")
chatbot = gr.Chatbot(height=400, type="messages")
user_input = gr.Textbox(label="Your prompt")
send_btn = gr.Button("π Enhance Prompt", variant="primary")
clear_btn = gr.Button("π§Ή Clear")
train_btn = gr.Button("π§© Train Model", variant="secondary")
log_output = gr.Textbox(label="Logs", lines=15)
# Lazy load model
model, tokenizer, embed_model, enhance_fn = get_flashpack_model()
logs = []
if enhance_fn is None:
def enhance_fn(prompt, chat):
chat = chat or []
chat.append({"role": "assistant",
"content": "β οΈ No pretrained model found. Please click 'Train Model' to create one."})
return chat
logs.append("β οΈ No pretrained model found. Ready to train.")
else:
logs.append("β
Model loaded β ready to enhance.")
# Button callbacks
send_btn.click(enhance_fn, [user_input, chatbot], chatbot)
user_input.submit(enhance_fn, [user_input, chatbot], chatbot)
clear_btn.click(lambda: [], None, chatbot)
def retrain():
global model, tokenizer, embed_model, enhance_fn, logs
logs = ["π Training model, please wait..."]
model, tokenizer, embed_model, enhance_fn, train_logs = train_flashpack_model()
logs.extend(train_logs)
return gr.Textbox.update(value="\n".join(logs))
train_btn.click(retrain, None, log_output)
if __name__ == "__main__":
demo.launch(show_error=True)
|