FPGA Based Accelerator for Tinygrad (#258)

* ops_risk

* risk sim

* guessing is for winners

* minor

* better

* matmal with risk

* conv doesn't work

* closer

* conv2d works

* ops_risk

* opt2 works

* opt1 may not be possible

* opt1 is a mulacc

* arty

* attosoc example building on mac

* minor

* riscv assembler

* gucci gang

* we got C code

* not a scam

* hello

* make risk mergeable into master

* unop support
This commit is contained in:
George Hotz 2021-06-07 17:45:09 -07:00 committed by GitHub
parent 77ba198b57
commit 2075fdeb4f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
19 changed files with 4074 additions and 1 deletions

View File

@ -9,6 +9,13 @@ Why aren't the other accelerators 3D like this?
--
Tesla chip
96x96 array
32 MiB SRAM
--
SNPE is using 4x4x4 -> 4x4 (64 FMAs) in the convs.
Then it's accumulating in that matrix.

206
extra/ops_risk.py Normal file
View File

@ -0,0 +1,206 @@
import numpy as np
from tinygrad.tensor import Function
from extra.risk import *
# ************* unary ops *************
class ReLU(Function):
def forward(ctx, input):
ctx.save_for_backward(input)
return risk_unop(input, UnaryOps.RELU)
def backward(ctx, grad_output):
input, = ctx.saved_tensors
return risk_binop(grad_output, risk_unop(input, UnaryOps.GT0), BinaryOps.MUL)
class Log(Function):
def forward(ctx, input):
ctx.save_for_backward(input)
return risk_unop(input, UnaryOps.LOG)
def backward(ctx, grad_output):
input, = ctx.saved_tensors
return risk_binop(grad_output, input, BinaryOps.DIV)
class Exp(Function):
def forward(ctx, input):
ret = risk_unop(input, UnaryOps.EXP)
ctx.save_for_backward(ret)
return ret
def backward(ctx, grad_output):
ret, = ctx.saved_tensors
return risk_binop(grad_output, ret, BinaryOps.MUL)
# ************* processing ops *************
class Matmul(Function):
def forward(ctx, input, weight):
ctx.save_for_backward(input, weight)
return risk_matmul(input, weight)
def backward(ctx, grad_output):
input, weight = ctx.saved_tensors
grad_input = risk_matmul(grad_output, weight, transpose_w=True)
grad_weight = risk_matmul(input, grad_output, transpose_x=True)
return grad_input, grad_weight
class Conv2D(Function):
def forward(ctx, x, w, stride=1, groups=1):
if type(ctx.stride) == int:
ctx.stride = (ctx.stride, ctx.stride)
cout,cin,H,W = w.shape
ys,xs = ctx.stride
bs,cin_ = x.shape[0], x.shape[1]
iy,ix = x.shape[2],x.shape[3]
oy,ox = (x.shape[2]-(H-ys))//ys, (x.shape[3]-(W-xs))//xs
assert cin*ctx.groups == cin_
assert cout % ctx.groups == 0
rcout = cout//ctx.groups
# if H == 1 and W == 1 and ctx.groups == 1 and ctx.stride == (1,1):
gx = x.reshape(bs,ctx.groups,cin,x.shape[2],x.shape[3])
tx = np.lib.stride_tricks.as_strided(gx,
shape=(bs, ctx.groups, cin, oy, ox, H, W),
strides=(*gx.strides[0:3], gx.strides[3]*ys, gx.strides[4]*xs, *gx.strides[3:5]),
writeable=False,
)
tw = w.reshape(ctx.groups, rcout, cin, H, W)
ctx.save_for_backward(tx, tw, x.shape)
print((*gx.strides[0:3], gx.strides[3]*ys, gx.strides[4]*xs, *gx.strides[3:5]))
"""
ret = np.zeros((bs,ctx.groups,oy,ox,rcout),dtype=x.dtype)
for g in range(ctx.groups):
#ijYXyx,kjyx -> iYXk ->ikYX
ret[:,g] += np.tensordot(tx[:,g], tw[g], ((1,4,5),(1,2,3)))
print(bs, ctx.groups, cin)
return np.moveaxis(ret,4,2).reshape(bs, cout, oy, ox)
"""
riski_dmar(SLOT(0), x) # bs, groups, cin, x.shape[2], x.shape[3]
riski_dmar(SLOT(1), w) # groups, rcout, cin, H, W
risk_reset_counts()
print(bs, ctx.groups, rcout, oy, ox, cin, H, W)
for B in range(0, bs):
if cin == 1 and rcout == 1 and ctx.groups > 1:
# hmm, this doesn't work, it's not a matmul
# you always have to loop over the groups, since they aren't joint
# the idea would be to collapse the HxW into the matmul, but you'd be limited to 9 for 3x3
# and while the load is easy in the weight matrix, it's hard in the image matrix (3 strides)
# and only the diagonal of the matrix would be useful! groups aren't channels!
# [(1, 144, 58, 58), (144, 1, 3, 3)] -> (1, 144, 56, 56)
# what does a grouped 1x1 conv look like?
# bs x groups x yx -- groups x 1 --> bs x groups x yx
# it looks like a broadcasted multiply
print("opt1")
# x: bs x groups x iy x ix
# w: groups x H x W
# out: bs x groups x oy x ox
# ix x groups x groups
for g in range(0, groups, SZ):
for Y in range(0, oy):
for X in range(0, ox, SZ):
IY,IX = Y*ys,X*xs
riski_mov(Reg.MATMUL_OUTPUT, Reg.ZERO)
for y in range(IY, IY+H):
for x in range(IX, IX+W):
riski_load(Reg.MATMUL_INPUT,
SLOT(0) + B*groups*iy*ix + g*iy*ix + y*ix + x,
xs, iy*ix, min(SZ, ox-X), min(SZ, groups-g))
# 0 here is for broadcasting
riski_load(Reg.MATMUL_WEIGHTS,
SLOT(1) + g*H*W + (y-IY)*W + (x-IX),
0, H*W, SZ, min(SZ, groups-g))
riski_mulacc()
#risk_regdump()
riski_store(Reg.MATMUL_OUTPUT,
SLOT(2) + B*groups*oy*ox + g*oy*ox + Y*ox + X,
1, oy*ox, min(SZ, ox-X), min(SZ, groups-g))
elif H == 1 and W == 1 and xs == 1 and ys == 1:
print("opt2")
# oxy x cin x rcout -- unstrided 1x1
# this is a simple matmul
for g in range(0, groups):
for c in range(0, rcout, SZ):
yx = oy*ox
assert yx == iy*ix
for YX in range(0, oy*ox, SZ): # these are next to each other
# inner conv
riski_mov(Reg.MATMUL_OUTPUT, Reg.ZERO)
for ci in range(0, cin, SZ):
riski_load(Reg.MATMUL_INPUT,
SLOT(0) + B*groups*cin*yx + g*cin*yx + ci*yx + YX,
1, yx, min(SZ, yx-YX), min(SZ, cin-ci))
riski_load(Reg.MATMUL_WEIGHTS,
SLOT(1) + g*rcout*cin + c*cin + ci,
1, cin, min(SZ, cin-ci), min(SZ, rcout-c))
riski_matmul()
riski_store(Reg.MATMUL_OUTPUT,
SLOT(2) + B*groups*rcout*yx + g*rcout*yx + c*yx + YX,
1, yx, min(SZ, yx-YX), min(SZ, rcout-c))
else:
print("unoptimized")
# ox x cin x rcout -- unoptimized
for g in range(0, groups):
for c in range(0, rcout, SZ):
for Y in range(0, oy):
for X in range(0, ox, SZ):
IY,IX = Y*ys,X*xs
# inner conv
riski_mov(Reg.MATMUL_OUTPUT, Reg.ZERO)
for ci in range(0, cin, SZ):
# not a loop in 1x1 convs, 9 in 3x3, 25 in 5x5
for y in range(IY, IY+H):
for x in range(IX, IX+W):
riski_load(Reg.MATMUL_INPUT,
SLOT(0) + B*groups*cin*iy*ix + g*cin*iy*ix + ci*iy*ix + y*ix + x,
xs, iy*ix, min(SZ, ox-X), min(SZ, cin-ci))
riski_load(Reg.MATMUL_WEIGHTS,
SLOT(1) + g*rcout*cin*H*W + c*cin*H*W + ci*H*W + (y-IY)*W + (x-IX),
H*W, cin*H*W, min(SZ, cin-ci), min(SZ, rcout-c))
riski_matmul()
riski_store(Reg.MATMUL_OUTPUT,
SLOT(2) + B*groups*rcout*oy*ox + g*rcout*oy*ox + c*oy*ox + Y*ox + X,
1, oy*ox, min(SZ, ox-X), min(SZ, rcout-c))
risk_print_counts()
#print(x.shape, w.shape, "->", ret.shape)
return riski_dmaw(SLOT(2), (bs, cout, oy, ox))
def backward(ctx, grad_output):
bs,_,oy,ox = grad_output.shape
tx, tw, x_shape = ctx.saved_tensors
_,rcout,cin,H,W = tw.shape
ys,xs = ctx.stride
OY,OX = x_shape[2:4]
ggg = grad_output.reshape(bs,ctx.groups,rcout,oy,ox)
gdw = np.zeros((ctx.groups,rcout,cin,H,W), dtype=tx.dtype)
for g in range(ctx.groups):
#'ikYX,ijYXyx -> kjyx'
gdw[g] += np.tensordot(ggg[:,g], tx[:,g], ((0,2,3),(0,2,3)))
# needs to be optimized
gdx = np.zeros((bs,ctx.groups,cin,OY,OX), dtype=tx.dtype)
for k in range(oy*ox):
Y, X = k//ox, k%ox
iY,iX = Y*ys, X*xs
#gdx[:,:,: , iY:iY+H, iX:iX+W] += np.einsum('igk,gkjyx->igjyx', ggg[:,:,:,Y,X], tw)
for g in range(ctx.groups):
tg = np.dot(ggg[:,g,:,Y,X].reshape(bs, -1), tw[g].reshape(rcout, -1))
gdx[:, g, :, iY:iY+H, iX:iX+W] += tg.reshape((bs, cin, H, W))
return gdx.reshape((bs, ctx.groups*cin, OY, OX)), gdw.reshape((ctx.groups*rcout, cin, H, W))

283
extra/risk.py Executable file
View File

@ -0,0 +1,283 @@
#!/usr/bin/env python3
# RISK architecture is going to change everything
# implement on S7t-VG6 (lol, too much $$$)
# Arty A7-100T
# 256 MB of DDR3 with 2.6 GB/s of RAM bandwidth (vs 512 GB/s on S7t-VG6)
# 255K 19-bit elements
import functools
import numpy as np
from collections import defaultdict
# 32x32 * 32x32 -> 32x32 matmul = 65536 FLOPS @ 1 GHz = 64 TOPS
# mulacc is 2048 FLOPS, 32x less
# 32x32 (aka 1024 element) ALU
# 1024 wide permute
# 1024 wide load/store (1 cycle to SRAM)
# all in elements, aka TF32 (19 bits)
# targets:
# matmul input
# matmul weights
# ALU
# permute
# 1024x1024x4x19 bits = 10MB
# fully strided
# load1024 <target>, <address>, <stride x (32)>, <stride y (32)>
# 4 slots
# <input> <weight> <output> <empty>
# <empty> <output> <input> <weight>
# <weight> <input> <empty> <output>
SZ = 32
SLOTSIZE = 1024*1024*2 # 5MB, for 20MB total
sram = np.zeros((SLOTSIZE*4), dtype=np.float32)
regfile = {}
SLOT = lambda x: x*SLOTSIZE
from enum import Enum
class Reg(Enum):
ZERO = 0
# can the ALU use the same registers?
MATMUL_INPUT = 1
MATMUL_WEIGHTS = 2
MATMUL_OUTPUT = 3
# this should be a generic function
class UnaryOps(Enum):
RELU = 0
EXP = 1
LOG = 2
GT0 = 3
class BinaryOps(Enum):
ADD = 0
SUB = 1
MUL = 2
DIV = 3
MULACC = 4
for t in Reg:
regfile[t] = np.zeros((SZ, SZ), dtype=np.float32)
# *** profiler ***
cnts = defaultdict(int)
tcnts = defaultdict(int)
utils = defaultdict(int)
maxdma = 0
def count(func):
@functools.wraps(func)
def wrapper(*args, **kwargs):
cnts[func.__name__] += 1
tcnts[func.__name__] += 1
return func(*args, **kwargs)
return wrapper
import atexit
@atexit.register
def risk_print_counts():
print(cnts)
print(tcnts)
print(utils)
util_n = sum([k[0]*k[1]*v for k,v in utils.items()])
util_d = sum([SZ*SZ*v for k,v in utils.items()])
print("%.2f GOPS %d maxdma" % ((tcnts['riski_matmul']*SZ*SZ*SZ*2 + tcnts['riski_mulacc']*SZ*SZ*2)*1e-9, maxdma))
print("ran in %.2f us with util %.2f%% total %.2f us" % (sum(cnts.values())*1e-3, util_n*100/(util_d+1), sum(tcnts.values())*1e-3))
def risk_reset_counts():
global cnts, utils
cnts = defaultdict(int)
utils = defaultdict(int)
def risk_regdump():
print("\n***** regdump *****")
print(regfile[Reg.MATMUL_INPUT])
print(regfile[Reg.MATMUL_WEIGHTS])
print(regfile[Reg.MATMUL_OUTPUT])
# *** instructions ***
@count
def riski_unop(op):
if op == UnaryOps.RELU:
regfile[Reg.MATMUL_OUTPUT] = np.maximum(regfile[Reg.MATMUL_INPUT], 0)
elif op == UnaryOps.LOG:
regfile[Reg.MATMUL_OUTPUT] = np.log(regfile[Reg.MATMUL_INPUT])
elif op == UnaryOps.EXP:
regfile[Reg.MATMUL_OUTPUT] = np.exp(regfile[Reg.MATMUL_INPUT])
elif op == UnaryOps.GT0:
regfile[Reg.MATMUL_OUTPUT] = (regfile[Reg.MATMUL_INPUT] >= 0)
@count
def riski_add():
regfile[Reg.MATMUL_OUTPUT] = regfile[Reg.MATMUL_INPUT] + regfile[Reg.MATMUL_WEIGHTS]
@count
def riski_sub():
regfile[Reg.MATMUL_OUTPUT] = regfile[Reg.MATMUL_INPUT] - regfile[Reg.MATMUL_WEIGHTS]
@count
def riski_mul():
regfile[Reg.MATMUL_OUTPUT] = regfile[Reg.MATMUL_INPUT] * regfile[Reg.MATMUL_WEIGHTS]
@count
def riski_div():
regfile[Reg.MATMUL_OUTPUT] = regfile[Reg.MATMUL_INPUT] / regfile[Reg.MATMUL_WEIGHTS]
@count
def riski_mulacc():
regfile[Reg.MATMUL_OUTPUT] += regfile[Reg.MATMUL_INPUT] * regfile[Reg.MATMUL_WEIGHTS]
binops = {BinaryOps.ADD: riski_add,
BinaryOps.SUB: riski_sub,
BinaryOps.MUL: riski_mul,
BinaryOps.DIV: riski_div,
BinaryOps.MULACC: riski_mulacc}
@count
def riski_matmul():
#print("LLL:\n",regfile[Reg.MATMUL_INPUT],"\n",regfile[Reg.MATMUL_WEIGHTS])
regfile[Reg.MATMUL_OUTPUT] += \
regfile[Reg.MATMUL_INPUT] @ \
regfile[Reg.MATMUL_WEIGHTS]
@count
def riski_mov(tout, tin):
regfile[tout][:] = regfile[tin]
@count
def riski_load(target, address, stride_y=SZ, stride_x=1, len_y=SZ, len_x=SZ):
global util_n, util_d
utils[(len_y, len_x)] += 1
d = regfile[target]
d[:] = 0
for y in range(0, len_y):
for x in range(0, len_x):
d[y, x] = sram[address + y*stride_y + x*stride_x]
@count
def riski_store(target, address, stride_y=SZ, stride_x=1, len_y=SZ, len_x=SZ):
d = regfile[target]
for y in range(0, len_y):
for x in range(0, len_x):
sram[address + y*stride_y + x*stride_x] = d[y, x]
@count
def riski_dmar(address, arr):
global maxdma
arr = arr.reshape(-1)
assert(arr.shape[0] <= SLOTSIZE)
maxdma = max(maxdma, arr.shape[0])
print("DMAR %d elements" % arr.shape[0])
sram[address:address+arr.shape[0]] = arr
@count
def riski_dmaw(address, shp):
print("DMAW %d elements" % np.prod(shp))
return np.copy(sram[address:address+np.prod(shp)].reshape(shp))
# *** RISK-5 code ***
def risk_unop(x, op):
riski_dmar(SLOT(0), x)
cnt = np.prod(x.shape)
for i in range(0, np.prod(x.shape), SZ*SZ):
riski_load(Reg.MATMUL_INPUT, SLOT(0)+i)
riski_unop(op)
riski_store(Reg.MATMUL_OUTPUT, SLOT(2)+i)
return riski_dmaw(SLOT(2), x.shape)
def risk_binop(x, w, op):
riski_dmar(SLOT(0), x)
riski_dmar(SLOT(1), w)
for i in range(0, np.prod(x.shape), SZ*SZ):
riski_load(Reg.MATMUL_INPUT, SLOT(0)+i)
riski_load(Reg.MATMUL_WEIGHTS, SLOT(1)+i)
binops[op]()
riski_store(Reg.MATMUL_OUTPUT, SLOT(2)+i)
return riski_dmaw(SLOT(2), x.shape)
def risk_matmul(x, w, transpose_x=False, transpose_w=False):
# copy matrices into SRAM
# x is M x K
# w is K x N
# out is M x N
riski_dmar(SLOT(0), x)
riski_dmar(SLOT(1), w)
if transpose_x:
K,M = x.shape[-2], x.shape[-1]
else:
M,K = x.shape[-2], x.shape[-1]
if transpose_w:
N = w.shape[-2]
assert w.shape[-1] == K
else:
N = w.shape[-1]
assert w.shape[-2] == K
cnt = np.prod(x.shape[0:-2]) if len(x.shape) > 2 else 1
# do matmul
for c in range(cnt):
for m in range(0, M, SZ):
for n in range(0, N, SZ):
riski_mov(Reg.MATMUL_OUTPUT, Reg.ZERO)
for k in range(0, K, SZ):
if transpose_x:
riski_load(Reg.MATMUL_INPUT, SLOT(0)+c*M*K + k*M+m, 1, M, min(SZ, M-m), min(SZ, K-k))
else:
riski_load(Reg.MATMUL_INPUT, SLOT(0)+c*M*K + m*K+k, K, 1, min(SZ, M-m), min(SZ, K-k))
if transpose_w:
riski_load(Reg.MATMUL_WEIGHTS, SLOT(1)+c*K*N + n*K+k, 1, K, min(SZ, K-k), min(SZ, N-n))
else:
riski_load(Reg.MATMUL_WEIGHTS, SLOT(1)+c*K*N + k*N+n, N, 1, min(SZ, K-k), min(SZ, N-n))
riski_matmul()
riski_store(Reg.MATMUL_OUTPUT, SLOT(2)+c*M*N + m*N+n, N, 1, min(SZ, M-m), min(SZ, N-n))
# copy back from SRAM
return riski_dmaw(SLOT(2), (*x.shape[0:-2],M,N))
import unittest
class TestRisk(unittest.TestCase):
def test_matmul_even(self):
x = np.random.uniform(size=(SZ*8, SZ*8)).astype(np.float32)
w = np.random.uniform(size=(SZ*8, SZ*8)).astype(np.float32)
np.testing.assert_allclose(x @ w, risk_matmul(x, w), rtol=1e-5)
def test_matmul_small(self):
x = np.array([[1,2,3],[4,5,6],[7,8,9]])
w = np.array([[-1,-2,-3],[-4,-5,-6],[-7,-8,-9]])
np.testing.assert_allclose(x @ w, risk_matmul(x, w), rtol=1e-5)
def test_matmul_uneven(self):
x = np.random.uniform(size=(47, 79)).astype(np.float32)
w = np.random.uniform(size=(79, 42)).astype(np.float32)
np.testing.assert_allclose(x @ w, risk_matmul(x, w), rtol=1e-5)
def test_matmul_transpose(self):
x = np.random.uniform(size=(33, 33)).astype(np.float32)
w = np.random.uniform(size=(33, 33)).astype(np.float32)
np.testing.assert_allclose(x @ w, risk_matmul(x, w), rtol=1e-5)
np.testing.assert_allclose(x.T @ w, risk_matmul(x, w, True), rtol=1e-5)
np.testing.assert_allclose(x @ w.T, risk_matmul(x, w, False, True), rtol=1e-5)
np.testing.assert_allclose(x.T @ w.T, risk_matmul(x, w, True, True), rtol=1e-5)
def test_matmul_transpose_uneven_w(self):
x = np.random.uniform(size=(47, 79)).astype(np.float32)
w = np.random.uniform(size=(42, 79)).astype(np.float32)
np.testing.assert_allclose(x @ w.T, risk_matmul(x, w, transpose_w=True), rtol=1e-5)
def test_matmul_transpose_uneven_x(self):
x = np.random.uniform(size=(79, 47)).astype(np.float32)
w = np.random.uniform(size=(79, 42)).astype(np.float32)
np.testing.assert_allclose(x.T @ w, risk_matmul(x, w, transpose_x=True), rtol=1e-5)
if __name__ == "__main__":
np.random.seed(1337)
unittest.main(verbosity=2)

1
fpga/.gitignore vendored Normal file
View File

@ -0,0 +1 @@
out/

5
fpga/all.sh Executable file
View File

@ -0,0 +1,5 @@
#!/bin/bash -e
./riscv.sh
./build.sh
./prog.sh

24
fpga/build.sh Executable file
View File

@ -0,0 +1,24 @@
#!/usr/bin/env bash
set -ex
mkdir -p out
cd out
BASE=/Users/taylor/fun/fpga
# yosys commit 82f5829aba108be4a3786e7a237fd7bcebe61eb6
# build normally
$BASE/yosys/yosys -p "synth_xilinx -flatten -nowidelut -abc9 -arch xc7 -top top; write_json attosoc.json" ../src/attosoc.v ../src/attosoc_top.v ../src/simpleuart.v
# nextpnr-xilinx 0be5cc19f3261101730ce9274720aaf3784f83e2
# cmake -DARCH=xilinx -DBUILD_GUI=no -DBUILD_PYTHON=no -DUSE_OPENMP=No .
# python3 xilinx/python/bbaexport.py --device xc7a100tcsg324-1 --bba xilinx/xc7a100t.bba
# ./bbasm -l xilinx/xc7a100t.bba xilinx/xc7a100t.bin
$BASE/nextpnr-xilinx/nextpnr-xilinx --chipdb $BASE/nextpnr-xilinx/xilinx/xc7a100t.bin --xdc ../src/arty.xdc --json attosoc.json --write attosoc_routed.json --fasm attosoc.fasm
XRAY_UTILS_DIR=$BASE/prjxray/utils
XRAY_TOOLS_DIR=$BASE/prjxray/build/tools
XRAY_DATABASE_DIR=$BASE/prjxray/database
"${XRAY_UTILS_DIR}/fasm2frames.py" --db-root "${XRAY_DATABASE_DIR}/artix7" --part xc7a100tcsg324-1 attosoc.fasm > attosoc.frames
"${XRAY_TOOLS_DIR}/xc7frames2bit" --part_file "${XRAY_DATABASE_DIR}/artix7/xc7a100tcsg324-1/part.yaml" --part_name xc7a100tcsg324-1 --frm_file attosoc.frames --output_file attosoc.bit

15
fpga/console.py Executable file
View File

@ -0,0 +1,15 @@
#!/usr/bin/env python3
import time
import pyftdi.serialext
port = pyftdi.serialext.serial_for_url('ftdi://ftdi:2232h/2', baudrate=115200)
print(port)
while 1:
#port.write(b'a')
data = port.read(1)
print(data)
time.sleep(0.01)

27
fpga/digilent_arty.cfg Normal file
View File

@ -0,0 +1,27 @@
#
# Digilent Arty with Xilinx Artix-7 FPGA
#
# http://store.digilentinc.com/arty-artix-7-fpga-development-board-for-makers-and-hobbyists/
#
# iManufacturer 1 Digilent
# iProduct 2 Digilent USB Device
# iSerial 3 210319A28C7F
interface ftdi
ftdi_device_desc "Digilent USB Device"
ftdi_vid_pid 0x0403 0x6010
# channel 1 does not have any functionality
ftdi_channel 0
# just TCK TDI TDO TMS, no reset
ftdi_layout_init 0x0088 0x008b
reset_config none
adapter_khz 10000
source [find cpld/xilinx-xc7.cfg]
source [find cpld/jtagspi.cfg]
init
xc7_program xc7.tap
pld load 0 out/attosoc.bit
exit

3
fpga/prog.sh Executable file
View File

@ -0,0 +1,3 @@
#!/bin/bash
openocd -d -f digilent_arty.cfg

10
fpga/riscv.sh Executable file
View File

@ -0,0 +1,10 @@
#!/bin/bash -e
cd out
riscv64-unknown-elf-gcc -Os -march=rv32i -mabi=ilp32 -nostdlib ../src/main.c
#riscv64-unknown-elf-as ../src/riscv.asm
riscv64-unknown-elf-objdump -d a.out
riscv64-unknown-elf-objcopy -O binary a.out a.asm
xxd a.asm
python -c 'import struct; dat = open("a.asm", "rb").read(); print("\n".join(["%08x" % c for c in struct.unpack("I"*(len(dat)//4), dat)]))' > ../src/firmware.hex

1
fpga/src/.gitignore vendored Normal file
View File

@ -0,0 +1 @@
firmware.bin

58
fpga/src/arty.xdc Normal file
View File

@ -0,0 +1,58 @@
# R
set_property LOC G6 [get_ports led[0]]
set_property LOC G3 [get_ports led[1]]
set_property LOC J3 [get_ports led[2]]
set_property LOC K1 [get_ports led[3]]
# G
set_property LOC F6 [get_ports led[4]]
set_property LOC J4 [get_ports led[5]]
set_property LOC J2 [get_ports led[6]]
set_property LOC H6 [get_ports led[7]]
# B
set_property LOC E1 [get_ports led[8]]
set_property LOC G4 [get_ports led[9]]
set_property LOC H4 [get_ports led[10]]
set_property LOC K2 [get_ports led[11]]
# second row
# set_property LOC H5 [get_ports led[12]]
# set_property LOC J5 [get_ports led[13]]
# set_property LOC T9 [get_ports led[14]]
# set_property LOC T10 [get_ports led[15]]
set_property IOSTANDARD LVCMOS33 [get_ports led[0]]
set_property IOSTANDARD LVCMOS33 [get_ports led[1]]
set_property IOSTANDARD LVCMOS33 [get_ports led[2]]
set_property IOSTANDARD LVCMOS33 [get_ports led[3]]
set_property IOSTANDARD LVCMOS33 [get_ports led[4]]
set_property IOSTANDARD LVCMOS33 [get_ports led[5]]
set_property IOSTANDARD LVCMOS33 [get_ports led[6]]
set_property IOSTANDARD LVCMOS33 [get_ports led[7]]
set_property IOSTANDARD LVCMOS33 [get_ports led[8]]
set_property IOSTANDARD LVCMOS33 [get_ports led[9]]
set_property IOSTANDARD LVCMOS33 [get_ports led[10]]
set_property IOSTANDARD LVCMOS33 [get_ports led[11]]
set_property IOSTANDARD LVCMOS33 [get_ports led[12]]
set_property IOSTANDARD LVCMOS33 [get_ports led[13]]
set_property IOSTANDARD LVCMOS33 [get_ports led[14]]
set_property IOSTANDARD LVCMOS33 [get_ports led[15]]
set_property LOC A8 [get_ports sw[0]]
set_property LOC C11 [get_ports sw[1]]
set_property LOC C10 [get_ports sw[2]]
set_property LOC A10 [get_ports sw[3]]
set_property IOSTANDARD LVCMOS33 [get_ports sw[0]]
set_property IOSTANDARD LVCMOS33 [get_ports sw[1]]
set_property IOSTANDARD LVCMOS33 [get_ports sw[2]]
set_property IOSTANDARD LVCMOS33 [get_ports sw[3]]
set_property LOC E3 [get_ports clk_i]
set_property IOSTANDARD LVCMOS33 [get_ports clk_i]
set_property LOC A9 [get_ports ser_rx]
set_property IOSTANDARD LVCMOS33 [get_ports ser_rx]
set_property LOC D10 [get_ports ser_tx]
set_property IOSTANDARD LVCMOS33 [get_ports ser_tx]

3136
fpga/src/attosoc.v Normal file

File diff suppressed because it is too large Load Diff

48
fpga/src/attosoc_top.v Normal file
View File

@ -0,0 +1,48 @@
module top (
input clk_i,
input [3:0] sw,
output [11:0] led,
output ser_tx,
input ser_rx,
);
//assign led = {&sw, |sw, ^sw, ~^sw};
reg clk50 = 1'b0;
always @(posedge clk_i)
clk50 <= ~clk50;
wire clk;
BUFGCTRL bufg_i (
.I0(clk50),
.CE0(1'b1),
.S0(1'b1),
.O(clk)
);
// wire clk = clk_i;
//reg clkdiv;
//reg [22:0] ctr;
//always @(posedge clk) {clkdiv, ctr} <= ctr + 1'b1;
wire [7:0] soc_led;
attosoc soc_i(
.clk(clk),
.reset(sw[0]),
.led(soc_led),
.ser_tx(ser_tx),
.ser_rx(ser_rx),
);
// this maps 2 bits to each LED
generate
genvar i;
for (i = 0; i < 4; i++) begin
assign led[0 + i] = soc_led[2 * i]; // R
assign led[4 + i] = soc_led[(2 * i) + 1]; // G
assign led[8 + i] = &soc_led[2 * i +: 2]; // B
end
endgenerate
endmodule

46
fpga/src/firmware.hex Normal file
View File

@ -0,0 +1,46 @@
fd010113
01312e23
01412c23
01512a23
01612823
02112623
02812423
02912223
03212023
01712623
020007b7
1b200713
00e7a223
00400a13
06800a93
06500b13
06c00993
00800913
00100493
02000437
06f00b93
00942023
fff90913
03c000ef
00149493
fe0900e3
ff4496e3
03c000ef
01542423
034000ef
01642423
02c000ef
01342423
024000ef
01342423
01c000ef
01742423
fc1ff06f
001002b7
fff28293
fe029ee3
00008067
000102b7
fff28293
fe029ee3
00008067

46
fpga/src/main.c Normal file
View File

@ -0,0 +1,46 @@
#include <stdint.h>
#define reg_leds (*(volatile uint32_t*)0x02000000)
#define reg_uart_clkdiv (*(volatile uint32_t*)0x02000004)
#define reg_uart_data (*(volatile uint32_t*)0x02000008)
void delay();
int main() {
// 50 mhz clock
reg_uart_clkdiv = 434;
while (1) {
for (int i = 1; i < 0x100; i <<= 1) {
if (i == 4) {
sdelay();
reg_uart_data = 'h';
sdelay();
reg_uart_data = 'e';
sdelay();
reg_uart_data = 'l';
sdelay();
reg_uart_data = 'l';
sdelay();
reg_uart_data = 'o';
}
reg_leds = i;
delay();
}
}
}
void __attribute__ ((noinline)) delay() {
asm ("lui t0, 0x100\n"
"lop:"
"addi t0,t0,-0x1\n"
"bne t0,zero,lop\n"
::);
}
void __attribute__ ((noinline)) sdelay() {
asm ("lui t0, 0x10\n"
"lop2:"
"addi t0,t0,-0x1\n"
"bne t0,zero,lop2\n"
::);
}

137
fpga/src/simpleuart.v Normal file
View File

@ -0,0 +1,137 @@
/*
* PicoSoC - A simple example SoC using PicoRV32
*
* Copyright (C) 2017 Clifford Wolf <clifford@clifford.at>
*
* Permission to use, copy, modify, and/or distribute this software for any
* purpose with or without fee is hereby granted, provided that the above
* copyright notice and this permission notice appear in all copies.
*
* THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
* WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
* MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR
* ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
* WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
* ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF
* OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
*
*/
module simpleuart #(parameter integer DEFAULT_DIV = 1) (
input clk,
input resetn,
output ser_tx,
input ser_rx,
input [3:0] reg_div_we,
input [31:0] reg_div_di,
output [31:0] reg_div_do,
input reg_dat_we,
input reg_dat_re,
input [31:0] reg_dat_di,
output [31:0] reg_dat_do,
output reg_dat_wait
);
reg [31:0] cfg_divider;
reg [3:0] recv_state;
reg [31:0] recv_divcnt;
reg [7:0] recv_pattern;
reg [7:0] recv_buf_data;
reg recv_buf_valid;
reg [9:0] send_pattern;
reg [3:0] send_bitcnt;
reg [31:0] send_divcnt;
reg send_dummy;
assign reg_div_do = cfg_divider;
assign reg_dat_wait = reg_dat_we && (send_bitcnt || send_dummy);
assign reg_dat_do = recv_buf_valid ? recv_buf_data : ~0;
always @(posedge clk) begin
if (!resetn) begin
cfg_divider <= DEFAULT_DIV;
end else begin
if (reg_div_we[0]) cfg_divider[ 7: 0] <= reg_div_di[ 7: 0];
if (reg_div_we[1]) cfg_divider[15: 8] <= reg_div_di[15: 8];
if (reg_div_we[2]) cfg_divider[23:16] <= reg_div_di[23:16];
if (reg_div_we[3]) cfg_divider[31:24] <= reg_div_di[31:24];
end
end
always @(posedge clk) begin
if (!resetn) begin
recv_state <= 0;
recv_divcnt <= 0;
recv_pattern <= 0;
recv_buf_data <= 0;
recv_buf_valid <= 0;
end else begin
recv_divcnt <= recv_divcnt + 1;
if (reg_dat_re)
recv_buf_valid <= 0;
case (recv_state)
0: begin
if (!ser_rx)
recv_state <= 1;
recv_divcnt <= 0;
end
1: begin
if (2*recv_divcnt > cfg_divider) begin
recv_state <= 2;
recv_divcnt <= 0;
end
end
10: begin
if (recv_divcnt > cfg_divider) begin
recv_buf_data <= recv_pattern;
recv_buf_valid <= 1;
recv_state <= 0;
end
end
default: begin
if (recv_divcnt > cfg_divider) begin
recv_pattern <= {ser_rx, recv_pattern[7:1]};
recv_state <= recv_state + 1;
recv_divcnt <= 0;
end
end
endcase
end
end
assign ser_tx = send_pattern[0];
always @(posedge clk) begin
if (reg_div_we)
send_dummy <= 1;
send_divcnt <= send_divcnt + 1;
if (!resetn) begin
send_pattern <= ~0;
send_bitcnt <= 0;
send_divcnt <= 0;
send_dummy <= 1;
end else begin
if (send_dummy && !send_bitcnt) begin
send_pattern <= ~0;
send_bitcnt <= 15;
send_divcnt <= 0;
send_dummy <= 0;
end else
if (reg_dat_we && !send_bitcnt) begin
send_pattern <= {1'b1, reg_dat_di[7:0], 1'b0};
send_bitcnt <= 10;
send_divcnt <= 0;
end else
if (send_divcnt > cfg_divider && send_bitcnt) begin
send_pattern <= {1'b1, send_pattern[9:1]};
send_bitcnt <= send_bitcnt - 1;
send_divcnt <= 0;
end
end
end
endmodule

View File

@ -161,6 +161,22 @@ class TestOps(unittest.TestCase):
lambda x,w: torch.nn.functional.conv2d(x,w,groups=groups).relu(),
lambda x,w: Tensor.conv2d(x,w,groups=groups).relu(), atol=1e-4, grad_rtol=1e-5)
def test_grouped_conv2d(self):
groups = 2
helper_test_op([(1,2,5,5), (groups,1,3,3)],
lambda x,w: torch.nn.functional.conv2d(x,w,groups=groups).relu(),
lambda x,w: Tensor.conv2d(x,w,groups=groups).relu(), atol=1e-4, grad_rtol=1e-5, forward_only=True)
def test_fancy_conv2d(self):
bs = 2
cin = 3
cout = 1
groups = 3
H,W = 3,3
helper_test_op([(bs,cin,11,28), (groups*cout,cin//groups,H,W)],
lambda x,w: torch.nn.functional.conv2d(x,w,groups=groups).relu(),
lambda x,w: Tensor.conv2d(x,w,groups=groups).relu(), atol=1e-4, grad_rtol=1e-5, forward_only=True)
def test_strided_conv2d(self):
bs = 4
cin = 3
@ -191,4 +207,5 @@ class TestOps(unittest.TestCase):
lambda x: Tensor.avg_pool2d(x, kernel_size=ksz), rtol=1e-5)
if __name__ == '__main__':
np.random.seed(1337)
unittest.main(verbosity=2)

View File

@ -144,7 +144,7 @@ class Tensor:
for t0 in reversed(self.deepwalk()):
assert (t0.grad is not None)
with ProfileOp(t0._ctx.__class__.__name__, [t0.grad], backward=True) as po:
po.output = grads = t0._ctx.backward(t0._ctx, t0.grad.data)
grads = t0._ctx.backward(t0._ctx, t0.grad.data)
if len(t0._ctx.parents) == 1:
grads = [grads]
for t, g in zip(t0._ctx.parents, grads):
@ -355,6 +355,9 @@ def _register_ops(namespace, device=Device.CPU):
from tinygrad import ops_cpu
_register_ops(ops_cpu)
if os.getenv("RISK", None) is not None:
from extra import ops_risk
_register_ops(ops_risk)
try:
import pyopencl as cl
# TODO: move this import to require_init_gpu?