Skip to content

Movement

Movement (low level)¤

view ¤

view(*shape) -> Tensor

.view is an alias for .reshape.

Source code in tinygrad/tensor.py
880
881
882
def view(self, *shape) -> Tensor:
  """`.view` is an alias for `.reshape`."""
  return self.reshape(shape)

reshape ¤

reshape(shape, *args) -> Tensor

Returns a tensor with the same data as the original tensor but with a different shape. shape can be passed as a tuple or as separate arguments.

t = Tensor.arange(6)
print(t.reshape(2, 3).numpy())
[[0 1 2]
 [3 4 5]]
Source code in tinygrad/tensor.py
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
def reshape(self, shape, *args) -> Tensor:
  """
  Returns a tensor with the same data as the original tensor but with a different shape.
  `shape` can be passed as a tuple or as separate arguments.

  ```python exec="true" source="above" session="tensor" result="python"
  t = Tensor.arange(6)
  print(t.reshape(2, 3).numpy())
  ```
  """
  # resolve None and args
  new_shape = tuple([s if s is not None else self.shape[i] for i,s in enumerate(argfix(shape, *args))])
  # resolve -1
  if (c := new_shape.count(-1)) > 1: raise RuntimeError(f"only one dimension can be inferred using -1, getting {new_shape}")
  if c: new_shape = tuple([-prod(self.shape) // prod(new_shape) if s == -1 else s for s in new_shape])
  return F.Reshape.apply(self, shape=new_shape) if new_shape != self.shape else self

expand ¤

expand(shape, *args) -> Tensor

Returns a tensor that is expanded to the shape that is specified. Expand can also increase the number of dimensions that a tensor has.

Passing a -1 or None to a dimension means that its size will not be changed.

t = Tensor([1, 2, 3])
print(t.expand(4, -1).numpy())
[[1 2 3]
 [1 2 3]
 [1 2 3]
 [1 2 3]]
Source code in tinygrad/tensor.py
901
902
903
904
905
906
907
908
909
910
911
912
913
def expand(self, shape, *args) -> Tensor:
  """
  Returns a tensor that is expanded to the shape that is specified.
  Expand can also increase the number of dimensions that a tensor has.

  Passing a `-1` or `None` to a dimension means that its size will not be changed.

  ```python exec="true" source="above" session="tensor" result="python"
  t = Tensor([1, 2, 3])
  print(t.expand(4, -1).numpy())
  ```
  """
  return self._broadcast_to(tuple(from_ if to == -1 or to is None else to for from_, to in zip(*(_pad_left(self.shape, argfix(shape, *args))))))

permute ¤

permute(order, *args) -> Tensor

Returns a tensor that is a permutation of the original tensor. The new tensor has the same data as the original tensor but with the dimensions permuted according to the order specified. order can be passed as a tuple or as separate arguments.

t = Tensor.arange(6).reshape(2, 3)
print(t.numpy())
[[0 1 2]
 [3 4 5]]
print(t.permute(1, 0).numpy())
[[0 3]
 [1 4]
 [2 5]]

Source code in tinygrad/tensor.py
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
def permute(self, order, *args) -> Tensor:
  """
  Returns a tensor that is a permutation of the original tensor.
  The new tensor has the same data as the original tensor but with the dimensions permuted according to the order specified.
  `order` can be passed as a tuple or as separate arguments.

  ```python exec="true" source="above" session="tensor" result="python"
  t = Tensor.arange(6).reshape(2, 3)
  print(t.numpy())
  ```
  ```python exec="true" source="above" session="tensor" result="python"
  print(t.permute(1, 0).numpy())
  ```
  """
  order_arg = tuple(self._resolve_dim(x) for x in argfix(order, *args))
  if sorted(order_arg) != list(range(self.ndim)): raise RuntimeError(f"order is not a valid permutation, getting {order_arg}")
  return F.Permute.apply(self, order=order_arg)

flip ¤

flip(axis, *args) -> Tensor

Returns a tensor that reverses the order of the original tensor along given axis. axis can be passed as a tuple or as separate arguments.

t = Tensor.arange(6).reshape(2, 3)
print(t.numpy())
[[0 1 2]
 [3 4 5]]
print(t.flip(0).numpy())
[[3 4 5]
 [0 1 2]]
print(t.flip((0, 1)).numpy())
[[5 4 3]
 [2 1 0]]

Source code in tinygrad/tensor.py
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
def flip(self, axis, *args) -> Tensor:
  """
  Returns a tensor that reverses the order of the original tensor along given `axis`.
  `axis` can be passed as a tuple or as separate arguments.

  ```python exec="true" source="above" session="tensor" result="python"
  t = Tensor.arange(6).reshape(2, 3)
  print(t.numpy())
  ```
  ```python exec="true" source="above" session="tensor" result="python"
  print(t.flip(0).numpy())
  ```
  ```python exec="true" source="above" session="tensor" result="python"
  print(t.flip((0, 1)).numpy())
  ```
  """
  axis_arg = tuple(self._resolve_dim(x) for x in argfix(axis, *args))
  if len(axis_arg) != len(dedup(axis_arg)): raise RuntimeError(f"dim can appear at least once, getting {axis_arg}")
  return F.Flip.apply(self, axis=axis_arg)

shrink ¤

shrink(
    arg: Tuple[Optional[Tuple[sint, sint]], ...]
) -> Tensor

Returns a tensor that shrinks the each axis based on input arg. arg must have the same length as self.ndim. For each axis, it can be None, which means no shrink, or a tuple (start, end) that works the same as Python slice.

t = Tensor.arange(9).reshape(3, 3)
print(t.numpy())
[[0 1 2]
 [3 4 5]
 [6 7 8]]
print(t.shrink(((None, (1, 3)))).numpy())
[[1 2]
 [4 5]
 [7 8]]
print(t.shrink((((0, 2), (0, 2)))).numpy())
[[0 1]
 [3 4]]

Source code in tinygrad/tensor.py
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
def shrink(self, arg:Tuple[Optional[Tuple[sint, sint]], ...]) -> Tensor:
  """
  Returns a tensor that shrinks the each axis based on input arg.
  `arg` must have the same length as `self.ndim`.
  For each axis, it can be `None`, which means no shrink, or a tuple `(start, end)` that works the same as Python slice.

  ```python exec="true" source="above" session="tensor" result="python"
  t = Tensor.arange(9).reshape(3, 3)
  print(t.numpy())
  ```
  ```python exec="true" source="above" session="tensor" result="python"
  print(t.shrink(((None, (1, 3)))).numpy())
  ```
  ```python exec="true" source="above" session="tensor" result="python"
  print(t.shrink((((0, 2), (0, 2)))).numpy())
  ```
  """
  if all(x is None or x == (0,s) for x,s in zip(arg, self.shape)): return self
  return F.Shrink.apply(self, arg=tuple(x if x is not None else (0,s) for x,s in zip(arg, self.shape)))

pad ¤

pad(
    arg: Tuple[Optional[Tuple[sint, sint]], ...],
    value: float = 0.0,
) -> Tensor

Returns a tensor that pads the each axis based on input arg. arg must have the same length as self.ndim. For each axis, it can be None, which means no pad, or a tuple (pad_before, pad_after). If value is specified, the tensor is padded with value instead of 0.0.

t = Tensor.arange(6).reshape(2, 3)
print(t.numpy())
[[0 1 2]
 [3 4 5]]
print(t.pad(((None, (1, 2)))).numpy())
[[0 0 1 2 0 0]
 [0 3 4 5 0 0]]
print(t.pad(((None, (1, 2))), -2).numpy())
[[-2  0  1  2 -2 -2]
 [-2  3  4  5 -2 -2]]

Source code in tinygrad/tensor.py
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
def pad(self, arg:Tuple[Optional[Tuple[sint, sint]], ...], value:float=0.0) -> Tensor:
  """
  Returns a tensor that pads the each axis based on input arg.
  `arg` must have the same length as `self.ndim`.
  For each axis, it can be `None`, which means no pad, or a tuple `(pad_before, pad_after)`.
  If `value` is specified, the tensor is padded with `value` instead of `0.0`.

  ```python exec="true" source="above" session="tensor" result="python"
  t = Tensor.arange(6).reshape(2, 3)
  print(t.numpy())
  ```
  ```python exec="true" source="above" session="tensor" result="python"
  print(t.pad(((None, (1, 2)))).numpy())
  ```
  ```python exec="true" source="above" session="tensor" result="python"
  print(t.pad(((None, (1, 2))), -2).numpy())
  ```
  """
  if all(x is None or x == (0,0) for x in arg): return self
  ret = F.Pad.apply(self, arg=(narg:=tuple(x if x is not None else (0,0) for x in arg)))
  return ret if 0 == value else ret + F.Pad.apply(Tensor.ones_like(self), arg=narg).where(0, value)

Movement (high level)¤

gather ¤

gather(dim: int, index: Tensor) -> Tensor

Gathers values along an axis specified by dim.

t = Tensor([[1, 2], [3, 4]])
print(t.numpy())
[[1 2]
 [3 4]]
print(t.gather(1, Tensor([[0, 0], [1, 0]])).numpy())
[[1 1]
 [4 3]]

Source code in tinygrad/tensor.py
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
def gather(self:Tensor, dim:int, index:Tensor) -> Tensor:
  """
  Gathers values along an axis specified by `dim`.

  ```python exec="true" source="above" session="tensor" result="python"
  t = Tensor([[1, 2], [3, 4]])
  print(t.numpy())
  ```
  ```python exec="true" source="above" session="tensor" result="python"
  print(t.gather(1, Tensor([[0, 0], [1, 0]])).numpy())
  ```
  """
  assert index.ndim == self.ndim, f"self.ndim must equal index.ndim, {self.ndim=}, {index.ndim=}"
  dim = self._resolve_dim(dim)
  assert all(s >= i for d,(s,i) in enumerate(zip(self.shape, index.shape)) if d != dim), "requires self.shape[d] >= index.shape[d] for all d != dim"
  index = index.to(self.device)
  x = self.shrink(tuple((0, i) if d != dim else None for d,i in enumerate(index.shape))).unsqueeze(-1).transpose(-1, dim)
  return ((index.unsqueeze(-1) == Tensor.arange(self.shape[dim], requires_grad=False, device=self.device)) * x).sum(-1, acc_dtype=self.dtype)

cat ¤

cat(*args: Tensor, dim: int = 0) -> Tensor

Concatenates self with other Tensor in args along an axis specified by dim. All tensors must have the same shape except in the concatenating dimension.

t0, t1, t2 = Tensor([[1, 2]]), Tensor([[3, 4]]), Tensor([[5, 6]])
print(t0.cat(t1, t2, dim=0).numpy())
[[1 2]
 [3 4]
 [5 6]]
print(t0.cat(t1, t2, dim=1).numpy())
[[1 2 3 4 5 6]]

Source code in tinygrad/tensor.py
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
def cat(self:Tensor, *args:Tensor, dim:int=0) -> Tensor:
  """
  Concatenates self with other `Tensor` in `args` along an axis specified by `dim`.
  All tensors must have the same shape except in the concatenating dimension.

  ```python exec="true" source="above" session="tensor" result="python"
  t0, t1, t2 = Tensor([[1, 2]]), Tensor([[3, 4]]), Tensor([[5, 6]])
  print(t0.cat(t1, t2, dim=0).numpy())
  ```
  ```python exec="true" source="above" session="tensor" result="python"
  print(t0.cat(t1, t2, dim=1).numpy())
  ```
  """
  dim = self._resolve_dim(dim)
  assert all(len(y.shape) == len(self.shape) and all(y.shape[i] == s for i,s in enumerate(self.shape) if i != dim) for y in args)
  catargs = [self, *args]
  cat_dims = [s.shape[dim] for s in catargs]
  cat_dim_cumsum = [0, *itertools.accumulate(cat_dims)]
  slc:List[List[Optional[Tuple[sint, sint]]]] = [[None for _ in self.shape] for _ in catargs]
  for d,k,s in zip(cat_dims, cat_dim_cumsum[:-1], slc): s[dim] = (k, cat_dim_cumsum[-1] - k - d)
  return functools.reduce(Tensor.__add__, [arg.pad(tuple(s)) for arg,s in zip(catargs, slc)])

stack ¤

stack(*args: Tensor, dim: int = 0) -> Tensor

Concatenates self with other Tensor in args along a new dimension specified by dim.

t0, t1, t2 = Tensor([1, 2]), Tensor([3, 4]), Tensor([5, 6])
print(t0.stack(t1, t2, dim=0).numpy())
[[1 2]
 [3 4]
 [5 6]]
print(t0.stack(t1, t2, dim=1).numpy())
[[1 3 5]
 [2 4 6]]

Source code in tinygrad/tensor.py
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
def stack(self:Tensor, *args:Tensor, dim:int=0) -> Tensor:
  """
  Concatenates self with other `Tensor` in `args` along a new dimension specified by `dim`.

  ```python exec="true" source="above" session="tensor" result="python"
  t0, t1, t2 = Tensor([1, 2]), Tensor([3, 4]), Tensor([5, 6])
  print(t0.stack(t1, t2, dim=0).numpy())
  ```
  ```python exec="true" source="above" session="tensor" result="python"
  print(t0.stack(t1, t2, dim=1).numpy())
  ```
  """
  # checks for shapes and number of dimensions delegated to cat
  return self.unsqueeze(dim).cat(*[t.unsqueeze(dim) for t in args], dim=dim)

repeat ¤

repeat(repeats, *args) -> Tensor

Repeats tensor number of times along each dimension specified by repeats. repeats can be passed as a tuple or as separate arguments.

t = Tensor([1, 2, 3])
print(t.repeat(4, 2).numpy())
[[1 2 3 1 2 3]
 [1 2 3 1 2 3]
 [1 2 3 1 2 3]
 [1 2 3 1 2 3]]
print(t.repeat(4, 2, 1).shape)
(4, 2, 3)

Source code in tinygrad/tensor.py
1235
1236
1237
1238
1239
1240
1241
1242
1243
1244
1245
1246
1247
1248
1249
1250
1251
1252
1253
def repeat(self, repeats, *args) -> Tensor:
  """
  Repeats tensor number of times along each dimension specified by `repeats`.
  `repeats` can be passed as a tuple or as separate arguments.

  ```python exec="true" source="above" session="tensor" result="python"
  t = Tensor([1, 2, 3])
  print(t.repeat(4, 2).numpy())
  ```
  ```python exec="true" source="above" session="tensor" result="python"
  print(t.repeat(4, 2, 1).shape)
  ```
  """
  repeats = argfix(repeats, *args)
  base_shape = (1,) * (len(repeats) - self.ndim) + self.shape
  new_shape = [x for b in base_shape for x in [1, b]]
  expand_shape = [x for rs in zip(repeats, base_shape) for x in rs]
  final_shape = [r*s for r,s in zip(repeats, base_shape)]
  return self.reshape(new_shape).expand(expand_shape).reshape(final_shape)

repeat_interleave ¤

repeat_interleave(
    repeats: int, dim: Optional[int] = None
) -> Tensor

Repeat elements of a tensor.

t = Tensor([1, 2, 3])
print(t.repeat_interleave(2).numpy())
[1 1 2 2 3 3]
Source code in tinygrad/tensor.py
1222
1223
1224
1225
1226
1227
1228
1229
1230
1231
1232
1233
def repeat_interleave(self, repeats:int, dim:Optional[int]=None) -> Tensor:
  """
  Repeat elements of a tensor.

  ```python exec="true" source="above" session="tensor" result="python"
  t = Tensor([1, 2, 3])
  print(t.repeat_interleave(2).numpy())
  ```
  """
  x, dim = (self.flatten(), 0) if dim is None else (self, dim)
  shp = x.shape
  return x.reshape(*shp[:dim+1], 1, *shp[dim+1:]).expand(*shp[:dim+1], repeats, *shp[dim+1:]).reshape(*shp[:dim], shp[dim]*repeats, *shp[dim+1:])

split ¤

split(
    sizes: Union[int, List[int]], dim: int = 0
) -> Tuple[Tensor, ...]

Splits the tensor into chunks along the dimension specified by dim. If sizes is an integer, it splits into equally sized chunks if possible, otherwise the last chunk will be smaller. If sizes is a list, it splits into len(sizes) chunks with size in dim according to size.

t = Tensor.arange(10).reshape(5, 2)
print(t.numpy())
[[0 1]
 [2 3]
 [4 5]
 [6 7]
 [8 9]]
split = t.split(2)
print("\n".join([repr(x.numpy()) for x in split]))
array([[0, 1],
       [2, 3]], dtype=int32)
array([[4, 5],
       [6, 7]], dtype=int32)
array([[8, 9]], dtype=int32)
split = t.split([1, 4])
print("\n".join([repr(x.numpy()) for x in split]))
array([[0, 1]], dtype=int32)
array([[2, 3],
       [4, 5],
       [6, 7],
       [8, 9]], dtype=int32)

Source code in tinygrad/tensor.py
1260
1261
1262
1263
1264
1265
1266
1267
1268
1269
1270
1271
1272
1273
1274
1275
1276
1277
1278
1279
1280
1281
1282
1283
def split(self, sizes:Union[int, List[int]], dim:int=0) -> Tuple[Tensor, ...]:
  """
  Splits the tensor into chunks along the dimension specified by `dim`.
  If `sizes` is an integer, it splits into equally sized chunks if possible, otherwise the last chunk will be smaller.
  If `sizes` is a list, it splits into `len(sizes)` chunks with size in `dim` according to `size`.

  ```python exec="true" source="above" session="tensor" result="python"
  t = Tensor.arange(10).reshape(5, 2)
  print(t.numpy())
  ```
  ```python exec="true" source="above" session="tensor" result="python"
  split = t.split(2)
  print("\\n".join([repr(x.numpy()) for x in split]))
  ```
  ```python exec="true" source="above" session="tensor" result="python"
  split = t.split([1, 4])
  print("\\n".join([repr(x.numpy()) for x in split]))
  ```
  """
  assert all_int(self.shape), f"does not support symbolic shape {self.shape}"
  dim = self._resolve_dim(dim)
  if isinstance(sizes, int): sizes = [min(sizes, self.shape[dim]-i) for i in range(0, max(1, self.shape[dim]), max(1, sizes))]
  assert sum(sizes) == self.shape[dim], f"expect sizes to sum exactly to {self.shape[dim]}, but got {sum(sizes)}"
  return tuple(self[sl] for sl in [tuple([slice(None)]*dim + [slice(sum(sizes[:i]), sum(sizes[:i + 1]))]) for i in range(len(sizes))])

chunk ¤

chunk(chunks: int, dim: int = 0) -> List[Tensor]

Splits the tensor into chunks number of chunks along the dimension dim. If the tensor size along dim is not divisible by chunks, all returned chunks will be the same size except the last one. The function may return fewer than the specified number of chunks.

chunked = Tensor.arange(11).chunk(6)
print("\n".join([repr(x.numpy()) for x in chunked]))
array([0, 1], dtype=int32)
array([2, 3], dtype=int32)
array([4, 5], dtype=int32)
array([6, 7], dtype=int32)
array([8, 9], dtype=int32)
array([10], dtype=int32)
chunked = Tensor.arange(12).chunk(6)
print("\n".join([repr(x.numpy()) for x in chunked]))
array([0, 1], dtype=int32)
array([2, 3], dtype=int32)
array([4, 5], dtype=int32)
array([6, 7], dtype=int32)
array([8, 9], dtype=int32)
array([10, 11], dtype=int32)
chunked = Tensor.arange(13).chunk(6)
print("\n".join([repr(x.numpy()) for x in chunked]))
array([0, 1, 2], dtype=int32)
array([3, 4, 5], dtype=int32)
array([6, 7, 8], dtype=int32)
array([ 9, 10, 11], dtype=int32)
array([12], dtype=int32)

Source code in tinygrad/tensor.py
1285
1286
1287
1288
1289
1290
1291
1292
1293
1294
1295
1296
1297
1298
1299
1300
1301
1302
1303
1304
1305
1306
1307
def chunk(self, chunks:int, dim:int=0) -> List[Tensor]:
  """
  Splits the tensor into `chunks` number of chunks along the dimension `dim`.
  If the tensor size along `dim` is not divisible by `chunks`, all returned chunks will be the same size except the last one.
  The function may return fewer than the specified number of chunks.

  ```python exec="true" source="above" session="tensor" result="python"
  chunked = Tensor.arange(11).chunk(6)
  print("\\n".join([repr(x.numpy()) for x in chunked]))
  ```
  ```python exec="true" source="above" session="tensor" result="python"
  chunked = Tensor.arange(12).chunk(6)
  print("\\n".join([repr(x.numpy()) for x in chunked]))
  ```
  ```python exec="true" source="above" session="tensor" result="python"
  chunked = Tensor.arange(13).chunk(6)
  print("\\n".join([repr(x.numpy()) for x in chunked]))
  ```
  """
  assert all_int(self.shape), f"does not support symbolic shape {self.shape}"
  assert chunks > 0, f"expect chunks to be greater than 0, got: {chunks}"
  dim = self._resolve_dim(dim)
  return list(self.split(ceildiv(self.shape[dim], chunks) if self.shape[dim] else [0]*chunks, dim=dim))

squeeze ¤

squeeze(dim: Optional[int] = None) -> Tensor

Returns a tensor with specified dimensions of input of size 1 removed. If dim is not specified, all dimensions with size 1 are removed.

t = Tensor.zeros(2, 1, 2, 1, 2)
print(t.squeeze().shape)
(2, 2, 2)
print(t.squeeze(0).shape)
(2, 1, 2, 1, 2)
print(t.squeeze(1).shape)
(2, 2, 1, 2)

Source code in tinygrad/tensor.py
1309
1310
1311
1312
1313
1314
1315
1316
1317
1318
1319
1320
1321
1322
1323
1324
1325
1326
1327
def squeeze(self, dim:Optional[int]=None) -> Tensor:
  """
  Returns a tensor with specified dimensions of input of size 1 removed.
  If `dim` is not specified, all dimensions with size 1 are removed.

  ```python exec="true" source="above" session="tensor" result="python"
  t = Tensor.zeros(2, 1, 2, 1, 2)
  print(t.squeeze().shape)
  ```
  ```python exec="true" source="above" session="tensor" result="python"
  print(t.squeeze(0).shape)
  ```
  ```python exec="true" source="above" session="tensor" result="python"
  print(t.squeeze(1).shape)
  ```
  """
  if dim is None: return self.reshape(tuple(dim for dim in self.shape if dim != 1))
  dim = self._resolve_dim(dim)
  return self if not self.ndim or self.shape[dim] != 1 else self.reshape(self.shape[:dim] + self.shape[dim+1:])

unsqueeze ¤

unsqueeze(dim: int) -> Tensor

Returns a tensor with a new dimension of size 1 inserted at the specified dim.

t = Tensor([1, 2, 3, 4])
print(t.unsqueeze(0).numpy())
[[1 2 3 4]]
print(t.unsqueeze(1).numpy())
[[1]
 [2]
 [3]
 [4]]

Source code in tinygrad/tensor.py
1329
1330
1331
1332
1333
1334
1335
1336
1337
1338
1339
1340
1341
1342
def unsqueeze(self, dim:int) -> Tensor:
  """
  Returns a tensor with a new dimension of size 1 inserted at the specified `dim`.

  ```python exec="true" source="above" session="tensor" result="python"
  t = Tensor([1, 2, 3, 4])
  print(t.unsqueeze(0).numpy())
  ```
  ```python exec="true" source="above" session="tensor" result="python"
  print(t.unsqueeze(1).numpy())
  ```
  """
  dim = self._resolve_dim(dim, outer=True)
  return self.reshape(self.shape[:dim] + (1,) + self.shape[dim:])

pad2d ¤

pad2d(padding: Sequence[int], value: float = 0.0) -> Tensor

Returns a tensor that pads the last two axes specified by padding (padding_left, padding_right, padding_top, padding_bottom). If value is specified, the tensor is padded with value instead of 0.0.

t = Tensor.arange(9).reshape(1, 1, 3, 3)
print(t.numpy())
[[[[0 1 2]
   [3 4 5]
   [6 7 8]]]]
print(t.pad2d((1, 1, 2, 0), value=-float("inf")).numpy())
[[[[-inf -inf -inf -inf -inf]
   [-inf -inf -inf -inf -inf]
   [-inf   0.   1.   2. -inf]
   [-inf   3.   4.   5. -inf]
   [-inf   6.   7.   8. -inf]]]]

Source code in tinygrad/tensor.py
1344
1345
1346
1347
1348
1349
1350
1351
1352
1353
1354
1355
1356
1357
1358
1359
1360
def pad2d(self, padding:Sequence[int], value:float=0.0) -> Tensor:
  """
  Returns a tensor that pads the last two axes specified by `padding` (padding_left, padding_right, padding_top, padding_bottom).
  If `value` is specified, the tensor is padded with `value` instead of `0.0`.

  ```python exec="true" source="above" session="tensor" result="python"
  t = Tensor.arange(9).reshape(1, 1, 3, 3)
  print(t.numpy())
  ```
  ```python exec="true" source="above" session="tensor" result="python"
  print(t.pad2d((1, 1, 2, 0), value=-float("inf")).numpy())
  ```
  """
  pads = tuple((max(p0, 0), max(p1, 0)) for p0, p1 in zip(padding[::2], padding[1::2]))[::-1]
  padded = self.pad((None,) * (self.ndim - len(padding) // 2) + tuple(pads), value=value)
  shrink = tuple((-min(p0, 0), min(p1 + s, s)) for p0, p1, s in zip(padding[::2], padding[1::2], padded.shape[::-1]))[::-1]
  return padded.shrink((None,) * (self.ndim - len(padding) // 2) + shrink)

T property ¤

T: Tensor

.T is an alias for .transpose().

transpose ¤

transpose(dim0=1, dim1=0) -> Tensor

Returns a tensor that is a transposed version of the original tensor. The given dimensions dim0 and dim1 are swapped.

t = Tensor.arange(6).reshape(2, 3)
print(t.numpy())
[[0 1 2]
 [3 4 5]]
print(t.transpose(0, 1).numpy())
[[0 3]
 [1 4]
 [2 5]]

Source code in tinygrad/tensor.py
1367
1368
1369
1370
1371
1372
1373
1374
1375
1376
1377
1378
1379
1380
1381
1382
def transpose(self, dim0=1, dim1=0) -> Tensor:
  """
  Returns a tensor that is a transposed version of the original tensor.
  The given dimensions `dim0` and `dim1` are swapped.

  ```python exec="true" source="above" session="tensor" result="python"
  t = Tensor.arange(6).reshape(2, 3)
  print(t.numpy())
  ```
  ```python exec="true" source="above" session="tensor" result="python"
  print(t.transpose(0, 1).numpy())
  ```
  """
  order = list(range(self.ndim))
  order[dim0], order[dim1] = order[dim1], order[dim0]
  return self.permute(order)

flatten ¤

flatten(start_dim=0, end_dim=-1)

Flattens the tensor by reshaping it into a one-dimensional tensor. If start_dim or end_dim are passed, only dimensions starting with start_dim and ending with end_dim are flattened.

t = Tensor.arange(8).reshape(2, 2, 2)
print(t.flatten().numpy())
[0 1 2 3 4 5 6 7]
print(t.flatten(start_dim=1).numpy())
[[0 1 2 3]
 [4 5 6 7]]

Source code in tinygrad/tensor.py
1384
1385
1386
1387
1388
1389
1390
1391
1392
1393
1394
1395
1396
1397
1398
def flatten(self, start_dim=0, end_dim=-1):
  """
  Flattens the tensor by reshaping it into a one-dimensional tensor.
  If `start_dim` or `end_dim` are passed, only dimensions starting with `start_dim` and ending with `end_dim` are flattened.

  ```python exec="true" source="above" session="tensor" result="python"
  t = Tensor.arange(8).reshape(2, 2, 2)
  print(t.flatten().numpy())
  ```
  ```python exec="true" source="above" session="tensor" result="python"
  print(t.flatten(start_dim=1).numpy())
  ```
  """
  start_dim, end_dim = self._resolve_dim(start_dim), self._resolve_dim(end_dim)
  return self.reshape(self.shape[:start_dim] + (prod(self.shape[start_dim:end_dim+1]), ) + self.shape[end_dim+1:])

unflatten ¤

unflatten(dim: int, sizes: Tuple[int, ...])

Unflattens dimension dim of the tensor into multiple dimensions specified by sizes. Tensor.flatten() is the inverse of this function.

print(Tensor.ones(3, 4, 1).unflatten(1, (2, 2)).shape)
(3, 2, 2, 1)
print(Tensor.ones(3, 4, 1).unflatten(1, (-1, 2)).shape)
(3, 2, 2, 1)
print(Tensor.ones(5, 12, 3).unflatten(-2, (2, 2, 3, 1, 1)).shape)
(5, 2, 2, 3, 1, 1, 3)

Source code in tinygrad/tensor.py
1400
1401
1402
1403
1404
1405
1406
1407
1408
1409
1410
1411
1412
1413
1414
1415
def unflatten(self, dim:int, sizes:Tuple[int,...]):
  """
  Unflattens dimension `dim` of the tensor into multiple dimensions specified by `sizes`. `Tensor.flatten()` is the inverse of this function.

  ```python exec="true" source="above" session="tensor" result="python"
  print(Tensor.ones(3, 4, 1).unflatten(1, (2, 2)).shape)
  ```
  ```python exec="true" source="above" session="tensor" result="python"
  print(Tensor.ones(3, 4, 1).unflatten(1, (-1, 2)).shape)
  ```
  ```python exec="true" source="above" session="tensor" result="python"
  print(Tensor.ones(5, 12, 3).unflatten(-2, (2, 2, 3, 1, 1)).shape)
  ```
  """
  dim = self._resolve_dim(dim)
  return self.reshape(self.shape[:dim] + sizes + self.shape[dim+1:])