File size: 6,455 Bytes
5d99e98
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39da461
 
 
 
 
 
 
 
 
 
 
 
5d99e98
 
 
 
 
 
 
39da461
5d99e98
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
import torch

torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True

import os
import random
import numpy as np
import gradio as gr
import spaces
from lakonlab.models.diffusions.schedulers.flow_map_sde import FlowMapSDEScheduler
from lakonlab.ui.gradio.create_img_edit import create_interface_img_edit
from lakonlab.pipelines.pipeline_piflux2 import PiFlux2Pipeline
from lakonlab.pipelines.prompt_rewriters.qwen3_vl import Qwen3VLPromptRewriter


DEFAULT_PROMPT = """Museum-style FIELD GUIDE poster on neutral parchment (#F3EEE3). Use Inter (or Helvetica/Arial). All text #2D3748, thin connector lines 1px #A0AEC0.

Center: full-body original fantasy creature, 3/4 standing pose. Around it: four small inset boxes labeled exactly "EYE DETAIL", "FOOT DETAIL", "SKIN TEXTURE", "SILHOUETTE SCALE" (with a simple human comparison silhouette). Bottom: a short footprint trail diagram. One small habitat vignette (misty rocky shoreline with tide pools).

Exact text (only these, clean print layout):
Top: "FIELD GUIDE"
Sub: "AURORA SHOREWALKER"
Small line: "CLASS: COASTAL DRIFTER"
Under silhouette: "HEIGHT: 1.7 m"

Crisp ink outlines with soft watercolor-like fills, high readability, balanced hierarchy, premium poster aesthetic."""

SYSTEM_PROMPT_TEXT_ONLY_PATH = 'lakonlab/pipelines/prompt_rewriters/system_prompts/default_text_only.txt'
SYSTEM_PROMPT_WITH_IMAGES_PATH = 'lakonlab/pipelines/prompt_rewriters/system_prompts/default_with_images.txt'


def _patch_diffusers_bnb_shape_check():
    try:
        import diffusers.quantizers.bitsandbytes.bnb_quantizer as bnbq
    except Exception:
        return

    def _numel(shape):
        if shape is None:
            return None
        if hasattr(shape, "numel"):          # torch.Size
            return int(shape.numel())
        # plain tuple/list
        n = 1
        for d in shape:
            n *= int(d)
        return n

    def patched_check(self, param_name, current_param, loaded_param):
        cshape = getattr(current_param, "shape", None)
        lshape = getattr(loaded_param, "shape", None)
        n = _numel(cshape)
        inferred_shape = (n,) if "bias" in param_name else ((n + 1) // 2, 1)
        if tuple(lshape) != tuple(inferred_shape):
            raise ValueError(
                f"Expected flattened shape mismatch for {param_name}: "
                f"loaded={tuple(lshape)} inferred={tuple(inferred_shape)}"
            )
        return True

    # Patch any quantizer class in that module that defines the method
    for name, obj in vars(bnbq).items():
        if isinstance(obj, type) and hasattr(obj, "check_quantized_param_shape"):
            setattr(obj, "check_quantized_param_shape", patched_check)


_patch_diffusers_bnb_shape_check()


pipe = PiFlux2Pipeline.from_pretrained(
    'diffusers/FLUX.2-dev-bnb-4bit',
    torch_dtype=torch.bfloat16)
pipe.load_piflow_adapter(
    'Lakonik/pi-FLUX.2',
    subfolder='gmflux2_k8_piid_4step',
    target_module_name='transformer')
pipe.scheduler = FlowMapSDEScheduler.from_config(  # use fixed shift=3.2
    pipe.scheduler.config, shift=3.2, use_dynamic_shifting=False, final_step_size_scale=0.5)
pipe = pipe.to('cuda')

prompt_rewriter = Qwen3VLPromptRewriter(
    device_map="cuda",
    system_prompt_text_only=open(SYSTEM_PROMPT_TEXT_ONLY_PATH, 'r').read(),
    system_prompt_wigh_images=open(SYSTEM_PROMPT_WITH_IMAGES_PATH, 'r').read(),
    max_new_tokens_default=512,
)


def set_random_seed(seed: int, deterministic: bool = True) -> None:
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    if deterministic:
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False


@spaces.GPU
def run_rewrite_prompt_gpu(seed, prompt, image_list, progress):
    set_random_seed(seed)
    progress(0.05, desc="Rewriting prompt...")
    if image_list is None:
        final_prompt = prompt_rewriter.rewrite_text_batch(
            [prompt])[0]
    else:
        final_prompt = prompt_rewriter.rewrite_edit_batch(
            [image_list], [prompt])[0]
    return final_prompt


def run_rewrite_prompt(seed, prompt, rewrite_prompt, in_image, progress=gr.Progress(track_tqdm=True)):
    image_list = None
    if in_image is not None and len(in_image) > 0:
        image_list = []
        for item in in_image:
            image_list.append(item[0])
    if rewrite_prompt:
        final_prompt = run_rewrite_prompt_gpu(seed, prompt, image_list, progress)
        return final_prompt, None
    else:
        return '', None


@spaces.GPU
def generate(
        seed, prompt, rewrite_prompt, rewritten_prompt, in_image, width, height, steps,
        progress=gr.Progress(track_tqdm=True)):
    image_list = None
    if in_image is not None and len(in_image) > 0:
        image_list = []
        for item in in_image:
            image_list.append(item[0])
    return pipe(
        image=image_list,
        prompt=rewritten_prompt if rewrite_prompt else prompt,
        width=width,
        height=height,
        num_inference_steps=steps,
        generator=torch.Generator().manual_seed(seed),
    ).images[0]


with gr.Blocks(analytics_enabled=False,
               title='pi-FLUX.2 Demo',
               css_paths='lakonlab/ui/gradio/style.css'
               ) as demo:

    md_txt = '# pi-FLUX.2 Demo\n\n' \
             'Official demo of the paper [pi-Flow: Policy-Based Few-Step Generation via Imitation Distillation](https://arxiv.org/abs/2510.14974). ' \
             '**Base model:** [FLUX.2 dev](https://huggingface.co/black-forest-labs/FLUX.2-dev). **Fast policy:** GMFlow. **Code:** [https://github.com/Lakonik/piFlow](https://github.com/Lakonik/piFlow).\n' \
             '<br> Use and distribution of this app are governed by the [FLUX [dev] Non-Commercial License](https://huggingface.co/black-forest-labs/FLUX.2-dev/blob/main/LICENSE.txt).'
    gr.Markdown(md_txt)

    create_interface_img_edit(
        generate,
        prompt=DEFAULT_PROMPT,
        steps=4, guidance_scale=None,
        args=['last_seed', 'prompt', 'rewrite_prompt', 'rewritten_prompt', 'in_image', 'width', 'height', 'steps'],
        rewrite_prompt_api=run_rewrite_prompt,
        rewrite_prompt_args=['last_seed', 'prompt', 'rewrite_prompt', 'in_image'],
        height=1024,
        width=1024
    )
    demo.queue().launch()