Skip to content

Commit

Permalink
WIP OAuth implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
faisal-fawad committed Nov 2, 2024
1 parent ba24e9a commit b93ca43
Show file tree
Hide file tree
Showing 7 changed files with 234 additions and 32 deletions.
10 changes: 9 additions & 1 deletion imaginate_api/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,20 @@
from werkzeug.exceptions import HTTPException
from imaginate_api.date.routes import bp as date_routes
from imaginate_api.image.routes import bp as image_routes
from imaginate_api.user.routes import bp as user_routes
from imaginate_api.config import Config
from imaginate_api.extensions import login_manager
import os


def create_app():
app = Flask(__name__)
app.config.from_object(Config)
login_manager.init_app(app)
app.secret_key = os.getenv("FLASK_SECRET_KEY")
app.register_blueprint(date_routes, url_prefix="/date")
app.register_blueprint(image_routes, url_prefix="/image")
app.register_blueprint(user_routes, url_prefix="/user")
return app


Expand Down Expand Up @@ -39,8 +46,9 @@ def handle_exception(exc: HTTPException):

# Run app on invocation
if __name__ == "__main__":
if Config.DB_ENV == 'prod':
if app.config["DB_ENV"] == "prod":
from waitress import serve

serve(app, host="0.0.0.0", port=8080)
else:
app.run()
13 changes: 13 additions & 0 deletions imaginate_api/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,4 +21,17 @@ class Config:
MONGO_TOKEN = os.getenv("MONGO_TOKEN")
PEXELS_TOKEN = os.getenv("PEXELS_TOKEN")
DB_ENV = get_db_env()
AUTH_PROVIDERS = {
"google": {
"client_id": os.getenv("GOOGLE_CLIENT_ID"),
"client_secret": os.getenv("GOOGLE_CLIENT_SECRET"),
"authorize_url": "https://accounts.google.com/o/oauth2/auth",
"token_url": "https://accounts.google.com/o/oauth2/token",
"user_info": {
"url": "https://www.googleapis.com/oauth2/v3/userinfo",
"data": lambda json: {"email": json["email"]},
},
"scopes": ["https://www.googleapis.com/auth/userinfo.email"],
}
}
TESTING = False
20 changes: 11 additions & 9 deletions imaginate_api/extensions.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,19 +2,21 @@
import gridfs
from imaginate_api.config import Config
import sys
from flask_login import LoginManager


def connect_mongodb(conn_uri: str, db_name: str):
client = MongoClient(conn_uri)
client = MongoClient(conn_uri)

# If the connection was not established properly, an exception will be raised by this if statement
if db_name not in client.list_database_names():
print(f"Database \"{db_name}\" does not exist", file=sys.stderr)
sys.exit(1)
return client[db_name], gridfs.GridFS(client[db_name])
# If the connection was not established properly, an exception will be raised by this if statement
if db_name not in client.list_database_names():
print(f'Database "{db_name}" does not exist', file=sys.stderr)
sys.exit(1)

return client[db_name], gridfs.GridFS(client[db_name])


# Setup
print(f"Running in \"{Config.DB_ENV}\" environment")
db, fs = connect_mongodb(Config.MONGO_TOKEN, f"imaginate_{Config.DB_ENV}")
print(f'Running in "{Config.DB_ENV}" environment')
db, fs = connect_mongodb(Config.MONGO_TOKEN, f"imaginate_{Config.DB_ENV}")
login_manager = LoginManager()
64 changes: 64 additions & 0 deletions imaginate_api/schemas/user_info.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
from bson.objectid import ObjectId
from flask_login import UserMixin
from imaginate_api.extensions import login_manager
from imaginate_api.extensions import db

# Specification: https://flask-login.readthedocs.io/en/latest/#
COLLECTION_NAME = "users"
COLLECTION = db[COLLECTION_NAME]


class User(UserMixin):
def __init__(self, user_data=None):
self.user_data = user_data or {}

@property
def is_authenticated(self):
return self.user_data.get("authenticated", False)

@property
def is_active(self):
return self.user_data.get("active", False)

@property
def is_anonymous(self):
return True # Always return True based on spec

def get_id(self):
return str(self.user_data["_id"])

def authenticate_user(self):
COLLECTION.update_one(
{"_id": self.user_data["_id"]}, {"$set": {"authenticated": True}}
)
self.user_data["authenticated"] = True

def deactivate_user(self):
COLLECTION.update_one({"_id": self.user_data["_id"]}, {"$set": {"active": False}})
self.user_data["active"] = False

# Create or find user by data -> email
@classmethod
def find_or_create_user(cls, data):
existing_user = COLLECTION.find_one({"email": data["email"]})
if existing_user:
return User(user_data=existing_user)

data["authenticated"] = False
data["active"] = True
new_user = COLLECTION.insert_one(data)
return User.get(new_user.inserted_id)

# Get user by ID
@classmethod
def get(cls, user_id):
user = COLLECTION.find_one({"_id": ObjectId(user_id)})
if not user:
return None
return cls(user_data=user)


# Callback function for Flask login library to load user from session user_id
@login_manager.user_loader
def load_user(user_id):
return User.get(user_id)
99 changes: 99 additions & 0 deletions imaginate_api/user/routes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
from flask import Blueprint, abort, request, redirect, url_for, session, current_app
from flask_login import current_user, login_user
from imaginate_api.schemas.user_info import User
from http import HTTPStatus
from urllib.parse import urlencode
import secrets
import requests

bp = Blueprint("user", __name__)
reroute_url = "index" # Currently set to index, but will be changed to imaginate home page in the future


# Initiates the authorization process with the specified provider
@bp.route("/authorize/<provider>")
def user_authorize(provider):
if not current_user.is_anonymous:
return redirect(url_for("index"))

provider_data = current_app.config["AUTH_PROVIDERS"].get(provider)
if not provider_data:
abort(
HTTPStatus.NOT_FOUND,
description=f"Invalid provider, supports: {list(current_app.config["AUTH_PROVIDERS"].keys())}",
)

session["oauth_state"] = secrets.token_urlsafe(32)
print(url_for("user.user_callback", provider=provider, _external=True))
query = urlencode(
{
"client_id": provider_data["client_id"],
"redirect_uri": url_for("user.user_callback", provider=provider, _external=True),
"response_type": "code", # This tells the OAuth provider that we expect an authorization code to be returned
"scope": " ".join(provider_data["scopes"]),
"state": session["oauth_state"],
}
)

return redirect(f"{provider_data["authorize_url"]}?{query}")


# Handles the callback (i.e. redirection response) process with the specified provider
@bp.route("/callback/<provider>")
def user_callback(provider):
if not current_user.is_anonymous:
return redirect(url_for("index"))

provider_data = current_app.config["AUTH_PROVIDERS"].get(provider)
if not provider_data:
abort(
HTTPStatus.NOT_FOUND,
description=f"Invalid provider, supports: {list(current_app.config["AUTH_PROVIDERS"].keys())}",
)

# Unable to authenticate with the specified provider
if "error" in request.args:
for k, v in request.args.items():
if k.startswith("error"):
print(f"{k}: {v}") # Debug any errors by printing them
abort(HTTPStatus.BAD_REQUEST, description="Authentication error")

# Authorization does not match the specification we have set
if request.args["state"] != session.get("oauth_state") or "code" not in request.args:
abort(HTTPStatus.BAD_REQUEST, description="Authorization error")

# Get an access token from the authorization code
response = requests.post(
provider_data["token_url"],
data={
"client_id": provider_data["client_id"],
"client_secret": provider_data["client_secret"],
"code": request.args["code"],
"grant_type": "authorization_code",
"redirect_uri": url_for("user.user_callback", provider=provider, _external=True),
},
headers={"Accept": "application/json"},
)
if not response.ok:
abort(response.status_code, description="Authorization error")
response_data = response.json()
token = response_data.get("access_token")
if not token:
abort(HTTPStatus.UNAUTHORIZED, description="Authorization error")

# Get the requested data
response = requests.get(
provider_data["user_info"]["url"],
headers={"Authorization": f"Bearer {token}", "Accept": "application/json"},
)
if not response.ok:
abort(response.status_code, description="Authorization error")

# Login user and map requested data
user_data = provider_data["user_info"]["data"](response.json())
user = User.find_or_create_user(user_data)
success = login_user(user)
if success:
user.authenticate_user()

return redirect(url_for("index"))
59 changes: 37 additions & 22 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ image-handler-client = {git = "https://github.com/imaginate-ai/image-handler-cli
requests = "^2.32.3"
waitress = "^3.0.0"
setuptools = "^75.2.0"
flask-login = "^0.6.3"

[tool.poetry.group.dev.dependencies]
pre-commit = "^3.7.1"
Expand Down

0 comments on commit b93ca43

Please sign in to comment.