pad ops broke coder (#2881)

* pad ops broke coder

* that contiguous fixes it

* Update lazy.py
This commit is contained in:
George Hotz 2023-12-20 17:03:41 -08:00 committed by GitHub
parent e1861ab65e
commit 64dded27f0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 8 additions and 5 deletions

View File

@ -19,12 +19,15 @@ def create_fixed_tokenizer(output_file):
with open(output_file, "wb") as f:
f.write(mp.SerializeToString())
# example:
# echo -en "write 2+2\nwrite hello world\ny\n" | TEMP=0 python3 examples/coder.py
if __name__ == "__main__":
Tensor.no_grad = True
# https://huggingface.co/teknium/OpenHermes-2.5-Mistral-7B/blob/main/config.json
with Timing("create model: "):
model = Transformer(4096, 14336, n_heads=32, n_layers=32, norm_eps=1e-5, vocab_size=32002, n_kv_heads=8, max_context=4096)
model = Transformer(4096, 14336, n_heads=32, n_layers=32, norm_eps=1e-5, vocab_size=32002, n_kv_heads=8, max_context=4096, jit=getenv("JIT", 1))
with Timing("download weights: "):
part1 = nn.state.torch_load(fetch("https://huggingface.co/teknium/OpenHermes-2.5-Mistral-7B/resolve/main/pytorch_model-00001-of-00002.bin?download=true"))

View File

@ -65,8 +65,8 @@ class Attention:
self.cache_k = Tensor.zeros(bsz, self.max_context, self.n_kv_heads, self.head_dim, dtype=x.dtype)
self.cache_v = Tensor.zeros(bsz, self.max_context, self.n_kv_heads, self.head_dim, dtype=x.dtype)
keys = self.cache_k.shrink((None, (0, start_pos), None, None)).cat(xk, dim=1)
values = self.cache_v.shrink((None, (0, start_pos), None, None)).cat(xv, dim=1)
keys = self.cache_k.shrink((None, (0, start_pos), None, None)).cat(xk, dim=1).contiguous()
values = self.cache_v.shrink((None, (0, start_pos), None, None)).cat(xv, dim=1).contiguous()
# update the cache
self.cache_k.assign(keys.pad((None,(0,self.max_context-start_pos-seqlen),None,None)).contiguous()).realize()

View File

@ -220,7 +220,7 @@ def _recurse_lb(buf:LazyBuffer, realizes:Set[LazyBuffer], allbufs:Dict[LazyBuffe
UNSAFE_PAD_OPS = {BinaryOps.DIV, BinaryOps.CMPLT, UnaryOps.LOG2, UnaryOps.EXP2, UnaryOps.RECIP}
def _is_padding_okay(buf:LazyBuffer, realizes:Set[LazyBuffer]) -> bool:
if buf in realizes or buf.realized: return True
# NOTE: this broke to_image_idx
# NOTE: this broke to_image_idx and coder with JIT
if buf.op in UNSAFE_PAD_OPS: return False
return all(_is_padding_okay(x.base, realizes) for x in buf.srcs)