antonjeran commited on
Commit
b53f252
·
1 Parent(s): fd529ef

Upload 40 files

Browse files
Files changed (41) hide show
  1. .gitattributes +1 -0
  2. LICENSE +661 -0
  3. Poster.pdf +0 -0
  4. README.md +133 -3
  5. code_new/RT60.py +131 -0
  6. code_new/RTS.py +87 -0
  7. code_new/__init__py +2 -0
  8. code_new/__pycache__/RT60.cpython-36.pyc +0 -0
  9. code_new/__pycache__/RT60.cpython-38.pyc +0 -0
  10. code_new/__pycache__/RTS.cpython-38.pyc +0 -0
  11. code_new/__pycache__/model.cpython-36.pyc +0 -0
  12. code_new/__pycache__/model.cpython-38.pyc +0 -0
  13. code_new/__pycache__/trainer.cpython-36.pyc +0 -0
  14. code_new/__pycache__/trainer.cpython-38.pyc +0 -0
  15. code_new/cfg/RIR_eval.yml +25 -0
  16. code_new/cfg/RIR_s1.yml +32 -0
  17. code_new/cfg/RIR_s1_temp.yml +32 -0
  18. code_new/main.py +72 -0
  19. code_new/miscc/__init__.py +2 -0
  20. code_new/miscc/__init__.pyc +0 -0
  21. code_new/miscc/__pycache__/__init__.cpython-36.pyc +0 -0
  22. code_new/miscc/__pycache__/__init__.cpython-38.pyc +0 -0
  23. code_new/miscc/__pycache__/config.cpython-36.pyc +0 -0
  24. code_new/miscc/__pycache__/config.cpython-38.pyc +0 -0
  25. code_new/miscc/__pycache__/datasets.cpython-36.pyc +0 -0
  26. code_new/miscc/__pycache__/datasets.cpython-38.pyc +0 -0
  27. code_new/miscc/__pycache__/utils.cpython-36.pyc +0 -0
  28. code_new/miscc/__pycache__/utils.cpython-38.pyc +0 -0
  29. code_new/miscc/config.py +97 -0
  30. code_new/miscc/config.pyc +0 -0
  31. code_new/miscc/datasets.py +113 -0
  32. code_new/miscc/datasets.pyc +0 -0
  33. code_new/miscc/utils.py +239 -0
  34. code_new/miscc/utils.pyc +0 -0
  35. code_new/model.py +413 -0
  36. code_new/single_copy.py +46 -0
  37. code_new/trainer.py +392 -0
  38. download_data.sh +3 -0
  39. download_generate.sh +2 -0
  40. example1.py +20 -0
  41. slides.pptx +3 -0
.gitattributes CHANGED
@@ -32,3 +32,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
32
  *.zip filter=lfs diff=lfs merge=lfs -text
33
  *.zst filter=lfs diff=lfs merge=lfs -text
34
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
32
  *.zip filter=lfs diff=lfs merge=lfs -text
33
  *.zst filter=lfs diff=lfs merge=lfs -text
34
  *tfevents* filter=lfs diff=lfs merge=lfs -text
35
+ slides.pptx filter=lfs diff=lfs merge=lfs -text
LICENSE ADDED
@@ -0,0 +1,661 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ GNU AFFERO GENERAL PUBLIC LICENSE
2
+ Version 3, 19 November 2007
3
+
4
+ Copyright (C) 2007 Free Software Foundation, Inc. <https://fsf.org/>
5
+ Everyone is permitted to copy and distribute verbatim copies
6
+ of this license document, but changing it is not allowed.
7
+
8
+ Preamble
9
+
10
+ The GNU Affero General Public License is a free, copyleft license for
11
+ software and other kinds of works, specifically designed to ensure
12
+ cooperation with the community in the case of network server software.
13
+
14
+ The licenses for most software and other practical works are designed
15
+ to take away your freedom to share and change the works. By contrast,
16
+ our General Public Licenses are intended to guarantee your freedom to
17
+ share and change all versions of a program--to make sure it remains free
18
+ software for all its users.
19
+
20
+ When we speak of free software, we are referring to freedom, not
21
+ price. Our General Public Licenses are designed to make sure that you
22
+ have the freedom to distribute copies of free software (and charge for
23
+ them if you wish), that you receive source code or can get it if you
24
+ want it, that you can change the software or use pieces of it in new
25
+ free programs, and that you know you can do these things.
26
+
27
+ Developers that use our General Public Licenses protect your rights
28
+ with two steps: (1) assert copyright on the software, and (2) offer
29
+ you this License which gives you legal permission to copy, distribute
30
+ and/or modify the software.
31
+
32
+ A secondary benefit of defending all users' freedom is that
33
+ improvements made in alternate versions of the program, if they
34
+ receive widespread use, become available for other developers to
35
+ incorporate. Many developers of free software are heartened and
36
+ encouraged by the resulting cooperation. However, in the case of
37
+ software used on network servers, this result may fail to come about.
38
+ The GNU General Public License permits making a modified version and
39
+ letting the public access it on a server without ever releasing its
40
+ source code to the public.
41
+
42
+ The GNU Affero General Public License is designed specifically to
43
+ ensure that, in such cases, the modified source code becomes available
44
+ to the community. It requires the operator of a network server to
45
+ provide the source code of the modified version running there to the
46
+ users of that server. Therefore, public use of a modified version, on
47
+ a publicly accessible server, gives the public access to the source
48
+ code of the modified version.
49
+
50
+ An older license, called the Affero General Public License and
51
+ published by Affero, was designed to accomplish similar goals. This is
52
+ a different license, not a version of the Affero GPL, but Affero has
53
+ released a new version of the Affero GPL which permits relicensing under
54
+ this license.
55
+
56
+ The precise terms and conditions for copying, distribution and
57
+ modification follow.
58
+
59
+ TERMS AND CONDITIONS
60
+
61
+ 0. Definitions.
62
+
63
+ "This License" refers to version 3 of the GNU Affero General Public License.
64
+
65
+ "Copyright" also means copyright-like laws that apply to other kinds of
66
+ works, such as semiconductor masks.
67
+
68
+ "The Program" refers to any copyrightable work licensed under this
69
+ License. Each licensee is addressed as "you". "Licensees" and
70
+ "recipients" may be individuals or organizations.
71
+
72
+ To "modify" a work means to copy from or adapt all or part of the work
73
+ in a fashion requiring copyright permission, other than the making of an
74
+ exact copy. The resulting work is called a "modified version" of the
75
+ earlier work or a work "based on" the earlier work.
76
+
77
+ A "covered work" means either the unmodified Program or a work based
78
+ on the Program.
79
+
80
+ To "propagate" a work means to do anything with it that, without
81
+ permission, would make you directly or secondarily liable for
82
+ infringement under applicable copyright law, except executing it on a
83
+ computer or modifying a private copy. Propagation includes copying,
84
+ distribution (with or without modification), making available to the
85
+ public, and in some countries other activities as well.
86
+
87
+ To "convey" a work means any kind of propagation that enables other
88
+ parties to make or receive copies. Mere interaction with a user through
89
+ a computer network, with no transfer of a copy, is not conveying.
90
+
91
+ An interactive user interface displays "Appropriate Legal Notices"
92
+ to the extent that it includes a convenient and prominently visible
93
+ feature that (1) displays an appropriate copyright notice, and (2)
94
+ tells the user that there is no warranty for the work (except to the
95
+ extent that warranties are provided), that licensees may convey the
96
+ work under this License, and how to view a copy of this License. If
97
+ the interface presents a list of user commands or options, such as a
98
+ menu, a prominent item in the list meets this criterion.
99
+
100
+ 1. Source Code.
101
+
102
+ The "source code" for a work means the preferred form of the work
103
+ for making modifications to it. "Object code" means any non-source
104
+ form of a work.
105
+
106
+ A "Standard Interface" means an interface that either is an official
107
+ standard defined by a recognized standards body, or, in the case of
108
+ interfaces specified for a particular programming language, one that
109
+ is widely used among developers working in that language.
110
+
111
+ The "System Libraries" of an executable work include anything, other
112
+ than the work as a whole, that (a) is included in the normal form of
113
+ packaging a Major Component, but which is not part of that Major
114
+ Component, and (b) serves only to enable use of the work with that
115
+ Major Component, or to implement a Standard Interface for which an
116
+ implementation is available to the public in source code form. A
117
+ "Major Component", in this context, means a major essential component
118
+ (kernel, window system, and so on) of the specific operating system
119
+ (if any) on which the executable work runs, or a compiler used to
120
+ produce the work, or an object code interpreter used to run it.
121
+
122
+ The "Corresponding Source" for a work in object code form means all
123
+ the source code needed to generate, install, and (for an executable
124
+ work) run the object code and to modify the work, including scripts to
125
+ control those activities. However, it does not include the work's
126
+ System Libraries, or general-purpose tools or generally available free
127
+ programs which are used unmodified in performing those activities but
128
+ which are not part of the work. For example, Corresponding Source
129
+ includes interface definition files associated with source files for
130
+ the work, and the source code for shared libraries and dynamically
131
+ linked subprograms that the work is specifically designed to require,
132
+ such as by intimate data communication or control flow between those
133
+ subprograms and other parts of the work.
134
+
135
+ The Corresponding Source need not include anything that users
136
+ can regenerate automatically from other parts of the Corresponding
137
+ Source.
138
+
139
+ The Corresponding Source for a work in source code form is that
140
+ same work.
141
+
142
+ 2. Basic Permissions.
143
+
144
+ All rights granted under this License are granted for the term of
145
+ copyright on the Program, and are irrevocable provided the stated
146
+ conditions are met. This License explicitly affirms your unlimited
147
+ permission to run the unmodified Program. The output from running a
148
+ covered work is covered by this License only if the output, given its
149
+ content, constitutes a covered work. This License acknowledges your
150
+ rights of fair use or other equivalent, as provided by copyright law.
151
+
152
+ You may make, run and propagate covered works that you do not
153
+ convey, without conditions so long as your license otherwise remains
154
+ in force. You may convey covered works to others for the sole purpose
155
+ of having them make modifications exclusively for you, or provide you
156
+ with facilities for running those works, provided that you comply with
157
+ the terms of this License in conveying all material for which you do
158
+ not control copyright. Those thus making or running the covered works
159
+ for you must do so exclusively on your behalf, under your direction
160
+ and control, on terms that prohibit them from making any copies of
161
+ your copyrighted material outside their relationship with you.
162
+
163
+ Conveying under any other circumstances is permitted solely under
164
+ the conditions stated below. Sublicensing is not allowed; section 10
165
+ makes it unnecessary.
166
+
167
+ 3. Protecting Users' Legal Rights From Anti-Circumvention Law.
168
+
169
+ No covered work shall be deemed part of an effective technological
170
+ measure under any applicable law fulfilling obligations under article
171
+ 11 of the WIPO copyright treaty adopted on 20 December 1996, or
172
+ similar laws prohibiting or restricting circumvention of such
173
+ measures.
174
+
175
+ When you convey a covered work, you waive any legal power to forbid
176
+ circumvention of technological measures to the extent such circumvention
177
+ is effected by exercising rights under this License with respect to
178
+ the covered work, and you disclaim any intention to limit operation or
179
+ modification of the work as a means of enforcing, against the work's
180
+ users, your or third parties' legal rights to forbid circumvention of
181
+ technological measures.
182
+
183
+ 4. Conveying Verbatim Copies.
184
+
185
+ You may convey verbatim copies of the Program's source code as you
186
+ receive it, in any medium, provided that you conspicuously and
187
+ appropriately publish on each copy an appropriate copyright notice;
188
+ keep intact all notices stating that this License and any
189
+ non-permissive terms added in accord with section 7 apply to the code;
190
+ keep intact all notices of the absence of any warranty; and give all
191
+ recipients a copy of this License along with the Program.
192
+
193
+ You may charge any price or no price for each copy that you convey,
194
+ and you may offer support or warranty protection for a fee.
195
+
196
+ 5. Conveying Modified Source Versions.
197
+
198
+ You may convey a work based on the Program, or the modifications to
199
+ produce it from the Program, in the form of source code under the
200
+ terms of section 4, provided that you also meet all of these conditions:
201
+
202
+ a) The work must carry prominent notices stating that you modified
203
+ it, and giving a relevant date.
204
+
205
+ b) The work must carry prominent notices stating that it is
206
+ released under this License and any conditions added under section
207
+ 7. This requirement modifies the requirement in section 4 to
208
+ "keep intact all notices".
209
+
210
+ c) You must license the entire work, as a whole, under this
211
+ License to anyone who comes into possession of a copy. This
212
+ License will therefore apply, along with any applicable section 7
213
+ additional terms, to the whole of the work, and all its parts,
214
+ regardless of how they are packaged. This License gives no
215
+ permission to license the work in any other way, but it does not
216
+ invalidate such permission if you have separately received it.
217
+
218
+ d) If the work has interactive user interfaces, each must display
219
+ Appropriate Legal Notices; however, if the Program has interactive
220
+ interfaces that do not display Appropriate Legal Notices, your
221
+ work need not make them do so.
222
+
223
+ A compilation of a covered work with other separate and independent
224
+ works, which are not by their nature extensions of the covered work,
225
+ and which are not combined with it such as to form a larger program,
226
+ in or on a volume of a storage or distribution medium, is called an
227
+ "aggregate" if the compilation and its resulting copyright are not
228
+ used to limit the access or legal rights of the compilation's users
229
+ beyond what the individual works permit. Inclusion of a covered work
230
+ in an aggregate does not cause this License to apply to the other
231
+ parts of the aggregate.
232
+
233
+ 6. Conveying Non-Source Forms.
234
+
235
+ You may convey a covered work in object code form under the terms
236
+ of sections 4 and 5, provided that you also convey the
237
+ machine-readable Corresponding Source under the terms of this License,
238
+ in one of these ways:
239
+
240
+ a) Convey the object code in, or embodied in, a physical product
241
+ (including a physical distribution medium), accompanied by the
242
+ Corresponding Source fixed on a durable physical medium
243
+ customarily used for software interchange.
244
+
245
+ b) Convey the object code in, or embodied in, a physical product
246
+ (including a physical distribution medium), accompanied by a
247
+ written offer, valid for at least three years and valid for as
248
+ long as you offer spare parts or customer support for that product
249
+ model, to give anyone who possesses the object code either (1) a
250
+ copy of the Corresponding Source for all the software in the
251
+ product that is covered by this License, on a durable physical
252
+ medium customarily used for software interchange, for a price no
253
+ more than your reasonable cost of physically performing this
254
+ conveying of source, or (2) access to copy the
255
+ Corresponding Source from a network server at no charge.
256
+
257
+ c) Convey individual copies of the object code with a copy of the
258
+ written offer to provide the Corresponding Source. This
259
+ alternative is allowed only occasionally and noncommercially, and
260
+ only if you received the object code with such an offer, in accord
261
+ with subsection 6b.
262
+
263
+ d) Convey the object code by offering access from a designated
264
+ place (gratis or for a charge), and offer equivalent access to the
265
+ Corresponding Source in the same way through the same place at no
266
+ further charge. You need not require recipients to copy the
267
+ Corresponding Source along with the object code. If the place to
268
+ copy the object code is a network server, the Corresponding Source
269
+ may be on a different server (operated by you or a third party)
270
+ that supports equivalent copying facilities, provided you maintain
271
+ clear directions next to the object code saying where to find the
272
+ Corresponding Source. Regardless of what server hosts the
273
+ Corresponding Source, you remain obligated to ensure that it is
274
+ available for as long as needed to satisfy these requirements.
275
+
276
+ e) Convey the object code using peer-to-peer transmission, provided
277
+ you inform other peers where the object code and Corresponding
278
+ Source of the work are being offered to the general public at no
279
+ charge under subsection 6d.
280
+
281
+ A separable portion of the object code, whose source code is excluded
282
+ from the Corresponding Source as a System Library, need not be
283
+ included in conveying the object code work.
284
+
285
+ A "User Product" is either (1) a "consumer product", which means any
286
+ tangible personal property which is normally used for personal, family,
287
+ or household purposes, or (2) anything designed or sold for incorporation
288
+ into a dwelling. In determining whether a product is a consumer product,
289
+ doubtful cases shall be resolved in favor of coverage. For a particular
290
+ product received by a particular user, "normally used" refers to a
291
+ typical or common use of that class of product, regardless of the status
292
+ of the particular user or of the way in which the particular user
293
+ actually uses, or expects or is expected to use, the product. A product
294
+ is a consumer product regardless of whether the product has substantial
295
+ commercial, industrial or non-consumer uses, unless such uses represent
296
+ the only significant mode of use of the product.
297
+
298
+ "Installation Information" for a User Product means any methods,
299
+ procedures, authorization keys, or other information required to install
300
+ and execute modified versions of a covered work in that User Product from
301
+ a modified version of its Corresponding Source. The information must
302
+ suffice to ensure that the continued functioning of the modified object
303
+ code is in no case prevented or interfered with solely because
304
+ modification has been made.
305
+
306
+ If you convey an object code work under this section in, or with, or
307
+ specifically for use in, a User Product, and the conveying occurs as
308
+ part of a transaction in which the right of possession and use of the
309
+ User Product is transferred to the recipient in perpetuity or for a
310
+ fixed term (regardless of how the transaction is characterized), the
311
+ Corresponding Source conveyed under this section must be accompanied
312
+ by the Installation Information. But this requirement does not apply
313
+ if neither you nor any third party retains the ability to install
314
+ modified object code on the User Product (for example, the work has
315
+ been installed in ROM).
316
+
317
+ The requirement to provide Installation Information does not include a
318
+ requirement to continue to provide support service, warranty, or updates
319
+ for a work that has been modified or installed by the recipient, or for
320
+ the User Product in which it has been modified or installed. Access to a
321
+ network may be denied when the modification itself materially and
322
+ adversely affects the operation of the network or violates the rules and
323
+ protocols for communication across the network.
324
+
325
+ Corresponding Source conveyed, and Installation Information provided,
326
+ in accord with this section must be in a format that is publicly
327
+ documented (and with an implementation available to the public in
328
+ source code form), and must require no special password or key for
329
+ unpacking, reading or copying.
330
+
331
+ 7. Additional Terms.
332
+
333
+ "Additional permissions" are terms that supplement the terms of this
334
+ License by making exceptions from one or more of its conditions.
335
+ Additional permissions that are applicable to the entire Program shall
336
+ be treated as though they were included in this License, to the extent
337
+ that they are valid under applicable law. If additional permissions
338
+ apply only to part of the Program, that part may be used separately
339
+ under those permissions, but the entire Program remains governed by
340
+ this License without regard to the additional permissions.
341
+
342
+ When you convey a copy of a covered work, you may at your option
343
+ remove any additional permissions from that copy, or from any part of
344
+ it. (Additional permissions may be written to require their own
345
+ removal in certain cases when you modify the work.) You may place
346
+ additional permissions on material, added by you to a covered work,
347
+ for which you have or can give appropriate copyright permission.
348
+
349
+ Notwithstanding any other provision of this License, for material you
350
+ add to a covered work, you may (if authorized by the copyright holders of
351
+ that material) supplement the terms of this License with terms:
352
+
353
+ a) Disclaiming warranty or limiting liability differently from the
354
+ terms of sections 15 and 16 of this License; or
355
+
356
+ b) Requiring preservation of specified reasonable legal notices or
357
+ author attributions in that material or in the Appropriate Legal
358
+ Notices displayed by works containing it; or
359
+
360
+ c) Prohibiting misrepresentation of the origin of that material, or
361
+ requiring that modified versions of such material be marked in
362
+ reasonable ways as different from the original version; or
363
+
364
+ d) Limiting the use for publicity purposes of names of licensors or
365
+ authors of the material; or
366
+
367
+ e) Declining to grant rights under trademark law for use of some
368
+ trade names, trademarks, or service marks; or
369
+
370
+ f) Requiring indemnification of licensors and authors of that
371
+ material by anyone who conveys the material (or modified versions of
372
+ it) with contractual assumptions of liability to the recipient, for
373
+ any liability that these contractual assumptions directly impose on
374
+ those licensors and authors.
375
+
376
+ All other non-permissive additional terms are considered "further
377
+ restrictions" within the meaning of section 10. If the Program as you
378
+ received it, or any part of it, contains a notice stating that it is
379
+ governed by this License along with a term that is a further
380
+ restriction, you may remove that term. If a license document contains
381
+ a further restriction but permits relicensing or conveying under this
382
+ License, you may add to a covered work material governed by the terms
383
+ of that license document, provided that the further restriction does
384
+ not survive such relicensing or conveying.
385
+
386
+ If you add terms to a covered work in accord with this section, you
387
+ must place, in the relevant source files, a statement of the
388
+ additional terms that apply to those files, or a notice indicating
389
+ where to find the applicable terms.
390
+
391
+ Additional terms, permissive or non-permissive, may be stated in the
392
+ form of a separately written license, or stated as exceptions;
393
+ the above requirements apply either way.
394
+
395
+ 8. Termination.
396
+
397
+ You may not propagate or modify a covered work except as expressly
398
+ provided under this License. Any attempt otherwise to propagate or
399
+ modify it is void, and will automatically terminate your rights under
400
+ this License (including any patent licenses granted under the third
401
+ paragraph of section 11).
402
+
403
+ However, if you cease all violation of this License, then your
404
+ license from a particular copyright holder is reinstated (a)
405
+ provisionally, unless and until the copyright holder explicitly and
406
+ finally terminates your license, and (b) permanently, if the copyright
407
+ holder fails to notify you of the violation by some reasonable means
408
+ prior to 60 days after the cessation.
409
+
410
+ Moreover, your license from a particular copyright holder is
411
+ reinstated permanently if the copyright holder notifies you of the
412
+ violation by some reasonable means, this is the first time you have
413
+ received notice of violation of this License (for any work) from that
414
+ copyright holder, and you cure the violation prior to 30 days after
415
+ your receipt of the notice.
416
+
417
+ Termination of your rights under this section does not terminate the
418
+ licenses of parties who have received copies or rights from you under
419
+ this License. If your rights have been terminated and not permanently
420
+ reinstated, you do not qualify to receive new licenses for the same
421
+ material under section 10.
422
+
423
+ 9. Acceptance Not Required for Having Copies.
424
+
425
+ You are not required to accept this License in order to receive or
426
+ run a copy of the Program. Ancillary propagation of a covered work
427
+ occurring solely as a consequence of using peer-to-peer transmission
428
+ to receive a copy likewise does not require acceptance. However,
429
+ nothing other than this License grants you permission to propagate or
430
+ modify any covered work. These actions infringe copyright if you do
431
+ not accept this License. Therefore, by modifying or propagating a
432
+ covered work, you indicate your acceptance of this License to do so.
433
+
434
+ 10. Automatic Licensing of Downstream Recipients.
435
+
436
+ Each time you convey a covered work, the recipient automatically
437
+ receives a license from the original licensors, to run, modify and
438
+ propagate that work, subject to this License. You are not responsible
439
+ for enforcing compliance by third parties with this License.
440
+
441
+ An "entity transaction" is a transaction transferring control of an
442
+ organization, or substantially all assets of one, or subdividing an
443
+ organization, or merging organizations. If propagation of a covered
444
+ work results from an entity transaction, each party to that
445
+ transaction who receives a copy of the work also receives whatever
446
+ licenses to the work the party's predecessor in interest had or could
447
+ give under the previous paragraph, plus a right to possession of the
448
+ Corresponding Source of the work from the predecessor in interest, if
449
+ the predecessor has it or can get it with reasonable efforts.
450
+
451
+ You may not impose any further restrictions on the exercise of the
452
+ rights granted or affirmed under this License. For example, you may
453
+ not impose a license fee, royalty, or other charge for exercise of
454
+ rights granted under this License, and you may not initiate litigation
455
+ (including a cross-claim or counterclaim in a lawsuit) alleging that
456
+ any patent claim is infringed by making, using, selling, offering for
457
+ sale, or importing the Program or any portion of it.
458
+
459
+ 11. Patents.
460
+
461
+ A "contributor" is a copyright holder who authorizes use under this
462
+ License of the Program or a work on which the Program is based. The
463
+ work thus licensed is called the contributor's "contributor version".
464
+
465
+ A contributor's "essential patent claims" are all patent claims
466
+ owned or controlled by the contributor, whether already acquired or
467
+ hereafter acquired, that would be infringed by some manner, permitted
468
+ by this License, of making, using, or selling its contributor version,
469
+ but do not include claims that would be infringed only as a
470
+ consequence of further modification of the contributor version. For
471
+ purposes of this definition, "control" includes the right to grant
472
+ patent sublicenses in a manner consistent with the requirements of
473
+ this License.
474
+
475
+ Each contributor grants you a non-exclusive, worldwide, royalty-free
476
+ patent license under the contributor's essential patent claims, to
477
+ make, use, sell, offer for sale, import and otherwise run, modify and
478
+ propagate the contents of its contributor version.
479
+
480
+ In the following three paragraphs, a "patent license" is any express
481
+ agreement or commitment, however denominated, not to enforce a patent
482
+ (such as an express permission to practice a patent or covenant not to
483
+ sue for patent infringement). To "grant" such a patent license to a
484
+ party means to make such an agreement or commitment not to enforce a
485
+ patent against the party.
486
+
487
+ If you convey a covered work, knowingly relying on a patent license,
488
+ and the Corresponding Source of the work is not available for anyone
489
+ to copy, free of charge and under the terms of this License, through a
490
+ publicly available network server or other readily accessible means,
491
+ then you must either (1) cause the Corresponding Source to be so
492
+ available, or (2) arrange to deprive yourself of the benefit of the
493
+ patent license for this particular work, or (3) arrange, in a manner
494
+ consistent with the requirements of this License, to extend the patent
495
+ license to downstream recipients. "Knowingly relying" means you have
496
+ actual knowledge that, but for the patent license, your conveying the
497
+ covered work in a country, or your recipient's use of the covered work
498
+ in a country, would infringe one or more identifiable patents in that
499
+ country that you have reason to believe are valid.
500
+
501
+ If, pursuant to or in connection with a single transaction or
502
+ arrangement, you convey, or propagate by procuring conveyance of, a
503
+ covered work, and grant a patent license to some of the parties
504
+ receiving the covered work authorizing them to use, propagate, modify
505
+ or convey a specific copy of the covered work, then the patent license
506
+ you grant is automatically extended to all recipients of the covered
507
+ work and works based on it.
508
+
509
+ A patent license is "discriminatory" if it does not include within
510
+ the scope of its coverage, prohibits the exercise of, or is
511
+ conditioned on the non-exercise of one or more of the rights that are
512
+ specifically granted under this License. You may not convey a covered
513
+ work if you are a party to an arrangement with a third party that is
514
+ in the business of distributing software, under which you make payment
515
+ to the third party based on the extent of your activity of conveying
516
+ the work, and under which the third party grants, to any of the
517
+ parties who would receive the covered work from you, a discriminatory
518
+ patent license (a) in connection with copies of the covered work
519
+ conveyed by you (or copies made from those copies), or (b) primarily
520
+ for and in connection with specific products or compilations that
521
+ contain the covered work, unless you entered into that arrangement,
522
+ or that patent license was granted, prior to 28 March 2007.
523
+
524
+ Nothing in this License shall be construed as excluding or limiting
525
+ any implied license or other defenses to infringement that may
526
+ otherwise be available to you under applicable patent law.
527
+
528
+ 12. No Surrender of Others' Freedom.
529
+
530
+ If conditions are imposed on you (whether by court order, agreement or
531
+ otherwise) that contradict the conditions of this License, they do not
532
+ excuse you from the conditions of this License. If you cannot convey a
533
+ covered work so as to satisfy simultaneously your obligations under this
534
+ License and any other pertinent obligations, then as a consequence you may
535
+ not convey it at all. For example, if you agree to terms that obligate you
536
+ to collect a royalty for further conveying from those to whom you convey
537
+ the Program, the only way you could satisfy both those terms and this
538
+ License would be to refrain entirely from conveying the Program.
539
+
540
+ 13. Remote Network Interaction; Use with the GNU General Public License.
541
+
542
+ Notwithstanding any other provision of this License, if you modify the
543
+ Program, your modified version must prominently offer all users
544
+ interacting with it remotely through a computer network (if your version
545
+ supports such interaction) an opportunity to receive the Corresponding
546
+ Source of your version by providing access to the Corresponding Source
547
+ from a network server at no charge, through some standard or customary
548
+ means of facilitating copying of software. This Corresponding Source
549
+ shall include the Corresponding Source for any work covered by version 3
550
+ of the GNU General Public License that is incorporated pursuant to the
551
+ following paragraph.
552
+
553
+ Notwithstanding any other provision of this License, you have
554
+ permission to link or combine any covered work with a work licensed
555
+ under version 3 of the GNU General Public License into a single
556
+ combined work, and to convey the resulting work. The terms of this
557
+ License will continue to apply to the part which is the covered work,
558
+ but the work with which it is combined will remain governed by version
559
+ 3 of the GNU General Public License.
560
+
561
+ 14. Revised Versions of this License.
562
+
563
+ The Free Software Foundation may publish revised and/or new versions of
564
+ the GNU Affero General Public License from time to time. Such new versions
565
+ will be similar in spirit to the present version, but may differ in detail to
566
+ address new problems or concerns.
567
+
568
+ Each version is given a distinguishing version number. If the
569
+ Program specifies that a certain numbered version of the GNU Affero General
570
+ Public License "or any later version" applies to it, you have the
571
+ option of following the terms and conditions either of that numbered
572
+ version or of any later version published by the Free Software
573
+ Foundation. If the Program does not specify a version number of the
574
+ GNU Affero General Public License, you may choose any version ever published
575
+ by the Free Software Foundation.
576
+
577
+ If the Program specifies that a proxy can decide which future
578
+ versions of the GNU Affero General Public License can be used, that proxy's
579
+ public statement of acceptance of a version permanently authorizes you
580
+ to choose that version for the Program.
581
+
582
+ Later license versions may give you additional or different
583
+ permissions. However, no additional obligations are imposed on any
584
+ author or copyright holder as a result of your choosing to follow a
585
+ later version.
586
+
587
+ 15. Disclaimer of Warranty.
588
+
589
+ THERE IS NO WARRANTY FOR THE PROGRAM, TO THE EXTENT PERMITTED BY
590
+ APPLICABLE LAW. EXCEPT WHEN OTHERWISE STATED IN WRITING THE COPYRIGHT
591
+ HOLDERS AND/OR OTHER PARTIES PROVIDE THE PROGRAM "AS IS" WITHOUT WARRANTY
592
+ OF ANY KIND, EITHER EXPRESSED OR IMPLIED, INCLUDING, BUT NOT LIMITED TO,
593
+ THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
594
+ PURPOSE. THE ENTIRE RISK AS TO THE QUALITY AND PERFORMANCE OF THE PROGRAM
595
+ IS WITH YOU. SHOULD THE PROGRAM PROVE DEFECTIVE, YOU ASSUME THE COST OF
596
+ ALL NECESSARY SERVICING, REPAIR OR CORRECTION.
597
+
598
+ 16. Limitation of Liability.
599
+
600
+ IN NO EVENT UNLESS REQUIRED BY APPLICABLE LAW OR AGREED TO IN WRITING
601
+ WILL ANY COPYRIGHT HOLDER, OR ANY OTHER PARTY WHO MODIFIES AND/OR CONVEYS
602
+ THE PROGRAM AS PERMITTED ABOVE, BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY
603
+ GENERAL, SPECIAL, INCIDENTAL OR CONSEQUENTIAL DAMAGES ARISING OUT OF THE
604
+ USE OR INABILITY TO USE THE PROGRAM (INCLUDING BUT NOT LIMITED TO LOSS OF
605
+ DATA OR DATA BEING RENDERED INACCURATE OR LOSSES SUSTAINED BY YOU OR THIRD
606
+ PARTIES OR A FAILURE OF THE PROGRAM TO OPERATE WITH ANY OTHER PROGRAMS),
607
+ EVEN IF SUCH HOLDER OR OTHER PARTY HAS BEEN ADVISED OF THE POSSIBILITY OF
608
+ SUCH DAMAGES.
609
+
610
+ 17. Interpretation of Sections 15 and 16.
611
+
612
+ If the disclaimer of warranty and limitation of liability provided
613
+ above cannot be given local legal effect according to their terms,
614
+ reviewing courts shall apply local law that most closely approximates
615
+ an absolute waiver of all civil liability in connection with the
616
+ Program, unless a warranty or assumption of liability accompanies a
617
+ copy of the Program in return for a fee.
618
+
619
+ END OF TERMS AND CONDITIONS
620
+
621
+ How to Apply These Terms to Your New Programs
622
+
623
+ If you develop a new program, and you want it to be of the greatest
624
+ possible use to the public, the best way to achieve this is to make it
625
+ free software which everyone can redistribute and change under these terms.
626
+
627
+ To do so, attach the following notices to the program. It is safest
628
+ to attach them to the start of each source file to most effectively
629
+ state the exclusion of warranty; and each file should have at least
630
+ the "copyright" line and a pointer to where the full notice is found.
631
+
632
+ <one line to give the program's name and a brief idea of what it does.>
633
+ Copyright (C) <year> <name of author>
634
+
635
+ This program is free software: you can redistribute it and/or modify
636
+ it under the terms of the GNU Affero General Public License as published
637
+ by the Free Software Foundation, either version 3 of the License, or
638
+ (at your option) any later version.
639
+
640
+ This program is distributed in the hope that it will be useful,
641
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
642
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
643
+ GNU Affero General Public License for more details.
644
+
645
+ You should have received a copy of the GNU Affero General Public License
646
+ along with this program. If not, see <https://www.gnu.org/licenses/>.
647
+
648
+ Also add information on how to contact you by electronic and paper mail.
649
+
650
+ If your software can interact with users remotely through a computer
651
+ network, you should also make sure that it provides a way for users to
652
+ get its source. For example, if your program is a web application, its
653
+ interface could display a "Source" link that leads users to an archive
654
+ of the code. There are many ways you could offer source, and different
655
+ solutions will be better for different programs; see section 13 for the
656
+ specific requirements.
657
+
658
+ You should also get your employer (if you work as a programmer) or school,
659
+ if any, to sign a "copyright disclaimer" for the program, if necessary.
660
+ For more information on this, and how to apply and follow the GNU AGPL, see
661
+ <https://www.gnu.org/licenses/>.
Poster.pdf ADDED
Binary file (740 kB). View file
 
README.md CHANGED
@@ -1,3 +1,133 @@
1
- ---
2
- license: cc-by-nc-4.0
3
- ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # FAST-RIR: FAST NEURAL DIFFUSE ROOM IMPULSE RESPONSE GENERATOR (ICASSP 2022)
2
+ This is the official implementation of our neural-network-based fast diffuse room impulse response generator ([**FAST-RIR**](https://arxiv.org/pdf/2110.04057.pdf)) for generating room impulse responses (RIRs) for a given rectangular acoustic environment. Our model is inspired by [**StackGAN**](https://github.com/hanzhanggit/StackGAN-Pytorch) architecture. The audio examples and spectrograms of the generated RIRs are available [here](https://anton-jeran.github.io/FRIR/).
3
+
4
+ **NEWS : We have genaralized our FAST-RIR to generate RIRs for any 3D indoor scenes represented using meshes. Official code of our network [**MESH2IR**](https://anton-jeran.github.io/M2IR/) is available.**
5
+
6
+ ## Requirements
7
+
8
+ ```
9
+ Python3.6
10
+ Pytorch
11
+ python-dateutil
12
+ easydict
13
+ pandas
14
+ torchfile
15
+ gdown
16
+ librosa
17
+ soundfile
18
+ acoustics
19
+ wavefile
20
+ wavfile
21
+ pyyaml==5.4.1
22
+ pickle
23
+ ```
24
+
25
+
26
+ ## Embedding
27
+
28
+ Each normalized embedding is created as follows: If you are using our trained model, you may need to use extra parameter Correction(CRR).
29
+
30
+ ```
31
+ Listener Position = LP
32
+ Source Position = SP
33
+ Room Dimension = RD
34
+ Reverberation Time = T60
35
+ Correction = CRR
36
+
37
+ CRR = 0.1 if 0.5<T60<0.6
38
+ CRR = 0.2 if T60>0.6
39
+ CRR = 0 otherwise
40
+
41
+ Embedding = ([LP_X,LP_Y,LP_Z,SP_X,SP_Y,SP_Z,RD_X,RD_Y,RD_Z,(T60+CRR)] /5) - 1
42
+ ```
43
+
44
+
45
+ ## Generete RIRs using trained model
46
+
47
+ Download the trained model using this command
48
+
49
+ ```
50
+ source download_generate.sh
51
+ ```
52
+
53
+ Create normalized embeddings list in pickle format. You can run following command to generate an example embedding list
54
+ ```
55
+ python3 example1.py
56
+ ```
57
+
58
+ Run the following command inside **code_new** to generate RIRs corresponding to the normalized embeddings list. You can find generated RIRs inside **code_new/Generated_RIRs**
59
+
60
+ ```
61
+ python3 main.py --cfg cfg/RIR_eval.yml --gpu 0
62
+ ```
63
+
64
+ ## Range
65
+
66
+ Our trained NN-DAS is capable of generating RIRs with the following range accurately.
67
+ ```
68
+ Room Dimension X --> 8m to 11m
69
+ Room Dimesnion Y --> 6m to 8m
70
+ Room Dimension Z --> 2.5m to 3.5m
71
+ Listener Position --> Any position within the room
72
+ Speaker Position --> Any position within the room
73
+ Reverberation time --> 0.2s to 0.7s
74
+ ```
75
+
76
+ ## Training the Model
77
+
78
+ Run the following command to download the training dataset we created using a [**Diffuse Acoustic Simulator**](https://github.com/GAMMA-UMD/pygsound). You also can train the model using your dataset.
79
+
80
+ ```
81
+ source download_data.sh
82
+ ```
83
+
84
+ Run the following command to train the model. You can pass what GPUs to be used for training as an input argument. In this example, I am using 2 GPUs.
85
+
86
+ ```
87
+ python3 main.py --cfg cfg/RIR_s1.yml --gpu 0,1
88
+ ```
89
+
90
+
91
+ ## Related Works
92
+ 1) [**IR-GAN: Room Impulse Response Generator for Far-field Speech Recognition (INTERSPEECH2021)**](https://github.com/anton-jeran/IR-GAN)
93
+ 2) [**TS-RIR: Translated synthetic room impulse responses for speech augmentation (IEEE ASRU 2021)**](https://github.com/GAMMA-UMD/TS-RIR)
94
+
95
+
96
+ ## Citations
97
+ If you use our **FAST-RIR** for your research, please consider citing
98
+
99
+ ```
100
+ @INPROCEEDINGS{9747846,
101
+ author={Ratnarajah, Anton and Zhang, Shi-Xiong and Yu, Meng and Tang, Zhenyu and Manocha, Dinesh and Yu, Dong},
102
+ booktitle={ICASSP 2022 - 2022 IEEE International Conference on Acoustics, Speech and Signal Processing (ICASSP)},
103
+ title={Fast-Rir: Fast Neural Diffuse Room Impulse Response Generator},
104
+ year={2022},
105
+ volume={},
106
+ number={},
107
+ pages={571-575},
108
+ doi={10.1109/ICASSP43922.2022.9747846}}
109
+ ```
110
+
111
+ Our work is inspired by
112
+ ```
113
+ @inproceedings{han2017stackgan,
114
+ Author = {Han Zhang and Tao Xu and Hongsheng Li and Shaoting Zhang and Xiaogang Wang and Xiaolei Huang and Dimitris Metaxas},
115
+ Title = {StackGAN: Text to Photo-realistic Image Synthesis with Stacked Generative Adversarial Networks},
116
+ Year = {2017},
117
+ booktitle = {{ICCV}},
118
+ }
119
+ ```
120
+
121
+ If you use our training dataset generated using [**Diffuse Acoustic Simulator**](https://github.com/GAMMA-UMD/pygsound) in your research, please consider citing
122
+ ```
123
+ @inproceedings{9052932,
124
+ author={Z. {Tang} and L. {Chen} and B. {Wu} and D. {Yu} and D. {Manocha}},
125
+ booktitle={ICASSP 2020 - 2020 IEEE International Conference on Acoustics, Speech and Signal Processing (ICASSP)},
126
+ title={Improving Reverberant Speech Training Using Diffuse Acoustic Simulation},
127
+ year={2020},
128
+ volume={},
129
+ number={},
130
+ pages={6969-6973},
131
+ }
132
+ ```
133
+
code_new/RT60.py ADDED
@@ -0,0 +1,131 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import math
3
+ from scipy.io import wavfile
4
+ from scipy import stats
5
+
6
+ from acoustics.utils import _is_1d
7
+ from acoustics.signal import bandpass
8
+ from acoustics.bands import (_check_band_type, octave_low, octave_high, third_low, third_high)
9
+
10
+ import soundfile as sf
11
+ from multiprocessing import Pool
12
+
13
+ def t60_impulse(raw_signal,fs): # pylint: disable=too-many-locals
14
+ """
15
+ Reverberation time from a WAV impulse response.
16
+ :param file_name: name of the WAV file containing the impulse response.
17
+ :param bands: Octave or third bands as NumPy array.
18
+ :param rt: Reverberation time estimator. It accepts `'t30'`, `'t20'`, `'t10'` and `'edt'`.
19
+ :returns: Reverberation time :math:`T_{60}`
20
+ """
21
+ bands =np.array([62.5 ,125, 250, 500,1000, 2000])
22
+
23
+ if np.max(raw_signal)==0 and np.min(raw_signal)==0:
24
+ print('came 1')
25
+ return .5
26
+
27
+ # fs, raw_signal = wavfile.read(file_name)
28
+ band_type = _check_band_type(bands)
29
+
30
+ # if band_type == 'octave':
31
+ low = octave_low(bands[0], bands[-1])
32
+ high = octave_high(bands[0], bands[-1])
33
+ # elif band_type == 'third':
34
+ # low = third_low(bands[0], bands[-1])
35
+ # high = third_high(bands[0], bands[-1])
36
+
37
+
38
+ init = -0.0
39
+ end = -60.0
40
+ factor = 1.0
41
+ bands =bands[3:5]
42
+ low = low[3:5]
43
+ high = high[3:5]
44
+
45
+ t60 = np.zeros(bands.size)
46
+
47
+ for band in range(bands.size):
48
+ # Filtering signal
49
+ filtered_signal = bandpass(raw_signal, low[band], high[band], fs, order=8)
50
+ abs_signal = np.abs(filtered_signal) / np.max(np.abs(filtered_signal))
51
+
52
+ # Schroeder integration
53
+ sch = np.cumsum(abs_signal[::-1]**2)[::-1]
54
+ sch_db = 10.0 * np.log10(sch / np.max(sch))
55
+ if math.isnan(sch_db[1]):
56
+ print('came 2')
57
+ return .5
58
+ # print("leng sch_db ",sch_db.size)
59
+ # print("sch_db ",sch_db)
60
+ # Linear regression
61
+ sch_init = sch_db[np.abs(sch_db - init).argmin()]
62
+ sch_end = sch_db[np.abs(sch_db - end).argmin()]
63
+ init_sample = np.where(sch_db == sch_init)[0][0]
64
+ end_sample = np.where(sch_db == sch_end)[0][0]
65
+ x = np.arange(init_sample, end_sample + 1) / fs
66
+ y = sch_db[init_sample:end_sample + 1]
67
+ slope, intercept = stats.linregress(x, y)[0:2]
68
+
69
+ # Reverberation time (T30, T20, T10 or EDT)
70
+ db_regress_init = (init - intercept) / slope
71
+ db_regress_end = (end - intercept) / slope
72
+ t60[band] = factor * (db_regress_end - db_regress_init)
73
+ mean_t60 =(t60[1]+t60[0])/2
74
+ # print("meant60 is ", mean_t60)
75
+ if math.isnan(mean_t60):
76
+ print('came 3')
77
+ return .5
78
+ return mean_t60
79
+
80
+ def t60_error(filename1,filename2):
81
+ real_wave,fs = sf.read(filename1)
82
+ fake_wave,fs = sf.read(filename2)
83
+
84
+ channel = int(real_wave.size/len(real_wave))
85
+ pool = Pool(processes=8)
86
+
87
+ results =[]
88
+ for n in range(channel):
89
+ results.append(pool.apply_async(t60_parallel, args=(n,real_wave,fake_wave,fs,)))
90
+
91
+ T60_error =0
92
+ for result in results:
93
+ T60_error = T60_error + result.get()
94
+
95
+ T60_error = T60_error/channel
96
+
97
+ pool.close()
98
+ pool.join()
99
+
100
+
101
+ # T60_error = Parallel(n_jobs=64)(delayed(t60_parallel)(n, real_wave,fake_wave,fs) for n in range(channel))#np.random.randint(0,1023,size=channel))#
102
+ # T60_error = sum(results)/channel
103
+
104
+ # for n in range(channel):
105
+ # real_wave_single = real_wave[:,n]
106
+ # fake_wave_single = fake_wave[:,n]
107
+ # Real_T60_val = t60_impulse(real_wave_single,fs)
108
+ # Fake_T60_val = t60_impulse(fake_wave_single,fs)
109
+ # T60_diff = abs(Real_T60_val-Fake_T60_val)
110
+ # T60_error = T60_error + T60_diff
111
+ # T60_error = T60_error/channel
112
+ return str(T60_error)
113
+
114
+ def t60_parallel(n,real_wave,fake_wave,fs):
115
+ real_wave_single = real_wave[n,:]
116
+ fake_wave_single = fake_wave[n,:]
117
+ Real_T60_val = t60_impulse(real_wave_single,fs)
118
+ Fake_T60_val = t60_impulse(fake_wave_single,fs)
119
+ T60_diff = abs(Real_T60_val-Fake_T60_val)
120
+
121
+ return T60_diff
122
+
123
+
124
+
125
+
126
+
127
+
128
+
129
+ if __name__ == '__main__':
130
+ t60_impulse('/home/anton/Desktop/gamma101/data/evaluation_all/SF1/Hotel_SkalskyDvur_ConferenceRoom2-MicID01-SpkID01_20170906_S-09-RIR-IR_sweep_15s_45Hzto22kHz_FS16kHz.v00.wav')
131
+
code_new/RTS.py ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ # import librosa
3
+
4
+ from scipy.io import wavfile
5
+ from scipy import stats
6
+ import soundfile as sf
7
+
8
+ from acoustics.utils import _is_1d
9
+ from acoustics.signal import bandpass
10
+ from acoustics.bands import (_check_band_type, octave_low, octave_high, third_low, third_high)
11
+
12
+ def t60_impulse(file_name): # pylint: disable=too-many-locals
13
+ """
14
+ Reverberation time from a WAV impulse response.
15
+ :param file_name: name of the WAV file containing the impulse response.
16
+ :param bands: Octave or third bands as NumPy array.
17
+ :param rt: Reverberation time estimator. It accepts `'t30'`, `'t20'`, `'t10'` and `'edt'`.
18
+ :returns: Reverberation time :math:`T_{60}`
19
+ """
20
+ bands =np.array([62.5 ,125, 250, 500,1000, 2000])
21
+
22
+ fs =16000;
23
+ # raw_signal, _ = librosa.load(file_name, sr=fs, mono=True, duration=1)
24
+
25
+ # fs, raw_signal = wavfile.read(file_name)
26
+ raw_signal,fs = sf.read(file_name)
27
+ band_type = _check_band_type(bands)
28
+
29
+ # if band_type == 'octave':
30
+ low = octave_low(bands[0], bands[-1])
31
+ high = octave_high(bands[0], bands[-1])
32
+ # elif band_type == 'third':
33
+ # low = third_low(bands[0], bands[-1])
34
+ # high = third_high(bands[0], bands[-1])
35
+
36
+
37
+ init = -0.0
38
+ end = -60.0
39
+ factor = 1.0
40
+ bands =bands[3:5]
41
+ low = low[3:5]
42
+ high = high[3:5]
43
+
44
+ t60 = np.zeros(bands.size)
45
+
46
+ for band in range(bands.size):
47
+ # Filtering signal
48
+ filtered_signal = bandpass(raw_signal, low[band], high[band], fs, order=8)
49
+ abs_signal = np.abs(filtered_signal) / np.max(np.abs(filtered_signal))
50
+
51
+ # Schroeder integration
52
+ sch = np.cumsum(abs_signal[::-1]**2)[::-1]
53
+ sch_db = 10.0 * np.log10(sch / np.max(sch))
54
+
55
+ # Linear regression
56
+ sch_init = sch_db[np.abs(sch_db - init).argmin()]
57
+ sch_end = sch_db[np.abs(sch_db - end).argmin()]
58
+ init_sample = np.where(sch_db == sch_init)[0][0]
59
+ end_sample = np.where(sch_db == sch_end)[0][0]
60
+ x = np.arange(init_sample, end_sample + 1) / fs
61
+ y = sch_db[init_sample:end_sample + 1]
62
+ slope, intercept = stats.linregress(x, y)[0:2]
63
+
64
+ # Reverberation time (T30, T20, T10 or EDT)
65
+ db_regress_init = (init - intercept) / slope
66
+ db_regress_end = (end - intercept) / slope
67
+ t60[band] = factor * (db_regress_end - db_regress_init)
68
+ mean_t60 =(t60[1]+t60[0])/2
69
+ return mean_t60
70
+
71
+ def t60_error(file_name1,file_name2):
72
+ RT_real = t60_impulse(file_name1)
73
+ RT_fake = t60_impulse(file_name2)
74
+ RT_diff = abs(RT_real-RT_fake)
75
+ return str(RT_diff)
76
+
77
+ if __name__ == '__main__':
78
+ t60_impulse('/home/anton/Anton/data/vcc2016_training/SF1/VUT_FIT_D105-MicID01-SpkID04_20170901_S-12-RIR-IR_sweep_15s_45Hzto22kHz_FS16kHz.v00.wav')
79
+ # t60_impulse('/home/anton/Desktop/data/vcc2016_training/SF1/2.wav')
80
+ # t60_impulse('/home/anton/Desktop/data/vcc2016_training/SF1/3.wav')
81
+ # t60_impulse('/home/anton/Desktop/data/vcc2016_training/SF1/4.wav')
82
+ # t60_impulse('/home/anton/Desktop/data/vcc2016_training/SF1/5.wav')
83
+ # t60_impulse('/home/anton/Desktop/data/vcc2016_training/SF1/6.wav')
84
+ # t60_impulse('/home/anton/Desktop/data/vcc2016_training/SF1/7.wav')
85
+ # t60_impulse('/home/anton/Desktop/data/vcc2016_training/SF1/8.wav')
86
+ # t60_impulse('/home/anton/Desktop/data/vcc2016_training/SF1/9.wav')
87
+ # t60_impulse('/home/anton/Desktop/data/vcc2016_training/SF1/10.wav')
code_new/__init__py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from __future__ import division
2
+ from __future__ import print_function
code_new/__pycache__/RT60.cpython-36.pyc ADDED
Binary file (3.13 kB). View file
 
code_new/__pycache__/RT60.cpython-38.pyc ADDED
Binary file (3.12 kB). View file
 
code_new/__pycache__/RTS.cpython-38.pyc ADDED
Binary file (2.34 kB). View file
 
code_new/__pycache__/model.cpython-36.pyc ADDED
Binary file (9.16 kB). View file
 
code_new/__pycache__/model.cpython-38.pyc ADDED
Binary file (9.05 kB). View file
 
code_new/__pycache__/trainer.cpython-36.pyc ADDED
Binary file (7.79 kB). View file
 
code_new/__pycache__/trainer.cpython-38.pyc ADDED
Binary file (7.31 kB). View file
 
code_new/cfg/RIR_eval.yml ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ CONFIG_NAME: 'eval'
2
+
3
+ DATASET_NAME: 'RIR'
4
+ EMBEDDING_TYPE: 'cnn-rnn'
5
+ GPU_ID: '0,1'
6
+ # Z_DIM: 100
7
+
8
+ NET_G: '../generate/netG_epoch_242.pth'
9
+
10
+ DATA_DIR: '../data/Medium_Room'
11
+ EVAL_DIR: '../example1.pickle'
12
+ WORKERS: 4
13
+ RIRSIZE: 4096
14
+ STAGE: 1
15
+ TRAIN:
16
+ FLAG: False
17
+ BATCH_SIZE: 64
18
+
19
+ GAN:
20
+ CONDITION_DIM: 10
21
+ DF_DIM: 96
22
+ GF_DIM: 256
23
+
24
+ TEXT:
25
+ DIMENSION: 10
code_new/cfg/RIR_s1.yml ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ CONFIG_NAME: 'stageI'
2
+
3
+ DATASET_NAME: 'RIR'
4
+ EMBEDDING_TYPE: 'cnn-rnn'
5
+ GPU_ID: '0,1'
6
+
7
+ DATA_DIR: '../data/Medium_Room'
8
+
9
+ EVAL_DIR: '../generate/embeddings/'
10
+ RIRSIZE: 4096
11
+ WORKERS: 4
12
+ STAGE: 1
13
+ TRAIN:
14
+ FLAG: True
15
+ BATCH_SIZE: 128
16
+ MAX_EPOCH: 2000
17
+ LR_DECAY_EPOCH: 40
18
+ SNAPSHOT_INTERVAL: 50
19
+ # DISCRIMINATOR_LR: 0.0002
20
+ # GENERATOR_LR: 0.0002
21
+ DISCRIMINATOR_LR: 0.00008
22
+ GENERATOR_LR: 0.00008
23
+ COEFF:
24
+ KL: 2.0
25
+
26
+ GAN:
27
+ CONDITION_DIM: 10
28
+ DF_DIM: 96
29
+ GF_DIM: 256
30
+
31
+ TEXT:
32
+ DIMENSION: 10
code_new/cfg/RIR_s1_temp.yml ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ CONFIG_NAME: 'stageI'
2
+
3
+ DATASET_NAME: 'RIR'
4
+ EMBEDDING_TYPE: 'cnn-rnn'
5
+ GPU_ID: '0,1'
6
+
7
+ DATA_DIR: '../data/Medium_Room'
8
+
9
+ EVAL_DIR: '../generate/embeddings/'
10
+ RIRSIZE: 4096
11
+ WORKERS: 4
12
+ STAGE: 1
13
+ TRAIN:
14
+ FLAG: True
15
+ BATCH_SIZE: 128
16
+ MAX_EPOCH: 2000
17
+ LR_DECAY_EPOCH: 40
18
+ SNAPSHOT_INTERVAL: 50
19
+ # DISCRIMINATOR_LR: 0.0002
20
+ # GENERATOR_LR: 0.0002
21
+ DISCRIMINATOR_LR: 0.00008
22
+ GENERATOR_LR: 0.00008
23
+ COEFF:
24
+ KL: 2.0
25
+
26
+ GAN:
27
+ CONDITION_DIM: 10
28
+ DF_DIM: 96
29
+ GF_DIM: 256
30
+
31
+ TEXT:
32
+ DIMENSION: 10
code_new/main.py ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import print_function
2
+ import torch.backends.cudnn as cudnn
3
+ import torch
4
+ import torchvision.transforms as transforms
5
+
6
+ import argparse
7
+ import os
8
+ import random
9
+ import sys
10
+ import pprint
11
+ import datetime
12
+ import dateutil
13
+ import dateutil.tz
14
+
15
+
16
+ dir_path = (os.path.abspath(os.path.join(os.path.realpath(__file__), './.')))
17
+ sys.path.append(dir_path)
18
+
19
+ from miscc.datasets import TextDataset
20
+ from miscc.config import cfg, cfg_from_file
21
+ from miscc.utils import mkdir_p
22
+ from trainer import GANTrainer
23
+
24
+
25
+ def parse_args():
26
+ parser = argparse.ArgumentParser(description='Train a GAN network')
27
+ parser.add_argument('--cfg', dest='cfg_file',
28
+ help='optional config file',
29
+ default='birds_stage1.yml', type=str)
30
+ parser.add_argument('--gpu', dest='gpu_id', type=str, default='0')
31
+ parser.add_argument('--data_dir', dest='data_dir', type=str, default='')
32
+ parser.add_argument('--manualSeed', type=int, help='manual seed')
33
+ args = parser.parse_args()
34
+ return args
35
+
36
+ if __name__ == "__main__":
37
+ args = parse_args()
38
+ if args.cfg_file is not None:
39
+ cfg_from_file(args.cfg_file)
40
+ if args.gpu_id != -1:
41
+ cfg.GPU_ID = args.gpu_id
42
+ if args.data_dir != '':
43
+ cfg.DATA_DIR = args.data_dir
44
+ print('Using config:')
45
+ pprint.pprint(cfg)
46
+ if args.manualSeed is None:
47
+ args.manualSeed = random.randint(1, 10000)
48
+ random.seed(args.manualSeed)
49
+ torch.manual_seed(args.manualSeed)
50
+ if cfg.CUDA:
51
+ torch.cuda.manual_seed_all(args.manualSeed)
52
+ now = datetime.datetime.now(dateutil.tz.tzlocal())
53
+ timestamp = now.strftime('%Y_%m_%d_%H_%M_%S')
54
+ output_dir = '../output/%s_%s_%s' % \
55
+ (cfg.DATASET_NAME, cfg.CONFIG_NAME, timestamp)
56
+
57
+ num_gpu = len(cfg.GPU_ID.split(','))
58
+ if cfg.TRAIN.FLAG:
59
+ dataset = TextDataset(cfg.DATA_DIR, 'train',
60
+ rirsize=cfg.RIRSIZE)
61
+ assert dataset
62
+ #commented for temporary
63
+ dataloader = torch.utils.data.DataLoader(
64
+ dataset, batch_size=cfg.TRAIN.BATCH_SIZE * num_gpu,
65
+ drop_last=True, shuffle=True, num_workers=int(cfg.WORKERS))
66
+
67
+ algo = GANTrainer(output_dir)
68
+ algo.train(dataloader, cfg.STAGE)
69
+ else:
70
+ file_path = cfg.EVAL_DIR
71
+ algo = GANTrainer(output_dir)
72
+ algo.sample(file_path, cfg.STAGE)
code_new/miscc/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from __future__ import division
2
+ from __future__ import print_function
code_new/miscc/__init__.pyc ADDED
Binary file (241 Bytes). View file
 
code_new/miscc/__pycache__/__init__.cpython-36.pyc ADDED
Binary file (218 Bytes). View file
 
code_new/miscc/__pycache__/__init__.cpython-38.pyc ADDED
Binary file (243 Bytes). View file
 
code_new/miscc/__pycache__/config.cpython-36.pyc ADDED
Binary file (2.11 kB). View file
 
code_new/miscc/__pycache__/config.cpython-38.pyc ADDED
Binary file (2.13 kB). View file
 
code_new/miscc/__pycache__/datasets.cpython-36.pyc ADDED
Binary file (2.49 kB). View file
 
code_new/miscc/__pycache__/datasets.cpython-38.pyc ADDED
Binary file (2.55 kB). View file
 
code_new/miscc/__pycache__/utils.cpython-36.pyc ADDED
Binary file (4.37 kB). View file
 
code_new/miscc/__pycache__/utils.cpython-38.pyc ADDED
Binary file (4.4 kB). View file
 
code_new/miscc/config.py ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import division
2
+ from __future__ import print_function
3
+
4
+ import os.path as osp
5
+ import numpy as np
6
+ from easydict import EasyDict as edict
7
+
8
+
9
+ __C = edict()
10
+ cfg = __C
11
+
12
+ # Dataset name: flowers, birds
13
+ __C.DATASET_NAME = 'birds'
14
+ __C.EMBEDDING_TYPE = 'cnn-rnn'
15
+ __C.CONFIG_NAME = ''
16
+ __C.GPU_ID = '0'
17
+ __C.CUDA = True
18
+ __C.WORKERS = 6
19
+
20
+ __C.NET_G = ''
21
+ __C.NET_D = ''
22
+ __C.STAGE1_G = ''
23
+ __C.DATA_DIR = ''
24
+ __C.EVAL_DIR = ''
25
+ __C.VIS_COUNT = 64
26
+
27
+ __C.Z_DIM = 100
28
+ __C.RIRSIZE = 4096
29
+ __C.STAGE = 1
30
+
31
+
32
+ # Training options
33
+ __C.TRAIN = edict()
34
+ __C.TRAIN.FLAG = True
35
+ __C.TRAIN.BATCH_SIZE = 64
36
+ __C.TRAIN.MAX_EPOCH = 600
37
+ __C.TRAIN.SNAPSHOT_INTERVAL = 50
38
+ __C.TRAIN.PRETRAINED_MODEL = ''
39
+ __C.TRAIN.PRETRAINED_EPOCH = 600
40
+ __C.TRAIN.LR_DECAY_EPOCH = 600
41
+ __C.TRAIN.DISCRIMINATOR_LR = 2e-4
42
+ __C.TRAIN.GENERATOR_LR = 2e-4
43
+
44
+ __C.TRAIN.COEFF = edict()
45
+ __C.TRAIN.COEFF.KL = 2.0
46
+
47
+ # Modal options
48
+ __C.GAN = edict()
49
+ __C.GAN.CONDITION_DIM = 128
50
+ __C.GAN.DF_DIM = 64
51
+ __C.GAN.GF_DIM = 128
52
+ __C.GAN.R_NUM = 4
53
+
54
+ __C.TEXT = edict()
55
+ __C.TEXT.DIMENSION = 1024
56
+
57
+
58
+ def _merge_a_into_b(a, b):
59
+ """Merge config dictionary a into config dictionary b, clobbering the
60
+ options in b whenever they are also specified in a.
61
+ """
62
+ if type(a) is not edict:
63
+ return
64
+
65
+ for k, v in a.items():
66
+ # a must specify keys that are in b
67
+ if k not in b:
68
+ raise KeyError('{} is not a valid config key'.format(k))
69
+
70
+ # the types must match, too
71
+ old_type = type(b[k])
72
+ if old_type is not type(v):
73
+ if isinstance(b[k], np.ndarray):
74
+ v = np.array(v, dtype=b[k].dtype)
75
+ else:
76
+ raise ValueError(('Type mismatch ({} vs. {}) '
77
+ 'for config key: {}').format(type(b[k]),
78
+ type(v), k))
79
+
80
+ # recursively merge dicts
81
+ if type(v) is edict:
82
+ try:
83
+ _merge_a_into_b(a[k], b[k])
84
+ except:
85
+ print('Error under config key: {}'.format(k))
86
+ raise
87
+ else:
88
+ b[k] = v
89
+
90
+
91
+ def cfg_from_file(filename):
92
+ """Load a config file and merge it into the default options."""
93
+ import yaml
94
+ with open(filename, 'r') as f:
95
+ yaml_cfg = edict(yaml.load(f))
96
+
97
+ _merge_a_into_b(yaml_cfg, __C)
code_new/miscc/config.pyc ADDED
Binary file (2.71 kB). View file
 
code_new/miscc/datasets.py ADDED
@@ -0,0 +1,113 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import absolute_import
2
+ from __future__ import division
3
+ from __future__ import print_function
4
+ from __future__ import unicode_literals
5
+
6
+
7
+ import torch.utils.data as data
8
+ # from PIL import Image
9
+ import soundfile as sf
10
+ import PIL
11
+ import os
12
+ import os.path
13
+ import pickle
14
+ import random
15
+ import numpy as np
16
+ import pandas as pd
17
+ from scipy import signal
18
+
19
+ from miscc.config import cfg
20
+
21
+
22
+ class TextDataset(data.Dataset):
23
+ def __init__(self, data_dir, split='train',rirsize=4096): #, transform=None, target_transform=None):
24
+
25
+ # self.transform = transform
26
+ # self.target_transform = target_transform
27
+ self.rirsize = rirsize
28
+ self.data = []
29
+ self.data_dir = data_dir
30
+ self.bbox = None
31
+
32
+ split_dir = os.path.join(data_dir, split)
33
+
34
+ self.filenames = self.load_filenames(split_dir)
35
+ self.embeddings = self.load_embedding(split_dir)
36
+
37
+ def get_RIR(self, RIR_path):
38
+ wav,fs = sf.read(RIR_path) #Image.open(RIR_path).convert('RGB')
39
+ length = wav.size
40
+ # crop_length = int((16384*(80))/(64))
41
+ crop_length = 4096 #int(16384)
42
+ if(length<crop_length):
43
+ zeros = np.zeros(crop_length-length)
44
+ RIR_original = np.concatenate([wav,zeros])
45
+ else:
46
+ RIR_original = wav[0:crop_length]
47
+
48
+ # resample_length = int((self.rirsize*(80))/(64))
49
+ resample_length = int(self.rirsize)
50
+ if(resample_length==16384):
51
+ RIR = RIR_original
52
+ else:
53
+ RIR = RIR_original#signal.resample(RIR_original,resample_length)
54
+ RIR = np.array([RIR]).astype('float32')
55
+
56
+
57
+
58
+ # if bbox is not None:
59
+ # R = int(np.maximum(bbox[2], bbox[3]) * 0.75)
60
+ # center_x = int((2 * bbox[0] + bbox[2]) / 2)
61
+ # center_y = int((2 * bbox[1] + bbox[3]) / 2)
62
+ # y1 = np.maximum(0, center_y - R)
63
+ # y2 = np.minimum(height, center_y + R)
64
+ # x1 = np.maximum(0, center_x - R)
65
+ # x2 = np.minimum(width, center_x + R)
66
+ # RIR = RIR.crop([x1, y1, x2, y2])
67
+ # load_size = int(self.rirsize * 76 / 64)
68
+ # RIR = RIR.resize((load_size, load_size), PIL.Image.BILINEAR)
69
+ # if self.transform is not None:
70
+ # RIR = self.transform(RIR)
71
+ return RIR
72
+
73
+
74
+ def load_embedding(self, data_dir):
75
+ embedding_filename = '/embeddings.pickle'
76
+ with open(data_dir + embedding_filename, 'rb') as f:
77
+ embeddings = pickle.load(f)
78
+ # embeddings = np.array(embeddings)
79
+ # # embedding_shape = [embeddings.shape[-1]]
80
+ # print('embeddings: ', embeddings.shape)
81
+ return embeddings
82
+
83
+ # def load_class_id(self, data_dir, total_num):
84
+ # if os.path.isfile(data_dir + '/class_info.pickle'):
85
+ # with open(data_dir + '/class_info.pickle', 'rb') as f:
86
+ # class_id = pickle.load(f)
87
+ # else:
88
+ # class_id = np.arange(total_num)
89
+ # return class_id
90
+
91
+ def load_filenames(self, data_dir):
92
+ filepath = os.path.join(data_dir, 'filenames.pickle')
93
+ with open(filepath, 'rb') as f:
94
+ filenames = pickle.load(f)
95
+ print('Load filenames from: %s (%d)' % (filepath, len(filenames)))
96
+ return filenames
97
+
98
+ def __getitem__(self, index):
99
+ key = self.filenames[index]
100
+
101
+ data_dir = self.data_dir
102
+
103
+ # captions = self.captions[key]
104
+ embeddings = self.embeddings[key]
105
+ RIR_name = '%s/RIR/%s.wav' % (data_dir, key)
106
+ RIR = self.get_RIR(RIR_name)
107
+ embedding = np.array(embeddings).astype('float32')
108
+ # if self.target_transform is not None:
109
+ # embedding = self.target_transform(embedding)
110
+ return RIR, embedding
111
+
112
+ def __len__(self):
113
+ return len(self.filenames)
code_new/miscc/datasets.pyc ADDED
Binary file (3.16 kB). View file
 
code_new/miscc/utils.py ADDED
@@ -0,0 +1,239 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import errno
3
+ import numpy as np
4
+
5
+ from copy import deepcopy
6
+ from miscc.config import cfg
7
+ from scipy.io.wavfile import write
8
+ from torch.nn import init
9
+ import torch
10
+ import torch.nn as nn
11
+ import torchvision.utils as vutils
12
+ from wavefile import WaveWriter, Format
13
+ import RT60
14
+ from multiprocessing import Pool
15
+
16
+
17
+ #############################
18
+ def KL_loss(mu, logvar):
19
+ # -0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2)
20
+ KLD_element = mu.pow(2).add_(logvar.exp()).mul_(-1).add_(1).add_(logvar)
21
+ KLD = torch.mean(KLD_element).mul_(-0.5)
22
+ return KLD
23
+
24
+
25
+ def compute_discriminator_loss(netD, real_RIRs, fake_RIRs,
26
+ real_labels, fake_labels,
27
+ conditions, gpus):
28
+ criterion = nn.BCELoss()
29
+ batch_size = real_RIRs.size(0)
30
+ cond = conditions.detach()
31
+ fake = fake_RIRs.detach()
32
+ real_features = nn.parallel.data_parallel(netD, (real_RIRs), gpus)
33
+ fake_features = nn.parallel.data_parallel(netD, (fake), gpus)
34
+ # real pairs
35
+ #print("util conditions ",cond.size())
36
+ inputs = (real_features, cond)
37
+ real_logits = nn.parallel.data_parallel(netD.get_cond_logits, inputs, gpus)
38
+ errD_real = criterion(real_logits, real_labels)
39
+ # wrong pairs
40
+ inputs = (real_features[:(batch_size-1)], cond[1:])
41
+ wrong_logits = \
42
+ nn.parallel.data_parallel(netD.get_cond_logits, inputs, gpus)
43
+ errD_wrong = criterion(wrong_logits, fake_labels[1:])
44
+ # fake pairs
45
+ inputs = (fake_features, cond)
46
+ fake_logits = nn.parallel.data_parallel(netD.get_cond_logits, inputs, gpus)
47
+ errD_fake = criterion(fake_logits, fake_labels)
48
+
49
+ if netD.get_uncond_logits is not None:
50
+ real_logits = \
51
+ nn.parallel.data_parallel(netD.get_uncond_logits,
52
+ (real_features), gpus)
53
+ fake_logits = \
54
+ nn.parallel.data_parallel(netD.get_uncond_logits,
55
+ (fake_features), gpus)
56
+ uncond_errD_real = criterion(real_logits, real_labels)
57
+ uncond_errD_fake = criterion(fake_logits, fake_labels)
58
+ #
59
+ errD = ((errD_real + uncond_errD_real) / 2. +
60
+ (errD_fake + errD_wrong + uncond_errD_fake) / 3.)
61
+ errD_real = (errD_real + uncond_errD_real) / 2.
62
+ errD_fake = (errD_fake + uncond_errD_fake) / 2.
63
+ else:
64
+ errD = errD_real + (errD_fake + errD_wrong) * 0.5
65
+ return errD, errD_real.data, errD_wrong.data, errD_fake.data
66
+ # return errD, errD_real.data[0], errD_wrong.data[0], errD_fake.data[0]
67
+
68
+
69
+
70
+ def compute_generator_loss(epoch,netD,real_RIRs, fake_RIRs, real_labels, conditions, gpus):
71
+ criterion = nn.BCELoss()
72
+ loss = nn.L1Loss() #nn.MSELoss()
73
+ loss1 = nn.MSELoss()
74
+ RT_error = 0
75
+ # print("num", real_RIRs.size(),real_RIRs.size()[0])
76
+ # input("kk")
77
+
78
+
79
+ cond = conditions.detach()
80
+ fake_features = nn.parallel.data_parallel(netD, (fake_RIRs), gpus)
81
+ # fake pairs
82
+ inputs = (fake_features, cond)
83
+ fake_logits = nn.parallel.data_parallel(netD.get_cond_logits, inputs, gpus)
84
+ MSE_error = loss(real_RIRs,fake_RIRs)
85
+ MSE_error1 = loss1(real_RIRs,fake_RIRs)
86
+ sample_size = real_RIRs.size()[0]
87
+ channel = 12
88
+ fs = 16000
89
+ rn = np.random.randint(sample_size-(channel*2))
90
+ real_wave = np.array(real_RIRs[rn:rn+channel].to("cpu").detach())
91
+ real_wave = real_wave.reshape(channel,4096)
92
+ fake_wave = np.array(fake_RIRs[rn:rn+channel].to("cpu").detach())
93
+ fake_wave = fake_wave.reshape(channel,4096)
94
+
95
+ pool = Pool(processes=12)
96
+
97
+ results =[]
98
+ for n in range(channel):
99
+ results.append(pool.apply_async(RT60.t60_parallel, args=(n,real_wave,fake_wave,fs,)))
100
+
101
+ T60_error =0
102
+ for result in results:
103
+ T60_error = T60_error + result.get()
104
+
105
+ RT_error = T60_error/channel
106
+
107
+ pool.close()
108
+ pool.join()
109
+
110
+
111
+
112
+
113
+ # T60_error =0
114
+ # for m in range(channel):
115
+ # real_wave_single = real_wave[:,(rn+m)]
116
+ # fake_wave_single = fake_wave[:,(rn+m)]
117
+ # Real_T60_val = RT60.t60_impulse(real_wave_single,fs)
118
+ # Fake_T60_val = RT60.t60_impulse(fake_wave_single,fs)
119
+ # T60_diff = abs(Real_T60_val-Fake_T60_val)
120
+ # T60_error = T60_error + T60_diff
121
+
122
+ # RT_error = T60_error/channel
123
+
124
+
125
+ # r = WaveWriter("real.wav", channels=portion, samplerate=fs)
126
+ # r.write(np.array(real_IR))
127
+ # f = WaveWriter("fake.wav", channels=portion, samplerate=fs)
128
+ # f.write(np.array(fake_IR))
129
+
130
+
131
+ # result = call_python_version("3.8", "RT60", "t60_error",
132
+ # ["real.wav","fake.wav"])
133
+ # # print("RT_error ",result)
134
+ # RT_error = float(result)
135
+
136
+
137
+ # print("RT_error ",RT_error)
138
+
139
+ # if(epoch<100):
140
+ # errD_fake = criterion(fake_logits, real_labels)# + 2* 4096 * MSE_error
141
+ # else:
142
+ # errD_fake = criterion(fake_logits, real_labels) + 2* 4096 * MSE_error
143
+ errD_fake = criterion(fake_logits, real_labels) + 5* 4096 * MSE_error1 + 40 * RT_error
144
+ if netD.get_uncond_logits is not None:
145
+ fake_logits = \
146
+ nn.parallel.data_parallel(netD.get_uncond_logits,
147
+ (fake_features), gpus)
148
+ uncond_errD_fake = criterion(fake_logits, real_labels)
149
+ errD_fake += uncond_errD_fake
150
+ return errD_fake, MSE_error,RT_error
151
+
152
+
153
+ #############################
154
+ def weights_init(m):
155
+ classname = m.__class__.__name__
156
+ if classname.find('Conv') != -1:
157
+ m.weight.data.normal_(0.0, 0.02)
158
+ elif classname.find('BatchNorm') != -1:
159
+ m.weight.data.normal_(1.0, 0.02)
160
+ m.bias.data.fill_(0)
161
+ elif classname.find('Linear') != -1:
162
+ m.weight.data.normal_(0.0, 0.02)
163
+ if m.bias is not None:
164
+ m.bias.data.fill_(0.0)
165
+
166
+
167
+ #############################
168
+ def save_RIR_results(data_RIR, fake, epoch, RIR_dir):
169
+ num = cfg.VIS_COUNT
170
+ fake = fake[0:num]
171
+ # data_RIR is changed to [0,1]
172
+ if data_RIR is not None:
173
+ data_RIR = data_RIR[0:num]
174
+ for i in range(num):
175
+ # #print("came 1")
176
+ real_RIR_path = RIR_dir+"/real_sample"+str(i)+".wav"
177
+ fake_RIR_path = RIR_dir+"/fake_sample"+str(i)+"_epoch_"+str(epoch)+".wav"
178
+ fs =16000
179
+
180
+ real_IR = np.array(data_RIR[i].to("cpu").detach())
181
+ fake_IR = np.array(fake[i].to("cpu").detach())
182
+ # #print("fake_IR ", fake_IR.size)
183
+ # #print("real_IR ", real_IR.size)
184
+ # #print("max real_IR ", max(real_IR[0]))
185
+ # #print("min real_IR ", min(real_IR[0]))
186
+ r = WaveWriter(real_RIR_path, channels=1, samplerate=fs)
187
+ r.write(np.array(real_IR))
188
+ f = WaveWriter(fake_RIR_path, channels=1, samplerate=fs)
189
+ f.write(np.array(fake_IR))
190
+
191
+
192
+ # write(real_RIR_path,fs,real_IR)
193
+ # write(fake_RIR_path,fs,fake_IR)
194
+
195
+
196
+ # write(real_RIR_path,fs,real_IR)
197
+ # write(fake_RIR_path,fs,fake_IR)
198
+
199
+ # vutils.save_image(
200
+ # data_RIR, '%s/real_samples.png' % RIR_dir,
201
+ # normalize=True)
202
+ # # fake.data is still [-1, 1]
203
+ # vutils.save_image(
204
+ # fake.data, '%s/fake_samples_epoch_%03d.png' %
205
+ # (RIR_dir, epoch), normalize=True)
206
+ else:
207
+ for i in range(num):
208
+ # #print("came 2")
209
+ fake_RIR_path = RIR_dir+"/small_fake_sample"+str(i)+"_epoch_"+str(epoch)+".wav"
210
+ fs =16000
211
+ fake_IR = np.array(fake[i].to("cpu").detach())
212
+ f = WaveWriter(fake_RIR_path, channels=1, samplerate=fs)
213
+ f.write(np.array(fake_IR))
214
+
215
+ # write(fake_RIR_path,fs,fake[i].astype(np.float32))
216
+
217
+ # vutils.save_image(
218
+ # fake.data, '%s/lr_fake_samples_epoch_%03d.png' %
219
+ # (RIR_dir, epoch), normalize=True)
220
+
221
+
222
+ def save_model(netG, netD, epoch, model_dir):
223
+ torch.save(
224
+ netG.state_dict(),
225
+ '%s/netG_epoch_%d.pth' % (model_dir, epoch))
226
+ torch.save(
227
+ netD.state_dict(),
228
+ '%s/netD_epoch_last.pth' % (model_dir))
229
+ #print('Save G/D models')
230
+
231
+
232
+ def mkdir_p(path):
233
+ try:
234
+ os.makedirs(path)
235
+ except OSError as exc: # Python >2.5
236
+ if exc.errno == errno.EEXIST and os.path.isdir(path):
237
+ pass
238
+ else:
239
+ raise
code_new/miscc/utils.pyc ADDED
Binary file (5.71 kB). View file
 
code_new/model.py ADDED
@@ -0,0 +1,413 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.parallel
4
+ from miscc.config import cfg
5
+ from torch.autograd import Variable
6
+
7
+
8
+ def conv3x1(in_planes, out_planes, stride=1):
9
+ "3x1 convolution with padding"
10
+ kernel_length = 41
11
+ return nn.Conv1d(in_planes, out_planes, kernel_size=kernel_length, stride=stride,
12
+ padding=20, bias=False)
13
+
14
+ def old_conv3x1(in_planes, out_planes, stride=1):
15
+ "3x1 convolution with padding"
16
+ kernel_length = 3
17
+ return nn.Conv1d(in_planes, out_planes, kernel_size=kernel_length, stride=stride,
18
+ padding=1, bias=False)
19
+ # def convn3x1(in_planes, out_planes, stride=1):
20
+ # "3x1 convolution with padding"
21
+ # return nn.Conv1d(in_planes, out_planes, kernel_size=9, stride=stride,
22
+ # padding=4, bias=False)
23
+
24
+
25
+ # Upsale the spatial size by a factor of 2
26
+ def upBlock4(in_planes, out_planes):
27
+ kernel_length = 41
28
+ stride = 4
29
+ block = nn.Sequential(
30
+ # nn.Upsample(scale_factor=4, mode='nearest'),
31
+ # conv3x1(in_planes, out_planes),
32
+ nn.ConvTranspose1d(in_planes,out_planes,kernel_size=kernel_length,stride=stride, padding=19,output_padding=1),
33
+ nn.BatchNorm1d(out_planes),
34
+ # nn.ReLU(True)
35
+ nn.PReLU())
36
+ return block
37
+ def upBlock2(in_planes, out_planes):
38
+ kernel_length = 41
39
+ stride = 2
40
+ block = nn.Sequential(
41
+ # nn.Upsample(scale_factor=4, mode='nearest'),
42
+ # conv3x1(in_planes, out_planes),
43
+ nn.ConvTranspose1d(in_planes,out_planes,kernel_size=kernel_length,stride=stride, padding=20,output_padding=1),
44
+ nn.BatchNorm1d(out_planes),
45
+ # nn.ReLU(True)
46
+ nn.PReLU())
47
+ return block
48
+
49
+ def sameBlock(in_planes, out_planes):
50
+ block = nn.Sequential(
51
+ # nn.Upsample(scale_factor=4, mode='nearest'),
52
+ conv3x1(in_planes, out_planes),
53
+ nn.BatchNorm1d(out_planes),
54
+ # nn.ReLU(True)
55
+ nn.PReLU())
56
+ return block
57
+
58
+
59
+ class ResBlock(nn.Module):
60
+ def __init__(self, channel_num):
61
+ super(ResBlock, self).__init__()
62
+ self.block = nn.Sequential(
63
+ conv3x1(channel_num, channel_num),
64
+ nn.BatchNorm1d(channel_num),
65
+ # nn.ReLU(True),
66
+ nn.PReLU(),
67
+ conv3x1(channel_num, channel_num),
68
+ nn.BatchNorm1d(channel_num))
69
+ self.relu = nn.PReLU()#nn.ReLU(inplace=True)
70
+
71
+ def forward(self, x):
72
+ residual = x
73
+ out = self.block(x)
74
+ out += residual
75
+ out = self.relu(out)
76
+ return out
77
+
78
+
79
+ # class CA_NET(nn.Module): #not chnaged yet
80
+ # # some code is modified from vae examples
81
+ # # (https://github.com/pytorch/examples/blob/master/vae/main.py)
82
+ # def __init__(self):
83
+ # super(CA_NET, self).__init__()
84
+ # self.t_dim = cfg.TEXT.DIMENSION
85
+ # self.c_dim = cfg.GAN.CONDITION_DIM
86
+ # self.fc = nn.Linear(self.t_dim, self.c_dim * 2, bias=True)
87
+ # self.relu = nn.ReLU()
88
+
89
+ # def encode(self, text_embedding):
90
+ # x = self.relu(self.fc(text_embedding))
91
+ # mu = x[:, :self.c_dim]
92
+ # logvar = x[:, self.c_dim:]
93
+ # return mu, logvar
94
+
95
+ # def reparametrize(self, mu, logvar):
96
+ # std = logvar.mul(0.5).exp_()
97
+ # if cfg.CUDA:
98
+ # eps = torch.cuda.FloatTensor(std.size()).normal_()
99
+ # else:
100
+ # eps = torch.FloatTensor(std.size()).normal_()
101
+ # eps = Variable(eps)
102
+ # return eps.mul(std).add_(mu)
103
+
104
+ # def forward(self, text_embedding):
105
+ # mu, logvar = self.encode(text_embedding)
106
+ # c_code = self.reparametrize(mu, logvar)
107
+ # return c_code, mu, logvar
108
+
109
+ class COND_NET(nn.Module): #not chnaged yet
110
+ # some code is modified from vae examples
111
+ # (https://github.com/pytorch/examples/blob/master/vae/main.py)
112
+ def __init__(self):
113
+ super(COND_NET, self).__init__()
114
+ self.t_dim = cfg.TEXT.DIMENSION
115
+ self.c_dim = cfg.GAN.CONDITION_DIM
116
+ self.fc = nn.Linear(self.t_dim, self.c_dim, bias=True)
117
+ self.relu = nn.PReLU()#nn.ReLU()
118
+
119
+ def encode(self, text_embedding):
120
+ x = self.relu(self.fc(text_embedding))
121
+ # mu = x[:, :self.c_dim]
122
+ # logvar = x[:, self.c_dim:]
123
+ return x
124
+
125
+ # def reparametrize(self, mu, logvar):
126
+ # std = logvar.mul(0.5).exp_()
127
+ # if cfg.CUDA:
128
+ # eps = torch.cuda.FloatTensor(std.size()).normal_()
129
+ # else:
130
+ # eps = torch.FloatTensor(std.size()).normal_()
131
+ # eps = Variable(eps)
132
+ # return eps.mul(std).add_(mu)
133
+
134
+ def forward(self, text_embedding):
135
+ c_code = self.encode(text_embedding)
136
+ # c_code = self.reparametrize(mu, logvar)
137
+ return c_code #, mu, logvar
138
+
139
+
140
+ class D_GET_LOGITS(nn.Module): #not chnaged yet
141
+ def __init__(self, ndf, nef, bcondition=True):
142
+ super(D_GET_LOGITS, self).__init__()
143
+ self.df_dim = ndf
144
+ self.ef_dim = nef
145
+ self.bcondition = bcondition
146
+ kernel_length =41
147
+ if bcondition:
148
+ self.convd1d = nn.ConvTranspose1d(ndf*8,ndf //2,kernel_size=kernel_length,stride=1, padding=20)
149
+ # self.outlogits = nn.Sequential(
150
+ # old_conv3x1(ndf * 8 + nef, ndf * 8),
151
+ # nn.BatchNorm1d(ndf * 8),
152
+ # nn.LeakyReLU(0.2, inplace=True),
153
+ # nn.Conv1d(ndf * 8, 1, kernel_size=16, stride=4),
154
+ # # nn.Conv1d(1, 1, kernel_size=16, stride=4),
155
+ # nn.Sigmoid()
156
+ # )
157
+ self.outlogits = nn.Sequential(
158
+ old_conv3x1(ndf //2 + nef, ndf //2 ),
159
+ nn.BatchNorm1d(ndf //2 ),
160
+ nn.LeakyReLU(0.2, inplace=True),
161
+ nn.Conv1d(ndf //2 , 1, kernel_size=16, stride=4),
162
+ # nn.Conv1d(1, 1, kernel_size=16, stride=4),
163
+ nn.Sigmoid()
164
+ )
165
+ else:
166
+ # self.outlogits = nn.Sequential(
167
+ # nn.Conv1d(ndf * 8, 1, kernel_size=16, stride=4),
168
+ # # nn.Conv1d(1, 1, kernel_size=16, stride=4),
169
+ # nn.Sigmoid())
170
+ self.convd1d = nn.ConvTranspose1d(ndf*8,ndf //2,kernel_size=kernel_length,stride=1, padding=20)
171
+ self.outlogits = nn.Sequential(
172
+ nn.Conv1d(ndf // 2 , 1, kernel_size=16, stride=4),
173
+ # nn.Conv1d(1, 1, kernel_size=16, stride=4),
174
+ nn.Sigmoid())
175
+
176
+ def forward(self, h_code, c_code=None):
177
+ # conditioning output
178
+ h_code = self.convd1d(h_code)
179
+ if self.bcondition and c_code is not None:
180
+ #print("mode c_code1 ",c_code.size())
181
+ c_code = c_code.view(-1, self.ef_dim, 1)
182
+ #print("mode c_code2 ",c_code.size())
183
+
184
+ c_code = c_code.repeat(1, 1, 16)
185
+ # state size (ngf+egf) x 16
186
+ #print("mode c_code ",c_code.size())
187
+ #print("mode h_code ",h_code.size())
188
+
189
+ h_c_code = torch.cat((h_code, c_code), 1)
190
+ else:
191
+ h_c_code = h_code
192
+
193
+ output = self.outlogits(h_c_code)
194
+
195
+ return output.view(-1)
196
+
197
+
198
+ # ############# Networks for stageI GAN #############
199
+ class STAGE1_G(nn.Module):
200
+ def __init__(self):
201
+ super(STAGE1_G, self).__init__()
202
+ self.gf_dim = cfg.GAN.GF_DIM * 8
203
+ self.ef_dim = cfg.GAN.CONDITION_DIM
204
+ # self.z_dim = cfg.Z_DIM
205
+ self.define_module()
206
+
207
+ def define_module(self):
208
+ kernel_length = 41
209
+ ninput = self.ef_dim #self.z_dim + self.ef_dim
210
+ ngf = self.gf_dim
211
+ # TEXT.DIMENSION -> GAN.CONDITION_DIM
212
+ # self.ca_net = CA_NET()
213
+ self.cond_net = COND_NET()
214
+ # -> ngf x 16
215
+ self.fc = nn.Sequential(
216
+ nn.Linear(ninput, ngf * 16, bias=False),
217
+ nn.BatchNorm1d(ngf * 16),
218
+ # nn.ReLU(True)
219
+ nn.PReLU())
220
+
221
+ # ngf x 16 -> ngf/2 x 64
222
+ self.upsample1 = upBlock4(ngf, ngf // 2)
223
+ # -> ngf/4 x 256
224
+ self.upsample2 = upBlock4(ngf // 2, ngf // 4)
225
+ # -> ngf/8 x 1024
226
+ self.upsample3 = upBlock4(ngf // 4, ngf // 8)
227
+ # -> ngf/16 x 4096
228
+ self.upsample4 = upBlock2(ngf // 8, ngf // 16)
229
+ self.upsample5 = upBlock2(ngf // 16, ngf // 16)
230
+ # -> 1 x 4096
231
+ self.RIR = nn.Sequential(
232
+ nn.ConvTranspose1d(ngf // 16,1,kernel_size=kernel_length,stride=1, padding=20),
233
+ # old_conv3x1(ngf // 16, 1), # conv3x3(ngf // 16, 3),
234
+ nn.Tanh())
235
+
236
+ def forward(self, text_embedding):
237
+ # c_code, mu, logvar = self.ca_net(text_embedding)
238
+ c_code = self.cond_net(text_embedding)
239
+ # z_c_code = torch.cat((noise, c_code), 1)
240
+ h_code = self.fc(c_code)
241
+
242
+ h_code = h_code.view(-1, self.gf_dim, 16)
243
+ # #print("h_code 1 ",h_code.size())
244
+ h_code = self.upsample1(h_code)
245
+ # #print("h_code 2 ",h_code.size())
246
+ h_code = self.upsample2(h_code)
247
+ # #print("h_code 3 ",h_code.size())
248
+ h_code = self.upsample3(h_code)
249
+ # #print("h_code 4 ",h_code.size())
250
+ h_code = self.upsample4(h_code)
251
+ h_code = self.upsample5(h_code)
252
+ # #print("h_code 5 ",h_code.size())
253
+ # state size 3 x 64 x 64
254
+ fake_RIR = self.RIR(h_code)
255
+ # return None, fake_RIR, mu, logvar
256
+ #print("generator ", text_embedding.size())
257
+ return None, fake_RIR, text_embedding #c_code
258
+
259
+
260
+ class STAGE1_D(nn.Module):
261
+ def __init__(self):
262
+ super(STAGE1_D, self).__init__()
263
+ self.df_dim = cfg.GAN.DF_DIM
264
+ self.ef_dim = cfg.GAN.CONDITION_DIM
265
+ self.define_module()
266
+
267
+ def define_module(self):
268
+ ndf, nef = self.df_dim, self.ef_dim
269
+ kernel_length =41
270
+ self.encode_RIR = nn.Sequential(
271
+ nn.Conv1d(1, ndf, kernel_length, 4, 20, bias=False),
272
+ nn.LeakyReLU(0.2, inplace=True),
273
+ # state size. (ndf) x 1024
274
+ nn.Conv1d(ndf, ndf * 2, kernel_length, 4, 20, bias=False),
275
+ nn.BatchNorm1d(ndf * 2),
276
+ nn.LeakyReLU(0.2, inplace=True),
277
+ # state size (ndf*2) x 256
278
+ nn.Conv1d(ndf*2, ndf * 4, kernel_length, 4, 20, bias=False),
279
+ nn.BatchNorm1d(ndf * 4),
280
+ nn.LeakyReLU(0.2, inplace=True),
281
+ # # state size (ndf*4) x 64
282
+ nn.Conv1d(ndf*4, ndf * 8, kernel_length, 4, 20, bias=False),
283
+ nn.BatchNorm1d(ndf * 8),
284
+ # state size (ndf * 8) x 16)
285
+ nn.LeakyReLU(0.2, inplace=True)
286
+ )
287
+
288
+ self.get_cond_logits = D_GET_LOGITS(ndf, nef)
289
+ self.get_uncond_logits = None
290
+
291
+ def forward(self, RIRs):
292
+ #print("model RIRs ",RIRs.size())
293
+ RIR_embedding = self.encode_RIR(RIRs)
294
+ #print("models RIR_embedding ",RIR_embedding.size())
295
+
296
+ return RIR_embedding
297
+
298
+
299
+ # ############# Networks for stageII GAN #############
300
+ class STAGE2_G(nn.Module):
301
+ def __init__(self, STAGE1_G):
302
+ super(STAGE2_G, self).__init__()
303
+ self.gf_dim = cfg.GAN.GF_DIM
304
+ self.ef_dim = cfg.GAN.CONDITION_DIM
305
+ # self.z_dim = cfg.Z_DIM
306
+ self.STAGE1_G = STAGE1_G
307
+ # fix parameters of stageI GAN
308
+ for param in self.STAGE1_G.parameters():
309
+ param.requires_grad = False
310
+ self.define_module()
311
+
312
+ def _make_layer(self, block, channel_num):
313
+ layers = []
314
+ for i in range(cfg.GAN.R_NUM):
315
+ layers.append(block(channel_num))
316
+ return nn.Sequential(*layers)
317
+
318
+ def define_module(self):
319
+ ngf = self.gf_dim
320
+ # TEXT.DIMENSION -> GAN.CONDITION_DIM
321
+ # self.ca_net = CA_NET()
322
+ self.cond_net = COND_NET()
323
+ # --> 4ngf x 16 x 16
324
+ self.encoder = nn.Sequential(
325
+ conv3x1(1, ngf),
326
+ nn.ReLU(True),
327
+ nn.Conv1d(ngf, ngf * 2, 16, 4, 6, bias=False),
328
+ nn.BatchNorm1d(ngf * 2),
329
+ nn.ReLU(True),
330
+ nn.Conv1d(ngf * 2, ngf * 4, 16, 4, 6, bias=False),
331
+ nn.BatchNorm1d(ngf * 4),
332
+ nn.ReLU(True))
333
+ self.hr_joint = nn.Sequential(
334
+ conv3x1(self.ef_dim + ngf * 4, ngf * 4),
335
+ nn.BatchNorm1d(ngf * 4),
336
+ nn.ReLU(True))
337
+ self.residual = self._make_layer(ResBlock, ngf * 4)
338
+ # --> 2ngf x 1024
339
+ self.upsample1 = upBlock4(ngf * 4, ngf * 2)
340
+ # --> ngf x 4096
341
+ self.upsample2 = upBlock4(ngf * 2, ngf)
342
+ # --> ngf // 2 x 16384
343
+ self.upsample3 = upBlock4(ngf, ngf // 2)
344
+ # --> ngf // 4 x 16384
345
+ self.upsample4 = sameBlock(ngf // 2, ngf // 4)
346
+ # --> 1 x 16384
347
+ self.RIR = nn.Sequential(
348
+ conv3x1(ngf // 4, 1),
349
+ nn.Tanh())
350
+
351
+ def forward(self, text_embedding):
352
+ _, stage1_RIR, _= self.STAGE1_G(text_embedding)
353
+ stage1_RIR = stage1_RIR.detach()
354
+ encoded_RIR = self.encoder(stage1_RIR)
355
+
356
+ # c_code, mu, logvar = self.ca_net(text_embedding)
357
+ c_code1 = self.cond_net(text_embedding)
358
+ c_code = c_code1.view(-1, self.ef_dim, 1)
359
+ c_code = c_code.repeat(1, 1, 256) # c_code.repeat(1, 1, 16, 16)
360
+ i_c_code = torch.cat([encoded_RIR, c_code], 1)
361
+ h_code = self.hr_joint(i_c_code)
362
+ h_code = self.residual(h_code)
363
+
364
+ h_code = self.upsample1(h_code)
365
+ h_code = self.upsample2(h_code)
366
+ h_code = self.upsample3(h_code)
367
+ h_code = self.upsample4(h_code)
368
+
369
+ fake_RIR = self.RIR(h_code)
370
+ return stage1_RIR, fake_RIR, c_code1 #mu, logvar
371
+
372
+
373
+ class STAGE2_D(nn.Module):
374
+ def __init__(self):
375
+ super(STAGE2_D, self).__init__()
376
+ self.df_dim = cfg.GAN.DF_DIM
377
+ self.ef_dim = cfg.GAN.CONDITION_DIM
378
+ self.define_module()
379
+
380
+ def define_module(self):
381
+ ndf, nef = self.df_dim, self.ef_dim
382
+ self.encode_RIR = nn.Sequential(
383
+ nn.Conv1d(1, ndf, 3, 1, 1, bias=False), # 16384 * ndf
384
+ nn.LeakyReLU(0.2, inplace=True),
385
+ nn.Conv1d(ndf, ndf * 2, 16, 4, 6, bias=False),
386
+ nn.BatchNorm1d(ndf * 2),
387
+ nn.LeakyReLU(0.2, inplace=True), # 4096 * ndf * 2
388
+ nn.Conv1d(ndf * 2, ndf * 4, 16, 4, 6, bias=False),
389
+ nn.BatchNorm1d(ndf * 4),
390
+ nn.LeakyReLU(0.2, inplace=True), # 1024 * ndf * 4
391
+ nn.Conv1d(ndf * 4, ndf * 8, 16, 4, 6, bias=False),
392
+ nn.BatchNorm1d(ndf * 8),
393
+ nn.LeakyReLU(0.2, inplace=True), # 256 * ndf * 8
394
+ nn.Conv1d(ndf * 8, ndf * 16, 16, 4, 6, bias=False),
395
+ nn.BatchNorm1d(ndf * 16),
396
+ nn.LeakyReLU(0.2, inplace=True), # 64 * ndf * 16
397
+ nn.Conv1d(ndf * 16, ndf * 32, 16, 4, 6, bias=False),
398
+ nn.BatchNorm1d(ndf * 32),
399
+ nn.LeakyReLU(0.2, inplace=True), # 16 * ndf * 32
400
+ conv3x1(ndf * 32, ndf * 16),
401
+ nn.BatchNorm1d(ndf * 16),
402
+ nn.LeakyReLU(0.2, inplace=True), # 16 * ndf * 16
403
+ conv3x1(ndf * 16, ndf * 8),
404
+ nn.BatchNorm1d(ndf * 8),
405
+ nn.LeakyReLU(0.2, inplace=True) # 16 * ndf * 8
406
+ )
407
+
408
+ self.get_cond_logits = D_GET_LOGITS(ndf, nef, bcondition=True)
409
+ self.get_uncond_logits = D_GET_LOGITS(ndf, nef, bcondition=False)
410
+
411
+ def forward(self, RIRs):
412
+ RIR_embedding = self.encode_RIR(RIRs)
413
+ return RIR_embedding
code_new/single_copy.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os, fnmatch
2
+ import numpy as np
3
+ import random
4
+ import soundfile as sf
5
+ from scipy.io.wavfile import write
6
+ # import librosa
7
+ import RT60
8
+
9
+ folder_path = "/cephfs/anton/room-impulse-responses/AIR/RWCP_REVERB_AACHEN/real_rirs_isotropic_noises/"
10
+ final_path = "/cephfs/anton/room-impulse-responses/AIR/RWCP_REVERB_AACHEN/AACHEN/"
11
+ tfs =16000
12
+ file_label = open("RT60.txt","w")
13
+
14
+ for root, dirnames, filenames in os.walk(folder_path):
15
+ for filename in filenames:
16
+ if filename.endswith(".wav"):
17
+ ACE_Path = os.path.join(root, filename)
18
+ wave,fs = sf.read(ACE_Path)
19
+ channel = int(wave.size/len(wave))
20
+
21
+ if(channel == 1):
22
+ wave_single = wave #librosa.resample(wave, fs, tfs)
23
+ max_loc = np.where(wave_single == np.amax(wave_single))
24
+ min_loc = np.where(wave_single == np.amin(wave_single))
25
+ start = min(max_loc[0][0],min_loc[0][0])
26
+ wave_single =wave_single[start:len(wave_single)]
27
+ T60_val = RT60.t60_impulse(wave_single,tfs)
28
+
29
+ if(T60_val<1):
30
+ file_label.write(str(T60_val)+"\n")
31
+ save_path = final_path+ filename
32
+ write(save_path,tfs,wave_single.astype(np.float32))
33
+ else:
34
+ for n in range(channel):
35
+ wave_single = wave[:,n]#librosa.resample(wave[:,n], fs, tfs)
36
+ max_loc = np.where(wave_single == np.amax(wave_single))
37
+ min_loc = np.where(wave_single == np.amin(wave_single))
38
+ start = min(max_loc[0][0],min_loc[0][0])
39
+ wave_single =wave_single[start:len(wave_single)]
40
+ T60_val = RT60.t60_impulse(wave_single,tfs)
41
+
42
+ if(T60_val<1):
43
+ file_label.write(str(T60_val)+"\n")
44
+ save_path = final_path+filename+str(n)+".wav"
45
+ write(save_path,tfs,wave_single.astype(np.float32))
46
+
code_new/trainer.py ADDED
@@ -0,0 +1,392 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import print_function
2
+ from six.moves import range
3
+ from PIL import Image
4
+
5
+ import torch.backends.cudnn as cudnn
6
+ import torch
7
+ import torch.nn as nn
8
+ from torch.autograd import Variable
9
+ import torch.optim as optim
10
+ import os
11
+ import time
12
+
13
+ import numpy as np
14
+ import torchfile
15
+ import pickle
16
+
17
+ import soundfile as sf
18
+ import re
19
+ import math
20
+ from wavefile import WaveWriter, Format
21
+
22
+ from miscc.config import cfg
23
+ from miscc.utils import mkdir_p
24
+ from miscc.utils import weights_init
25
+ from miscc.utils import save_RIR_results, save_model
26
+ from miscc.utils import KL_loss
27
+ from miscc.utils import compute_discriminator_loss, compute_generator_loss
28
+
29
+ # from torch.utils.tensorboard import summary
30
+ # from torch.utils.tensorboard import FileWriter
31
+
32
+
33
+ class GANTrainer(object):
34
+ def __init__(self, output_dir):
35
+ if cfg.TRAIN.FLAG:
36
+ self.model_dir = os.path.join(output_dir, 'Model')
37
+ self.model_dir_RT = os.path.join(output_dir, 'Model_RT')
38
+ self.RIR_dir = os.path.join(output_dir, 'RIR')
39
+ self.log_dir = os.path.join(output_dir, 'Log')
40
+ mkdir_p(self.model_dir)
41
+ mkdir_p(self.model_dir_RT)
42
+ mkdir_p(self.RIR_dir)
43
+ mkdir_p(self.log_dir)
44
+ # self.summary_writer = FileWriter(self.log_dir)
45
+
46
+ self.max_epoch = cfg.TRAIN.MAX_EPOCH
47
+ self.snapshot_interval = cfg.TRAIN.SNAPSHOT_INTERVAL
48
+
49
+ s_gpus = cfg.GPU_ID.split(',')
50
+ self.gpus = [int(ix) for ix in s_gpus]
51
+ self.num_gpus = len(self.gpus)
52
+ self.batch_size = cfg.TRAIN.BATCH_SIZE * self.num_gpus
53
+ torch.cuda.set_device(self.gpus[0])
54
+ cudnn.benchmark = True
55
+
56
+ # ############# For training stageI GAN #############
57
+ def load_network_stageI(self):
58
+ from model import STAGE1_G, STAGE1_D
59
+ netG = STAGE1_G()
60
+ netG.apply(weights_init)
61
+ print(netG)
62
+ netD = STAGE1_D()
63
+ netD.apply(weights_init)
64
+ print(netD)
65
+
66
+ if cfg.NET_G != '':
67
+ state_dict = \
68
+ torch.load(cfg.NET_G,
69
+ map_location=lambda storage, loc: storage)
70
+ netG.load_state_dict(state_dict)
71
+ print('Load from: ', cfg.NET_G)
72
+ if cfg.NET_D != '':
73
+ state_dict = \
74
+ torch.load(cfg.NET_D,
75
+ map_location=lambda storage, loc: storage)
76
+ netD.load_state_dict(state_dict)
77
+ print('Load from: ', cfg.NET_D)
78
+ if cfg.CUDA:
79
+ netG.cuda()
80
+ netD.cuda()
81
+ return netG, netD
82
+
83
+ # ############# For training stageII GAN #############
84
+ def load_network_stageII(self):
85
+ from model import STAGE1_G, STAGE2_G, STAGE2_D
86
+
87
+ Stage1_G = STAGE1_G()
88
+ netG = STAGE2_G(Stage1_G)
89
+ netG.apply(weights_init)
90
+ print(netG)
91
+ if cfg.NET_G != '':
92
+ state_dict = \
93
+ torch.load(cfg.NET_G,
94
+ map_location=lambda storage, loc: storage)
95
+ netG.load_state_dict(state_dict)
96
+ print('Load from: ', cfg.NET_G)
97
+ elif cfg.STAGE1_G != '':
98
+ state_dict = \
99
+ torch.load(cfg.STAGE1_G,
100
+ map_location=lambda storage, loc: storage)
101
+ netG.STAGE1_G.load_state_dict(state_dict)
102
+ print('Load from: ', cfg.STAGE1_G)
103
+ else:
104
+ print("Please give the Stage1_G path")
105
+ return
106
+
107
+ netD = STAGE2_D()
108
+ netD.apply(weights_init)
109
+ if cfg.NET_D != '':
110
+ state_dict = \
111
+ torch.load(cfg.NET_D,
112
+ map_location=lambda storage, loc: storage)
113
+ netD.load_state_dict(state_dict)
114
+ print('Load from: ', cfg.NET_D)
115
+ print(netD)
116
+
117
+ if cfg.CUDA:
118
+ netG.cuda()
119
+ netD.cuda()
120
+ return netG, netD
121
+
122
+ def train(self, data_loader, stage=1):
123
+ if stage == 1:
124
+ netG, netD = self.load_network_stageI()
125
+ else:
126
+ netG, netD = self.load_network_stageII()
127
+
128
+ # nz = cfg.Z_DIM
129
+ batch_size = self.batch_size
130
+ # noise = Variable(torch.FloatTensor(batch_size, nz))
131
+ # fixed_noise = \
132
+ # Variable(torch.FloatTensor(batch_size, nz).normal_(0, 1),
133
+ # volatile=True)
134
+ real_labels = Variable(torch.FloatTensor(batch_size).fill_(1))
135
+ fake_labels = Variable(torch.FloatTensor(batch_size).fill_(0))
136
+ if cfg.CUDA:
137
+ # noise, fixed_noise = noise.cuda(), fixed_noise.cuda()
138
+ real_labels, fake_labels = real_labels.cuda(), fake_labels.cuda()
139
+
140
+ generator_lr = cfg.TRAIN.GENERATOR_LR
141
+ discriminator_lr = cfg.TRAIN.DISCRIMINATOR_LR
142
+ lr_decay_step = cfg.TRAIN.LR_DECAY_EPOCH
143
+ # optimizerD = \
144
+ # optim.Adam(netD.parameters(),
145
+ # lr=cfg.TRAIN.DISCRIMINATOR_LR, betas=(0.5, 0.999))
146
+ optimizerD = \
147
+ optim.RMSprop(netD.parameters(),
148
+ lr=cfg.TRAIN.DISCRIMINATOR_LR)
149
+ netG_para = []
150
+ for p in netG.parameters():
151
+ if p.requires_grad:
152
+ netG_para.append(p)
153
+ # optimizerG = optim.Adam(netG_para,
154
+ # lr=cfg.TRAIN.GENERATOR_LR,
155
+ # betas=(0.5, 0.999))
156
+ optimizerG = optim.RMSprop(netG_para,
157
+ lr=cfg.TRAIN.GENERATOR_LR)
158
+ count = 0
159
+ least_RT=10
160
+ for epoch in range(self.max_epoch):
161
+ start_t = time.time()
162
+ if epoch % lr_decay_step == 0 and epoch > 0:
163
+ generator_lr *= 0.7#0.5
164
+ for param_group in optimizerG.param_groups:
165
+ param_group['lr'] = generator_lr
166
+ discriminator_lr *= 0.7#0.5
167
+ for param_group in optimizerD.param_groups:
168
+ param_group['lr'] = discriminator_lr
169
+
170
+ for i, data in enumerate(data_loader, 0):
171
+ ######################################################
172
+ # (1) Prepare training data
173
+ ######################################################
174
+ real_RIR_cpu, txt_embedding = data
175
+ real_RIRs = Variable(real_RIR_cpu)
176
+ txt_embedding = Variable(txt_embedding)
177
+ if cfg.CUDA:
178
+ real_RIRs = real_RIRs.cuda()
179
+ txt_embedding = txt_embedding.cuda()
180
+ #print("trianer RIRs ",real_RIRs.size())
181
+ #print("trianer embedding ",txt_embedding.size())
182
+
183
+ #######################################################
184
+ # (2) Generate fake images
185
+ ######################################################
186
+ # noise.data.normal_(0, 1)
187
+ # inputs = (txt_embedding, noise)
188
+ inputs = (txt_embedding)
189
+ # _, fake_RIRs, mu, logvar = \
190
+ # nn.parallel.data_parallel(netG, inputs, self.gpus)
191
+ _, fake_RIRs,c_code = nn.parallel.data_parallel(netG, inputs, self.gpus)
192
+
193
+ ############################
194
+ # (3) Update D network
195
+ ###########################
196
+ netD.zero_grad()
197
+ errD, errD_real, errD_wrong, errD_fake = \
198
+ compute_discriminator_loss(netD, real_RIRs, fake_RIRs,
199
+ real_labels, fake_labels,
200
+ c_code, self.gpus)
201
+
202
+ errD_total = errD*5
203
+ errD_total.backward()
204
+ optimizerD.step()
205
+ ############################
206
+ # (2) Update G network
207
+ ###########################
208
+ # kl_loss = KL_loss(mu, logvar)
209
+ netG.zero_grad()
210
+ errG,MSE_error,RT_error= compute_generator_loss(epoch,netD,real_RIRs, fake_RIRs,
211
+ real_labels, c_code, self.gpus)
212
+ errG_total = errG *5#+ kl_loss * cfg.TRAIN.COEFF.KL
213
+ errG_total.backward()
214
+ optimizerG.step()
215
+ for p in range(2):
216
+ inputs = (txt_embedding)
217
+ # _, fake_RIRs, mu, logvar = \
218
+ # nn.parallel.data_parallel(netG, inputs, self.gpus)
219
+ _, fake_RIRs,c_code = nn.parallel.data_parallel(netG, inputs, self.gpus)
220
+ netG.zero_grad()
221
+ errG,MSE_error,RT_error = compute_generator_loss(epoch,netD,real_RIRs, fake_RIRs,
222
+ real_labels, c_code, self.gpus)
223
+ # kl_loss = KL_loss(mu, logvar)
224
+ errG_total = errG *5#+ kl_loss * cfg.TRAIN.COEFF.KL
225
+ errG_total.backward()
226
+ optimizerG.step()
227
+
228
+ count = count + 1
229
+ if i % 100 == 0:
230
+ # summary_D = summary.scalar('D_loss', errD.data[0])
231
+ # summary_D_r = summary.scalar('D_loss_real', errD_real)
232
+ # summary_D_w = summary.scalar('D_loss_wrong', errD_wrong)
233
+ # summary_D_f = summary.scalar('D_loss_fake', errD_fake)
234
+ # summary_G = summary.scalar('G_loss', errG.data[0])
235
+ # summary_KL = summary.scalar('KL_loss', kl_loss.data[0])
236
+ # summary_D = summary.scalar('D_loss', errD.data)
237
+ # summary_D_r = summary.scalar('D_loss_real', errD_real)
238
+ # summary_D_w = summary.scalar('D_loss_wrong', errD_wrong)
239
+ # summary_D_f = summary.scalar('D_loss_fake', errD_fake)
240
+ # summary_G = summary.scalar('G_loss', errG.data)
241
+ # summary_KL = summary.scalar('KL_loss', kl_loss.data)
242
+
243
+ # self.summary_writer.add_summary(summary_D, count)
244
+ # self.summary_writer.add_summary(summary_D_r, count)
245
+ # self.summary_writer.add_summary(summary_D_w, count)
246
+ # self.summary_writer.add_summary(summary_D_f, count)
247
+ # self.summary_writer.add_summary(summary_G, count)
248
+ # self.summary_writer.add_summary(summary_KL, count)
249
+
250
+ # save the image result for each epoch
251
+ inputs = (txt_embedding)
252
+ lr_fake, fake, _ = \
253
+ nn.parallel.data_parallel(netG, inputs, self.gpus)
254
+ if(epoch%self.snapshot_interval==0):
255
+ save_RIR_results(real_RIR_cpu, fake, epoch, self.RIR_dir)
256
+ if lr_fake is not None:
257
+ save_RIR_results(None, lr_fake, epoch, self.RIR_dir)
258
+ end_t = time.time()
259
+ # print('''[%d/%d][%d/%d] Loss_D: %.4f Loss_G: %.4f Loss_KL: %.4f
260
+ # Loss_real: %.4f Loss_wrong:%.4f Loss_fake %.4f
261
+ # Total Time: %.2fsec
262
+ # '''
263
+ # % (epoch, self.max_epoch, i, len(data_loader),
264
+ # errD.data[0], errG.data[0], kl_loss.data[0],
265
+ # errD_real, errD_wrong, errD_fake, (end_t - start_t)))
266
+ # print('''[%d/%d][%d/%d] Loss_D: %.4f Loss_G: %.4f Loss_KL: %.4f
267
+ # Loss_real: %.4f Loss_wrong:%.4f Loss_fake %.4f
268
+ # Total Time: %.2fsec
269
+ # '''
270
+ # % (epoch, self.max_epoch, i, len(data_loader),
271
+ # errD.data, errG.data, kl_loss.data,
272
+ # errD_real, errD_wrong, errD_fake, (end_t - start_t)))
273
+ print('''[%d/%d][%d/%d] Loss_D: %.4f Loss_G: %.4f
274
+ Loss_real: %.4f Loss_wrong:%.4f Loss_fake %.4f MSE_ERROR %.4f RT_error %.4f
275
+ Total Time: %.2fsec
276
+ '''
277
+ % (epoch, self.max_epoch, i, len(data_loader),
278
+ errD.data, errG.data,
279
+ errD_real, errD_wrong, errD_fake,MSE_error*4096, RT_error,(end_t - start_t)))
280
+
281
+ store_to_file ="[{}/{}][{}/{}] Loss_D: {:.4f} Loss_G: {:.4f} Loss_real: {:.4f} Loss_wrong:{:.4f} Loss_fake {:.4f} MSE Error:{:.4f} RT_error{:.4f} Total Time: {:.2f}sec".format(epoch, self.max_epoch, i, len(data_loader),
282
+ errD.data, errG.data, errD_real, errD_wrong, errD_fake,MSE_error*4096,RT_error, (end_t - start_t))
283
+ store_to_file =store_to_file+"\n"
284
+ with open("errors.txt", "a") as myfile:
285
+ myfile.write(store_to_file)
286
+
287
+ if (RT_error<least_RT):
288
+ least_RT = RT_error
289
+ save_model(netG, netD, epoch, self.model_dir_RT)
290
+ if epoch % self.snapshot_interval == 0:
291
+ save_model(netG, netD, epoch, self.model_dir)
292
+ #
293
+ save_model(netG, netD, self.max_epoch, self.model_dir)
294
+ #
295
+ # self.summary_writer.close()
296
+
297
+
298
+ def sample(self,file_path,stage=1):
299
+ if stage == 1:
300
+ netG, _ = self.load_network_stageI()
301
+ else:
302
+ netG, _ = self.load_network_stageII()
303
+ netG.eval()
304
+
305
+ time_list =[]
306
+
307
+
308
+
309
+
310
+ embedding_path = file_path
311
+ with open(embedding_path, 'rb') as f:
312
+ embeddings_pickle = pickle.load(f)
313
+
314
+
315
+
316
+ embeddings_list =[]
317
+ num_embeddings = len(embeddings_pickle)
318
+ for b in range (num_embeddings):
319
+ embeddings_list.append(embeddings_pickle[b])
320
+
321
+ embeddings = np.array(embeddings_list)
322
+
323
+ save_dir_GAN = "Generated_RIRs"
324
+ mkdir_p(save_dir_GAN)
325
+
326
+
327
+
328
+ normalize_embedding = []
329
+
330
+
331
+ batch_size = np.minimum(num_embeddings, self.batch_size)
332
+
333
+
334
+ count = 0
335
+ count_this = 0
336
+ while count < num_embeddings:
337
+
338
+ iend = count + batch_size
339
+ if iend > num_embeddings:
340
+ iend = num_embeddings
341
+ count = num_embeddings - batch_size
342
+ embeddings_batch = embeddings[count:iend]
343
+
344
+
345
+
346
+ txt_embedding = Variable(torch.FloatTensor(embeddings_batch))
347
+ if cfg.CUDA:
348
+ txt_embedding = txt_embedding.cuda()
349
+
350
+ #######################################################
351
+ # (2) Generate fake images
352
+ ######################################################
353
+ start_t = time.time()
354
+ inputs = (txt_embedding)
355
+ _, fake_RIRs,c_code = \
356
+ nn.parallel.data_parallel(netG, inputs, self.gpus)
357
+ end_t = time.time()
358
+ diff_t = end_t - start_t
359
+ time_list.append(diff_t)
360
+
361
+ RIR_batch_size = batch_size #int(batch_size/2)
362
+ print("batch_size ", RIR_batch_size)
363
+ channel_size = 64
364
+
365
+ for i in range(channel_size):
366
+ fs =16000
367
+ wave_name = "RIR-"+str(count+i)+".wav"
368
+ save_name_GAN = '%s/%s' % (save_dir_GAN,wave_name)
369
+ print("wave : ",save_name_GAN)
370
+ res = {}
371
+ res_buffer = []
372
+ rate = 16000
373
+ res['rate'] = rate
374
+
375
+ wave_GAN = fake_RIRs[i].data.cpu().numpy()
376
+ wave_GAN = np.array(wave_GAN[0])
377
+
378
+
379
+ res_buffer.append(wave_GAN)
380
+ res['samples'] = np.zeros((len(res_buffer), np.max([len(ps) for ps in res_buffer])))
381
+ for i, c in enumerate(res_buffer):
382
+ res['samples'][i, :len(c)] = c
383
+
384
+ w = WaveWriter(save_name_GAN, channels=np.shape(res['samples'])[0], samplerate=int(res['rate']))
385
+ w.write(np.array(res['samples']))
386
+
387
+ print("counter = ",count)
388
+ count = count+64
389
+ count_this = count_this+1
390
+
391
+
392
+
download_data.sh ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ gdown https://drive.google.com/uc?id=17NF1MVtXaWe9zhqWJqmG5tFUZb_9X0M5
2
+ unzip data.zip
3
+ mkdir output
download_generate.sh ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ gdown https://drive.google.com/uc?id=1XOyzsZD3s_pkZBlWcH3KtCR9YpjRVbHG
2
+ unzip generate.zip
example1.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import numpy as np
3
+ import random
4
+ import argparse
5
+ import pickle
6
+
7
+ normalize_geometry_embeddings_list =[]
8
+
9
+ for n in range(960):
10
+
11
+ lx = (8/960)*n + 0.5
12
+ geometry_embeddings= [lx,3.5,1.5,8.8,3.5,1.5,9,7,3,0.35]
13
+ max_dimension = 5
14
+ normalize_geometry_embeddings =np.divide(geometry_embeddings,max_dimension)-1
15
+ normalize_geometry_embeddings_list.append(normalize_geometry_embeddings)
16
+
17
+
18
+ embeddings_pickle ="example1.pickle"
19
+ with open(embeddings_pickle, 'wb') as f:
20
+ pickle.dump(normalize_geometry_embeddings_list, f, protocol=2)
slides.pptx ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f928a9b1f7bc05d972e51988104bb547e9cce25bb03f7841023807050af65875
3
+ size 4718146