llm.c updates

This commit is contained in:
George Hotz 2024-09-27 15:25:59 +08:00
parent eaa1e0eeeb
commit b0e70ab04f
1 changed files with 11 additions and 12 deletions

View File

@ -120,6 +120,7 @@ if __name__ == "__main__":
parser.add_argument("--num_iterations", type=int, default=10, help="number of iterations to run")
parser.add_argument("--batch_size", type=int, default=4, help="batch size")
parser.add_argument("--sequence_length", type=int, default=64, help="sequence length")
parser.add_argument("--skip_test", action="store_true", help="skip test")
args = parser.parse_args()
B, T = args.batch_size, args.sequence_length
assert 1 <= T <= 1024
@ -135,10 +136,7 @@ if __name__ == "__main__":
# load the tokens
# prefer to use tiny_shakespeare if it's available, otherwise use tiny_stories
# we're using val instead of train split just because it is smaller/faster
shake_tokens_bin = "data/tiny_shakespeare_val.bin"
story_tokens_bin = "data/TinyStories_val.bin"
assert os.path.isfile(shake_tokens_bin) or os.path.isfile(story_tokens_bin), "you must run prepro on some dataset"
tokens_bin = shake_tokens_bin if os.path.isfile(shake_tokens_bin) else story_tokens_bin
tokens_bin = fetch("https://huggingface.co/datasets/karpathy/llmc-starter-pack/resolve/main/tiny_shakespeare_val.bin")
assert os.path.isfile(tokens_bin)
print(f"loading cached tokens in {tokens_bin}")
with open(tokens_bin, "rb") as f:
@ -181,12 +179,13 @@ if __name__ == "__main__":
t1 = time.time()
print(f"iteration {i}, loss: {loss.item()}, time: {(t1-t0)*1000:.3f}ms")
start = "<|endoftext|>"
start_ids = encode(start)
x = (Tensor(start_ids)[None, ...])
max_new_tokens = 16
temperature = 1.0
top_k = 40
y = model.generate(x, max_new_tokens, temperature=temperature, top_k=top_k)
print(decode(y[0].tolist()))
if not args.skip_test:
start = "<|endoftext|>"
start_ids = encode(start)
x = (Tensor(start_ids)[None, ...])
max_new_tokens = 16
temperature = 1.0
top_k = 40
y = model.generate(x, max_new_tokens, temperature=temperature, top_k=top_k)
print(decode(y[0].tolist()))