diff --git a/examples/gpt2.py b/examples/gpt2.py index 5a818a9c..b36e0d14 100644 --- a/examples/gpt2.py +++ b/examples/gpt2.py @@ -9,8 +9,8 @@ from tinygrad.nn import Embedding, Linear, LayerNorm from tinygrad.shape.symbolic import Variable from tinygrad.jit import TinyJit import tiktoken -from tinygrad.nn.state import torch_load, load_state_dict -from tinygrad.helpers import GlobalCounters, Timing, DEBUG, getenv, fetch, colored, CI +from tinygrad.nn.state import torch_load, load_state_dict, get_state_dict +from tinygrad.helpers import GlobalCounters, Timing, DEBUG, getenv, fetch, colored, dtypes MAX_CONTEXT = getenv("MAX_CONTEXT", 128) @@ -154,8 +154,10 @@ if __name__ == "__main__": parser.add_argument('--model_size', type=str, default="gpt2-medium", help="Size of model to use [gpt2, gpt2-medium, gpt2-large, gpt2-xl]") parser.add_argument('--timing', action='store_true', help="Print timing per token") parser.add_argument('--seed', type=int, help="Set the random seed") - parser.add_argument('--batch_size', type=int, default=1, help="Size of model to use [gpt2, gpt2-medium, gpt2-large, gpt2-xl]") + parser.add_argument('--batch_size', type=int, default=1, help="Set the input batch size") parser.add_argument('--benchmark', type=int, default=-1, help="Benchmark GPT with the given number of tokens") + parser.add_argument('--fp16', action='store_true', help="Cast the weights to float16") + parser.add_argument('--noshow', action='store_true', help="Don't show the output") args = parser.parse_args() if args.seed is not None: @@ -165,11 +167,16 @@ if __name__ == "__main__": print(f"using {args.model_size}") gpt2 = GPT2.build(args.model_size) + if args.fp16: + for l in get_state_dict(gpt2).values(): + l.assign(l.cast(dtypes.float16).realize()) + if args.benchmark != -1: gpt2.model(Tensor.rand(args.batch_size, args.benchmark), Variable("a", 0, MAX_CONTEXT).bind(0)).realize() else: - print('Generating text...') texts = gpt2.greedy_until(args.prompt, args.count, args.temperature, timing=args.timing, batch_size=args.batch_size) - if len(texts) == 1: print(texts[0]) - else: - for i,text in enumerate(texts): print(colored(f"Response {i}:", "green"), text) \ No newline at end of file + if not args.noshow: + print('Generating text...') + if len(texts) == 1: print(texts[0]) + else: + for i,text in enumerate(texts): print(colored(f"Response {i}:", "green"), text) \ No newline at end of file