mirror of https://github.com/commaai/tinygrad.git
pad ops broke coder (#2881)
* pad ops broke coder * that contiguous fixes it * Update lazy.py
This commit is contained in:
parent
e1861ab65e
commit
64dded27f0
|
@ -19,12 +19,15 @@ def create_fixed_tokenizer(output_file):
|
|||
with open(output_file, "wb") as f:
|
||||
f.write(mp.SerializeToString())
|
||||
|
||||
# example:
|
||||
# echo -en "write 2+2\nwrite hello world\ny\n" | TEMP=0 python3 examples/coder.py
|
||||
|
||||
if __name__ == "__main__":
|
||||
Tensor.no_grad = True
|
||||
|
||||
# https://huggingface.co/teknium/OpenHermes-2.5-Mistral-7B/blob/main/config.json
|
||||
with Timing("create model: "):
|
||||
model = Transformer(4096, 14336, n_heads=32, n_layers=32, norm_eps=1e-5, vocab_size=32002, n_kv_heads=8, max_context=4096)
|
||||
model = Transformer(4096, 14336, n_heads=32, n_layers=32, norm_eps=1e-5, vocab_size=32002, n_kv_heads=8, max_context=4096, jit=getenv("JIT", 1))
|
||||
|
||||
with Timing("download weights: "):
|
||||
part1 = nn.state.torch_load(fetch("https://huggingface.co/teknium/OpenHermes-2.5-Mistral-7B/resolve/main/pytorch_model-00001-of-00002.bin?download=true"))
|
||||
|
|
|
@ -65,8 +65,8 @@ class Attention:
|
|||
self.cache_k = Tensor.zeros(bsz, self.max_context, self.n_kv_heads, self.head_dim, dtype=x.dtype)
|
||||
self.cache_v = Tensor.zeros(bsz, self.max_context, self.n_kv_heads, self.head_dim, dtype=x.dtype)
|
||||
|
||||
keys = self.cache_k.shrink((None, (0, start_pos), None, None)).cat(xk, dim=1)
|
||||
values = self.cache_v.shrink((None, (0, start_pos), None, None)).cat(xv, dim=1)
|
||||
keys = self.cache_k.shrink((None, (0, start_pos), None, None)).cat(xk, dim=1).contiguous()
|
||||
values = self.cache_v.shrink((None, (0, start_pos), None, None)).cat(xv, dim=1).contiguous()
|
||||
|
||||
# update the cache
|
||||
self.cache_k.assign(keys.pad((None,(0,self.max_context-start_pos-seqlen),None,None)).contiguous()).realize()
|
||||
|
|
|
@ -220,7 +220,7 @@ def _recurse_lb(buf:LazyBuffer, realizes:Set[LazyBuffer], allbufs:Dict[LazyBuffe
|
|||
UNSAFE_PAD_OPS = {BinaryOps.DIV, BinaryOps.CMPLT, UnaryOps.LOG2, UnaryOps.EXP2, UnaryOps.RECIP}
|
||||
def _is_padding_okay(buf:LazyBuffer, realizes:Set[LazyBuffer]) -> bool:
|
||||
if buf in realizes or buf.realized: return True
|
||||
# NOTE: this broke to_image_idx
|
||||
# NOTE: this broke to_image_idx and coder with JIT
|
||||
if buf.op in UNSAFE_PAD_OPS: return False
|
||||
return all(_is_padding_okay(x.base, realizes) for x in buf.srcs)
|
||||
|
||||
|
|
Loading…
Reference in New Issue