mirror of https://github.com/commaai/tinygrad.git
comments and pow
This commit is contained in:
parent
2075fdeb4f
commit
4535d39baa
|
@ -1,12 +1,15 @@
|
|||
#!/usr/bin/env python3
|
||||
|
||||
# RISK architecture is going to change everything
|
||||
# implement on S7t-VG6 (lol, too much $$$)
|
||||
|
||||
# Arty A7-100T
|
||||
# 256 MB of DDR3 with 2.6 GB/s of RAM bandwidth (vs 512 GB/s on S7t-VG6)
|
||||
# 255K 19-bit elements
|
||||
|
||||
# S7t-VG6
|
||||
# 16 GB of GDDR6
|
||||
# 189 Mb embedded RAM, aka 9M 19-bit elements
|
||||
|
||||
import functools
|
||||
import numpy as np
|
||||
from collections import defaultdict
|
||||
|
@ -34,7 +37,7 @@ from collections import defaultdict
|
|||
# <weight> <input> <empty> <output>
|
||||
|
||||
SZ = 32
|
||||
SLOTSIZE = 1024*1024*2 # 5MB, for 20MB total
|
||||
SLOTSIZE = 1024*1024*2 # 5MB, for 20MB total. 8M elements
|
||||
sram = np.zeros((SLOTSIZE*4), dtype=np.float32)
|
||||
regfile = {}
|
||||
SLOT = lambda x: x*SLOTSIZE
|
||||
|
@ -47,7 +50,7 @@ class Reg(Enum):
|
|||
MATMUL_WEIGHTS = 2
|
||||
MATMUL_OUTPUT = 3
|
||||
|
||||
# this should be a generic function
|
||||
# this should be a generic function with a LUT, similar to the ANE
|
||||
class UnaryOps(Enum):
|
||||
RELU = 0
|
||||
EXP = 1
|
||||
|
@ -60,6 +63,7 @@ class BinaryOps(Enum):
|
|||
MUL = 2
|
||||
DIV = 3
|
||||
MULACC = 4
|
||||
POW = 5
|
||||
|
||||
for t in Reg:
|
||||
regfile[t] = np.zeros((SZ, SZ), dtype=np.float32)
|
||||
|
@ -133,11 +137,17 @@ def riski_div():
|
|||
def riski_mulacc():
|
||||
regfile[Reg.MATMUL_OUTPUT] += regfile[Reg.MATMUL_INPUT] * regfile[Reg.MATMUL_WEIGHTS]
|
||||
|
||||
@count
|
||||
def riski_pow():
|
||||
regfile[Reg.MATMUL_OUTPUT] = np.pow(regfile[Reg.MATMUL_INPUT], regfile[Reg.MATMUL_WEIGHTS])
|
||||
|
||||
# TODO: make accumulate a bit in the instruction available to all
|
||||
binops = {BinaryOps.ADD: riski_add,
|
||||
BinaryOps.SUB: riski_sub,
|
||||
BinaryOps.MUL: riski_mul,
|
||||
BinaryOps.DIV: riski_div,
|
||||
BinaryOps.MULACC: riski_mulacc}
|
||||
BinaryOps.MULACC: riski_mulacc,
|
||||
BinaryOps.POW: riski_pow}
|
||||
|
||||
@count
|
||||
def riski_matmul():
|
||||
|
@ -181,7 +191,7 @@ def riski_dmaw(address, shp):
|
|||
print("DMAW %d elements" % np.prod(shp))
|
||||
return np.copy(sram[address:address+np.prod(shp)].reshape(shp))
|
||||
|
||||
# *** RISK-5 code ***
|
||||
# *** RISK-5 code to be compiled ***
|
||||
|
||||
def risk_unop(x, op):
|
||||
riski_dmar(SLOT(0), x)
|
||||
|
|
Loading…
Reference in New Issue