This commit is contained in:
George Hotz 2020-10-19 09:37:07 -07:00
parent bee89a4840
commit 4faf3a0aed
1 changed files with 6 additions and 7 deletions

View File

@ -159,13 +159,12 @@ class Conv2D(Function):
def forward(ctx, x, w): def forward(ctx, x, w):
cout,cin,H,W = w.shape cout,cin,H,W = w.shape
ret = np.zeros((x.shape[0], cout, x.shape[2]-(H-1), x.shape[3]-(W-1)), dtype=w.dtype) ret = np.zeros((x.shape[0], cout, x.shape[2]-(H-1), x.shape[3]-(W-1)), dtype=w.dtype)
for Y in range(ret.shape[2]): for j in range(H):
for X in range(ret.shape[3]): for i in range(W):
for j in range(H): tw = w[:, :, j, i]
for i in range(W): for Y in range(ret.shape[2]):
tx = x[:, :, Y+j, X+i] for X in range(ret.shape[3]):
tw = w[:, :, j, i] ret[:, :, Y, X] += x[:, :, Y+j, X+i].dot(tw.T)
ret[:, :, Y, X] += tx.dot(tw.T)
return ret return ret
@staticmethod @staticmethod