mirror of https://github.com/commaai/tinygrad.git
fp16 and noshow flags for gpt2 (#2470)
This commit is contained in:
parent
e267a93124
commit
7f9a4c1285
|
@ -9,8 +9,8 @@ from tinygrad.nn import Embedding, Linear, LayerNorm
|
|||
from tinygrad.shape.symbolic import Variable
|
||||
from tinygrad.jit import TinyJit
|
||||
import tiktoken
|
||||
from tinygrad.nn.state import torch_load, load_state_dict
|
||||
from tinygrad.helpers import GlobalCounters, Timing, DEBUG, getenv, fetch, colored, CI
|
||||
from tinygrad.nn.state import torch_load, load_state_dict, get_state_dict
|
||||
from tinygrad.helpers import GlobalCounters, Timing, DEBUG, getenv, fetch, colored, dtypes
|
||||
|
||||
MAX_CONTEXT = getenv("MAX_CONTEXT", 128)
|
||||
|
||||
|
@ -154,8 +154,10 @@ if __name__ == "__main__":
|
|||
parser.add_argument('--model_size', type=str, default="gpt2-medium", help="Size of model to use [gpt2, gpt2-medium, gpt2-large, gpt2-xl]")
|
||||
parser.add_argument('--timing', action='store_true', help="Print timing per token")
|
||||
parser.add_argument('--seed', type=int, help="Set the random seed")
|
||||
parser.add_argument('--batch_size', type=int, default=1, help="Size of model to use [gpt2, gpt2-medium, gpt2-large, gpt2-xl]")
|
||||
parser.add_argument('--batch_size', type=int, default=1, help="Set the input batch size")
|
||||
parser.add_argument('--benchmark', type=int, default=-1, help="Benchmark GPT with the given number of tokens")
|
||||
parser.add_argument('--fp16', action='store_true', help="Cast the weights to float16")
|
||||
parser.add_argument('--noshow', action='store_true', help="Don't show the output")
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.seed is not None:
|
||||
|
@ -165,11 +167,16 @@ if __name__ == "__main__":
|
|||
print(f"using {args.model_size}")
|
||||
gpt2 = GPT2.build(args.model_size)
|
||||
|
||||
if args.fp16:
|
||||
for l in get_state_dict(gpt2).values():
|
||||
l.assign(l.cast(dtypes.float16).realize())
|
||||
|
||||
if args.benchmark != -1:
|
||||
gpt2.model(Tensor.rand(args.batch_size, args.benchmark), Variable("a", 0, MAX_CONTEXT).bind(0)).realize()
|
||||
else:
|
||||
print('Generating text...')
|
||||
texts = gpt2.greedy_until(args.prompt, args.count, args.temperature, timing=args.timing, batch_size=args.batch_size)
|
||||
if not args.noshow:
|
||||
print('Generating text...')
|
||||
if len(texts) == 1: print(texts[0])
|
||||
else:
|
||||
for i,text in enumerate(texts): print(colored(f"Response {i}:", "green"), text)
|
Loading…
Reference in New Issue