summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--fbin-scanner.py72
-rw-r--r--fbin/__init__.py3
-rw-r--r--fbin/api.py26
-rw-r--r--fbin/db.py71
-rwxr-xr-xfbin/fbin.py152
-rw-r--r--fbin/file_storage/base.py17
-rw-r--r--fbin/login.py20
7 files changed, 164 insertions, 197 deletions
diff --git a/fbin-scanner.py b/fbin-scanner.py
index 2e60522..76bc33b 100644
--- a/fbin-scanner.py
+++ b/fbin-scanner.py
@@ -80,47 +80,47 @@ def get_report(dbfile, digest, fileobj):
def main():
storage = importlib.import_module(current_app.config.get('STORAGE_MODULE', 'fbin.file_storage.filesystem')).Storage(current_app)
- with session_scope() as session:
- files = deque(session.query(File).filter(File.scanned == False).all())
- while len(files):
- dbfile = files.pop()
- if not dbfile.get_size():
- logger.info('Ignoring file %s/%s due to unknown size', dbfile.filename, dbfile.hash)
- continue
- if dbfile.get_size() > 32*10**6:
- logger.info('Ignoring file %s/%s due to size (%s)', dbfile.filename, dbfile.hash, dbfile.formatted_size)
- continue
- logger.info('Checking file %s/%s (%s)', dbfile.filename, dbfile.hash, dbfile.formatted_size)
- try:
- with storage.temp_file(dbfile) as f:
- h = hashlib.sha256()
+ files = deque(db.session.query(File).filter(File.scanned == False).all())
+ while len(files):
+ dbfile = files.pop()
+ if not dbfile.get_size():
+ logger.info('Ignoring file %s/%s due to unknown size', dbfile.filename, dbfile.hash)
+ continue
+ if dbfile.get_size() > 32*10**6:
+ logger.info('Ignoring file %s/%s due to size (%s)', dbfile.filename, dbfile.hash, dbfile.formatted_size)
+ continue
+ logger.info('Checking file %s/%s (%s)', dbfile.filename, dbfile.hash, dbfile.formatted_size)
+ try:
+ with storage.temp_file(dbfile) as f:
+ h = hashlib.sha256()
+ chunk = f.read(2**10*16)
+ while chunk:
+ h.update(chunk)
chunk = f.read(2**10*16)
- while chunk:
- h.update(chunk)
- chunk = f.read(2**10*16)
- f.seek(0)
- digest = h.hexdigest()
- logger.info('SHA-256: %s', digest)
- report = get_report(dbfile, digest, f)
- except:
- logger.exception('Failed to get report for %s/%s', dbfile.filename, dbfile.hash)
- # Most likely an error from virustotal, so just break here and retry later.
- break
- dbfile.scanned = True
- if report and any(r.get('detected', False) for r in report['scans'].values()):
- logger.warning('Positive match')
- dbfile.blocked_reason = report
- else:
- logger.info('No match')
- session.add(dbfile)
- session.commit()
- time.sleep(FILE_DELAY)
- logger.info('No more files to scan')
+ f.seek(0)
+ digest = h.hexdigest()
+ logger.info('SHA-256: %s', digest)
+ report = get_report(dbfile, digest, f)
+ except:
+ logger.exception('Failed to get report for %s/%s', dbfile.filename, dbfile.hash)
+ # Most likely an error from virustotal, so just break here and retry later.
+ break
+ dbfile.scanned = True
+ if report and any(r.get('detected', False) for r in report['scans'].values()):
+ logger.warning('Positive match')
+ dbfile.blocked_reason = report
+ else:
+ logger.info('No match')
+ db.session.add(dbfile)
+ db.session.commit()
+ time.sleep(FILE_DELAY)
+ logger.info('No more files to scan')
app = Flask('scanner')
with app.app_context():
app.config.from_pyfile(args.config_file)
- from fbin.db import session_scope, File
+ from fbin.db import db, File
+ db.init_app(app)
config = app.config
main()
diff --git a/fbin/__init__.py b/fbin/__init__.py
index 6c6a9f5..4bc2ef3 100644
--- a/fbin/__init__.py
+++ b/fbin/__init__.py
@@ -30,9 +30,10 @@ def context_processors():
}
with app.app_context():
- from .fbin import app as fbin
+ from .fbin import app as fbin, db
from .api import app as api
from .login import login_manager
app.register_blueprint(fbin)
app.register_blueprint(api, url_prefix = '/api')
login_manager.init_app(app)
+ db.init_app(app)
diff --git a/fbin/api.py b/fbin/api.py
index 4f605f0..8f3f86c 100644
--- a/fbin/api.py
+++ b/fbin/api.py
@@ -6,7 +6,7 @@ from flask.views import MethodView
from flask_login import current_user
import jwt
-from . import db
+from .db import db, User, NoResultFound
from .fbin import upload as fbin_upload, get_file
app = Blueprint('api', __name__)
@@ -32,17 +32,16 @@ def authenticate():
token = jwt.decode(token, current_app.config['SECRET_KEY'], issuer = request.url_root)
except jwt.InvalidTokenError:
abort(403)
- with db.session_scope() as s:
- try:
- user = s.query(db.User).filter(db.User.id == token['sub']).one()
- token_datetime = datetime.datetime.fromtimestamp(token['iat'])
- # If token was issued before api_key_date was updated, consider it invalid.
- if token_datetime < user.api_key_date:
- abort(403)
- else:
- g.user = user
- except db.NoResultFound:
+ try:
+ user = db.session.query(User).filter(User.id == token['sub']).one()
+ token_datetime = datetime.datetime.fromtimestamp(token['iat'])
+ # If token was issued before api_key_date was updated, consider it invalid.
+ if token_datetime < user.api_key_date:
abort(403)
+ else:
+ g.user = user
+ except NoResultFound:
+ abort(403)
def api_login_required(f):
def wrapper(*args, **kwargs):
@@ -74,9 +73,8 @@ class FileAPI(MethodView):
'status': False,
'message': 'Empty or missing filename',
}
- with db.session_scope() as sess:
- f.filename = filename
- sess.add(f)
+ f.filename = filename
+ db.session.add(f)
return {
'status': True,
}
diff --git a/fbin/db.py b/fbin/db.py
index be79c76..876efeb 100644
--- a/fbin/db.py
+++ b/fbin/db.py
@@ -4,57 +4,52 @@ import mimetypes
import os
from flask import current_app
-from sqlalchemy import create_engine, Column, Integer, String, DateTime, Text, Index, ForeignKey, Boolean, JSON, BigInteger
-from sqlalchemy.ext.declarative import declarative_base
-from sqlalchemy.orm import sessionmaker, relation, backref
from sqlalchemy.orm.exc import NoResultFound
from sqlalchemy.exc import IntegrityError
-from sqlalchemy.sql import and_
+from flask_sqlalchemy import SQLAlchemy
-engine = create_engine(current_app.config['DB_URI'])
+db = SQLAlchemy()
-Base = declarative_base(bind = engine)
-
-class User(Base):
+class User(db.Model):
__tablename__ = 'users'
- id = Column(Integer, primary_key = True)
- username = Column(String, unique = True, index = True)
- jab_id = Column(String(24), unique = True, index = True)
- api_key_date = Column(DateTime, default = datetime.datetime.utcnow)
- files = relation('File', backref = 'user', order_by = 'File.date.desc()')
+ id = db.Column(db.Integer, primary_key=True)
+ username = db.Column(db.String, unique=True, index=True)
+ jab_id = db.Column(db.String(24), unique=True, index=True)
+ api_key_date = db.Column(db.DateTime, default=datetime.datetime.utcnow)
+ files = db.relation('File', backref='user', order_by='File.date.desc()')
def __init__(self, username, jab_id):
self.username = username
self.jab_id = jab_id
-class UserSession(Base):
+class UserSession(db.Model):
__tablename__ = 'sessions'
- id = Column(Integer, primary_key = True)
- user_id = Column(Integer, ForeignKey('users.id'), index = True)
- access_token = Column(String)
- refresh_token = Column(String)
- updated = Column(DateTime)
+ id = db.Column(db.Integer, primary_key=True)
+ user_id = db.Column(db.Integer, db.ForeignKey('users.id'), index=True)
+ access_token = db.Column(db.String)
+ refresh_token = db.Column(db.String)
+ updated = db.Column(db.DateTime)
def __init__(self, user_id, access_token, refresh_token):
self.user_id = user_id
self.access_token = access_token
self.refresh_token = refresh_token
-class File(Base):
+class File(db.Model):
__tablename__ = 'files'
- id = Column(Integer, primary_key = True)
- hash = Column(String, unique = True, index = True)
- filename = Column(String)
- size = Column(BigInteger)
- date = Column(DateTime)
- user_id = Column(Integer, ForeignKey('users.id'), nullable = True)
- ip = Column(String)
- accessed = Column(DateTime)
- scanned = Column(Boolean, nullable=False, default=False)
- blocked_reason = Column(JSON)
+ id = db.Column(db.Integer, primary_key=True)
+ hash = db.Column(db.String, unique=True, index=True)
+ filename = db.Column(db.String)
+ size = db.Column(db.BigInteger)
+ date = db.Column(db.DateTime)
+ user_id = db.Column(db.Integer, db.ForeignKey('users.id'), nullable=True)
+ ip = db.Column(db.String)
+ accessed = db.Column(db.DateTime)
+ scanned = db.Column(db.Boolean, nullable=False, default=False)
+ blocked_reason = db.Column(db.JSON)
def __init__(self, hash, filename, size, date, user_id = None, ip = None):
self.hash = hash
@@ -114,19 +109,3 @@ class File(Base):
@property
def exists(self):
return os.path.exists(self.get_path())
-
-Base.metadata.create_all()
-Session = sessionmaker(bind = engine, autoflush = True, autocommit = False)
-
-@contextmanager
-def session_scope():
- session = Session()
- try:
- session.expire_on_commit = False
- yield session
- session.commit()
- except:
- session.rollback()
- raise
- finally:
- session.close()
diff --git a/fbin/fbin.py b/fbin/fbin.py
index fce6ea5..ac4569e 100755
--- a/fbin/fbin.py
+++ b/fbin/fbin.py
@@ -22,7 +22,7 @@ from PIL import Image
import requests
from werkzeug.utils import secure_filename
-from . import db
+from .db import db, User, UserSession, File, NoResultFound, IntegrityError
from .monkey import patch as monkey_patch
from .login import login_manager, load_user
from .file_storage.exceptions import StorageError
@@ -45,54 +45,50 @@ else:
has_mogrify = True
def get_or_create_user(username, jab_id):
- with db.session_scope() as sess:
+ try:
+ return db.session.query(User).filter(User.jab_id == jab_id).one()
+ except NoResultFound:
try:
- return sess.query(db.User).filter(db.User.jab_id == jab_id).one()
- except db.NoResultFound:
- try:
- user = db.User(username, jab_id)
- sess.add(user)
- sess.commit()
- sess.refresh(user)
- return user
- except db.IntegrityError:
- return None
+ user = User(username, jab_id)
+ db.session.add(user)
+ db.session.commit()
+ db.session.refresh(user)
+ return user
+ except IntegrityError:
+ return None
def get_file(file_hash, user_id=None, update_accessed=False):
- with db.session_scope() as sess:
- try:
- f = sess.query(db.File).filter(db.File.hash == file_hash)
- if user_id:
- f = f.filter(db.File.user_id == user_id)
- f = f.one()
- except db.NoResultFound:
- return None
- if update_accessed:
- f.accessed = datetime.datetime.utcnow()
- sess.add(f)
- sess.commit()
- # Refresh after field update.
- sess.refresh(f)
- return f
+ try:
+ f = db.session.query(File).filter(File.hash == file_hash)
+ if user_id:
+ f = f.filter(File.user_id == user_id)
+ f = f.one()
+ except NoResultFound:
+ return None
+ if update_accessed:
+ f.accessed = datetime.datetime.utcnow()
+ db.session.add(f)
+ db.session.commit()
+ # Refresh after field update.
+ db.session.refresh(f)
+ return f
def get_files(user):
- with db.session_scope() as sess:
- try:
- sess.add(user)
- files = user.files
- except db.NoResultFound:
- return []
+ try:
+ db.session.add(user)
+ files = user.files
+ except NoResultFound:
+ return []
return files
def delete_file(file):
- with db.session_scope() as sess:
- sess.delete(file)
- sess.commit()
- filename = file.get_path()
- storage.delete_file(file)
- thumbfile = file.get_thumb_path()
- if os.path.exists(thumbfile):
- os.unlink(thumbfile)
+ db.session.delete(file)
+ db.session.commit()
+ filename = file.get_path()
+ storage.delete_file(file)
+ thumbfile = file.get_thumb_path()
+ if os.path.exists(thumbfile):
+ os.unlink(thumbfile)
app = Blueprint('fbin', __name__)
@@ -223,11 +219,10 @@ def logout():
if not current_user.is_authenticated:
return redirect(url_for('.index'))
session_id = int(current_user.get_id().split(':', 1)[-1])
- with db.session_scope() as s:
- try:
- s.query(db.UserSession).filter_by(id = session_id).delete()
- except:
- raise
+ try:
+ db.session.query(UserSession).filter_by(id = session_id).delete()
+ except:
+ raise
logout_user()
return redirect(url_for('.index'))
@@ -272,12 +267,11 @@ def auth():
response = rs.get(urljoin(current_app.config['OAUTH_URL'], '/api/user'), headers = {'Authorization': 'Bearer {}'.format(token['access_token'])})
user = response.json()
user = get_or_create_user(user['username'], user['id'])
- with db.session_scope() as s:
- us = db.UserSession(user.id, token['access_token'], token['refresh_token'])
- us.updated = datetime.datetime.utcnow()
- s.add(us)
- s.commit()
- s.refresh(us)
+ us = UserSession(user.id, token['access_token'], token['refresh_token'])
+ us.updated = datetime.datetime.utcnow()
+ db.session.add(us)
+ db.session.commit()
+ db.session.refresh(us)
user = load_user('{}:{}'.format(user.id, us.id))
if not user:
flash('Failed to retrieve user instance.', 'error')
@@ -293,7 +287,7 @@ def files():
context = {
'title': 'Files',
'files': files,
- 'total_size': db.File.pretty_size(sum(size for size in (f.get_size() for f in files) if size is not None)),
+ 'total_size': File.pretty_size(sum(size for size in (f.get_size() for f in files) if size is not None)),
}
return render_template('files.html', **context)
@@ -306,25 +300,24 @@ def file_edit():
flash('File not found.', 'error')
return redirect(url_for('.files'))
if 'filename' in request.form:
- with db.session_scope() as sess:
- old_path = f.get_path()
- filename = request.form.get('filename', f.filename)
- f.filename = filename
- new_path = f.get_path()
- # If extension changed, the local filename also changes. We could just store the file without the extension,
- # but that would break the existing files, requiring a manual rename.
- if old_path != new_path:
- try:
- if os.path.exists(new_path):
- # This shouldn't happen unless we have two files with the same hash, which should be impossible.
- raise RuntimeError()
- else:
- os.rename(old_path, new_path)
- except:
- flash(Markup('Internal rename failed; file may have become unreachable. '
- 'Please contact an admin and specify <strong>hash={}</strong>.'.format(f.hash)), 'error')
- sess.add(f)
- flash('Filename changed to "{}".'.format(f.filename), 'success')
+ old_path = f.get_path()
+ filename = request.form.get('filename', f.filename)
+ f.filename = filename
+ new_path = f.get_path()
+ # If extension changed, the local filename also changes. We could just store the file without the extension,
+ # but that would break the existing files, requiring a manual rename.
+ if old_path != new_path:
+ try:
+ if os.path.exists(new_path):
+ # This shouldn't happen unless we have two files with the same hash, which should be impossible.
+ raise RuntimeError()
+ else:
+ os.rename(old_path, new_path)
+ except:
+ flash(Markup('Internal rename failed; file may have become unreachable. '
+ 'Please contact an admin and specify <strong>hash={}</strong>.'.format(f.hash)), 'error')
+ db.session.add(f)
+ flash('Filename changed to "{}".'.format(f.filename), 'success')
elif 'delete' in request.form:
try:
delete_file(f)
@@ -345,7 +338,7 @@ def images():
'title': 'Images',
'fullwidth': True,
'files': files,
- 'total_size': db.File.pretty_size(sum(size for size in (f.get_size() for f in files) if size is not None)),
+ 'total_size': File.pretty_size(sum(size for size in (f.get_size() for f in files) if size is not None)),
}
return render_template('images.html', **context)
@@ -358,7 +351,7 @@ def videos():
'title': 'Videos',
'fullwidth': True,
'files': files,
- 'total_size': db.File.pretty_size(sum(size for size in (f.get_size() for f in files) if size is not None)),
+ 'total_size': File.pretty_size(sum(size for size in (f.get_size() for f in files) if size is not None)),
}
return render_template('images.html', **context)
@@ -428,12 +421,11 @@ def generate_api_key():
@app.route('/invalidate-api-keys')
@login_required
def invalidate_api_keys():
- with db.session_scope() as s:
- user = current_user.user
- s.add(user)
- user.api_key_date = datetime.datetime.utcnow()
- s.commit()
- flash('All API keys invalidated.', 'success')
+ user = current_user.user
+ db.session.add(user)
+ user.api_key_date = datetime.datetime.utcnow()
+ db.session.commit()
+ flash('All API keys invalidated.', 'success')
return redirect(request.referrer)
login_manager.login_view = '.login'
diff --git a/fbin/file_storage/base.py b/fbin/file_storage/base.py
index e2ca1a6..abdf580 100644
--- a/fbin/file_storage/base.py
+++ b/fbin/file_storage/base.py
@@ -1,6 +1,6 @@
import datetime
-from .. import db
+from ..db import db, File
from .exceptions import *
class BaseStorage:
@@ -11,19 +11,18 @@ class BaseStorage:
user = file.user_id is not None
size_limit = self.app.config.get('USER_FILE_SIZE_LIMIT' if user else 'ANONYMOUS_FILE_SIZE_LIMIT')
if size_limit is not None and file.size > size_limit:
- raise FileSizeError('The file size is too large (max {})'.format(db.File.pretty_size(size_limit)))
+ raise FileSizeError('The file size is too large (max {})'.format(File.pretty_size(size_limit)))
def add_file(self, file_hash, filename, size, user=None, ip=None, verify=True):
'''Adds the file to the database.
Call from store_file after the file is successfully stored.'''
- with db.session_scope() as sess:
- f = db.File(file_hash, filename, size, datetime.datetime.utcnow(), user.id if user else None, ip)
- # Raises on invalid files
- self.verify_file(f)
- sess.add(f)
- sess.commit()
- sess.refresh(f)
+ f = File(file_hash, filename, size, datetime.datetime.utcnow(), user.id if user else None, ip)
+ # Raises on invalid files
+ self.verify_file(f)
+ db.session.add(f)
+ db.session.commit()
+ db.session.refresh(f)
return f
def store_file(self, uploaded_file, file_hash, filename, user, ip):
diff --git a/fbin/login.py b/fbin/login.py
index b365e75..b9602a8 100644
--- a/fbin/login.py
+++ b/fbin/login.py
@@ -7,11 +7,11 @@ from flask_login import LoginManager
import jwt
import requests
-from . import db
+from .db import db, User, UserSession
login_manager = LoginManager()
-class User:
+class BinUser:
def __init__(self, user, user_session):
self.user = user
self.user_session = user_session
@@ -38,12 +38,11 @@ class User:
traceback.print_exc()
flash('Failed to refresh authentication token (verification failed)', 'error')
return
- with db.session_scope() as sess:
- self.user_session.access_token = token['access_token']
- self.user_session.refresh_token = token['refresh_token']
- self.user_session.updated = datetime.datetime.utcnow()
- sess.add(self.user_session)
- sess.commit()
+ self.user_session.access_token = token['access_token']
+ self.user_session.refresh_token = token['refresh_token']
+ self.user_session.updated = datetime.datetime.utcnow()
+ db.session.add(self.user_session)
+ db.session.commit()
return True
@property
@@ -88,9 +87,8 @@ class User:
def load_user(user_id):
user_id, session_id = map(int, user_id.split(':', 1))
try:
- with db.session_scope() as sess:
- user, user_session = sess.query(db.User, db.UserSession).join(db.UserSession).filter(db.User.id == user_id, db.UserSession.id == session_id).one()
- return User(user, user_session)
+ user, user_session = db.session.query(User, UserSession).join(UserSession).filter(User.id == user_id, UserSession.id == session_id).one()
+ return BinUser(user, user_session)
except:
traceback.print_exc()
return None