mirror of https://github.com/commaai/tinygrad.git
fix buf_index not found case in _apply_tc_opt (#3739)
ValueError if src.src[0] is not a LOAD. Replaced with returning None in _apply_tc_opt and test to make sure the net output is KernelOptError.
This commit is contained in:
parent
6bf11a2ce3
commit
90e55a9fd1
|
@ -4,7 +4,7 @@ import unittest
|
|||
from tinygrad.codegen.kernel import Opt, OptOps, KernelOptError, tensor_cores
|
||||
from tinygrad.codegen.linearizer import Linearizer, UOp, UOps, expand_node, expand_idxs
|
||||
from tinygrad.device import Device, Buffer
|
||||
from tinygrad.ops import BinaryOps, BufferOps, MemBuffer, ConstBuffer, LazyOp, LoadOps, TernaryOps
|
||||
from tinygrad.ops import BinaryOps, BufferOps, MemBuffer, ConstBuffer, LazyOp, LoadOps, TernaryOps, ReduceOps, UnaryOps
|
||||
from tinygrad.shape.shapetracker import ShapeTracker
|
||||
from tinygrad.shape.view import View
|
||||
from tinygrad.shape.symbolic import MulNode, Variable, NumNode, Node
|
||||
|
@ -628,6 +628,17 @@ class TestLinearizerOpts(unittest.TestCase):
|
|||
with self.assertRaises(AssertionError):
|
||||
assert k.apply_tensor_cores(use_tensor_cores=1, extra_opts=x), "no valid tensor core" # for METAL in runners
|
||||
|
||||
def test_buf_index_not_found_tensor_core(self):
|
||||
if not Device[Device.DEFAULT].compiler.linearizer_opts.has_tensor_cores:
|
||||
self.skipTest("device doesn't have tensor cores")
|
||||
if Device.DEFAULT not in tensor_cores:
|
||||
self.skipTest("No tensor cores for device")
|
||||
|
||||
ast = LazyOp(op=BufferOps.STORE, src=(LazyOp(op=ReduceOps.SUM, src=(LazyOp(op=BinaryOps.MUL, src=(LazyOp(op=UnaryOps.CAST, src=(LazyOp(op=BinaryOps.CMPEQ, src=(LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=1, dtype=dtypes.int, st=ShapeTracker(views=(View(shape=(1243, 256), strides=(0, 1), offset=0, mask=None, contiguous=False),)))), LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=2, dtype=dtypes.int, st=ShapeTracker(views=(View(shape=(1243, 256), strides=(1, 0), offset=0, mask=None, contiguous=False),))))), arg=None),), arg=(dtypes.float, False)), LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=3, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(1243, 256), strides=(1, 0), offset=0, mask=None, contiguous=False),))))), arg=None),), arg=(0,)),), arg=MemBuffer(idx=0, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(1, 256), strides=(0, 1), offset=0, mask=None, contiguous=True),)))) # noqa: E501
|
||||
k = Linearizer(ast, opts=Device[Device.DEFAULT].compiler.linearizer_opts)
|
||||
with self.assertRaises(KernelOptError):
|
||||
k.apply_opt(Opt(OptOps.TC, 0, 1))
|
||||
|
||||
def test_tensor_core_opts(self):
|
||||
if not Device[Device.DEFAULT].compiler.linearizer_opts.has_tensor_cores:
|
||||
self.skipTest("device doesn't have tensor cores")
|
||||
|
|
|
@ -364,8 +364,11 @@ class Kernel:
|
|||
if mul_op.op != BinaryOps.MUL: continue
|
||||
|
||||
def buf_index(src: LazyOp) -> Optional[int]:
|
||||
# TODO: apply tc even if the sources are not from LOAD
|
||||
if src.op == BufferOps.LOAD and src.arg.dtype == tc.dtype_in: return self.bufs.index(cast(MemBuffer, src.arg))
|
||||
if opt_level >= 1 and src.op == UnaryOps.CAST and src.arg[0] == tc.dtype_in: return self.bufs.index(cast(MemBuffer, src.src[0].arg))
|
||||
try:
|
||||
if opt_level >= 1 and src.op == UnaryOps.CAST and src.arg[0] == tc.dtype_in: return self.bufs.index(cast(MemBuffer, src.src[0].arg))
|
||||
except ValueError: return None
|
||||
return None
|
||||
if (buf0:=buf_index(mul_op.src[0])) is None or (buf1:=buf_index(mul_op.src[1])) is None: continue
|
||||
|
||||
|
|
Loading…
Reference in New Issue