Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
|
@@ -1,3 +1,4 @@
|
|
|
|
|
| 1 |
import argparse
|
| 2 |
import os
|
| 3 |
import tempfile
|
|
@@ -96,6 +97,7 @@ def adjust_location(x0, y0, x1, y1, input_image):
|
|
| 96 |
draw.rectangle([(x0,y0),(x1,y1)], outline="red", width=5)
|
| 97 |
return x_0, y_0, x_1, y_1, concat_img
|
| 98 |
|
|
|
|
| 99 |
def prepare_data(device, input_image, x0, y0, x1, y1, polar, azimuth, text):
|
| 100 |
if input_image.size[0] != 256 or input_image.size[1] != 256:
|
| 101 |
input_image = input_image.resize((256, 256))
|
|
@@ -127,6 +129,7 @@ def prepare_data(device, input_image, x0, y0, x1, y1, polar, azimuth, text):
|
|
| 127 |
return batch
|
| 128 |
|
| 129 |
|
|
|
|
| 130 |
@torch.no_grad()
|
| 131 |
def run_generation(sampler, model, device, input_image, x0, y0, x1, y1, polar, azimuth, text, seed):
|
| 132 |
seed_everything(seed)
|
|
@@ -182,6 +185,7 @@ def load_example(input_image, x0, y0, x1, y1, polar, azimuth, prompt):
|
|
| 182 |
# print(type(polar))
|
| 183 |
return input_image, x0, y0, x1, y1, polar, azimuth, prompt
|
| 184 |
|
|
|
|
| 185 |
@torch.no_grad()
|
| 186 |
def main(args):
|
| 187 |
# load model
|
|
|
|
| 1 |
+
import spaces
|
| 2 |
import argparse
|
| 3 |
import os
|
| 4 |
import tempfile
|
|
|
|
| 97 |
draw.rectangle([(x0,y0),(x1,y1)], outline="red", width=5)
|
| 98 |
return x_0, y_0, x_1, y_1, concat_img
|
| 99 |
|
| 100 |
+
@spaces.GPU
|
| 101 |
def prepare_data(device, input_image, x0, y0, x1, y1, polar, azimuth, text):
|
| 102 |
if input_image.size[0] != 256 or input_image.size[1] != 256:
|
| 103 |
input_image = input_image.resize((256, 256))
|
|
|
|
| 129 |
return batch
|
| 130 |
|
| 131 |
|
| 132 |
+
@spaces.GPU
|
| 133 |
@torch.no_grad()
|
| 134 |
def run_generation(sampler, model, device, input_image, x0, y0, x1, y1, polar, azimuth, text, seed):
|
| 135 |
seed_everything(seed)
|
|
|
|
| 185 |
# print(type(polar))
|
| 186 |
return input_image, x0, y0, x1, y1, polar, azimuth, prompt
|
| 187 |
|
| 188 |
+
@spaces.GPU
|
| 189 |
@torch.no_grad()
|
| 190 |
def main(args):
|
| 191 |
# load model
|