Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
|
@@ -38,9 +38,6 @@ def encode(init_image, torch_device, ae):
|
|
| 38 |
init_image = torch.from_numpy(init_image).permute(2, 0, 1).float() / 127.5 - 1
|
| 39 |
init_image = init_image.unsqueeze(0)
|
| 40 |
init_image = init_image.to(torch_device)
|
| 41 |
-
print("!!!!!!!init_image!!!!!!",init_image.device)
|
| 42 |
-
print("!!!!!!!ae!!!!!!",next(ae.parameters()).device)
|
| 43 |
-
|
| 44 |
with torch.no_grad():
|
| 45 |
init_image = ae.encode(init_image.to()).to(torch.bfloat16)
|
| 46 |
return init_image
|
|
@@ -65,20 +62,22 @@ class FluxEditor:
|
|
| 65 |
# init all components
|
| 66 |
self.t5 = load_t5(self.device, max_length=256 if self.name == "flux-schnell" else 512)
|
| 67 |
self.clip = load_clip(self.device)
|
| 68 |
-
self.model = load_flow_model(self.name, device=
|
| 69 |
-
self.ae = load_ae(self.name, device=
|
| 70 |
self.t5.eval()
|
| 71 |
self.clip.eval()
|
| 72 |
self.ae.eval()
|
| 73 |
self.model.eval()
|
| 74 |
-
|
| 75 |
-
self.
|
| 76 |
-
|
| 77 |
-
|
|
|
|
| 78 |
|
| 79 |
@torch.inference_mode()
|
| 80 |
@spaces.GPU(duration=60)
|
| 81 |
def edit(self, init_image, source_prompt, target_prompt, num_steps, inject_step, guidance, seed):
|
|
|
|
| 82 |
seed = None
|
| 83 |
# if seed == -1:
|
| 84 |
# seed = None
|
|
@@ -112,6 +111,11 @@ class FluxEditor:
|
|
| 112 |
t0 = time.perf_counter()
|
| 113 |
|
| 114 |
opts.seed = None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 115 |
#############inverse#######################
|
| 116 |
info = {}
|
| 117 |
info['feature'] = {}
|
|
@@ -125,6 +129,12 @@ class FluxEditor:
|
|
| 125 |
inp_target = prepare(self.t5, self.clip, init_image, prompt=opts.target_prompt)
|
| 126 |
timesteps = get_schedule(opts.num_steps, inp["img"].shape[1], shift=(self.name != "flux-schnell"))
|
| 127 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 128 |
# inversion initial noise
|
| 129 |
with torch.no_grad():
|
| 130 |
z, info = denoise(self.model, **inp, timesteps=timesteps, guidance=1, inverse=True, info=info)
|
|
@@ -136,6 +146,12 @@ class FluxEditor:
|
|
| 136 |
# denoise initial noise
|
| 137 |
x, _ = denoise(self.model, **inp_target, timesteps=timesteps, guidance=guidance, inverse=False, info=info)
|
| 138 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 139 |
# decode latents to pixel space
|
| 140 |
x = unpack(x.float(), opts.width, opts.height)
|
| 141 |
|
|
@@ -171,7 +187,7 @@ class FluxEditor:
|
|
| 171 |
exif_data[ExifTags.Base.Model] = self.name
|
| 172 |
if self.add_sampling_metadata:
|
| 173 |
exif_data[ExifTags.Base.ImageDescription] = source_prompt
|
| 174 |
-
|
| 175 |
|
| 176 |
|
| 177 |
print("End Edit")
|
|
@@ -226,5 +242,5 @@ if __name__ == "__main__":
|
|
| 226 |
parser.add_argument("--port", type=int, default=41035)
|
| 227 |
args = parser.parse_args()
|
| 228 |
|
| 229 |
-
demo = create_demo(
|
| 230 |
-
demo.launch()
|
|
|
|
| 38 |
init_image = torch.from_numpy(init_image).permute(2, 0, 1).float() / 127.5 - 1
|
| 39 |
init_image = init_image.unsqueeze(0)
|
| 40 |
init_image = init_image.to(torch_device)
|
|
|
|
|
|
|
|
|
|
| 41 |
with torch.no_grad():
|
| 42 |
init_image = ae.encode(init_image.to()).to(torch.bfloat16)
|
| 43 |
return init_image
|
|
|
|
| 62 |
# init all components
|
| 63 |
self.t5 = load_t5(self.device, max_length=256 if self.name == "flux-schnell" else 512)
|
| 64 |
self.clip = load_clip(self.device)
|
| 65 |
+
self.model = load_flow_model(self.name, device="cpu" if self.offload else self.device)
|
| 66 |
+
self.ae = load_ae(self.name, device="cpu" if self.offload else self.device)
|
| 67 |
self.t5.eval()
|
| 68 |
self.clip.eval()
|
| 69 |
self.ae.eval()
|
| 70 |
self.model.eval()
|
| 71 |
+
|
| 72 |
+
if self.offload:
|
| 73 |
+
self.model.cpu()
|
| 74 |
+
torch.cuda.empty_cache()
|
| 75 |
+
self.ae.encoder.to(self.device)
|
| 76 |
|
| 77 |
@torch.inference_mode()
|
| 78 |
@spaces.GPU(duration=60)
|
| 79 |
def edit(self, init_image, source_prompt, target_prompt, num_steps, inject_step, guidance, seed):
|
| 80 |
+
torch.cuda.empty_cache()
|
| 81 |
seed = None
|
| 82 |
# if seed == -1:
|
| 83 |
# seed = None
|
|
|
|
| 111 |
t0 = time.perf_counter()
|
| 112 |
|
| 113 |
opts.seed = None
|
| 114 |
+
if self.offload:
|
| 115 |
+
self.ae = self.ae.cpu()
|
| 116 |
+
torch.cuda.empty_cache()
|
| 117 |
+
self.t5, self.clip = self.t5.to(self.device), self.clip.to(self.device)
|
| 118 |
+
|
| 119 |
#############inverse#######################
|
| 120 |
info = {}
|
| 121 |
info['feature'] = {}
|
|
|
|
| 129 |
inp_target = prepare(self.t5, self.clip, init_image, prompt=opts.target_prompt)
|
| 130 |
timesteps = get_schedule(opts.num_steps, inp["img"].shape[1], shift=(self.name != "flux-schnell"))
|
| 131 |
|
| 132 |
+
# offload TEs to CPU, load model to gpu
|
| 133 |
+
if self.offload:
|
| 134 |
+
self.t5, self.clip = self.t5.cpu(), self.clip.cpu()
|
| 135 |
+
torch.cuda.empty_cache()
|
| 136 |
+
self.model = self.model.to(self.device)
|
| 137 |
+
|
| 138 |
# inversion initial noise
|
| 139 |
with torch.no_grad():
|
| 140 |
z, info = denoise(self.model, **inp, timesteps=timesteps, guidance=1, inverse=True, info=info)
|
|
|
|
| 146 |
# denoise initial noise
|
| 147 |
x, _ = denoise(self.model, **inp_target, timesteps=timesteps, guidance=guidance, inverse=False, info=info)
|
| 148 |
|
| 149 |
+
# offload model, load autoencoder to gpu
|
| 150 |
+
if self.offload:
|
| 151 |
+
self.model.cpu()
|
| 152 |
+
torch.cuda.empty_cache()
|
| 153 |
+
self.ae.decoder.to(x.device)
|
| 154 |
+
|
| 155 |
# decode latents to pixel space
|
| 156 |
x = unpack(x.float(), opts.width, opts.height)
|
| 157 |
|
|
|
|
| 187 |
exif_data[ExifTags.Base.Model] = self.name
|
| 188 |
if self.add_sampling_metadata:
|
| 189 |
exif_data[ExifTags.Base.ImageDescription] = source_prompt
|
| 190 |
+
img.save(fn, exif=exif_data, quality=95, subsampling=0)
|
| 191 |
|
| 192 |
|
| 193 |
print("End Edit")
|
|
|
|
| 242 |
parser.add_argument("--port", type=int, default=41035)
|
| 243 |
args = parser.parse_args()
|
| 244 |
|
| 245 |
+
demo = create_demo(args.name, args.device)
|
| 246 |
+
demo.launch()
|