mirror of https://github.com/commaai/tinygrad.git
Remove Tensor.data (#565)
This commit is contained in:
parent
4efe0169bb
commit
7944cfdadc
|
@ -44,7 +44,7 @@ def infer(model, img):
|
|||
# if you want to look at the outputs
|
||||
"""
|
||||
import matplotlib.pyplot as plt
|
||||
plt.plot(out.data[0])
|
||||
plt.plot(out.numpy()[0])
|
||||
plt.show()
|
||||
"""
|
||||
return out, retimg
|
||||
|
@ -68,7 +68,7 @@ if __name__ == "__main__":
|
|||
ret, frame = cap.read()
|
||||
img = Image.fromarray(frame[:, :, [2,1,0]])
|
||||
out, retimg = infer(model, img)
|
||||
print(np.argmax(out.data), np.max(out.data), lbls[np.argmax(out.data)])
|
||||
print(np.argmax(out.numpy()), np.max(out.numpy()), lbls[np.argmax(out.numpy())])
|
||||
SCALE = 3
|
||||
simg = cv2.resize(retimg, (224*SCALE, 224*SCALE))
|
||||
retimg = cv2.cvtColor(simg, cv2.COLOR_RGB2BGR)
|
||||
|
@ -84,5 +84,5 @@ if __name__ == "__main__":
|
|||
img = Image.open(url)
|
||||
st = time.time()
|
||||
out, _ = infer(model, img)
|
||||
print(np.argmax(out.data), np.max(out.data), lbls[np.argmax(out.data)])
|
||||
print(np.argmax(out.numpy()), np.max(out.numpy()), lbls[np.argmax(out.numpy())])
|
||||
print(f"did inference in {(time.time()-st):2f}")
|
||||
|
|
|
@ -71,14 +71,14 @@ class BigConvNet:
|
|||
with open(filename+'.npy', 'wb') as f:
|
||||
for par in optim.get_parameters(self):
|
||||
#if par.requires_grad:
|
||||
np.save(f, par.cpu().data)
|
||||
np.save(f, par.cpu().numpy())
|
||||
|
||||
def load(self, filename):
|
||||
with open(filename+'.npy', 'rb') as f:
|
||||
for par in optim.get_parameters(self):
|
||||
#if par.requires_grad:
|
||||
try:
|
||||
par.cpu().data[:] = np.load(f)
|
||||
par.cpu().numpy()[:] = np.load(f)
|
||||
if GPU:
|
||||
par.gpu()
|
||||
except:
|
||||
|
|
|
@ -88,8 +88,8 @@ if __name__ == "__main__":
|
|||
opt_time = (time.time()-st)*1000.0
|
||||
|
||||
st = time.time()
|
||||
loss = loss.cpu().data
|
||||
cat = np.argmax(out.cpu().data, axis=1)
|
||||
loss = loss.cpu().numpy()
|
||||
cat = np.argmax(out.cpu().numpy(), axis=1)
|
||||
accuracy = (cat == Y).mean()
|
||||
finish_time = (time.time()-st)*1000.0
|
||||
|
||||
|
|
|
@ -60,7 +60,7 @@ cmd = sys.argv[1]
|
|||
vgg7 = Vgg7()
|
||||
|
||||
def nansbane(p):
|
||||
if numpy.isnan(numpy.min(p.data)):
|
||||
if numpy.isnan(numpy.min(p.numpy())):
|
||||
raise Exception("A NaN in the model has been detected. This model will not be interacted with to prevent further damage.")
|
||||
|
||||
def load_and_save(path, save):
|
||||
|
@ -90,7 +90,7 @@ elif cmd == "execute":
|
|||
|
||||
load_and_save(model, False)
|
||||
|
||||
image_save(out_file, vgg7.forward(Tensor(image_load(in_file))).data)
|
||||
image_save(out_file, vgg7.forward(Tensor(image_load(in_file))).numpy())
|
||||
elif cmd == "execute_full":
|
||||
model = sys.argv[2]
|
||||
in_file = sys.argv[3]
|
||||
|
@ -179,7 +179,7 @@ elif cmd == "train":
|
|||
optim.step()
|
||||
|
||||
# warning: used by sample probability adjuster
|
||||
loss_indicator = loss.max().data[0]
|
||||
loss_indicator = loss.max().numpy()[0]
|
||||
print("Round " + str(rnum) + " : " + str(loss_indicator))
|
||||
|
||||
if (rnum % rounds_per_save) == 0:
|
||||
|
|
|
@ -44,6 +44,6 @@ img -= 0.5
|
|||
img /= 0.5
|
||||
|
||||
out = m.forward(Tensor(img))
|
||||
outnp = out.cpu().data.ravel()
|
||||
outnp = out.cpu().numpy().ravel()
|
||||
choice = outnp.argmax()
|
||||
print(out.shape, choice, outnp[choice], lbls[choice])
|
||||
|
|
|
@ -7,7 +7,7 @@ import os
|
|||
# where the * is a number starting at 0.
|
||||
# Each file is simply raw little-endian floats,
|
||||
# as readable by: numpy.fromfile(path, "<f4")
|
||||
# and as writable by: t.data.astype("<f4", "C").tofile(path)
|
||||
# and as writable by: t.numpy().astype("<f4", "C").tofile(path)
|
||||
# This format is intended to be extremely simple to get into literally anything.
|
||||
# It is not intended to be structural or efficient - reloading a network when
|
||||
# unnecessary is inefficient anyway.
|
||||
|
@ -49,7 +49,7 @@ class KinneDir:
|
|||
"""
|
||||
path = f"{self.base}{self.next_part_index}.bin"
|
||||
if self.save:
|
||||
t.data.astype("<f4", "C").tofile(path)
|
||||
t.numpy().astype("<f4", "C").tofile(path)
|
||||
self.metadata.write(f"{self.next_part_index}: {t.shape}\n")
|
||||
else:
|
||||
t.assign(Tensor(numpy.fromfile(path, "<f4")).reshape(shape=t.shape))
|
||||
|
|
|
@ -172,7 +172,7 @@ class Vgg7:
|
|||
tile_t = Tensor(tile)
|
||||
tile_fwd_t = self.forward(tile_t)
|
||||
# Replace tile.
|
||||
image_out[:, :, out_y:out_y + out_h, out_x:out_x + out_w] = tile_fwd_t.data
|
||||
image_out[:, :, out_y:out_y + out_h, out_x:out_x + out_w] = tile_fwd_t.numpy()
|
||||
|
||||
return image_out
|
||||
|
||||
|
|
|
@ -38,7 +38,7 @@ class Upsample:
|
|||
def upsampleNearest(self, input):
|
||||
# TODO: Implement actual interpolation function
|
||||
# inspired: https://github.com/pytorch/pytorch/blob/master/torch/csrc/api/include/torch/nn/functional/upsampling.h
|
||||
return input.cpu().data.repeat(self.scale_factor, axis=len(input.shape)-2).repeat(self.scale_factor, axis=len(input.shape)-1)
|
||||
return input.cpu().numpy().repeat(self.scale_factor, axis=len(input.shape)-2).repeat(self.scale_factor, axis=len(input.shape)-1)
|
||||
|
||||
def __repr__(self):
|
||||
return f"Upsample(scale_factor={self.scale_factor!r}, mode={self.mode!r})"
|
||||
|
|
|
@ -18,7 +18,7 @@ def show_labels(prediction, confidence = 0.5, num_classes = 80):
|
|||
coco_labels = fetch('https://raw.githubusercontent.com/pjreddie/darknet/master/data/coco.names')
|
||||
coco_labels = coco_labels.decode('utf-8').split('\n')
|
||||
|
||||
prediction = prediction.detach().cpu().data
|
||||
prediction = prediction.detach().cpu().numpy()
|
||||
|
||||
conf_mask = (prediction[:,:,4] > confidence)
|
||||
conf_mask = np.expand_dims(conf_mask, 2)
|
||||
|
@ -119,7 +119,7 @@ def bbox_iou(box1, box2):
|
|||
|
||||
|
||||
def process_results(prediction, confidence = 0.9, num_classes = 80, nms_conf = 0.4):
|
||||
prediction = prediction.detach().cpu().data
|
||||
prediction = prediction.detach().cpu().numpy()
|
||||
conf_mask = (prediction[:,:,4] > confidence)
|
||||
conf_mask = np.expand_dims(conf_mask, 2)
|
||||
prediction = prediction * conf_mask
|
||||
|
@ -275,7 +275,7 @@ def predict_transform(prediction, inp_dim, anchors, num_classes):
|
|||
prediction = prediction.reshape(shape=(batch_size, grid_size*grid_size*num_anchors, bbox_attrs))
|
||||
|
||||
# st = time.time()
|
||||
prediction_cpu = prediction.cpu().data
|
||||
prediction_cpu = prediction.cpu().numpy()
|
||||
# print('put on CPU in %.2f s' % (time.time() - st))
|
||||
|
||||
anchors = [(a[0]/stride, a[1]/stride) for a in anchors]
|
||||
|
@ -417,11 +417,11 @@ class Darknet:
|
|||
print(self.blocks[i + 1]["type"], "weights", i)
|
||||
model = self.module_list[i]
|
||||
conv = model[0]
|
||||
print(conv.weight.cpu().data[0][0][0])
|
||||
print(conv.weight.cpu().numpy()[0][0][0])
|
||||
if conv.bias is not None:
|
||||
print("biases")
|
||||
print(conv.bias.shape)
|
||||
print(conv.bias.cpu().data[0][0:5])
|
||||
print(conv.bias.cpu().numpy()[0][0:5])
|
||||
else:
|
||||
print("None biases for layer", i)
|
||||
|
||||
|
@ -525,7 +525,7 @@ class Darknet:
|
|||
if (layers[1]) > 0: layers[1] = layers[1] - i
|
||||
map1 = outputs[i + layers[0]]
|
||||
map2 = outputs[i + layers[1]]
|
||||
x = Tensor(np.concatenate((map1.cpu().data, map2.cpu().data), 1))
|
||||
x = Tensor(np.concatenate((map1.cpu().numpy(), map2.cpu().numpy()), 1))
|
||||
elif module_type == "shortcut":
|
||||
from_ = int(module["from"])
|
||||
x = outputs[i - 1] + outputs[i + from_]
|
||||
|
@ -540,7 +540,7 @@ class Darknet:
|
|||
detections = x
|
||||
write = 1
|
||||
else:
|
||||
detections = Tensor(np.concatenate((detections.cpu().data, x.cpu().data), 1))
|
||||
detections = Tensor(np.concatenate((detections.cpu().numpy(), x.cpu().numpy()), 1))
|
||||
|
||||
# print(module_type, 'layer took %.2f s' % (time.time() - st))
|
||||
outputs[i] = x
|
||||
|
|
|
@ -9,8 +9,8 @@ def mask_like(like, mask_inx, mask_value = 1.0):
|
|||
def jacobian(func, input):
|
||||
output = func(input)
|
||||
|
||||
ji = input.data.reshape(-1).shape[-1]
|
||||
jo = output.data.reshape(-1).shape[-1]
|
||||
ji = input.numpy().reshape(-1).shape[-1]
|
||||
jo = output.numpy().reshape(-1).shape[-1]
|
||||
J = np.zeros((jo,ji), dtype=np.float32)
|
||||
|
||||
for o in range(jo):
|
||||
|
@ -19,25 +19,25 @@ def jacobian(func, input):
|
|||
|
||||
# tinygrad doesn't support slicing, tiny-hack to select
|
||||
# the needed scalar an backpropagate only through it
|
||||
o_scalar = Tensor(mask_like(output.data, o, 1.)).mul(output).sum()
|
||||
o_scalar = Tensor(mask_like(output.numpy(), o, 1.)).mul(output).sum()
|
||||
o_scalar.backward()
|
||||
|
||||
for i, grad in enumerate(input.grad.data.reshape(-1)):
|
||||
for i, grad in enumerate(input.grad.numpy().reshape(-1)):
|
||||
J[o,i] = grad
|
||||
return J
|
||||
|
||||
def numerical_jacobian(func, input, eps = 1e-6):
|
||||
output = func(input)
|
||||
|
||||
ji = input.data.reshape(-1).shape[-1]
|
||||
jo = output.data.reshape(-1).shape[-1]
|
||||
ji = input.numpy().reshape(-1).shape[-1]
|
||||
jo = output.numpy().reshape(-1).shape[-1]
|
||||
NJ = np.zeros((jo, ji), dtype=np.float32)
|
||||
|
||||
for i in range(ji):
|
||||
eps_perturb = mask_like(input.data, i, mask_value = eps)
|
||||
eps_perturb = mask_like(input.numpy(), i, mask_value = eps)
|
||||
|
||||
output_perturb_add = func(Tensor(input.data + eps_perturb)).data.reshape(-1)
|
||||
output_perturb_sub = func(Tensor(input.data - eps_perturb)).data.reshape(-1)
|
||||
output_perturb_add = func(Tensor(input.numpy() + eps_perturb)).numpy().reshape(-1)
|
||||
output_perturb_sub = func(Tensor(input.numpy() - eps_perturb)).numpy().reshape(-1)
|
||||
|
||||
grad_approx = ((output_perturb_add) - (output_perturb_sub)) / (2*eps)
|
||||
|
||||
|
|
|
@ -33,10 +33,10 @@ def train(model, X_train, Y_train, optim, steps, BS=128, lossfn=sparse_categoric
|
|||
|
||||
# printing
|
||||
if not noloss:
|
||||
cat = np.argmax(out.cpu().data, axis=-1)
|
||||
cat = np.argmax(out.cpu().numpy(), axis=-1)
|
||||
accuracy = (cat == y).mean()
|
||||
|
||||
loss = loss.detach().cpu().data
|
||||
loss = loss.detach().cpu().numpy()
|
||||
losses.append(loss)
|
||||
accuracies.append(accuracy)
|
||||
t.set_description("loss %.2f accuracy %.2f" % (loss, accuracy))
|
||||
|
@ -50,7 +50,7 @@ def evaluate(model, X_test, Y_test, num_classes=None, BS=128, return_predict=Fal
|
|||
for i in trange((len(Y_test)-1)//BS+1, disable=getenv('CI', False)):
|
||||
x = Tensor(transform(X_test[i*BS:(i+1)*BS]))
|
||||
out = model.forward(x) if hasattr(model, 'forward') else model(x)
|
||||
Y_test_preds_out[i*BS:(i+1)*BS] = out.cpu().data
|
||||
Y_test_preds_out[i*BS:(i+1)*BS] = out.cpu().numpy()
|
||||
Y_test_preds = np.argmax(Y_test_preds_out, axis=-1)
|
||||
Y_test = target_transform(Y_test)
|
||||
return (Y_test == Y_test_preds).mean(), Y_test_preds
|
||||
|
|
|
@ -58,7 +58,7 @@ class Transformer:
|
|||
|
||||
def forward(self, x):
|
||||
bs = x.shape[0]
|
||||
xnp = x.cpu().data.astype(np.int32)
|
||||
xnp = x.cpu().numpy().astype(np.int32)
|
||||
onehot = np.zeros((bs, x.shape[1], self.maxlen+self.syms), dtype=np.float32)
|
||||
for i in range(x.shape[1]):
|
||||
onehot[range(bs), i, i] = 1
|
||||
|
|
|
@ -46,7 +46,7 @@ def _infer(model: EfficientNet, img, bs=1):
|
|||
# run the net
|
||||
if bs > 1: img = img.repeat(bs, axis=0)
|
||||
out = model.forward(Tensor(img)).cpu()
|
||||
return _LABELS[np.argmax(out.data[0])]
|
||||
return _LABELS[np.argmax(out.numpy()[0])]
|
||||
|
||||
chicken_img = Image.open(pathlib.Path(__file__).parent / 'efficientnet/Chicken.jpg')
|
||||
car_img = Image.open(pathlib.Path(__file__).parent / 'efficientnet/car.jpg')
|
||||
|
|
|
@ -18,19 +18,19 @@ class TestNN(unittest.TestCase):
|
|||
bn.bias = Tensor.randn(sz)
|
||||
bn.running_mean = Tensor.randn(sz)
|
||||
bn.running_var = Tensor.randn(sz)
|
||||
bn.running_var.data[bn.running_var.data < 0] = 0
|
||||
bn.running_var.numpy()[bn.running_var.numpy() < 0] = 0
|
||||
|
||||
# create in torch
|
||||
with torch.no_grad():
|
||||
tbn = torch.nn.BatchNorm2d(sz).eval()
|
||||
tbn.training = training
|
||||
tbn.weight[:] = torch.tensor(bn.weight.data)
|
||||
tbn.bias[:] = torch.tensor(bn.bias.data)
|
||||
tbn.running_mean[:] = torch.tensor(bn.running_mean.data)
|
||||
tbn.running_var[:] = torch.tensor(bn.running_var.data)
|
||||
tbn.weight[:] = torch.tensor(bn.weight.numpy())
|
||||
tbn.bias[:] = torch.tensor(bn.bias.numpy())
|
||||
tbn.running_mean[:] = torch.tensor(bn.running_mean.numpy())
|
||||
tbn.running_var[:] = torch.tensor(bn.running_var.numpy())
|
||||
|
||||
np.testing.assert_allclose(bn.running_mean.data, tbn.running_mean.detach().numpy(), rtol=1e-5)
|
||||
np.testing.assert_allclose(bn.running_var.data, tbn.running_var.detach().numpy(), rtol=1e-5)
|
||||
np.testing.assert_allclose(bn.running_mean.numpy(), tbn.running_mean.detach().numpy(), rtol=1e-5)
|
||||
np.testing.assert_allclose(bn.running_var.numpy(), tbn.running_var.detach().numpy(), rtol=1e-5)
|
||||
|
||||
# trial
|
||||
inn = Tensor.randn(2, sz, 3, 3)
|
||||
|
@ -39,15 +39,15 @@ class TestNN(unittest.TestCase):
|
|||
outt = bn(inn)
|
||||
|
||||
# in torch
|
||||
toutt = tbn(torch.tensor(inn.cpu().data))
|
||||
toutt = tbn(torch.tensor(inn.cpu().numpy()))
|
||||
|
||||
# close
|
||||
np.testing.assert_allclose(outt.data, toutt.detach().numpy(), rtol=5e-4)
|
||||
np.testing.assert_allclose(outt.numpy(), toutt.detach().numpy(), rtol=5e-4)
|
||||
|
||||
np.testing.assert_allclose(bn.running_mean.data, tbn.running_mean.detach().numpy(), rtol=1e-5)
|
||||
np.testing.assert_allclose(bn.running_mean.numpy(), tbn.running_mean.detach().numpy(), rtol=1e-5)
|
||||
|
||||
# TODO: this is failing
|
||||
# np.testing.assert_allclose(bn.running_var.data, tbn.running_var.detach().numpy(), rtol=1e-5)
|
||||
# np.testing.assert_allclose(bn.running_var.numpy(), tbn.running_var.detach().numpy(), rtol=1e-5)
|
||||
|
||||
def test_batchnorm2d_training(self):
|
||||
self.test_batchnorm2d(True)
|
||||
|
@ -64,11 +64,11 @@ class TestNN(unittest.TestCase):
|
|||
torch_layer = torch.nn.Linear(in_dim, out_dim).eval()
|
||||
torch_layer.weight[:] = torch.tensor(model.weight.numpy(), dtype=torch.float32)
|
||||
torch_layer.bias[:] = torch.tensor(model.bias.numpy(), dtype=torch.float32)
|
||||
torch_x = torch.tensor(x.cpu().data, dtype=torch.float32)
|
||||
torch_x = torch.tensor(x.cpu().numpy(), dtype=torch.float32)
|
||||
torch_z = torch_layer(torch_x)
|
||||
|
||||
# test
|
||||
np.testing.assert_allclose(z.data, torch_z.detach().numpy(), atol=5e-4, rtol=1e-5)
|
||||
np.testing.assert_allclose(z.numpy(), torch_z.detach().numpy(), atol=5e-4, rtol=1e-5)
|
||||
|
||||
BS, T, in_dim, out_dim = 4, 2, 8, 16
|
||||
_test_linear(Tensor.randn(BS, in_dim))
|
||||
|
@ -84,15 +84,15 @@ class TestNN(unittest.TestCase):
|
|||
# create in torch
|
||||
with torch.no_grad():
|
||||
torch_layer = torch.nn.Conv2d(C1, C2, kernel_size=K, stride=S, padding=P).eval()
|
||||
torch_layer.weight[:] = torch.tensor(layer.weight.data, dtype=torch.float32)
|
||||
torch_layer.bias[:] = torch.tensor(layer.bias.data, dtype=torch.float32)
|
||||
torch_layer.weight[:] = torch.tensor(layer.weight.numpy(), dtype=torch.float32)
|
||||
torch_layer.bias[:] = torch.tensor(layer.bias.numpy(), dtype=torch.float32)
|
||||
|
||||
# test
|
||||
x = Tensor.uniform(BS, C1, H, W)
|
||||
z = layer(x)
|
||||
torch_x = torch.tensor(x.cpu().data)
|
||||
torch_x = torch.tensor(x.cpu().numpy())
|
||||
torch_z = torch_layer(torch_x)
|
||||
np.testing.assert_allclose(z.data, torch_z.detach().numpy(), atol=5e-4, rtol=1e-5)
|
||||
np.testing.assert_allclose(z.numpy(), torch_z.detach().numpy(), atol=5e-4, rtol=1e-5)
|
||||
|
||||
def test_groupnorm(self):
|
||||
BS, H, W, C, G = 20, 10, 10, 6, 3
|
||||
|
@ -103,15 +103,15 @@ class TestNN(unittest.TestCase):
|
|||
# create in torch
|
||||
with torch.no_grad():
|
||||
torch_layer = torch.nn.GroupNorm(G, C).eval()
|
||||
torch_layer.weight[:] = torch.tensor(layer.weight.data, dtype=torch.float32)
|
||||
torch_layer.bias[:] = torch.tensor(layer.bias.data, dtype=torch.float32)
|
||||
torch_layer.weight[:] = torch.tensor(layer.weight.numpy(), dtype=torch.float32)
|
||||
torch_layer.bias[:] = torch.tensor(layer.bias.numpy(), dtype=torch.float32)
|
||||
|
||||
# test
|
||||
x = Tensor.randn(BS, C, H, W)
|
||||
z = layer(x)
|
||||
torch_x = torch.tensor(x.cpu().data)
|
||||
torch_x = torch.tensor(x.cpu().numpy())
|
||||
torch_z = torch_layer(torch_x)
|
||||
np.testing.assert_allclose(z.data, torch_z.detach().numpy(), atol=5e-3, rtol=5e-3)
|
||||
np.testing.assert_allclose(z.numpy(), torch_z.detach().numpy(), atol=5e-3, rtol=5e-3)
|
||||
|
||||
def test_layernorm(self):
|
||||
N, C, H, W = 20, 5, 10, 10
|
||||
|
@ -122,15 +122,15 @@ class TestNN(unittest.TestCase):
|
|||
# create in torch
|
||||
with torch.no_grad():
|
||||
torch_layer = torch.nn.LayerNorm([H, W]).eval()
|
||||
torch_layer.weight[:] = torch.tensor(layer.weight.data, dtype=torch.float32)
|
||||
torch_layer.bias[:] = torch.tensor(layer.bias.data, dtype=torch.float32)
|
||||
torch_layer.weight[:] = torch.tensor(layer.weight.numpy(), dtype=torch.float32)
|
||||
torch_layer.bias[:] = torch.tensor(layer.bias.numpy(), dtype=torch.float32)
|
||||
|
||||
# test
|
||||
x = Tensor.randn(N, C, H, W)
|
||||
z = layer(x)
|
||||
torch_x = torch.tensor(x.cpu().data)
|
||||
torch_x = torch.tensor(x.cpu().numpy())
|
||||
torch_z = torch_layer(torch_x)
|
||||
np.testing.assert_allclose(z.data, torch_z.detach().numpy(), atol=5e-3, rtol=5e-3)
|
||||
np.testing.assert_allclose(z.numpy(), torch_z.detach().numpy(), atol=5e-3, rtol=5e-3)
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
|
|
@ -14,7 +14,7 @@ def step_tinygrad(optim, kwargs={}):
|
|||
out = net.forward()
|
||||
out.backward()
|
||||
optim.step()
|
||||
return net.x.cpu().data, net.W.cpu().data
|
||||
return net.x.cpu().numpy(), net.W.cpu().numpy()
|
||||
|
||||
def step_pytorch(optim, kwargs={}):
|
||||
net = TorchNet()
|
||||
|
|
|
@ -44,7 +44,7 @@ class TestTinygrad(unittest.TestCase):
|
|||
out = out.logsoftmax()
|
||||
out = out.mul(m).add(m).sum()
|
||||
out.backward()
|
||||
return out.cpu().data, x.grad.cpu().data, W.grad.cpu().data
|
||||
return out.cpu().numpy(), x.grad.cpu().numpy(), W.grad.cpu().numpy()
|
||||
|
||||
def test_pytorch():
|
||||
x = torch.tensor(x_init, requires_grad=True)
|
||||
|
@ -70,7 +70,7 @@ class TestTinygrad(unittest.TestCase):
|
|||
out = out.logsoftmax()
|
||||
out = out.sum()
|
||||
out.backward()
|
||||
return out.cpu().data, u.cpu().grad.data, v.cpu().grad.data, w.cpu().grad.data
|
||||
return out.cpu().numpy(), u.cpu().grad.numpy(), v.cpu().grad.numpy(), w.cpu().grad.numpy()
|
||||
|
||||
def test_pytorch():
|
||||
u = torch.tensor(U_init, requires_grad=True)
|
||||
|
@ -106,7 +106,7 @@ class TestTinygrad(unittest.TestCase):
|
|||
Tensor.training = True
|
||||
n, rate = 1_000_000, 0.1
|
||||
w = Tensor.ones(n).dropout(rate)
|
||||
non_zeros = np.count_nonzero(w.cpu().data)
|
||||
non_zeros = np.count_nonzero(w.cpu().numpy())
|
||||
expected = n * (1 - rate)
|
||||
np.testing.assert_allclose(non_zeros, expected, rtol=1e-3)
|
||||
|
||||
|
|
|
@ -94,10 +94,6 @@ class Tensor:
|
|||
def detach(self): return Tensor(self.lazydata, device=self.device, requires_grad=False)
|
||||
def numpy(self) -> np.ndarray: return np.array(self.lazydata.toCPU())
|
||||
|
||||
# TODO: this keeps the legacy behavior working, remove it after refactor
|
||||
@property
|
||||
def data(self) -> np.ndarray: return self.numpy()
|
||||
|
||||
# TODO: if things are realized this won't work
|
||||
def to_(self, device:str):
|
||||
assert self.lazydata.realized is None
|
||||
|
|
Loading…
Reference in New Issue