Remove Tensor.data (#565)

This commit is contained in:
Kirill 2023-02-19 03:36:12 +03:00 committed by GitHub
parent 4efe0169bb
commit 7944cfdadc
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
17 changed files with 65 additions and 69 deletions

View File

@ -44,7 +44,7 @@ def infer(model, img):
# if you want to look at the outputs
"""
import matplotlib.pyplot as plt
plt.plot(out.data[0])
plt.plot(out.numpy()[0])
plt.show()
"""
return out, retimg
@ -68,7 +68,7 @@ if __name__ == "__main__":
ret, frame = cap.read()
img = Image.fromarray(frame[:, :, [2,1,0]])
out, retimg = infer(model, img)
print(np.argmax(out.data), np.max(out.data), lbls[np.argmax(out.data)])
print(np.argmax(out.numpy()), np.max(out.numpy()), lbls[np.argmax(out.numpy())])
SCALE = 3
simg = cv2.resize(retimg, (224*SCALE, 224*SCALE))
retimg = cv2.cvtColor(simg, cv2.COLOR_RGB2BGR)
@ -84,5 +84,5 @@ if __name__ == "__main__":
img = Image.open(url)
st = time.time()
out, _ = infer(model, img)
print(np.argmax(out.data), np.max(out.data), lbls[np.argmax(out.data)])
print(np.argmax(out.numpy()), np.max(out.numpy()), lbls[np.argmax(out.numpy())])
print(f"did inference in {(time.time()-st):2f}")

View File

@ -71,14 +71,14 @@ class BigConvNet:
with open(filename+'.npy', 'wb') as f:
for par in optim.get_parameters(self):
#if par.requires_grad:
np.save(f, par.cpu().data)
np.save(f, par.cpu().numpy())
def load(self, filename):
with open(filename+'.npy', 'rb') as f:
for par in optim.get_parameters(self):
#if par.requires_grad:
try:
par.cpu().data[:] = np.load(f)
par.cpu().numpy()[:] = np.load(f)
if GPU:
par.gpu()
except:

View File

@ -88,8 +88,8 @@ if __name__ == "__main__":
opt_time = (time.time()-st)*1000.0
st = time.time()
loss = loss.cpu().data
cat = np.argmax(out.cpu().data, axis=1)
loss = loss.cpu().numpy()
cat = np.argmax(out.cpu().numpy(), axis=1)
accuracy = (cat == Y).mean()
finish_time = (time.time()-st)*1000.0

View File

@ -60,7 +60,7 @@ cmd = sys.argv[1]
vgg7 = Vgg7()
def nansbane(p):
if numpy.isnan(numpy.min(p.data)):
if numpy.isnan(numpy.min(p.numpy())):
raise Exception("A NaN in the model has been detected. This model will not be interacted with to prevent further damage.")
def load_and_save(path, save):
@ -90,7 +90,7 @@ elif cmd == "execute":
load_and_save(model, False)
image_save(out_file, vgg7.forward(Tensor(image_load(in_file))).data)
image_save(out_file, vgg7.forward(Tensor(image_load(in_file))).numpy())
elif cmd == "execute_full":
model = sys.argv[2]
in_file = sys.argv[3]
@ -179,7 +179,7 @@ elif cmd == "train":
optim.step()
# warning: used by sample probability adjuster
loss_indicator = loss.max().data[0]
loss_indicator = loss.max().numpy()[0]
print("Round " + str(rnum) + " : " + str(loss_indicator))
if (rnum % rounds_per_save) == 0:

View File

@ -44,6 +44,6 @@ img -= 0.5
img /= 0.5
out = m.forward(Tensor(img))
outnp = out.cpu().data.ravel()
outnp = out.cpu().numpy().ravel()
choice = outnp.argmax()
print(out.shape, choice, outnp[choice], lbls[choice])

View File

@ -7,7 +7,7 @@ import os
# where the * is a number starting at 0.
# Each file is simply raw little-endian floats,
# as readable by: numpy.fromfile(path, "<f4")
# and as writable by: t.data.astype("<f4", "C").tofile(path)
# and as writable by: t.numpy().astype("<f4", "C").tofile(path)
# This format is intended to be extremely simple to get into literally anything.
# It is not intended to be structural or efficient - reloading a network when
# unnecessary is inefficient anyway.
@ -49,7 +49,7 @@ class KinneDir:
"""
path = f"{self.base}{self.next_part_index}.bin"
if self.save:
t.data.astype("<f4", "C").tofile(path)
t.numpy().astype("<f4", "C").tofile(path)
self.metadata.write(f"{self.next_part_index}: {t.shape}\n")
else:
t.assign(Tensor(numpy.fromfile(path, "<f4")).reshape(shape=t.shape))

View File

@ -172,7 +172,7 @@ class Vgg7:
tile_t = Tensor(tile)
tile_fwd_t = self.forward(tile_t)
# Replace tile.
image_out[:, :, out_y:out_y + out_h, out_x:out_x + out_w] = tile_fwd_t.data
image_out[:, :, out_y:out_y + out_h, out_x:out_x + out_w] = tile_fwd_t.numpy()
return image_out

View File

@ -38,7 +38,7 @@ class Upsample:
def upsampleNearest(self, input):
# TODO: Implement actual interpolation function
# inspired: https://github.com/pytorch/pytorch/blob/master/torch/csrc/api/include/torch/nn/functional/upsampling.h
return input.cpu().data.repeat(self.scale_factor, axis=len(input.shape)-2).repeat(self.scale_factor, axis=len(input.shape)-1)
return input.cpu().numpy().repeat(self.scale_factor, axis=len(input.shape)-2).repeat(self.scale_factor, axis=len(input.shape)-1)
def __repr__(self):
return f"Upsample(scale_factor={self.scale_factor!r}, mode={self.mode!r})"

View File

@ -18,7 +18,7 @@ def show_labels(prediction, confidence = 0.5, num_classes = 80):
coco_labels = fetch('https://raw.githubusercontent.com/pjreddie/darknet/master/data/coco.names')
coco_labels = coco_labels.decode('utf-8').split('\n')
prediction = prediction.detach().cpu().data
prediction = prediction.detach().cpu().numpy()
conf_mask = (prediction[:,:,4] > confidence)
conf_mask = np.expand_dims(conf_mask, 2)
@ -119,7 +119,7 @@ def bbox_iou(box1, box2):
def process_results(prediction, confidence = 0.9, num_classes = 80, nms_conf = 0.4):
prediction = prediction.detach().cpu().data
prediction = prediction.detach().cpu().numpy()
conf_mask = (prediction[:,:,4] > confidence)
conf_mask = np.expand_dims(conf_mask, 2)
prediction = prediction * conf_mask
@ -275,7 +275,7 @@ def predict_transform(prediction, inp_dim, anchors, num_classes):
prediction = prediction.reshape(shape=(batch_size, grid_size*grid_size*num_anchors, bbox_attrs))
# st = time.time()
prediction_cpu = prediction.cpu().data
prediction_cpu = prediction.cpu().numpy()
# print('put on CPU in %.2f s' % (time.time() - st))
anchors = [(a[0]/stride, a[1]/stride) for a in anchors]
@ -417,11 +417,11 @@ class Darknet:
print(self.blocks[i + 1]["type"], "weights", i)
model = self.module_list[i]
conv = model[0]
print(conv.weight.cpu().data[0][0][0])
print(conv.weight.cpu().numpy()[0][0][0])
if conv.bias is not None:
print("biases")
print(conv.bias.shape)
print(conv.bias.cpu().data[0][0:5])
print(conv.bias.cpu().numpy()[0][0:5])
else:
print("None biases for layer", i)
@ -525,7 +525,7 @@ class Darknet:
if (layers[1]) > 0: layers[1] = layers[1] - i
map1 = outputs[i + layers[0]]
map2 = outputs[i + layers[1]]
x = Tensor(np.concatenate((map1.cpu().data, map2.cpu().data), 1))
x = Tensor(np.concatenate((map1.cpu().numpy(), map2.cpu().numpy()), 1))
elif module_type == "shortcut":
from_ = int(module["from"])
x = outputs[i - 1] + outputs[i + from_]
@ -540,7 +540,7 @@ class Darknet:
detections = x
write = 1
else:
detections = Tensor(np.concatenate((detections.cpu().data, x.cpu().data), 1))
detections = Tensor(np.concatenate((detections.cpu().numpy(), x.cpu().numpy()), 1))
# print(module_type, 'layer took %.2f s' % (time.time() - st))
outputs[i] = x

View File

@ -9,8 +9,8 @@ def mask_like(like, mask_inx, mask_value = 1.0):
def jacobian(func, input):
output = func(input)
ji = input.data.reshape(-1).shape[-1]
jo = output.data.reshape(-1).shape[-1]
ji = input.numpy().reshape(-1).shape[-1]
jo = output.numpy().reshape(-1).shape[-1]
J = np.zeros((jo,ji), dtype=np.float32)
for o in range(jo):
@ -19,25 +19,25 @@ def jacobian(func, input):
# tinygrad doesn't support slicing, tiny-hack to select
# the needed scalar an backpropagate only through it
o_scalar = Tensor(mask_like(output.data, o, 1.)).mul(output).sum()
o_scalar = Tensor(mask_like(output.numpy(), o, 1.)).mul(output).sum()
o_scalar.backward()
for i, grad in enumerate(input.grad.data.reshape(-1)):
for i, grad in enumerate(input.grad.numpy().reshape(-1)):
J[o,i] = grad
return J
def numerical_jacobian(func, input, eps = 1e-6):
output = func(input)
ji = input.data.reshape(-1).shape[-1]
jo = output.data.reshape(-1).shape[-1]
ji = input.numpy().reshape(-1).shape[-1]
jo = output.numpy().reshape(-1).shape[-1]
NJ = np.zeros((jo, ji), dtype=np.float32)
for i in range(ji):
eps_perturb = mask_like(input.data, i, mask_value = eps)
eps_perturb = mask_like(input.numpy(), i, mask_value = eps)
output_perturb_add = func(Tensor(input.data + eps_perturb)).data.reshape(-1)
output_perturb_sub = func(Tensor(input.data - eps_perturb)).data.reshape(-1)
output_perturb_add = func(Tensor(input.numpy() + eps_perturb)).numpy().reshape(-1)
output_perturb_sub = func(Tensor(input.numpy() - eps_perturb)).numpy().reshape(-1)
grad_approx = ((output_perturb_add) - (output_perturb_sub)) / (2*eps)

View File

@ -33,10 +33,10 @@ def train(model, X_train, Y_train, optim, steps, BS=128, lossfn=sparse_categoric
# printing
if not noloss:
cat = np.argmax(out.cpu().data, axis=-1)
cat = np.argmax(out.cpu().numpy(), axis=-1)
accuracy = (cat == y).mean()
loss = loss.detach().cpu().data
loss = loss.detach().cpu().numpy()
losses.append(loss)
accuracies.append(accuracy)
t.set_description("loss %.2f accuracy %.2f" % (loss, accuracy))
@ -50,7 +50,7 @@ def evaluate(model, X_test, Y_test, num_classes=None, BS=128, return_predict=Fal
for i in trange((len(Y_test)-1)//BS+1, disable=getenv('CI', False)):
x = Tensor(transform(X_test[i*BS:(i+1)*BS]))
out = model.forward(x) if hasattr(model, 'forward') else model(x)
Y_test_preds_out[i*BS:(i+1)*BS] = out.cpu().data
Y_test_preds_out[i*BS:(i+1)*BS] = out.cpu().numpy()
Y_test_preds = np.argmax(Y_test_preds_out, axis=-1)
Y_test = target_transform(Y_test)
return (Y_test == Y_test_preds).mean(), Y_test_preds

View File

@ -58,7 +58,7 @@ class Transformer:
def forward(self, x):
bs = x.shape[0]
xnp = x.cpu().data.astype(np.int32)
xnp = x.cpu().numpy().astype(np.int32)
onehot = np.zeros((bs, x.shape[1], self.maxlen+self.syms), dtype=np.float32)
for i in range(x.shape[1]):
onehot[range(bs), i, i] = 1

View File

@ -46,7 +46,7 @@ def _infer(model: EfficientNet, img, bs=1):
# run the net
if bs > 1: img = img.repeat(bs, axis=0)
out = model.forward(Tensor(img)).cpu()
return _LABELS[np.argmax(out.data[0])]
return _LABELS[np.argmax(out.numpy()[0])]
chicken_img = Image.open(pathlib.Path(__file__).parent / 'efficientnet/Chicken.jpg')
car_img = Image.open(pathlib.Path(__file__).parent / 'efficientnet/car.jpg')

View File

@ -18,19 +18,19 @@ class TestNN(unittest.TestCase):
bn.bias = Tensor.randn(sz)
bn.running_mean = Tensor.randn(sz)
bn.running_var = Tensor.randn(sz)
bn.running_var.data[bn.running_var.data < 0] = 0
bn.running_var.numpy()[bn.running_var.numpy() < 0] = 0
# create in torch
with torch.no_grad():
tbn = torch.nn.BatchNorm2d(sz).eval()
tbn.training = training
tbn.weight[:] = torch.tensor(bn.weight.data)
tbn.bias[:] = torch.tensor(bn.bias.data)
tbn.running_mean[:] = torch.tensor(bn.running_mean.data)
tbn.running_var[:] = torch.tensor(bn.running_var.data)
tbn.weight[:] = torch.tensor(bn.weight.numpy())
tbn.bias[:] = torch.tensor(bn.bias.numpy())
tbn.running_mean[:] = torch.tensor(bn.running_mean.numpy())
tbn.running_var[:] = torch.tensor(bn.running_var.numpy())
np.testing.assert_allclose(bn.running_mean.data, tbn.running_mean.detach().numpy(), rtol=1e-5)
np.testing.assert_allclose(bn.running_var.data, tbn.running_var.detach().numpy(), rtol=1e-5)
np.testing.assert_allclose(bn.running_mean.numpy(), tbn.running_mean.detach().numpy(), rtol=1e-5)
np.testing.assert_allclose(bn.running_var.numpy(), tbn.running_var.detach().numpy(), rtol=1e-5)
# trial
inn = Tensor.randn(2, sz, 3, 3)
@ -39,15 +39,15 @@ class TestNN(unittest.TestCase):
outt = bn(inn)
# in torch
toutt = tbn(torch.tensor(inn.cpu().data))
toutt = tbn(torch.tensor(inn.cpu().numpy()))
# close
np.testing.assert_allclose(outt.data, toutt.detach().numpy(), rtol=5e-4)
np.testing.assert_allclose(outt.numpy(), toutt.detach().numpy(), rtol=5e-4)
np.testing.assert_allclose(bn.running_mean.data, tbn.running_mean.detach().numpy(), rtol=1e-5)
np.testing.assert_allclose(bn.running_mean.numpy(), tbn.running_mean.detach().numpy(), rtol=1e-5)
# TODO: this is failing
# np.testing.assert_allclose(bn.running_var.data, tbn.running_var.detach().numpy(), rtol=1e-5)
# np.testing.assert_allclose(bn.running_var.numpy(), tbn.running_var.detach().numpy(), rtol=1e-5)
def test_batchnorm2d_training(self):
self.test_batchnorm2d(True)
@ -64,11 +64,11 @@ class TestNN(unittest.TestCase):
torch_layer = torch.nn.Linear(in_dim, out_dim).eval()
torch_layer.weight[:] = torch.tensor(model.weight.numpy(), dtype=torch.float32)
torch_layer.bias[:] = torch.tensor(model.bias.numpy(), dtype=torch.float32)
torch_x = torch.tensor(x.cpu().data, dtype=torch.float32)
torch_x = torch.tensor(x.cpu().numpy(), dtype=torch.float32)
torch_z = torch_layer(torch_x)
# test
np.testing.assert_allclose(z.data, torch_z.detach().numpy(), atol=5e-4, rtol=1e-5)
np.testing.assert_allclose(z.numpy(), torch_z.detach().numpy(), atol=5e-4, rtol=1e-5)
BS, T, in_dim, out_dim = 4, 2, 8, 16
_test_linear(Tensor.randn(BS, in_dim))
@ -84,15 +84,15 @@ class TestNN(unittest.TestCase):
# create in torch
with torch.no_grad():
torch_layer = torch.nn.Conv2d(C1, C2, kernel_size=K, stride=S, padding=P).eval()
torch_layer.weight[:] = torch.tensor(layer.weight.data, dtype=torch.float32)
torch_layer.bias[:] = torch.tensor(layer.bias.data, dtype=torch.float32)
torch_layer.weight[:] = torch.tensor(layer.weight.numpy(), dtype=torch.float32)
torch_layer.bias[:] = torch.tensor(layer.bias.numpy(), dtype=torch.float32)
# test
x = Tensor.uniform(BS, C1, H, W)
z = layer(x)
torch_x = torch.tensor(x.cpu().data)
torch_x = torch.tensor(x.cpu().numpy())
torch_z = torch_layer(torch_x)
np.testing.assert_allclose(z.data, torch_z.detach().numpy(), atol=5e-4, rtol=1e-5)
np.testing.assert_allclose(z.numpy(), torch_z.detach().numpy(), atol=5e-4, rtol=1e-5)
def test_groupnorm(self):
BS, H, W, C, G = 20, 10, 10, 6, 3
@ -103,15 +103,15 @@ class TestNN(unittest.TestCase):
# create in torch
with torch.no_grad():
torch_layer = torch.nn.GroupNorm(G, C).eval()
torch_layer.weight[:] = torch.tensor(layer.weight.data, dtype=torch.float32)
torch_layer.bias[:] = torch.tensor(layer.bias.data, dtype=torch.float32)
torch_layer.weight[:] = torch.tensor(layer.weight.numpy(), dtype=torch.float32)
torch_layer.bias[:] = torch.tensor(layer.bias.numpy(), dtype=torch.float32)
# test
x = Tensor.randn(BS, C, H, W)
z = layer(x)
torch_x = torch.tensor(x.cpu().data)
torch_x = torch.tensor(x.cpu().numpy())
torch_z = torch_layer(torch_x)
np.testing.assert_allclose(z.data, torch_z.detach().numpy(), atol=5e-3, rtol=5e-3)
np.testing.assert_allclose(z.numpy(), torch_z.detach().numpy(), atol=5e-3, rtol=5e-3)
def test_layernorm(self):
N, C, H, W = 20, 5, 10, 10
@ -122,15 +122,15 @@ class TestNN(unittest.TestCase):
# create in torch
with torch.no_grad():
torch_layer = torch.nn.LayerNorm([H, W]).eval()
torch_layer.weight[:] = torch.tensor(layer.weight.data, dtype=torch.float32)
torch_layer.bias[:] = torch.tensor(layer.bias.data, dtype=torch.float32)
torch_layer.weight[:] = torch.tensor(layer.weight.numpy(), dtype=torch.float32)
torch_layer.bias[:] = torch.tensor(layer.bias.numpy(), dtype=torch.float32)
# test
x = Tensor.randn(N, C, H, W)
z = layer(x)
torch_x = torch.tensor(x.cpu().data)
torch_x = torch.tensor(x.cpu().numpy())
torch_z = torch_layer(torch_x)
np.testing.assert_allclose(z.data, torch_z.detach().numpy(), atol=5e-3, rtol=5e-3)
np.testing.assert_allclose(z.numpy(), torch_z.detach().numpy(), atol=5e-3, rtol=5e-3)
if __name__ == '__main__':
unittest.main()

View File

@ -14,7 +14,7 @@ def step_tinygrad(optim, kwargs={}):
out = net.forward()
out.backward()
optim.step()
return net.x.cpu().data, net.W.cpu().data
return net.x.cpu().numpy(), net.W.cpu().numpy()
def step_pytorch(optim, kwargs={}):
net = TorchNet()

View File

@ -44,7 +44,7 @@ class TestTinygrad(unittest.TestCase):
out = out.logsoftmax()
out = out.mul(m).add(m).sum()
out.backward()
return out.cpu().data, x.grad.cpu().data, W.grad.cpu().data
return out.cpu().numpy(), x.grad.cpu().numpy(), W.grad.cpu().numpy()
def test_pytorch():
x = torch.tensor(x_init, requires_grad=True)
@ -70,7 +70,7 @@ class TestTinygrad(unittest.TestCase):
out = out.logsoftmax()
out = out.sum()
out.backward()
return out.cpu().data, u.cpu().grad.data, v.cpu().grad.data, w.cpu().grad.data
return out.cpu().numpy(), u.cpu().grad.numpy(), v.cpu().grad.numpy(), w.cpu().grad.numpy()
def test_pytorch():
u = torch.tensor(U_init, requires_grad=True)
@ -106,7 +106,7 @@ class TestTinygrad(unittest.TestCase):
Tensor.training = True
n, rate = 1_000_000, 0.1
w = Tensor.ones(n).dropout(rate)
non_zeros = np.count_nonzero(w.cpu().data)
non_zeros = np.count_nonzero(w.cpu().numpy())
expected = n * (1 - rate)
np.testing.assert_allclose(non_zeros, expected, rtol=1e-3)

View File

@ -94,10 +94,6 @@ class Tensor:
def detach(self): return Tensor(self.lazydata, device=self.device, requires_grad=False)
def numpy(self) -> np.ndarray: return np.array(self.lazydata.toCPU())
# TODO: this keeps the legacy behavior working, remove it after refactor
@property
def data(self) -> np.ndarray: return self.numpy()
# TODO: if things are realized this won't work
def to_(self, device:str):
assert self.lazydata.realized is None