mirror of https://github.com/commaai/tinygrad.git
hcq profiler support args (#5989)
* hcq profiler support args * bytes -> _bytes * fix * add test * mypy * not f strings * percison
This commit is contained in:
parent
45b1761175
commit
38d5eecc68
|
@ -169,6 +169,27 @@ class TestProfiler(unittest.TestCase):
|
|||
assert l['name'].find("->") == -1, "should be kernel"
|
||||
assert r['name'] == f"{Device.DEFAULT} -> {Device.DEFAULT}:1", "should be copy"
|
||||
|
||||
@unittest.skipIf(MOCKGPU and Device.DEFAULT == "AMD", "AMD mockgpu with indirect buffers does not support queue wait interrupts")
|
||||
def test_profile_copy_args(self):
|
||||
d1 = Device[f"{Device.DEFAULT}:1"]
|
||||
|
||||
def f(a):
|
||||
x = (a + 1).realize()
|
||||
return x, x.to(d1.dname).realize()
|
||||
|
||||
a = Tensor.randn(10, 10, device=TestProfiler.d0.dname).realize()
|
||||
with helper_collect_profile(TestProfiler.d0, d1) as profile:
|
||||
jf = TinyJit(f)
|
||||
for _ in range(3):
|
||||
TestProfiler.d0.raw_prof_records, TestProfiler.d0.sig_prof_records = [], [] # reset to collect only graph logs
|
||||
d1.raw_prof_records, d1.sig_prof_records = [], []
|
||||
jf(a)
|
||||
del jf
|
||||
|
||||
node = helper_profile_filter_node(profile, name=f"{Device.DEFAULT} -> {Device.DEFAULT}:1")[-1]
|
||||
assert node['args']['Size'] == "400.00 B"
|
||||
assert abs(float(node['args']['GB/S']) - ((10 * 10 * 4) / 1e3) / (node['dur'])) < 0.01
|
||||
|
||||
@unittest.skipIf(CI, "skip CI")
|
||||
def test_profile_sync(self):
|
||||
mv = memoryview(bytearray(struct.pack("ff", 0, 1)))
|
||||
|
|
|
@ -493,7 +493,7 @@ class HCQCompiled(Compiled):
|
|||
self.timeline_value:int = 1
|
||||
self.timeline_signal, self._shadow_timeline_signal = timeline_signals
|
||||
self.sig_prof_records:List[Tuple[HCQSignal, HCQSignal, str, bool]] = []
|
||||
self.raw_prof_records:List[Tuple[decimal.Decimal, decimal.Decimal, str, bool]] = []
|
||||
self.raw_prof_records:List[Tuple[decimal.Decimal, decimal.Decimal, str, bool, Optional[Dict]]] = []
|
||||
self.dep_prof_records:List[Tuple[decimal.Decimal, decimal.Decimal, HCQCompiled, bool, decimal.Decimal, decimal.Decimal, HCQCompiled, bool]] = []
|
||||
if PROFILE: self._prof_setup()
|
||||
|
||||
|
@ -509,7 +509,7 @@ class HCQCompiled(Compiled):
|
|||
|
||||
if self.timeline_value > (1 << 31): self._wrap_timeline_signal()
|
||||
if PROFILE:
|
||||
self.raw_prof_records += [(st.timestamp, en.timestamp, name, is_cp) for st, en, name, is_cp in self.sig_prof_records]
|
||||
self.raw_prof_records += [(st.timestamp, en.timestamp, name, is_cp, None) for st, en, name, is_cp in self.sig_prof_records]
|
||||
self.sig_prof_records = []
|
||||
|
||||
def _alloc_kernargs(self, alloc_size:int) -> int:
|
||||
|
@ -580,8 +580,8 @@ class HCQCompiled(Compiled):
|
|||
# Sync to be sure all events on the device are recorded.
|
||||
self.synchronize()
|
||||
|
||||
for st, en, name, is_cp in self.raw_prof_records:
|
||||
self.profile_logger.events += [(name, self._gpu2cpu_time(st, is_cp), self._gpu2cpu_time(en, is_cp), self.dname, qname[is_cp])]
|
||||
for st, en, name, is_cp, args in self.raw_prof_records:
|
||||
self.profile_logger.events += [(name, self._gpu2cpu_time(st, is_cp), self._gpu2cpu_time(en, is_cp), self.dname, qname[is_cp], args)]
|
||||
for a_st, a_en, a_dev, a_is_copy, b_st, b_en, b_dev, b_is_copy in self.dep_prof_records:
|
||||
# Perfetto connects nodes based on timing data, ensuring every choice is valid by averaging times to a midpoint.
|
||||
a_tm, b_tm = a_dev._gpu2cpu_time((a_st+a_en)/decimal.Decimal(2), a_is_copy), b_dev._gpu2cpu_time((b_st+b_en)/decimal.Decimal(2), b_is_copy)
|
||||
|
|
|
@ -27,6 +27,7 @@ def all_same(items:Union[Tuple[T, ...], List[T]]): return all(x == items[0] for
|
|||
def all_int(t: Sequence[Any]) -> TypeGuard[Tuple[int, ...]]: return all(isinstance(s, int) for s in t)
|
||||
def colored(st, color:Optional[str], background=False): return f"\u001b[{10*background+60*(color.upper() == color)+30+['black', 'red', 'green', 'yellow', 'blue', 'magenta', 'cyan', 'white'].index(color.lower())}m{st}\u001b[0m" if color is not None else st # replace the termcolor library with one line # noqa: E501
|
||||
def colorize_float(x: float): return colored(f"{x:7.2f}x", 'green' if x < 0.75 else 'red' if x > 1.15 else 'yellow')
|
||||
def memsize_to_str(_bytes: int) -> str: return [f"{(_bytes / d):.2f} {pr}" for d,pr in [(1e9,"GB"),(1e6,"MB"),(1e3,"KB"),(1,"B")] if _bytes > d][0]
|
||||
def ansistrip(s:str): return re.sub('\x1b\\[(K|.*?m)', '', s)
|
||||
def ansilen(s:str): return len(ansistrip(s))
|
||||
def make_pair(x:Union[int, Tuple[int, ...]], cnt=2) -> Tuple[int, ...]: return (x,)*cnt if isinstance(x, int) else x
|
||||
|
@ -166,7 +167,7 @@ class ProfileLogger:
|
|||
|
||||
def __init__(self): self.events, self.deps, ProfileLogger.writers = [], [], ProfileLogger.writers + 1
|
||||
|
||||
def add_event(self, ev_name, ev_start, ev_end, actor, subactor=None): self.events += [(ev_name, ev_start, ev_end, actor, subactor)]
|
||||
def add_event(self, ev_name, ev_start, ev_end, actor, subactor=None, args=None): self.events += [(ev_name, ev_start, ev_end, actor, subactor, args)]
|
||||
|
||||
def _ensure_actor(self, actor_name, subactor_name):
|
||||
if actor_name not in self.actors:
|
||||
|
@ -181,9 +182,10 @@ class ProfileLogger:
|
|||
|
||||
def __del__(self):
|
||||
# perfetto json docs: https://docs.google.com/document/d/1CvAClvFfyA5R-PhYUmn5OOQtYMH4h6I0nSsKchNAySU/preview
|
||||
for name,st,et,actor_name,subactor_name in self.events:
|
||||
for name, st, et, actor_name, subactor_name, args in self.events:
|
||||
pid, tid = self._ensure_actor(actor_name,subactor_name)
|
||||
self.mjson.append({"name": name, "ph": "X", "pid": pid, "tid": tid, "ts":st, "dur":et-st})
|
||||
args = {k: (v if v.__class__ is str else v(et-st)) for k, v in args.items()} if args is not None else None
|
||||
self.mjson.append({"name": name, "ph": "X", "pid": pid, "tid": tid, "ts": st, "dur": et-st, "args": args})
|
||||
|
||||
for en,st,dep_actor_name,dep_subactor_name,actor_name,subactor_name in self.deps:
|
||||
dep_pid, dep_tid = self._ensure_actor(dep_actor_name,dep_subactor_name)
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
import collections, time
|
||||
from typing import List, Any, Dict, cast, Optional, Tuple, Set
|
||||
from tinygrad.helpers import round_up, to_mv, PROFILE
|
||||
from tinygrad.helpers import round_up, to_mv, PROFILE, memsize_to_str
|
||||
from tinygrad.device import HCQCompiled, HCQAllocator, HCQSignal, HCQBuffer, HWCommandQueue, HWComputeQueue, HWCopyQueue, \
|
||||
Buffer, BufferOptions, Compiled, Device
|
||||
from tinygrad.shape.symbolic import Variable
|
||||
|
@ -47,7 +47,7 @@ class HCQGraph(MultiGraphRunner):
|
|||
self.kickoff_value: int = 0
|
||||
|
||||
self.prof_signals: List[HCQSignal] = [self.devices[0].signal_t() for i in range(len(self.jit_cache) * 2)] if PROFILE else []
|
||||
self.prof_records: List[Tuple[Tuple[int, bool], Tuple[int, bool], HCQCompiled, str, bool, List[int]]] = []
|
||||
self.prof_records: List[Tuple[Tuple[int, bool], Tuple[int, bool], HCQCompiled, str, bool, List[int], Optional[Dict]]] = []
|
||||
|
||||
last_j: Dict[HWCommandQueue, Optional[int]] = collections.defaultdict(lambda: None)
|
||||
queue_access: Dict[HWCommandQueue, Dict[HWCommandQueue, Optional[int]]] = collections.defaultdict(lambda: collections.defaultdict(lambda: None))
|
||||
|
@ -96,7 +96,10 @@ class HCQGraph(MultiGraphRunner):
|
|||
sig_st, sig_en = (j * 2, True), (j * 2 + 1, True)
|
||||
if len(opt_deps) == 0 and (prev_ji:=last_j[enqueue_queue]) is not None: sig_st = (prev_ji * 2 + 1, False)
|
||||
|
||||
self.prof_records.append((sig_st, sig_en, enqueue_dev, prof_ji_desc, not is_exec_prg, [d - 1 for _, d in rdeps]))
|
||||
if is_exec_prg: prof_args = None
|
||||
else: prof_args = {"Size": memsize_to_str(ji.bufs[0].nbytes), "GB/S": lambda dur, b=ji.bufs[0].nbytes: f"{b/1e3/dur:.2f}"} # type: ignore
|
||||
|
||||
self.prof_records.append((sig_st, sig_en, enqueue_dev, prof_ji_desc, not is_exec_prg, [d - 1 for _, d in rdeps], prof_args))
|
||||
|
||||
last_j[enqueue_queue] = j
|
||||
|
||||
|
@ -187,11 +190,11 @@ class HCQGraph(MultiGraphRunner):
|
|||
def collect_timestamps(self):
|
||||
timestamps = [s.timestamp for s in self.prof_signals]
|
||||
|
||||
for (st,_), (en,_), dev, desc, is_cp, deps in self.prof_records:
|
||||
dev.raw_prof_records += [(timestamps[st], timestamps[en], desc, is_cp)]
|
||||
for (st,_), (en,_), dev, desc, is_cp, deps, args in self.prof_records:
|
||||
dev.raw_prof_records += [(timestamps[st], timestamps[en], desc, is_cp, args)]
|
||||
|
||||
for x in deps:
|
||||
(b_st,_), (b_en,_), b_dev, _, b_is_cp, _ = self.prof_records[x]
|
||||
(b_st,_), (b_en,_), b_dev, _, b_is_cp, _, _ = self.prof_records[x]
|
||||
dev.dep_prof_records += [(timestamps[b_st], timestamps[b_en], b_dev, b_is_cp, timestamps[st], timestamps[en], dev, is_cp)]
|
||||
|
||||
def __del__(self):
|
||||
|
|
Loading…
Reference in New Issue