2022-10-29 02:22:15 +08:00
import os
2023-01-24 08:24:46 +08:00
os . environ [ " NVIDIA_TF32_OVERRIDE " ] = " 0 "
2023-03-02 13:54:51 +08:00
os . environ [ " MKL_NUM_THREADS " ] = " 1 "
os . environ [ " NUMEXPR_NUM_THREADS " ] = " 1 "
os . environ [ " OMP_NUM_THREADS " ] = " 1 "
2022-10-11 07:06:00 +08:00
import unittest
import torch
2022-10-30 04:42:33 +08:00
torch . set_num_threads ( 1 )
2022-10-11 07:06:00 +08:00
import time
import numpy as np
2022-11-11 15:17:09 +08:00
np . set_printoptions ( linewidth = 160 )
2022-11-12 10:34:24 +08:00
from functools import partial
2023-03-25 01:24:27 +08:00
from tinygrad . lazy import Device
2023-02-22 22:58:27 +08:00
from tinygrad . ops import GlobalCounters
2022-10-11 07:06:00 +08:00
from tinygrad . tensor import Tensor
from tinygrad . nn import Conv2d
2023-08-07 00:32:33 +08:00
from tinygrad . helpers import colored , getenv , CI
2023-02-26 00:32:33 +08:00
from tinygrad . jit import TinyJit
2023-07-24 04:00:56 +08:00
import pytest
2023-08-01 10:18:19 +08:00
pytestmark = [ pytest . mark . exclude_cuda , pytest . mark . exclude_gpu , pytest . mark . exclude_clang ]
2022-10-11 07:06:00 +08:00
2023-02-01 07:09:09 +08:00
IN_CHANS = [ int ( x ) for x in getenv ( " IN_CHANS " , " 4,16,64 " ) . split ( " , " ) ]
2022-11-08 13:12:08 +08:00
2023-02-01 07:09:09 +08:00
torch_device = torch . device ( ' mps ' if getenv ( " MPS " , 0 ) else ( ' cuda ' if getenv ( " TORCHCUDA " , 0 ) else ' cpu ' ) )
2023-03-25 01:24:27 +08:00
if str ( torch_device ) == " mps " :
import torch . mps
sync = lambda : torch . mps . synchronize ( )
elif str ( torch_device ) == " cuda " :
import torch . cuda
sync = lambda : torch . cuda . synchronize ( )
else :
sync = lambda : None
2022-11-12 10:34:24 +08:00
2022-11-08 13:12:08 +08:00
def colorize_float ( x ) :
ret = f " { x : 7.2f } x "
2023-01-10 04:40:01 +08:00
if x < 0.75 :
return colored ( ret , ' green ' )
2023-07-09 07:58:26 +08:00
elif x > 1.15 :
2023-01-10 04:40:01 +08:00
return colored ( ret , ' red ' )
2022-11-08 13:12:08 +08:00
else :
2023-01-10 04:40:01 +08:00
return colored ( ret , ' yellow ' )
2022-11-08 13:12:08 +08:00
2022-11-12 10:34:24 +08:00
save_ops , save_mem = 0 , 0
2022-11-08 13:12:08 +08:00
CNT = 8
2022-11-08 13:27:56 +08:00
def helper_test_speed ( f1 , * args ) :
2022-11-12 10:34:24 +08:00
global save_ops , save_mem
2022-11-08 13:12:08 +08:00
ets = [ ]
ret = None
2023-03-29 09:11:02 +08:00
cache_defeat = np . zeros ( ( 2048 , 2048 ) )
for i in range ( CNT ) :
2022-11-08 13:12:08 +08:00
del ret
2023-03-29 09:11:02 +08:00
# operation cache defeats
args = [ ( x + 1 ) . realize ( ) if isinstance ( x , Tensor ) else ( None if x is None else ( x + 1 ) ) for x in args ]
2023-02-18 04:31:05 +08:00
2023-03-02 10:57:29 +08:00
# force syncing
2023-03-02 13:34:45 +08:00
[ x . numpy ( ) if isinstance ( x , Tensor ) or str ( torch_device ) == " cpu " else x . cpu ( ) . numpy ( ) for x in args if x is not None ]
2023-02-18 04:31:05 +08:00
2023-03-29 09:11:02 +08:00
# clear 32MB global memory cache (CPU and global memory only)
cache_defeat + = 1
# manual pre sync
if isinstance ( args [ 0 ] , Tensor ) : Device [ args [ 0 ] . device ] . synchronize ( )
else : sync ( )
2023-03-02 10:57:29 +08:00
GlobalCounters . global_ops = 0
GlobalCounters . global_mem = 0
2023-03-29 09:11:02 +08:00
st = time . perf_counter ( )
2022-11-08 13:12:08 +08:00
ret = f1 ( * args )
2023-03-29 09:11:02 +08:00
if isinstance ( ret , Tensor ) : Device [ ret . device ] . synchronize ( )
else : sync ( )
et = ( time . perf_counter ( ) - st ) * 1000
2023-06-22 02:50:43 +08:00
if i > = 1 : ets . append ( et )
2022-11-12 10:34:24 +08:00
if GlobalCounters . global_ops :
save_ops , save_mem = GlobalCounters . global_ops , GlobalCounters . global_mem
2023-08-22 07:57:59 +08:00
return ret . numpy ( ) if isinstance ( ret , Tensor ) else ret . cpu ( ) . numpy ( ) , np . min ( ets )
2022-11-08 13:12:08 +08:00
2023-02-12 23:43:17 +08:00
def helper_test_generic_square ( name , N , f1 , f2 , onearg = False ) :
2022-11-08 13:12:08 +08:00
torch . manual_seed ( 0 )
2022-11-12 10:34:24 +08:00
torch_a = ( torch . rand ( N , N ) - 0.5 ) . to ( torch_device )
2023-02-12 23:43:17 +08:00
torch_b = ( torch . rand ( N , N ) - 0.5 ) . to ( torch_device ) if not onearg else None
2022-11-11 15:17:09 +08:00
2022-11-08 13:12:08 +08:00
tiny_a = Tensor ( torch_a . cpu ( ) . numpy ( ) )
2023-02-12 23:43:17 +08:00
tiny_b = Tensor ( torch_b . cpu ( ) . numpy ( ) ) if not onearg else None
2022-11-08 13:12:08 +08:00
2023-06-19 11:28:06 +08:00
helper_test_generic ( f " { name : 30s } { N : 5d } x { N : 5d } " , f1 , ( torch_a , torch_b ) , TinyJit ( lambda a , b : f2 ( a , b ) . realize ( ) ) , ( tiny_a , tiny_b ) )
2022-11-12 10:34:24 +08:00
prefix = None
2023-02-12 23:43:17 +08:00
def helper_test_generic ( name , f1 , f1_args , f2 , f2_args ) :
2022-11-12 10:34:24 +08:00
global prefix
2022-11-08 13:12:08 +08:00
with torch . no_grad ( ) :
2023-02-12 23:43:17 +08:00
val_torch , et_torch = helper_test_speed ( f1 , * f1_args )
val_tinygrad , et_tinygrad = helper_test_speed ( f2 , * f2_args )
2022-11-08 13:12:08 +08:00
2023-02-11 04:01:07 +08:00
desc = " faster " if et_torch > et_tinygrad else " slower "
2022-11-12 10:34:24 +08:00
flops = save_ops * 1e-6
2023-03-11 01:44:12 +08:00
mem = save_mem * 1e-6
2023-08-07 00:32:33 +08:00
print ( ( " \r " if not CI else " " ) + f " { name : 42s } { et_torch : 7.2f } ms ( { flops / et_torch : 8.2f } GFLOPS { mem / et_torch : 8.2f } GB/s) in torch, { et_tinygrad : 7.2f } ms ( { flops / et_tinygrad : 8.2f } GFLOPS { mem / et_tinygrad : 8.2f } GB/s) in tinygrad, { colorize_float ( et_tinygrad / et_torch ) } { desc } { flops : 10.2f } MOPS { mem : 8.2f } MB " )
2022-11-08 13:12:08 +08:00
np . testing . assert_allclose ( val_tinygrad , val_torch , atol = 1e-4 , rtol = 1e-3 )
2022-10-29 02:22:15 +08:00
2023-06-19 11:28:06 +08:00
def helper_test_conv ( bs , in_chans , out_chans , kernel_size , img_size_y , img_size_x ) :
torch . manual_seed ( 0 )
torch_dat = torch . rand ( bs , in_chans , img_size_y , img_size_x ) . to ( torch_device )
torch_conv = torch . nn . Conv2d ( in_chans , out_chans , kernel_size , bias = None ) . to ( torch_device )
tiny_dat = Tensor ( torch_dat . cpu ( ) . numpy ( ) )
tiny_conv = Conv2d ( in_chans , out_chans , kernel_size , bias = None )
tiny_conv . weight = Tensor ( torch_conv . weight . detach ( ) . cpu ( ) . numpy ( ) )
def f1 ( torch_dat ) : return torch_conv ( torch_dat )
def f2 ( tiny_dat ) : return tiny_conv ( tiny_dat ) . realize ( )
2023-06-27 09:55:27 +08:00
helper_test_generic ( f " conv bs: { bs : 3d } chans: { in_chans : 3d } -> { out_chans : 3d } k: { kernel_size } " , f1 , ( torch_dat , ) , TinyJit ( f2 ) , ( tiny_dat , ) )
2023-06-19 11:28:06 +08:00
2023-08-07 00:32:33 +08:00
@unittest.skipIf ( getenv ( " BIG " ) == 0 , " no big tests " )
2023-06-19 11:28:06 +08:00
class TestBigSpeed ( unittest . TestCase ) :
2023-06-27 09:55:27 +08:00
def test_add ( self ) :
def f ( a , b ) : return a + b
2023-07-09 07:58:26 +08:00
helper_test_generic_square ( ' add ' , 8192 , f , f )
2023-06-19 11:28:06 +08:00
def test_exp ( self ) :
def f ( a , b ) : return a . exp ( )
2023-07-09 07:58:26 +08:00
helper_test_generic_square ( ' exp ' , 8192 , f , f , onearg = True )
2023-06-19 11:28:06 +08:00
def test_gemm_2048 ( self ) :
def f ( a , b ) : return a @ b
helper_test_generic_square ( ' gemm ' , 2048 , f , f )
def test_gemm_4096 ( self ) :
def f ( a , b ) : return a @ b
helper_test_generic_square ( ' gemm ' , 4096 , f , f )
def test_large_conv_1x1 ( self ) : helper_test_conv ( bs = 32 , in_chans = 128 , out_chans = 128 , kernel_size = 1 , img_size_y = 128 , img_size_x = 128 )
2023-07-21 07:46:45 +08:00
def test_large_conv_3x3 ( self ) : helper_test_conv ( bs = 4 , in_chans = 128 , out_chans = 128 , kernel_size = 3 , img_size_y = 130 , img_size_x = 130 )
def test_large_conv_5x5 ( self ) : helper_test_conv ( bs = 4 , in_chans = 128 , out_chans = 128 , kernel_size = 5 , img_size_y = 130 , img_size_x = 130 )
2023-06-19 11:28:06 +08:00
2023-08-07 00:32:33 +08:00
@unittest.skipIf ( getenv ( " BIG " ) == 1 , " only big tests " )
2022-10-30 04:42:33 +08:00
class TestSpeed ( unittest . TestCase ) :
2023-01-26 07:40:19 +08:00
def test_sub ( self ) :
def f ( a , b ) : return a - b
helper_test_generic_square ( ' sub ' , 4096 , f , f )
2023-08-07 01:32:01 +08:00
@unittest.skipIf ( getenv ( " CI " , " " ) != " " and Device . DEFAULT == " WEBGPU " , " breaking on webgpu CI " )
2023-01-26 07:40:19 +08:00
def test_pow ( self ) :
def f ( a , b ) : return a . pow ( b )
helper_test_generic_square ( ' pow ' , 2048 , f , f )
2022-11-12 10:34:24 +08:00
2022-11-08 13:12:08 +08:00
def test_sum ( self ) :
def f ( a , b ) : return a . sum ( )
2023-02-18 04:31:05 +08:00
helper_test_generic_square ( ' sum ' , 2048 , f , f , onearg = True )
2023-02-12 23:43:17 +08:00
helper_test_generic_square ( ' sum ' , 4096 , f , f , onearg = True )
2022-11-08 13:12:08 +08:00
2023-01-23 13:28:40 +08:00
def test_partial_sum ( self ) :
R = 256
def f ( a , b ) : return a . reshape ( int ( 4096 / / R ) , int ( 4096 * R ) ) . sum ( axis = 1 )
2023-02-12 23:43:17 +08:00
helper_test_generic_square ( ' partial_sum ' , 4096 , f , f , onearg = True )
2023-07-20 00:37:23 +08:00
@unittest.skip ( " not really used in models " )
2023-07-18 09:59:36 +08:00
def test_cumsum ( self ) :
def f0 ( a , b ) : return a . cumsum ( axis = 0 )
def f1 ( a , b ) : return a . cumsum ( axis = 1 )
helper_test_generic_square ( ' cumsum_0 ' , 256 , f0 , f0 , onearg = True )
helper_test_generic_square ( ' cumsum_1 ' , 256 , f1 , f1 , onearg = True )
2023-07-20 00:37:23 +08:00
2023-07-27 06:01:12 +08:00
def test_cat ( self ) :
helper_test_generic_square ( ' cat_0 ' , 256 , lambda x , y : torch . cat ( ( x , y ) , dim = 0 ) , lambda x , y : x . cat ( y , dim = 0 ) )
helper_test_generic_square ( ' cat_1 ' , 256 , lambda x , y : torch . cat ( ( x , y ) , dim = 1 ) , lambda x , y : x . cat ( y , dim = 1 ) )
2022-11-11 15:17:09 +08:00
def test_array_packing ( self ) :
2022-11-12 10:34:24 +08:00
N = 2048
2022-11-11 15:17:09 +08:00
def f ( a , b ) : return a . reshape ( N , N / / 32 , 32 ) . permute ( 1 , 0 , 2 ) . contiguous ( )
2023-02-12 23:43:17 +08:00
helper_test_generic_square ( ' array_packing ' , N , f , f , onearg = True )
2022-11-11 15:17:09 +08:00
2022-11-08 13:12:08 +08:00
def test_permute ( self ) :
2022-11-11 15:17:09 +08:00
for N in [ 1024 , 4096 ] :
# this is a 64MB tensor, M1 L1 cache is 128kB
# to fit easily in L1, rotations should be 128x128 chunks. 128x128 is also the AMX size
def f ( a , b ) : return a . permute ( 1 , 0 ) . contiguous ( )
2023-02-12 23:43:17 +08:00
helper_test_generic_square ( ' permute ' , N , f , f , onearg = True )
2023-03-11 01:44:12 +08:00
2023-01-10 04:40:01 +08:00
def test_double_permute ( self ) :
N = 64
torch . manual_seed ( 0 )
torch_a = ( torch . rand ( N , N , N , N ) - 0.5 ) . to ( torch_device )
tiny_a = Tensor ( torch_a . cpu ( ) . numpy ( ) )
def f ( a ) : return a . permute ( 1 , 0 , 3 , 2 ) . contiguous ( )
2023-02-12 23:43:17 +08:00
helper_test_generic ( f " double_permute { tiny_a . shape } " , f , ( torch_a , ) , TinyJit ( lambda a : f ( a ) . realize ( ) ) , ( tiny_a , ) )
2022-11-08 13:12:08 +08:00
def test_neg ( self ) :
def f ( a , b ) : return - a
2023-02-12 23:43:17 +08:00
helper_test_generic_square ( ' neg ' , 4096 , f , f , onearg = True )
2022-11-08 13:12:08 +08:00
def test_exp ( self ) :
def f ( a , b ) : return a . exp ( )
2023-02-12 23:43:17 +08:00
helper_test_generic_square ( ' exp ' , 2048 , f , f , onearg = True )
2022-11-08 13:12:08 +08:00
def test_relu ( self ) :
def f ( a , b ) : return a . relu ( )
2023-02-12 23:43:17 +08:00
helper_test_generic_square ( ' relu ' , 4096 , f , f , onearg = True )
2022-11-08 13:12:08 +08:00
def test_max ( self ) :
def f ( a , b ) : return a . max ( )
2023-02-12 23:43:17 +08:00
helper_test_generic_square ( ' max ' , 4096 , f , f , onearg = True )
2022-11-08 13:12:08 +08:00
def test_mul_sum ( self ) :
def f ( a , b ) : return ( a * b ) . sum ( )
2022-11-08 13:27:56 +08:00
helper_test_generic_square ( ' mul_sum ' , 4096 , f , f )
2022-11-08 13:12:08 +08:00
def test_add ( self ) :
2023-01-29 16:23:06 +08:00
for N in [ 1 , 1024 , 4096 ] :
2022-11-08 13:12:08 +08:00
def f ( a , b ) : return a + b
2022-11-08 13:27:56 +08:00
helper_test_generic_square ( ' add ' , N , f , f )
2022-11-08 13:12:08 +08:00
2023-01-26 10:11:26 +08:00
def test_add_constant ( self ) :
def f ( a , b ) : return a + 2.0
2023-02-12 23:43:17 +08:00
helper_test_generic_square ( ' add_constant ' , 4096 , f , f , onearg = True )
2023-01-26 10:11:26 +08:00
2022-11-08 13:12:08 +08:00
def test_add_sq ( self ) :
def f ( a , b ) : return a * a + b * b
2022-11-08 13:27:56 +08:00
helper_test_generic_square ( ' add_sq ' , 4096 , f , f )
2022-11-08 13:12:08 +08:00
2022-10-30 04:42:33 +08:00
def test_gemm ( self ) :
2022-11-08 13:12:08 +08:00
def f ( a , b ) : return a @ b
2023-03-25 01:24:27 +08:00
helper_test_generic_square ( ' gemm ' , 1024 , f , f )
2022-11-08 13:12:08 +08:00
2023-03-29 09:11:02 +08:00
def test_gemm_small ( self ) :
def f ( a , b ) : return a @ b
helper_test_generic_square ( ' gemm ' , 256 , f , f )
2022-11-08 13:12:08 +08:00
def test_gemm_unrolled ( self ) :
N = 512
def f1 ( a , b ) : return a @b.T
def f2 ( a , b ) : return ( a . reshape ( N , 1 , N ) . expand ( N , N , N ) * b . reshape ( 1 , N , N ) . expand ( N , N , N ) ) . sum ( axis = 2 )
2022-11-08 13:27:56 +08:00
helper_test_generic_square ( ' gemm_unrolled ' , N , f1 , f2 )
2023-03-11 01:44:12 +08:00
2022-11-11 15:17:09 +08:00
def test_gemm_unrolled_permute_l ( self ) :
N = 512
def f1 ( a , b ) : return a . T @b.T
def f2 ( a , b ) : return ( a . permute ( 1 , 0 ) . reshape ( N , 1 , N ) . expand ( N , N , N ) * b . reshape ( 1 , N , N ) . expand ( N , N , N ) ) . sum ( axis = 2 )
helper_test_generic_square ( ' gemm_unrolled_permute_l ' , N , f1 , f2 )
2022-11-08 13:12:08 +08:00
def test_gemm_unrolled_permute_r ( self ) :
N = 512
def f1 ( a , b ) : return a @b
def f2 ( a , b ) : return ( a . reshape ( N , 1 , N ) . expand ( N , N , N ) * b . permute ( 1 , 0 ) . reshape ( 1 , N , N ) . expand ( N , N , N ) ) . sum ( axis = 2 )
2022-11-08 13:27:56 +08:00
helper_test_generic_square ( ' gemm_unrolled_permute_r ' , N , f1 , f2 )
2022-11-08 13:12:08 +08:00
def test_gemm_unrolled_permute_lr ( self ) :
N = 512
def f1 ( a , b ) : return a . T @b
def f2 ( a , b ) : return ( a . permute ( 1 , 0 ) . reshape ( N , 1 , N ) . expand ( N , N , N ) * b . permute ( 1 , 0 ) . reshape ( 1 , N , N ) . expand ( N , N , N ) ) . sum ( axis = 2 )
2022-11-08 13:27:56 +08:00
helper_test_generic_square ( ' gemm_unrolled_permute_lr ' , N , f1 , f2 )
2022-10-30 04:42:33 +08:00
2023-01-10 04:40:01 +08:00
def test_openpilot_conv2d ( self ) :
bs , in_chans , out_chans = 1 , 12 , 32
torch . manual_seed ( 0 )
torch_dat = torch . rand ( bs , 64 , 128 , 12 ) . to ( torch_device )
torch_conv = torch . nn . Conv2d ( in_chans , out_chans , 3 , bias = None , padding = 1 ) . to ( torch_device )
tiny_dat = Tensor ( torch_dat . cpu ( ) . numpy ( ) )
tiny_conv = Conv2d ( in_chans , out_chans , 3 , bias = None , padding = 1 )
tiny_conv . weight = Tensor ( torch_conv . weight . detach ( ) . cpu ( ) . numpy ( ) )
2023-02-12 23:43:17 +08:00
def f1 ( torch_dat ) : return torch_conv ( torch_dat . permute ( 0 , 3 , 1 , 2 ) )
def f2 ( tiny_dat ) : return tiny_conv ( tiny_dat . permute ( 0 , 3 , 1 , 2 ) ) . realize ( )
2023-07-20 00:37:23 +08:00
helper_test_generic ( f " conv bs: { bs : 3d } chans: { in_chans : 3d } -> { out_chans : 3d } k:3 " , f1 , ( torch_dat , ) , TinyJit ( f2 ) , ( tiny_dat , ) )
2023-01-10 04:40:01 +08:00
2022-10-11 07:06:00 +08:00
def test_conv2d ( self ) :
for bs in [ 32 ] :
2022-10-29 02:22:15 +08:00
for in_chans in IN_CHANS :
2022-11-08 13:12:08 +08:00
for out_chans in [ 32 ] :
2023-06-19 11:28:06 +08:00
helper_test_conv ( bs , in_chans , out_chans , 3 , 34 , 34 )
2022-10-11 07:06:00 +08:00
if __name__ == ' __main__ ' :
unittest . main ( )