projectlosangeles commited on
Commit
1423669
·
verified ·
1 Parent(s): 8dbc50f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +52 -13
app.py CHANGED
@@ -41,16 +41,14 @@ from huggingface_hub import hf_hub_download
41
  print('=' * 70)
42
  print('Loading models...')
43
  print('=' * 70)
44
- print('Loading small model...')
45
  print('=' * 70)
46
 
47
  SEQ_LEN = 2048
48
  PAD_IDX = 721
49
- DEVICE = 'cuda' # 'cpu'
50
-
51
- # instantiate the model
52
 
53
- model = TransformerWrapper(
54
  num_tokens = PAD_IDX+1,
55
  max_seq_len = SEQ_LEN,
56
  attn_layers = Decoder(dim = 2048,
@@ -61,9 +59,9 @@ model = TransformerWrapper(
61
  )
62
  )
63
 
64
- model = AutoregressiveWrapper(model, ignore_index=PAD_IDX)
65
 
66
- model.to(DEVICE)
67
  print('=' * 70)
68
 
69
  print('Loading model checkpoint...')
@@ -73,9 +71,11 @@ checkpoint = hf_hub_download(
73
  filename='Chordified_Piano_Transformer_Texturing_Trained_Model_18092_steps_0.7058_loss_0.7977_acc.pth'
74
  )
75
 
76
- model.load_state_dict(torch.load(checkpoint, map_location=DEVICE, weights_only=True))
77
 
78
- model.eval()
 
 
79
 
80
  print('=' * 70)
81
  print('Done!')
@@ -83,10 +83,49 @@ print('=' * 70)
83
 
84
  # =================================================================================================
85
 
86
- if DEVICE == 'cpu':
87
- dtype = torch.bfloat16
88
- else:
89
- dtype = torch.bfloat16
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
90
 
91
  ctx = torch.amp.autocast(device_type=DEVICE, dtype=dtype)
92
 
 
41
  print('=' * 70)
42
  print('Loading models...')
43
  print('=' * 70)
44
+ print('Loading chords texturing model...')
45
  print('=' * 70)
46
 
47
  SEQ_LEN = 2048
48
  PAD_IDX = 721
49
+ DEVICE = 'cuda'
 
 
50
 
51
+ tex_model = TransformerWrapper(
52
  num_tokens = PAD_IDX+1,
53
  max_seq_len = SEQ_LEN,
54
  attn_layers = Decoder(dim = 2048,
 
59
  )
60
  )
61
 
62
+ tex_model = AutoregressiveWrapper(tex_model, ignore_index=PAD_IDX)
63
 
64
+ tex_model.to(DEVICE)
65
  print('=' * 70)
66
 
67
  print('Loading model checkpoint...')
 
71
  filename='Chordified_Piano_Transformer_Texturing_Trained_Model_18092_steps_0.7058_loss_0.7977_acc.pth'
72
  )
73
 
74
+ tex_model.load_state_dict(torch.load(checkpoint, map_location=DEVICE, weights_only=True))
75
 
76
+ tex_model.eval()
77
+
78
+ tex_model = torch.compile(tex_model)
79
 
80
  print('=' * 70)
81
  print('Done!')
 
83
 
84
  # =================================================================================================
85
 
86
+ print('Loading chords progressions model...')
87
+ print('=' * 70)
88
+
89
+ SEQ_LEN = 380
90
+ PAD_IDX = 324
91
+ DEVICE = 'cuda' # 'cpu'
92
+
93
+ # instantiate the model
94
+
95
+ prg_model = TransformerWrapper(
96
+ num_tokens = PAD_IDX+1,
97
+ max_seq_len = SEQ_LEN,
98
+ attn_layers = Decoder(dim = 2048,
99
+ depth = 6,
100
+ heads = 16,
101
+ rotary_pos_emb = True,
102
+ attn_flash = True
103
+ )
104
+ )
105
+
106
+ prg_model = AutoregressiveWrapper(prg_model, ignore_index=PAD_IDX)
107
+
108
+ prg_model.to(DEVICE)
109
+ print('=' * 70)
110
+
111
+ print('Loading model checkpoint...')
112
+
113
+ checkpoint = hf_hub_download(
114
+ repo_id='asigalov61/Chordified-Piano-Transformer',
115
+ filename='Chordified_Piano_Transformer_Texturing_Trained_Model_18092_steps_0.7058_loss_0.7977_acc.pth'
116
+ )
117
+
118
+ prg_model.load_state_dict(torch.load(checkpoint, map_location=DEVICE, weights_only=True))
119
+
120
+ prg_model.eval()
121
+
122
+ prg_model = torch.compile(prg_model)
123
+
124
+ print('=' * 70)
125
+
126
+ # =================================================================================================
127
+
128
+ dtype = torch.bfloat16
129
 
130
  ctx = torch.amp.autocast(device_type=DEVICE, dtype=dtype)
131