kimtaeyeong1229 commited on
Commit
d4bad9d
ยท
verified ยท
1 Parent(s): 9eddf4d

Upload views/dl_lab.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. views/dl_lab.py +203 -0
views/dl_lab.py ADDED
@@ -0,0 +1,203 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """DL Lab page: PyTorch MLP training with live progress."""
2
+ import streamlit as st
3
+ import plotly.graph_objects as go
4
+ from plotly.subplots import make_subplots
5
+ import plotly.express as px
6
+ import numpy as np
7
+ import pandas as pd
8
+
9
+ from utils.data import get_train_test_data
10
+ from utils.models import train_mlp, build_sklearn_model, train_sklearn_model, XGBOOST_AVAILABLE
11
+
12
+ NEEDS_SCALING = {'SVM (RBF)', 'KNN', 'Logistic Regression'}
13
+
14
+
15
+ def show():
16
+ st.title("๋”ฅ๋Ÿฌ๋‹ ์‹ค์Šต โ€” PyTorch MLP")
17
+ st.markdown("์‹ ๊ฒฝ๋ง ๊ตฌ์กฐ์™€ ํ•™์Šต ํŒŒ๋ผ๋ฏธํ„ฐ๋ฅผ ์ง์ ‘ ์„ค์ •ํ•˜๊ณ , ์—ํฌํฌ๋ณ„ ํ•™์Šต ๊ณผ์ •์„ ์‹ค์‹œ๊ฐ„์œผ๋กœ ํ™•์ธํ•˜์„ธ์š”.")
18
+
19
+ # Concept explanation
20
+ with st.expander("ํผ์…‰ํŠธ๋ก  โ†’ MLP ๊ฐœ๋… ์„ค๋ช…"):
21
+ st.markdown("""
22
+ ### ํผ์…‰ํŠธ๋ก  (1957, Rosenblatt)
23
+ ์ƒ๋ฌผ ๋‰ด๋Ÿฐ์„ ๋ชจ๋ฐฉํ•œ ์ตœ์ดˆ์˜ ์ธ๊ณต ๋‰ด๋Ÿฐ:
24
+
25
+ $$z = w_1x_1 + w_2x_2 + \\cdots + w_nx_n + b = w^Tx + b$$
26
+ $$\\text{output} = \\sigma(z)$$
27
+
28
+ **ํ•œ๊ณ„**: XOR ๋ฌธ์ œ ํ•ด๊ฒฐ ๋ถˆ๊ฐ€ (์„ ํ˜• ๋ถ„๋ฆฌ ๋ถˆ๊ฐ€)
29
+
30
+ ### MLP: ํผ์…‰ํŠธ๋ก ์„ ์—ฌ๋Ÿฌ ์ธต์œผ๋กœ ์Œ“๊ธฐ
31
+ ```
32
+ ์ž…๋ ฅ์ธต(8) โ†’ ์€๋‹‰์ธต1(64, ReLU) โ†’ ์€๋‹‰์ธต2(32, ReLU) โ†’ ์ถœ๋ ฅ์ธต(1, Sigmoid)
33
+ ```
34
+
35
+ ### ํ•™์Šต ๊ณผ์ • (์—ญ์ „ํŒŒ)
36
+ 1. **์ˆœ์ „ํŒŒ**: ์ž…๋ ฅ โ†’ ์ถœ๋ ฅ โ†’ ์†์‹ค ๊ณ„์‚ฐ
37
+ 2. **์—ญ์ „ํŒŒ**: ์†์‹ค โ†’ ๊ธฐ์šธ๊ธฐ ๊ณ„์‚ฐ(๋ฏธ๋ถ„) โ†’ ๊ฐ€์ค‘์น˜ ์—…๋ฐ์ดํŠธ
38
+ $$w \\leftarrow w - \\eta \\cdot \\frac{\\partial L}{\\partial w}$$
39
+
40
+ ### ํ™œ์„ฑํ™” ํ•จ์ˆ˜ ๋น„๊ต
41
+ | ํ•จ์ˆ˜ | ์ˆ˜์‹ | ์šฉ๋„ |
42
+ |------|------|------|
43
+ | **ReLU** | $\\max(0, x)$ | ์€๋‹‰์ธต ๊ธฐ๋ณธ๊ฐ’ |
44
+ | **Sigmoid** | $\\frac{1}{1+e^{-x}}$ | ์ด์ง„๋ถ„๋ฅ˜ ์ถœ๋ ฅ์ธต |
45
+ | **BatchNorm** | ๋ฐฐ์น˜ ์ •๊ทœํ™” | ํ•™์Šต ์•ˆ์ •ํ™” |
46
+ | **Dropout** | ๋žœ๋ค ๋‰ด๋Ÿฐ ์ œ๊ฑฐ | ๊ณผ์ ํ•ฉ ๋ฐฉ์ง€ |
47
+ """)
48
+
49
+ st.markdown("---")
50
+ st.subheader("ํ•˜์ดํผํŒŒ๋ผ๋ฏธํ„ฐ ์„ค์ •")
51
+
52
+ col1, col2 = st.columns(2)
53
+
54
+ with col1:
55
+ st.markdown("**๋„คํŠธ์›Œํฌ ๊ตฌ์กฐ**")
56
+ h1 = st.slider("์€๋‹‰์ธต 1 ํฌ๊ธฐ", 16, 256, 64, step=16)
57
+ h2 = st.slider("์€๋‹‰์ธต 2 ํฌ๊ธฐ", 8, 128, 32, step=8)
58
+ add_h3 = st.checkbox("์€๋‹‰์ธต 3 ์ถ”๊ฐ€", value=False)
59
+ h3 = st.slider("์€๋‹‰์ธต 3 ํฌ๊ธฐ", 8, 64, 16, step=8) if add_h3 else None
60
+ dropout = st.slider("Dropout ๋น„์œจ", 0.0, 0.7, 0.3, step=0.05)
61
+
62
+ with col2:
63
+ st.markdown("**ํ•™์Šต ์„ค์ •**")
64
+ epochs = st.slider("Epochs (์—ํฌํฌ ์ˆ˜)", 10, 200, 100, step=10)
65
+ lr = st.select_slider(
66
+ "ํ•™์Šต๋ฅ  (Learning Rate)",
67
+ options=[0.0001, 0.0005, 0.001, 0.005, 0.01, 0.05],
68
+ value=0.001,
69
+ )
70
+ batch_size = st.select_slider("Batch Size", options=[16, 32, 64, 128], value=32)
71
+
72
+ hidden_dims = [h1, h2] + ([h3] if h3 else [])
73
+
74
+ # Model architecture preview
75
+ arch_str = f"์ž…๋ ฅ(8) โ†’ {' โ†’ '.join([str(h) for h in hidden_dims])} โ†’ ์ถœ๋ ฅ(1)"
76
+ st.info(f"**๋„คํŠธ์›Œํฌ ๊ตฌ์กฐ**: {arch_str}")
77
+
78
+ total_params = 8 * hidden_dims[0]
79
+ for i in range(len(hidden_dims) - 1):
80
+ total_params += hidden_dims[i] * hidden_dims[i + 1]
81
+ total_params += hidden_dims[-1] * 1
82
+ st.caption(f"์˜ˆ์ƒ ํŒŒ๋ผ๋ฏธํ„ฐ ์ˆ˜ (์„ ํ˜• ๋ ˆ์ด์–ด๋งŒ): ~{total_params:,}๊ฐœ")
83
+
84
+ st.markdown("---")
85
+
86
+ if st.button("ํ•™์Šต ์‹œ์ž‘", type="primary", use_container_width=True):
87
+ X_train, X_test, y_train, y_test, X_tr_sc, X_te_sc, _ = get_train_test_data()
88
+
89
+ # Live progress placeholders
90
+ progress_bar = st.progress(0, text="ํ•™์Šต ์‹œ์ž‘...")
91
+ col_loss, col_acc = st.columns(2)
92
+ loss_placeholder = col_loss.empty()
93
+ acc_placeholder = col_acc.empty()
94
+ metrics_placeholder = st.empty()
95
+
96
+ history = {
97
+ 'epoch': [],
98
+ 'train_loss': [], 'test_loss': [],
99
+ 'train_acc': [], 'test_acc': [],
100
+ }
101
+
102
+ def progress_callback(epoch, total, tr_loss, tr_acc, te_loss, te_acc):
103
+ history['epoch'].append(epoch)
104
+ history['train_loss'].append(tr_loss)
105
+ history['test_loss'].append(te_loss)
106
+ history['train_acc'].append(tr_acc)
107
+ history['test_acc'].append(te_acc)
108
+
109
+ progress_bar.progress(epoch / total, text=f"Epoch {epoch}/{total}")
110
+
111
+ # Loss chart
112
+ fig_l = go.Figure()
113
+ fig_l.add_trace(go.Scatter(x=history['epoch'], y=history['train_loss'],
114
+ name='ํ•™์Šต ์†์‹ค', line=dict(color='#2ecc71')))
115
+ fig_l.add_trace(go.Scatter(x=history['epoch'], y=history['test_loss'],
116
+ name='ํ…Œ์ŠคํŠธ ์†์‹ค', line=dict(color='#e74c3c')))
117
+ fig_l.update_layout(title='์†์‹ค (BCE Loss)', xaxis_title='Epoch',
118
+ yaxis_title='Loss', height=280, margin=dict(t=40, b=30))
119
+ loss_placeholder.plotly_chart(fig_l, use_container_width=True)
120
+
121
+ # Accuracy chart
122
+ fig_a = go.Figure()
123
+ fig_a.add_trace(go.Scatter(x=history['epoch'], y=history['train_acc'],
124
+ name='ํ•™์Šต ์ •ํ™•๋„', line=dict(color='#2ecc71')))
125
+ fig_a.add_trace(go.Scatter(x=history['epoch'], y=history['test_acc'],
126
+ name='ํ…Œ์ŠคํŠธ ์ •ํ™•๋„', line=dict(color='#e74c3c')))
127
+ fig_a.update_layout(title='์ •ํ™•๋„', xaxis_title='Epoch',
128
+ yaxis_title='Accuracy', height=280, margin=dict(t=40, b=30))
129
+ acc_placeholder.plotly_chart(fig_a, use_container_width=True)
130
+
131
+ if epoch % 10 == 0 or epoch == total:
132
+ metrics_placeholder.markdown(
133
+ f"**Epoch {epoch}/{total}** | "
134
+ f"Train Loss: `{tr_loss:.4f}` Acc: `{tr_acc*100:.1f}%` | "
135
+ f"Test Loss: `{te_loss:.4f}` Acc: `{te_acc*100:.1f}%`"
136
+ )
137
+
138
+ result = train_mlp(
139
+ X_tr_sc, X_te_sc, y_train, y_test,
140
+ hidden_dims=hidden_dims,
141
+ epochs=epochs,
142
+ lr=lr,
143
+ batch_size=batch_size,
144
+ dropout=dropout,
145
+ progress_callback=progress_callback,
146
+ )
147
+
148
+ progress_bar.empty()
149
+
150
+ # Final metrics
151
+ st.markdown("---")
152
+ st.subheader("์ตœ์ข… ๊ฒฐ๊ณผ")
153
+ c1, c2, c3 = st.columns(3)
154
+ c1.metric("์ตœ์ข… ํ…Œ์ŠคํŠธ ์ •ํ™•๋„", f"{result['final_acc']*100:.2f}%")
155
+ c2.metric("์ตœ๊ณ  ํ…Œ์ŠคํŠธ ์ •ํ™•๋„", f"{max(result['test_accs'])*100:.2f}%",
156
+ f"Epoch {result['test_accs'].index(max(result['test_accs']))+1}")
157
+ c3.metric("์ˆ˜๋ ด ํŒ์ •", "์ˆ˜๋ ด" if abs(result['test_accs'][-1] - result['test_accs'][-10]) < 0.01
158
+ else "๋ฏธ์ˆ˜๋ ด", help="๋งˆ์ง€๋ง‰ 10 ์—ํฌํฌ ๋ณ€ํ™”๋Ÿ‰ < 1%")
159
+
160
+ # Confusion matrix
161
+ st.subheader("ํ˜ผ๋™ ํ–‰๋ ฌ")
162
+ cm = result['confusion_matrix']
163
+ fig_cm = px.imshow(
164
+ cm, text_auto=True,
165
+ x=['์˜ˆ์ธก: ์‚ฌ๋ง', '์˜ˆ์ธก: ์ƒ์กด'],
166
+ y=['์‹ค์ œ: ์‚ฌ๋ง', '์‹ค์ œ: ์ƒ์กด'],
167
+ color_continuous_scale='Blues',
168
+ title='MLP (PyTorch) โ€” ํ˜ผ๋™ ํ–‰๋ ฌ',
169
+ )
170
+ fig_cm.update_layout(coloraxis_showscale=False)
171
+ st.plotly_chart(fig_cm, use_container_width=True)
172
+
173
+ # Compare with ML models
174
+ st.markdown("---")
175
+ st.subheader("ML ๋ชจ๋ธ๊ณผ ์„ฑ๋Šฅ ๋น„๊ต")
176
+ compare_algos = ['Logistic Regression', 'Random Forest', 'Gradient Boosting']
177
+ if XGBOOST_AVAILABLE:
178
+ compare_algos.append('XGBoost')
179
+
180
+ cmp_results = {'MLP (PyTorch)': result['final_acc']}
181
+ for a in compare_algos:
182
+ use_sc = a in NEEDS_SCALING
183
+ X_tr = X_tr_sc if use_sc else X_train.values
184
+ X_te = X_te_sc if use_sc else X_test.values
185
+ m = build_sklearn_model(a, {})
186
+ r = train_sklearn_model(m, X_tr, X_te, y_train, y_test)
187
+ cmp_results[a] = r['accuracy']
188
+
189
+ cmp_df = pd.DataFrame([
190
+ {'๋ชจ๋ธ': k, '์ •ํ™•๋„': v} for k, v in sorted(cmp_results.items(), key=lambda x: -x[1])
191
+ ])
192
+
193
+ fig_bar = px.bar(
194
+ cmp_df, x='์ •ํ™•๋„', y='๋ชจ๋ธ', orientation='h',
195
+ text_auto='.3f',
196
+ color='์ •ํ™•๋„', color_continuous_scale='RdYlGn',
197
+ title='MLP vs ML ๋ชจ๋ธ ๋น„๊ต',
198
+ range_x=[0.6, 0.95],
199
+ )
200
+ fig_bar.update_layout(coloraxis_showscale=False)
201
+ st.plotly_chart(fig_bar, use_container_width=True)
202
+
203
+ st.success(f"MLP ํ•™์Šต ์™„๋ฃŒ! ์ตœ์ข… ํ…Œ์ŠคํŠธ ์ •ํ™•๋„: **{result['final_acc']*100:.2f}%**")