-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtesting_stuff.py
More file actions
38 lines (30 loc) · 804 Bytes
/
testing_stuff.py
File metadata and controls
38 lines (30 loc) · 804 Bytes
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
# Policy-Gradient RL
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import gym
import argparse
import wandb
import os
from reinforce_agent import ReinforceAgent
from actor_critic_agent import ActorCriticAgent
# hyperparams
env_name = 'LunarLander-v2'
episodes = 1_000
gamma = 0.99
lr = 0.001
# environment setup
print(f'Training in the {env_name} environment.')
env = gym.make(env_name) # new_step_api=True
obs_shape = env.observation_space.shape[0]
action_shape = env.action_space.n
device = torch.device('cpu')
# init agent
agent = ActorCriticAgent(n_features=obs_shape, n_actions=action_shape, device=device, lr=lr)
print(agent.parameters())
# get a trajectory from the current policy
obs, _ = env.reset()
print(obs)
print(agent.forward(obs))
env.close()