mirror of https://github.com/commaai/tinygrad.git
Replicate llm.c in tinygrad (#4179)
* write llm.c and add a few new methods to tensor * training works * add jit * tests for new functions * test tolist * simple fix for onnx test failures (#4186) * write llm.c and add a few new methods to tensor * training works * add jit * tests for new functions * bump line count to 7500 * simplest fix * safenumpy tolist for now --------- Co-authored-by: George Hotz <geohot@gmail.com> Co-authored-by: George Hotz <72895+geohot@users.noreply.github.com> --------- Co-authored-by: geohotstan <135171913+geohotstan@users.noreply.github.com>
This commit is contained in:
parent
b6e7243bfa
commit
55ae73e951
|
@ -0,0 +1 @@
|
|||
data
|
|
@ -0,0 +1,190 @@
|
|||
#!/usr/bin/env python3
|
||||
import os, math, time
|
||||
import numpy as np
|
||||
from tinygrad import Tensor, nn, fetch, Device, TinyJit, GlobalCounters
|
||||
from dataclasses import dataclass
|
||||
|
||||
@dataclass
|
||||
class GPTConfig:
|
||||
block_size: int = 1024
|
||||
vocab_size: int = 50257
|
||||
n_layer: int = 12
|
||||
n_head: int = 12
|
||||
n_embd: int = 768
|
||||
|
||||
class CausalSelfAttention:
|
||||
def __init__(self, config:GPTConfig):
|
||||
assert config.n_embd % config.n_head == 0
|
||||
# key, query, value projections for all heads, but in a batch
|
||||
self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd)
|
||||
# output projection
|
||||
self.c_proj = nn.Linear(config.n_embd, config.n_embd)
|
||||
# regularization
|
||||
self.n_head = config.n_head
|
||||
self.n_embd = config.n_embd
|
||||
# not really a 'bias', more of a mask, but following the OpenAI/HF naming though
|
||||
self.bias = Tensor.ones(1, 1, config.block_size, config.block_size).tril()
|
||||
self.bias.requires_grad = False
|
||||
|
||||
def __call__(self, x:Tensor):
|
||||
B, T, C = x.shape
|
||||
qkv = self.c_attn(x)
|
||||
q, k, v = qkv.split(self.n_embd, dim=2)
|
||||
k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
|
||||
q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
|
||||
v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
|
||||
|
||||
# manual implementation of attention
|
||||
att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
|
||||
att = att.masked_fill(self.bias[:,:,:T,:T] == 0, float('-inf'))
|
||||
att = att.softmax()
|
||||
y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)
|
||||
y = y.transpose(1, 2).view(B, T, C) # re-assemble all head outputs side by side
|
||||
# output projection
|
||||
y = self.c_proj(y)
|
||||
return y
|
||||
|
||||
class MLP:
|
||||
def __init__(self, config:GPTConfig):
|
||||
self.c_fc = nn.Linear(config.n_embd, 4 * config.n_embd)
|
||||
self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd)
|
||||
|
||||
def __call__(self, x:Tensor) -> Tensor:
|
||||
return self.c_proj(self.c_fc(x).gelu())
|
||||
|
||||
class Block:
|
||||
def __init__(self, config:GPTConfig):
|
||||
self.ln_1 = nn.LayerNorm(config.n_embd)
|
||||
self.attn = CausalSelfAttention(config)
|
||||
self.ln_2 = nn.LayerNorm(config.n_embd)
|
||||
self.mlp = MLP(config)
|
||||
|
||||
def __call__(self, x:Tensor):
|
||||
x = x + self.attn(self.ln_1(x))
|
||||
x = x + self.mlp(self.ln_2(x))
|
||||
return x
|
||||
|
||||
class GPT:
|
||||
def __init__(self, config:GPTConfig):
|
||||
self.config = config
|
||||
|
||||
self.wte = nn.Embedding(config.vocab_size, config.n_embd)
|
||||
self.wpe = nn.Embedding(config.block_size, config.n_embd)
|
||||
self.h = [Block(config) for _ in range(config.n_layer)]
|
||||
self.ln_f = nn.LayerNorm(config.n_embd)
|
||||
self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
|
||||
self.wte.weight = self.lm_head.weight # https://paperswithcode.com/method/weight-tying
|
||||
|
||||
def load_pretrained(self):
|
||||
weights = nn.state.torch_load(fetch(f'https://huggingface.co/gpt2/resolve/main/pytorch_model.bin'))
|
||||
transposed = ('attn.c_attn.weight', 'attn.c_proj.weight', 'mlp.c_fc.weight', 'mlp.c_proj.weight')
|
||||
for k in weights:
|
||||
if k.endswith(transposed):
|
||||
weights[k] = weights[k].to(Device.DEFAULT).T.contiguous()
|
||||
# lm head and wte are tied
|
||||
weights['lm_head.weight'] = weights['wte.weight']
|
||||
nn.state.load_state_dict(self, weights)
|
||||
|
||||
def generate(self, idx, max_new_tokens, temperature=1.0, top_k=None):
|
||||
for _ in range(max_new_tokens):
|
||||
idx_cond = idx if idx.shape[1] <= self.config.block_size else idx[:, -self.config.block_size:]
|
||||
logits, _ = self(idx_cond)
|
||||
logits = logits[:, -1, :] / temperature
|
||||
idx_next = logits.softmax().multinomial()
|
||||
idx = Tensor.cat(idx, idx_next, dim=1)
|
||||
return idx
|
||||
|
||||
def __call__(self, idx:Tensor, targets=None):
|
||||
b, t = idx.shape
|
||||
pos = Tensor.arange(0, t)
|
||||
|
||||
tok_emb = self.wte(idx) # token embeddings of shape (b, t, n_embd)
|
||||
pos_emb = self.wpe(pos) # position embeddings of shape (t, n_embd)
|
||||
x = tok_emb + pos_emb
|
||||
|
||||
x = self.ln_f(x.sequential(self.h))
|
||||
|
||||
if targets is not None:
|
||||
logits = self.lm_head(x)
|
||||
loss = logits.sparse_categorical_crossentropy(targets)
|
||||
else:
|
||||
logits = self.lm_head(x[:, [-1], :])
|
||||
loss = None
|
||||
|
||||
return logits, loss
|
||||
|
||||
if __name__ == "__main__":
|
||||
import tiktoken, argparse
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--num_iterations", type=int, default=10, help="number of iterations to run")
|
||||
parser.add_argument("--batch_size", type=int, default=4, help="batch size")
|
||||
parser.add_argument("--sequence_length", type=int, default=64, help="sequence length")
|
||||
args = parser.parse_args()
|
||||
B, T = args.batch_size, args.sequence_length
|
||||
assert 1 <= T <= 1024
|
||||
|
||||
model = GPT(GPTConfig(n_layer=12, n_head=12, n_embd=768))
|
||||
model.load_pretrained()
|
||||
|
||||
# init the tokenizer
|
||||
enc = tiktoken.get_encoding("gpt2")
|
||||
encode = lambda s: enc.encode(s, allowed_special={"<|endoftext|>"})
|
||||
decode = lambda l: enc.decode(l)
|
||||
|
||||
# load the tokens
|
||||
# prefer to use tiny_shakespeare if it's available, otherwise use tiny_stories
|
||||
# we're using val instead of train split just because it is smaller/faster
|
||||
shake_tokens_bin = "data/tiny_shakespeare_val.bin"
|
||||
story_tokens_bin = "data/TinyStories_val.bin"
|
||||
assert os.path.isfile(shake_tokens_bin) or os.path.isfile(story_tokens_bin), "you must run prepro on some dataset"
|
||||
tokens_bin = shake_tokens_bin if os.path.isfile(shake_tokens_bin) else story_tokens_bin
|
||||
assert os.path.isfile(tokens_bin)
|
||||
print(f"loading cached tokens in {tokens_bin}")
|
||||
with open(tokens_bin, "rb") as f:
|
||||
tokens = np.frombuffer(f.read(), dtype=np.int32)
|
||||
tokens = Tensor(tokens)
|
||||
|
||||
# lightweight dataloader
|
||||
def get_batch():
|
||||
assert B*T+1 <= len(tokens), "not enough tokens"
|
||||
# for 338,025 tokens. E.g. with B=8 T=1024, this will yield 41 batches before looping
|
||||
i = 0
|
||||
while True:
|
||||
x = tokens[i:i+B*T].view(B, T)
|
||||
y = tokens[i+1:i+B*T+1].view(B, T)
|
||||
yield x, y
|
||||
i += B*T
|
||||
if i + B*T + 1 >= len(tokens):
|
||||
i = 0 # in prod we'd want to randomize the start point a bit
|
||||
|
||||
# forward backward for a few iterations
|
||||
data_iter = iter(get_batch())
|
||||
x, y = next(data_iter) # we'll overfit this batch below
|
||||
optimizer = nn.optim.Adam(nn.state.get_parameters(model), lr=1e-4)
|
||||
|
||||
@TinyJit
|
||||
def step(x, y):
|
||||
_, loss = model(x, y)
|
||||
optimizer.zero_grad()
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
return loss
|
||||
|
||||
for i in range(args.num_iterations):
|
||||
GlobalCounters.reset()
|
||||
t0 = time.time()
|
||||
loss = step(x.contiguous(), y.contiguous())
|
||||
Device[Device.DEFAULT].synchronize()
|
||||
t1 = time.time()
|
||||
print(f"iteration {i}, loss: {loss.item()}, time: {(t1-t0)*1000:.3f}ms")
|
||||
|
||||
start = "<|endoftext|>"
|
||||
start_ids = encode(start)
|
||||
x = (Tensor(start_ids)[None, ...])
|
||||
max_new_tokens = 16
|
||||
temperature = 1.0
|
||||
top_k = 40
|
||||
y = model.generate(x, max_new_tokens, temperature=temperature, top_k=top_k)
|
||||
print(decode(y[0].tolist()))
|
||||
|
|
@ -428,7 +428,7 @@ def Resize(X:Tensor, roi=None, scales=None, sizes=None, antialias=0, axes=None,
|
|||
y_out = roi[-2][0] * (X.shape[-2] - 1) + y_out * ((roi[-2][1] - roi[-2][0]) * (X.shape[-2] - 1) / (output_shape[-2] - 1)) if output_shape[-2] > 1 else Tensor([0.5 * (roi[-2][0] + roi[-2][1]) * (X.shape[-2] - 1)])
|
||||
return x_out.clip(0, X.shape[-1]-1), y_out.clip(0, X.shape[-2]-1)
|
||||
if roi is not None:
|
||||
roi = safe_numpy(roi)
|
||||
roi = safe_numpy(roi).tolist()
|
||||
roi = [(st,ed) for st, ed in zip(roi[:len(roi)//2], roi[len(roi)//2:])]
|
||||
roi_ = [(1,1)] * 4
|
||||
if axes is not None:
|
||||
|
@ -691,7 +691,7 @@ def Adagrad(R, T, *inputs, decay_factor=0.0, epsilon=0.0, norm_coefficient=0.0):
|
|||
X.grad.requires_grad, H.requires_grad = False, False # TODO manually turning off requires_grad, see TODO under (domain == "ai.onnx.preview.training") in onnx.py
|
||||
H.assign(H.detach() + X.grad * X.grad).realize()
|
||||
H_adaptive = H.sqrt() + epsilon
|
||||
X.assign(X.detach() - r * X.grad / H_adaptive)
|
||||
X.assign(X.detach() - r.tolist() * X.grad / H_adaptive)
|
||||
ret.extend([X, H])
|
||||
ret = ret[::2] + ret[1::2]
|
||||
return tuple(ret)
|
||||
|
@ -699,7 +699,7 @@ def Adagrad(R, T, *inputs, decay_factor=0.0, epsilon=0.0, norm_coefficient=0.0):
|
|||
def Momentum(R, T, *inputs, alpha, beta, mode, norm_coefficient):
|
||||
groups = len(inputs) // 3
|
||||
grouped_inputs = [inputs[i::groups] for i in range(groups)]
|
||||
T, R = safe_numpy(T), safe_numpy(R)
|
||||
T, R.requires_grad = T.item(), False
|
||||
beta_adjusted = beta if T > 0 else 1
|
||||
ret = []
|
||||
for X, G, V in grouped_inputs:
|
||||
|
@ -716,7 +716,7 @@ def Momentum(R, T, *inputs, alpha, beta, mode, norm_coefficient):
|
|||
def Adam(R, T, *inputs, alpha=0.9, beta=0.999, epsilon=0.0, norm_coefficient=0.0, norm_coefficient_post=0.0):
|
||||
groups = len(inputs) // 4
|
||||
grouped_inputs = [inputs[i::groups] for i in range(groups)]
|
||||
T, R = safe_numpy(T), safe_numpy(R)
|
||||
T, R.requires_grad = T.item(), False
|
||||
ret = []
|
||||
for X, G, V, H in grouped_inputs:
|
||||
X.grad = (norm_coefficient * X + G).realize()
|
||||
|
|
|
@ -1695,6 +1695,10 @@ class TestOps(unittest.TestCase):
|
|||
helper_test_op([], lambda: torch.nn.functional.one_hot(torch.tensor(data), 8).type(torch.int32),
|
||||
lambda: Tensor(data).one_hot(8), forward_only=True)
|
||||
|
||||
def test_masked_fill(self):
|
||||
helper_test_op([(32,10)], lambda x: x.masked_fill((x>0.1).detach(), -math.inf))
|
||||
helper_test_op([(32,10)], lambda x: x.masked_fill((x<0.1).detach(), -math.inf))
|
||||
|
||||
if __name__ == '__main__':
|
||||
np.random.seed(1337)
|
||||
unittest.main(verbosity=2)
|
||||
|
|
|
@ -228,6 +228,29 @@ class TestTinygrad(unittest.TestCase):
|
|||
assert Tensor([]).numel() == 0
|
||||
assert Tensor.randn(1,0,2,5).numel() == 0
|
||||
|
||||
def test_len(self):
|
||||
assert len(torch.zeros(7)) == len(Tensor.zeros(7))
|
||||
assert len(torch.zeros(10,20)) == len(Tensor.zeros(10,20))
|
||||
assert len(torch.zeros(10,20)) == len(Tensor.zeros(10,20,30))
|
||||
assert len(torch.zeros(1).flatten()) == len(Tensor.zeros(1).flatten())
|
||||
|
||||
def test_size(self):
|
||||
t1, t2 = torch.zeros(10,20), Tensor.zeros(10,20)
|
||||
assert t1.size() == t2.size()
|
||||
assert t1.size(0) == t2.size(0)
|
||||
assert t1.size(1) == t2.size(1)
|
||||
assert t1.size(-1) == t2.size(-1)
|
||||
assert t1.size(-2) == t2.size(-2)
|
||||
with self.assertRaises(IndexError): t2.size(2)
|
||||
|
||||
def test_tolist(self):
|
||||
assert Tensor([1,2,3]).tolist() == [1,2,3]
|
||||
assert Tensor([1.5,2,3]).tolist() == [1.5,2,3]
|
||||
|
||||
# TODO: match torch here
|
||||
# NotImplementedError: multi-dimensional sub-views are not implemented
|
||||
#assert Tensor([[1,2,3], [4,5,6]]).tolist() == [[1,2,3], [4,5,6]]
|
||||
|
||||
def test_element_size(self):
|
||||
for _, dtype in dtypes.fields().items():
|
||||
assert dtype.itemsize == Tensor.randn(3, dtype=dtype).element_size(), f"Tensor.element_size() not matching Tensor.dtype.itemsize for {dtype}"
|
||||
|
|
|
@ -126,6 +126,8 @@ class Tensor:
|
|||
|
||||
def __bool__(self): raise TypeError("__bool__ on Tensor is not defined")
|
||||
|
||||
def __len__(self): return self.shape[0] if len(self.shape) else 1
|
||||
|
||||
@property
|
||||
def device(self) -> Union[str, Tuple[str, ...]]: return self.lazydata.device
|
||||
|
||||
|
@ -188,6 +190,7 @@ class Tensor:
|
|||
assert self.dtype.fmt is not None, f"no fmt dtype for {self.dtype}"
|
||||
assert self.numel() == 1, "must have one element for item"
|
||||
return self._data().cast(self.dtype.fmt)[0]
|
||||
def tolist(self) -> List[ConstType]: return list(self.data())
|
||||
def numpy(self) -> np.ndarray:
|
||||
if self.dtype == dtypes.bfloat16: return self.float().numpy()
|
||||
assert self.dtype.np is not None, f"no np dtype for {self.dtype}"
|
||||
|
@ -357,7 +360,7 @@ class Tensor:
|
|||
self.grad = Tensor(1.0, dtype=self.dtype, device=self.device, requires_grad=False)
|
||||
|
||||
for t0 in reversed(self.deepwalk()):
|
||||
if t0.grad is None: raise RuntimeError("tensor has no grad")
|
||||
if t0.grad is None: raise RuntimeError(f"tensor {t0} has no grad")
|
||||
grads = t0._ctx.backward(t0.grad.lazydata)
|
||||
grads = [Tensor(g, device=self.device, requires_grad=False) if g is not None else None
|
||||
for g in ([grads] if len(t0._ctx.parents) == 1 else grads)]
|
||||
|
@ -370,6 +373,7 @@ class Tensor:
|
|||
|
||||
# ***** movement mlops *****
|
||||
|
||||
def view(self, *shape) -> Tensor: return self.reshape(shape) # in tinygrad, view and reshape are the same thing
|
||||
def reshape(self, shape, *args) -> Tensor:
|
||||
new_shape = argfix(shape, *args)
|
||||
new_shape = tuple([-prod(self.shape) // prod(new_shape) if s == -1 else (s if s is not None else self.shape[i]) for i,s in enumerate(new_shape)])
|
||||
|
@ -938,6 +942,8 @@ class Tensor:
|
|||
x,z = x_._broadcasted(other, match_dtype=False)
|
||||
return F.Where.apply(x.cast(dtypes.bool), *y._broadcasted(z))
|
||||
|
||||
def masked_fill(self:Tensor, mask:Tensor, value:Union[Tensor, ConstType]): return mask.where(value, self)
|
||||
|
||||
# ***** op wrappers (wasted lines to make the typechecker happy) *****
|
||||
|
||||
def __neg__(self) -> Tensor: return self.neg()
|
||||
|
@ -1045,6 +1051,7 @@ class Tensor:
|
|||
def element_size(self) -> int: return self.dtype.itemsize
|
||||
def nbytes(self) -> int: return self.numel() * self.element_size()
|
||||
def is_floating_point(self) -> bool: return dtypes.is_float(self.dtype)
|
||||
def size(self, dim=None) -> Union[sint, Tuple[sint, ...]]: return self.shape if dim is None else self.shape[dim]
|
||||
|
||||
# register functions to move between devices
|
||||
for device in Device._devices: setattr(Tensor, f"{device.lower()}", functools.partialmethod(Tensor.to, device))
|
||||
|
|
Loading…
Reference in New Issue