mirror of https://github.com/commaai/tinygrad.git
failing llama test
This commit is contained in:
parent
8aa63847c7
commit
3ec457248c
|
@ -114,11 +114,11 @@ class Transformer:
|
|||
def __init__(self, dim, multiple_of, n_heads, n_layers, norm_eps, vocab_size, max_batch_size=32, max_seq_len=1024):
|
||||
self.layers = [TransformerBlock(dim, multiple_of, n_heads, norm_eps) for _ in range(n_layers)]
|
||||
self.norm = RMSNorm(dim, norm_eps)
|
||||
self.tok_embeddings = {"weight": Tensor.zeros(vocab_size, dim)}
|
||||
self.tok_embeddings = {"weight": Tensor.glorot_uniform(vocab_size, dim)}
|
||||
self.output = Linear(dim, vocab_size, bias=False)
|
||||
self.freqs_cis = Tensor(precompute_freqs_cis(dim // n_heads, max_seq_len * 2))
|
||||
|
||||
def __call__(self, tokens:Tensor, start_pos:int):
|
||||
def __call__(self, tokens:Tensor, start_pos:int, early_realize_freqs_cis:bool=True):
|
||||
_bsz, seqlen, _ = tokens.shape
|
||||
h = tokens @ self.tok_embeddings['weight']
|
||||
|
||||
|
@ -128,7 +128,7 @@ class Transformer:
|
|||
# WTF!!! This changes the output, and fixes the kv caching. Most serious tinygrad bug in a while.
|
||||
# It is not fixed by disabling the method cache.
|
||||
# TODO: P0. Fix this bug. An offset is likely getting lost somewhere.
|
||||
freqs_cis.realize()
|
||||
if early_realize_freqs_cis: freqs_cis.realize()
|
||||
|
||||
if seqlen > 1:
|
||||
mask = np.full((1, 1, seqlen, start_pos + seqlen), float("-inf"), dtype=np.float32)
|
||||
|
@ -160,9 +160,9 @@ WEIGHTS1_FILENAME = os.path.join(os.path.dirname(os.path.abspath(__file__)), "..
|
|||
|
||||
# **** helper functions ****
|
||||
|
||||
def onehot_encode(toks):
|
||||
def onehot_encode(toks, vocab_size=VOCAB_SIZE):
|
||||
# this allows the embedding to work in tinygrad
|
||||
onehot = np.zeros((1, len(toks), VOCAB_SIZE), dtype=np.float32)
|
||||
onehot = np.zeros((1, len(toks), vocab_size), dtype=np.float32)
|
||||
onehot[0,range(len(toks)),toks] = 1
|
||||
return Tensor(onehot)
|
||||
|
||||
|
@ -178,6 +178,8 @@ def sample(logits, temperature):
|
|||
# **** main code ****
|
||||
|
||||
if __name__ == "__main__":
|
||||
Tensor.no_grad = True
|
||||
|
||||
print(f"using {Device.DEFAULT} backend")
|
||||
from sentencepiece import SentencePieceProcessor
|
||||
sp_model = SentencePieceProcessor(model_file=TOKENIZER_FILENAME)
|
||||
|
|
|
@ -0,0 +1,41 @@
|
|||
#!/usr/bin/env python
|
||||
import numpy as np
|
||||
from examples.llama import Transformer, onehot_encode
|
||||
from tinygrad.tensor import Tensor
|
||||
|
||||
VOCAB_SIZE = 4
|
||||
args_test = {"dim": 2, "multiple_of": 1, "n_heads": 1, "n_layers": 1, "norm_eps": 1e-05, "vocab_size": VOCAB_SIZE}
|
||||
|
||||
if __name__ == "__main__":
|
||||
Tensor.no_grad = True
|
||||
|
||||
Tensor.manual_seed(1337)
|
||||
model = Transformer(**args_test)
|
||||
|
||||
print("run a")
|
||||
outa_0 = model(onehot_encode([1], VOCAB_SIZE), 0).numpy()
|
||||
print(outa_0)
|
||||
outa_1 = model(onehot_encode([3], VOCAB_SIZE), 1).numpy()
|
||||
print(outa_1)
|
||||
|
||||
print("run b")
|
||||
outb_0 = model(onehot_encode([1], VOCAB_SIZE), 0, False).numpy()
|
||||
print(outb_0)
|
||||
outb_1 = model(onehot_encode([3], VOCAB_SIZE), 1, False).numpy()
|
||||
print(outb_1)
|
||||
|
||||
print("run c")
|
||||
outc_0 = model(onehot_encode([1], VOCAB_SIZE), 0).numpy()
|
||||
print(outc_0)
|
||||
outc_1 = model(onehot_encode([3], VOCAB_SIZE), 1).numpy()
|
||||
print(outc_1)
|
||||
|
||||
# a and c are the same
|
||||
np.testing.assert_allclose(outa_0, outc_0)
|
||||
np.testing.assert_allclose(outa_1, outc_1)
|
||||
|
||||
# b and c should the same
|
||||
np.testing.assert_allclose(outb_0, outc_0)
|
||||
print("FAILS")
|
||||
np.testing.assert_allclose(outb_1, outc_1)
|
||||
|
|
@ -151,10 +151,12 @@ class LazyBuffer:
|
|||
# NOTE: we have to make a copy of the numpy array here in case the user changes it. expose this?
|
||||
@staticmethod
|
||||
def fromCPU(x, device) -> LazyBuffer: return LazyBuffer(device, x.shape, LoadOps, LazyOp(LoadOps.FROMCPU, tuple(), x.copy()), dtypes.from_np(x))
|
||||
|
||||
# NOTE: we also have to copy the numpy array on the way out...otherwise the underlying Tensor could be freed and use after free. improve this?
|
||||
def toCPU(self):
|
||||
ret = self.realize().toCPU()
|
||||
log_op(InterpretedBuffer(ret), LazyOp(LoadOps.TOCPU, (self.realized,), None))
|
||||
return ret
|
||||
return ret.copy()
|
||||
|
||||
def unary_op(self:LazyBuffer, op:UnaryOps) -> LazyBuffer: return elementwise_op(op, self)
|
||||
def binary_op(self:LazyBuffer, op:BinaryOps, y:LazyBuffer) -> LazyBuffer: return elementwise_op(op, self, y)
|
||||
|
|
Loading…
Reference in New Issue