HemanM commited on
Commit
0c388fc
·
verified ·
1 Parent(s): b9f30a1

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +380 -0
app.py ADDED
@@ -0,0 +1,380 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # app.py
2
+ import math, json, random, time, threading, io, os
3
+ from dataclasses import dataclass, asdict
4
+ from typing import List, Tuple, Dict, Any
5
+ import numpy as np
6
+ import plotly.graph_objs as go
7
+ import gradio as gr
8
+
9
+ # =========================
10
+ # UX THEME & STYLES
11
+ # =========================
12
+ CUSTOM_CSS = """
13
+ :root {
14
+ --radius-2xl: 20px;
15
+ }
16
+ .gradio-container {max-width: 1400px !important}
17
+ #header-card {border-radius: var(--radius-2xl); box-shadow: 0 6px 24px rgba(0,0,0,0.08)}
18
+ #viz-card, #right-card, #table-card {border-radius: var(--radius-2xl); box-shadow: 0 6px 24px rgba(0,0,0,0.06)}
19
+ #stats {display:flex; gap:16px; flex-wrap:wrap}
20
+ .stat {flex:1; min-width:180px; background:#0b1220; color:white; border-radius:16px; padding:14px 16px}
21
+ .stat .k {font-size:14px; opacity:0.8}
22
+ .stat .v {font-size:22px; font-weight:700}
23
+ .gr-button {border-radius:14px}
24
+ """
25
+
26
+ # =========================
27
+ # GENOME & EVOLUTION CORE
28
+ # =========================
29
+ @dataclass
30
+ class Genome:
31
+ d_model: int
32
+ n_layers: int
33
+ n_heads: int
34
+ ffn_mult: float
35
+ memory_tokens: int
36
+ dropout: float
37
+ species: int = 0
38
+ fitness: float = float("inf")
39
+
40
+ def vector(self) -> np.ndarray:
41
+ # Normalized structural vector (0..1)
42
+ return np.array([
43
+ self.d_model / 1024.0,
44
+ self.n_layers / 24.0,
45
+ self.n_heads / 32.0,
46
+ self.ffn_mult / 8.0,
47
+ self.memory_tokens / 64.0,
48
+ self.dropout / 0.5
49
+ ], dtype=np.float32)
50
+
51
+ def random_genome(rng: random.Random) -> Genome:
52
+ return Genome(
53
+ d_model=rng.choice([256, 384, 512, 640]),
54
+ n_layers=rng.choice([4, 6, 8, 10, 12]),
55
+ n_heads=rng.choice([4, 6, 8, 10, 12]),
56
+ ffn_mult=rng.choice([2.0, 3.0, 4.0, 6.0]),
57
+ memory_tokens=rng.choice([0, 4, 8, 16]),
58
+ dropout=rng.choice([0.0, 0.05, 0.1, 0.15]),
59
+ species=rng.randrange(5)
60
+ )
61
+
62
+ def mutate(g: Genome, rng: random.Random, rate: float) -> Genome:
63
+ g = Genome(**asdict(g))
64
+ if rng.random() < rate: g.d_model = rng.choice([256, 384, 512, 640])
65
+ if rng.random() < rate: g.n_layers = rng.choice([4, 6, 8, 10, 12])
66
+ if rng.random() < rate: g.n_heads = rng.choice([4, 6, 8, 10, 12])
67
+ if rng.random() < rate: g.ffn_mult = rng.choice([2.0, 3.0, 4.0, 6.0])
68
+ if rng.random() < rate: g.memory_tokens = rng.choice([0, 4, 8, 16])
69
+ if rng.random() < rate: g.dropout = rng.choice([0.0, 0.05, 0.1, 0.15])
70
+ if rng.random() < rate * 0.5: g.species = rng.randrange(5)
71
+ g.fitness = float("inf")
72
+ return g
73
+
74
+ def crossover(a: Genome, b: Genome, rng: random.Random) -> Genome:
75
+ return Genome(
76
+ d_model = a.d_model if rng.random()<0.5 else b.d_model,
77
+ n_layers = a.n_layers if rng.random()<0.5 else b.n_layers,
78
+ n_heads = a.n_heads if rng.random()<0.5 else b.n_heads,
79
+ ffn_mult = a.ffn_mult if rng.random()<0.5 else b.ffn_mult,
80
+ memory_tokens = a.memory_tokens if rng.random()<0.5 else b.memory_tokens,
81
+ dropout = a.dropout if rng.random()<0.5 else b.dropout,
82
+ species = a.species if rng.random()<0.5 else b.species,
83
+ fitness = float("inf")
84
+ )
85
+
86
+ # =========================
87
+ # FITNESS HOOK (Phase 1: fast surrogate)
88
+ # Swap this later for real PIQA/HellaSwag evaluation
89
+ # =========================
90
+ def rastrigin(x: np.ndarray) -> float:
91
+ A, n = 10.0, x.shape[0]
92
+ return A * n + np.sum(x**2 - A * np.cos(2 * math.pi * x))
93
+
94
+ def fitness_hook(genome: Genome, dataset: str, explore: float) -> float:
95
+ """
96
+ Phase 1 (demo, fast):
97
+ - Build vector v in [-1,1] from genome params and score via Rastrigin.
98
+ - Add small parsimony penalty and exploration noise.
99
+ Phase 2 (real):
100
+ - Replace with tiny train/eval steps on chosen dataset (PIQA/HellaSwag/WikiText-ppl).
101
+ """
102
+ v = genome.vector() * 2 - 1 # [-1,1]
103
+ base = rastrigin(v)
104
+ parsimony = 0.001 * (genome.d_model + 50*genome.n_layers + 20*genome.n_heads + 100*genome.memory_tokens)
105
+ noise = np.random.normal(scale=0.05 * max(0.0, min(1.0, explore)))
106
+ return float(base + parsimony + noise)
107
+
108
+ # =========================
109
+ # PROJECTION & VIZ
110
+ # =========================
111
+ def sphere_project(points: np.ndarray) -> np.ndarray:
112
+ # Fixed random projection 6D -> 3D then normalize to unit sphere
113
+ rng = np.random.RandomState(42)
114
+ W = rng.normal(size=(points.shape[1], 3)).astype(np.float32)
115
+ Y = points @ W
116
+ norms = np.linalg.norm(Y, axis=1, keepdims=True) + 1e-8
117
+ return Y / norms
118
+
119
+ def make_sphere_figure(points3d: np.ndarray, genomes: List[Genome], gen_idx: int) -> go.Figure:
120
+ species = np.array([g.species for g in genomes])
121
+ tooltip = [
122
+ json.dumps({k:v for k,v in asdict(g).items() if k!="fitness"}) + f"\nfitness={g.fitness:.3f}"
123
+ for g in genomes
124
+ ]
125
+
126
+ scatter = go.Scatter3d(
127
+ x=points3d[:,0], y=points3d[:,1], z=points3d[:,2],
128
+ mode='markers',
129
+ marker=dict(size=6, color=species, opacity=0.9),
130
+ text=tooltip, hoverinfo='text'
131
+ )
132
+
133
+ # Sphere mesh
134
+ u = np.linspace(0, 2*np.pi, 48)
135
+ v = np.linspace(0, np.pi, 24)
136
+ xs = np.outer(np.cos(u), np.sin(v))
137
+ ys = np.outer(np.sin(u), np.sin(v))
138
+ zs = np.outer(np.ones_like(u), np.cos(v))
139
+ sphere = go.Surface(x=xs, y=ys, z=zs, opacity=0.15, showscale=False)
140
+
141
+ layout = go.Layout(
142
+ title=f"Evo Sphere — Generation {gen_idx}",
143
+ scene=dict(xaxis=dict(visible=False), yaxis=dict(visible=False), zaxis=dict(visible=False)),
144
+ margin=dict(l=0, r=0, t=40, b=0),
145
+ showlegend=False
146
+ )
147
+ return go.Figure(data=[sphere, scatter], layout=layout)
148
+
149
+ def make_history_figure(history: List[Tuple[int,float]]) -> go.Figure:
150
+ xs = [h[0] for h in history]
151
+ ys = [h[1] for h in history]
152
+ fig = go.Figure(data=[go.Scatter(x=xs, y=ys, mode="lines+markers")])
153
+ fig.update_layout(title="Best Fitness per Generation", xaxis_title="Generation",
154
+ yaxis_title="Fitness (lower is better)",
155
+ margin=dict(l=30,r=10,t=40,b=30))
156
+ return fig
157
+
158
+ def approx_params(g: Genome) -> int:
159
+ # Very rough estimate ignoring embeddings/vocab:
160
+ # per-layer ~ (4 + 2*ffn_mult) * d_model^2
161
+ per_layer = (4.0 + 2.0 * float(g.ffn_mult)) * (g.d_model ** 2)
162
+ total = per_layer * g.n_layers
163
+ # tiny bump for memory tokens pathways (illustrative only)
164
+ total += 1000 * g.memory_tokens
165
+ return int(total)
166
+
167
+ # =========================
168
+ # ORCHESTRATOR
169
+ # =========================
170
+ class EvoRunner:
171
+ def __init__(self):
172
+ self.lock = threading.Lock()
173
+ self.running = False
174
+ self.stop_flag = False
175
+ self.state: Dict[str, Any] = {}
176
+
177
+ def run(self, dataset, pop_size, generations, mutation_rate, explore, exploit, seed, pace_ms):
178
+ rng = random.Random(int(seed))
179
+ self.stop_flag = False
180
+ self.running = True
181
+
182
+ pop: List[Genome] = [random_genome(rng) for _ in range(pop_size)]
183
+ # initial eval
184
+ for g in pop:
185
+ g.fitness = fitness_hook(g, dataset, explore)
186
+
187
+ history: List[Tuple[int,float]] = []
188
+ best_overall: Genome | None = None
189
+
190
+ for gen in range(1, generations+1):
191
+ if self.stop_flag: break
192
+
193
+ # Selection: tournament size depends on exploitation
194
+ k = max(2, int(2 + exploit * 5))
195
+ parents = []
196
+ for _ in range(pop_size):
197
+ sample = rng.sample(pop, k=k)
198
+ parents.append(min(sample, key=lambda x: x.fitness))
199
+
200
+ # Reproduce
201
+ children = []
202
+ for i in range(0, pop_size, 2):
203
+ a = parents[i]
204
+ b = parents[(i+1) % pop_size]
205
+ child1 = mutate(crossover(a,b,rng), rng, mutation_rate)
206
+ child2 = mutate(crossover(b,a,rng), rng, mutation_rate)
207
+ children.extend([child1, child2])
208
+ children = children[:pop_size]
209
+
210
+ # Evaluate kids
211
+ for c in children:
212
+ c.fitness = fitness_hook(c, dataset, explore)
213
+
214
+ # Elitism
215
+ elite_n = max(1, pop_size // 10)
216
+ elites = sorted(pop, key=lambda x: x.fitness)[:elite_n]
217
+
218
+ # Next pop
219
+ pop = sorted(children, key=lambda x: x.fitness)
220
+ pop[-elite_n:] = elites
221
+
222
+ best = min(pop, key=lambda x: x.fitness)
223
+ if best_overall is None or best.fitness < best_overall.fitness:
224
+ best_overall = best
225
+
226
+ history.append((gen, best.fitness))
227
+
228
+ # Viz snapshot
229
+ P = np.stack([g.vector() for g in pop], axis=0)
230
+ P3 = sphere_project(P)
231
+ sphere_fig = make_sphere_figure(P3, pop, gen)
232
+ hist_fig = make_history_figure(history)
233
+ top = sorted(pop, key=lambda x: x.fitness)[: min(12, len(pop))]
234
+ top_table = [
235
+ {
236
+ "gen": gen,
237
+ "fitness": round(t.fitness, 4),
238
+ "d_model": t.d_model,
239
+ "layers": t.n_layers,
240
+ "heads": t.n_heads,
241
+ "ffn_mult": t.ffn_mult,
242
+ "mem": t.memory_tokens,
243
+ "dropout": t.dropout,
244
+ "species": t.species,
245
+ "params_approx": approx_params(t)
246
+ } for t in top
247
+ ]
248
+ best_card = top_table[0] if len(top_table) else {}
249
+
250
+ with self.lock:
251
+ self.state = {
252
+ "sphere": sphere_fig,
253
+ "history": hist_fig,
254
+ "top": top_table,
255
+ "best": best_card,
256
+ "gen": gen,
257
+ "dataset": dataset
258
+ }
259
+
260
+ time.sleep(max(0.0, pace_ms/1000.0))
261
+
262
+ self.running = False
263
+
264
+ def start(self, *args, **kwargs):
265
+ if self.running: return
266
+ t = threading.Thread(target=self.run, args=args, kwargs=kwargs, daemon=True)
267
+ t.start()
268
+
269
+ def stop(self):
270
+ self.stop_flag = True
271
+
272
+ runner = EvoRunner()
273
+
274
+ # =========================
275
+ # GRADIO UI CALLBACKS
276
+ # =========================
277
+ def start_evo(dataset, pop, gens, mut, explore, exploit, seed, pace_ms):
278
+ runner.start(dataset, int(pop), int(gens), float(mut), float(explore), float(exploit), int(seed), int(pace_ms))
279
+ return (gr.update(interactive=False), gr.update(interactive=True))
280
+
281
+ def stop_evo():
282
+ runner.stop()
283
+ return (gr.update(interactive=True), gr.update(interactive=False))
284
+
285
+ def poll_state():
286
+ with runner.lock:
287
+ s = runner.state.copy()
288
+ # Defaults before first run
289
+ sphere = s.get("sphere", go.Figure())
290
+ history = s.get("history", go.Figure())
291
+ best = s.get("best", {})
292
+ gen = s.get("gen", 0)
293
+ dataset = s.get("dataset", "Demo (Surrogate)")
294
+ top = s.get("top", [])
295
+ # Build stats Markdown
296
+ if best:
297
+ stats_md = (
298
+ f"**Dataset:** {dataset} \n"
299
+ f"**Generation:** {gen} \n"
300
+ f"**Best fitness:** {best.get('fitness','–')} \n"
301
+ f"**Config:** d_model={best.get('d_model')} · layers={best.get('layers')} · "
302
+ f"heads={best.get('heads')} · ffn_mult={best.get('ffn_mult')} · mem={best.get('mem')} · "
303
+ f"dropout={best.get('dropout')} \n"
304
+ f"**~Params (rough):** {best.get('params_approx'):,}"
305
+ )
306
+ else:
307
+ stats_md = "Waiting… click **Start Evolution**."
308
+
309
+ # Dataframe rows
310
+ import pandas as pd
311
+ df = pd.DataFrame(top)
312
+ return sphere, history, stats_md, df
313
+
314
+ def export_snapshot():
315
+ with runner.lock:
316
+ payload = json.dumps(runner.state, default=lambda o: o, indent=2)
317
+ # Write to a temp file so user can download
318
+ path = "evo_snapshot.json"
319
+ with open(path, "w", encoding="utf-8") as f:
320
+ f.write(payload)
321
+ return path
322
+
323
+ # =========================
324
+ # BUILD UI
325
+ # =========================
326
+ with gr.Blocks(theme=gr.themes.Soft(), css=CUSTOM_CSS) as demo:
327
+ with gr.Column(elem_id="header-card"):
328
+ gr.Markdown(
329
+ "# Evo Playground — Live Evolving Transformer Architectures\n"
330
+ "Watch the population **mutate, recombine, and converge** in real time. "
331
+ "Choose a dataset and search behavior; the 3D sphere shows the architecture landscape (species = colors)."
332
+ )
333
+
334
+ with gr.Row():
335
+ # LEFT: Controls
336
+ with gr.Column(scale=1):
337
+ with gr.Group():
338
+ dataset = gr.Dropdown(
339
+ label="Dataset",
340
+ choices=["Demo (Surrogate)", "PIQA (Phase 2)", "HellaSwag (Phase 2)", "WikiText Perplexity (Phase 2)"],
341
+ value="Demo (Surrogate)",
342
+ info="Demo is instant. Phase 2 datasets will do tiny train/eval steps per genome."
343
+ )
344
+ pop = gr.Slider(8, 80, value=24, step=2, label="Population size")
345
+ gens = gr.Slider(5, 200, value=60, step=1, label="Max generations")
346
+ mut = gr.Slider(0.05, 0.9, value=0.25, step=0.01, label="Mutation rate")
347
+ with gr.Row():
348
+ explore = gr.Slider(0.0, 1.0, value=0.35, step=0.05, label="Exploration")
349
+ exploit = gr.Slider(0.0, 1.0, value=0.65, step=0.05, label="Exploitation")
350
+ seed = gr.Number(value=42, label="Seed", precision=0)
351
+ pace = gr.Slider(0, 1000, value=120, step=10, label="Pace (ms between gens)")
352
+ with gr.Row():
353
+ start = gr.Button("▶ Start Evolution", variant="primary")
354
+ stop = gr.Button("⏹ Stop", variant="secondary")
355
+
356
+ with gr.Group(elem_id="right-card"):
357
+ stats_md = gr.Markdown("Waiting…")
358
+
359
+ export_btn = gr.Button("Export Snapshot (JSON)")
360
+ export_file = gr.File(label="Download snapshot", visible=False)
361
+
362
+ # RIGHT: Viz + Table
363
+ with gr.Column(scale=2):
364
+ with gr.Group(elem_id="viz-card"):
365
+ sphere_plot = gr.Plot(label="Evolution Sphere")
366
+ with gr.Group(elem_id="viz-card"):
367
+ hist_plot = gr.Plot(label="Best Fitness History")
368
+ with gr.Group(elem_id="table-card"):
369
+ top_df = gr.Dataframe(label="Top Genomes (live)", wrap=True, interactive=False)
370
+
371
+ # Wiring
372
+ start.click(start_evo, [dataset, pop, gens, mut, explore, exploit, seed, pace], [start, stop])
373
+ stop.click(stop_evo, [], [start, stop])
374
+ export_btn.click(export_snapshot, [], [export_file])
375
+
376
+ # Live polling
377
+ demo.load(poll_state, None, [sphere_plot, hist_plot, stats_md, top_df], every=0.7)
378
+
379
+ if __name__ == "__main__":
380
+ demo.launch()