From 0faa732c9a3e1ffced2b26bee682f513b0e5f0ae Mon Sep 17 00:00:00 2001 From: Jon Bergli Heier Date: Wed, 28 Oct 2020 19:16:34 +0100 Subject: Use flask-sqlalchemy instead of using sqlalchemy directly This makes database access a bit easier and also greatly simplifies some upcoming changes. --- fbin-scanner.py | 72 +++++++++++----------- fbin/__init__.py | 3 +- fbin/api.py | 26 ++++---- fbin/db.py | 71 ++++++++-------------- fbin/fbin.py | 152 ++++++++++++++++++++++------------------------ fbin/file_storage/base.py | 17 +++--- fbin/login.py | 20 +++--- 7 files changed, 164 insertions(+), 197 deletions(-) diff --git a/fbin-scanner.py b/fbin-scanner.py index 2e60522..76bc33b 100644 --- a/fbin-scanner.py +++ b/fbin-scanner.py @@ -80,47 +80,47 @@ def get_report(dbfile, digest, fileobj): def main(): storage = importlib.import_module(current_app.config.get('STORAGE_MODULE', 'fbin.file_storage.filesystem')).Storage(current_app) - with session_scope() as session: - files = deque(session.query(File).filter(File.scanned == False).all()) - while len(files): - dbfile = files.pop() - if not dbfile.get_size(): - logger.info('Ignoring file %s/%s due to unknown size', dbfile.filename, dbfile.hash) - continue - if dbfile.get_size() > 32*10**6: - logger.info('Ignoring file %s/%s due to size (%s)', dbfile.filename, dbfile.hash, dbfile.formatted_size) - continue - logger.info('Checking file %s/%s (%s)', dbfile.filename, dbfile.hash, dbfile.formatted_size) - try: - with storage.temp_file(dbfile) as f: - h = hashlib.sha256() + files = deque(db.session.query(File).filter(File.scanned == False).all()) + while len(files): + dbfile = files.pop() + if not dbfile.get_size(): + logger.info('Ignoring file %s/%s due to unknown size', dbfile.filename, dbfile.hash) + continue + if dbfile.get_size() > 32*10**6: + logger.info('Ignoring file %s/%s due to size (%s)', dbfile.filename, dbfile.hash, dbfile.formatted_size) + continue + logger.info('Checking file %s/%s (%s)', dbfile.filename, dbfile.hash, dbfile.formatted_size) + try: + with storage.temp_file(dbfile) as f: + h = hashlib.sha256() + chunk = f.read(2**10*16) + while chunk: + h.update(chunk) chunk = f.read(2**10*16) - while chunk: - h.update(chunk) - chunk = f.read(2**10*16) - f.seek(0) - digest = h.hexdigest() - logger.info('SHA-256: %s', digest) - report = get_report(dbfile, digest, f) - except: - logger.exception('Failed to get report for %s/%s', dbfile.filename, dbfile.hash) - # Most likely an error from virustotal, so just break here and retry later. - break - dbfile.scanned = True - if report and any(r.get('detected', False) for r in report['scans'].values()): - logger.warning('Positive match') - dbfile.blocked_reason = report - else: - logger.info('No match') - session.add(dbfile) - session.commit() - time.sleep(FILE_DELAY) - logger.info('No more files to scan') + f.seek(0) + digest = h.hexdigest() + logger.info('SHA-256: %s', digest) + report = get_report(dbfile, digest, f) + except: + logger.exception('Failed to get report for %s/%s', dbfile.filename, dbfile.hash) + # Most likely an error from virustotal, so just break here and retry later. + break + dbfile.scanned = True + if report and any(r.get('detected', False) for r in report['scans'].values()): + logger.warning('Positive match') + dbfile.blocked_reason = report + else: + logger.info('No match') + db.session.add(dbfile) + db.session.commit() + time.sleep(FILE_DELAY) + logger.info('No more files to scan') app = Flask('scanner') with app.app_context(): app.config.from_pyfile(args.config_file) - from fbin.db import session_scope, File + from fbin.db import db, File + db.init_app(app) config = app.config main() diff --git a/fbin/__init__.py b/fbin/__init__.py index 6c6a9f5..4bc2ef3 100644 --- a/fbin/__init__.py +++ b/fbin/__init__.py @@ -30,9 +30,10 @@ def context_processors(): } with app.app_context(): - from .fbin import app as fbin + from .fbin import app as fbin, db from .api import app as api from .login import login_manager app.register_blueprint(fbin) app.register_blueprint(api, url_prefix = '/api') login_manager.init_app(app) + db.init_app(app) diff --git a/fbin/api.py b/fbin/api.py index 4f605f0..8f3f86c 100644 --- a/fbin/api.py +++ b/fbin/api.py @@ -6,7 +6,7 @@ from flask.views import MethodView from flask_login import current_user import jwt -from . import db +from .db import db, User, NoResultFound from .fbin import upload as fbin_upload, get_file app = Blueprint('api', __name__) @@ -32,17 +32,16 @@ def authenticate(): token = jwt.decode(token, current_app.config['SECRET_KEY'], issuer = request.url_root) except jwt.InvalidTokenError: abort(403) - with db.session_scope() as s: - try: - user = s.query(db.User).filter(db.User.id == token['sub']).one() - token_datetime = datetime.datetime.fromtimestamp(token['iat']) - # If token was issued before api_key_date was updated, consider it invalid. - if token_datetime < user.api_key_date: - abort(403) - else: - g.user = user - except db.NoResultFound: + try: + user = db.session.query(User).filter(User.id == token['sub']).one() + token_datetime = datetime.datetime.fromtimestamp(token['iat']) + # If token was issued before api_key_date was updated, consider it invalid. + if token_datetime < user.api_key_date: abort(403) + else: + g.user = user + except NoResultFound: + abort(403) def api_login_required(f): def wrapper(*args, **kwargs): @@ -74,9 +73,8 @@ class FileAPI(MethodView): 'status': False, 'message': 'Empty or missing filename', } - with db.session_scope() as sess: - f.filename = filename - sess.add(f) + f.filename = filename + db.session.add(f) return { 'status': True, } diff --git a/fbin/db.py b/fbin/db.py index be79c76..876efeb 100644 --- a/fbin/db.py +++ b/fbin/db.py @@ -4,57 +4,52 @@ import mimetypes import os from flask import current_app -from sqlalchemy import create_engine, Column, Integer, String, DateTime, Text, Index, ForeignKey, Boolean, JSON, BigInteger -from sqlalchemy.ext.declarative import declarative_base -from sqlalchemy.orm import sessionmaker, relation, backref from sqlalchemy.orm.exc import NoResultFound from sqlalchemy.exc import IntegrityError -from sqlalchemy.sql import and_ +from flask_sqlalchemy import SQLAlchemy -engine = create_engine(current_app.config['DB_URI']) +db = SQLAlchemy() -Base = declarative_base(bind = engine) - -class User(Base): +class User(db.Model): __tablename__ = 'users' - id = Column(Integer, primary_key = True) - username = Column(String, unique = True, index = True) - jab_id = Column(String(24), unique = True, index = True) - api_key_date = Column(DateTime, default = datetime.datetime.utcnow) - files = relation('File', backref = 'user', order_by = 'File.date.desc()') + id = db.Column(db.Integer, primary_key=True) + username = db.Column(db.String, unique=True, index=True) + jab_id = db.Column(db.String(24), unique=True, index=True) + api_key_date = db.Column(db.DateTime, default=datetime.datetime.utcnow) + files = db.relation('File', backref='user', order_by='File.date.desc()') def __init__(self, username, jab_id): self.username = username self.jab_id = jab_id -class UserSession(Base): +class UserSession(db.Model): __tablename__ = 'sessions' - id = Column(Integer, primary_key = True) - user_id = Column(Integer, ForeignKey('users.id'), index = True) - access_token = Column(String) - refresh_token = Column(String) - updated = Column(DateTime) + id = db.Column(db.Integer, primary_key=True) + user_id = db.Column(db.Integer, db.ForeignKey('users.id'), index=True) + access_token = db.Column(db.String) + refresh_token = db.Column(db.String) + updated = db.Column(db.DateTime) def __init__(self, user_id, access_token, refresh_token): self.user_id = user_id self.access_token = access_token self.refresh_token = refresh_token -class File(Base): +class File(db.Model): __tablename__ = 'files' - id = Column(Integer, primary_key = True) - hash = Column(String, unique = True, index = True) - filename = Column(String) - size = Column(BigInteger) - date = Column(DateTime) - user_id = Column(Integer, ForeignKey('users.id'), nullable = True) - ip = Column(String) - accessed = Column(DateTime) - scanned = Column(Boolean, nullable=False, default=False) - blocked_reason = Column(JSON) + id = db.Column(db.Integer, primary_key=True) + hash = db.Column(db.String, unique=True, index=True) + filename = db.Column(db.String) + size = db.Column(db.BigInteger) + date = db.Column(db.DateTime) + user_id = db.Column(db.Integer, db.ForeignKey('users.id'), nullable=True) + ip = db.Column(db.String) + accessed = db.Column(db.DateTime) + scanned = db.Column(db.Boolean, nullable=False, default=False) + blocked_reason = db.Column(db.JSON) def __init__(self, hash, filename, size, date, user_id = None, ip = None): self.hash = hash @@ -114,19 +109,3 @@ class File(Base): @property def exists(self): return os.path.exists(self.get_path()) - -Base.metadata.create_all() -Session = sessionmaker(bind = engine, autoflush = True, autocommit = False) - -@contextmanager -def session_scope(): - session = Session() - try: - session.expire_on_commit = False - yield session - session.commit() - except: - session.rollback() - raise - finally: - session.close() diff --git a/fbin/fbin.py b/fbin/fbin.py index fce6ea5..ac4569e 100755 --- a/fbin/fbin.py +++ b/fbin/fbin.py @@ -22,7 +22,7 @@ from PIL import Image import requests from werkzeug.utils import secure_filename -from . import db +from .db import db, User, UserSession, File, NoResultFound, IntegrityError from .monkey import patch as monkey_patch from .login import login_manager, load_user from .file_storage.exceptions import StorageError @@ -45,54 +45,50 @@ else: has_mogrify = True def get_or_create_user(username, jab_id): - with db.session_scope() as sess: + try: + return db.session.query(User).filter(User.jab_id == jab_id).one() + except NoResultFound: try: - return sess.query(db.User).filter(db.User.jab_id == jab_id).one() - except db.NoResultFound: - try: - user = db.User(username, jab_id) - sess.add(user) - sess.commit() - sess.refresh(user) - return user - except db.IntegrityError: - return None + user = User(username, jab_id) + db.session.add(user) + db.session.commit() + db.session.refresh(user) + return user + except IntegrityError: + return None def get_file(file_hash, user_id=None, update_accessed=False): - with db.session_scope() as sess: - try: - f = sess.query(db.File).filter(db.File.hash == file_hash) - if user_id: - f = f.filter(db.File.user_id == user_id) - f = f.one() - except db.NoResultFound: - return None - if update_accessed: - f.accessed = datetime.datetime.utcnow() - sess.add(f) - sess.commit() - # Refresh after field update. - sess.refresh(f) - return f + try: + f = db.session.query(File).filter(File.hash == file_hash) + if user_id: + f = f.filter(File.user_id == user_id) + f = f.one() + except NoResultFound: + return None + if update_accessed: + f.accessed = datetime.datetime.utcnow() + db.session.add(f) + db.session.commit() + # Refresh after field update. + db.session.refresh(f) + return f def get_files(user): - with db.session_scope() as sess: - try: - sess.add(user) - files = user.files - except db.NoResultFound: - return [] + try: + db.session.add(user) + files = user.files + except NoResultFound: + return [] return files def delete_file(file): - with db.session_scope() as sess: - sess.delete(file) - sess.commit() - filename = file.get_path() - storage.delete_file(file) - thumbfile = file.get_thumb_path() - if os.path.exists(thumbfile): - os.unlink(thumbfile) + db.session.delete(file) + db.session.commit() + filename = file.get_path() + storage.delete_file(file) + thumbfile = file.get_thumb_path() + if os.path.exists(thumbfile): + os.unlink(thumbfile) app = Blueprint('fbin', __name__) @@ -223,11 +219,10 @@ def logout(): if not current_user.is_authenticated: return redirect(url_for('.index')) session_id = int(current_user.get_id().split(':', 1)[-1]) - with db.session_scope() as s: - try: - s.query(db.UserSession).filter_by(id = session_id).delete() - except: - raise + try: + db.session.query(UserSession).filter_by(id = session_id).delete() + except: + raise logout_user() return redirect(url_for('.index')) @@ -272,12 +267,11 @@ def auth(): response = rs.get(urljoin(current_app.config['OAUTH_URL'], '/api/user'), headers = {'Authorization': 'Bearer {}'.format(token['access_token'])}) user = response.json() user = get_or_create_user(user['username'], user['id']) - with db.session_scope() as s: - us = db.UserSession(user.id, token['access_token'], token['refresh_token']) - us.updated = datetime.datetime.utcnow() - s.add(us) - s.commit() - s.refresh(us) + us = UserSession(user.id, token['access_token'], token['refresh_token']) + us.updated = datetime.datetime.utcnow() + db.session.add(us) + db.session.commit() + db.session.refresh(us) user = load_user('{}:{}'.format(user.id, us.id)) if not user: flash('Failed to retrieve user instance.', 'error') @@ -293,7 +287,7 @@ def files(): context = { 'title': 'Files', 'files': files, - 'total_size': db.File.pretty_size(sum(size for size in (f.get_size() for f in files) if size is not None)), + 'total_size': File.pretty_size(sum(size for size in (f.get_size() for f in files) if size is not None)), } return render_template('files.html', **context) @@ -306,25 +300,24 @@ def file_edit(): flash('File not found.', 'error') return redirect(url_for('.files')) if 'filename' in request.form: - with db.session_scope() as sess: - old_path = f.get_path() - filename = request.form.get('filename', f.filename) - f.filename = filename - new_path = f.get_path() - # If extension changed, the local filename also changes. We could just store the file without the extension, - # but that would break the existing files, requiring a manual rename. - if old_path != new_path: - try: - if os.path.exists(new_path): - # This shouldn't happen unless we have two files with the same hash, which should be impossible. - raise RuntimeError() - else: - os.rename(old_path, new_path) - except: - flash(Markup('Internal rename failed; file may have become unreachable. ' - 'Please contact an admin and specify hash={}.'.format(f.hash)), 'error') - sess.add(f) - flash('Filename changed to "{}".'.format(f.filename), 'success') + old_path = f.get_path() + filename = request.form.get('filename', f.filename) + f.filename = filename + new_path = f.get_path() + # If extension changed, the local filename also changes. We could just store the file without the extension, + # but that would break the existing files, requiring a manual rename. + if old_path != new_path: + try: + if os.path.exists(new_path): + # This shouldn't happen unless we have two files with the same hash, which should be impossible. + raise RuntimeError() + else: + os.rename(old_path, new_path) + except: + flash(Markup('Internal rename failed; file may have become unreachable. ' + 'Please contact an admin and specify hash={}.'.format(f.hash)), 'error') + db.session.add(f) + flash('Filename changed to "{}".'.format(f.filename), 'success') elif 'delete' in request.form: try: delete_file(f) @@ -345,7 +338,7 @@ def images(): 'title': 'Images', 'fullwidth': True, 'files': files, - 'total_size': db.File.pretty_size(sum(size for size in (f.get_size() for f in files) if size is not None)), + 'total_size': File.pretty_size(sum(size for size in (f.get_size() for f in files) if size is not None)), } return render_template('images.html', **context) @@ -358,7 +351,7 @@ def videos(): 'title': 'Videos', 'fullwidth': True, 'files': files, - 'total_size': db.File.pretty_size(sum(size for size in (f.get_size() for f in files) if size is not None)), + 'total_size': File.pretty_size(sum(size for size in (f.get_size() for f in files) if size is not None)), } return render_template('images.html', **context) @@ -428,12 +421,11 @@ def generate_api_key(): @app.route('/invalidate-api-keys') @login_required def invalidate_api_keys(): - with db.session_scope() as s: - user = current_user.user - s.add(user) - user.api_key_date = datetime.datetime.utcnow() - s.commit() - flash('All API keys invalidated.', 'success') + user = current_user.user + db.session.add(user) + user.api_key_date = datetime.datetime.utcnow() + db.session.commit() + flash('All API keys invalidated.', 'success') return redirect(request.referrer) login_manager.login_view = '.login' diff --git a/fbin/file_storage/base.py b/fbin/file_storage/base.py index e2ca1a6..abdf580 100644 --- a/fbin/file_storage/base.py +++ b/fbin/file_storage/base.py @@ -1,6 +1,6 @@ import datetime -from .. import db +from ..db import db, File from .exceptions import * class BaseStorage: @@ -11,19 +11,18 @@ class BaseStorage: user = file.user_id is not None size_limit = self.app.config.get('USER_FILE_SIZE_LIMIT' if user else 'ANONYMOUS_FILE_SIZE_LIMIT') if size_limit is not None and file.size > size_limit: - raise FileSizeError('The file size is too large (max {})'.format(db.File.pretty_size(size_limit))) + raise FileSizeError('The file size is too large (max {})'.format(File.pretty_size(size_limit))) def add_file(self, file_hash, filename, size, user=None, ip=None, verify=True): '''Adds the file to the database. Call from store_file after the file is successfully stored.''' - with db.session_scope() as sess: - f = db.File(file_hash, filename, size, datetime.datetime.utcnow(), user.id if user else None, ip) - # Raises on invalid files - self.verify_file(f) - sess.add(f) - sess.commit() - sess.refresh(f) + f = File(file_hash, filename, size, datetime.datetime.utcnow(), user.id if user else None, ip) + # Raises on invalid files + self.verify_file(f) + db.session.add(f) + db.session.commit() + db.session.refresh(f) return f def store_file(self, uploaded_file, file_hash, filename, user, ip): diff --git a/fbin/login.py b/fbin/login.py index b365e75..b9602a8 100644 --- a/fbin/login.py +++ b/fbin/login.py @@ -7,11 +7,11 @@ from flask_login import LoginManager import jwt import requests -from . import db +from .db import db, User, UserSession login_manager = LoginManager() -class User: +class BinUser: def __init__(self, user, user_session): self.user = user self.user_session = user_session @@ -38,12 +38,11 @@ class User: traceback.print_exc() flash('Failed to refresh authentication token (verification failed)', 'error') return - with db.session_scope() as sess: - self.user_session.access_token = token['access_token'] - self.user_session.refresh_token = token['refresh_token'] - self.user_session.updated = datetime.datetime.utcnow() - sess.add(self.user_session) - sess.commit() + self.user_session.access_token = token['access_token'] + self.user_session.refresh_token = token['refresh_token'] + self.user_session.updated = datetime.datetime.utcnow() + db.session.add(self.user_session) + db.session.commit() return True @property @@ -88,9 +87,8 @@ class User: def load_user(user_id): user_id, session_id = map(int, user_id.split(':', 1)) try: - with db.session_scope() as sess: - user, user_session = sess.query(db.User, db.UserSession).join(db.UserSession).filter(db.User.id == user_id, db.UserSession.id == session_id).one() - return User(user, user_session) + user, user_session = db.session.query(User, UserSession).join(UserSession).filter(User.id == user_id, UserSession.id == session_id).one() + return BinUser(user, user_session) except: traceback.print_exc() return None -- cgit v1.2.3