hcq profiler support args (#5989)

* hcq profiler support args

* bytes -> _bytes

* fix

* add test

* mypy

* not f strings

* percison
This commit is contained in:
nimlgen 2024-08-09 00:18:36 +03:00 committed by GitHub
parent 45b1761175
commit 38d5eecc68
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 41 additions and 15 deletions

View File

@ -169,6 +169,27 @@ class TestProfiler(unittest.TestCase):
assert l['name'].find("->") == -1, "should be kernel"
assert r['name'] == f"{Device.DEFAULT} -> {Device.DEFAULT}:1", "should be copy"
@unittest.skipIf(MOCKGPU and Device.DEFAULT == "AMD", "AMD mockgpu with indirect buffers does not support queue wait interrupts")
def test_profile_copy_args(self):
d1 = Device[f"{Device.DEFAULT}:1"]
def f(a):
x = (a + 1).realize()
return x, x.to(d1.dname).realize()
a = Tensor.randn(10, 10, device=TestProfiler.d0.dname).realize()
with helper_collect_profile(TestProfiler.d0, d1) as profile:
jf = TinyJit(f)
for _ in range(3):
TestProfiler.d0.raw_prof_records, TestProfiler.d0.sig_prof_records = [], [] # reset to collect only graph logs
d1.raw_prof_records, d1.sig_prof_records = [], []
jf(a)
del jf
node = helper_profile_filter_node(profile, name=f"{Device.DEFAULT} -> {Device.DEFAULT}:1")[-1]
assert node['args']['Size'] == "400.00 B"
assert abs(float(node['args']['GB/S']) - ((10 * 10 * 4) / 1e3) / (node['dur'])) < 0.01
@unittest.skipIf(CI, "skip CI")
def test_profile_sync(self):
mv = memoryview(bytearray(struct.pack("ff", 0, 1)))

View File

@ -493,7 +493,7 @@ class HCQCompiled(Compiled):
self.timeline_value:int = 1
self.timeline_signal, self._shadow_timeline_signal = timeline_signals
self.sig_prof_records:List[Tuple[HCQSignal, HCQSignal, str, bool]] = []
self.raw_prof_records:List[Tuple[decimal.Decimal, decimal.Decimal, str, bool]] = []
self.raw_prof_records:List[Tuple[decimal.Decimal, decimal.Decimal, str, bool, Optional[Dict]]] = []
self.dep_prof_records:List[Tuple[decimal.Decimal, decimal.Decimal, HCQCompiled, bool, decimal.Decimal, decimal.Decimal, HCQCompiled, bool]] = []
if PROFILE: self._prof_setup()
@ -509,7 +509,7 @@ class HCQCompiled(Compiled):
if self.timeline_value > (1 << 31): self._wrap_timeline_signal()
if PROFILE:
self.raw_prof_records += [(st.timestamp, en.timestamp, name, is_cp) for st, en, name, is_cp in self.sig_prof_records]
self.raw_prof_records += [(st.timestamp, en.timestamp, name, is_cp, None) for st, en, name, is_cp in self.sig_prof_records]
self.sig_prof_records = []
def _alloc_kernargs(self, alloc_size:int) -> int:
@ -580,8 +580,8 @@ class HCQCompiled(Compiled):
# Sync to be sure all events on the device are recorded.
self.synchronize()
for st, en, name, is_cp in self.raw_prof_records:
self.profile_logger.events += [(name, self._gpu2cpu_time(st, is_cp), self._gpu2cpu_time(en, is_cp), self.dname, qname[is_cp])]
for st, en, name, is_cp, args in self.raw_prof_records:
self.profile_logger.events += [(name, self._gpu2cpu_time(st, is_cp), self._gpu2cpu_time(en, is_cp), self.dname, qname[is_cp], args)]
for a_st, a_en, a_dev, a_is_copy, b_st, b_en, b_dev, b_is_copy in self.dep_prof_records:
# Perfetto connects nodes based on timing data, ensuring every choice is valid by averaging times to a midpoint.
a_tm, b_tm = a_dev._gpu2cpu_time((a_st+a_en)/decimal.Decimal(2), a_is_copy), b_dev._gpu2cpu_time((b_st+b_en)/decimal.Decimal(2), b_is_copy)

View File

@ -27,6 +27,7 @@ def all_same(items:Union[Tuple[T, ...], List[T]]): return all(x == items[0] for
def all_int(t: Sequence[Any]) -> TypeGuard[Tuple[int, ...]]: return all(isinstance(s, int) for s in t)
def colored(st, color:Optional[str], background=False): return f"\u001b[{10*background+60*(color.upper() == color)+30+['black', 'red', 'green', 'yellow', 'blue', 'magenta', 'cyan', 'white'].index(color.lower())}m{st}\u001b[0m" if color is not None else st # replace the termcolor library with one line # noqa: E501
def colorize_float(x: float): return colored(f"{x:7.2f}x", 'green' if x < 0.75 else 'red' if x > 1.15 else 'yellow')
def memsize_to_str(_bytes: int) -> str: return [f"{(_bytes / d):.2f} {pr}" for d,pr in [(1e9,"GB"),(1e6,"MB"),(1e3,"KB"),(1,"B")] if _bytes > d][0]
def ansistrip(s:str): return re.sub('\x1b\\[(K|.*?m)', '', s)
def ansilen(s:str): return len(ansistrip(s))
def make_pair(x:Union[int, Tuple[int, ...]], cnt=2) -> Tuple[int, ...]: return (x,)*cnt if isinstance(x, int) else x
@ -166,7 +167,7 @@ class ProfileLogger:
def __init__(self): self.events, self.deps, ProfileLogger.writers = [], [], ProfileLogger.writers + 1
def add_event(self, ev_name, ev_start, ev_end, actor, subactor=None): self.events += [(ev_name, ev_start, ev_end, actor, subactor)]
def add_event(self, ev_name, ev_start, ev_end, actor, subactor=None, args=None): self.events += [(ev_name, ev_start, ev_end, actor, subactor, args)]
def _ensure_actor(self, actor_name, subactor_name):
if actor_name not in self.actors:
@ -181,9 +182,10 @@ class ProfileLogger:
def __del__(self):
# perfetto json docs: https://docs.google.com/document/d/1CvAClvFfyA5R-PhYUmn5OOQtYMH4h6I0nSsKchNAySU/preview
for name,st,et,actor_name,subactor_name in self.events:
for name, st, et, actor_name, subactor_name, args in self.events:
pid, tid = self._ensure_actor(actor_name,subactor_name)
self.mjson.append({"name": name, "ph": "X", "pid": pid, "tid": tid, "ts":st, "dur":et-st})
args = {k: (v if v.__class__ is str else v(et-st)) for k, v in args.items()} if args is not None else None
self.mjson.append({"name": name, "ph": "X", "pid": pid, "tid": tid, "ts": st, "dur": et-st, "args": args})
for en,st,dep_actor_name,dep_subactor_name,actor_name,subactor_name in self.deps:
dep_pid, dep_tid = self._ensure_actor(dep_actor_name,dep_subactor_name)

View File

@ -1,6 +1,6 @@
import collections, time
from typing import List, Any, Dict, cast, Optional, Tuple, Set
from tinygrad.helpers import round_up, to_mv, PROFILE
from tinygrad.helpers import round_up, to_mv, PROFILE, memsize_to_str
from tinygrad.device import HCQCompiled, HCQAllocator, HCQSignal, HCQBuffer, HWCommandQueue, HWComputeQueue, HWCopyQueue, \
Buffer, BufferOptions, Compiled, Device
from tinygrad.shape.symbolic import Variable
@ -47,7 +47,7 @@ class HCQGraph(MultiGraphRunner):
self.kickoff_value: int = 0
self.prof_signals: List[HCQSignal] = [self.devices[0].signal_t() for i in range(len(self.jit_cache) * 2)] if PROFILE else []
self.prof_records: List[Tuple[Tuple[int, bool], Tuple[int, bool], HCQCompiled, str, bool, List[int]]] = []
self.prof_records: List[Tuple[Tuple[int, bool], Tuple[int, bool], HCQCompiled, str, bool, List[int], Optional[Dict]]] = []
last_j: Dict[HWCommandQueue, Optional[int]] = collections.defaultdict(lambda: None)
queue_access: Dict[HWCommandQueue, Dict[HWCommandQueue, Optional[int]]] = collections.defaultdict(lambda: collections.defaultdict(lambda: None))
@ -96,7 +96,10 @@ class HCQGraph(MultiGraphRunner):
sig_st, sig_en = (j * 2, True), (j * 2 + 1, True)
if len(opt_deps) == 0 and (prev_ji:=last_j[enqueue_queue]) is not None: sig_st = (prev_ji * 2 + 1, False)
self.prof_records.append((sig_st, sig_en, enqueue_dev, prof_ji_desc, not is_exec_prg, [d - 1 for _, d in rdeps]))
if is_exec_prg: prof_args = None
else: prof_args = {"Size": memsize_to_str(ji.bufs[0].nbytes), "GB/S": lambda dur, b=ji.bufs[0].nbytes: f"{b/1e3/dur:.2f}"} # type: ignore
self.prof_records.append((sig_st, sig_en, enqueue_dev, prof_ji_desc, not is_exec_prg, [d - 1 for _, d in rdeps], prof_args))
last_j[enqueue_queue] = j
@ -187,11 +190,11 @@ class HCQGraph(MultiGraphRunner):
def collect_timestamps(self):
timestamps = [s.timestamp for s in self.prof_signals]
for (st,_), (en,_), dev, desc, is_cp, deps in self.prof_records:
dev.raw_prof_records += [(timestamps[st], timestamps[en], desc, is_cp)]
for (st,_), (en,_), dev, desc, is_cp, deps, args in self.prof_records:
dev.raw_prof_records += [(timestamps[st], timestamps[en], desc, is_cp, args)]
for x in deps:
(b_st,_), (b_en,_), b_dev, _, b_is_cp, _ = self.prof_records[x]
(b_st,_), (b_en,_), b_dev, _, b_is_cp, _, _ = self.prof_records[x]
dev.dep_prof_records += [(timestamps[b_st], timestamps[b_en], b_dev, b_is_cp, timestamps[st], timestamps[en], dev, is_cp)]
def __del__(self):