File size: 6,455 Bytes
5d99e98 39da461 5d99e98 39da461 5d99e98 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 |
import torch
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
import os
import random
import numpy as np
import gradio as gr
import spaces
from lakonlab.models.diffusions.schedulers.flow_map_sde import FlowMapSDEScheduler
from lakonlab.ui.gradio.create_img_edit import create_interface_img_edit
from lakonlab.pipelines.pipeline_piflux2 import PiFlux2Pipeline
from lakonlab.pipelines.prompt_rewriters.qwen3_vl import Qwen3VLPromptRewriter
DEFAULT_PROMPT = """Museum-style FIELD GUIDE poster on neutral parchment (#F3EEE3). Use Inter (or Helvetica/Arial). All text #2D3748, thin connector lines 1px #A0AEC0.
Center: full-body original fantasy creature, 3/4 standing pose. Around it: four small inset boxes labeled exactly "EYE DETAIL", "FOOT DETAIL", "SKIN TEXTURE", "SILHOUETTE SCALE" (with a simple human comparison silhouette). Bottom: a short footprint trail diagram. One small habitat vignette (misty rocky shoreline with tide pools).
Exact text (only these, clean print layout):
Top: "FIELD GUIDE"
Sub: "AURORA SHOREWALKER"
Small line: "CLASS: COASTAL DRIFTER"
Under silhouette: "HEIGHT: 1.7 m"
Crisp ink outlines with soft watercolor-like fills, high readability, balanced hierarchy, premium poster aesthetic."""
SYSTEM_PROMPT_TEXT_ONLY_PATH = 'lakonlab/pipelines/prompt_rewriters/system_prompts/default_text_only.txt'
SYSTEM_PROMPT_WITH_IMAGES_PATH = 'lakonlab/pipelines/prompt_rewriters/system_prompts/default_with_images.txt'
def _patch_diffusers_bnb_shape_check():
try:
import diffusers.quantizers.bitsandbytes.bnb_quantizer as bnbq
except Exception:
return
def _numel(shape):
if shape is None:
return None
if hasattr(shape, "numel"): # torch.Size
return int(shape.numel())
# plain tuple/list
n = 1
for d in shape:
n *= int(d)
return n
def patched_check(self, param_name, current_param, loaded_param):
cshape = getattr(current_param, "shape", None)
lshape = getattr(loaded_param, "shape", None)
n = _numel(cshape)
inferred_shape = (n,) if "bias" in param_name else ((n + 1) // 2, 1)
if tuple(lshape) != tuple(inferred_shape):
raise ValueError(
f"Expected flattened shape mismatch for {param_name}: "
f"loaded={tuple(lshape)} inferred={tuple(inferred_shape)}"
)
return True
# Patch any quantizer class in that module that defines the method
for name, obj in vars(bnbq).items():
if isinstance(obj, type) and hasattr(obj, "check_quantized_param_shape"):
setattr(obj, "check_quantized_param_shape", patched_check)
_patch_diffusers_bnb_shape_check()
pipe = PiFlux2Pipeline.from_pretrained(
'diffusers/FLUX.2-dev-bnb-4bit',
torch_dtype=torch.bfloat16)
pipe.load_piflow_adapter(
'Lakonik/pi-FLUX.2',
subfolder='gmflux2_k8_piid_4step',
target_module_name='transformer')
pipe.scheduler = FlowMapSDEScheduler.from_config( # use fixed shift=3.2
pipe.scheduler.config, shift=3.2, use_dynamic_shifting=False, final_step_size_scale=0.5)
pipe = pipe.to('cuda')
prompt_rewriter = Qwen3VLPromptRewriter(
device_map="cuda",
system_prompt_text_only=open(SYSTEM_PROMPT_TEXT_ONLY_PATH, 'r').read(),
system_prompt_wigh_images=open(SYSTEM_PROMPT_WITH_IMAGES_PATH, 'r').read(),
max_new_tokens_default=512,
)
def set_random_seed(seed: int, deterministic: bool = True) -> None:
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
os.environ['PYTHONHASHSEED'] = str(seed)
if deterministic:
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
@spaces.GPU
def run_rewrite_prompt_gpu(seed, prompt, image_list, progress):
set_random_seed(seed)
progress(0.05, desc="Rewriting prompt...")
if image_list is None:
final_prompt = prompt_rewriter.rewrite_text_batch(
[prompt])[0]
else:
final_prompt = prompt_rewriter.rewrite_edit_batch(
[image_list], [prompt])[0]
return final_prompt
def run_rewrite_prompt(seed, prompt, rewrite_prompt, in_image, progress=gr.Progress(track_tqdm=True)):
image_list = None
if in_image is not None and len(in_image) > 0:
image_list = []
for item in in_image:
image_list.append(item[0])
if rewrite_prompt:
final_prompt = run_rewrite_prompt_gpu(seed, prompt, image_list, progress)
return final_prompt, None
else:
return '', None
@spaces.GPU
def generate(
seed, prompt, rewrite_prompt, rewritten_prompt, in_image, width, height, steps,
progress=gr.Progress(track_tqdm=True)):
image_list = None
if in_image is not None and len(in_image) > 0:
image_list = []
for item in in_image:
image_list.append(item[0])
return pipe(
image=image_list,
prompt=rewritten_prompt if rewrite_prompt else prompt,
width=width,
height=height,
num_inference_steps=steps,
generator=torch.Generator().manual_seed(seed),
).images[0]
with gr.Blocks(analytics_enabled=False,
title='pi-FLUX.2 Demo',
css_paths='lakonlab/ui/gradio/style.css'
) as demo:
md_txt = '# pi-FLUX.2 Demo\n\n' \
'Official demo of the paper [pi-Flow: Policy-Based Few-Step Generation via Imitation Distillation](https://arxiv.org/abs/2510.14974). ' \
'**Base model:** [FLUX.2 dev](https://huggingface.co/black-forest-labs/FLUX.2-dev). **Fast policy:** GMFlow. **Code:** [https://github.com/Lakonik/piFlow](https://github.com/Lakonik/piFlow).\n' \
'<br> Use and distribution of this app are governed by the [FLUX [dev] Non-Commercial License](https://huggingface.co/black-forest-labs/FLUX.2-dev/blob/main/LICENSE.txt).'
gr.Markdown(md_txt)
create_interface_img_edit(
generate,
prompt=DEFAULT_PROMPT,
steps=4, guidance_scale=None,
args=['last_seed', 'prompt', 'rewrite_prompt', 'rewritten_prompt', 'in_image', 'width', 'height', 'steps'],
rewrite_prompt_api=run_rewrite_prompt,
rewrite_prompt_args=['last_seed', 'prompt', 'rewrite_prompt', 'in_image'],
height=1024,
width=1024
)
demo.queue().launch()
|