Skip to content

Commit 58be4fd

Browse files
committed
Update: Optimize download
1 parent e42797e commit 58be4fd

File tree

1 file changed

+75
-56
lines changed

1 file changed

+75
-56
lines changed

utils/download.py

+75-56
Original file line numberDiff line numberDiff line change
@@ -2,49 +2,73 @@
22
import os
33
import hashlib
44
import tarfile
5-
import urllib.request
65
import zipfile
7-
8-
from tqdm import tqdm
96
from pathlib import Path
10-
from logger import logger
11-
from py7zr import SevenZipFile
127

13-
14-
class TqdmUpTo(tqdm):
15-
def update_to(self, b=1, bsize=1, tsize=None):
16-
if tsize is not None:
17-
self.total = tsize
18-
self.update(b * bsize - self.n)
8+
import requests
9+
from py7zr import SevenZipFile
10+
from tqdm import tqdm
11+
from config import ABS_PATH
1912

2013

21-
def _download_file(url, dest_path):
14+
def _download_file(url, dest_path, max_retry=1):
2215
logging.info(f"Downloading: {url}")
16+
2317
headers = {
2418
'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/58.0.3029.110 Safari/537.3'
2519
}
2620

27-
if os.path.exists(dest_path):
28-
file_size = os.path.getsize(dest_path)
29-
headers['Range'] = f'bytes={file_size}-'
30-
31-
request = urllib.request.Request(url, headers=headers)
32-
33-
response = urllib.request.urlopen(request)
34-
if response.geturl() != url:
35-
return _download_file(response.geturl(), dest_path)
36-
37-
total_size = int(response.headers['Content-Length'])
38-
39-
with open(dest_path, 'ab') as file, tqdm(total=total_size, unit='B', unit_scale=True, unit_divisor=1024, miniters=1,
40-
desc=url.split('/')[-1]) as t:
41-
chunk_size = 1024 * 1024 # 1MB
42-
while True:
43-
chunk = response.read(chunk_size)
44-
if not chunk:
45-
break
46-
file.write(chunk)
47-
t.update(len(chunk))
21+
try:
22+
response = requests.head(url, headers=headers, allow_redirects=True, timeout=10)
23+
if response.status_code >= 400:
24+
logging.error(f"Failed to connect to {url}, status code: {response.status_code}")
25+
return False, f"Failed to connect, status code: {response.status_code}"
26+
except Exception as e:
27+
logging.error(f"Failed to get file size for {url}: {e}")
28+
return False, f"Request timeout: {e}"
29+
30+
total_size = int(response.headers.get('content-length', 0))
31+
file_size = os.path.getsize(dest_path) if os.path.exists(dest_path) else 0
32+
33+
if file_size == total_size:
34+
logging.info(f"File {dest_path} already downloaded and complete.")
35+
return True, "File already downloaded and complete."
36+
elif file_size > total_size:
37+
logging.warning(f"Local file size {file_size} exceeds server file size {total_size}. Removing local file.")
38+
os.remove(dest_path)
39+
if max_retry <= 0:
40+
return False, "Local file size exceeds server file size."
41+
return _download_file(url, dest_path, max_retry=max_retry - 1)
42+
43+
headers['Range'] = f'bytes={file_size}-' if file_size > 0 else None
44+
45+
os.makedirs(os.path.dirname(dest_path), exist_ok=True)
46+
47+
relative_path = os.path.relpath(dest_path, ABS_PATH)
48+
chunk_size = 1024 * 1024 # 1MB
49+
50+
try:
51+
with requests.get(url, headers=headers, stream=True, timeout=10) as response, open(dest_path, 'ab') as file, tqdm(
52+
total=total_size,
53+
initial=file_size,
54+
unit='B',
55+
unit_scale=True,
56+
unit_divisor=1024,
57+
desc=f"Downloading: {relative_path or url.split('/')[-1]}",
58+
) as progress:
59+
for chunk in response.iter_content(chunk_size=chunk_size):
60+
if chunk:
61+
file.write(chunk)
62+
progress.update(len(chunk))
63+
64+
logging.info(f"Download completed: {dest_path}")
65+
return True, "Download completed."
66+
except Exception as e:
67+
logging.error(f"Error during downloading {url}: {e}")
68+
if max_retry > 0:
69+
logging.info(f"Retrying download ({max_retry} retries left)...")
70+
return _download_file(url, dest_path, max_retry=max_retry - 1)
71+
return False, f"Download failed: {e}"
4872

4973

5074
def verify_md5(file_path, expected_md5):
@@ -89,47 +113,43 @@ def extract_file(file_path, destination=None):
89113

90114
def download_file(urls, target_path, extract_destination=None, expected_md5=None, expected_sha256=None):
91115
if os.path.exists(target_path):
116+
success_msg = "File already exists and verified successfully!"
92117
if expected_md5 is not None:
93118
success, message = verify_md5(Path(target_path), expected_md5)
94-
if not success:
95-
os.remove(target_path)
96-
return False, message
119+
if success:
120+
return True, success_msg
97121

98122
if expected_sha256 is not None:
99123
success, message = verify_sha256(Path(target_path), expected_sha256)
100-
if not success:
101-
os.remove(target_path)
102-
return False, message
124+
if success:
125+
return True, success_msg
103126

104127
# If it's a compressed file and the target_path already exists, skip the download
105128
if extract_destination and target_path.endswith(('.zip', '.tar.gz', '.tar.bz2', '.7z')):
106129
extract_file(target_path, extract_destination)
107130
os.remove(target_path)
108-
109-
return True, "File already exists and verified successfully!"
131+
return True, success_msg
110132

111133
is_download = False
112134
for url in urls:
113135
try:
114-
_download_file(url, target_path)
115-
is_download = True
116-
break
136+
is_download, _ = _download_file(url, target_path)
137+
if is_download:
138+
break
117139
except Exception as error:
118-
logger.error(f"downloading from URL {url}: {error}")
140+
logging.error(f"downloading from URL {url}: {error}")
119141

120142
if not is_download:
121143
return False, "Error downloading from all provided URLs."
122144

123145
if expected_md5 is not None:
124146
success, message = verify_md5(Path(target_path), expected_md5)
125147
if not success:
126-
os.remove(target_path)
127148
return False, message
128149

129150
if expected_sha256 is not None:
130151
success, message = verify_sha256(Path(target_path), expected_sha256)
131152
if not success:
132-
os.remove(target_path)
133153
return False, message
134154

135155
# If it's a compressed file, extract it
@@ -141,14 +161,13 @@ def download_file(urls, target_path, extract_destination=None, expected_md5=None
141161

142162

143163
if __name__ == "__main__":
144-
URLS = [
145-
"YOUR_PRIMARY_URL_HERE",
146-
"YOUR_FIRST_BACKUP_URL_HERE",
147-
# ... you can add more backup URLs as needed
164+
import logger
165+
166+
URL = [
167+
"https://hf-mirror.com/hfl/chinese-roberta-wwm-ext-large/resolve/main/pytorch_model.bin",
148168
]
149-
TARGET_PATH = ""
150-
EXPECTED_MD5 = ""
151-
EXTRACT_DESTINATION = ""
169+
TARGET_PATH = r"E:\work\vits-simple-api\data\bert\chinese-roberta-wwm-ext-large/pytorch_model1.bin"
170+
EXPECTED_MD5 = None
171+
EXTRACT_DESTINATION = None
152172

153-
success, message = download_file(URLS, TARGET_PATH, EXPECTED_MD5, EXTRACT_DESTINATION)
154-
print(message)
173+
print(download_file(URL, TARGET_PATH, EXPECTED_MD5, EXTRACT_DESTINATION))

0 commit comments

Comments
 (0)