fix debug=1 llama/gpt2 timings (#2143)

This commit is contained in:
nimlgen 2023-10-24 22:45:00 +03:00 committed by GitHub
parent 4444e6d4b3
commit e21bf776c8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 4 additions and 4 deletions

View File

@ -180,9 +180,9 @@ class GPT2:
if timing: print("")
st = GlobalCounters.time_sum_s
with Timing("total ", enabled=timing):
with Timing(f"ran model in ", on_exit=(lambda et: f", {(GlobalCounters.time_sum_s-st)*1e3:.2f} ms on GPU"+
with Timing("ran model in ", on_exit=(lambda et: (f", {(GlobalCounters.time_sum_s-st)*1e3:.2f} ms on GPU" if DEBUG>=2 else "")+
f", {GlobalCounters.global_ops*1e-9:.2f} GOPS, {GlobalCounters.global_mem*1e-9:.2f} GB"+
f", {GlobalCounters.global_mem*1e-9/(GlobalCounters.time_sum_s-st):.2f} GB/s") if DEBUG else None, enabled=timing):
(f", {GlobalCounters.global_mem*1e-9/(GlobalCounters.time_sum_s-st):.2f} GB/s" if DEBUG>=2 else "")) if DEBUG else None, enabled=timing):
probs = self.model(Tensor([toks[start_pos:]]), start_pos, temperature)
probs_np = probs.numpy()
tok = int(np.random.choice(len(probs_np), p=probs_np))

View File

@ -558,9 +558,9 @@ After you are done speaking, output [EOS]. You are not Chad.
if args.timing: print("")
st = GlobalCounters.time_sum_s
with Timing("total ", enabled=args.timing, on_exit=lambda x: f", {1e9/x:.2f} tok/sec"):
with Timing("ran model in ", on_exit=(lambda et: f", {(GlobalCounters.time_sum_s-st)*1e3:.2f} ms on GPU"+
with Timing("ran model in ", on_exit=(lambda et: (f", {(GlobalCounters.time_sum_s-st)*1e3:.2f} ms on GPU" if DEBUG>=2 else "")+
f", {GlobalCounters.global_ops*1e-9:.2f} GOPS, {GlobalCounters.global_mem*1e-9:.2f} GB"+
f", {GlobalCounters.global_mem*1e-9/(GlobalCounters.time_sum_s-st):.2f} GB/s") if DEBUG else None, enabled=args.timing):
(f", {GlobalCounters.global_mem*1e-9/(GlobalCounters.time_sum_s-st):.2f} GB/s" if DEBUG>=2 else "")) if DEBUG else None, enabled=args.timing):
probs = llama.model(Tensor([toks[start_pos:]]), start_pos, args.temperature).realize()
probs_np = probs.numpy()
tok = int(np.random.choice(len(probs_np), p=probs_np))