move MultiGraphRunner logic to GraphRunner [pr] (#7083)

* move MultiGraphRunner logic to GraphRunner [pr]

* _access_resources
This commit is contained in:
George Hotz 2024-10-16 11:04:30 +08:00 committed by GitHub
parent 207fbc4bc7
commit 023b77cc6e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 10 additions and 9 deletions

View File

@ -97,25 +97,23 @@ class GraphRunner(Runner): # pylint: disable=abstract-method
global_dim_idx, local_dim_idx = find_symbolic_dim(ji.prg.p.global_size), find_symbolic_dim(ji.prg.p.local_size)
if global_dim_idx is not None or local_dim_idx is not None: self.launch_dims_replace[j] = (global_dim_idx, local_dim_idx)
# used in MultiGraphRunner. the ints are id() of _bufs
self.w_dependency_map: Dict[int, Any] = {}
self.r_dependency_map: Dict[int, List[Any]] = collections.defaultdict(list)
super().__init__(colored(f"<batched {len(self.jit_cache)}>", "cyan"), jit_cache[0].prg.dname.split(":")[0],
ssimplify(op_estimate), ssimplify(mem_estimate), ssimplify(lds_estimate))
def updated_vars(self, var_vals):
def updated_vars(self, var_vals: Dict[Variable, int]):
vals = [var_vals[v] for v in self.vars]
for j, vidxs in self.var_vals_replace.items():
for i, v in enumerate(vidxs): yield j, i, vals[v]
def updated_launch_dims(self, var_vals):
def updated_launch_dims(self, var_vals: Dict[Variable, int]):
dims = [tuple(sym_infer(s, var_vals) for s in dim) for dim in self.symbolic_dims]
for j, (gl, lc) in self.launch_dims_replace.items(): yield j, (dims[gl] if gl is not None else None), (dims[lc] if lc is not None else None)
class MultiGraphRunner(GraphRunner): # pylint: disable=abstract-method
def __init__(self, jit_cache: List[ExecItem], input_rawbuffers: List[Buffer], var_vals: Dict[Variable, int]):
self.w_dependency_map: Dict[Any, Any] = {}
self.r_dependency_map: Dict[Any, List[Any]] = collections.defaultdict(list)
super().__init__(jit_cache, input_rawbuffers, var_vals)
def _access_resources(self, read, write, new_dependency:Any):
def _access_resources(self, read:List[Buffer], write:List[Buffer], new_dependency:Any):
# To synchronize access to resources, we monitor the necessary prerequisites for accessing each resource,
# whether for write or read operations. A resource can be accessed by either a single writer or multiple readers.
wait_nodes = []
@ -129,6 +127,9 @@ class MultiGraphRunner(GraphRunner): # pylint: disable=abstract-method
for rawbuf in write: self.w_dependency_map[id(rawbuf.base._buf)] = new_dependency
return list({id(x):x for x in wait_nodes}.values())
# a marker for your graph supporting multiple devices of the same type
class MultiGraphRunner(GraphRunner): pass # pylint: disable=abstract-method
ReturnType = TypeVar('ReturnType')
@dataclass
class CapturedJit(Generic[ReturnType]):