lunarflu HF Staff commited on
Commit
54b8e0e
·
1 Parent(s): 8b9d958

[deepfloydif.py] stage2 testing

Browse files
Files changed (1) hide show
  1. deepfloydif.py +123 -48
deepfloydif.py CHANGED
@@ -10,8 +10,6 @@ from PIL import Image
10
  import asyncio
11
  import glob
12
 
13
- MY_GUILD_ID = 1077674588122648679 if os.getenv("TEST_ENV", False) else 879548962464493619
14
- MY_GUILD = discord.Object(id=MY_GUILD_ID)
15
  HF_TOKEN = os.getenv('HF_TOKEN')
16
  DISCORD_TOKEN = os.environ.get("DISCORD_TOKEN", None)
17
  deepfloyd_client = Client("huggingface-projects/IF", HF_TOKEN)
@@ -19,86 +17,84 @@ deepfloyd_client = Client("huggingface-projects/IF", HF_TOKEN)
19
  BOT_USER_ID = 1086256910572986469
20
  DEEPFLOYD_CHANNEL_ID = 1121834257959092234
21
 
22
- #-------------------------------------------------------------------------------------------------------------------------------------
23
- # deepfloydif stage 1 generation
24
- def inference(prompt):
25
  negative_prompt = ''
26
  seed = random.randint(0, 1000)
27
- #seed = 1
28
  number_of_images = 4
29
  guidance_scale = 7
30
  custom_timesteps_1 = 'smart50'
31
  number_of_inference_steps = 50
32
 
33
- stage_1_results, stage_1_param_path, stage_1_result_path = deepfloyd_client.predict(
34
- prompt, negative_prompt, seed, number_of_images, guidance_scale, custom_timesteps_1, number_of_inference_steps, api_name='/generate64')
 
 
 
 
 
 
 
 
 
 
 
35
 
36
- return [stage_1_results, stage_1_param_path, stage_1_result_path]
37
- #-------------------------------------------------------------------------------------------------------------------------------------
38
- async def try_deepfloydif(interaction, prompt, client):
39
  thread = None
40
  try:
41
  global BOT_USER_ID
42
  global DEEPFLOYD_CHANNEL_ID
43
  if interaction.user.id != BOT_USER_ID:
44
  if interaction.channel.id == DEEPFLOYD_CHANNEL_ID:
45
- await interaction.response.send_message(f"Working on it!")
46
  channel = interaction.channel
47
- message = await channel.send(f"DeepfloydIF Thread")
48
- #await message.add_reaction('🔁')
 
49
  thread = await message.create_thread(name=f'{prompt}', auto_archive_duration=60)
50
- await thread.send(f"[DISCLAIMER: HuggingBot is a **highly experimental** beta feature; Additional information" \
51
- f" on the DeepfloydIF model can be found here: https://huggingface.co/spaces/DeepFloyd/IF")
52
 
53
  dfif_command_message_id = message.id # used for updating the 'status' of our generations using reaction emojis
54
 
55
- await thread.send(f'{interaction.user.mention}Generating images in thread, can take ~1 minute...')
56
 
57
- # generation
58
  loop = asyncio.get_running_loop()
59
- result = await loop.run_in_executor(None, inference, prompt)
60
  stage_1_results = result[0]
61
  stage_1_result_path = result[2]
62
- partialpath = stage_1_result_path[5:]
63
  png_files = list(glob.glob(f"{stage_1_results}/**/*.png"))
64
 
65
- img1 = None
66
- img2 = None
67
- img3 = None
68
- img4 = None
69
-
70
  if png_files:
 
71
  png_file_index = 0
72
  images = load_image(png_files, stage_1_results, png_file_index)
73
- img1 = images[0]
74
- img2 = images[1]
75
- img3 = images[2]
76
- img4 = images[3]
77
-
78
- combined_image = Image.new('RGB', (img1.width * 2, img1.height * 2))
79
-
80
- combined_image.paste(img1, (0, 0))
81
- combined_image.paste(img2, (img1.width, 0))
82
- combined_image.paste(img3, (0, img1.height))
83
- combined_image.paste(img4, (img1.width, img1.height))
84
-
85
- combined_image_path = os.path.join(stage_1_results, f'{partialpath}{dfif_command_message_id}.png')
86
  combined_image.save(combined_image_path)
87
 
88
  with open(combined_image_path, 'rb') as f:
89
- combined_image_dfif = await thread.send(f'{interaction.user.mention}React with the image number you want to upscale!', file=discord.File(
90
- f, f'{partialpath}{dfif_command_message_id}.png')) # named something like: tmpgtv4qjix1111269940599738479.png
91
 
92
  emoji_list = ['↖️', '↗️', '↙️', '↘️']
93
- await react1234(emoji_list, combined_image_dfif)
94
-
95
- #await message.remove_reaction('🔁', client.user)
96
- #await message.add_reaction('✔️')
 
97
 
98
  except Exception as e:
99
  print(f"Error: {e}")
100
- #-------------------------------------------------------------------------------------------------------------------------------------
101
  def load_image(png_files, stage_1_results, png_file_index):
 
102
  for file in png_files:
103
  png_file = png_files[png_file_index]
104
  png_path = os.path.join(stage_1_results, png_file)
@@ -112,9 +108,88 @@ def load_image(png_files, stage_1_results, png_file_index):
112
  img4 = Image.open(png_path)
113
  png_file_index = png_file_index + 1
114
  return [img1, img2, img3, img4]
115
- #-------------------------------------------------------------------------------------------------------------------------------------
116
- #-------------------------------------------------------------------------------------------------------------------------------------
117
- async def react1234(reaction_emojis, combined_image_dfif):
118
  for emoji in reaction_emojis:
119
  await combined_image_dfif.add_reaction(emoji)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
120
 
 
10
  import asyncio
11
  import glob
12
 
 
 
13
  HF_TOKEN = os.getenv('HF_TOKEN')
14
  DISCORD_TOKEN = os.environ.get("DISCORD_TOKEN", None)
15
  deepfloyd_client = Client("huggingface-projects/IF", HF_TOKEN)
 
17
  BOT_USER_ID = 1086256910572986469
18
  DEEPFLOYD_CHANNEL_ID = 1121834257959092234
19
 
20
+
21
+ def deepfloyd_stage_1_inference(prompt):
22
+ """Generates an image based on a prompt"""
23
  negative_prompt = ''
24
  seed = random.randint(0, 1000)
 
25
  number_of_images = 4
26
  guidance_scale = 7
27
  custom_timesteps_1 = 'smart50'
28
  number_of_inference_steps = 50
29
 
30
+ stage_1_results, stage_1_param_path, stage_1_result_path = deepfloyd_client.predict(prompt, negative_prompt, seed, number_of_images, guidance_scale, custom_timesteps_1, number_of_inference_steps, api_name='/generate64')
31
+
32
+ return [stage_1_results, stage_1_param_path, stage_1_result_path]
33
+
34
+ def deepfloyd_stage_2_inference(index, stage_1_result_path):
35
+ """Upscales one of the images from deepfloyd_stage_1_inference based on the chosen index"""
36
+ selected_index_for_stage_2 = index
37
+ seed_2 = 0
38
+ guidance_scale_2 = 4
39
+ custom_timesteps_2 = 'smart50'
40
+ number_of_inference_steps_2 = 50
41
+ result_path = deepfloyd_client.predict(stage_1_result_path, selected_index_for_stage_2, seed_2, guidance_scale_2, custom_timesteps_2, number_of_inference_steps_2, api_name='/upscale256')
42
+ return result_path
43
 
44
+ async def deepfloydif_stage_1(interaction, prompt, client):
45
+ """DeepfloydIF command (generate images with realistic text using slash commands)"""
 
46
  thread = None
47
  try:
48
  global BOT_USER_ID
49
  global DEEPFLOYD_CHANNEL_ID
50
  if interaction.user.id != BOT_USER_ID:
51
  if interaction.channel.id == DEEPFLOYD_CHANNEL_ID:
52
+ await interaction.response.send_message("Working on it!")
53
  channel = interaction.channel
54
+ # interaction.response message can't be used to create a thread, so we create another message
55
+ message = await channel.send("DeepfloydIF Thread")
56
+ #await message.add_reaction('<a:loading:1114111677990981692>')
57
  thread = await message.create_thread(name=f'{prompt}', auto_archive_duration=60)
58
+ await thread.send("[DISCLAIMER: HuggingBot is a **highly experimental** beta feature; Additional information on the DeepfloydIF model can be found here: https://huggingface.co/spaces/DeepFloyd/IF")
 
59
 
60
  dfif_command_message_id = message.id # used for updating the 'status' of our generations using reaction emojis
61
 
62
+ await thread.send(f'{interaction.user.mention} Generating images in thread, can take ~1 minute...')
63
 
 
64
  loop = asyncio.get_running_loop()
65
+ result = await loop.run_in_executor(None, deepfloyd_stage_1_inference, prompt)
66
  stage_1_results = result[0]
67
  stage_1_result_path = result[2]
68
+ partial_path = pathlib.Path(stage_1_result_path).name
69
  png_files = list(glob.glob(f"{stage_1_results}/**/*.png"))
70
 
 
 
 
 
 
71
  if png_files:
72
+ # take all 4 images and combine them into one large 2x2 image (similar to Midjourney)
73
  png_file_index = 0
74
  images = load_image(png_files, stage_1_results, png_file_index)
75
+ combined_image = Image.new('RGB', (images[0].width * 2, images[0].height * 2))
76
+ combined_image.paste(images[0], (0, 0))
77
+ combined_image.paste(images[1], (images[0].width, 0))
78
+ combined_image.paste(images[2], (0, images[0].height))
79
+ combined_image.paste(images[3], (images[0].width, images[0].height))
80
+ combined_image_path = os.path.join(stage_1_results, f'{partial_path}{dfif_command_message_id}.png')
 
 
 
 
 
 
 
81
  combined_image.save(combined_image_path)
82
 
83
  with open(combined_image_path, 'rb') as f:
84
+ combined_image_dfif = await thread.send(f'{interaction.user.mention} React with the image number you want to upscale!', file=discord.File(f, f'{partial_path}{dfif_command_message_id}.png'))
 
85
 
86
  emoji_list = ['↖️', '↗️', '↙️', '↘️']
87
+ await react_1234(emoji_list, combined_image_dfif)
88
+ #await message.remove_reaction('<a:loading:1114111677990981692>', client.user)
89
+ #await message.add_reaction('<:agree:1098629085955113011>')
90
+ else:
91
+ await thread.send(f'{interaction.user.mention} No PNG files were found, cannot post them!')
92
 
93
  except Exception as e:
94
  print(f"Error: {e}")
95
+
96
  def load_image(png_files, stage_1_results, png_file_index):
97
+ """Opens images as variables so we can combine them later"""
98
  for file in png_files:
99
  png_file = png_files[png_file_index]
100
  png_path = os.path.join(stage_1_results, png_file)
 
108
  img4 = Image.open(png_path)
109
  png_file_index = png_file_index + 1
110
  return [img1, img2, img3, img4]
111
+
112
+ async def react_1234(reaction_emojis, combined_image_dfif):
113
+ """Sets up 4 reaction emojis so the user can choose an image to upscale for deepfloydif"""
114
  for emoji in reaction_emojis:
115
  await combined_image_dfif.add_reaction(emoji)
116
+
117
+ async def deepfloydif_stage_2(index: int, stage_1_result_path, thread, dfif_command_message_id):
118
+ """upscaling function for images generated using /deepfloydif"""
119
+ try:
120
+ parent_channel = thread.parent
121
+ dfif_command_message = await parent_channel.fetch_message(dfif_command_message_id)
122
+ #await dfif_command_message.remove_reaction('<:agree:1098629085955113011>', client.user)
123
+ #await dfif_command_message.add_reaction('<a:loading:1114111677990981692>')
124
+ if index == 0:
125
+ position = "top left"
126
+ elif index == 1:
127
+ position = "top right"
128
+ elif index == 2:
129
+ position = "bottom left"
130
+ elif index == 3:
131
+ position = "bottom right"
132
+ await thread.send(f"Upscaling the {position} image...")
133
+
134
+ # run blocking function in executor
135
+ loop = asyncio.get_running_loop()
136
+ result_path = await loop.run_in_executor(None, deepfloyd_stage_2_inference, index, stage_1_result_path)
137
+
138
+ with open(result_path, 'rb') as f:
139
+ await thread.send('Here is the upscaled image!', file=discord.File(f, 'result.png'))
140
+
141
+ #await dfif_command_message.remove_reaction('<a:loading:1114111677990981692>', client.user)
142
+ #await dfif_command_message.add_reaction('<:agree:1098629085955113011>')
143
+ await thread.edit(archived=True)
144
+
145
+ except Exception as e:
146
+ print(f"Error: {e}")
147
+ parent_channel = thread.parent
148
+ dfif_command_message = await parent_channel.fetch_message(dfif_command_message_id)
149
+ #await dfif_command_message.remove_reaction('<a:loading:1114111677990981692>', client.user)
150
+ #await dfif_command_message.add_reaction('<:disagree:1098628957521313892>')
151
+ await thread.send(f"Error during stage 2 upscaling, {e}")
152
+ await thread.edit(archived=True)
153
+
154
+
155
+
156
+
157
+
158
+
159
+
160
+
161
+
162
+ async def deepfloyd_stage_2_react_check(reaction, user):
163
+ """Checks for a reaction in order to call dfif2"""
164
+ try:
165
+ global BOT_USER_ID
166
+ global DEEPFLOYD_CHANNEL_ID
167
+ if user.id != BOT_USER_ID: #
168
+ thread = reaction.message.channel
169
+ thread_parent_id = thread.parent.id
170
+ if thread_parent_id == DEEPFLOYD_CHANNEL_ID:
171
+ if reaction.message.attachments:
172
+ if user.id == reaction.message.mentions[0].id:
173
+ attachment = reaction.message.attachments[0]
174
+ image_name = attachment.filename
175
+ partial_path_message_id = image_name[:-4]
176
+ partial_path = partial_path_message_id[:11]
177
+ message_id = partial_path_message_id[11:]
178
+ full_path = "/tmp/" + partial_path
179
+ emoji = reaction.emoji
180
+ if emoji == "↖️":
181
+ index = 0
182
+ elif emoji == "↗️":
183
+ index = 1
184
+ elif emoji == "↙️":
185
+ index = 2
186
+ elif emoji == "↘️":
187
+ index = 3
188
+ stage_1_result_path = full_path
189
+ thread = reaction.message.channel
190
+ dfif_command_message_id = message_id
191
+ await dfif2(index, stage_1_result_path, thread, dfif_command_message_id)
192
+
193
+ except Exception as e:
194
+ print(f"Error: {e} (known error, does not cause issues, low priority)")
195