import datetime import traceback from urllib.parse import urljoin from flask import current_app, flash from flask_login import LoginManager import jwt import requests from . import db login_manager = LoginManager() class User: def __init__(self, user, user_session): self.user = user self.user_session = user_session self.token = None def refresh_access_token(self): response = requests.post(urljoin(current_app.config['OAUTH_URL'], 'token'), data = { 'grant_type': 'refresh_token', 'client_id': current_app.config['OAUTH_CLIENT_ID'], 'client_secret': current_app.config['OAUTH_CLIENT_SECRET'], 'refresh_token': self.user_session.refresh_token, }) if response.status_code != 200: flash('Failed to refresh authentication token (API call returned {} {})'.format(response.status_code, response.reason), 'error') return token = response.json() try: access_data = jwt.decode(token['access_token'], key = current_app.config['JWT_PUBLIC_KEY'], audience = current_app.config['OAUTH_CLIENT_ID']) refresh_data = jwt.decode(token['refresh_token'], key = current_app.config['JWT_PUBLIC_KEY'], audience = current_app.config['OAUTH_CLIENT_ID']) except jwt.InvalidTokenError as e: traceback.print_exc() flash('Failed to refresh authentication token (verification failed)', 'error') return with db.session_scope() as sess: self.user_session.access_token = token['access_token'] self.user_session.refresh_token = token['refresh_token'] self.user_session.updated = datetime.datetime.utcnow() sess.add(self.user_session) return True @property def is_authenticated(self): if self.user is None: return False if self.token: return True try: self.token = jwt.decode(self.user_session.access_token, key = current_app.config['JWT_PUBLIC_KEY'], audience = current_app.config['OAUTH_CLIENT_ID']) except jwt.ExpiredSignatureError: try: if not self.refresh_access_token(): return False except: traceback.print_exc() flash('Failed to refresh authentication token (unhandled error; contact an admin)', 'error') return False except jwt.InvalidTokenError: return False return True @property def is_active(self): return True @property def is_anonymous(self): return False def get_id(self): return '{}:{}'.format(self.user.id, self.user_session.id) def get_user_id(self): return self.user.id if self.is_authenticated else None @property def username(self): return self.user.username @login_manager.user_loader def load_user(user_id): user_id, session_id = map(int, user_id.split(':', 1)) try: with db.session_scope() as sess: user, user_session = sess.query(db.User, db.UserSession).join(db.UserSession).filter(db.User.id == user_id, db.UserSession.id == session_id).one() return User(user, user_session) except: traceback.print_exc() return None