gemm is 1.7 TFLOPS on a single M1 core

This commit is contained in:
George Hotz 2022-10-29 13:42:33 -07:00
parent 52bfbc31be
commit fdb43fe553
2 changed files with 61 additions and 1 deletions

View File

@ -91,6 +91,29 @@ class LLVM:
# looks like we have two options, either use clang or handle vectorization in tinygrad
# for the sake of the GPU, we should probably do in tinygrad
# ARM NEON is 128b wide, aka <4 x float> (similar to most GPUs)
# Firestorm (big M1 core) can do up to 4 ops per cycle @ 3.2 GHz = 3.2*4*4*2 = 102.4 GFLOPS (fma)
# There's also AMX https://github.com/corsix/amx/blob/main/README.md
# It seems like torch CPU must be using this? I'm seeing ~150 GFLOPS with convs
# Calling nnp_s4gemm_only_3x3__neon and nnp_owt8x8_3x3_with_bias__neon which don't seem like AMX
# Could this be a winograd conv?
# 2048x2048 matmul in 9.88 ms (17.18 GOPS) = 1739 GFLOPS (so much! this has to be the AMX)
# calling libBLAS.dylib`SGEMM
# 0x1c3ac5070: 0x0020100d .long 0x0020100d ; AMX instruction 0 = ldx
# 0x1c3ac5074: 0x0020102b .long 0x0020102b ; AMX instruction 1 = ldy (presumed typo in ldst.md)
# 0x1c3ac5078: 0x0020119f .long 0x0020119f ; AMX instruction 12 = fma32
# 0x1c3ac507c: 0x0020118e .long 0x0020118e ; AMX instruction 12 = fma32
# 0x1c3ac5080: 0x9144410f add x15, x8, #0x110, lsl #12 ; =0x110000
# 0x1c3ac5084: 0x00201188 .long 0x00201188 ; AMX instruction 12 = fma32
# 0x1c3ac5088: 0x0020118f .long 0x0020118f ; AMX instruction 12 = fma32
# 0x1c3ac508c: 0x8b0a016b add x11, x11, x10
# 0x1c3ac5090: 0x8b0c01ad add x13, x13, x12
# 0x1c3ac5094: 0xf1000529 subs x9, x9, #0x1
# 0x1c3ac5098: 0x54fffec1 b.ne 0x1c3ac5070 ; <+140>
def __init__(self):
if LLVM.engine is not None:
return

View File

@ -1,6 +1,7 @@
import os
import unittest
import torch
torch.set_num_threads(1)
import time
import numpy as np
from tinygrad.tensor import Tensor
@ -9,7 +10,43 @@ from tinygrad.nn import Conv2d
IN_CHANS = [int(x) for x in os.getenv("IN_CHANS", "1,16,64").split(",")]
CNT = 5
class TestSpeedVTorch(unittest.TestCase):
class TestSpeed(unittest.TestCase):
def test_gemm(self):
N = 2048
torch.manual_seed(0)
torch_a = torch.rand(N, N)
torch_b = torch.rand(N, N)
tiny_a = Tensor(torch_a.cpu().numpy())
tiny_b = Tensor(torch_b.cpu().numpy())
ets_torch = []
for _ in range(CNT):
torch_a += 1
st = time.monotonic()
while 1:
torch_c = torch_a @ torch_b
et_torch = (time.monotonic() - st) * 1000
ets_torch.append(et_torch)
ets_tinygrad = []
for _ in range(CNT):
tiny_a += 1
tiny_a.realize()
st = time.monotonic()
tiny_c = tiny_a @ tiny_b
tiny_c.realize()
et_tinygrad = (time.monotonic() - st) * 1000
ets_tinygrad.append(et_tinygrad)
val_torch = torch_c.numpy().sum()
val_tinygrad = tiny_c.numpy().sum()
et_torch = np.median(ets_torch)
et_tinygrad = np.median(ets_tinygrad)
print(f"{N}x{N} {et_torch:7.2f} ms in torch, {et_tinygrad:7.2f} ms in tinygrad, {et_tinygrad/et_torch:7.2f}x slower", val_torch, val_tinygrad)
relative_error = abs((val_tinygrad-val_torch)/val_torch)
assert relative_error < 0.01
def test_conv2d(self):
torch.manual_seed(0)
for bs in [32]: