2024-01-27 11:06:57 +08:00
|
|
|
from tinygrad.helpers import getenv
|
|
|
|
from tinygrad import dtypes, Tensor
|
|
|
|
|
2024-03-28 04:43:09 +08:00
|
|
|
dtype_in = dtypes.half if getenv("HALF") else dtypes.bfloat16 if getenv("BFLOAT16") else dtypes.float
|
|
|
|
acc_dtype = dtypes.half if getenv("ACC_HALF") else dtypes.bfloat16 if getenv("ACC_BFLOAT16") else None
|
2024-01-27 11:06:57 +08:00
|
|
|
|
|
|
|
CNT = getenv("CNT", 8)
|
|
|
|
BS = getenv("BS", 16)
|
|
|
|
CIN = getenv("CIN", 128)
|
|
|
|
COUT = getenv("COUT", 128)
|
|
|
|
HW = getenv("HW", 128)
|
|
|
|
K = getenv("K", 3)
|
|
|
|
PADDING = getenv("PADDING", 1)
|
|
|
|
COMP = getenv("COMP", 0)
|
2024-03-28 04:43:09 +08:00
|
|
|
ATOL = getenv("ATOL", 1e-4)
|
|
|
|
RTOL = getenv("RTOL", 3e-2)
|
2024-01-27 11:06:57 +08:00
|
|
|
|
|
|
|
FLOPS = BS*K*K*CIN*HW*HW*COUT*2
|
|
|
|
def rand_input(): return Tensor.rand(BS, CIN, HW, HW, dtype=dtype_in).realize(), Tensor.rand(COUT, CIN, K, K, dtype=dtype_in).realize()
|
|
|
|
|
2024-03-22 12:23:36 +08:00
|
|
|
if __name__ == "__main__":
|
|
|
|
a, b = rand_input()
|
|
|
|
for i in range(CNT):
|
|
|
|
if i > 0 and getenv("RAND", 0) != 0:
|
|
|
|
a, b = rand_input()
|
|
|
|
c = a.conv2d(b, padding=PADDING, acc_dtype=acc_dtype).realize()
|
2024-01-27 11:06:57 +08:00
|
|
|
|
2024-03-22 12:23:36 +08:00
|
|
|
if COMP:
|
|
|
|
import numpy as np, time, torch
|
|
|
|
torch_device = "cuda:0" if torch.cuda.is_available() else ("mps" if getenv("MPS", 0) else "cpu")
|
|
|
|
ta, tb = torch.from_numpy(a.numpy()).to(torch_device), torch.from_numpy(b.numpy()).to(torch_device)
|
|
|
|
tc = torch.nn.functional.conv2d(ta, tb, padding=PADDING)
|
2024-03-28 04:43:09 +08:00
|
|
|
np.testing.assert_allclose(c.numpy(), tc.cpu(), atol=ATOL, rtol=RTOL)
|