fp16 and noshow flags for gpt2 (#2470)

This commit is contained in:
chenyu 2023-11-27 16:23:03 -05:00 committed by GitHub
parent e267a93124
commit 7f9a4c1285
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 14 additions and 7 deletions

View File

@ -9,8 +9,8 @@ from tinygrad.nn import Embedding, Linear, LayerNorm
from tinygrad.shape.symbolic import Variable
from tinygrad.jit import TinyJit
import tiktoken
from tinygrad.nn.state import torch_load, load_state_dict
from tinygrad.helpers import GlobalCounters, Timing, DEBUG, getenv, fetch, colored, CI
from tinygrad.nn.state import torch_load, load_state_dict, get_state_dict
from tinygrad.helpers import GlobalCounters, Timing, DEBUG, getenv, fetch, colored, dtypes
MAX_CONTEXT = getenv("MAX_CONTEXT", 128)
@ -154,8 +154,10 @@ if __name__ == "__main__":
parser.add_argument('--model_size', type=str, default="gpt2-medium", help="Size of model to use [gpt2, gpt2-medium, gpt2-large, gpt2-xl]")
parser.add_argument('--timing', action='store_true', help="Print timing per token")
parser.add_argument('--seed', type=int, help="Set the random seed")
parser.add_argument('--batch_size', type=int, default=1, help="Size of model to use [gpt2, gpt2-medium, gpt2-large, gpt2-xl]")
parser.add_argument('--batch_size', type=int, default=1, help="Set the input batch size")
parser.add_argument('--benchmark', type=int, default=-1, help="Benchmark GPT with the given number of tokens")
parser.add_argument('--fp16', action='store_true', help="Cast the weights to float16")
parser.add_argument('--noshow', action='store_true', help="Don't show the output")
args = parser.parse_args()
if args.seed is not None:
@ -165,11 +167,16 @@ if __name__ == "__main__":
print(f"using {args.model_size}")
gpt2 = GPT2.build(args.model_size)
if args.fp16:
for l in get_state_dict(gpt2).values():
l.assign(l.cast(dtypes.float16).realize())
if args.benchmark != -1:
gpt2.model(Tensor.rand(args.batch_size, args.benchmark), Variable("a", 0, MAX_CONTEXT).bind(0)).realize()
else:
print('Generating text...')
texts = gpt2.greedy_until(args.prompt, args.count, args.temperature, timing=args.timing, batch_size=args.batch_size)
if not args.noshow:
print('Generating text...')
if len(texts) == 1: print(texts[0])
else:
for i,text in enumerate(texts): print(colored(f"Response {i}:", "green"), text)