2023-11-18 11:42:43 +08:00
|
|
|
from typing import Tuple
|
|
|
|
import time
|
2024-01-02 07:02:03 +08:00
|
|
|
from tinygrad import Tensor, TinyJit, nn
|
2023-11-18 11:42:43 +08:00
|
|
|
import gymnasium as gym
|
2024-06-23 02:45:06 +08:00
|
|
|
from tinygrad.helpers import trange
|
2023-11-18 11:42:43 +08:00
|
|
|
import numpy as np # TODO: remove numpy import
|
|
|
|
|
2024-01-08 09:41:09 +08:00
|
|
|
ENVIRONMENT_NAME = 'CartPole-v1'
|
|
|
|
#ENVIRONMENT_NAME = 'LunarLander-v2'
|
|
|
|
|
|
|
|
#import examples.rl.lightupbutton
|
|
|
|
#ENVIRONMENT_NAME = 'PressTheLightUpButton-v0'
|
|
|
|
|
|
|
|
# *** hyperparameters ***
|
|
|
|
# https://github.com/llSourcell/Unity_ML_Agents/blob/master/docs/best-practices-ppo.md
|
|
|
|
|
|
|
|
BATCH_SIZE = 256
|
|
|
|
ENTROPY_SCALE = 0.0005
|
|
|
|
REPLAY_BUFFER_SIZE = 2000
|
|
|
|
PPO_EPSILON = 0.2
|
|
|
|
HIDDEN_UNITS = 32
|
|
|
|
LEARNING_RATE = 1e-2
|
|
|
|
TRAIN_STEPS = 5
|
|
|
|
EPISODES = 40
|
|
|
|
DISCOUNT_FACTOR = 0.99
|
|
|
|
|
2023-11-18 11:42:43 +08:00
|
|
|
class ActorCritic:
|
2024-01-08 09:41:09 +08:00
|
|
|
def __init__(self, in_features, out_features, hidden_state=HIDDEN_UNITS):
|
2023-11-18 11:42:43 +08:00
|
|
|
self.l1 = nn.Linear(in_features, hidden_state)
|
|
|
|
self.l2 = nn.Linear(hidden_state, out_features)
|
|
|
|
|
|
|
|
self.c1 = nn.Linear(in_features, hidden_state)
|
|
|
|
self.c2 = nn.Linear(hidden_state, 1)
|
|
|
|
|
|
|
|
def __call__(self, obs:Tensor) -> Tuple[Tensor, Tensor]:
|
|
|
|
x = self.l1(obs).tanh()
|
|
|
|
act = self.l2(x).log_softmax()
|
|
|
|
x = self.c1(obs).relu()
|
|
|
|
return act, self.c2(x)
|
|
|
|
|
|
|
|
def evaluate(model:ActorCritic, test_env:gym.Env) -> float:
|
|
|
|
(obs, _), terminated, truncated = test_env.reset(), False, False
|
|
|
|
total_rew = 0.0
|
|
|
|
while not terminated and not truncated:
|
2024-01-02 07:02:03 +08:00
|
|
|
act = model(Tensor(obs))[0].argmax().item()
|
2023-11-18 11:42:43 +08:00
|
|
|
obs, rew, terminated, truncated, _ = test_env.step(act)
|
|
|
|
total_rew += float(rew)
|
|
|
|
return total_rew
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
2024-01-08 09:41:09 +08:00
|
|
|
env = gym.make(ENVIRONMENT_NAME)
|
2023-11-18 11:42:43 +08:00
|
|
|
|
|
|
|
model = ActorCritic(env.observation_space.shape[0], int(env.action_space.n)) # type: ignore
|
2024-01-08 09:41:09 +08:00
|
|
|
opt = nn.optim.Adam(nn.state.get_parameters(model), lr=LEARNING_RATE)
|
2023-11-18 11:42:43 +08:00
|
|
|
|
|
|
|
@TinyJit
|
|
|
|
def train_step(x:Tensor, selected_action:Tensor, reward:Tensor, old_log_dist:Tensor) -> Tuple[Tensor, Tensor, Tensor]:
|
|
|
|
with Tensor.train():
|
|
|
|
log_dist, value = model(x)
|
2024-01-08 09:41:09 +08:00
|
|
|
action_mask = (selected_action.reshape(-1, 1) == Tensor.arange(log_dist.shape[1]).reshape(1, -1).expand(selected_action.shape[0], -1)).float()
|
2023-11-18 11:42:43 +08:00
|
|
|
|
2024-01-08 09:41:09 +08:00
|
|
|
# get real advantage using the value function
|
2023-11-18 11:42:43 +08:00
|
|
|
advantage = reward.reshape(-1, 1) - value
|
2024-01-08 09:41:09 +08:00
|
|
|
masked_advantage = action_mask * advantage.detach()
|
2023-11-18 11:42:43 +08:00
|
|
|
|
|
|
|
# PPO
|
2024-01-08 09:41:09 +08:00
|
|
|
ratios = (log_dist - old_log_dist).exp()
|
|
|
|
unclipped_ratio = masked_advantage * ratios
|
|
|
|
clipped_ratio = masked_advantage * ratios.clip(1-PPO_EPSILON, 1+PPO_EPSILON)
|
|
|
|
action_loss = -unclipped_ratio.minimum(clipped_ratio).sum(-1).mean()
|
2023-11-18 11:42:43 +08:00
|
|
|
|
|
|
|
entropy_loss = (log_dist.exp() * log_dist).sum(-1).mean() # this encourages diversity
|
|
|
|
critic_loss = advantage.square().mean()
|
|
|
|
opt.zero_grad()
|
2024-01-08 09:41:09 +08:00
|
|
|
(action_loss + entropy_loss*ENTROPY_SCALE + critic_loss).backward()
|
2023-11-18 11:42:43 +08:00
|
|
|
opt.step()
|
|
|
|
return action_loss.realize(), entropy_loss.realize(), critic_loss.realize()
|
|
|
|
|
|
|
|
@TinyJit
|
2024-01-06 13:39:55 +08:00
|
|
|
def get_action(obs:Tensor) -> Tensor:
|
2023-11-18 11:42:43 +08:00
|
|
|
# TODO: with no_grad
|
|
|
|
Tensor.no_grad = True
|
2024-01-06 13:39:55 +08:00
|
|
|
ret = model(obs)[0].exp().multinomial().realize()
|
2023-11-18 11:42:43 +08:00
|
|
|
Tensor.no_grad = False
|
|
|
|
return ret
|
|
|
|
|
|
|
|
st, steps = time.perf_counter(), 0
|
|
|
|
Xn, An, Rn = [], [], []
|
2024-01-08 09:41:09 +08:00
|
|
|
for episode_number in (t:=trange(EPISODES)):
|
2024-01-06 13:39:55 +08:00
|
|
|
get_action.reset() # NOTE: if you don't reset the jit here it captures the wrong model on the first run through
|
2023-11-18 11:42:43 +08:00
|
|
|
|
|
|
|
obs:np.ndarray = env.reset()[0]
|
|
|
|
rews, terminated, truncated = [], False, False
|
|
|
|
# NOTE: we don't want to early stop since then the rewards are wrong for the last episode
|
|
|
|
while not terminated and not truncated:
|
|
|
|
# pick actions
|
|
|
|
# TODO: what's the temperature here?
|
2024-01-06 13:39:55 +08:00
|
|
|
act = get_action(Tensor(obs)).item()
|
2023-11-18 11:42:43 +08:00
|
|
|
|
|
|
|
# save this state action pair
|
|
|
|
# TODO: don't use np.copy here on the CPU, what's the tinygrad way to do this and keep on device? need __setitem__ assignment
|
|
|
|
Xn.append(np.copy(obs))
|
|
|
|
An.append(act)
|
|
|
|
|
|
|
|
obs, rew, terminated, truncated, _ = env.step(act)
|
|
|
|
rews.append(float(rew))
|
|
|
|
steps += len(rews)
|
|
|
|
|
|
|
|
# reward to go
|
|
|
|
# TODO: move this into tinygrad
|
2024-01-08 09:41:09 +08:00
|
|
|
discounts = np.power(DISCOUNT_FACTOR, np.arange(len(rews)))
|
2023-11-18 11:42:43 +08:00
|
|
|
Rn += [np.sum(rews[i:] * discounts[:len(rews)-i]) for i in range(len(rews))]
|
|
|
|
|
2024-01-08 09:41:09 +08:00
|
|
|
Xn, An, Rn = Xn[-REPLAY_BUFFER_SIZE:], An[-REPLAY_BUFFER_SIZE:], Rn[-REPLAY_BUFFER_SIZE:]
|
2023-11-18 11:42:43 +08:00
|
|
|
X, A, R = Tensor(Xn), Tensor(An), Tensor(Rn)
|
|
|
|
|
|
|
|
# TODO: make this work
|
2024-01-08 09:41:09 +08:00
|
|
|
#vsz = Variable("sz", 1, REPLAY_BUFFER_SIZE-1).bind(len(Xn))
|
2023-11-18 11:42:43 +08:00
|
|
|
#X, A, R = Tensor(Xn).reshape(vsz, None), Tensor(An).reshape(vsz), Tensor(Rn).reshape(vsz)
|
|
|
|
|
2024-01-08 09:41:09 +08:00
|
|
|
old_log_dist = model(X)[0].detach() # TODO: could save these instead of recomputing
|
|
|
|
for i in range(TRAIN_STEPS):
|
|
|
|
samples = Tensor.randint(BATCH_SIZE, high=X.shape[0]).realize() # TODO: remove the need for this
|
2023-11-18 11:42:43 +08:00
|
|
|
# TODO: is this recompiling based on the shape?
|
|
|
|
action_loss, entropy_loss, critic_loss = train_step(X[samples], A[samples], R[samples], old_log_dist[samples])
|
2024-01-08 09:41:09 +08:00
|
|
|
t.set_description(f"sz: {len(Xn):5d} steps/s: {steps/(time.perf_counter()-st):7.2f} action_loss: {action_loss.item():7.3f} entropy_loss: {entropy_loss.item():7.3f} critic_loss: {critic_loss.item():8.3f} reward: {sum(rews):6.2f}")
|
2023-11-18 11:42:43 +08:00
|
|
|
|
2024-01-08 09:41:09 +08:00
|
|
|
test_rew = evaluate(model, gym.make(ENVIRONMENT_NAME, render_mode='human'))
|
2023-11-18 11:42:43 +08:00
|
|
|
print(f"test reward: {test_rew}")
|