projectlosangeles commited on
Commit
8dbc50f
·
verified ·
1 Parent(s): 31492b9

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +32 -177
app.py CHANGED
@@ -96,111 +96,26 @@ print('=' * 70)
96
  # =================================================================================================
97
 
98
  @spaces.GPU
99
- def generate_drums(chords,
100
- drums_biases,
101
- banned_drums,
102
- input_temperature,
103
- input_top_p_value
104
- ):
105
 
106
- flat_chords = TMIDIX.flatten(chords)
107
 
108
- tries = 0
109
- max_tries = 30
110
- good_seq = False
111
- pitches_set = set()
112
-
113
- while not good_seq and tries < max_tries:
114
-
115
- ci_num = random.choice([37, 42])
116
-
117
- input_seq = [640] + flat_chords + [641] + [0, ci_num+384, 127+512]
118
-
119
- dcount = 0
120
-
121
- for i in tqdm.tqdm(range(len(chords))):
122
-
123
- y = 0
124
- dtime = 0
125
-
126
- if 0 < i < len(chords):
127
- next_dtime = chords[i][0]
128
-
129
- elif i == len(chords):
130
- next_dtime = 1
131
-
132
- else:
133
- next_dtime = 512
134
-
135
- while y < 128 or 384 < y < 642:
136
 
137
- x = torch.LongTensor(input_seq).to(DEVICE)
138
-
139
- with ctx:
140
- out = model.generate_advanced(x,
141
- 1,
142
- temperature=input_temperature,
143
- filter_logits_fn=top_p,
144
- filter_kwargs={'thres': input_top_p_value},
145
- logits_bias=drums_biases,
146
- masked_tokens=banned_drums,
147
- return_prime=False,
148
- verbose=False
149
- )
150
 
151
- y = out.tolist()[0]
152
-
153
- if y < 128 or 384 < y < 642:
154
-
155
- if y < 128 and dtime+y >= next_dtime:
156
- input_seq.extend([chords[i][0]-dtime] + chords[i][1:])
157
- dtime = 0
158
- break
159
-
160
- if y < 128 and dtime+y < next_dtime and next_dtime != 512:
161
- dtime += y
162
-
163
- input_seq.append(y)
164
-
165
- if 384 < y < 512:
166
- dcount += 1
167
- pitches_set.add(y)
168
-
169
- else:
170
- if i != 0:
171
- input_seq[-1] = chords[i][0]-dtime
172
- input_seq.extend(chords[i][1:])
173
- dtime = 0
174
-
175
- if i == 0 and 256 < y < 384:
176
- print('Bad sequence! (Bad start chord)')
177
- tries += 1
178
- print('Retry attempt:', tries)
179
- break
180
-
181
- if i == len(chords) // 4:
182
- if dcount < i // 2:
183
- print('Bad sequence! (Insufficient drums density)')
184
- tries += 1
185
- print('Retry attempt:', tries)
186
- break
187
-
188
- if len(pitches_set) < 3:
189
- print('Bad sequence! (Insufficient drums types)')
190
- tries += 1
191
- print('Retry attempt:', tries)
192
- break
193
-
194
- if i == len(chords)-1:
195
- if dcount > len(chords) // 2:
196
- good_seq = True
197
- print('Generated good seq!')
198
-
199
- else:
200
- print('Bad sequence! (Insufficient total drums count)')
201
-
202
- if not good_seq or tries > max_tries-1:
203
- print('Failed to generate good seq!')
204
 
205
  return input_seq[len(flat_chords)+2:]
206
 
@@ -217,101 +132,41 @@ def tokens_to_escore_notes(tokens):
217
  channel = 0
218
  patch = 0
219
 
220
- for m in tokens:
221
 
 
 
222
  if 0 <= m < 128:
223
- time += m
224
-
225
- elif 128 < m < 256:
226
- dur = (m-128)
227
-
228
- elif 256 < m < 512:
229
- pitch = (m-256) % 128
230
- chan = (m-256) // 128
231
 
232
- if chan == 0:
233
- song_f.append(['note', time, dur, 0, pitch, max(40, pitch), 0])
234
-
235
- elif 512 < m < 640:
236
- vel = m-512
237
-
238
- if chan == 1:
239
- song_f.append(['note', time, 2, 9, pitch, vel, 128])
240
 
241
  return song_f
242
 
243
  # =================================================================================================
244
 
245
- def Generate_Drums(input_midi,
246
- input_model,
247
- input_drums,
248
- input_drums_bias,
249
- input_temperature,
250
- input_top_p_value
251
- ):
252
 
253
  print('=' * 70)
254
  print('Req start time: {:%Y-%m-%d %H:%M:%S}'.format(datetime.datetime.now(PDT)))
255
  start_time = reqtime.time()
256
 
257
- fn = os.path.basename(input_midi.name)
258
- fn1 = fn.split('.')[0]
259
-
260
  print('=' * 70)
261
- print('Input file name:', fn)
262
- print('Input model:', input_model)
263
- print('input drums:', input_drums)
264
- print('Req drums bias:', input_drums_bias)
265
  print('Req model temp:', input_temperature)
266
  print('Req top_k value:', input_top_p_value)
267
  print('=' * 70)
268
 
269
- raw_score = TMIDIX.midi2single_track_ms_score(input_midi.name)
270
- escore_notes = TMIDIX.advanced_score_processor(raw_score,
271
- return_enhanced_score_notes=True,
272
- apply_sustain=True
273
- )
274
-
275
- if escore_notes and escore_notes[0]:
276
-
277
- escore_notes = TMIDIX.augment_enhanced_score_notes(escore_notes[0], timings_divider=32)
278
-
279
- escore_notes = TMIDIX.fix_escore_notes_durations(escore_notes, min_notes_gap=0)
280
-
281
- escore_notes = TMIDIX.strip_drums_from_escore_notes(escore_notes)
282
-
283
- if escore_notes:
284
-
285
- escore_notes = TMIDIX.recalculate_score_timings(escore_notes)
286
-
287
- escore_notes = TMIDIX.delta_score_to_abs_score(TMIDIX.delta_score_notes(escore_notes, timings_clip_value=127))
288
-
289
- sp_score = TMIDIX.solo_piano_escore_notes(escore_notes)
290
 
291
- if sp_score:
292
-
293
- cscore = TMIDIX.chordify_score([1000, sp_score])
294
-
295
- score = []
296
-
297
- pc = cscore[0]
298
-
299
- chords = []
300
-
301
- for c in cscore:
302
-
303
- cho = []
304
-
305
- score.append(max(0, min(127, c[0][1]-pc[0][1])))
306
- cho.append(max(0, min(127, c[0][1]-pc[0][1])))
307
-
308
- for n in c:
309
- score.extend([max(1, min(127, n[2]))+128, max(1, min(127, n[4]))+256])
310
- cho.extend([max(1, min(127, n[2]))+128, max(1, min(127, n[4]))+256])
311
-
312
- chords.append(cho)
313
-
314
- pc = c
315
 
316
 
317
  print('Score has', len(chords), 'chords')
 
96
  # =================================================================================================
97
 
98
  @spaces.GPU
99
+ def generate_chords(chords,
100
+ input_temperature,
101
+ input_top_p_value
102
+ ):
 
 
103
 
 
104
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
105
 
106
+ x = torch.LongTensor(input_seq).to(DEVICE)
107
+
108
+ with ctx:
109
+ out = model.generate(x,
110
+ 1,
111
+ temperature=input_temperature,
112
+ filter_logits_fn=top_p,
113
+ filter_kwargs={'thres': input_top_p_value},
114
+ return_prime=False,
115
+ verbose=False
116
+ )
 
 
117
 
118
+ y = out.tolist()[0]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
119
 
120
  return input_seq[len(flat_chords)+2:]
121
 
 
132
  channel = 0
133
  patch = 0
134
 
135
+ patches = [0] * 16
136
 
137
+ for m in tokens:
138
+
139
  if 0 <= m < 128:
140
+ time += m * 32
 
 
 
 
 
 
 
141
 
142
+ elif 461 < m < 589:
143
+ pitch = (m-461)
144
+
145
+ elif 589 < m < 717:
146
+ dur = (m-589) * 32
147
+ song_f.append(['note', time, dur, 0, pitch, max(40, pitch), 0])
 
 
148
 
149
  return song_f
150
 
151
  # =================================================================================================
152
 
153
+ def Generate_Chords(input_midi,
154
+ input_chords,
155
+ input_temperature,
156
+ input_top_p_value
157
+ ):
 
 
158
 
159
  print('=' * 70)
160
  print('Req start time: {:%Y-%m-%d %H:%M:%S}'.format(datetime.datetime.now(PDT)))
161
  start_time = reqtime.time()
162
 
 
 
 
163
  print('=' * 70)
164
+ print('Input chords:', input_chords)
 
 
 
165
  print('Req model temp:', input_temperature)
166
  print('Req top_k value:', input_top_p_value)
167
  print('=' * 70)
168
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
169
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
170
 
171
 
172
  print('Score has', len(chords), 'chords')