Heshwa commited on
Commit
d78074f
·
verified ·
1 Parent(s): cabbfc7

Upload folder using huggingface_hub

Browse files
.DS_Store ADDED
Binary file (8.2 kB). View file
 
.gitattributes CHANGED
@@ -33,3 +33,6 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ output/flappy_bird_evaluation.gif filter=lfs diff=lfs merge=lfs -text
37
+ output/flappy_bird_evaluation.mp4 filter=lfs diff=lfs merge=lfs -text
38
+ output/flappy_bird_training.mp4 filter=lfs diff=lfs merge=lfs -text
README.md ADDED
@@ -0,0 +1,182 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Flappy Bird Reinforcement Learning Agent
2
+
3
+ This project implements a reinforcement learning agent for the Flappy Bird game using the REINFORCE (Policy Gradient) algorithm. The agent learns to play Flappy Bird by maximizing the cumulative reward through trial and error.
4
+
5
+ ## Features
6
+
7
+ - Policy Gradient (REINFORCE) algorithm implementation
8
+ - Neural network policy with multiple hidden layers
9
+ - Checkpoint saving and resuming for training
10
+ - Evaluation with statistical analysis
11
+ - Video recording of gameplay
12
+ - Dynamic plotting of training progress
13
+
14
+ ## Dependencies
15
+
16
+ - Python 3.13
17
+ - PyTorch
18
+ - NumPy
19
+ - Matplotlib
20
+ - Gymnasium
21
+ - Flappy Bird Gymnasium
22
+ - ImageIO
23
+ - tqdm
24
+ - PIL (Pillow)
25
+ - Hugging Face Hub (for model uploading)
26
+
27
+ ## Installation
28
+
29
+ 1. Create a virtual environment:
30
+ ```bash
31
+ python -m venv env
32
+ source env/bin/activate # On Windows: env\Scripts\activate
33
+ ```
34
+
35
+ 2. Install dependencies:
36
+ ```bash
37
+ pip install -r requirements.txt
38
+ ```
39
+
40
+ Or manually:
41
+ ```bash
42
+ pip install torch numpy matplotlib gymnasium flappy-bird-gymnasium imageio tqdm pillow huggingface-hub
43
+ ```
44
+
45
+ ## Usage
46
+
47
+ ### Training
48
+
49
+ Run the training script with default settings (trains and evaluates):
50
+ ```bash
51
+ python agent.py
52
+ ```
53
+
54
+ The agent will train for 100,000 episodes by default. Checkpoints are saved every 100 episodes to `checkpoints/checkpoint.pth`, and the final model is saved to `checkpoints/final_model.pth`.
55
+
56
+ ### Configuration
57
+
58
+ Modify the `hyperparams` dictionary in `agent.py` to adjust:
59
+
60
+ - `no_of_episodes`: Number of training episodes
61
+ - `gamma`: Discount factor
62
+ - `lr`: Learning rate
63
+ - `fps`: Frames per second for video output
64
+ - `save_videos`: Whether to save training/evaluation videos
65
+ - `train`: Whether to run training (default: True)
66
+ - `evaluate`: Whether to run evaluation (default: True)
67
+ - `push_hf`: Whether to push model to Hugging Face Hub (default: False)
68
+ - `hf_repo`: Hugging Face repository name (required if push_hf is True)
69
+ - `eval_episodes`: Number of evaluation episodes (default: 5)
70
+
71
+ ### Examples
72
+
73
+ Train only (modify hyperparams):
74
+ ```python
75
+ hyperparams["evaluate"] = False
76
+ ```
77
+
78
+ Evaluate only (modify hyperparams):
79
+ ```python
80
+ hyperparams["train"] = False
81
+ hyperparams["eval_episodes"] = 10
82
+ ```
83
+
84
+ Train and push to HF (modify hyperparams):
85
+ ```python
86
+ hyperparams["push_hf"] = True
87
+ hyperparams["hf_repo"] = "your-username/flappy-bird-rl-agent"
88
+ ```
89
+
90
+ ### Resuming Training
91
+
92
+ If checkpoints exist, the training will automatically resume from the last saved checkpoint.
93
+
94
+ ### Evaluation
95
+
96
+ The script automatically evaluates the trained model after training. To evaluate a saved model separately, modify the `evaluate_policy` call in `agent.py`.
97
+
98
+ ### Uploading to Hugging Face
99
+
100
+ To upload your trained model and project to Hugging Face Hub:
101
+
102
+ 1. **Get your Hugging Face token**:
103
+ - Go to [Hugging Face Settings](https://huggingface.co/settings/tokens)
104
+ - Create a new token with "Write" permissions
105
+ - Copy the token
106
+
107
+ 2. **Set the token as an environment variable**:
108
+ ```bash
109
+ export HF_TOKEN=your_token_here
110
+ ```
111
+
112
+ 3. **Modify hyperparams in `agent.py`**:
113
+ ```python
114
+ hyperparams["push_hf"] = True
115
+ hyperparams["hf_repo"] = "your-username/flappy-bird-rl-agent"
116
+ ```
117
+
118
+ The entire project (excluding unnecessary files) will be uploaded to your repository.
119
+
120
+ ## Configuration
121
+
122
+ Modify the `hyperparams` dictionary in `agent.py` to adjust:
123
+
124
+ - `no_of_episodes`: Number of training episodes
125
+ - `gamma`: Discount factor
126
+ - `lr`: Learning rate
127
+ - `fps`: Frames per second for video output
128
+ - `save_videos`: Whether to save training/evaluation videos
129
+
130
+ ## Output
131
+
132
+ - Training plots: `output/rewards_vs_episodes.png`, `output/loss_vs_episodes.png`
133
+ - Training video: `output/flappy_bird_training.mp4` (if `save_videos=True`)
134
+ - Evaluation video: `output/flappy_bird_evaluation.mp4`
135
+ - Checkpoints: `checkpoints/`
136
+
137
+ ## Results
138
+
139
+ ### Training Progress
140
+
141
+ ![Training Rewards](output/rewards_vs_episodes.png)
142
+ *Figure 1: Total reward per episode during training*
143
+
144
+ ![Training Loss](output/loss_vs_episodes.png)
145
+ *Figure 2: Loss per episode during training*
146
+
147
+ ### Evaluation
148
+
149
+ The evaluation video shows the trained agent playing Flappy Bird. To create a 3-second GIF from the last 3 seconds of the video:
150
+
151
+ ```bash
152
+ ffmpeg -sseof -3 -i output/flappy_bird_evaluation.mp4 -t 3 -vf "fps=10,scale=320:-1:flags=lanczos" output/flappy_bird_evaluation.gif
153
+ ```
154
+
155
+ ![Evaluation GIF](output/flappy_bird_evaluation.gif)
156
+ *Figure 3: GIF of the evaluation gameplay*
157
+
158
+ ### Final Evaluation Statistics
159
+
160
+ Episode 1: Total Reward: 597.100000000053
161
+ Episode 2: Total Reward: 2084.89999999937
162
+ Episode 3: Total Reward: 428.4000000000229
163
+ Episode 4: Total Reward: 22.700000000000063
164
+ Episode 5: Total Reward: 152.89999999999645
165
+
166
+ Evaluation Statistics:
167
+ Mean Reward: 657.20
168
+ Standard Deviation: 741.78
169
+ Max Reward: 2084.90
170
+ Mean Score: 446.00
171
+ Max Score: 446.00
172
+
173
+ ## Architecture
174
+
175
+ The policy network consists of:
176
+ - Input layer: Matches observation space dimension
177
+ - Hidden layers: 256, 128, 64 neurons with ReLU activation
178
+ - Output layer: Softmax over action space
179
+
180
+ ## Algorithm
181
+
182
+ REINFORCE uses Monte Carlo policy gradients to update the policy parameters by maximizing the expected cumulative reward. The agent samples actions from the current policy, collects trajectories, and updates the policy using the discounted returns.
agent.py ADDED
@@ -0,0 +1,335 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import flappy_bird_gymnasium
2
+ import gymnasium
3
+ from policy import PolicyNetwork
4
+ import torch
5
+ import matplotlib.pyplot as plt
6
+ import numpy as np
7
+ import imageio
8
+ import tqdm
9
+ import os
10
+ from PIL import Image, ImageDraw, ImageFont
11
+ import matplotlib
12
+ matplotlib.use('TkAgg') # Use interactive backend
13
+ plt.ion() # Turn on interactive mode
14
+ from huggingface_hub import HfApi, upload_file, login, upload_folder
15
+
16
+ hyperparams = {
17
+ "no_of_episodes": 52300,
18
+ "gamma": 0.99,
19
+ "lr": 3e-4, # Lower learning rate for stability
20
+ "fps": 30,
21
+ "out_directory": "output/flappy_bird_evaluation.mp4",
22
+ "save_videos": True,
23
+ "train": False,
24
+ "evaluate": False,
25
+ "push_hf": True,
26
+ "hf_repo": "Heshwa/flappy-bird-reinforce-v1",
27
+ "eval_episodes": 3
28
+ }
29
+
30
+ def discounted_returns(rewards, gamma):
31
+ G = []
32
+ Gt = 0.0
33
+ for r in reversed(rewards):
34
+ Gt = r + gamma * Gt
35
+ G.insert(0, Gt)
36
+
37
+ eps = np.finfo(np.float32).eps.item()
38
+ G = torch.tensor(G, dtype=torch.float32)
39
+ G = (G - G.mean()) / (G.std() + eps)
40
+ return G
41
+
42
+
43
+ def annotate_frame(frame, episode, score):
44
+ """Draw episode number and score onto an RGB ndarray frame and return ndarray."""
45
+ try:
46
+ img = Image.fromarray(frame)
47
+ except Exception:
48
+ # If frame is already a PIL Image
49
+ img = frame
50
+ draw = ImageDraw.Draw(img)
51
+ font = ImageFont.load_default()
52
+ text = f"Ep:{episode} Score:{score}"
53
+ x, y = 8, 8
54
+ # draw black outline for readability
55
+ draw.text((x - 1, y - 1), text, font=font, fill=(0, 0, 0))
56
+ draw.text((x + 1, y - 1), text, font=font, fill=(0, 0, 0))
57
+ draw.text((x - 1, y + 1), text, font=font, fill=(0, 0, 0))
58
+ draw.text((x + 1, y + 1), text, font=font, fill=(0, 0, 0))
59
+ # draw main text
60
+ draw.text((x, y), text, font=font, fill=(255, 255, 255))
61
+ return np.array(img)
62
+
63
+
64
+
65
+
66
+
67
+
68
+ env = gymnasium.make("FlappyBird-v0",render_mode="rgb_array",use_lidar=False)
69
+ policy_net = PolicyNetwork(input_dim=env.observation_space.shape[0], output_dim=env.action_space.n)
70
+ optimizer = torch.optim.Adam(policy_net.parameters(), lr=hyperparams["lr"])
71
+
72
+ ensure_dir = lambda p: os.makedirs(p, exist_ok=True)
73
+
74
+
75
+ def load_checkpoint(policy_net, optimizer, checkpoint_path):
76
+ if os.path.exists(checkpoint_path):
77
+ checkpoint = torch.load(checkpoint_path)
78
+ policy_net.load_state_dict(checkpoint['model_state_dict'])
79
+ optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
80
+ start_episode = checkpoint['episode']
81
+ rewards = checkpoint.get('rewards', [])
82
+ losses = checkpoint.get('losses', [])
83
+ print(f"Resumed from episode {start_episode}")
84
+ return start_episode, rewards, losses
85
+ else:
86
+ print("No checkpoint found, starting from scratch")
87
+ return 0, [], []
88
+
89
+
90
+ def get_latest_checkpoint(checkpoint_dir):
91
+ checkpoint_path = os.path.join(checkpoint_dir, "checkpoint.pth")
92
+ if os.path.exists(checkpoint_path):
93
+ return checkpoint_path
94
+ final_path = os.path.join(checkpoint_dir, "final_model.pth")
95
+ if os.path.exists(final_path):
96
+ return final_path
97
+ return None
98
+
99
+
100
+ def reinforce(policy_net, optimizer, env, hyperparams, start_episode=0, rewards=[], losses=[]):
101
+
102
+
103
+
104
+ # Initialize dynamic plots
105
+ fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 5))
106
+ ax1.set_xlabel("Episode")
107
+ ax1.set_ylabel("Total Reward")
108
+ ax1.set_title("Total Reward vs Episode")
109
+ line1, = ax1.plot([], [], 'b-')
110
+
111
+ ax2.set_xlabel("Episode")
112
+ ax2.set_ylabel("Loss")
113
+ ax2.set_title("Loss vs Episode")
114
+ line2, = ax2.plot([], [], 'r-')
115
+
116
+ plt.tight_layout()
117
+
118
+ checkpoint_dir = "checkpoints"
119
+ ensure_dir(checkpoint_dir)
120
+
121
+ for i in tqdm.tqdm(range(start_episode, hyperparams["no_of_episodes"])):
122
+ returns = []
123
+ log_probs = []
124
+ obs, _ = env.reset()
125
+ terminated = False
126
+ truncated = False
127
+ while not (terminated or truncated):
128
+ # action = env.action_space.sample()
129
+ action, log_prob = policy_net.act(obs)
130
+ obs, reward, terminated, truncated, info = env.step(action)
131
+ returns.append(reward)
132
+ log_probs.append(log_prob)
133
+ # print(f"Action: {action}, Reward: {reward}, Terminated: {terminated}")
134
+
135
+ loss = 0.0
136
+ # print("Episode:", i+1)
137
+ for log_prob, Gt in zip(log_probs, discounted_returns(returns, hyperparams["gamma"])):
138
+ loss += -log_prob * Gt
139
+ loss = loss/len(returns)
140
+ # print("Loss:", loss)
141
+ # print("total reward:", sum(returns))
142
+ optimizer.zero_grad()
143
+ loss.backward()
144
+ torch.nn.utils.clip_grad_norm_(policy_net.parameters(), max_norm=1.0) # Add gradient clipping
145
+ optimizer.step()
146
+ losses.append(loss.item())
147
+ rewards.append(sum(returns))
148
+
149
+ # Save checkpoint every 100 episodes
150
+ if (i + 1) % 100 == 0:
151
+ checkpoint_path = os.path.join(checkpoint_dir, "checkpoint.pth")
152
+ torch.save({
153
+ 'model_state_dict': policy_net.state_dict(),
154
+ 'optimizer_state_dict': optimizer.state_dict(),
155
+ 'episode': i + 1,
156
+ 'rewards': rewards,
157
+ 'losses': losses
158
+ }, checkpoint_path)
159
+ print(f"Checkpoint saved at episode {i+1} to {checkpoint_path}")
160
+
161
+
162
+
163
+ # Update plots every 100 episodes
164
+ if (i + 1) % 100 == 0:
165
+ line1.set_xdata(range(1, len(rewards) + 1))
166
+ line1.set_ydata(rewards)
167
+ ax1.relim()
168
+ ax1.autoscale_view()
169
+
170
+ line2.set_xdata(range(1, len(losses) + 1))
171
+ line2.set_ydata(losses)
172
+ ax2.relim()
173
+ ax2.autoscale_view()
174
+
175
+ plt.draw()
176
+ plt.pause(0.01) # Small pause to update the plot
177
+
178
+
179
+ ensure_dir("output")
180
+ plt.close(fig) # Close the dynamic plot
181
+ plt.close('all') # Close all plots
182
+ plt.figure()
183
+ plt.plot(rewards)
184
+ plt.xlabel("Episode")
185
+ plt.ylabel("Total Reward")
186
+ plt.title("Total Reward vs Episode")
187
+ plt.savefig("output/rewards_vs_episodes.png")
188
+ plt.close()
189
+ plt.figure()
190
+ plt.plot(losses)
191
+ plt.xlabel("Episode")
192
+ plt.ylabel("Loss")
193
+ plt.title("Loss vs Episode")
194
+ plt.savefig("output/loss_vs_episodes.png")
195
+ plt.close()
196
+
197
+ return rewards, losses
198
+
199
+
200
+
201
+ def evaluate_policy(policy_net, env, episodes=5,save_videos=False, checkpoint_path=None,out_directory="output/flappy_bird_evaluation.mp4"):
202
+ if checkpoint_path is None:
203
+ checkpoint_path = get_latest_checkpoint("checkpoints") or "checkpoints/final_model.pth"
204
+
205
+ if checkpoint_path and os.path.exists(checkpoint_path):
206
+ checkpoint = torch.load(checkpoint_path)
207
+ policy_net.load_state_dict(checkpoint['model_state_dict'])
208
+ print(f"Loaded model from {checkpoint_path}")
209
+ else:
210
+ print("No checkpoint found, using current model")
211
+
212
+ if save_videos:
213
+ images = []
214
+
215
+ episode_rewards = []
216
+ episode_scores = []
217
+ for i in range(episodes):
218
+ obs, _ = env.reset()
219
+ if save_videos:
220
+ frame = env.render()
221
+ if frame is not None:
222
+ images.append(annotate_frame(frame, i + 1, 0))
223
+ terminated = False
224
+ truncated = False
225
+ total_reward = 0
226
+ while not (terminated or truncated):
227
+ action, _ = policy_net.act(obs)
228
+ obs, reward, terminated, truncated, info = env.step(action)
229
+ if save_videos:
230
+ frame = env.render()
231
+ if frame is not None:
232
+ images.append(annotate_frame(frame, i + 1, int(info.get("score"))))
233
+ total_reward += reward
234
+ env.render()
235
+ episode_rewards.append(total_reward)
236
+ episode_scores.append(info.get("score"))
237
+ print(f"Episode {i+1}: Total Reward: {total_reward}")
238
+
239
+ # Compute statistics
240
+ mean_reward = np.mean(episode_rewards)
241
+ std_reward = np.std(episode_rewards)
242
+ max_reward = np.max(episode_rewards)
243
+ mean_score = np.max(episode_scores)
244
+ max_score = np.max(episode_scores)
245
+ print(f"\nEvaluation Statistics:")
246
+ print(f"Mean Reward: {mean_reward:.2f}")
247
+ print(f"Standard Deviation: {std_reward:.2f}")
248
+ print(f"Max Reward: {max_reward:.2f}")
249
+ print(f"Mean Score: {mean_score:.2f}")
250
+ print(f"Max Score: {max_score:.2f}")
251
+
252
+
253
+ if save_videos:
254
+ ensure_dir("output")
255
+ imageio.mimwrite(out_directory, images, fps=hyperparams["fps"])
256
+
257
+
258
+
259
+ def push_to_hf(repo_name, token=None):
260
+ """
261
+ Push the project to Hugging Face Hub.
262
+
263
+ Args:
264
+ repo_name (str): Name of the HF repository (e.g., 'username/model-name')
265
+ token (str, optional): HF token. If None, uses environment variable HF_TOKEN
266
+ """
267
+ if token:
268
+ login(token)
269
+ else:
270
+ login()
271
+
272
+ # Upload the project folder, excluding unnecessary files
273
+ ignore_patterns = [
274
+ "env/",
275
+ "__pycache__/",
276
+ "*.pyc",
277
+ ".git/",
278
+ "*.log",
279
+ "text.ipynb",
280
+ # "*.mp4", # Optional: exclude videos if too large
281
+ # "*.gif" # Optional: exclude GIFs
282
+ ]
283
+
284
+ try:
285
+ upload_folder(
286
+ folder_path=".",
287
+ repo_id=repo_name,
288
+ repo_type="model",
289
+ ignore_patterns=ignore_patterns
290
+ )
291
+ print(f"Project uploaded to https://huggingface.co/{repo_name}")
292
+ except Exception as e:
293
+ print(f"Upload failed: {e}")
294
+
295
+
296
+
297
+ if __name__ == "__main__":
298
+ env = gymnasium.make("FlappyBird-v0",render_mode="rgb_array",use_lidar=False)
299
+ policy_net = PolicyNetwork(input_dim=env.observation_space.shape[0], output_dim=env.action_space.n)
300
+ optimizer = torch.optim.Adam(policy_net.parameters(), lr=hyperparams["lr"])
301
+
302
+ ensure_dir = lambda p: os.makedirs(p, exist_ok=True)
303
+
304
+ # Load checkpoint if exists (prefer checkpoint.pth for resuming, else final_model.pth)
305
+ checkpoint_path = get_latest_checkpoint("checkpoints")
306
+ if checkpoint_path:
307
+ start_episode, rewards, losses = load_checkpoint(policy_net, optimizer, checkpoint_path)
308
+ else:
309
+ start_episode, rewards, losses = 0, [], []
310
+
311
+ if hyperparams["train"]:
312
+ rewards, losses = reinforce(policy_net, optimizer, env, hyperparams,start_episode=start_episode, rewards=rewards, losses=losses)
313
+
314
+ # Save final model
315
+ final_checkpoint_path = "checkpoints/final_model.pth"
316
+ ensure_dir("checkpoints")
317
+ torch.save({
318
+ 'model_state_dict': policy_net.state_dict(),
319
+ 'optimizer_state_dict': optimizer.state_dict(),
320
+ 'episode': hyperparams["no_of_episodes"],
321
+ 'rewards': rewards,
322
+ 'losses': losses
323
+ }, final_checkpoint_path)
324
+ print(f"Final model saved to {final_checkpoint_path}")
325
+
326
+ if hyperparams["evaluate"]:
327
+ evaluate_policy(policy_net, env, episodes=hyperparams["eval_episodes"], save_videos=hyperparams["save_videos"], checkpoint_path=None, out_directory=hyperparams["out_directory"])
328
+
329
+ if hyperparams["push_hf"]:
330
+ if hyperparams["hf_repo"] is None:
331
+ print("Error: hf_repo is required when push_hf is True")
332
+ else:
333
+ push_to_hf(hyperparams["hf_repo"])
334
+
335
+ env.close()
checkpoints/checkpoint.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:07e044c7c4e27d05f6b4d36714994120b51996a39703a964b95cef214f3ce238
3
+ size 1490549
checkpoints/final_model.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7ea3fa287930789d32a4dac03f8ef60f5d7000b9020daf05758a4400c7727515
3
+ size 1490587
output/flappy_bird_evaluation.gif ADDED

Git LFS Details

  • SHA256: 9fc5fc23e280ca12408769462ea210a765750c17713bc10ce9660c333a354be4
  • Pointer size: 131 Bytes
  • Size of remote file: 700 kB
output/flappy_bird_evaluation.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:859c760509b5059fcd5798b260bcff082ce107a37a8af1f9d58c002f5b38e063
3
+ size 15957601
output/flappy_bird_training.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:69aa820437344b05e9e3d009b96a36c3859df65914e1c305d87ef91ec7f696a1
3
+ size 2672703
output/loss_vs_episodes.png ADDED
output/rewards_vs_episodes.png ADDED
policy.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+ import torch
3
+ from torch.distributions import Categorical
4
+ import numpy as np
5
+
6
+
7
+ class PolicyNetwork(nn.Module):
8
+ def __init__(self, input_dim, output_dim):
9
+ super(PolicyNetwork, self).__init__()
10
+ self.fc1 = nn.Linear(input_dim, 256) # Increased from 128
11
+ self.fc2 = nn.Linear(256, 128) # Increased from 64
12
+ self.fc3 = nn.Linear(128, 64) # New layer
13
+ self.fc4 = nn.Linear(64, output_dim)
14
+ self.relu = nn.ReLU()
15
+ self.softmax = nn.Softmax(dim=-1)
16
+
17
+ def forward(self, x):
18
+ x = self.relu(self.fc1(x))
19
+ x = self.relu(self.fc2(x))
20
+ x = self.relu(self.fc3(x))
21
+ x = self.softmax(self.fc4(x))
22
+ return x
23
+
24
+ def act(self, state):
25
+ state = torch.from_numpy(state).float().unsqueeze(0)
26
+ probs = self.forward(state)
27
+ m = Categorical(probs)
28
+ # action = np.argmax(m)
29
+ action = m.sample()
30
+ return action.item(), m.log_prob(action)
31
+
requirements.txt ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ torch>=2.0.0
2
+ numpy>=1.21.0
3
+ matplotlib>=3.5.0
4
+ gymnasium>=0.29.0
5
+ flappy-bird-gymnasium>=0.4.0
6
+ imageio>=2.31.0
7
+ tqdm>=4.64.0
8
+ Pillow>=9.0.0
9
+ huggingface-hub>=0.17.0