waifu2x vgg7: testcase, auto-RGBA->RGB, function to grab pretrained models, training "fix" (#2117)

This commit is contained in:
20kdc 2023-10-20 06:07:15 +01:00 committed by GitHub
parent e0b2bf46b4
commit bedd028061
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 51 additions and 9 deletions

View File

@ -26,23 +26,23 @@ def set_sample_count(samples_dir, sc):
file.write(str(sc) + "\n")
if len(sys.argv) < 2:
print("python3 -m examples.vgg7 import MODELJSON MODELDIR")
print("python3 -m examples.vgg7 import MODELJSON MODEL")
print(" imports a waifu2x JSON vgg_7 model, i.e. waifu2x/models/vgg_7/art/scale2.0x_model.json")
print(" into a safetensors file")
print(" weight tensors are ordered in tinygrad/ncnn form, as so: (outC,inC,H,W)")
print(" *this format is used by most other commands in this program*")
print("python3 -m examples.vgg7 import_kinne MODELDIR MODEL_SAFETENSORS")
print("python3 -m examples.vgg7 import_kinne MODEL_KINNE MODEL_SAFETENSORS")
print(" imports a model in 'KINNE' format (raw floats: used by older versions of this example) into safetensors")
print("python3 -m examples.vgg7 execute MODELDIR IMG_IN IMG_OUT")
print("python3 -m examples.vgg7 execute MODEL IMG_IN IMG_OUT")
print(" given an already-nearest-neighbour-scaled image, runs vgg7 on it")
print(" output image has 7 pixels removed on all edges")
print(" do not run on large images, will have *hilarious* RAM use")
print("python3 -m examples.vgg7 execute_full MODELDIR IMG_IN IMG_OUT")
print("python3 -m examples.vgg7 execute_full MODEL IMG_IN IMG_OUT")
print(" does the 'whole thing' (padding, tiling)")
print(" safe for large images, etc.")
print("python3 -m examples.vgg7 new MODELDIR")
print("python3 -m examples.vgg7 new MODEL")
print(" creates a new model (experimental)")
print("python3 -m examples.vgg7 train MODELDIR SAMPLES_DIR ROUNDS ROUNDS_SAVE")
print("python3 -m examples.vgg7 train MODEL SAMPLES_DIR ROUNDS ROUNDS_SAVE")
print(" trains a model (experimental)")
print(" (how experimental? well, every time I tried it, it flooded w/ NaNs)")
print(" note: ROUNDS < 0 means 'forever'. ROUNDS_SAVE <= 0 is not a good idea.")
@ -130,7 +130,7 @@ elif cmd == "train":
# This is used to try and get the network to focus on "interesting" samples,
# which works nicely with the microsample system.
sample_probs = None
sample_probs_path = model + "/sample_probs.bin"
sample_probs_path = model + "_sample_probs.bin"
try:
# try to read...
sample_probs = numpy.fromfile(sample_probs_path, "<f8")
@ -191,7 +191,7 @@ elif cmd == "train":
optim.step()
# warning: used by sample probability adjuster
loss_indicator = loss.max().numpy()[0]
loss_indicator = loss.max().numpy()
print("Round " + str(rnum) + " : " + str(loss_indicator))
if (rnum % rounds_per_save) == 0:

View File

@ -4,6 +4,8 @@
import numpy
from tinygrad.tensor import Tensor
from PIL import Image
from pathlib import Path
from extra.utils import download_file
# File Formats
@ -17,6 +19,9 @@ def image_load(path) -> numpy.ndarray:
"""
# file
na = numpy.array(Image.open(path))
if na.shape[2] == 4:
# RGBA -> RGB (covers opaque images with alpha channels)
na = na[:,:,0:3]
# fix shape
na = numpy.moveaxis(na, [2,0,1], [0,1,2])
# shape is now (3,h,w), add 1
@ -113,6 +118,19 @@ class Vgg7:
def get_parameters(self) -> list:
return self.conv1.get_parameters() + self.conv2.get_parameters() + self.conv3.get_parameters() + self.conv4.get_parameters() + self.conv5.get_parameters() + self.conv6.get_parameters() + self.conv7.get_parameters()
def load_from_pretrained(self, intent = "art", subtype = "scale2.0x"):
"""
Downloads a nagadomi/waifu2x JSON weight file and loads it.
"""
fn = Path(__file__).parents[2] / ("weights/vgg_7_" + intent + "_" + subtype + "_model.json")
download_file("https://github.com/nagadomi/waifu2x/raw/master/models/vgg_7/" + intent + "/" + subtype + "_model.json", fn)
import json
with open(fn, "rb") as f:
data = json.load(f)
self.load_waifu2x_json(data)
def load_waifu2x_json(self, data: list):
"""
Loads weights from one of the waifu2x JSON files, i.e. waifu2x/models/vgg_7/art/noise0_model.json
@ -126,7 +144,6 @@ class Vgg7:
self.conv6.load_waifu2x_json(data[5])
self.conv7.load_waifu2x_json(data[6])
def forward_tiled(self, image: numpy.ndarray, tile_size: int) -> numpy.ndarray:
"""
Given an ndarray image as loaded by image_load (NOT a tensor), scales it, pads it, splits it up, forwards the pieces, and reconstitutes it.

View File

@ -0,0 +1,25 @@
#!/usr/bin/env python
import pathlib
import unittest
import numpy as np
from tinygrad.tensor import Tensor
from tinygrad.ops import Device
class TestVGG7(unittest.TestCase):
def test_vgg7(self):
from examples.vgg7_helpers.waifu2x import Vgg7, image_load
# Create in tinygrad
Tensor.manual_seed(1337)
mdl = Vgg7()
mdl.load_from_pretrained()
# Scale up an image
test_x = image_load(pathlib.Path(__file__).parent / 'waifu2x/input.png')
test_y = image_load(pathlib.Path(__file__).parent / 'waifu2x/output.png')
scaled = mdl.forward_tiled(test_x, 156)
scaled = np.fmax(0, np.fmin(1, scaled))
np.testing.assert_allclose(scaled, test_y, atol=5e-3, rtol=5e-3)
if __name__ == '__main__':
unittest.main()

Binary file not shown.

After

Width:  |  Height:  |  Size: 7.1 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 15 KiB