diff --git a/test/test_profiler.py b/test/test_profiler.py index 2dac69cc..2d0fc5d5 100644 --- a/test/test_profiler.py +++ b/test/test_profiler.py @@ -169,6 +169,27 @@ class TestProfiler(unittest.TestCase): assert l['name'].find("->") == -1, "should be kernel" assert r['name'] == f"{Device.DEFAULT} -> {Device.DEFAULT}:1", "should be copy" + @unittest.skipIf(MOCKGPU and Device.DEFAULT == "AMD", "AMD mockgpu with indirect buffers does not support queue wait interrupts") + def test_profile_copy_args(self): + d1 = Device[f"{Device.DEFAULT}:1"] + + def f(a): + x = (a + 1).realize() + return x, x.to(d1.dname).realize() + + a = Tensor.randn(10, 10, device=TestProfiler.d0.dname).realize() + with helper_collect_profile(TestProfiler.d0, d1) as profile: + jf = TinyJit(f) + for _ in range(3): + TestProfiler.d0.raw_prof_records, TestProfiler.d0.sig_prof_records = [], [] # reset to collect only graph logs + d1.raw_prof_records, d1.sig_prof_records = [], [] + jf(a) + del jf + + node = helper_profile_filter_node(profile, name=f"{Device.DEFAULT} -> {Device.DEFAULT}:1")[-1] + assert node['args']['Size'] == "400.00 B" + assert abs(float(node['args']['GB/S']) - ((10 * 10 * 4) / 1e3) / (node['dur'])) < 0.01 + @unittest.skipIf(CI, "skip CI") def test_profile_sync(self): mv = memoryview(bytearray(struct.pack("ff", 0, 1))) diff --git a/tinygrad/device.py b/tinygrad/device.py index 856ff3d2..4a5857e8 100644 --- a/tinygrad/device.py +++ b/tinygrad/device.py @@ -493,7 +493,7 @@ class HCQCompiled(Compiled): self.timeline_value:int = 1 self.timeline_signal, self._shadow_timeline_signal = timeline_signals self.sig_prof_records:List[Tuple[HCQSignal, HCQSignal, str, bool]] = [] - self.raw_prof_records:List[Tuple[decimal.Decimal, decimal.Decimal, str, bool]] = [] + self.raw_prof_records:List[Tuple[decimal.Decimal, decimal.Decimal, str, bool, Optional[Dict]]] = [] self.dep_prof_records:List[Tuple[decimal.Decimal, decimal.Decimal, HCQCompiled, bool, decimal.Decimal, decimal.Decimal, HCQCompiled, bool]] = [] if PROFILE: self._prof_setup() @@ -509,7 +509,7 @@ class HCQCompiled(Compiled): if self.timeline_value > (1 << 31): self._wrap_timeline_signal() if PROFILE: - self.raw_prof_records += [(st.timestamp, en.timestamp, name, is_cp) for st, en, name, is_cp in self.sig_prof_records] + self.raw_prof_records += [(st.timestamp, en.timestamp, name, is_cp, None) for st, en, name, is_cp in self.sig_prof_records] self.sig_prof_records = [] def _alloc_kernargs(self, alloc_size:int) -> int: @@ -580,8 +580,8 @@ class HCQCompiled(Compiled): # Sync to be sure all events on the device are recorded. self.synchronize() - for st, en, name, is_cp in self.raw_prof_records: - self.profile_logger.events += [(name, self._gpu2cpu_time(st, is_cp), self._gpu2cpu_time(en, is_cp), self.dname, qname[is_cp])] + for st, en, name, is_cp, args in self.raw_prof_records: + self.profile_logger.events += [(name, self._gpu2cpu_time(st, is_cp), self._gpu2cpu_time(en, is_cp), self.dname, qname[is_cp], args)] for a_st, a_en, a_dev, a_is_copy, b_st, b_en, b_dev, b_is_copy in self.dep_prof_records: # Perfetto connects nodes based on timing data, ensuring every choice is valid by averaging times to a midpoint. a_tm, b_tm = a_dev._gpu2cpu_time((a_st+a_en)/decimal.Decimal(2), a_is_copy), b_dev._gpu2cpu_time((b_st+b_en)/decimal.Decimal(2), b_is_copy) diff --git a/tinygrad/helpers.py b/tinygrad/helpers.py index df9d7948..390bd1d9 100644 --- a/tinygrad/helpers.py +++ b/tinygrad/helpers.py @@ -27,6 +27,7 @@ def all_same(items:Union[Tuple[T, ...], List[T]]): return all(x == items[0] for def all_int(t: Sequence[Any]) -> TypeGuard[Tuple[int, ...]]: return all(isinstance(s, int) for s in t) def colored(st, color:Optional[str], background=False): return f"\u001b[{10*background+60*(color.upper() == color)+30+['black', 'red', 'green', 'yellow', 'blue', 'magenta', 'cyan', 'white'].index(color.lower())}m{st}\u001b[0m" if color is not None else st # replace the termcolor library with one line # noqa: E501 def colorize_float(x: float): return colored(f"{x:7.2f}x", 'green' if x < 0.75 else 'red' if x > 1.15 else 'yellow') +def memsize_to_str(_bytes: int) -> str: return [f"{(_bytes / d):.2f} {pr}" for d,pr in [(1e9,"GB"),(1e6,"MB"),(1e3,"KB"),(1,"B")] if _bytes > d][0] def ansistrip(s:str): return re.sub('\x1b\\[(K|.*?m)', '', s) def ansilen(s:str): return len(ansistrip(s)) def make_pair(x:Union[int, Tuple[int, ...]], cnt=2) -> Tuple[int, ...]: return (x,)*cnt if isinstance(x, int) else x @@ -166,7 +167,7 @@ class ProfileLogger: def __init__(self): self.events, self.deps, ProfileLogger.writers = [], [], ProfileLogger.writers + 1 - def add_event(self, ev_name, ev_start, ev_end, actor, subactor=None): self.events += [(ev_name, ev_start, ev_end, actor, subactor)] + def add_event(self, ev_name, ev_start, ev_end, actor, subactor=None, args=None): self.events += [(ev_name, ev_start, ev_end, actor, subactor, args)] def _ensure_actor(self, actor_name, subactor_name): if actor_name not in self.actors: @@ -181,15 +182,16 @@ class ProfileLogger: def __del__(self): # perfetto json docs: https://docs.google.com/document/d/1CvAClvFfyA5R-PhYUmn5OOQtYMH4h6I0nSsKchNAySU/preview - for name,st,et,actor_name,subactor_name in self.events: + for name, st, et, actor_name, subactor_name, args in self.events: pid, tid = self._ensure_actor(actor_name,subactor_name) - self.mjson.append({"name": name, "ph": "X", "pid": pid, "tid": tid, "ts":st, "dur":et-st}) + args = {k: (v if v.__class__ is str else v(et-st)) for k, v in args.items()} if args is not None else None + self.mjson.append({"name": name, "ph": "X", "pid": pid, "tid": tid, "ts": st, "dur": et-st, "args": args}) for en,st,dep_actor_name,dep_subactor_name,actor_name,subactor_name in self.deps: dep_pid, dep_tid = self._ensure_actor(dep_actor_name,dep_subactor_name) pid, tid = self._ensure_actor(actor_name,subactor_name) - self.mjson.append({"ph": "s", "pid": dep_pid, "tid": dep_tid, "id": len(self.mjson), "ts":en, "bp": "e"}) - self.mjson.append({"ph": "f", "pid": pid, "tid": tid, "id": len(self.mjson)-1, "ts":st, "bp": "e"}) + self.mjson.append({"ph": "s", "pid": dep_pid, "tid": dep_tid, "id": len(self.mjson), "ts": en, "bp": "e"}) + self.mjson.append({"ph": "f", "pid": pid, "tid": tid, "id": len(self.mjson)-1, "ts": st, "bp": "e"}) ProfileLogger.writers -= 1 if ProfileLogger.writers == 0 and len(self.mjson) > 0: diff --git a/tinygrad/runtime/graph/hcq.py b/tinygrad/runtime/graph/hcq.py index c23d2468..b5a65ce5 100644 --- a/tinygrad/runtime/graph/hcq.py +++ b/tinygrad/runtime/graph/hcq.py @@ -1,6 +1,6 @@ import collections, time from typing import List, Any, Dict, cast, Optional, Tuple, Set -from tinygrad.helpers import round_up, to_mv, PROFILE +from tinygrad.helpers import round_up, to_mv, PROFILE, memsize_to_str from tinygrad.device import HCQCompiled, HCQAllocator, HCQSignal, HCQBuffer, HWCommandQueue, HWComputeQueue, HWCopyQueue, \ Buffer, BufferOptions, Compiled, Device from tinygrad.shape.symbolic import Variable @@ -47,7 +47,7 @@ class HCQGraph(MultiGraphRunner): self.kickoff_value: int = 0 self.prof_signals: List[HCQSignal] = [self.devices[0].signal_t() for i in range(len(self.jit_cache) * 2)] if PROFILE else [] - self.prof_records: List[Tuple[Tuple[int, bool], Tuple[int, bool], HCQCompiled, str, bool, List[int]]] = [] + self.prof_records: List[Tuple[Tuple[int, bool], Tuple[int, bool], HCQCompiled, str, bool, List[int], Optional[Dict]]] = [] last_j: Dict[HWCommandQueue, Optional[int]] = collections.defaultdict(lambda: None) queue_access: Dict[HWCommandQueue, Dict[HWCommandQueue, Optional[int]]] = collections.defaultdict(lambda: collections.defaultdict(lambda: None)) @@ -96,7 +96,10 @@ class HCQGraph(MultiGraphRunner): sig_st, sig_en = (j * 2, True), (j * 2 + 1, True) if len(opt_deps) == 0 and (prev_ji:=last_j[enqueue_queue]) is not None: sig_st = (prev_ji * 2 + 1, False) - self.prof_records.append((sig_st, sig_en, enqueue_dev, prof_ji_desc, not is_exec_prg, [d - 1 for _, d in rdeps])) + if is_exec_prg: prof_args = None + else: prof_args = {"Size": memsize_to_str(ji.bufs[0].nbytes), "GB/S": lambda dur, b=ji.bufs[0].nbytes: f"{b/1e3/dur:.2f}"} # type: ignore + + self.prof_records.append((sig_st, sig_en, enqueue_dev, prof_ji_desc, not is_exec_prg, [d - 1 for _, d in rdeps], prof_args)) last_j[enqueue_queue] = j @@ -187,11 +190,11 @@ class HCQGraph(MultiGraphRunner): def collect_timestamps(self): timestamps = [s.timestamp for s in self.prof_signals] - for (st,_), (en,_), dev, desc, is_cp, deps in self.prof_records: - dev.raw_prof_records += [(timestamps[st], timestamps[en], desc, is_cp)] + for (st,_), (en,_), dev, desc, is_cp, deps, args in self.prof_records: + dev.raw_prof_records += [(timestamps[st], timestamps[en], desc, is_cp, args)] for x in deps: - (b_st,_), (b_en,_), b_dev, _, b_is_cp, _ = self.prof_records[x] + (b_st,_), (b_en,_), b_dev, _, b_is_cp, _, _ = self.prof_records[x] dev.dep_prof_records += [(timestamps[b_st], timestamps[b_en], b_dev, b_is_cp, timestamps[st], timestamps[en], dev, is_cp)] def __del__(self):