update hip_matmul with new abstraction (#2605)

This commit is contained in:
Yixiang Gao 2023-12-04 15:37:10 -06:00 committed by GitHub
parent 5540f6e966
commit fde44aed76
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 17 additions and 9 deletions

View File

@ -1,7 +1,7 @@
import time
import numpy as np
from tinygrad.helpers import dtypes, getenv, prod
from tinygrad.runtime.ops_hip import RawHIPBuffer, HIPProgram, compile_hip
from tinygrad.helpers import dtypes, getenv, prod, flat_mv
from tinygrad.runtime.ops_hip import HIPAllocator, HIPProgram, compile_hip
# AMD_LOG_LEVEL=3 ./MIOpenDriver gemm --iter 1000 --time 1 --a_w 2048 --a_h 2048 --b_w 2048
# 5.5: Cijk_Ailk_Bljk_HHS_BH_MT128x128x16_MI16x16x16x1_SN_1LDSB0_APM1_ABV0_ACED0_AF0EM1_AF1EM1_AMAS3_ASE_ASGT_ASAE01_ASCE01_ASEM1_AAC0_BL1_BS1_DTL0_DTVA0_DVO0_ETSP_EPS1_FL0_GRVW8_GSU1_GSUASB_GLS0_ISA1100_IU1_K1_KLA_LBSPP128_LPA0_LPB8_LDL1_LRVW16_LWPMn1_LDW0_FMA_MIAV1_MDA2_NTA0_NTB0_NTC0_NTD0_NEPBS0_NLCA1_NLCB1_ONLL1_OPLV0_PK0_PAP0_PGR1_PLR1_RK0_SIA1_SS1_SU32_SUM0_SUS128_SCIUI1_SPO0_SRVW0_SSO0_SVW4_SNLL0_TT4_64_TLDS1_USFGROn1_VAW2_VSn1_VW4_WSGRA1_WSGRB1_WS32_WG32_4_1_WGM4
@ -23,14 +23,19 @@ assert N%(16*KY) == 0, f"N must be multiple of {16*KY}"
FLOPS = N*N*N*2
BW = N*N*3*4
a = RawHIPBuffer(N*N, dtypes.float32)
# Can HIPAllocator initialized as device=0 by default?
device = 0
hipallocator = HIPAllocator(device)
a = hipallocator.alloc(N*N*4)
b = hipallocator.alloc(N*N*2)
c = hipallocator.alloc(N*N*2)
na = np.empty(N*N, np.float32)
nb = np.random.default_rng().standard_normal(size=(N,N), dtype=np.float32).astype(np.float16)
nc = np.random.default_rng().standard_normal(size=(N,N), dtype=np.float32).astype(np.float16)
b = RawHIPBuffer.fromCPU(nb)
c = RawHIPBuffer.fromCPU(nc)
hipallocator.copyin(b, bytearray(nb))
hipallocator.copyin(c, bytearray(nc))
prog = HIPProgram("test", compile_hip(f"""
lib = compile_hip(f"""
#define F32
typedef float float8 __attribute__((ext_vector_type(8)));
typedef _Float16 half16 __attribute__((ext_vector_type(16)));
@ -87,7 +92,9 @@ extern "C" __global__ void __launch_bounds__ (128, 1) test(float* c, __half* a,
}}
}}
}}
}}"""))
}}""")
prog = HIPProgram(device, "test", lib)
def timeit(fxn):
st = time.perf_counter()
@ -99,7 +106,8 @@ def timeit(fxn):
global_size, local_size = [N//(KX*16*2), N//(KY*16*2), 1], [32, 2, 2]
print("global/local size", global_size, local_size, f"local_size:{prod(local_size)} total_size:{prod(global_size+local_size)}")
tm = min([timeit(lambda: prog(a, b, c, global_size=global_size, local_size=local_size, wait=True)) for _ in range(1000)])
na = a.toCPU().reshape(N,N)
hipallocator.copyout(flat_mv(na.data),a)
na = na.reshape(N,N)
comp = nb.astype(np.float32) @ nc.astype(np.float32)
print(f"{N*N:10d} {tm*1e6:9.2f} us, would be {FLOPS*1e-9/tm:9.2f} GFLOPS matmul, {BW*1e-9/tm:.2f} GB/s")
np.testing.assert_allclose(na, comp, atol=1e-2, rtol=1e-2)