2023-03-11 08:56:07 +08:00
# this can be constructed from a cl_cache or loaded from a thneed file
2022-10-21 03:35:59 +08:00
import time
import struct
import json
import traceback
import numpy as np
2023-11-02 14:01:32 +08:00
from tinygrad . runtime . ops_gpu import CLProgram , compile_gpu
2023-08-29 10:59:55 +08:00
from tinygrad . helpers import DEBUG , getenv
2022-10-21 10:36:43 +08:00
from collections import defaultdict
2022-10-21 03:35:59 +08:00
import pyopencl as cl
2023-03-02 10:57:29 +08:00
from tinygrad . runtime . ops_gpu import CL , OSX_TIMING_RATIO
2022-10-21 03:35:59 +08:00
2023-02-01 07:09:09 +08:00
DEBUGCL = getenv ( " DEBUGCL " , 0 )
FLOAT16 = getenv ( " FLOAT16 " , 0 )
2022-10-21 03:35:59 +08:00
class Thneed :
def __init__ ( self , cl_cache = [ ] , inputs = { } ) :
2023-02-09 06:52:29 +08:00
self . cl_cache , self . inputs = cl_cache [ : ] , inputs
2022-10-21 03:35:59 +08:00
self . gobj = 0
# build graph
2023-01-12 12:18:42 +08:00
# NOTE: if CLCACHE=1, this is wrong!
2022-10-21 10:36:43 +08:00
nodes = defaultdict ( lambda : { ' in_edges ' : [ ] , ' out_edges ' : [ ] } )
2022-10-21 03:35:59 +08:00
for _ , args in self . cl_cache :
# output is always the first parameter
for a in args [ 3 : ] :
2022-10-21 10:36:43 +08:00
nodes [ a ] [ ' out_edges ' ] . append ( args [ 2 ] )
nodes [ args [ 2 ] ] [ ' in_edges ' ] . append ( a )
2023-03-11 08:56:07 +08:00
2022-10-21 03:35:59 +08:00
# get buffers to save
self . buffers_to_save = set ( )
self . outputs = [ ]
2022-10-21 10:36:43 +08:00
for n in nodes . keys ( ) :
if len ( nodes [ n ] [ ' in_edges ' ] ) == 0 :
2022-10-21 03:35:59 +08:00
self . buffers_to_save . add ( n )
2022-10-21 10:36:43 +08:00
if len ( nodes [ n ] [ ' out_edges ' ] ) == 0 :
2022-10-21 03:35:59 +08:00
self . outputs . append ( n )
2023-03-11 08:56:07 +08:00
2023-01-10 04:40:01 +08:00
fake_inputs = [ ]
for k , n in self . inputs . items ( ) :
if n in self . buffers_to_save :
self . buffers_to_save . remove ( n )
else :
print ( f " WARNING: { k } was not a used input, removing it " )
fake_inputs . append ( k )
for k in fake_inputs :
del self . inputs [ k ]
2022-10-21 03:35:59 +08:00
def load ( self , input_fn ) :
float32 = not FLOAT16
mf = cl . mem_flags
image_fmt = cl . ImageFormat ( cl . channel_order . RGBA , cl . channel_type . FLOAT if float32 else cl . channel_type . HALF_FLOAT )
image_fmt_32 = cl . ImageFormat ( cl . channel_order . RGBA , cl . channel_type . FLOAT )
with open ( input_fn , " rb " ) as f :
json_len = struct . unpack ( " I " , f . read ( 4 ) ) [ 0 ]
jdat = json . loads ( f . read ( json_len ) . decode ( ' latin_1 ' ) )
weights = f . read ( )
# load in the buffers
bufs = { ' \x00 \x00 \x00 \x00 \x00 \x00 \x00 \x00 ' : None }
bufs_loaded = { }
ptr = 0
for o in jdat [ ' objects ' ] :
#print(o)
if o [ ' needs_load ' ] :
nptr = ptr + o [ ' size ' ]
o [ ' data ' ] = weights [ ptr : nptr ]
ptr = nptr
if o [ ' arg_type ' ] == " image2d_t " or o [ ' arg_type ' ] == " image1d_t " :
tfmt = image_fmt_32 if ' float32 ' in o and o [ ' float32 ' ] else image_fmt
if o [ ' arg_type ' ] == " image2d_t " :
if ' buffer_id ' in o and o [ ' height ' ] == 1 and not bufs_loaded [ o [ ' buffer_id ' ] ] :
# hack: use a image1d since we can back that with a buffer
2023-07-19 10:59:30 +08:00
buf = cl . Image ( CL . cl_ctxs [ 0 ] , mf . READ_WRITE , tfmt , shape = ( o [ ' width ' ] , ) , buffer = bufs [ o [ ' buffer_id ' ] ] )
2022-10-21 03:35:59 +08:00
else :
# buffer isn't supported in image2d, copy buffer into image
if ' buffer_id ' in o and bufs_loaded [ o [ ' buffer_id ' ] ] :
arr = np . zeros ( bufs [ o [ ' buffer_id ' ] ] . size / / 2 , dtype = np . float16 )
2023-07-03 03:33:59 +08:00
cl . enqueue_copy ( CL . cl_queue [ 0 ] , arr , bufs [ o [ ' buffer_id ' ] ] )
2023-07-19 10:59:30 +08:00
buf = cl . Image ( CL . cl_ctxs [ 0 ] , mf . READ_WRITE | mf . COPY_HOST_PTR , tfmt ,
2022-10-21 03:35:59 +08:00
shape = ( o [ ' width ' ] , o [ ' height ' ] ) , pitches = ( o [ ' row_pitch ' ] , ) , hostbuf = arr )
elif o [ ' needs_load ' ] :
2023-07-19 10:59:30 +08:00
buf = cl . Image ( CL . cl_ctxs [ 0 ] , mf . READ_WRITE | mf . COPY_HOST_PTR , tfmt ,
2022-10-21 03:35:59 +08:00
shape = ( o [ ' width ' ] , o [ ' height ' ] ) , pitches = ( o [ ' row_pitch ' ] , ) , hostbuf = o [ ' data ' ] )
else :
2023-07-19 10:59:30 +08:00
buf = cl . Image ( CL . cl_ctxs [ 0 ] , mf . READ_WRITE , tfmt , shape = ( o [ ' width ' ] , o [ ' height ' ] ) )
2022-10-21 03:35:59 +08:00
if o [ ' arg_type ' ] == " image1d_t " :
assert not o [ ' needs_load ' ]
assert not bufs_loaded [ o [ ' buffer_id ' ] ]
2023-07-19 10:59:30 +08:00
buf = cl . Image ( CL . cl_ctxs [ 0 ] , mf . READ_WRITE , tfmt , shape = ( o [ ' width ' ] , ) , buffer = bufs [ o [ ' buffer_id ' ] ] )
2022-10-21 03:35:59 +08:00
else :
if ' data ' in o :
2023-07-19 10:59:30 +08:00
buf = cl . Buffer ( CL . cl_ctxs [ 0 ] , mf . READ_WRITE | mf . COPY_HOST_PTR , hostbuf = o [ ' data ' ] )
2022-10-21 03:35:59 +08:00
else :
# zero out buffers
2023-07-19 10:59:30 +08:00
buf = cl . Buffer ( CL . cl_ctxs [ 0 ] , mf . READ_WRITE | mf . COPY_HOST_PTR , hostbuf = b ' \x00 ' * o [ ' size ' ] )
2023-03-11 08:56:07 +08:00
2022-10-21 03:35:59 +08:00
bufs [ o [ ' id ' ] ] = buf
bufs_loaded [ o [ ' id ' ] ] = ' data ' in o
2022-10-21 05:07:28 +08:00
# if it's loaded, it's saved
if ' data ' in o :
self . buffers_to_save . add ( buf )
2022-10-21 03:35:59 +08:00
# load binaries
2023-11-02 14:01:32 +08:00
prgs = { }
2022-10-21 03:35:59 +08:00
for o in jdat [ ' binaries ' ] :
nptr = ptr + o [ ' length ' ]
2023-11-02 14:01:32 +08:00
prgs [ o [ ' name ' ] ] = CLProgram ( o [ ' name ' ] , weights [ ptr : nptr ] )
2022-10-21 03:35:59 +08:00
ptr = nptr
2023-03-11 08:56:07 +08:00
2022-10-21 03:35:59 +08:00
# populate the cl_cache
2022-10-21 04:46:52 +08:00
for i , k in enumerate ( jdat [ ' kernels ' ] ) :
2022-10-21 03:35:59 +08:00
kernel = prgs [ k [ ' name ' ] ]
aaa = [ ]
2022-10-21 04:46:52 +08:00
for j , ( a , sz ) in enumerate ( zip ( k [ ' args ' ] , k [ ' args_size ' ] ) ) :
2022-10-21 03:35:59 +08:00
if len ( a ) == 0 :
aa = cl . LocalMemory ( sz )
elif len ( a ) == 4 :
a = a . encode ( ' latin_1 ' )
aa = np . uint32 ( struct . unpack ( " I " , a ) [ 0 ] )
elif len ( a ) == 2 :
a = a . encode ( ' latin_1 ' )
aa = np . uint16 ( struct . unpack ( " H " , a ) [ 0 ] )
elif len ( a ) == 8 :
2022-10-21 04:46:52 +08:00
#print(i,j,struct.unpack("Q", a.encode('latin_1'))[0])
2022-10-21 03:35:59 +08:00
aa = bufs [ a ]
aaa . append ( aa )
self . cl_cache . append ( ( kernel , [ k [ ' global_work_size ' ] , k [ ' local_work_size ' ] , * aaa ] ) )
2023-08-29 10:59:55 +08:00
if DEBUG > = 1 : print ( f " thneed: total bufs loaded: { len ( bufs . keys ( ) ) } " )
2022-10-21 03:35:59 +08:00
# load inputs
for k in jdat [ ' inputs ' ] :
self . inputs [ k [ ' name ' ] ] = bufs [ k [ ' buffer_id ' ] ]
# load outputs
for k in jdat [ ' outputs ' ] :
self . outputs . append ( bufs [ k [ ' buffer_id ' ] ] )
def save ( self , output_fn ) :
# this is the struct that will be saved
jdat = { " binaries " : [ ] , " programs " : { } , " kernels " : [ ] , " objects " : [ ] }
# build the pieces of this struct
weights = [ ]
binaries = [ ]
saved_objs = set ( )
saved_binaries = set ( )
for prg , args in self . cl_cache :
# get binaries for saving
if prg . name not in saved_binaries :
2023-07-19 10:59:30 +08:00
binary = prg . clprograms [ 0 ] . get_info ( cl . program_info . BINARIES )
2022-10-21 03:35:59 +08:00
assert len ( binary ) == 1
jdat [ ' binaries ' ] . append ( { " name " : prg . name , " length " : len ( binary [ 0 ] ) } )
binaries . append ( binary [ 0 ] )
saved_binaries . add ( prg . name )
2023-03-11 08:56:07 +08:00
2022-10-21 03:35:59 +08:00
# get the args from the kernel, some need the data saved
targs , args_size = [ ] , [ ]
argdtypes = prg . argdtypes if prg . argdtypes is not None else [ None ] * ( len ( args ) - 2 )
for a , d in zip ( args [ 2 : ] , argdtypes ) :
if d == np . int16 :
targs . append ( struct . pack ( " H " , a ) . decode ( " latin_1 " ) )
args_size . append ( 2 )
elif d == np . int32 :
targs . append ( struct . pack ( " I " , a ) . decode ( " latin_1 " ) )
args_size . append ( 4 )
elif isinstance ( a , cl . LocalMemory ) :
targs . append ( " " )
args_size . append ( a . size )
elif d is None :
if getattr ( a , " global_id " , None ) is None :
setattr ( a , " global_id " , self . gobj )
self . gobj + = 1
ptr = struct . pack ( " Q " , a . global_id ) . decode ( " latin_1 " )
if ptr not in saved_objs :
if isinstance ( a , cl . Buffer ) :
needs_load = a in self . buffers_to_save
jdat [ ' objects ' ] . append ( {
" id " : ptr , " arg_type " : " float* " , " needs_load " : needs_load , " size " : a . size ,
} )
if needs_load :
data = np . empty ( a . size / / 4 , dtype = np . float32 )
2023-05-04 03:15:28 +08:00
cl . enqueue_copy ( CL . cl_queue [ 0 ] , data , a , is_blocking = True )
2022-10-21 03:35:59 +08:00
weights . append ( data . tobytes ( ) )
elif isinstance ( a , cl . Image ) :
2023-03-11 08:56:07 +08:00
assert a . format == cl . ImageFormat ( cl . channel_order . RGBA , cl . channel_type . HALF_FLOAT if FLOAT16 else cl . channel_type . FLOAT ) , " wrong type "
2022-10-21 03:35:59 +08:00
needs_load = a in self . buffers_to_save
row_pitch = ( a . shape [ 0 ] * 4 * ( 2 if FLOAT16 else 4 ) + 63 ) / / 64 * 64
size = row_pitch * a . shape [ 1 ]
# this is *2 if float16 and *4 if float32
2023-07-19 10:59:30 +08:00
buf = cl . Buffer ( CL . cl_ctxs [ 0 ] , cl . mem_flags . READ_WRITE , size = size * ( 2 if FLOAT16 else 1 ) )
2022-10-21 03:35:59 +08:00
# zero out the buffer
2023-05-04 03:15:28 +08:00
cl . enqueue_copy ( CL . cl_queue [ 0 ] , buf , b ' \x00 ' * buf . size , is_blocking = True )
2022-10-21 03:35:59 +08:00
2023-11-02 14:01:32 +08:00
CLProgram ( " from_image_strided " , compile_gpu ( """
2022-10-21 03:35:59 +08:00
__kernel void from_image_strided ( read_only image2d_t in , __global float4 * out , int row_pitch ) {
const sampler_t smp = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST ;
int2 l ;
l . y = get_global_id ( 1 ) ;
l . x = get_global_id ( 0 ) ;
out [ l . y * row_pitch + l . x ] = read_imagef ( in , smp , l ) ;
}
2023-11-04 03:31:29 +08:00
""" ), argdtypes=(None, None, np.int32))(a, buf, row_pitch//(4*(2 if FLOAT16 else 4)), global_size=a.shape)
2022-10-21 03:35:59 +08:00
# multiple of 32 isn't enough
jdat [ ' objects ' ] . append ( {
" id " : ptr , " needs_load " : needs_load , " size " : size , " arg_type " : " image2d_t " ,
" width " : a . shape [ 0 ] , " height " : a . shape [ 1 ] , " row_pitch " : row_pitch , " float32 " : not FLOAT16 ,
} )
if needs_load :
data = np . empty ( size / / ( 2 if FLOAT16 else 4 ) , dtype = np . float32 )
2023-05-04 03:15:28 +08:00
cl . enqueue_copy ( CL . cl_queue [ 0 ] , data , buf , is_blocking = True )
2022-10-21 03:35:59 +08:00
if FLOAT16 : data = data . astype ( np . float16 )
weights . append ( data . tobytes ( ) )
else :
raise Exception ( " unknown object " , a )
#print(jdat['objects'][-1])
saved_objs . add ( ptr )
targs . append ( ptr )
args_size . append ( 8 )
else :
raise Exception ( " idk this type " )
# save the kernel itself
jdat [ ' kernels ' ] . append ( {
" name " : prg . name ,
" work_dim " : len ( args [ 0 ] ) ,
" global_work_size " : args [ 0 ] ,
# TODO: C++ thneed requires a local_work_size, so we fill it with ones
" local_work_size " : [ 1 for _ in args [ 0 ] ] if args [ 1 ] is None else args [ 1 ] ,
" num_args " : len ( args ) - 2 ,
" args " : targs ,
2023-03-11 08:56:07 +08:00
" args_size " : args_size
2022-10-21 03:35:59 +08:00
} )
jdat [ ' outputs ' ] = [ {
" buffer_id " : struct . pack ( " Q " , x . global_id ) . decode ( " latin_1 " ) ,
" size " : x . size ,
} for x in self . outputs ]
jdat [ ' inputs ' ] = [ {
" buffer_id " : struct . pack ( " Q " , v . global_id ) . decode ( " latin_1 " ) ,
" size " : v . size ,
" name " : k
} for k , v in self . inputs . items ( ) ] [ : : - 1 ]
print ( f " saving thneed to { output_fn } " )
with open ( output_fn , " wb " ) as f :
j = json . dumps ( jdat , ensure_ascii = False ) . encode ( ' latin_1 ' )
f . write ( struct . pack ( " I " , len ( j ) ) )
f . write ( j )
f . write ( b ' ' . join ( weights ) )
f . write ( b ' ' . join ( binaries ) )
def run ( self ) :
events = [ ]
st = time . monotonic ( )
for prg , args in self . cl_cache :
2023-07-19 10:59:30 +08:00
events . append ( prg . clprgs [ 0 ] ( CL . cl_queue [ 0 ] , * args ) )
2022-10-21 03:35:59 +08:00
mt = time . monotonic ( )
2023-05-04 03:15:28 +08:00
CL . synchronize ( )
2023-01-10 04:40:01 +08:00
et = time . monotonic ( ) - st
print ( f " submit in { ( mt - st ) * 1000.0 : .2f } ms, total runtime is { et * 1000.0 : .2f } ms " )
2022-10-21 03:35:59 +08:00
2023-02-21 01:50:46 +08:00
if DEBUGCL > = 2 :
for i , ( ( prg , args ) , e ) in enumerate ( zip ( self . cl_cache , events ) ) :
2023-10-16 11:39:46 +08:00
print ( f " { i : 3d } { prg . name : 25s } " + " queued @ %5.2f ms, submit @ %5.2f ms, start @ %5.2f ms, end @ %5.2f ms " % tuple ( ( x * OSX_TIMING_RATIO - st * 1e9 ) / 1e6 for x in [ e . profile . queued , e . profile . submit , e . profile . start , e . profile . end ] ) )
2023-01-19 10:02:02 +08:00
if DEBUGCL > = 1 :
2023-01-25 02:19:04 +08:00
total_runtime = 0
2022-10-21 03:35:59 +08:00
for i , ( ( prg , args ) , e ) in enumerate ( zip ( self . cl_cache , events ) ) :
2023-01-29 11:02:51 +08:00
runtime = ( e . profile . end - e . profile . start ) * OSX_TIMING_RATIO
2023-10-16 11:39:46 +08:00
print ( f " { i : 3d } time { total_runtime / 1e6 : 5.2f } ms running { prg . name : 25s } with { str ( args [ 0 ] ) : 15s } { str ( args [ 1 ] ) : 15s } count { len ( args ) - 2 : 2d } runtime { runtime / 1e3 : 7.2f } us { ( getattr ( prg , ' op_estimate ' , float ( ' nan ' ) ) ) / runtime : 9.2f } GFLOPS -> { args [ 2 ] . shape if hasattr ( args [ 2 ] , ' shape ' ) else args [ 2 ] . size } " )
2023-03-04 14:00:09 +08:00
if hasattr ( prg , ' prg ' ) and ( ( DEBUGCL > = 2 and getenv ( " PRINT_KERNEL " , - 1 ) == i ) or DEBUGCL > = 3 ) :
2023-01-19 10:02:02 +08:00
print ( prg . prg )
2023-01-25 02:19:04 +08:00
total_runtime + = runtime
print ( f " total runtime: { total_runtime / 1e6 : .2f } ms wall time: { et * 1000.0 : .2f } ms " )
2023-03-05 07:31:51 +08:00
return total_runtime / 1e9
2023-01-13 04:26:58 +08:00
return et