2023-09-08 02:50:41 +08:00
|
|
|
# https://tvm.apache.org/docs/tutorial/tensor_expr_get_started.html#example-2-manually-optimizing-matrix-multiplication-with-te
|
|
|
|
|
|
|
|
M, N, K = 1024, 1024, 1024
|
|
|
|
|
2023-09-29 00:24:32 +08:00
|
|
|
try:
|
|
|
|
import tvm
|
|
|
|
from tvm import te
|
|
|
|
#print(tvm.target.Target.list_kinds())
|
2023-09-08 02:50:41 +08:00
|
|
|
|
2023-09-29 00:24:32 +08:00
|
|
|
# c, opencl
|
|
|
|
target = tvm.target.Target(target="c")
|
2023-09-08 02:50:41 +08:00
|
|
|
|
2023-09-29 00:24:32 +08:00
|
|
|
# TVM Matrix Multiplication using TE
|
|
|
|
k = te.reduce_axis((0, K), "k")
|
|
|
|
A = te.placeholder((M, K), name="A")
|
|
|
|
B = te.placeholder((K, N), name="B")
|
|
|
|
C = te.compute((M, N), lambda x, y: te.sum(A[x, k] * B[k, y], axis=k), name="C")
|
2023-09-08 02:50:41 +08:00
|
|
|
|
2023-09-29 00:24:32 +08:00
|
|
|
# Default schedule
|
|
|
|
s = te.create_schedule(C.op)
|
|
|
|
#print(tvm.lower(s, [A, B, C], simple_mode=True))
|
|
|
|
|
|
|
|
# Output C code
|
|
|
|
func = tvm.build(s, [A, B, C], target=target, name="mmult")
|
|
|
|
print(func.get_source())
|
|
|
|
except ImportError:
|
|
|
|
print("** please install TVM for TVM output")
|
2023-09-08 02:50:41 +08:00
|
|
|
|
|
|
|
# tinygrad version
|
|
|
|
|
|
|
|
import os
|
|
|
|
from tinygrad.tensor import Tensor
|
2024-03-27 12:02:46 +08:00
|
|
|
from tinygrad.engine.schedule import create_schedule
|
2023-09-08 02:50:41 +08:00
|
|
|
|
|
|
|
# define the compute
|
|
|
|
A = Tensor.rand(M, K, device="clang")
|
|
|
|
B = Tensor.rand(K, N, device="clang")
|
|
|
|
C = (A.reshape(M, 1, K) * B.permute(1,0).reshape(1, N, K)).sum(axis=2)
|
|
|
|
|
2024-02-13 01:10:45 +08:00
|
|
|
sched = create_schedule([C.lazydata])
|
2024-07-13 09:50:55 +08:00
|
|
|
from tinygrad.codegen.kernel import Kernel
|
2024-03-29 05:50:23 +08:00
|
|
|
from tinygrad.device import CompilerOptions
|
2024-07-13 09:50:55 +08:00
|
|
|
lin = Kernel(sched[-1].ast, CompilerOptions(has_local=False, supports_float4=False))
|
2023-09-29 00:24:32 +08:00
|
|
|
#lin.hand_coded_optimizations()
|
|
|
|
lin.linearize()
|
|
|
|
from tinygrad.runtime.ops_clang import renderer
|
|
|
|
src = renderer("mmult", lin.uops)
|
|
|
|
print(src)
|