mirror of https://github.com/commaai/tinygrad.git
hlb_cifar10 init from torch weights
This commit is contained in:
parent
a6b9733256
commit
e4db0c820f
|
@ -83,6 +83,17 @@ def train_cifar():
|
|||
print(X_train.shape, Y_train.shape)
|
||||
Xt, Yt = fetch_batch(X_test, Y_test, BS=BS)
|
||||
model = SpeedyResNet()
|
||||
|
||||
# init weights with torch
|
||||
if getenv("TORCHWEIGHTS"):
|
||||
from examples.hlb_cifar10_torch import SpeedyResNet as SpeedyResNetTorch
|
||||
torch_model = SpeedyResNetTorch()
|
||||
model_state_dict = optim.get_state_dict(model)
|
||||
torch_state_dict = torch_model.state_dict()
|
||||
for k,v in torch_state_dict.items():
|
||||
print(f"initting {k} from torch")
|
||||
model_state_dict[k].assign(Tensor(v.detach().numpy())).realize()
|
||||
|
||||
if getenv("ADAM"):
|
||||
optimizer = optim.Adam(optim.get_parameters(model), lr=Tensor([0.001]).realize())
|
||||
else:
|
||||
|
|
|
@ -30,28 +30,28 @@ class ConvGroup(nn.Module):
|
|||
x = self.norm[2](self.conv[2](x) * mult).relu()
|
||||
return x + residual
|
||||
|
||||
class GlobalMaxPool(nn.Module):
|
||||
def forward(self, x): return torch.amax(x, dim=(2,3))
|
||||
|
||||
class SpeedyResNet(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
# TODO: add whitening
|
||||
self.ic = nn.Conv2d(3, 64, kernel_size=1)
|
||||
self.ib = nn.BatchNorm2d(64, track_running_stats=False, eps=1e-12, momentum=0.8)
|
||||
self.net = nn.ModuleList([
|
||||
nn.Conv2d(3, 64, kernel_size=1),
|
||||
nn.BatchNorm2d(64, track_running_stats=False, eps=1e-12, momentum=0.8),
|
||||
nn.ReLU(),
|
||||
ConvGroup(64, 128, short=False),
|
||||
ConvGroup(128, 256, short=True),
|
||||
ConvGroup(256, 512, short=False),
|
||||
GlobalMaxPool(),
|
||||
nn.Linear(512, num_classes, bias=False)
|
||||
])
|
||||
self.lin = nn.Linear(512, num_classes, bias=False)
|
||||
|
||||
# note, pytorch just uses https://pytorch.org/docs/stable/generated/torch.nn.CrossEntropyLoss.html instead of log_softmax
|
||||
def forward(self, x):
|
||||
x = self.ic(x)
|
||||
x = self.ib(x)
|
||||
x = x.relu()
|
||||
for layer in self.net:
|
||||
x = layer(x)
|
||||
x = torch.amax(x, dim=(2,3))
|
||||
x = self.lin(x)
|
||||
return x.log_softmax(-1)
|
||||
|
||||
def train_step_jitted(model, optimizer, X, Y):
|
||||
|
|
|
@ -0,0 +1,57 @@
|
|||
import numpy as np
|
||||
import ctypes
|
||||
from pyhip import hip, hiprtc # type: ignore
|
||||
from tinygrad.helpers import DEBUG
|
||||
from tinygrad.ops import Compiled
|
||||
from tinygrad.runtime.lib import RawBufferCopyInOut
|
||||
from tinygrad.codegen.cstyle import CStyleCodegen, CStyleLanguage
|
||||
|
||||
# The default HIP stream is used for everything.
|
||||
|
||||
class RawHIPBuffer(RawBufferCopyInOut):
|
||||
def __init__(self, size, dtype):
|
||||
self.buf_sz = size * dtype.itemsize
|
||||
super().__init__(size, dtype, hip.hipMalloc(self.buf_sz))
|
||||
def _copyin(self, x:np.ndarray): hip.hipMemcpyAsync_htod(self._buf, x.ctypes.data, self.buf_sz, 0)
|
||||
def _copyout(self, x:np.ndarray): hip.hipMemcpyAsync_dtoh(x.ctypes.data, self._buf, self.buf_sz, 0)
|
||||
|
||||
class HIPProgram:
|
||||
def __init__(self, name:str, prg:str, binary=False):
|
||||
try:
|
||||
if not binary:
|
||||
prog = hiprtc.hiprtcCreateProgram(prg, name, [], [])
|
||||
device_properties = hip.hipGetDeviceProperties(0)
|
||||
hiprtc.hiprtcCompileProgram(prog, [f'--offload-arch={device_properties.gcnArchName}'])
|
||||
prg = hiprtc.hiprtcGetCode(prog)
|
||||
except hip.hipError as e:
|
||||
if DEBUG >= 3: print("FAILED TO BUILD", prg)
|
||||
raise e
|
||||
if DEBUG >= 5: print(prg)
|
||||
module = hip.hipModuleLoadData(prg)
|
||||
self.prg = hip.hipModuleGetFunction(module, name)
|
||||
|
||||
def __call__(self, global_size, local_size, *args, wait=False):
|
||||
local_size = (local_size + [1] * (3 - len(local_size))) if local_size is not None else (1,1,1)
|
||||
global_size = global_size + [1] * (3 - len(global_size))
|
||||
assert all(x%y == 0 for x,y in zip(global_size, local_size)), f"local:{local_size} must divide global:{global_size}"
|
||||
global_size = [x//y for x,y in zip(global_size, local_size)]
|
||||
if wait:
|
||||
start, end = hip.hipEventCreate(), hip.hipEventCreate()
|
||||
hip.hipEventRecord(start)
|
||||
class PackageStruct(ctypes.Structure):
|
||||
_fields_ = [(f'field{idx}', ctypes.c_void_p) for idx in range(len(args))]
|
||||
struct = PackageStruct(*[data._buf for data in args])
|
||||
hip.hipModuleLaunchKernel(self.prg, global_size[0], global_size[1], global_size[2], local_size[0], local_size[1], local_size[2], 0, 0, struct)
|
||||
if wait:
|
||||
hip.hipEventRecord(end)
|
||||
hip.hipEventSynchronize(end)
|
||||
return hip.hipEventElapsedTime(start, end)*1e-3
|
||||
|
||||
class HIPCodegen(CStyleCodegen):
|
||||
lang = CStyleLanguage(
|
||||
kernel_prefix = "#define INFINITY (__builtin_inff())\nextern \"C\" __global__", smem_prefix = "__shared__ ", barrier = "__syncthreads();", float4 = "make_float4",
|
||||
half_prekernel = "",
|
||||
gid = [f'blockDim.{chr(120+i)}*blockIdx.{chr(120+i)}+threadIdx.{chr(120+i)}' for i in range(3)],
|
||||
lid = [f'threadIdx.{chr(120+i)}' for i in range(3)])
|
||||
|
||||
HIPBuffer = Compiled(RawHIPBuffer, HIPCodegen, HIPProgram, hip.hipDeviceSynchronize)
|
Loading…
Reference in New Issue