2
2
import os
3
3
import hashlib
4
4
import tarfile
5
- import urllib .request
6
5
import zipfile
7
-
8
- from tqdm import tqdm
9
6
from pathlib import Path
10
- from logger import logger
11
- from py7zr import SevenZipFile
12
7
13
-
14
- class TqdmUpTo (tqdm ):
15
- def update_to (self , b = 1 , bsize = 1 , tsize = None ):
16
- if tsize is not None :
17
- self .total = tsize
18
- self .update (b * bsize - self .n )
8
+ import requests
9
+ from py7zr import SevenZipFile
10
+ from tqdm import tqdm
11
+ from config import ABS_PATH
19
12
20
13
21
- def _download_file (url , dest_path ):
14
+ def _download_file (url , dest_path , max_retry = 1 ):
22
15
logging .info (f"Downloading: { url } " )
16
+
23
17
headers = {
24
18
'User-Agent' : 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/58.0.3029.110 Safari/537.3'
25
19
}
26
20
27
- if os .path .exists (dest_path ):
28
- file_size = os .path .getsize (dest_path )
29
- headers ['Range' ] = f'bytes={ file_size } -'
30
-
31
- request = urllib .request .Request (url , headers = headers )
32
-
33
- response = urllib .request .urlopen (request )
34
- if response .geturl () != url :
35
- return _download_file (response .geturl (), dest_path )
36
-
37
- total_size = int (response .headers ['Content-Length' ])
38
-
39
- with open (dest_path , 'ab' ) as file , tqdm (total = total_size , unit = 'B' , unit_scale = True , unit_divisor = 1024 , miniters = 1 ,
40
- desc = url .split ('/' )[- 1 ]) as t :
41
- chunk_size = 1024 * 1024 # 1MB
42
- while True :
43
- chunk = response .read (chunk_size )
44
- if not chunk :
45
- break
46
- file .write (chunk )
47
- t .update (len (chunk ))
21
+ try :
22
+ response = requests .head (url , headers = headers , allow_redirects = True , timeout = 10 )
23
+ if response .status_code >= 400 :
24
+ logging .error (f"Failed to connect to { url } , status code: { response .status_code } " )
25
+ return False , f"Failed to connect, status code: { response .status_code } "
26
+ except Exception as e :
27
+ logging .error (f"Failed to get file size for { url } : { e } " )
28
+ return False , f"Request timeout: { e } "
29
+
30
+ total_size = int (response .headers .get ('content-length' , 0 ))
31
+ file_size = os .path .getsize (dest_path ) if os .path .exists (dest_path ) else 0
32
+
33
+ if file_size == total_size :
34
+ logging .info (f"File { dest_path } already downloaded and complete." )
35
+ return True , "File already downloaded and complete."
36
+ elif file_size > total_size :
37
+ logging .warning (f"Local file size { file_size } exceeds server file size { total_size } . Removing local file." )
38
+ os .remove (dest_path )
39
+ if max_retry <= 0 :
40
+ return False , "Local file size exceeds server file size."
41
+ return _download_file (url , dest_path , max_retry = max_retry - 1 )
42
+
43
+ headers ['Range' ] = f'bytes={ file_size } -' if file_size > 0 else None
44
+
45
+ os .makedirs (os .path .dirname (dest_path ), exist_ok = True )
46
+
47
+ relative_path = os .path .relpath (dest_path , ABS_PATH )
48
+ chunk_size = 1024 * 1024 # 1MB
49
+
50
+ try :
51
+ with requests .get (url , headers = headers , stream = True , timeout = 10 ) as response , open (dest_path , 'ab' ) as file , tqdm (
52
+ total = total_size ,
53
+ initial = file_size ,
54
+ unit = 'B' ,
55
+ unit_scale = True ,
56
+ unit_divisor = 1024 ,
57
+ desc = f"Downloading: { relative_path or url .split ('/' )[- 1 ]} " ,
58
+ ) as progress :
59
+ for chunk in response .iter_content (chunk_size = chunk_size ):
60
+ if chunk :
61
+ file .write (chunk )
62
+ progress .update (len (chunk ))
63
+
64
+ logging .info (f"Download completed: { dest_path } " )
65
+ return True , "Download completed."
66
+ except Exception as e :
67
+ logging .error (f"Error during downloading { url } : { e } " )
68
+ if max_retry > 0 :
69
+ logging .info (f"Retrying download ({ max_retry } retries left)..." )
70
+ return _download_file (url , dest_path , max_retry = max_retry - 1 )
71
+ return False , f"Download failed: { e } "
48
72
49
73
50
74
def verify_md5 (file_path , expected_md5 ):
@@ -89,47 +113,43 @@ def extract_file(file_path, destination=None):
89
113
90
114
def download_file (urls , target_path , extract_destination = None , expected_md5 = None , expected_sha256 = None ):
91
115
if os .path .exists (target_path ):
116
+ success_msg = "File already exists and verified successfully!"
92
117
if expected_md5 is not None :
93
118
success , message = verify_md5 (Path (target_path ), expected_md5 )
94
- if not success :
95
- os .remove (target_path )
96
- return False , message
119
+ if success :
120
+ return True , success_msg
97
121
98
122
if expected_sha256 is not None :
99
123
success , message = verify_sha256 (Path (target_path ), expected_sha256 )
100
- if not success :
101
- os .remove (target_path )
102
- return False , message
124
+ if success :
125
+ return True , success_msg
103
126
104
127
# If it's a compressed file and the target_path already exists, skip the download
105
128
if extract_destination and target_path .endswith (('.zip' , '.tar.gz' , '.tar.bz2' , '.7z' )):
106
129
extract_file (target_path , extract_destination )
107
130
os .remove (target_path )
108
-
109
- return True , "File already exists and verified successfully!"
131
+ return True , success_msg
110
132
111
133
is_download = False
112
134
for url in urls :
113
135
try :
114
- _download_file (url , target_path )
115
- is_download = True
116
- break
136
+ is_download , _ = _download_file (url , target_path )
137
+ if is_download :
138
+ break
117
139
except Exception as error :
118
- logger .error (f"downloading from URL { url } : { error } " )
140
+ logging .error (f"downloading from URL { url } : { error } " )
119
141
120
142
if not is_download :
121
143
return False , "Error downloading from all provided URLs."
122
144
123
145
if expected_md5 is not None :
124
146
success , message = verify_md5 (Path (target_path ), expected_md5 )
125
147
if not success :
126
- os .remove (target_path )
127
148
return False , message
128
149
129
150
if expected_sha256 is not None :
130
151
success , message = verify_sha256 (Path (target_path ), expected_sha256 )
131
152
if not success :
132
- os .remove (target_path )
133
153
return False , message
134
154
135
155
# If it's a compressed file, extract it
@@ -141,14 +161,13 @@ def download_file(urls, target_path, extract_destination=None, expected_md5=None
141
161
142
162
143
163
if __name__ == "__main__" :
144
- URLS = [
145
- "YOUR_PRIMARY_URL_HERE" ,
146
- "YOUR_FIRST_BACKUP_URL_HERE" ,
147
- # ... you can add more backup URLs as needed
164
+ import logger
165
+
166
+ URL = [
167
+ "https://hf-mirror.com/hfl/chinese-roberta-wwm-ext-large/resolve/main/pytorch_model.bin" ,
148
168
]
149
- TARGET_PATH = " "
150
- EXPECTED_MD5 = ""
151
- EXTRACT_DESTINATION = ""
169
+ TARGET_PATH = r"E:\work\vits-simple-api\data\bert\chinese-roberta-wwm-ext-large/pytorch_model1.bin "
170
+ EXPECTED_MD5 = None
171
+ EXTRACT_DESTINATION = None
152
172
153
- success , message = download_file (URLS , TARGET_PATH , EXPECTED_MD5 , EXTRACT_DESTINATION )
154
- print (message )
173
+ print (download_file (URL , TARGET_PATH , EXPECTED_MD5 , EXTRACT_DESTINATION ))
0 commit comments