From 0faa732c9a3e1ffced2b26bee682f513b0e5f0ae Mon Sep 17 00:00:00 2001 From: Jon Bergli Heier Date: Wed, 28 Oct 2020 19:16:34 +0100 Subject: Use flask-sqlalchemy instead of using sqlalchemy directly This makes database access a bit easier and also greatly simplifies some upcoming changes. --- fbin/fbin.py | 152 ++++++++++++++++++++++++++++------------------------------- 1 file changed, 72 insertions(+), 80 deletions(-) (limited to 'fbin/fbin.py') diff --git a/fbin/fbin.py b/fbin/fbin.py index fce6ea5..ac4569e 100755 --- a/fbin/fbin.py +++ b/fbin/fbin.py @@ -22,7 +22,7 @@ from PIL import Image import requests from werkzeug.utils import secure_filename -from . import db +from .db import db, User, UserSession, File, NoResultFound, IntegrityError from .monkey import patch as monkey_patch from .login import login_manager, load_user from .file_storage.exceptions import StorageError @@ -45,54 +45,50 @@ else: has_mogrify = True def get_or_create_user(username, jab_id): - with db.session_scope() as sess: + try: + return db.session.query(User).filter(User.jab_id == jab_id).one() + except NoResultFound: try: - return sess.query(db.User).filter(db.User.jab_id == jab_id).one() - except db.NoResultFound: - try: - user = db.User(username, jab_id) - sess.add(user) - sess.commit() - sess.refresh(user) - return user - except db.IntegrityError: - return None + user = User(username, jab_id) + db.session.add(user) + db.session.commit() + db.session.refresh(user) + return user + except IntegrityError: + return None def get_file(file_hash, user_id=None, update_accessed=False): - with db.session_scope() as sess: - try: - f = sess.query(db.File).filter(db.File.hash == file_hash) - if user_id: - f = f.filter(db.File.user_id == user_id) - f = f.one() - except db.NoResultFound: - return None - if update_accessed: - f.accessed = datetime.datetime.utcnow() - sess.add(f) - sess.commit() - # Refresh after field update. - sess.refresh(f) - return f + try: + f = db.session.query(File).filter(File.hash == file_hash) + if user_id: + f = f.filter(File.user_id == user_id) + f = f.one() + except NoResultFound: + return None + if update_accessed: + f.accessed = datetime.datetime.utcnow() + db.session.add(f) + db.session.commit() + # Refresh after field update. + db.session.refresh(f) + return f def get_files(user): - with db.session_scope() as sess: - try: - sess.add(user) - files = user.files - except db.NoResultFound: - return [] + try: + db.session.add(user) + files = user.files + except NoResultFound: + return [] return files def delete_file(file): - with db.session_scope() as sess: - sess.delete(file) - sess.commit() - filename = file.get_path() - storage.delete_file(file) - thumbfile = file.get_thumb_path() - if os.path.exists(thumbfile): - os.unlink(thumbfile) + db.session.delete(file) + db.session.commit() + filename = file.get_path() + storage.delete_file(file) + thumbfile = file.get_thumb_path() + if os.path.exists(thumbfile): + os.unlink(thumbfile) app = Blueprint('fbin', __name__) @@ -223,11 +219,10 @@ def logout(): if not current_user.is_authenticated: return redirect(url_for('.index')) session_id = int(current_user.get_id().split(':', 1)[-1]) - with db.session_scope() as s: - try: - s.query(db.UserSession).filter_by(id = session_id).delete() - except: - raise + try: + db.session.query(UserSession).filter_by(id = session_id).delete() + except: + raise logout_user() return redirect(url_for('.index')) @@ -272,12 +267,11 @@ def auth(): response = rs.get(urljoin(current_app.config['OAUTH_URL'], '/api/user'), headers = {'Authorization': 'Bearer {}'.format(token['access_token'])}) user = response.json() user = get_or_create_user(user['username'], user['id']) - with db.session_scope() as s: - us = db.UserSession(user.id, token['access_token'], token['refresh_token']) - us.updated = datetime.datetime.utcnow() - s.add(us) - s.commit() - s.refresh(us) + us = UserSession(user.id, token['access_token'], token['refresh_token']) + us.updated = datetime.datetime.utcnow() + db.session.add(us) + db.session.commit() + db.session.refresh(us) user = load_user('{}:{}'.format(user.id, us.id)) if not user: flash('Failed to retrieve user instance.', 'error') @@ -293,7 +287,7 @@ def files(): context = { 'title': 'Files', 'files': files, - 'total_size': db.File.pretty_size(sum(size for size in (f.get_size() for f in files) if size is not None)), + 'total_size': File.pretty_size(sum(size for size in (f.get_size() for f in files) if size is not None)), } return render_template('files.html', **context) @@ -306,25 +300,24 @@ def file_edit(): flash('File not found.', 'error') return redirect(url_for('.files')) if 'filename' in request.form: - with db.session_scope() as sess: - old_path = f.get_path() - filename = request.form.get('filename', f.filename) - f.filename = filename - new_path = f.get_path() - # If extension changed, the local filename also changes. We could just store the file without the extension, - # but that would break the existing files, requiring a manual rename. - if old_path != new_path: - try: - if os.path.exists(new_path): - # This shouldn't happen unless we have two files with the same hash, which should be impossible. - raise RuntimeError() - else: - os.rename(old_path, new_path) - except: - flash(Markup('Internal rename failed; file may have become unreachable. ' - 'Please contact an admin and specify hash={}.'.format(f.hash)), 'error') - sess.add(f) - flash('Filename changed to "{}".'.format(f.filename), 'success') + old_path = f.get_path() + filename = request.form.get('filename', f.filename) + f.filename = filename + new_path = f.get_path() + # If extension changed, the local filename also changes. We could just store the file without the extension, + # but that would break the existing files, requiring a manual rename. + if old_path != new_path: + try: + if os.path.exists(new_path): + # This shouldn't happen unless we have two files with the same hash, which should be impossible. + raise RuntimeError() + else: + os.rename(old_path, new_path) + except: + flash(Markup('Internal rename failed; file may have become unreachable. ' + 'Please contact an admin and specify hash={}.'.format(f.hash)), 'error') + db.session.add(f) + flash('Filename changed to "{}".'.format(f.filename), 'success') elif 'delete' in request.form: try: delete_file(f) @@ -345,7 +338,7 @@ def images(): 'title': 'Images', 'fullwidth': True, 'files': files, - 'total_size': db.File.pretty_size(sum(size for size in (f.get_size() for f in files) if size is not None)), + 'total_size': File.pretty_size(sum(size for size in (f.get_size() for f in files) if size is not None)), } return render_template('images.html', **context) @@ -358,7 +351,7 @@ def videos(): 'title': 'Videos', 'fullwidth': True, 'files': files, - 'total_size': db.File.pretty_size(sum(size for size in (f.get_size() for f in files) if size is not None)), + 'total_size': File.pretty_size(sum(size for size in (f.get_size() for f in files) if size is not None)), } return render_template('images.html', **context) @@ -428,12 +421,11 @@ def generate_api_key(): @app.route('/invalidate-api-keys') @login_required def invalidate_api_keys(): - with db.session_scope() as s: - user = current_user.user - s.add(user) - user.api_key_date = datetime.datetime.utcnow() - s.commit() - flash('All API keys invalidated.', 'success') + user = current_user.user + db.session.add(user) + user.api_key_date = datetime.datetime.utcnow() + db.session.commit() + flash('All API keys invalidated.', 'success') return redirect(request.referrer) login_manager.login_view = '.login' -- cgit v1.2.3