-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathfileserver.py
392 lines (317 loc) · 13.6 KB
/
fileserver.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
import logging
import itertools
import threading
import os
import os.path
import json
import sys
import zmq
import serialization
class DownloadError(Exception):
pass
def file_request_msg(filename, offset, size):
"""Return serialized file request message"""
return json.dumps({'request': filename,
'offset': int(offset), 'size': int(size)})
def parse_file_request(message):
"""Return deserialized message as a tuple"""
def unpack(request, offset, size):
return request, int(offset), int(size)
filename, offset, size = unpack(**json.loads(message))
return filename, offset, size
def file_header(filename, offset, size):
"""Return serialized file header"""
return json.dumps({'filename': filename, 'offset': offset, 'size': size})
def parse_file_header(header):
"""Return deserialized header contents as a dict
The reason for returning a dict instead of a tuple is that the
contents of the header can take different forms, and thus also the
items in the returned dict.
"""
def unpack(filename, offset, size):
return filename, int(offset), int(size)
decoded_header = json.loads(header)
if 'error' in decoded_header:
return {'error': decoded_header['error']}
filename, offset, size = unpack(**decoded_header)
return {'filename': filename, 'offset': offset, 'size': size}
def error_msg(description):
"""Return serialized error message"""
return json.dumps({'error': description})
class FileChunker(object):
"""Handles chunked reading of file contents.
Reads chunks of a file into frames ready to be transmitted.
"""
def __init__(self, filename):
self.filename = filename
self.f = open(filename, 'rb')
def __del__(self):
# This check prevents us from trying to close a non-existant
# file in case open in __init__ raised an exception.
if hasattr(self, 'f'):
self.f.close()
def read(self, offset, size):
"""Read a chunk from opened file.
Given an offset and size, return a header and actual chunk
contents as a tuple.
The header contains offset, size and filename. The offset and
size will have the same values as the passed arguments. The
filename on the other hand will only contain the basename of
the opened file. This is because this header will be
transmitted and then the path to the file should not be
present.
"""
self.f.seek(offset)
contents = self.f.read(size)
header = {'filename': os.path.basename(self.filename),
'offset': offset, 'size': size}
return header, contents
class FileServer(threading.Thread):
def __init__(self, context, frontend_addr=None, pipe=None):
super(FileServer, self).__init__()
self.context = context
self.frontend_addr = frontend_addr
self.pipe = pipe
self._shared_files = []
self._bound_port = None
if not self.pipe:
self.pipe = self.context.socket(zmq.PAIR)
self.pipe.bind("inproc://fs-pipe")
def run(self, iterations=None,
use_frontend=True, frontend=None,
use_thread_pipe=True, thread_pipe=None):
poller = zmq.Poller()
if use_frontend:
if frontend is None:
frontend = self.context.socket(zmq.ROUTER)
if self.frontend_addr:
frontend.bind(self.frontend_addr)
port_str = self.frontend_addr.rpartition(':')[-1]
self._bound_port = int(port_str)
else:
self._bound_port = frontend.bind_to_random_port("tcp://*")
poller.register(frontend, zmq.POLLIN)
if use_thread_pipe:
if thread_pipe is None:
thread_pipe = self.context.socket(zmq.PAIR)
thread_pipe.connect("inproc://fs-pipe")
poller.register(thread_pipe, zmq.POLLIN)
for i in itertools.count():
if iterations is not None and iterations == i:
break
socks = dict(poller.poll())
if thread_pipe in socks and socks[thread_pipe] == zmq.POLLIN:
should_stop = self._handle_pipe(thread_pipe)
if should_stop:
break
elif frontend in socks and socks[frontend] == zmq.POLLIN:
self._handle_frontend(frontend)
def _handle_pipe(self, pipe):
data = pipe.recv()
if data == 'STOP':
return True
method, params = serialization.deserialize(data)
result = getattr(self, 'on_%s' % method)(params)
pipe.send(serialization.s_res(result))
return False
def _handle_frontend(self, frontend):
identity = frontend.recv()
message = frontend.recv()
frontend.send_multipart([identity] + self.on_frontend_message(message))
def add_file(self, filename):
"""Add file to be shared.
filename is the full path of the file to be shared.
The basename of the filename cannot collide with the basename
of another file already shared. This is because when serving
files on the frontend, basenames are used to identify the
files.
"""
self.pipe.send(serialization.s_req('add_file', filename))
response = self.pipe.recv()
if response and serialization.deserialize(response).result is False:
raise IndexError("File already present")
def on_add_file(self, filename):
if filename in self._shared_files:
return False
self._shared_files.append(filename)
return True
def remove_file(self, filename):
self.pipe.send(serialization.s_req('remove_file', filename))
response = self.pipe.recv()
def on_remove_file(self, filename):
self._shared_files.remove(filename)
def get_files(self):
self.pipe.send(serialization.s_req('get_files', None))
response = self.pipe.recv()
return serialization.deserialize(response).result
def on_get_files(self, dummy=None):
return tuple(self._shared_files)
def on_frontend_message(self, message):
# Process message from client. Returns list of zmq frames to
# be sent back to client.
filename, offset, size = parse_file_request(message)
for path in self._shared_files:
if os.path.basename(path) == filename:
try:
# This opens the file for every request. We'll
# want to cache this later
chunker = FileChunker(path)
header, contents = chunker.read(offset, size)
logging.debug('Server response header: %s', header)
frames = [file_header(**header), contents]
except IOError:
frames = [error_msg('read error')]
return frames
# If we got here we didn't have the requested file
return [error_msg('file not found')]
def get_bound_port(self):
self.pipe.send(serialization.s_req('get_bound_port', None))
response = self.pipe.recv()
return serialization.deserialize(response).result
def on_get_bound_port(self, dummy=None):
return self._bound_port
def stop(self):
self.pipe.send("STOP")
class Downloader(object):
"""Client for downloading files from server"""
def __init__(self, context, endpoint, filename, filesize, chunksize=None):
self.context = context
self.endpoint = endpoint
self.filename = filename
self.filesize = filesize
if chunksize is None:
chunksize = filesize
self.chunksize = chunksize
self.destination = os.path.join(os.getcwd(), filename)
self._has_downloaded = threading.Event()
self._synchronous = False
def _validate_header(self, header, offset, size):
# Check that header is consistent with what we expect.
# Otherwise raise DownloadError.
if 'error' in header:
raise DownloadError({'success': False, 'reason': header['error']})
elif header['filename'] != self.filename:
print header['filename']
print self.filename
raise DownloadError(
{'success': False,
'reason': 'Wrong filename received from server'})
elif header['offset'] != offset:
raise DownloadError(
{'success': False,
'reason': 'Wrong offset received from server'})
elif header['size'] != size:
raise DownloadError(
{'success': False,
'reason': 'Wrong size received from server'})
def _validate_chunk(self, chunk, size):
# Check that chunk is consistent with what we expect.
# Otherwise raise DownloadError.
if len(chunk) != size:
raise DownloadError(
{'success': False,
'reason': 'Wrong data received from server'})
def _chunks(self, chunksize, total):
# Generate tuples of offset and size describing chunks of a
# file. Continue until total size has been exhausted.
offset = 0
while offset < total:
size = min(chunksize, total - offset)
yield offset, size
offset += size
def _get_all_chunks(self, socket, filehandle, progress_callback):
# Query all chunks making up a file from socket and write it
# to filehandle.
total_chunks = (self.filesize + self.chunksize - 1) / self.chunksize
chunk_params = self._chunks(self.chunksize, self.filesize)
# For each chunk, send a request, receive the response and
# append it to file.
for curr_chunk, (offset, size) in enumerate(chunk_params):
# Create and send request to server
request_msg = file_request_msg(self.filename, offset, size)
logging.debug('Request chunk: %s', request_msg)
socket.send(request_msg.encode('utf-8'))
# The first frame contains the header
first_frame = socket.recv()
try:
header = parse_file_header(first_frame)
except (ValueError, TypeError):
raise DownloadError(
{'success': False,
'reason': 'Invalid data received from server'})
# Header had the right structure, now check that it is
# what we expect it to be
self._validate_header(header, offset, size)
# Retreive contents and validate
file_chunk = socket.recv()
self._validate_chunk(file_chunk, size)
filehandle.write(file_chunk)
if progress_callback is not None:
progress_callback(float(curr_chunk + 1) / total_chunks)
def _do_download(self, callback, progress_callback):
# This is the part that normally runs in a thread
socket = self.context.socket(zmq.DEALER)
socket.connect(self.endpoint)
with open(self.destination, 'wb') as f:
try:
self._get_all_chunks(socket, f, progress_callback)
except DownloadError as error:
callback(error.args[0])
return
self._has_downloaded.set()
callback({'success': True, 'path': self.destination})
@property
def has_downloaded(self):
"""True if file has been stored on file system."""
return self._has_downloaded.is_set()
def download(self, callback, progress_callback=None):
"""Start download of file to disk
Use callback to signal the result of the operation. callback
is passed a dict as argument which contains a key 'success'
with a value of True or False. If it is True, it indicates
success and the dict will also contain a key 'path' with a
value giving the path of the downloaded file on the file
system. If 'success' is False, the dict will instead contain
a key 'reason' with a string as value, stating the reason for
the failure.
If progress_callback is not None it will be called for each
received chunk, after it has been written to file. The value
of the fraction num_chunks_received / total_chunks is passed
as argument for each call to indicate download progress.
"""
if os.path.exists(self.destination):
callback({'success': False, 'reason': 'File already exists'})
return
# If we are unit testing, don't start a new thread
if self._synchronous:
self._do_download(callback, progress_callback)
else:
threading.Thread(target=self._do_download,
args=(callback, progress_callback)).start()
if __name__ == '__main__':
context = zmq.Context()
server = FileServer(context)
server.start()
server.add_file("tests.py")
server.add_file("README.md")
print server.get_files()
server.remove_file("README.md")
print server.get_files()
serverport = server.get_bound_port()
print serverport
try:
req = raw_input()
while True:
s = context.socket(zmq.DEALER)
s.connect("tcp://localhost:%d" % serverport)
msg = file_request_msg(req, 0, 42)
print "Sending:", msg
s.send(msg)
header = s.recv()
print header
if 'filename' in header:
print s.recv()
req = raw_input()
except (EOFError, KeyboardInterrupt):
pass
server.stop()