From 023b77cc6e0824b62e03819335535cc2f7b2a161 Mon Sep 17 00:00:00 2001 From: George Hotz <72895+geohot@users.noreply.github.com> Date: Wed, 16 Oct 2024 11:04:30 +0800 Subject: [PATCH] move MultiGraphRunner logic to GraphRunner [pr] (#7083) * move MultiGraphRunner logic to GraphRunner [pr] * _access_resources --- tinygrad/engine/jit.py | 19 ++++++++++--------- 1 file changed, 10 insertions(+), 9 deletions(-) diff --git a/tinygrad/engine/jit.py b/tinygrad/engine/jit.py index a8f6b2b4..49a3416e 100644 --- a/tinygrad/engine/jit.py +++ b/tinygrad/engine/jit.py @@ -97,25 +97,23 @@ class GraphRunner(Runner): # pylint: disable=abstract-method global_dim_idx, local_dim_idx = find_symbolic_dim(ji.prg.p.global_size), find_symbolic_dim(ji.prg.p.local_size) if global_dim_idx is not None or local_dim_idx is not None: self.launch_dims_replace[j] = (global_dim_idx, local_dim_idx) + # used in MultiGraphRunner. the ints are id() of _bufs + self.w_dependency_map: Dict[int, Any] = {} + self.r_dependency_map: Dict[int, List[Any]] = collections.defaultdict(list) + super().__init__(colored(f"", "cyan"), jit_cache[0].prg.dname.split(":")[0], ssimplify(op_estimate), ssimplify(mem_estimate), ssimplify(lds_estimate)) - def updated_vars(self, var_vals): + def updated_vars(self, var_vals: Dict[Variable, int]): vals = [var_vals[v] for v in self.vars] for j, vidxs in self.var_vals_replace.items(): for i, v in enumerate(vidxs): yield j, i, vals[v] - def updated_launch_dims(self, var_vals): + def updated_launch_dims(self, var_vals: Dict[Variable, int]): dims = [tuple(sym_infer(s, var_vals) for s in dim) for dim in self.symbolic_dims] for j, (gl, lc) in self.launch_dims_replace.items(): yield j, (dims[gl] if gl is not None else None), (dims[lc] if lc is not None else None) -class MultiGraphRunner(GraphRunner): # pylint: disable=abstract-method - def __init__(self, jit_cache: List[ExecItem], input_rawbuffers: List[Buffer], var_vals: Dict[Variable, int]): - self.w_dependency_map: Dict[Any, Any] = {} - self.r_dependency_map: Dict[Any, List[Any]] = collections.defaultdict(list) - super().__init__(jit_cache, input_rawbuffers, var_vals) - - def _access_resources(self, read, write, new_dependency:Any): + def _access_resources(self, read:List[Buffer], write:List[Buffer], new_dependency:Any): # To synchronize access to resources, we monitor the necessary prerequisites for accessing each resource, # whether for write or read operations. A resource can be accessed by either a single writer or multiple readers. wait_nodes = [] @@ -129,6 +127,9 @@ class MultiGraphRunner(GraphRunner): # pylint: disable=abstract-method for rawbuf in write: self.w_dependency_map[id(rawbuf.base._buf)] = new_dependency return list({id(x):x for x in wait_nodes}.values()) +# a marker for your graph supporting multiple devices of the same type +class MultiGraphRunner(GraphRunner): pass # pylint: disable=abstract-method + ReturnType = TypeVar('ReturnType') @dataclass class CapturedJit(Generic[ReturnType]):