diff --git a/test/unit/test_helpers.py b/test/unit/test_helpers.py index 7306801f..7bc74592 100644 --- a/test/unit/test_helpers.py +++ b/test/unit/test_helpers.py @@ -1,4 +1,4 @@ -import unittest +import gzip, unittest from PIL import Image from tinygrad.helpers import Context, ContextVar from tinygrad.helpers import merge_dicts, strip_parens, prod, round_up, fetch, fully_flatten, from_mv, to_mv, get_contraction, get_shape @@ -161,6 +161,24 @@ class TestFetch(unittest.TestCase): assert pimg.size == (77, 77), pimg.size assert img.parent.name == "images" + def test_fetch_gunzip_valid(self): + # compare fetch(gunzip=True) to fetch(gunzip=False) plus decompressing afterwards + gzip_url: str = 'https://ftp.gnu.org/gnu/gzip/gzip-1.13.tar.gz' + fp_gz = fetch(gzip_url, gunzip=True) + fp_no_gz = fetch(gzip_url, gunzip=False) + with open(fp_gz, 'rb') as f: content_gz = f.read() + with open(fp_no_gz, 'rb') as f: content_no_gz = gzip.decompress(f.read()) + assert fp_gz.stat().st_size > fp_no_gz.stat().st_size + assert isinstance(content_gz, bytes) and isinstance(content_no_gz, bytes) + assert len(content_gz) == len(content_no_gz) + assert content_gz == content_no_gz + + def test_fetch_gunzip_invalid(self): + # given a non-gzipped file, fetch(gunzip=True) fails + no_gzip_url: str = 'https://ftp.gnu.org/gnu/gzip/gzip-1.13.zip' + with self.assertRaises(gzip.BadGzipFile): + fetch(no_gzip_url, gunzip=True) + class TestFullyFlatten(unittest.TestCase): def test_fully_flatten(self): self.assertEqual(fully_flatten([[1, 3], [1, 2]]), [1, 3, 1, 2]) diff --git a/tinygrad/helpers.py b/tinygrad/helpers.py index 5e7a117e..9a11e5da 100644 --- a/tinygrad/helpers.py +++ b/tinygrad/helpers.py @@ -1,5 +1,5 @@ from __future__ import annotations -import os, functools, platform, time, re, contextlib, operator, hashlib, pickle, sqlite3, tempfile, pathlib, string, ctypes, sys +import os, functools, platform, time, re, contextlib, operator, hashlib, pickle, sqlite3, tempfile, pathlib, string, ctypes, sys, gzip import itertools, urllib.request, subprocess, shutil, math, json, contextvars from dataclasses import dataclass from typing import Dict, Tuple, Union, List, ClassVar, Optional, Iterable, Any, TypeVar, TYPE_CHECKING, Callable, Sequence @@ -262,19 +262,22 @@ def diskcache(func): # *** http support *** -def fetch(url:str, name:Optional[Union[pathlib.Path, str]]=None, subdir:Optional[str]=None, +def fetch(url:str, name:Optional[Union[pathlib.Path, str]]=None, subdir:Optional[str]=None, gunzip:bool=False, allow_caching=not getenv("DISABLE_HTTP_CACHE")) -> pathlib.Path: if url.startswith(("/", ".")): return pathlib.Path(url) if name is not None and (isinstance(name, pathlib.Path) or '/' in name): fp = pathlib.Path(name) - else: fp = pathlib.Path(_cache_dir) / "tinygrad" / "downloads" / (subdir or "") / (name or hashlib.md5(url.encode('utf-8')).hexdigest()) + else: + fp = pathlib.Path(_cache_dir) / "tinygrad" / "downloads" / (subdir or "") / \ + ((name or hashlib.md5(url.encode('utf-8')).hexdigest()) + (".gunzip" if gunzip else "")) if not fp.is_file() or not allow_caching: with urllib.request.urlopen(url, timeout=10) as r: assert r.status == 200 total_length = int(r.headers.get('content-length', 0)) progress_bar = tqdm(total=total_length, unit='B', unit_scale=True, desc=f"{url}", disable=CI) (path := fp.parent).mkdir(parents=True, exist_ok=True) + readfile = gzip.GzipFile(fileobj=r) if gunzip else r with tempfile.NamedTemporaryFile(dir=path, delete=False) as f: - while chunk := r.read(16384): progress_bar.update(f.write(chunk)) + while chunk := readfile.read(16384): progress_bar.update(f.write(chunk)) f.close() progress_bar.update(close=True) if (file_size:=os.stat(f.name).st_size) < total_length: raise RuntimeError(f"fetch size incomplete, {file_size} < {total_length}")