File size: 7,720 Bytes
25b932e
f3268bd
31737db
25b932e
 
 
 
d9e93e9
 
25b932e
f3268bd
25b932e
 
 
 
d9e93e9
5ee9a29
 
 
25b932e
5daa352
d9e93e9
5daa352
 
 
d9e93e9
 
25b932e
 
31737db
f3268bd
25b932e
 
 
 
 
 
d9e93e9
5ee9a29
 
 
f3268bd
9f53409
 
 
25b932e
9f53409
d9e93e9
9f53409
 
25b932e
 
f3268bd
 
 
 
5ee9a29
9f53409
d9e93e9
5ee9a29
 
 
9490127
25b932e
9490127
25b932e
a69d346
f3268bd
 
9490127
25b932e
9490127
f3268bd
5ee9a29
 
 
f3268bd
 
 
9490127
 
 
 
 
 
 
5ee9a29
9490127
 
f3268bd
9f53409
b2330bc
 
 
 
 
 
 
 
 
 
6c0c98e
d9e93e9
f3268bd
d9e93e9
9490127
f3268bd
d9e93e9
f3268bd
5ee9a29
 
f3268bd
 
9490127
f3268bd
9490127
25b932e
 
9490127
a69d346
 
 
 
 
 
 
b2330bc
a69d346
 
 
 
9490127
d9e93e9
5ee9a29
a69d346
5ee9a29
f3268bd
6c0c98e
8f4e2a0
8143e5c
 
a348f79
5ee9a29
 
 
 
 
 
a69d346
 
5ee9a29
a69d346
 
8143e5c
6c0c98e
 
 
f3268bd
25b932e
6c0c98e
f3268bd
f2309a4
 
b2330bc
f3268bd
8143e5c
f3268bd
6c0c98e
 
9f53409
5ee9a29
 
 
f3268bd
 
d9e93e9
49fa7d4
f3268bd
25b932e
f3268bd
 
9490127
25b932e
a69d346
6c0c98e
9490127
f3268bd
a69d346
 
 
b2330bc
 
a69d346
9490127
a69d346
9490127
a69d346
 
6c0c98e
 
f3268bd
5ee9a29
 
9490127
 
 
 
 
5ee9a29
9490127
25b932e
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
import os
import gc
import torch
import torch.nn as nn
import torch.optim as optim
import tempfile
import gradio as gr
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModel
from flashpack import FlashPackMixin
from huggingface_hub import Repository, list_repo_files, hf_hub_download

device = torch.device("cpu")
torch.set_num_threads(4)
print(f"πŸ”§ Using device: {device} (CPU-only mode)")

# ===========================
# Model Definition
# ===========================
class GemmaTrainer(nn.Module, FlashPackMixin):
    def __init__(self):
        super().__init__()
        input_dim = 1536
        hidden_dim = 1024
        output_dim = 1536
        self.fc1 = nn.Linear(input_dim, hidden_dim)
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(hidden_dim, hidden_dim)
        self.fc3 = nn.Linear(hidden_dim, output_dim)

    def forward(self, x: torch.Tensor):
        x = self.fc1(x)
        x = self.relu(x)
        x = self.fc2(x)
        x = self.relu(x)
        x = self.fc3(x)
        return x

# ===========================
# Encoder
# ===========================
def build_encoder(model_name="gpt2", max_length=128):
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token
    embed_model = AutoModel.from_pretrained(model_name).to(device)
    embed_model.eval()

    @torch.no_grad()
    def encode(prompt: str) -> torch.Tensor:
        inputs = tokenizer(prompt, return_tensors="pt", truncation=True,
                           padding="max_length", max_length=max_length).to(device)
        hidden = embed_model(**inputs).last_hidden_state
        mean_pool = hidden.mean(dim=1)
        max_pool, _ = hidden.max(dim=1)
        return torch.cat([mean_pool, max_pool], dim=1).cpu()
    
    return tokenizer, embed_model, encode

# ===========================
# Push model to HF
# ===========================
def push_flashpack_model_to_hf(model, hf_repo, log_fn):
    with tempfile.TemporaryDirectory() as tmp_dir:
        log_fn(f"πŸ“¦ Preparing repository {hf_repo}...")
        repo = Repository(local_dir=tmp_dir, clone_from=hf_repo, use_auth_token=True)
        model.save_flashpack(os.path.join(tmp_dir, "model.flashpack"), target_dtype=torch.float32)
        with open(os.path.join(tmp_dir, "README.md"), "w") as f:
            f.write("# FlashPack Model\nTrained locally and pushed to HF.")
        log_fn("⏳ Pushing model to Hugging Face...")
        repo.push_to_hub()
        log_fn(f"βœ… Model pushed to {hf_repo}")

# ===========================
# Training
# ===========================
def train_flashpack_model(dataset_name="rahul7star/prompt-enhancer-dataset",
                          hf_repo="rahul7star/FlashPack",
                          max_encode=1000):
    logs = []

    def log_fn(msg):
        logs.append(msg)
        print(msg)
    
    log_fn("πŸ“¦ Loading dataset...")
    dataset = load_dataset(dataset_name, split="train").select(range(max_encode))
    log_fn(f"βœ… Loaded {len(dataset)} samples")

    tokenizer, embed_model, encode_fn = build_encoder("gpt2")

    # Only encode short+long embeddings
    s_list, l_list = [], []
    for i, item in enumerate(dataset):
        s_list.append(encode_fn(item["short_prompt"]))
        l_list.append(encode_fn(item["long_prompt"]))
        if (i + 1) % 50 == 0:
            log_fn(f"  β†’ Encoded {i + 1}/{len(dataset)}")
            gc.collect()
    short_emb, long_emb = torch.vstack(s_list), torch.vstack(l_list)

    model = GemmaTrainer()
    optimizer = optim.Adam(model.parameters(), lr=1e-3)
    loss_fn = nn.CosineSimilarity(dim=1)

    log_fn("πŸš€ Training model...")
    for epoch in range(20):
        model.train()
        optimizer.zero_grad()
        preds = model(short_emb)
        loss = 1 - loss_fn(preds, long_emb).mean()
        loss.backward()
        optimizer.step()
        log_fn(f"Epoch {epoch+1}/20 | Loss: {loss.item():.5f}")
        if loss.item() < 0.01:
            log_fn("🎯 Early stopping.")
            break

    push_flashpack_model_to_hf(model, hf_repo, log_fn)
    tokenizer, embed_model, encode_fn = build_encoder("gpt2")

    @torch.no_grad()
    def enhance_fn(prompt, chat):
        chat = chat or []
        short_emb = encode_fn(prompt)
        mapped = model(short_emb.to(device)).cpu()
        long_prompt = f"🌟 Enhanced prompt: {prompt} (creatively expanded)"
        chat.append({"role": "user", "content": prompt})
        chat.append({"role": "assistant", "content": long_prompt})
        return chat

    return model, tokenizer, embed_model, enhance_fn, logs

# ===========================
# Lazy Load / Get Model
# ===========================
def get_flashpack_model(hf_repo="rahul7star/FlashPack"):
    local_model_path = "model.flashpack"

    if os.path.exists(local_model_path):
        print("βœ… Loading local model")
    else:
        try:
            files = list_repo_files(hf_repo)
            if "model.flashpack" in files:
                print("βœ… Downloading model from HF")
                local_model_path = hf_hub_download(repo_id=hf_repo, filename="model.flashpack")
            else:
                print("🚫 No pretrained model found")
                return None, None, None, None
        except Exception as e:
            print(f"⚠️ Error accessing HF: {e}")
            return None, None, None, None

    model = GemmaTrainer().from_flashpack(local_model_path)
    model.eval()
    tokenizer, embed_model, encode_fn = build_encoder("gpt2")

    @torch.no_grad()
    def enhance_fn(prompt, chat):
        chat = chat or []
        short_emb = encode_fn(prompt).to(device)
        mapped = model(short_emb).cpu()
        long_prompt = f"🌟 Enhanced prompt: {prompt} (creatively expanded)"
        chat.append({"role": "user", "content": prompt})
        chat.append({"role": "assistant", "content": long_prompt})
        return chat

    return model, tokenizer, embed_model, enhance_fn

# ===========================
# Gradio UI
# ===========================
with gr.Blocks(title="✨ FlashPack Prompt Enhancer") as demo:
    gr.Markdown("## 🧠 FlashPack Prompt Enhancer (CPU)\nShort β†’ Long prompt expander")

    chatbot = gr.Chatbot(height=400, type="messages")
    user_input = gr.Textbox(label="Your prompt")
    send_btn = gr.Button("πŸš€ Enhance Prompt", variant="primary")
    clear_btn = gr.Button("🧹 Clear")
    train_btn = gr.Button("🧩 Train Model", variant="secondary")
    log_output = gr.Textbox(label="Logs", lines=15)

    # Lazy load model
    model, tokenizer, embed_model, enhance_fn = get_flashpack_model()
    logs = []

    if enhance_fn is None:
        def enhance_fn(prompt, chat):
            chat = chat or []
            chat.append({"role": "assistant",
                         "content": "⚠️ No pretrained model found. Please click 'Train Model' to create one."})
            return chat
        logs.append("⚠️ No pretrained model found. Ready to train.")
    else:
        logs.append("βœ… Model loaded β€” ready to enhance.")

    # Button callbacks
    send_btn.click(enhance_fn, [user_input, chatbot], chatbot)
    user_input.submit(enhance_fn, [user_input, chatbot], chatbot)
    clear_btn.click(lambda: [], None, chatbot)

    def retrain():
        global model, tokenizer, embed_model, enhance_fn, logs
        logs = ["πŸš€ Training model, please wait..."]
        model, tokenizer, embed_model, enhance_fn, train_logs = train_flashpack_model()
        logs.extend(train_logs)
        return gr.Textbox.update(value="\n".join(logs))

    train_btn.click(retrain, None, log_output)

if __name__ == "__main__":
    demo.launch(show_error=True)