2022-10-29 02:22:15 +08:00
import os
2023-01-24 08:24:46 +08:00
os . environ [ " NVIDIA_TF32_OVERRIDE " ] = " 0 "
2022-10-11 07:06:00 +08:00
import unittest
import torch
2022-10-30 04:42:33 +08:00
torch . set_num_threads ( 1 )
2022-10-11 07:06:00 +08:00
import time
import numpy as np
2022-11-11 15:17:09 +08:00
np . set_printoptions ( linewidth = 160 )
2022-11-12 10:34:24 +08:00
from functools import partial
2023-02-18 04:31:05 +08:00
from tinygrad . ops import GlobalCounters , DEBUG
2022-10-11 07:06:00 +08:00
from tinygrad . tensor import Tensor
from tinygrad . nn import Conv2d
2023-02-01 07:09:09 +08:00
from tinygrad . helpers import colored , getenv
2023-02-12 23:43:17 +08:00
from extra . jit import TinyJit
2023-02-18 04:02:54 +08:00
METAL = getenv ( " METAL " )
2022-11-15 08:45:06 +08:00
try :
2023-02-11 06:13:53 +08:00
from tinygrad . runtime . opencl import CL
2023-02-18 04:02:54 +08:00
if METAL :
from tinygrad . runtime . metal import sync
else :
2023-02-18 04:31:05 +08:00
def sync ( ) : CL ( ) . cl_queue . finish ( )
2022-11-08 13:19:08 +08:00
except ImportError :
2022-11-08 13:21:35 +08:00
CL = None
2023-02-18 04:42:45 +08:00
def sync ( ) : pass
2022-10-11 07:06:00 +08:00
2023-02-01 07:09:09 +08:00
IN_CHANS = [ int ( x ) for x in getenv ( " IN_CHANS " , " 4,16,64 " ) . split ( " , " ) ]
2022-11-08 13:12:08 +08:00
2023-02-01 07:09:09 +08:00
torch_device = torch . device ( ' mps ' if getenv ( " MPS " , 0 ) else ( ' cuda ' if getenv ( " TORCHCUDA " , 0 ) else ' cpu ' ) )
2022-11-12 10:34:24 +08:00
2022-11-08 13:12:08 +08:00
def colorize_float ( x ) :
ret = f " { x : 7.2f } x "
2023-01-10 04:40:01 +08:00
if x < 0.75 :
return colored ( ret , ' green ' )
elif x > 1.5 :
return colored ( ret , ' red ' )
2022-11-08 13:12:08 +08:00
else :
2023-01-10 04:40:01 +08:00
return colored ( ret , ' yellow ' )
2022-11-08 13:12:08 +08:00
2022-11-12 10:34:24 +08:00
save_ops , save_mem = 0 , 0
2022-11-08 13:12:08 +08:00
CNT = 8
2022-11-08 13:27:56 +08:00
def helper_test_speed ( f1 , * args ) :
2022-11-12 10:34:24 +08:00
global save_ops , save_mem
2022-11-08 13:12:08 +08:00
ets = [ ]
ret = None
for _ in range ( CNT ) :
del ret
2022-11-12 10:34:24 +08:00
GlobalCounters . global_ops = 0
GlobalCounters . global_mem = 0
2023-02-12 23:43:17 +08:00
args = [ ( x + 1 ) . realize ( ) if isinstance ( x , Tensor ) else ( None if x is None else ( x + 1 ) ) for x in args ] # cache defeats
2023-02-18 04:31:05 +08:00
# sync all before!
sync ( )
torch . zeros ( 1 , device = torch_device ) . cpu ( )
if DEBUG > = 4 : print ( " benchmark start " )
2022-11-08 13:12:08 +08:00
st = time . monotonic ( )
ret = f1 ( * args )
2023-02-12 23:48:56 +08:00
if isinstance ( ret , Tensor ) and CL is not None and ret . device in [ " GPU " ] :
2023-02-18 04:02:54 +08:00
sync ( )
2023-02-12 23:48:56 +08:00
if not isinstance ( ret , Tensor ) and torch_device != " cpu " :
2022-11-12 10:34:24 +08:00
# TODO: better way to sync?
2023-01-24 08:24:46 +08:00
torch . zeros ( 1 , device = torch_device ) . cpu ( )
2022-11-08 13:12:08 +08:00
et = ( time . monotonic ( ) - st ) * 1000
ets . append ( et )
2023-02-18 04:31:05 +08:00
if DEBUG > = 4 : print ( " benchmark stop " )
2022-11-12 10:34:24 +08:00
if GlobalCounters . global_ops :
save_ops , save_mem = GlobalCounters . global_ops , GlobalCounters . global_mem
return ret . cpu ( ) . numpy ( ) , np . min ( ets )
2022-11-08 13:12:08 +08:00
2023-02-12 23:43:17 +08:00
def helper_test_generic_square ( name , N , f1 , f2 , onearg = False ) :
2022-11-08 13:12:08 +08:00
torch . manual_seed ( 0 )
2022-11-12 10:34:24 +08:00
torch_a = ( torch . rand ( N , N ) - 0.5 ) . to ( torch_device )
2023-02-12 23:43:17 +08:00
torch_b = ( torch . rand ( N , N ) - 0.5 ) . to ( torch_device ) if not onearg else None
2022-11-11 15:17:09 +08:00
2022-11-08 13:12:08 +08:00
tiny_a = Tensor ( torch_a . cpu ( ) . numpy ( ) )
2023-02-12 23:43:17 +08:00
tiny_b = Tensor ( torch_b . cpu ( ) . numpy ( ) ) if not onearg else None
2022-11-08 13:12:08 +08:00
2023-02-12 23:43:17 +08:00
helper_test_generic ( f " { name : 30s } { N : 4d } x { N : 4d } " , f1 , ( torch_a , torch_b ) , TinyJit ( lambda a , b : f2 ( a , b ) . realize ( ) ) , ( tiny_a , tiny_b ) )
2022-11-12 10:34:24 +08:00
prefix = None
2023-02-12 23:43:17 +08:00
def helper_test_generic ( name , f1 , f1_args , f2 , f2_args ) :
2022-11-12 10:34:24 +08:00
global prefix
2022-11-08 13:12:08 +08:00
with torch . no_grad ( ) :
2023-02-12 23:43:17 +08:00
val_torch , et_torch = helper_test_speed ( f1 , * f1_args )
val_tinygrad , et_tinygrad = helper_test_speed ( f2 , * f2_args )
2022-11-08 13:12:08 +08:00
2023-02-11 04:01:07 +08:00
desc = " faster " if et_torch > et_tinygrad else " slower "
2022-11-12 10:34:24 +08:00
flops = save_ops * 1e-6
mem = save_mem * 4 * 1e-6
2023-02-11 04:01:07 +08:00
print ( f " { prefix } { name : 40s } { et_torch : 7.2f } ms ( { flops / et_torch : 8.2f } GFLOPS { mem / et_torch : 8.2f } GB/s) in torch, { et_tinygrad : 7.2f } ms ( { flops / et_tinygrad : 8.2f } GFLOPS { mem / et_tinygrad : 8.2f } GB/s) in tinygrad, { colorize_float ( et_tinygrad / et_torch ) } { desc } { flops : 7.2f } MOPS { mem : 7.2f } MB " )
2022-11-12 10:34:24 +08:00
prefix = " "
2022-11-08 13:12:08 +08:00
np . testing . assert_allclose ( val_tinygrad , val_torch , atol = 1e-4 , rtol = 1e-3 )
2022-10-29 02:22:15 +08:00
2022-10-30 04:42:33 +08:00
class TestSpeed ( unittest . TestCase ) :
2022-11-12 10:34:24 +08:00
def setUp ( self ) :
global prefix
prefix = " " if prefix is None else " "
return super ( ) . setUp ( )
2023-01-26 07:40:19 +08:00
def test_sub ( self ) :
def f ( a , b ) : return a - b
helper_test_generic_square ( ' sub ' , 4096 , f , f )
def test_pow ( self ) :
def f ( a , b ) : return a . pow ( b )
helper_test_generic_square ( ' pow ' , 2048 , f , f )
2022-11-12 10:34:24 +08:00
2022-11-08 13:12:08 +08:00
def test_sum ( self ) :
def f ( a , b ) : return a . sum ( )
2023-02-18 04:31:05 +08:00
helper_test_generic_square ( ' sum ' , 2048 , f , f , onearg = True )
2023-02-12 23:43:17 +08:00
helper_test_generic_square ( ' sum ' , 4096 , f , f , onearg = True )
2022-11-08 13:12:08 +08:00
2023-01-23 13:28:40 +08:00
def test_partial_sum ( self ) :
R = 256
def f ( a , b ) : return a . reshape ( int ( 4096 / / R ) , int ( 4096 * R ) ) . sum ( axis = 1 )
2023-02-12 23:43:17 +08:00
helper_test_generic_square ( ' partial_sum ' , 4096 , f , f , onearg = True )
2023-01-23 13:28:40 +08:00
2022-11-11 15:17:09 +08:00
def test_array_packing ( self ) :
2022-11-12 10:34:24 +08:00
N = 2048
2022-11-11 15:17:09 +08:00
def f ( a , b ) : return a . reshape ( N , N / / 32 , 32 ) . permute ( 1 , 0 , 2 ) . contiguous ( )
2023-02-12 23:43:17 +08:00
helper_test_generic_square ( ' array_packing ' , N , f , f , onearg = True )
2022-11-11 15:17:09 +08:00
2022-11-08 13:12:08 +08:00
def test_permute ( self ) :
2022-11-11 15:17:09 +08:00
for N in [ 1024 , 4096 ] :
# this is a 64MB tensor, M1 L1 cache is 128kB
# to fit easily in L1, rotations should be 128x128 chunks. 128x128 is also the AMX size
def f ( a , b ) : return a . permute ( 1 , 0 ) . contiguous ( )
2023-02-12 23:43:17 +08:00
helper_test_generic_square ( ' permute ' , N , f , f , onearg = True )
2023-01-10 04:40:01 +08:00
def test_double_permute ( self ) :
N = 64
torch . manual_seed ( 0 )
torch_a = ( torch . rand ( N , N , N , N ) - 0.5 ) . to ( torch_device )
tiny_a = Tensor ( torch_a . cpu ( ) . numpy ( ) )
def f ( a ) : return a . permute ( 1 , 0 , 3 , 2 ) . contiguous ( )
2023-02-12 23:43:17 +08:00
helper_test_generic ( f " double_permute { tiny_a . shape } " , f , ( torch_a , ) , TinyJit ( lambda a : f ( a ) . realize ( ) ) , ( tiny_a , ) )
2022-11-08 13:12:08 +08:00
def test_neg ( self ) :
def f ( a , b ) : return - a
2023-02-12 23:43:17 +08:00
helper_test_generic_square ( ' neg ' , 4096 , f , f , onearg = True )
2022-11-08 13:12:08 +08:00
def test_exp ( self ) :
def f ( a , b ) : return a . exp ( )
2023-02-12 23:43:17 +08:00
helper_test_generic_square ( ' exp ' , 2048 , f , f , onearg = True )
2022-11-08 13:12:08 +08:00
def test_relu ( self ) :
def f ( a , b ) : return a . relu ( )
2023-02-12 23:43:17 +08:00
helper_test_generic_square ( ' relu ' , 4096 , f , f , onearg = True )
2022-11-08 13:12:08 +08:00
def test_max ( self ) :
def f ( a , b ) : return a . max ( )
2023-02-12 23:43:17 +08:00
helper_test_generic_square ( ' max ' , 4096 , f , f , onearg = True )
2022-11-08 13:12:08 +08:00
def test_mul_sum ( self ) :
def f ( a , b ) : return ( a * b ) . sum ( )
2022-11-08 13:27:56 +08:00
helper_test_generic_square ( ' mul_sum ' , 4096 , f , f )
2022-11-08 13:12:08 +08:00
def test_add ( self ) :
2023-01-29 16:23:06 +08:00
for N in [ 1 , 1024 , 4096 ] :
2022-11-08 13:12:08 +08:00
def f ( a , b ) : return a + b
2022-11-08 13:27:56 +08:00
helper_test_generic_square ( ' add ' , N , f , f )
2022-11-08 13:12:08 +08:00
2023-01-26 10:11:26 +08:00
def test_add_constant ( self ) :
def f ( a , b ) : return a + 2.0
2023-02-12 23:43:17 +08:00
helper_test_generic_square ( ' add_constant ' , 4096 , f , f , onearg = True )
2023-01-26 10:11:26 +08:00
def test_add_constant_zero ( self ) :
def f ( a , b ) : return a + 0.0
2023-02-12 23:43:17 +08:00
helper_test_generic_square ( ' add_constant_zero ' , 4096 , f , f , onearg = True )
2023-01-26 10:11:26 +08:00
2022-11-08 13:12:08 +08:00
def test_add_sq ( self ) :
def f ( a , b ) : return a * a + b * b
2022-11-08 13:27:56 +08:00
helper_test_generic_square ( ' add_sq ' , 4096 , f , f )
2022-11-08 13:12:08 +08:00
2022-10-30 04:42:33 +08:00
def test_gemm ( self ) :
2022-11-08 13:12:08 +08:00
def f ( a , b ) : return a @ b
2022-11-08 13:27:56 +08:00
helper_test_generic_square ( ' gemm ' , 512 , f , f )
2022-11-08 13:12:08 +08:00
def test_gemm_unrolled ( self ) :
N = 512
def f1 ( a , b ) : return a @b.T
def f2 ( a , b ) : return ( a . reshape ( N , 1 , N ) . expand ( N , N , N ) * b . reshape ( 1 , N , N ) . expand ( N , N , N ) ) . sum ( axis = 2 )
2022-11-08 13:27:56 +08:00
helper_test_generic_square ( ' gemm_unrolled ' , N , f1 , f2 )
2022-11-11 15:17:09 +08:00
def test_gemm_unrolled_permute_l ( self ) :
N = 512
def f1 ( a , b ) : return a . T @b.T
def f2 ( a , b ) : return ( a . permute ( 1 , 0 ) . reshape ( N , 1 , N ) . expand ( N , N , N ) * b . reshape ( 1 , N , N ) . expand ( N , N , N ) ) . sum ( axis = 2 )
helper_test_generic_square ( ' gemm_unrolled_permute_l ' , N , f1 , f2 )
2022-11-08 13:12:08 +08:00
def test_gemm_unrolled_permute_r ( self ) :
N = 512
def f1 ( a , b ) : return a @b
def f2 ( a , b ) : return ( a . reshape ( N , 1 , N ) . expand ( N , N , N ) * b . permute ( 1 , 0 ) . reshape ( 1 , N , N ) . expand ( N , N , N ) ) . sum ( axis = 2 )
2022-11-08 13:27:56 +08:00
helper_test_generic_square ( ' gemm_unrolled_permute_r ' , N , f1 , f2 )
2022-11-08 13:12:08 +08:00
def test_gemm_unrolled_permute_lr ( self ) :
N = 512
def f1 ( a , b ) : return a . T @b
def f2 ( a , b ) : return ( a . permute ( 1 , 0 ) . reshape ( N , 1 , N ) . expand ( N , N , N ) * b . permute ( 1 , 0 ) . reshape ( 1 , N , N ) . expand ( N , N , N ) ) . sum ( axis = 2 )
2022-11-08 13:27:56 +08:00
helper_test_generic_square ( ' gemm_unrolled_permute_lr ' , N , f1 , f2 )
2022-10-30 04:42:33 +08:00
2023-01-10 04:40:01 +08:00
def test_openpilot_conv2d ( self ) :
bs , in_chans , out_chans = 1 , 12 , 32
torch . manual_seed ( 0 )
torch_dat = torch . rand ( bs , 64 , 128 , 12 ) . to ( torch_device )
torch_conv = torch . nn . Conv2d ( in_chans , out_chans , 3 , bias = None , padding = 1 ) . to ( torch_device )
tiny_dat = Tensor ( torch_dat . cpu ( ) . numpy ( ) )
tiny_conv = Conv2d ( in_chans , out_chans , 3 , bias = None , padding = 1 )
tiny_conv . weight = Tensor ( torch_conv . weight . detach ( ) . cpu ( ) . numpy ( ) )
2023-02-12 23:43:17 +08:00
def f1 ( torch_dat ) : return torch_conv ( torch_dat . permute ( 0 , 3 , 1 , 2 ) )
def f2 ( tiny_dat ) : return tiny_conv ( tiny_dat . permute ( 0 , 3 , 1 , 2 ) ) . realize ( )
helper_test_generic ( f " conv bs: { bs : 3d } chans: { in_chans : 3d } -> { out_chans : 3d } " , f1 , ( torch_dat , ) , TinyJit ( f2 ) , ( tiny_dat , ) )
2023-01-10 04:40:01 +08:00
2022-10-11 07:06:00 +08:00
def test_conv2d ( self ) :
torch . manual_seed ( 0 )
for bs in [ 32 ] :
2022-10-29 02:22:15 +08:00
for in_chans in IN_CHANS :
2022-11-08 13:12:08 +08:00
for out_chans in [ 32 ] :
img_size = 34
2022-11-12 10:34:24 +08:00
torch_dat = torch . rand ( bs , in_chans , img_size , img_size ) . to ( torch_device )
torch_conv = torch . nn . Conv2d ( in_chans , out_chans , 3 , bias = None ) . to ( torch_device )
2022-11-08 13:12:08 +08:00
tiny_dat = Tensor ( torch_dat . cpu ( ) . numpy ( ) )
tiny_conv = Conv2d ( in_chans , out_chans , 3 , bias = None )
tiny_conv . weight = Tensor ( torch_conv . weight . detach ( ) . cpu ( ) . numpy ( ) )
2023-02-12 23:43:17 +08:00
def f1 ( torch_dat ) : return torch_conv ( torch_dat )
def f2 ( tiny_dat ) : return tiny_conv ( tiny_dat ) . realize ( )
helper_test_generic ( f " conv bs: { bs : 3d } chans: { in_chans : 3d } -> { out_chans : 3d } " , f1 , ( torch_dat , ) , TinyJit ( f2 ) , ( tiny_dat , ) )
2022-10-11 07:06:00 +08:00
if __name__ == ' __main__ ' :
unittest . main ( )