diff options
Diffstat (limited to 'fbin/file_storage')
-rw-r--r-- | fbin/file_storage/base.py | 39 | ||||
-rw-r--r-- | fbin/file_storage/filesystem.py | 44 | ||||
-rw-r--r-- | fbin/file_storage/s3.py | 55 |
3 files changed, 138 insertions, 0 deletions
diff --git a/fbin/file_storage/base.py b/fbin/file_storage/base.py new file mode 100644 index 0000000..6f39665 --- /dev/null +++ b/fbin/file_storage/base.py @@ -0,0 +1,39 @@ +import datetime + +from .. import db + +class BaseStorage: + def __init__(self, app): + self.app = app + + def add_file(self, file_hash, filename, size, user=None, ip=None): + '''Adds the file to the database. + + Call from store_file after the file is successfully stored.''' + with db.session_scope() as sess: + f = db.File(file_hash, filename, size, datetime.datetime.utcnow(), user.id if user else None, ip) + sess.add(f) + sess.commit() + sess.refresh(f) + return f + + def store_file(self, uploaded_file, file_hash, filename, user, ip): + '''Store uploaded_file.''' + raise NotImplementedError() + + def get_file(self, f): + '''Return a file object for the specified file. + + Subclasses can also return a flask.Response instance if required.''' + raise NotImplementedError() + + def delete_file(self, f): + '''Delete the specified file.''' + raise NotImplementedError() + + def temp_file(self, f): + '''Context manager which returns a temporary file for reading. + + This is used internally for eg. thumbnails.''' + raise NotImplementedError() + diff --git a/fbin/file_storage/filesystem.py b/fbin/file_storage/filesystem.py new file mode 100644 index 0000000..3433baf --- /dev/null +++ b/fbin/file_storage/filesystem.py @@ -0,0 +1,44 @@ +import contextlib +import os +import tempfile + +from .base import BaseStorage + +class Storage(BaseStorage): + def __init__(self, app): + super().__init__(app) + os.makedirs(self.app.config['FILE_DIRECTORY'], exist_ok=True) + + def store_file(self, uploaded_file, file_hash, user, ip): + size = uploaded_file.content_length + if hasattr(uploaded_file.stream, 'file'): + temp = None + temp_path = uploaded_file.stream.name + else: + temp = tempfile.NamedTemporaryFile(prefix='upload_', dir=self.app.config['FILE_DIRECTORY'], delete=False) + uploaded_file.save(temp.file) + temp_path = temp.name + size = os.path.getsize(temp_path) + try: + new_file = self.add_file(file_hash, uploaded_file.filename, size, user, ip) + os.rename(temp_path, new_file.get_path()) + return new_file + except: + os.unlink(temp.name) + raise + + def get_file(self, f): + path = f.get_path() + if not os.path.exists(path): + return + return path + + def delete_file(self, f): + path = f.get_path() + if os.path.exists(path): + os.unlink(path) + + @contextlib.contextmanager + def temp_file(self, f): + with open(f.get_path(), 'rb') as f: + yield f diff --git a/fbin/file_storage/s3.py b/fbin/file_storage/s3.py new file mode 100644 index 0000000..2f0b87b --- /dev/null +++ b/fbin/file_storage/s3.py @@ -0,0 +1,55 @@ +import contextlib +import tempfile + +import boto3 +from flask import request, send_file + +from .base import BaseStorage + +class Storage(BaseStorage): + def __init__(self, app): + super().__init__(app) + self.client = boto3.resource('s3', **self.app.config['S3_CONFIG']) + + def _get_object_key(self, file_hash, user_id): + return '{}_{}'.format(file_hash, user_id) + + def get_object_key(self, f): + return self._get_object_key(f.hash, f.user_id if f.user_id else 0) + + def store_file(self, uploaded_file, file_hash, user, ip): + bucket = self.client.Bucket(self.app.config['S3_BUCKET']) + key = self._get_object_key(file_hash, user.id if user else 0) + obj = bucket.upload_fileobj(Fileobj=uploaded_file.stream, Key=key) + size = uploaded_file.content_length + if not size: + obj = self.client.ObjectSummary(self.app.config['S3_BUCKET'], key) + size = obj.size + return self.add_file(file_hash, uploaded_file.filename, size, user, ip) + + def get_file(self, f): + obj = self.client.Object(self.app.config['S3_BUCKET'], self.get_object_key(f)) + kwargs = {} + if 'Range' in request.headers: + kwargs['Range'] = request.headers['Range'] + data = obj.get(**kwargs) + rv = send_file(data['Body'], attachment_filename=f.filename) + rv.headers['Content-Length'] = data['ContentLength'] + rv.headers['Accept-Ranges'] = data['AcceptRanges'] + if 'ContentRange' in data: + rv.headers['Content-Range'] = data['ContentRange'] + rv.status_code = 206 + return rv + + def delete_file(self, f): + obj = self.client.Object(self.app.config['S3_BUCKET'], self.get_object_key(f)) + obj.delete() + + @contextlib.contextmanager + def temp_file(self, f): + obj = self.client.Object(self.app.config['S3_BUCKET'], self.get_object_key(f)) + with tempfile.NamedTemporaryFile() as f: + obj.download_fileobj(f) + f.seek(0) + yield f + |