search: catch RuntimeError when timing acted_lins (#3664)

when compilation succeeds, but runtime fails due to thread limits
on METAL, this allows a beam search to proceed, treating this
the same way as a compile failure.
This commit is contained in:
Francis Lam 2024-03-11 13:14:03 -07:00 committed by GitHub
parent 490c5a3ec3
commit 9f13960f72
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 5 additions and 3 deletions

View File

@ -124,7 +124,8 @@ def beam_search(lin:Linearizer, rawbufs, amt:int, allow_test_size=True) -> Linea
if lib in seen_libs: continue
#print(acted_lins[i].colored_shape(), acted_lins[i].applied_opts) # for debugging BEAMs that segfault
seen_libs.add(lib)
tms = _time_program(vars, dev, lib, global_size, local_size, var_vals, rawbufs, early_stop=beam[0][1]*3 if len(beam) else 1.0)
try: tms = _time_program(vars, dev, lib, global_size, local_size, var_vals, rawbufs, early_stop=beam[0][1]*3 if len(beam) else 1.0)
except RuntimeError: continue # for runtime issues
timed_lins.append((acted_lins[i], min(tms)))
if DEBUG >= 2: print(f"\r{time.perf_counter() - st:7.2f}s: {timed_lins[-1][1]*1e6:12.2f} us {len(timed_lins):4d}/{len(acted_lins):4d} {timed_lins[-1][0].colored_shape()}\033[K", end="") # noqa: E501

View File

@ -40,7 +40,8 @@ class HSAProgram:
if not hasattr(self, "args_struct_t"):
self.args_struct_t = init_c_struct_t(tuple([(f'f{i}', ctypes.c_void_p) for i in range(len(args))] +
[(f'v{i}', ctypes.c_int) for i in range(len(vals))]))
assert ctypes.sizeof(self.args_struct_t) == self.kernargs_segment_size, f"{ctypes.sizeof(self.args_struct_t)} != {self.kernargs_segment_size}"
if ctypes.sizeof(self.args_struct_t) != self.kernargs_segment_size:
raise RuntimeError(f"HSAProgram.__call__: incorrect args struct size {ctypes.sizeof(self.args_struct_t)} != {self.kernargs_segment_size}")
kernargs = None
if self.kernargs_segment_size > 0:

View File

@ -43,7 +43,7 @@ class MetalProgram:
self.pipeline_state = unwrap2(self.device.device.newComputePipelineStateWithFunction_error_(self.fxn, None))
def __call__(self, *bufs, global_size:Tuple[int,int,int]=(1,1,1), local_size:Tuple[int,int,int]=(1,1,1), vals:Tuple[int, ...]=(), wait=False):
assert prod(local_size) <= self.pipeline_state.maxTotalThreadsPerThreadgroup(),f"local size {local_size} bigger than {self.pipeline_state.maxTotalThreadsPerThreadgroup()} with exec width {self.pipeline_state.threadExecutionWidth()} memory length {self.pipeline_state.staticThreadgroupMemoryLength()}" # noqa: E501
if prod(local_size) > self.pipeline_state.maxTotalThreadsPerThreadgroup(): raise RuntimeError(f"local size {local_size} bigger than {self.pipeline_state.maxTotalThreadsPerThreadgroup()} with exec width {self.pipeline_state.threadExecutionWidth()} memory length {self.pipeline_state.staticThreadgroupMemoryLength()}") # noqa: E501
command_buffer = self.device.mtl_queue.commandBuffer()
encoder = command_buffer.computeCommandEncoder()
encoder.setComputePipelineState_(self.pipeline_state)