mirror of https://github.com/commaai/tinygrad.git
fix tvm gemm example
This commit is contained in:
parent
af6e2f31ca
commit
8db92bd060
|
@ -31,9 +31,6 @@ except ImportError:
|
||||||
import os
|
import os
|
||||||
from tinygrad.tensor import Tensor
|
from tinygrad.tensor import Tensor
|
||||||
|
|
||||||
# disable optimizations
|
|
||||||
os.environ["NOOPT"] = "1"
|
|
||||||
|
|
||||||
# define the compute
|
# define the compute
|
||||||
A = Tensor.rand(M, K, device="clang")
|
A = Tensor.rand(M, K, device="clang")
|
||||||
B = Tensor.rand(K, N, device="clang")
|
B = Tensor.rand(K, N, device="clang")
|
||||||
|
@ -42,7 +39,7 @@ C = (A.reshape(M, 1, K) * B.permute(1,0).reshape(1, N, K)).sum(axis=2)
|
||||||
sched = C.lazydata.schedule()
|
sched = C.lazydata.schedule()
|
||||||
from tinygrad.codegen.linearizer import Linearizer
|
from tinygrad.codegen.linearizer import Linearizer
|
||||||
from tinygrad.codegen.kernel import LinearizerOptions
|
from tinygrad.codegen.kernel import LinearizerOptions
|
||||||
lin = Linearizer(sched[-1][0], LinearizerOptions(has_local=False, supports_float4=False))
|
lin = Linearizer(sched[-1].ast, LinearizerOptions(has_local=False, supports_float4=False))
|
||||||
#lin.hand_coded_optimizations()
|
#lin.hand_coded_optimizations()
|
||||||
lin.linearize()
|
lin.linearize()
|
||||||
from tinygrad.runtime.ops_clang import renderer
|
from tinygrad.runtime.ops_clang import renderer
|
||||||
|
|
Loading…
Reference in New Issue