mirror of https://github.com/commaai/tinygrad.git
move MultiGraphRunner logic to GraphRunner [pr] (#7083)
* move MultiGraphRunner logic to GraphRunner [pr] * _access_resources
This commit is contained in:
parent
207fbc4bc7
commit
023b77cc6e
|
@ -97,25 +97,23 @@ class GraphRunner(Runner): # pylint: disable=abstract-method
|
||||||
global_dim_idx, local_dim_idx = find_symbolic_dim(ji.prg.p.global_size), find_symbolic_dim(ji.prg.p.local_size)
|
global_dim_idx, local_dim_idx = find_symbolic_dim(ji.prg.p.global_size), find_symbolic_dim(ji.prg.p.local_size)
|
||||||
if global_dim_idx is not None or local_dim_idx is not None: self.launch_dims_replace[j] = (global_dim_idx, local_dim_idx)
|
if global_dim_idx is not None or local_dim_idx is not None: self.launch_dims_replace[j] = (global_dim_idx, local_dim_idx)
|
||||||
|
|
||||||
|
# used in MultiGraphRunner. the ints are id() of _bufs
|
||||||
|
self.w_dependency_map: Dict[int, Any] = {}
|
||||||
|
self.r_dependency_map: Dict[int, List[Any]] = collections.defaultdict(list)
|
||||||
|
|
||||||
super().__init__(colored(f"<batched {len(self.jit_cache)}>", "cyan"), jit_cache[0].prg.dname.split(":")[0],
|
super().__init__(colored(f"<batched {len(self.jit_cache)}>", "cyan"), jit_cache[0].prg.dname.split(":")[0],
|
||||||
ssimplify(op_estimate), ssimplify(mem_estimate), ssimplify(lds_estimate))
|
ssimplify(op_estimate), ssimplify(mem_estimate), ssimplify(lds_estimate))
|
||||||
|
|
||||||
def updated_vars(self, var_vals):
|
def updated_vars(self, var_vals: Dict[Variable, int]):
|
||||||
vals = [var_vals[v] for v in self.vars]
|
vals = [var_vals[v] for v in self.vars]
|
||||||
for j, vidxs in self.var_vals_replace.items():
|
for j, vidxs in self.var_vals_replace.items():
|
||||||
for i, v in enumerate(vidxs): yield j, i, vals[v]
|
for i, v in enumerate(vidxs): yield j, i, vals[v]
|
||||||
|
|
||||||
def updated_launch_dims(self, var_vals):
|
def updated_launch_dims(self, var_vals: Dict[Variable, int]):
|
||||||
dims = [tuple(sym_infer(s, var_vals) for s in dim) for dim in self.symbolic_dims]
|
dims = [tuple(sym_infer(s, var_vals) for s in dim) for dim in self.symbolic_dims]
|
||||||
for j, (gl, lc) in self.launch_dims_replace.items(): yield j, (dims[gl] if gl is not None else None), (dims[lc] if lc is not None else None)
|
for j, (gl, lc) in self.launch_dims_replace.items(): yield j, (dims[gl] if gl is not None else None), (dims[lc] if lc is not None else None)
|
||||||
|
|
||||||
class MultiGraphRunner(GraphRunner): # pylint: disable=abstract-method
|
def _access_resources(self, read:List[Buffer], write:List[Buffer], new_dependency:Any):
|
||||||
def __init__(self, jit_cache: List[ExecItem], input_rawbuffers: List[Buffer], var_vals: Dict[Variable, int]):
|
|
||||||
self.w_dependency_map: Dict[Any, Any] = {}
|
|
||||||
self.r_dependency_map: Dict[Any, List[Any]] = collections.defaultdict(list)
|
|
||||||
super().__init__(jit_cache, input_rawbuffers, var_vals)
|
|
||||||
|
|
||||||
def _access_resources(self, read, write, new_dependency:Any):
|
|
||||||
# To synchronize access to resources, we monitor the necessary prerequisites for accessing each resource,
|
# To synchronize access to resources, we monitor the necessary prerequisites for accessing each resource,
|
||||||
# whether for write or read operations. A resource can be accessed by either a single writer or multiple readers.
|
# whether for write or read operations. A resource can be accessed by either a single writer or multiple readers.
|
||||||
wait_nodes = []
|
wait_nodes = []
|
||||||
|
@ -129,6 +127,9 @@ class MultiGraphRunner(GraphRunner): # pylint: disable=abstract-method
|
||||||
for rawbuf in write: self.w_dependency_map[id(rawbuf.base._buf)] = new_dependency
|
for rawbuf in write: self.w_dependency_map[id(rawbuf.base._buf)] = new_dependency
|
||||||
return list({id(x):x for x in wait_nodes}.values())
|
return list({id(x):x for x in wait_nodes}.values())
|
||||||
|
|
||||||
|
# a marker for your graph supporting multiple devices of the same type
|
||||||
|
class MultiGraphRunner(GraphRunner): pass # pylint: disable=abstract-method
|
||||||
|
|
||||||
ReturnType = TypeVar('ReturnType')
|
ReturnType = TypeVar('ReturnType')
|
||||||
@dataclass
|
@dataclass
|
||||||
class CapturedJit(Generic[ReturnType]):
|
class CapturedJit(Generic[ReturnType]):
|
||||||
|
|
Loading…
Reference in New Issue