diff --git a/ReportState.py b/ReportState.py index c1263c0..b355643 100644 --- a/ReportState.py +++ b/ReportState.py @@ -1,9 +1,12 @@ +import logging import time import json from flask import current_app from google.auth import crypt, jwt import requests +logger = logging.getLogger(__name__) + def generate_jwt(service_account): signer = crypt.RSASigner.from_string(service_account['private_key']) @@ -16,6 +19,7 @@ def generate_jwt(service_account): 'scope': 'https://www.googleapis.com/auth/homegraph' } + # google-auth >= 2.x returns a str; no .decode() needed return jwt.encode(signer, payload) @@ -41,19 +45,19 @@ def report_state(access_token, report_state_file): } data = report_state_file response = requests.post(url, headers=headers, json=data) - print('Response: ' + response.text) + logger.info('Response: %s', response.text) return response.status_code == requests.codes.ok def main(report_state_file): service_account = current_app.config['SERVICE_ACCOUNT_DATA'] - print('By ReportState') - signed_jwt = generate_jwt(service_account).decode("utf-8") # Decode + logger.info('By ReportState') + signed_jwt = generate_jwt(service_account) access_token = get_access_token(signed_jwt) success = report_state(access_token, report_state_file) if success: - print('Report State has been done successfully.') + logger.info('Report State has been done successfully.') else: - print('Report State failed. Please check the log above.') + logger.error('Report State failed. Please check the log above.') diff --git a/action_devices.py b/action_devices.py index 372c230..efc6c78 100644 --- a/action_devices.py +++ b/action_devices.py @@ -2,17 +2,20 @@ # Code By DaTi_Co import json +import logging import requests from flask import current_app from notifications import mqtt +logger = logging.getLogger(__name__) + try: import ReportState as state REPORTSTATE_AVAILABLE = True except ImportError: state = None REPORTSTATE_AVAILABLE = False - print("ReportState module not available, some features may be disabled.") + logger.warning("ReportState module not available, some features may be disabled.") # Try to import firebase_admin, but provide fallback if not available try: @@ -20,7 +23,7 @@ FIREBASE_AVAILABLE = True except ImportError: FIREBASE_AVAILABLE = False - print("Firebase admin not available, using mock data for testing") + logger.warning("Firebase admin not available, using mock data for testing") # Mock data for testing when Firebase is not available MOCK_DEVICES = { @@ -90,7 +93,7 @@ def reference(): return db.reference('/devices') except Exception as e: # Firebase is installed but not initialized (e.g. missing credentials in dev) - print(f"Firebase not initialized, falling back to mock data: {e}") + logger.warning("Firebase not initialized, falling back to mock data: %s", e) return MockRef() @@ -109,15 +112,15 @@ def rstate(): } for device in devices: device = str(device) - print('\nGetting Device status from: ' + device) + logger.debug('Getting Device status from: %s', device) state_data = rquery(device) if state_data: payload['devices']['states'][device] = state_data - print(state_data) + logger.debug('Device state: %s', state_data) return payload except Exception as e: - print(f"Error in rstate: {e}") + logger.error("Error in rstate: %s", e) return {"devices": {"states": {}}} @@ -139,7 +142,7 @@ def rsync(): DEVICES.append(DEVICE) return DEVICES except Exception as e: - print(f"Error in rsync: {e}") + logger.error("Error in rsync: %s", e) return [] @@ -148,7 +151,7 @@ def rquery(deviceId): ref = reference() return ref.child(deviceId).child('states').get() except Exception as e: - print(f"Error querying device {deviceId}: {e}") + logger.error("Error querying device %s: %s", deviceId, e) return {"online": False} @@ -158,7 +161,7 @@ def rexecute(deviceId, parameters): ref.child(deviceId).child('states').update(parameters) return ref.child(deviceId).child('states').get() except Exception as e: - print(f"Error executing on device {deviceId}: {e}") + logger.error("Error executing on device %s: %s", deviceId, e) return parameters @@ -169,7 +172,7 @@ def onSync(): "devices": rsync() } except Exception as e: - print(f"Error in onSync: {e}") + logger.error("Error in onSync: %s", e) return {"agentUserId": "test-user", "devices": []} @@ -182,12 +185,12 @@ def onQuery(body): for i in body['inputs']: for device in i['payload']['devices']: deviceId = device['id'] - print('DEVICE ID: ' + deviceId) + logger.debug('DEVICE ID: %s', deviceId) data = rquery(deviceId) payload['devices'][deviceId] = data return payload except Exception as e: - print(f"Error in onQuery: {e}") + logger.error("Error in onQuery: %s", e) return {"devices": {}} @@ -211,81 +214,93 @@ def onExecute(body): for execution in command['execution']: execCommand = execution['command'] params = execution['params'] - # First try to refactor payload = commands(payload, deviceId, execCommand, params) return payload except Exception as e: - print(f"Error in onExecute: {e}") + logger.error("Error in onExecute: %s", e) return {'commands': [{'ids': [], 'status': 'ERROR', 'errorCode': 'deviceNotFound'}]} def commands(payload, deviceId, execCommand, params): - """ more clean code as was bedore. - dont remember how state ad parameters is used """ + """Map an execution command to its device-state parameters and apply them.""" + # Dispatch map: command → parameter transformer + _COMMAND_PARAMS = { + 'action.devices.commands.OnOff': lambda p: {'on': p['on']} if 'on' in p else None, + 'action.devices.commands.BrightnessAbsolute': lambda p: {'brightness': p.get('brightness', 100), 'on': True}, + 'action.devices.commands.StartStop': lambda p: {'isRunning': p['start']}, + 'action.devices.commands.PauseUnpause': lambda p: {'isPaused': p['pause']}, + 'action.devices.commands.GetCameraStream': lambda p: p, + 'action.devices.commands.LockUnlock': lambda p: {'isLocked': p['lock']}, + } + try: - if execCommand == 'action.devices.commands.OnOff': - if 'on' not in params: - print("Error: 'on' parameter missing for OnOff command") + transformer = _COMMAND_PARAMS.get(execCommand) + if transformer is None: + logger.debug('Unhandled command: %s', execCommand) + else: + transformed = transformer(params) + if transformed is None: + logger.error("'on' parameter missing for OnOff command") payload['commands'][0]['status'] = 'ERROR' payload['commands'][0]['errorCode'] = 'hardError' return payload - params = {'on': params['on']} - print('OnOff') - elif execCommand == 'action.devices.commands.BrightnessAbsolute': - params = {'brightness': params.get('brightness', 100), 'on': True} - print('BrightnessAbsolute') - elif execCommand == 'action.devices.commands.StartStop': - params = {'isRunning': params['start']} - print('StartStop') - elif execCommand == 'action.devices.commands.PauseUnpause': - params = {'isPaused': params['pause']} - print('PauseUnpause') - elif execCommand == 'action.devices.commands.GetCameraStream': - print('GetCameraStream') - elif execCommand == 'action.devices.commands.LockUnlock': - params = {'isLocked': params['lock']} - print('LockUnlock') - - # Out from elif + params = transformed + logger.debug('Executing command: %s', execCommand) + states = rexecute(deviceId, params) payload['commands'][0]['states'] = states return payload except Exception as e: - print(f"Error in commands: {e}") + logger.error("Error in commands: %s", e) payload['commands'][0]['status'] = 'ERROR' return payload +def _handle_execute(req): + """Execute intent handler – runs onExecute and publishes MQTT notification.""" + payload = onExecute(req) + try: + if (payload.get('commands') + and payload['commands'][0]['ids']): + deviceId = payload['commands'][0]['ids'][0] + params = payload['commands'][0]['states'] + mqtt.publish( + topic=str(deviceId) + '/notification', + payload=str(params), + qos=0, + ) + except Exception as mqtt_error: + logger.warning("MQTT error: %s", mqtt_error) + return payload + + +# --------------------------------------------------------------------------- +# Dispatch map: Google Home intent → handler function +# --------------------------------------------------------------------------- +_INTENT_DISPATCH = { + "action.devices.SYNC": lambda req: onSync(), + "action.devices.QUERY": onQuery, + "action.devices.EXECUTE": _handle_execute, + "action.devices.DISCONNECT": lambda req: {}, +} + + def actions(req): try: payload = {} for i in req['inputs']: - print(i['intent']) - if i['intent'] == "action.devices.SYNC": - payload = onSync() - elif i['intent'] == "action.devices.QUERY": - payload = onQuery(req) - elif i['intent'] == "action.devices.EXECUTE": - payload = onExecute(req) - # SEND TEST MQTT - try: - if payload.get('commands') and len(payload['commands']) > 0 and len(payload['commands'][0]['ids']) > 0: - deviceId = payload['commands'][0]['ids'][0] - params = payload['commands'][0]['states'] - mqtt.publish(topic=str(deviceId) + '/' + 'notification', - payload=str(params), qos=0) # SENDING MQTT MESSAGE - except Exception as mqtt_error: - print(f"MQTT error: {mqtt_error}") - elif i['intent'] == "action.devices.DISCONNECT": - print("\nDISCONNECT ACTION") - payload = {} + intent = i['intent'] + logger.debug('Intent: %s', intent) + handler = _INTENT_DISPATCH.get(intent) + if handler is not None: + payload = handler(req) else: - print('Unexpected action requested: %s', json.dumps(req)) + logger.warning('Unexpected action requested: %s', json.dumps(req)) payload = {} return payload except Exception as e: - print(f"Error in actions: {e}") + logger.error("Error in actions: %s", e) return {} @@ -297,19 +312,19 @@ def request_sync(api_key, agent_user_id): response = requests.post(url, json=data) - print(f'\nRequests Code: {requests.codes["ok"]}\nResponse Code: {response.status_code}') - print(f'\nResponse: {response.text}') + logger.debug('Requests Code: %s Response Code: %s', requests.codes["ok"], response.status_code) + logger.debug('Response: %s', response.text) return response.status_code == requests.codes['ok'] except Exception as e: - print(f"Error in request_sync: {e}") + logger.error("Error in request_sync: %s", e) return False def report_state(): try: if not REPORTSTATE_AVAILABLE: - print("ReportState module not available, skipping report_state") + logger.warning("ReportState module not available, skipping report_state") return "ReportState not available" import random n = random.randint(10**19, 10**20) @@ -323,5 +338,6 @@ def report_state(): return "THIS IS TEST NO RETURN" except Exception as e: - print(f"Error in report_state: {e}") + logger.error("Error in report_state: %s", e) return f"Error: {e}" + diff --git a/app.py b/app.py index 13e1302..3de2bf1 100644 --- a/app.py +++ b/app.py @@ -1,5 +1,6 @@ # coding: utf-8 # Code By DaTi_Co +import logging import os import secrets @@ -7,129 +8,148 @@ from dotenv import load_dotenv load_dotenv() except ImportError: - print("python-dotenv not available, continuing without loading .env file") + pass from flask import Flask, jsonify, send_from_directory, request +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + # Try importing Firebase, but don't fail if not available try: from firebase_admin import credentials, initialize_app FIREBASE_AVAILABLE = True except ImportError: - print("Firebase admin not available, continuing without it") + logger.warning("Firebase admin not available, continuing without it") FIREBASE_AVAILABLE = False # Try importing other modules, but provide fallbacks try: from flask_login import LoginManager from models import User, db - from my_oauth import oauth + from my_oauth import init_oauth from notifications import mqtt from routes import bp from auth import auth FULL_FEATURES = True except ImportError as e: - print(f"Some modules not available: {e}") + logger.warning("Some modules not available: %s", e) FULL_FEATURES = False -# Flask Application Configuration -app = Flask(__name__, template_folder='templates') -app.config['UPLOAD_FOLDER'] = './static/upload' +# File Extensions for Upload Folder +ALLOWED_EXTENSIONS = {'txt', 'py'} -if app.config.get("ENV") == "production": - try: - app.config.from_object("config.ProductionConfig") - except Exception as e: - print(f"Could not load ProductionConfig: {e}") -else: - try: - app.config.from_object("config.DevelopmentConfig") - except Exception as e: - print(f"Could not load DevelopmentConfig: {e}") - -# Apply fallback defaults AFTER from_object so config values are not overwritten. -# Config classes set these from environment variables; if the env var is unset, -# from_object produces None — these defaults keep the app startable in dev/test. -if not app.config.get('AGENT_USER_ID'): - app.config['AGENT_USER_ID'] = 'test-user' -if not app.config.get('API_KEY'): - app.config['API_KEY'] = 'test-api-key' -if not app.config.get('DATABASEURL'): - app.config['DATABASEURL'] = 'https://test-project-default-rtdb.firebaseio.com/' - -# Ensure SECRET_KEY is set; generate a random one if missing (not suitable for production) -if not app.config.get('SECRET_KEY'): - app.config['SECRET_KEY'] = secrets.token_urlsafe(16) - print("WARNING: SECRET_KEY not set in environment. Using a generated key (not suitable for production).") - -print(f'ENV is set to: {app.config.get("ENV", "development")}') -print(f'AGENT_USER_ID: {app.config.get("AGENT_USER_ID")}') - -# Register blueprints if available — initialize extensions first, then blueprints. -# This order ensures that a failure in extension setup does not leave blueprints -# partially registered while FULL_FEATURES is flipped to False. -if FULL_FEATURES: + +def allowed_file(filename): + """File Uploading Function""" + return '.' in filename and \ + filename.rsplit('.', 1)[1].lower() in ALLOWED_EXTENSIONS + + +def create_app(config_class=None): + """Application factory.""" + application = Flask(__name__, template_folder='templates') + application.config['UPLOAD_FOLDER'] = './static/upload' + + # Load configuration + env = os.environ.get('FLASK_ENV', os.environ.get('ENV', 'development')) + if config_class is not None: + application.config.from_object(config_class) + elif env == 'production': + try: + application.config.from_object('config.ProductionConfig') + except Exception as e: + logger.warning("Could not load ProductionConfig: %s", e) + else: + try: + application.config.from_object('config.DevelopmentConfig') + except Exception as e: + logger.warning("Could not load DevelopmentConfig: %s", e) + + # Apply fallback defaults so the app starts cleanly in dev/test + if not application.config.get('AGENT_USER_ID'): + application.config['AGENT_USER_ID'] = 'test-user' + if not application.config.get('API_KEY'): + application.config['API_KEY'] = 'test-api-key' + if not application.config.get('DATABASEURL'): + application.config['DATABASEURL'] = 'https://test-project-default-rtdb.firebaseio.com/' + if not application.config.get('SECRET_KEY'): + application.config['SECRET_KEY'] = secrets.token_urlsafe(16) + logger.warning("SECRET_KEY not set in environment. Using a generated key (not suitable for production).") + + logger.info('ENV is set to: %s', env) + logger.info('AGENT_USER_ID: %s', application.config.get('AGENT_USER_ID')) + + # Register extensions and blueprints when all features are available; + # fall back to minimal routes if extension initialisation fails. + if FULL_FEATURES: + if not _init_full_features(application): + _register_fallback_routes(application) + else: + _register_fallback_routes(application) + + # Initialize Firebase when full features are active + if FIREBASE_AVAILABLE and FULL_FEATURES: + try: + svc = application.config.get('SERVICE_ACCOUNT_DATA') + if svc: + firebase_creds = credentials.Certificate(svc) + firebase_opts = {'databaseURL': application.config['DATABASEURL']} + initialize_app(firebase_creds, firebase_opts) + logger.info("Firebase initialized successfully") + except Exception as e: + logger.warning("Could not initialize Firebase: %s", e) + + return application + + +def _init_full_features(application): + """Initialize extensions and blueprints for the full-featured app. + + Returns True when all extensions and blueprints were successfully + registered, False otherwise. + """ try: - # Initialize all extensions before touching the blueprint registry - mqtt.init_app(app) + mqtt.init_app(application) mqtt.subscribe('+/notification') mqtt.subscribe('+/status') - db.init_app(app) - oauth.init_app(app) + db.init_app(application) + init_oauth(application) + login_manager = LoginManager() login_manager.login_view = 'auth.login' - login_manager.init_app(app) + login_manager.init_app(application) @login_manager.user_loader def load_user(user_id): - """Get User ID""" - print(user_id) - return User.query.get(int(user_id)) + """Get User by ID.""" + return db.session.get(User, int(user_id)) - # Register blueprints only after all extensions succeed - app.register_blueprint(bp, url_prefix='') - app.register_blueprint(auth, url_prefix='') - except Exception as e: - print(f"Could not initialize full features: {e}") - FULL_FEATURES = False + # Create database tables within the application context + with application.app_context(): + try: + db_uri = application.config.get('SQLALCHEMY_DATABASE_URI', '') + logger.info('DB Engine: %s', db_uri.split(':')[0] if db_uri else 'sqlite') + db.create_all() + logger.info('Initialized the database.') + except Exception as e: + logger.warning("Could not create database tables: %s", e) -# Initialize Firebase if available -if FIREBASE_AVAILABLE and FULL_FEATURES: - try: - FIREBASE_ADMINSDK_FILE = app.config.get('SERVICE_ACCOUNT_DATA') - if FIREBASE_ADMINSDK_FILE: - FIREBASE_CREDENTIALS = credentials.Certificate(FIREBASE_ADMINSDK_FILE) - FIREBASE_DATABASEURL = app.config['DATABASEURL'] - FIREBASE_OPTIONS = {'databaseURL': FIREBASE_DATABASEURL} - initialize_app(FIREBASE_CREDENTIALS, FIREBASE_OPTIONS) - print("Firebase initialized successfully") + application.register_blueprint(bp, url_prefix='') + application.register_blueprint(auth, url_prefix='') + return True except Exception as e: - print(f"Could not initialize Firebase: {e}") - -# File Extensions for Upload Folder -ALLOWED_EXTENSIONS = {'txt', 'py'} - - -def allowed_file(filename): - """File Uploading Function""" - return '.' in filename and \ - filename.rsplit('.', 1)[1].lower() in ALLOWED_EXTENSIONS + logger.warning("Could not initialize full features: %s", e) + return False -@app.route('/uploads/') -def uploaded_file(filename): - """File formats for upload folder""" - return send_from_directory(app.config['UPLOAD_FOLDER'], filename) - - -# Fallback routes — only registered when full blueprint features are unavailable, -# to avoid duplicate URL rule conflicts with the blueprint routes. -if not FULL_FEATURES: - @app.route('/') +def _register_fallback_routes(application): + """Register minimal routes when full features are unavailable.""" + @application.route('/') def index(): - return {'status': 'Smart-Google is working!', 'agent_user_id': app.config['AGENT_USER_ID']} + return {'status': 'Smart-Google is working!', 'agent_user_id': application.config['AGENT_USER_ID']} - @app.route('/health') + @application.route('/health') def health(): return jsonify({ 'status': 'degraded', @@ -140,14 +160,14 @@ def health(): try: from action_devices import onSync, actions, request_sync, report_state - @app.route('/devices') + @application.route('/devices') def devices(): try: return onSync() except Exception as e: return {'error': str(e)}, 500 - @app.route('/smarthome', methods=['POST']) + @application.route('/smarthome', methods=['POST']) def smarthome(): try: req_data = request.get_json() @@ -159,38 +179,33 @@ def smarthome(): except Exception as e: return {'error': str(e)}, 500 - @app.route('/sync') + @application.route('/sync') def sync(): try: - success = request_sync(app.config['API_KEY'], app.config['AGENT_USER_ID']) + success = request_sync(application.config['API_KEY'], application.config['AGENT_USER_ID']) state_result = report_state() return {'sync_requested': True, 'success': success, 'state_report': state_result} except Exception as e: return {'error': str(e)}, 500 except ImportError as e: - print(f"Could not import action_devices: {e}") + logger.warning("Could not import action_devices: %s", e) -if FULL_FEATURES: - try: - @app.before_first_request - def create_db_command(): - """Search for tables and if there is no data create new tables.""" - print('DB Engine: ' + app.config.get('SQLALCHEMY_DATABASE_URI', 'sqlite').split(':')[0]) - db.create_all(app=app) - print('Initialized the database.') - except Exception as e: - print(f"Could not set up database initialization: {e}") -if __name__ == '__main__': - print("Starting Smart-Google Flask Application") - print(f"Full features: {FULL_FEATURES}") +# --------------------------------------------------------------------------- +# Module-level application instance (used by Gunicorn and the test suite) +# --------------------------------------------------------------------------- +app = create_app() - if FULL_FEATURES: - try: - db.create_all(app=app) - except Exception: - pass +@app.route('/uploads/') +def uploaded_file(filename): + """Serve files from the upload folder.""" + return send_from_directory(app.config['UPLOAD_FOLDER'], filename) + + +if __name__ == '__main__': + logger.info("Starting Smart-Google Flask Application") + logger.info("Full features: %s", FULL_FEATURES) host = os.environ.get('FLASK_RUN_HOST', '127.0.0.1') app.run(host=host, port=5000, debug=False) diff --git a/auth.py b/auth.py index 950988f..c0f2692 100644 --- a/auth.py +++ b/auth.py @@ -2,6 +2,7 @@ # Code By DaTi_Co from flask import Blueprint, render_template, redirect, url_for, request, flash +from sqlalchemy import select from werkzeug.security import generate_password_hash, check_password_hash from flask_login import login_user, logout_user, login_required from models import db, User @@ -21,7 +22,7 @@ def login_post(): password = request.form.get('password') remember = bool(request.form.get('remember')) - user = User.query.filter_by(email=email).first() + user = db.session.execute(select(User).filter_by(email=email)).scalar_one_or_none() if not user or not check_password_hash(user.password, password): flash('Please check your login details and try again.') @@ -43,13 +44,13 @@ def signup_post(): name = request.form.get('name') password = request.form.get('password') # get user from database - user = User.query.filter_by(email=email).first() + user = db.session.execute(select(User).filter_by(email=email)).scalar_one_or_none() if user: flash('This Mail is used by another Person') return redirect(url_for('auth.signup')) - # If not User found Create new + # If no user found, create a new one with a strong password hash (default algorithm) new_user = User(email=email, name=name, - password=generate_password_hash(password, method='sha256')) + password=generate_password_hash(password)) db.session.add(new_user) db.session.commit() diff --git a/models.py b/models.py index 3df0674..792595f 100644 --- a/models.py +++ b/models.py @@ -1,4 +1,6 @@ # models.py +import time +from datetime import datetime, timezone from flask_sqlalchemy import SQLAlchemy from flask_login import UserMixin @@ -7,8 +9,6 @@ class User(UserMixin, db.Model): id = db.Column(db.Integer, primary_key=True) - username = db.Column(db.String(40), unique=True) # this will be removed - # email = db.Column(db.String(100), unique=True) password = db.Column(db.String(100)) name = db.Column(db.String(1000)) @@ -25,6 +25,40 @@ class Client(db.Model): _redirect_uris = db.Column(db.Text) _default_scopes = db.Column(db.Text) + # ------------------------------------------------------------------ + # Authlib ClientMixin interface + # ------------------------------------------------------------------ + def get_client_id(self): + return self.client_id + + def get_default_redirect_uri(self): + uris = self.redirect_uris + return uris[0] if uris else '' + + def get_allowed_scope(self, scope): + if not scope: + return '' + allowed = set(self.default_scopes) + return ' '.join(allowed & set(scope.split())) + + def check_redirect_uri(self, redirect_uri): + return redirect_uri in self.redirect_uris + + def check_client_secret(self, client_secret): + return self.client_secret == client_secret + + def check_endpoint_auth_method(self, method, endpoint): + return True + + def check_grant_type(self, grant_type): + return grant_type in ('authorization_code', 'refresh_token') + + def check_response_type(self, response_type): + return response_type == 'code' + + # ------------------------------------------------------------------ + # Legacy helpers (kept for template compatibility) + # ------------------------------------------------------------------ @property def client_type(self): return 'public' @@ -37,7 +71,7 @@ def redirect_uris(self): @property def default_redirect_uri(self): - return self.redirect_uris[0] + return self.redirect_uris[0] if self.redirect_uris else '' @property def default_scopes(self): @@ -68,6 +102,23 @@ def delete(self): db.session.commit() return self + # ------------------------------------------------------------------ + # Authlib AuthorizationCodeMixin interface + # ------------------------------------------------------------------ + def get_redirect_uri(self): + return self.redirect_uri or '' + + def get_scope(self): + return self._scopes or '' + + def get_auth_time(self): + return int(time.time()) + + def is_expired(self): + if self.expires is None: + return True + return datetime.now(timezone.utc).replace(tzinfo=None) > self.expires + @property def scopes(self): if self._scopes: @@ -94,6 +145,31 @@ class Token(db.Model): expires = db.Column(db.DateTime) _scopes = db.Column(db.Text) + # ------------------------------------------------------------------ + # Authlib TokenMixin interface + # ------------------------------------------------------------------ + def get_client_id(self): + return self.client_id + + def get_scope(self): + return self._scopes or '' + + def get_expires_at(self): + if self.expires is None: + return 0 + return int(self.expires.replace(tzinfo=timezone.utc).timestamp()) + + def is_expired(self): + if self.expires is None: + return True + return datetime.now(timezone.utc).replace(tzinfo=None) > self.expires + + def is_revoked(self): + return False + + def check_client(self, client): + return client.get_client_id() == self.client_id + @property def scopes(self): if self._scopes: diff --git a/my_oauth.py b/my_oauth.py index 7857db7..58f704d 100644 --- a/my_oauth.py +++ b/my_oauth.py @@ -1,90 +1,124 @@ # coding: utf-8 # Code By DaTi_Co -# OAuth2 -from datetime import datetime, timedelta -from flask_oauthlib.provider import OAuth2Provider +# OAuth2 – migrated from Flask-OAuthlib to Authlib +import logging +from datetime import datetime, timedelta, timezone + from flask import session -from models import db -from models import Client, Token, Grant, User +from authlib.integrations.flask_oauth2 import AuthorizationServer, ResourceProtector +from authlib.integrations.flask_oauth2 import current_token # noqa: F401 – re-exported +from authlib.oauth2.rfc6749 import grants +from authlib.oauth2.rfc6750 import BearerTokenValidator + +from models import db, Client, Grant, Token, User + +logger = logging.getLogger(__name__) -oauth = OAuth2Provider() +# --------------------------------------------------------------------------- +# Authlib server & resource protector (initialised in create_app via init_oauth) +# --------------------------------------------------------------------------- +authorization = AuthorizationServer() +require_oauth = ResourceProtector() +# --------------------------------------------------------------------------- +# Helper: get the currently logged-in user from the session +# --------------------------------------------------------------------------- def get_current_user(): if 'id' in session: - uid = session['id'] - print(User.query.get(uid)) - return User.query.get(uid) + return db.session.get(User, session['id']) return None -@oauth.clientgetter -def load_client(client_id): - print("get client") - print(client_id) - print(Client.query.filter_by(client_id=client_id).first()) - return Client.query.filter_by(client_id=client_id).first() - - -@oauth.grantgetter -def load_grant(client_id, code): - print("grant getter") - return Grant.query.filter_by(client_id=client_id, code=code).first() - - -@oauth.grantsetter -def save_grant(client_id, code, request, *args, **kwargs): - # decide the expires time yourself - print("save grant") - expires = datetime.utcnow() + timedelta(seconds=100) - grant = Grant( - client_id=client_id, - code=code['code'], - redirect_uri=request.redirect_uri, - _scopes=' '.join(request.scopes), - user=get_current_user(), - expires=expires - ) - print(grant) - db.session.add(grant) - db.session.commit() - return grant - - -@oauth.tokengetter -def load_token(access_token=None, refresh_token=None): - print("token getter") - if access_token: - return Token.query.filter_by(access_token=access_token).first() - if refresh_token: - return Token.query.filter_by(refresh_token=refresh_token).first() +# --------------------------------------------------------------------------- +# Client / token query + save helpers required by AuthorizationServer +# --------------------------------------------------------------------------- +def _query_client(client_id): + return db.session.execute( + db.select(Client).filter_by(client_id=client_id) + ).scalar_one_or_none() -@oauth.tokensetter -def save_token(token, request, *args, **kwargs): - print("token setter") - toks = Token.query.filter_by( - client_id=request.client.client_id, - user_id=request.user.id - ) - print(toks) - # make sure that every client has only one token connected to a user - for t in toks: +def _save_token(token_data, request): + # Ensure each client/user pair has only one active token + existing = db.session.execute( + db.select(Token).filter_by( + client_id=request.client.client_id, + user_id=request.user.id, + ) + ).scalars().all() + for t in existing: db.session.delete(t) - expires_in = token.pop('expires_in') - expires = datetime.utcnow() + timedelta(seconds=expires_in) + expires_in = token_data.get('expires_in', 3600) + expires = datetime.now(timezone.utc).replace(tzinfo=None) + timedelta(seconds=expires_in) tok = Token( - access_token=token['access_token'], - refresh_token=token['refresh_token'], - token_type=token['token_type'], - _scopes=token['scope'], + access_token=token_data['access_token'], + refresh_token=token_data.get('refresh_token', ''), + token_type=token_data['token_type'], + _scopes=token_data.get('scope', ''), expires=expires, client_id=request.client.client_id, user_id=request.user.id, ) - print(tok) db.session.add(tok) db.session.commit() return tok + + +# --------------------------------------------------------------------------- +# Authorization Code grant +# --------------------------------------------------------------------------- +class _AuthorizationCodeGrant(grants.AuthorizationCodeGrant): + TOKEN_ENDPOINT_AUTH_METHODS = [ + 'client_secret_basic', + 'client_secret_post', + 'none', + ] + + def save_authorization_code(self, code, request): + expires = datetime.now(timezone.utc).replace(tzinfo=None) + timedelta(seconds=100) + grant = Grant( + code=code, + client_id=request.client.client_id, + redirect_uri=request.redirect_uri or '', + _scopes=' '.join(request.scopes), + user=request.user, + expires=expires, + ) + db.session.add(grant) + db.session.commit() + logger.debug('Authorization code saved for client %s', request.client.client_id) + return grant + + def query_authorization_code(self, code, client): + return db.session.execute( + db.select(Grant).filter_by(code=code, client_id=client.client_id) + ).scalar_one_or_none() + + def delete_authorization_code(self, authorization_code): + db.session.delete(authorization_code) + db.session.commit() + + def authenticate_user(self, authorization_code): + return db.session.get(User, authorization_code.user_id) + + +# --------------------------------------------------------------------------- +# Bearer token validator for protected resources +# --------------------------------------------------------------------------- +class _BearerTokenValidator(BearerTokenValidator): + def authenticate_token(self, token_string): + return db.session.execute( + db.select(Token).filter_by(access_token=token_string) + ).scalar_one_or_none() + + +# --------------------------------------------------------------------------- +# Application-level initialiser – called from create_app() +# --------------------------------------------------------------------------- +def init_oauth(app): + authorization.init_app(app, query_client=_query_client, save_token=_save_token) + authorization.register_grant(_AuthorizationCodeGrant) + require_oauth.register_token_validator(_BearerTokenValidator()) diff --git a/requirements.txt b/requirements.txt index f8f5ed2..4cb817a 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,16 +1,17 @@ -gunicorn==20.1.0 -Flask==1.1.2 -Flask-MQTT==1.1.1 -Flask-OAuthlib==0.9.6 -Flask-SQLAlchemy==2.5.1 -SQLAlchemy==1.4.47 -Werkzeug==0.16.1 -Jinja2>=2.10.1,<3.0.0 -MarkupSafe>=1.1.1,<2.1.0 -itsdangerous>=0.24,<2.0 -six==1.14.0 -firebase_admin==3.2.1 -PyMySQL==1.0.2 -requests-oauthlib -python-dotenv==0.15.0 -Flask-Login==0.4.1 \ No newline at end of file +gunicorn==22.0.0 +Flask==3.0.3 +Flask-MQTT==1.3.0 +authlib==1.6.7 +Flask-SQLAlchemy==3.1.1 +SQLAlchemy==2.0.30 +Werkzeug==3.0.3 +Jinja2>=3.1.0 +MarkupSafe>=2.1.0 +itsdangerous>=2.1.0 +firebase_admin==6.5.0 +google-auth>=2.22.0 +PyMySQL==1.1.1 +python-dotenv==1.0.0 +Flask-Login==0.6.3 +requests>=2.31.0 +cryptography>=41.0.0 \ No newline at end of file diff --git a/routes.py b/routes.py index 8132e51..8a77ea7 100644 --- a/routes.py +++ b/routes.py @@ -3,9 +3,10 @@ from flask import Blueprint, current_app, request, jsonify, redirect, render_template, make_response from flask_login import login_required, current_user +from sqlalchemy import select from action_devices import onSync, report_state, request_sync, actions -from models import Client -from my_oauth import get_current_user, oauth +from models import Client, db +from my_oauth import get_current_user, authorization, require_oauth, current_token from notifications import is_mqtt_connected @@ -35,37 +36,40 @@ def profile(): @bp.route('/oauth/token', methods=['POST']) -@oauth.token_handler def access_token(): - return {'version': '0.1.0'} + return authorization.create_token_response() @bp.route('/oauth/authorize', methods=['GET', 'POST']) -# Both GET (render consent page) and POST (handle form submission) are required -# by the OAuth2 Authorization Code flow and the @oauth.authorize_handler decorator. -@oauth.authorize_handler -def authorize(*args, **kwargs): +def authorize(): user = get_current_user() if not user: return redirect('/') + if request.method == 'GET': - client_id = kwargs.get('client_id') - client = Client.query.filter_by(client_id=client_id).first() + try: + grant = authorization.validate_consent_request(end_user=user) + except Exception: + return redirect('/') + client_id = request.args.get('client_id') + client = db.session.execute( + select(Client).filter_by(client_id=client_id) + ).scalar_one_or_none() if client is None: return redirect('/') - kwargs['client'] = client - kwargs['user'] = user - return render_template('authorize.html', **kwargs) + return render_template('authorize.html', grant=grant, user=user, client=client) - confirm = request.form.get('confirm', 'no') - return confirm == 'yes' + confirmed = request.form.get('confirm', 'no') == 'yes' + return authorization.create_authorization_response( + grant_user=user if confirmed else None + ) @bp.route('/api/me') -@oauth.require_oauth() -def me(req): - user = req.user - return jsonify(username=user.username) +@require_oauth() +def me(): + user = current_token.user + return jsonify(email=user.email) @bp.route('/sync') diff --git a/runtime.txt b/runtime.txt index ebbcd4a..b884b0f 100644 --- a/runtime.txt +++ b/runtime.txt @@ -1 +1 @@ -python-3.8.9 \ No newline at end of file +python-3.12.2 \ No newline at end of file