fix buf_index not found case in _apply_tc_opt (#3739)

ValueError if src.src[0] is not a LOAD. Replaced with returning None in _apply_tc_opt and test to make sure the net output is KernelOptError.
This commit is contained in:
chenyu 2024-03-14 14:27:05 -04:00 committed by GitHub
parent 6bf11a2ce3
commit 90e55a9fd1
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 16 additions and 2 deletions

View File

@ -4,7 +4,7 @@ import unittest
from tinygrad.codegen.kernel import Opt, OptOps, KernelOptError, tensor_cores
from tinygrad.codegen.linearizer import Linearizer, UOp, UOps, expand_node, expand_idxs
from tinygrad.device import Device, Buffer
from tinygrad.ops import BinaryOps, BufferOps, MemBuffer, ConstBuffer, LazyOp, LoadOps, TernaryOps
from tinygrad.ops import BinaryOps, BufferOps, MemBuffer, ConstBuffer, LazyOp, LoadOps, TernaryOps, ReduceOps, UnaryOps
from tinygrad.shape.shapetracker import ShapeTracker
from tinygrad.shape.view import View
from tinygrad.shape.symbolic import MulNode, Variable, NumNode, Node
@ -628,6 +628,17 @@ class TestLinearizerOpts(unittest.TestCase):
with self.assertRaises(AssertionError):
assert k.apply_tensor_cores(use_tensor_cores=1, extra_opts=x), "no valid tensor core" # for METAL in runners
def test_buf_index_not_found_tensor_core(self):
if not Device[Device.DEFAULT].compiler.linearizer_opts.has_tensor_cores:
self.skipTest("device doesn't have tensor cores")
if Device.DEFAULT not in tensor_cores:
self.skipTest("No tensor cores for device")
ast = LazyOp(op=BufferOps.STORE, src=(LazyOp(op=ReduceOps.SUM, src=(LazyOp(op=BinaryOps.MUL, src=(LazyOp(op=UnaryOps.CAST, src=(LazyOp(op=BinaryOps.CMPEQ, src=(LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=1, dtype=dtypes.int, st=ShapeTracker(views=(View(shape=(1243, 256), strides=(0, 1), offset=0, mask=None, contiguous=False),)))), LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=2, dtype=dtypes.int, st=ShapeTracker(views=(View(shape=(1243, 256), strides=(1, 0), offset=0, mask=None, contiguous=False),))))), arg=None),), arg=(dtypes.float, False)), LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=3, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(1243, 256), strides=(1, 0), offset=0, mask=None, contiguous=False),))))), arg=None),), arg=(0,)),), arg=MemBuffer(idx=0, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(1, 256), strides=(0, 1), offset=0, mask=None, contiguous=True),)))) # noqa: E501
k = Linearizer(ast, opts=Device[Device.DEFAULT].compiler.linearizer_opts)
with self.assertRaises(KernelOptError):
k.apply_opt(Opt(OptOps.TC, 0, 1))
def test_tensor_core_opts(self):
if not Device[Device.DEFAULT].compiler.linearizer_opts.has_tensor_cores:
self.skipTest("device doesn't have tensor cores")

View File

@ -364,8 +364,11 @@ class Kernel:
if mul_op.op != BinaryOps.MUL: continue
def buf_index(src: LazyOp) -> Optional[int]:
# TODO: apply tc even if the sources are not from LOAD
if src.op == BufferOps.LOAD and src.arg.dtype == tc.dtype_in: return self.bufs.index(cast(MemBuffer, src.arg))
if opt_level >= 1 and src.op == UnaryOps.CAST and src.arg[0] == tc.dtype_in: return self.bufs.index(cast(MemBuffer, src.src[0].arg))
try:
if opt_level >= 1 and src.op == UnaryOps.CAST and src.arg[0] == tc.dtype_in: return self.bufs.index(cast(MemBuffer, src.src[0].arg))
except ValueError: return None
return None
if (buf0:=buf_index(mul_op.src[0])) is None or (buf1:=buf_index(mul_op.src[1])) is None: continue