mirror of https://github.com/commaai/tinygrad.git
simpler
This commit is contained in:
parent
bee89a4840
commit
4faf3a0aed
|
@ -159,13 +159,12 @@ class Conv2D(Function):
|
||||||
def forward(ctx, x, w):
|
def forward(ctx, x, w):
|
||||||
cout,cin,H,W = w.shape
|
cout,cin,H,W = w.shape
|
||||||
ret = np.zeros((x.shape[0], cout, x.shape[2]-(H-1), x.shape[3]-(W-1)), dtype=w.dtype)
|
ret = np.zeros((x.shape[0], cout, x.shape[2]-(H-1), x.shape[3]-(W-1)), dtype=w.dtype)
|
||||||
for Y in range(ret.shape[2]):
|
for j in range(H):
|
||||||
for X in range(ret.shape[3]):
|
for i in range(W):
|
||||||
for j in range(H):
|
tw = w[:, :, j, i]
|
||||||
for i in range(W):
|
for Y in range(ret.shape[2]):
|
||||||
tx = x[:, :, Y+j, X+i]
|
for X in range(ret.shape[3]):
|
||||||
tw = w[:, :, j, i]
|
ret[:, :, Y, X] += x[:, :, Y+j, X+i].dot(tw.T)
|
||||||
ret[:, :, Y, X] += tx.dot(tw.T)
|
|
||||||
return ret
|
return ret
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
|
Loading…
Reference in New Issue