Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

rewrite API keys implementation #190

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
318 changes: 179 additions & 139 deletions common/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,160 +5,200 @@

import secrets
import yaml
from fastapi import Header, HTTPException, Request
from pydantic import BaseModel
from fastapi import Header, HTTPException
from pydantic import BaseModel, Field
from loguru import logger
from typing import Optional

from common.utils import coalesce


class AuthKeys(BaseModel):
"""
This class represents the authentication keys for the application.
It contains two types of keys: 'api_key' and 'admin_key'.
The 'api_key' is used for general API calls, while the 'admin_key'
is used for administrative tasks. The class also provides a method
to verify if a given key matches the stored 'api_key' or 'admin_key'.
"""

api_key: str
admin_key: str

def verify_key(self, test_key: str, key_type: str):
"""Verify if a given key matches the stored key."""
if key_type == "admin_key":
return test_key == self.admin_key
if key_type == "api_key":
# Admin keys are valid for all API calls
return test_key == self.api_key or test_key == self.admin_key
return False
from typing import Union
from enum import Flag, auto
from abc import ABC, abstractmethod

from common.utils import unwrap

# Global auth constants
AUTH_KEYS: Optional[AuthKeys] = None
DISABLE_AUTH: bool = False
__all__ = ["ROLE", "auth"]


def load_auth_keys(disable_from_config: bool):
"""Load the authentication keys from api_tokens.yml. If the file does not
exist, generate new keys and save them to api_tokens.yml."""
global AUTH_KEYS
global DISABLE_AUTH
# RBAC roles
class ROLE(Flag):
USER = auto()
ADMIN = auto()

DISABLE_AUTH = disable_from_config
if disable_from_config:
logger.warning(
"Disabling authentication makes your instance vulnerable. "
"Set the `disable_auth` flag to False in config.yml if you "
"want to share this instance with others."
)

return
class API_KEY(BaseModel):
"""stores an API key"""

key: str = Field(..., description="the API key value")
role: ROLE = Field()


class AUTH_PROVIDER(ABC):
@staticmethod
@abstractmethod
def add_api_key(role: ROLE) -> API_KEY:
"""add an API key"""

@staticmethod
@abstractmethod
def set_api_key(role: ROLE, api_key: str) -> API_KEY:
"""add an existing API key"""

@staticmethod
@abstractmethod
def remove_api_key(api_key: str) -> bool:
"""remove an API key"""

@staticmethod
@abstractmethod
def check_api_key(api_key: str) -> Union[API_KEY, None]:
"""check if an API key is valid"""

@staticmethod
@abstractmethod
def authenticate_api_key(api_key: str, role: ROLE) -> bool:
"""check if an api key has ROLE"""


try:
with open("api_tokens.yml", "r", encoding="utf8") as auth_file:
auth_keys_dict = yaml.safe_load(auth_file)
AUTH_KEYS = AuthKeys.model_validate(auth_keys_dict)
except FileNotFoundError:
new_auth_keys = AuthKeys(
api_key=secrets.token_hex(16), admin_key=secrets.token_hex(16)
class SIMPLE_AUTH_PROVIDER(AUTH_PROVIDER):
api_keys: list[API_KEY] = []

def __init__(self) -> None:
try:
with open("api_tokens.yml", "r", encoding="utf8") as auth_file:
keys_dict: dict = unwrap(yaml.safe_load(auth_file), {})

# load legacy keys
admin_key = keys_dict.get("admin_key")
if admin_key:
self.set_api_key(ROLE.ADMIN, admin_key)

user_key = keys_dict.get("api_key")
if user_key:
self.set_api_key(ROLE.USER, admin_key)

# load new keys
for role in ROLE :
role_keys = keys_dict.get(f"{role.name.lower()}_keys")
if role_keys:
for key in role_keys:
self.set_api_key(role, key)

except FileNotFoundError:
file = {}

for role in ROLE :
file[f"{role.name.lower()}_keys"] = [self.add_api_key(role).key for i in range(3)]

print(file)
with open("api_tokens.yml", "w", encoding="utf8") as auth_file:
yaml.safe_dump(file, auth_file, default_flow_style=False)

logger.info("API keys:")
for key in self.api_keys:
logger.info(f"{key.role.name} :\t {key.key}")
logger.info(
"If these keys get compromised, make sure to delete " +
"api_tokens.yml and restart the server. Have fun!"
)
AUTH_KEYS = new_auth_keys

with open("api_tokens.yml", "w", encoding="utf8") as auth_file:
yaml.safe_dump(AUTH_KEYS.model_dump(), auth_file, default_flow_style=False)

logger.info(
f"Your API key is: {AUTH_KEYS.api_key}\n"
f"Your admin key is: {AUTH_KEYS.admin_key}\n\n"
"If these keys get compromised, make sure to delete api_tokens.yml "
"and restart the server. Have fun!"
)


def get_key_permission(request: Request):
"""
Gets the key permission from a request.

Internal only! Use the depends functions for incoming requests.
"""

# Give full admin permissions if auth is disabled
if DISABLE_AUTH:
return "admin"

# Hyphens are okay here
test_key = coalesce(
request.headers.get("x-admin-key"),
request.headers.get("x-api-key"),
request.headers.get("authorization"),
)

if test_key is None:
raise ValueError("The provided authentication key is missing.")

if test_key.lower().startswith("bearer"):
test_key = test_key.split(" ")[1]

if AUTH_KEYS.verify_key(test_key, "admin_key"):
return "admin"
elif AUTH_KEYS.verify_key(test_key, "api_key"):
return "api"
else:
raise ValueError("The provided authentication key is invalid.")


async def check_api_key(
x_api_key: str = Header(None), authorization: str = Header(None)
):
"""Check if the API key is valid."""

# Allow request if auth is disabled
if DISABLE_AUTH:
return

if x_api_key:
if not AUTH_KEYS.verify_key(x_api_key, "api_key"):
raise HTTPException(401, "Invalid API key")
return x_api_key

if authorization:
split_key = authorization.split(" ")
if len(split_key) < 2:
raise HTTPException(401, "Invalid API key")
if split_key[0].lower() != "bearer" or not AUTH_KEYS.verify_key(
split_key[1], "api_key"
):
raise HTTPException(401, "Invalid API key")

return authorization
def add_api_key(self, role: ROLE) -> API_KEY:
return self.set_api_key(api_key=secrets.token_hex(16), role=role)

raise HTTPException(401, "Please provide an API key")
def set_api_key(self, role: ROLE, api_key: str) -> API_KEY:
key = API_KEY(key=api_key, role=role)
self.api_keys.append(key)
return key

def remove_api_key(self, api_key: str) -> bool:
for key in self.api_keys:
if key.key == api_key:
self.api_keys.remove(key)
return True
return False

async def check_admin_key(
x_admin_key: str = Header(None), authorization: str = Header(None)
):
"""Check if the admin key is valid."""
def check_api_key(self, api_key: str) -> Union[API_KEY, None]:
for key in self.api_keys:
if key.key == api_key:
return key
return None

# Allow request if auth is disabled
if DISABLE_AUTH:
return
def authenticate_api_key(self, api_key: str, role: ROLE) -> bool:
key = self.check_api_key(api_key)
if not key:
return False
return key.role & role # if key.role in role

if x_admin_key:
if not AUTH_KEYS.verify_key(x_admin_key, "admin_key"):
raise HTTPException(401, "Invalid admin key")
return x_admin_key

if authorization:
split_key = authorization.split(" ")
if len(split_key) < 2:
raise HTTPException(401, "Invalid admin key")
if split_key[0].lower() != "bearer" or not AUTH_KEYS.verify_key(
split_key[1], "admin_key"
class NOAUTH_AUTH_PROVIDER(AUTH_PROVIDER):
def add_api_key(self, role: ROLE) -> API_KEY:
return API_KEY(key=secrets.token_hex(16), role=role)

def set_api_key(self, role: ROLE, api_key: str) -> API_KEY:
return API_KEY(key=secrets.token_hex(16), role=role)

def remove_api_key(self, api_key: str) -> bool:
return True

def check_api_key(self, api_key: str) -> Union[API_KEY, None]:
return API_KEY(key=secrets.token_hex(16), role=ROLE.ADMIN)

def authenticate_api_key(self, api_key: str, role: ROLE) -> bool:
return True


class AUTH_PROVIDER_CONTAINER:
provider: AUTH_PROVIDER

def load(self, disable_from_config: bool):
"""Load the authentication keys from api_tokens.yml. If the file does not
exist, generate new keys and save them to api_tokens.yml."""

# TODO: Make provider a paramater instead of disable_from_config
provider = "noauth" if disable_from_config else "simple"

# allows for more types of providers
provider_class = {
"noauth": NOAUTH_AUTH_PROVIDER,
"simple": SIMPLE_AUTH_PROVIDER,
}.get(provider)

if not provider_class:
raise Exception()

if provider_class == NOAUTH_AUTH_PROVIDER:
logger.warning(
"Disabling authentication makes your instance vulnerable. "
"Set the `disable_auth` flag to False in config.yml if you "
"want to share this instance with others."
)

self.provider = provider_class()

# by returning a dynamic dependency we can have one function
# where we can specify what roles can access the endpoint
def check_api_key(self, role: ROLE):
"""Check if the API key is valid."""

async def check(
x_api_key: str = Header(None), authorization: str = Header(None)
):
raise HTTPException(401, "Invalid admin key")
return authorization
if x_api_key:
key = self.provider.authenticate_api_key(x_api_key, role)
if not key:
raise HTTPException(401, "Invalid API key")
return key

if authorization:
split_key = authorization.split(" ")
if len(split_key) < 2:
raise HTTPException(401, "Invalid API key")
key = self.provider.authenticate_api_key(split_key[1], role)
if split_key[0].lower() != "bearer" or not key:
raise HTTPException(401, "Invalid API key")

return key

raise HTTPException(401, "Please provide an API key")

return check


raise HTTPException(401, "Please provide an admin key")
auth = AUTH_PROVIDER_CONTAINER()
Loading
Loading