tinygrad/extra/optimization/extract_sa_pairs.py

130 lines
4.2 KiB
Python

import sys, sqlite3, pickle, math
from collections import defaultdict
from tqdm import tqdm, trange
import numpy as np
# stuff needed to unpack a kernel
from tinygrad.ops import LazyOp, TernaryOps, BinaryOps, UnaryOps, ReduceOps, BufferOps, MemBuffer, ConstBuffer
from tinygrad.helpers import dtypes
from tinygrad.shape.shapetracker import ShapeTracker
from tinygrad.shape.view import View
from tinygrad.shape.symbolic import Variable
inf, nan = float('inf'), float('nan')
from tinygrad.codegen.kernel import Opt, OptOps
# more stuff
from tinygrad.codegen.linearizer import Linearizer
from tinygrad.features.search import actions
from extra.optimization.helpers import lin_to_feats
from extra.optimization.pretrain_valuenet import ValueNet
from tinygrad.nn.optim import Adam
from tinygrad.nn.state import get_parameters, get_state_dict, safe_save, safe_load, load_state_dict
import random
from tinygrad.tensor import Tensor
from tinygrad.helpers import getenv
def dataset_from_cache(fn):
conn = sqlite3.connect(fn)
cur = conn.cursor()
cur.execute("SELECT * FROM time_linearizer")
grouped = defaultdict(dict)
for f in tqdm(cur.fetchall()): grouped[f[0]][f[1:-1]] = pickle.loads(f[-1])
opts_to_outcome = {}
for ast,sk in grouped.items():
cnts = defaultdict(int)
for sks,tm in sk.items():
if sks[1] != 1: continue
opts = eval(sks[0])
cnts[(len(opts), sks[1])] += 1
opts_to_outcome[(ast, tuple(opts))] = tm
#print(cnts)
S,A,V = [], [], []
for ast,k in tqdm(opts_to_outcome):
if len(k) == 0: continue
old_tm = min(opts_to_outcome[(ast,k[:-1])])
new_tm = min(opts_to_outcome[(ast,k)])
if math.isinf(old_tm) or math.isinf(new_tm) or old_tm < 1e-9 or new_tm < 1e-9: continue
try:
lin = Linearizer(eval(ast))
except Exception:
continue
for opt in k[:-1]: lin.apply_opt(opt)
act = k[-1]
log_ratio = math.log(old_tm/new_tm)
#print(f"ratio: {old_tm/new_tm:6.2f}x (log {log_ratio:5.2f}) from {str(act):50s} on {lin.colored_shape()}")
S.append(lin_to_feats(lin, use_sts=True))
A.append(actions.index(act))
V.append([log_ratio]) # NOTE: i have written the bug many times with this having the wrong dim
S, A, V = np.array(S), np.array(A), np.array(V, dtype=np.float32)
X = np.zeros((S.shape[0], S.shape[1]+len(actions)), dtype=np.float32)
X[:, :S.shape[1]] = S
X[range(S.shape[0]), S.shape[1]+A] = 1.0
return X, V
def log_likelihood(x:Tensor, mu:Tensor, log_sigma:Tensor):
#print(x.shape, mu.shape, log_sigma.shape)
#return (x-mu).abs() * (-log_sigma).exp() + log_sigma
return (x-mu).square() * (-2*log_sigma).exp() / 2 + log_sigma
if __name__ == "__main__":
if getenv("REGEN"):
X,V = dataset_from_cache(sys.argv[1] if len(sys.argv) > 1 else "/tmp/tinygrad_cache")
safe_save({"X": Tensor(X), "V": Tensor(V)}, "/tmp/dataset")
else:
ld = safe_load("/tmp/dataset")
X,V = ld['X'].numpy(), ld['V'].numpy()
print(X.shape, V.shape)
order = list(range(X.shape[0]))
random.shuffle(order)
X, V = X[order], V[order]
ratio = -512
X_test, V_test = Tensor(X[ratio:]), Tensor(V[ratio:])
X,V = X[:ratio], V[:ratio]
print(X.shape, V.shape)
#print(X[0], V[0])
#print(X[-1], V[-1])
print(X.shape)
net = ValueNet(X.shape[1], 2)
optim = Adam(get_parameters(net))
def get_minibatch(X,Y,bs):
xs, ys = [], []
#random.seed(1337)
for _ in range(bs):
sel = random.randint(0, len(X)-1)
xs.append(X[sel])
ys.append(Y[sel])
return Tensor(xs), Tensor(ys)
Tensor.no_grad, Tensor.training = False, True
losses = []
test_losses = []
test_loss = float('inf')
for i in (t:=trange(2000)):
x,y = get_minibatch(X,V,bs=256)
out = net(x)
#loss = (out-y).square().mean()
loss = log_likelihood(y, out[:, 0:1], out[:, 1:2]).mean()
optim.zero_grad()
loss.backward()
optim.step()
t.set_description(f"loss {loss.numpy():7.2f}, test loss {test_loss:7.2f}")
losses.append(loss.numpy().item())
test_losses.append(test_loss)
if i % 10: test_loss = (net(X_test)[:, 0:1]-V_test).square().mean().numpy().item()
safe_save(get_state_dict(net), "/tmp/qnet.safetensors")
import matplotlib.pyplot as plt
plt.plot(losses[20:])
plt.plot(test_losses[20:])
plt.show()