From e0bc35fbbaef028e17b53aad7132d9382e9de2e1 Mon Sep 17 00:00:00 2001 From: TerminalMan <84923604+SecretiveShell@users.noreply.github.com> Date: Fri, 6 Sep 2024 22:51:17 +0100 Subject: [PATCH 1/6] rewrite API keys implementation - add RBAC support - support multiple API keys - keep backwards compatibility - easy extension to other authentication types --- common/auth.py | 319 +++++++++++++++++------------- endpoints/Kobold/router.py | 53 +++-- endpoints/OAI/router.py | 11 +- endpoints/OAI/utils/completion.py | 4 +- endpoints/core/router.py | 133 +++++++++---- main.py | 4 +- 6 files changed, 332 insertions(+), 192 deletions(-) diff --git a/common/auth.py b/common/auth.py index 174208de..194d2991 100644 --- a/common/auth.py +++ b/common/auth.py @@ -6,159 +6,202 @@ import secrets import yaml from fastapi import Header, HTTPException, Request -from pydantic import BaseModel +from pydantic import BaseModel, Field from loguru import logger -from typing import Optional - -from common.utils import coalesce - - -class AuthKeys(BaseModel): - """ - This class represents the authentication keys for the application. - It contains two types of keys: 'api_key' and 'admin_key'. - The 'api_key' is used for general API calls, while the 'admin_key' - is used for administrative tasks. The class also provides a method - to verify if a given key matches the stored 'api_key' or 'admin_key'. - """ - - api_key: str - admin_key: str - - def verify_key(self, test_key: str, key_type: str): - """Verify if a given key matches the stored key.""" - if key_type == "admin_key": - return test_key == self.admin_key - if key_type == "api_key": - # Admin keys are valid for all API calls - return test_key == self.api_key or test_key == self.admin_key - return False +from typing import Optional, Union +from enum import Flag, auto +from abc import ABC, abstractmethod +from common.utils import coalesce, unwrap -# Global auth constants -AUTH_KEYS: Optional[AuthKeys] = None -DISABLE_AUTH: bool = False +__all__ = ["ROLE", "auth"] -def load_auth_keys(disable_from_config: bool): - """Load the authentication keys from api_tokens.yml. If the file does not - exist, generate new keys and save them to api_tokens.yml.""" - global AUTH_KEYS - global DISABLE_AUTH +# RBAC roles +class ROLE(Flag): + USER = auto() + ADMIN = auto() - DISABLE_AUTH = disable_from_config - if disable_from_config: - logger.warning( - "Disabling authentication makes your instance vulnerable. " - "Set the `disable_auth` flag to False in config.yml if you " - "want to share this instance with others." - ) - return +class API_KEY(BaseModel): + """stores an API key""" + + key: str = Field(..., description="the API key value") + role: ROLE = Field() + + +class AUTH_PROVIDER(ABC): + @staticmethod + def add_api_key(role: ROLE) -> API_KEY: + """add an API key""" + + @staticmethod + def set_api_key(role: ROLE, api_key: str) -> API_KEY: + """add an existing API key""" + + @staticmethod + def remove_api_key(api_key: str) -> bool: + """remove an API key""" + + @staticmethod + def check_api_key(api_key: str) -> Union[API_KEY, None]: + """check if an API key is valid""" + + @staticmethod + def authenticate_api_key(api_key: str, role: ROLE) -> bool: + """check if an api key has ROLE""" + - try: - with open("api_tokens.yml", "r", encoding="utf8") as auth_file: - auth_keys_dict = yaml.safe_load(auth_file) - AUTH_KEYS = AuthKeys.model_validate(auth_keys_dict) - except FileNotFoundError: - new_auth_keys = AuthKeys( - api_key=secrets.token_hex(16), admin_key=secrets.token_hex(16) +class SIMPLE_AUTH_PROVIDER(AUTH_PROVIDER): + api_keys: list[API_KEY] = [] + + def __init__(self) -> None: + try: + with open("api_tokens.yml", "r", encoding="utf8") as auth_file: + keys_dict: dict = yaml.safe_load(auth_file) + + # load legacy keys + admin_key = keys_dict.get("admin_key") + if admin_key: + self.set_api_key(ROLE.ADMIN, admin_key) + + admin_key = keys_dict.get("api_key") + if admin_key: + self.set_api_key(ROLE.USER, admin_key) + + # load new keys + admin_keys = keys_dict.get("admin_keys") + if admin_keys: + for key in admin_keys: + self.set_api_key(ROLE.ADMIN, key) + + user_keys = keys_dict.get("user_keys") + if admin_keys: + for key in admin_keys: + self.set_api_key(ROLE.ADMIN, key) + + except FileNotFoundError: + file = { + "admin_keys": [ + self.add_api_key(ROLE.ADMIN), + ], + "user_keys": [ + self.add_api_key(ROLE.USER), + ], + } + + with open("api_tokens.yml", "w", encoding="utf8") as auth_file: + yaml.safe_dump(file, auth_file, default_flow_style=False) + + logger.info("API keys:") + for key in self.api_keys: + logger.info(f"{key.role.name} :\t {key.key}") + logger.info( + "If these keys get compromised, make sure to delete api_tokens.yml and restart the server. Have fun!" ) - AUTH_KEYS = new_auth_keys - - with open("api_tokens.yml", "w", encoding="utf8") as auth_file: - yaml.safe_dump(AUTH_KEYS.model_dump(), auth_file, default_flow_style=False) - - logger.info( - f"Your API key is: {AUTH_KEYS.api_key}\n" - f"Your admin key is: {AUTH_KEYS.admin_key}\n\n" - "If these keys get compromised, make sure to delete api_tokens.yml " - "and restart the server. Have fun!" - ) - - -def get_key_permission(request: Request): - """ - Gets the key permission from a request. - - Internal only! Use the depends functions for incoming requests. - """ - - # Give full admin permissions if auth is disabled - if DISABLE_AUTH: - return "admin" - - # Hyphens are okay here - test_key = coalesce( - request.headers.get("x-admin-key"), - request.headers.get("x-api-key"), - request.headers.get("authorization"), - ) - - if test_key is None: - raise ValueError("The provided authentication key is missing.") - - if test_key.lower().startswith("bearer"): - test_key = test_key.split(" ")[1] - - if AUTH_KEYS.verify_key(test_key, "admin_key"): - return "admin" - elif AUTH_KEYS.verify_key(test_key, "api_key"): - return "api" - else: - raise ValueError("The provided authentication key is invalid.") - - -async def check_api_key( - x_api_key: str = Header(None), authorization: str = Header(None) -): - """Check if the API key is valid.""" - - # Allow request if auth is disabled - if DISABLE_AUTH: - return - - if x_api_key: - if not AUTH_KEYS.verify_key(x_api_key, "api_key"): - raise HTTPException(401, "Invalid API key") - return x_api_key - - if authorization: - split_key = authorization.split(" ") - if len(split_key) < 2: - raise HTTPException(401, "Invalid API key") - if split_key[0].lower() != "bearer" or not AUTH_KEYS.verify_key( - split_key[1], "api_key" - ): - raise HTTPException(401, "Invalid API key") - return authorization + def add_api_key(self, role: ROLE) -> API_KEY: + return self.set_api_key(key=secrets.token_hex(16), role=role) - raise HTTPException(401, "Please provide an API key") + def set_api_key(self, role: ROLE, api_key: str) -> API_KEY: + key = API_KEY(key=api_key, role=role) + self.api_keys.append(key) + return key + def remove_api_key(self, api_key: str) -> bool: + for key in self.api_keys: + if key.key == api_key: + self.api_keys.remove(key) + return True + return False -async def check_admin_key( - x_admin_key: str = Header(None), authorization: str = Header(None) -): - """Check if the admin key is valid.""" + def check_api_key(self, api_key: str) -> Union[API_KEY, None]: + for key in self.api_keys: + if key.key == api_key: + return key + return None - # Allow request if auth is disabled - if DISABLE_AUTH: - return + def authenticate_api_key(self, api_key: str, role: ROLE) -> bool: + key = self.check_api_key(api_key) + print(f"#### {key=}") + if not key: + return False + return key.role & role # if key.role in role - if x_admin_key: - if not AUTH_KEYS.verify_key(x_admin_key, "admin_key"): - raise HTTPException(401, "Invalid admin key") - return x_admin_key - if authorization: - split_key = authorization.split(" ") - if len(split_key) < 2: - raise HTTPException(401, "Invalid admin key") - if split_key[0].lower() != "bearer" or not AUTH_KEYS.verify_key( - split_key[1], "admin_key" +class NOAUTH_AUTH_PROVIDER(AUTH_PROVIDER): + def add_api_key(self, role: ROLE) -> API_KEY: + return API_KEY(key=secrets.token_hex(16), role=role) + + def set_api_key(self, role: ROLE, api_key: str) -> API_KEY: + return API_KEY(key=secrets.token_hex(16), role=role) + + def remove_api_key(self, api_key: str) -> bool: + return True + + def check_api_key(self, api_key: str) -> Union[API_KEY, None]: + return API_KEY(key=secrets.token_hex(16), role=ROLE.ADMIN) + + def authenticate_api_key(self, api_key: str, role: ROLE) -> bool: + return True + + +class AUTH_PROVIDER_CONTAINER: + provider: AUTH_PROVIDER + + def load(self, disable_from_config: bool): + """Load the authentication keys from api_tokens.yml. If the file does not + exist, generate new keys and save them to api_tokens.yml.""" + + # TODO: Make provider a paramater instead of disable_from_config + provider = "noauth" if disable_from_config else "simple" + + # allows for more types of providers + provider_class = { + "noauth": NOAUTH_AUTH_PROVIDER, + "simple": SIMPLE_AUTH_PROVIDER, + }.get(provider) + + if not provider_class: + raise Exception() + + if provider_class == NOAUTH_AUTH_PROVIDER: + logger.warning( + "Disabling authentication makes your instance vulnerable. " + "Set the `disable_auth` flag to False in config.yml if you " + "want to share this instance with others." + ) + + self.provider = provider_class() + + # by returning a dynamic dependency we can have one function where we can specify what roles can access the endpoint + def check_api_key(self, role: ROLE): + """Check if the API key is valid.""" + + async def check( + x_api_key: str = Header(None), authorization: str = Header(None) ): - raise HTTPException(401, "Invalid admin key") - return authorization + if x_api_key: + if not self.provider.authenticate_api_key(x_api_key, role): + raise HTTPException(401, "Invalid API key") + return x_api_key + + if authorization: + split_key = authorization.split(" ") + if len(split_key) < 2: + raise HTTPException(401, "Invalid API key") + if split_key[ + 0 + ].lower() != "bearer" or not self.provider.authenticate_api_key( + split_key[1], role + ): + raise HTTPException(401, "Invalid API key") + + return authorization + + raise HTTPException(401, "Please provide an API key") + + return check + - raise HTTPException(401, "Please provide an admin key") +auth = AUTH_PROVIDER_CONTAINER() diff --git a/endpoints/Kobold/router.py b/endpoints/Kobold/router.py index 334bae29..d80aeb21 100644 --- a/endpoints/Kobold/router.py +++ b/endpoints/Kobold/router.py @@ -3,7 +3,7 @@ from sse_starlette import EventSourceResponse from common import model -from common.auth import check_api_key +from common.auth import auth, ROLE from common.model import check_model_container from common.utils import unwrap from endpoints.core.utils.model import get_current_model @@ -45,7 +45,10 @@ def setup(): @kai_router.post( "/generate", - dependencies=[Depends(check_api_key), Depends(check_model_container)], + dependencies=[ + Depends(auth.check_api_key(ROLE.USER | ROLE.ADMIN)), + Depends(check_model_container), + ], ) async def generate(request: Request, data: GenerateRequest) -> GenerateResponse: response = await get_generation(data, request) @@ -55,7 +58,10 @@ async def generate(request: Request, data: GenerateRequest) -> GenerateResponse: @extra_kai_router.post( "/generate/stream", - dependencies=[Depends(check_api_key), Depends(check_model_container)], + dependencies=[ + Depends(auth.check_api_key(ROLE.USER | ROLE.ADMIN)), + Depends(check_model_container), + ], ) async def generate_stream(request: Request, data: GenerateRequest) -> GenerateResponse: response = EventSourceResponse(stream_generation(data, request), ping=maxsize) @@ -65,7 +71,10 @@ async def generate_stream(request: Request, data: GenerateRequest) -> GenerateRe @extra_kai_router.post( "/abort", - dependencies=[Depends(check_api_key), Depends(check_model_container)], + dependencies=[ + Depends(auth.check_api_key(ROLE.USER | ROLE.ADMIN)), + Depends(check_model_container), + ], ) async def abort_generate(data: AbortRequest) -> AbortResponse: response = await abort_generation(data.genkey) @@ -75,11 +84,17 @@ async def abort_generate(data: AbortRequest) -> AbortResponse: @extra_kai_router.get( "/generate/check", - dependencies=[Depends(check_api_key), Depends(check_model_container)], + dependencies=[ + Depends(auth.check_api_key(ROLE.USER | ROLE.ADMIN)), + Depends(check_model_container), + ], ) @extra_kai_router.post( "/generate/check", - dependencies=[Depends(check_api_key), Depends(check_model_container)], + dependencies=[ + Depends(auth.check_api_key(ROLE.USER | ROLE.ADMIN)), + Depends(check_model_container), + ], ) async def check_generate(data: CheckGenerateRequest) -> GenerateResponse: response = await generation_status(data.genkey) @@ -88,7 +103,11 @@ async def check_generate(data: CheckGenerateRequest) -> GenerateResponse: @kai_router.get( - "/model", dependencies=[Depends(check_api_key), Depends(check_model_container)] + "/model", + dependencies=[ + Depends(auth.check_api_key(ROLE.USER | ROLE.ADMIN)), + Depends(check_model_container), + ], ) async def current_model() -> CurrentModelResponse: """Fetches the current model and who owns it.""" @@ -99,7 +118,10 @@ async def current_model() -> CurrentModelResponse: @extra_kai_router.post( "/tokencount", - dependencies=[Depends(check_api_key), Depends(check_model_container)], + dependencies=[ + Depends(auth.check_api_key(ROLE.USER | ROLE.ADMIN)), + Depends(check_model_container), + ], ) async def get_tokencount(data: TokenCountRequest) -> TokenCountResponse: raw_tokens = model.container.encode_tokens(data.prompt) @@ -109,15 +131,24 @@ async def get_tokencount(data: TokenCountRequest) -> TokenCountResponse: @kai_router.get( "/config/max_length", - dependencies=[Depends(check_api_key), Depends(check_model_container)], + dependencies=[ + Depends(auth.check_api_key(ROLE.USER | ROLE.ADMIN)), + Depends(check_model_container), + ], ) @kai_router.get( "/config/max_context_length", - dependencies=[Depends(check_api_key), Depends(check_model_container)], + dependencies=[ + Depends(auth.check_api_key(ROLE.USER | ROLE.ADMIN)), + Depends(check_model_container), + ], ) @extra_kai_router.get( "/true_max_context_length", - dependencies=[Depends(check_api_key), Depends(check_model_container)], + dependencies=[ + Depends(auth.check_api_key(ROLE.USER | ROLE.ADMIN)), + Depends(check_model_container), + ], ) async def get_max_length() -> MaxLengthResponse: """Fetches the max length of the model.""" diff --git a/endpoints/OAI/router.py b/endpoints/OAI/router.py index 12c95a26..e63e5005 100644 --- a/endpoints/OAI/router.py +++ b/endpoints/OAI/router.py @@ -4,7 +4,7 @@ from sys import maxsize from common import model -from common.auth import check_api_key +from common.auth import auth, ROLE from common.model import check_embeddings_container, check_model_container from common.networking import handle_request_error, run_with_request_disconnect from common.tabby_config import config @@ -43,7 +43,7 @@ def setup(): # Completions endpoint @router.post( "/v1/completions", - dependencies=[Depends(check_api_key)], + dependencies=[Depends(auth.check_api_key(ROLE.USER | ROLE.ADMIN))], ) async def completion_request( request: Request, data: CompletionRequest @@ -93,7 +93,7 @@ async def completion_request( # Chat completions endpoint @router.post( "/v1/chat/completions", - dependencies=[Depends(check_api_key)], + dependencies=[Depends(auth.check_api_key(ROLE.USER | ROLE.ADMIN))], ) async def chat_completion_request( request: Request, data: ChatCompletionRequest @@ -153,7 +153,10 @@ async def chat_completion_request( # Embeddings endpoint @router.post( "/v1/embeddings", - dependencies=[Depends(check_api_key), Depends(check_embeddings_container)], + dependencies=[ + Depends(auth.check_api_key(ROLE.USER | ROLE.ADMIN)), + Depends(check_embeddings_container), + ], ) async def embeddings(request: Request, data: EmbeddingsRequest) -> EmbeddingsResponse: embeddings_task = asyncio.create_task(get_embeddings(data, request)) diff --git a/endpoints/OAI/utils/completion.py b/endpoints/OAI/utils/completion.py index d2795455..5a7900ec 100644 --- a/endpoints/OAI/utils/completion.py +++ b/endpoints/OAI/utils/completion.py @@ -14,7 +14,7 @@ from loguru import logger from common import model -from common.auth import get_key_permission +from common.auth import auth, ROLE from common.networking import ( get_generator_error, handle_request_disconnect, @@ -117,7 +117,7 @@ async def load_inline_model(model_name: str, request: Request): return # Inline model loading isn't enabled or the user isn't an admin - if not get_key_permission(request) == "admin": + if not auth.provider.check_api_key(request).role == ROLE.ADMIN: error_message = handle_request_error( f"Unable to switch model to {model_name} because " + "an admin key isn't provided", diff --git a/endpoints/core/router.py b/endpoints/core/router.py index 1f9d1948..013efd66 100644 --- a/endpoints/core/router.py +++ b/endpoints/core/router.py @@ -1,11 +1,12 @@ import asyncio import pathlib +from typing import Annotated from sys import maxsize from fastapi import APIRouter, Depends, HTTPException, Request from sse_starlette import EventSourceResponse from common import model, sampling -from common.auth import check_admin_key, check_api_key, get_key_permission +from common.auth import auth, ROLE from common.downloader import hf_repo_download from common.model import check_embeddings_container, check_model_container from common.networking import handle_request_error, run_with_request_disconnect @@ -53,9 +54,12 @@ async def healthcheck(): # Model list endpoint -@router.get("/v1/models", dependencies=[Depends(check_api_key)]) -@router.get("/v1/model/list", dependencies=[Depends(check_api_key)]) -async def list_models(request: Request) -> ModelList: +@router.get("/v1/models") +@router.get("/v1/model/list") +async def list_models( + request: Request, + api_key: Annotated[str, Depends(auth.check_api_key(ROLE.USER | ROLE.ADMIN))], +) -> ModelList: """ Lists all models in the model directory. @@ -67,7 +71,7 @@ async def list_models(request: Request) -> ModelList: draft_model_dir = config.draft_model.get("draft_model_dir") - if get_key_permission(request) == "admin": + if auth.provider.check_api_key(request).role == ROLE.ADMIN: models = get_model_list(model_path.resolve(), draft_model_dir) else: models = await get_current_model_list() @@ -81,7 +85,10 @@ async def list_models(request: Request) -> ModelList: # Currently loaded model endpoint @router.get( "/v1/model", - dependencies=[Depends(check_api_key), Depends(check_model_container)], + dependencies=[ + Depends(auth.check_api_key(ROLE.USER | ROLE.ADMIN)), + Depends(check_model_container), + ], ) async def current_model() -> ModelCard: """Returns the currently loaded model.""" @@ -89,7 +96,10 @@ async def current_model() -> ModelCard: return get_current_model() -@router.get("/v1/model/draft/list", dependencies=[Depends(check_api_key)]) +@router.get( + "/v1/model/draft/list", + dependencies=[Depends(auth.check_api_key(ROLE.USER | ROLE.ADMIN))], +) async def list_draft_models(request: Request) -> ModelList: """ Lists all draft models in the model directory. @@ -97,7 +107,7 @@ async def list_draft_models(request: Request) -> ModelList: Requires an admin key to see all draft models. """ - if get_key_permission(request) == "admin": + if auth.provider.check_api_key(request).role == ROLE.ADMIN: draft_model_dir = unwrap(config.draft_model.get("draft_model_dir"), "models") draft_model_path = pathlib.Path(draft_model_dir) @@ -109,7 +119,7 @@ async def list_draft_models(request: Request) -> ModelList: # Load model endpoint -@router.post("/v1/model/load", dependencies=[Depends(check_admin_key)]) +@router.post("/v1/model/load", dependencies=[Depends(auth.check_api_key(ROLE.ADMIN))]) async def load_model(data: ModelLoadRequest) -> ModelLoadResponse: """Loads a model into the model container. This returns an SSE stream.""" @@ -153,14 +163,17 @@ async def load_model(data: ModelLoadRequest) -> ModelLoadResponse: # Unload model endpoint @router.post( "/v1/model/unload", - dependencies=[Depends(check_admin_key), Depends(check_model_container)], + dependencies=[ + Depends(auth.check_api_key(ROLE.ADMIN)), + Depends(check_model_container), + ], ) async def unload_model(): """Unloads the currently loaded model.""" await model.unload_model(skip_wait=True) -@router.post("/v1/download", dependencies=[Depends(check_admin_key)]) +@router.post("/v1/download", dependencies=[Depends(auth.check_api_key(ROLE.ADMIN))]) async def download_model(request: Request, data: DownloadRequest) -> DownloadResponse: """Downloads a model from HuggingFace.""" @@ -182,8 +195,12 @@ async def download_model(request: Request, data: DownloadRequest) -> DownloadRes # Lora list endpoint -@router.get("/v1/loras", dependencies=[Depends(check_api_key)]) -@router.get("/v1/lora/list", dependencies=[Depends(check_api_key)]) +@router.get( + "/v1/loras", dependencies=[Depends(auth.check_api_key(ROLE.USER | ROLE.ADMIN))] +) +@router.get( + "/v1/lora/list", dependencies=[Depends(auth.check_api_key(ROLE.USER | ROLE.ADMIN))] +) async def list_all_loras(request: Request) -> LoraList: """ Lists all LoRAs in the lora directory. @@ -191,7 +208,7 @@ async def list_all_loras(request: Request) -> LoraList: Requires an admin key to see all LoRAs. """ - if get_key_permission(request) == "admin": + if auth.provider.check_api_key(request).role == ROLE.ADMIN: lora_path = pathlib.Path(unwrap(config.lora.get("lora_dir"), "loras")) loras = get_lora_list(lora_path.resolve()) else: @@ -203,7 +220,10 @@ async def list_all_loras(request: Request) -> LoraList: # Currently loaded loras endpoint @router.get( "/v1/lora", - dependencies=[Depends(check_api_key), Depends(check_model_container)], + dependencies=[ + Depends(auth.check_api_key(ROLE.USER | ROLE.ADMIN)), + Depends(check_model_container), + ], ) async def active_loras() -> LoraList: """Returns the currently loaded loras.""" @@ -214,7 +234,10 @@ async def active_loras() -> LoraList: # Load lora endpoint @router.post( "/v1/lora/load", - dependencies=[Depends(check_admin_key), Depends(check_model_container)], + dependencies=[ + Depends(auth.check_api_key(ROLE.ADMIN)), + Depends(check_model_container), + ], ) async def load_lora(data: LoraLoadRequest) -> LoraLoadResponse: """Loads a LoRA into the model container.""" @@ -249,7 +272,10 @@ async def load_lora(data: LoraLoadRequest) -> LoraLoadResponse: # Unload lora endpoint @router.post( "/v1/lora/unload", - dependencies=[Depends(check_admin_key), Depends(check_model_container)], + dependencies=[ + Depends(auth.check_api_key(ROLE.ADMIN)), + Depends(check_model_container), + ], ) async def unload_loras(): """Unloads the currently loaded loras.""" @@ -257,7 +283,10 @@ async def unload_loras(): await model.unload_loras() -@router.get("/v1/model/embedding/list", dependencies=[Depends(check_api_key)]) +@router.get( + "/v1/model/embedding/list", + dependencies=[Depends(auth.check_api_key(ROLE.USER | ROLE.ADMIN))], +) async def list_embedding_models(request: Request) -> ModelList: """ Lists all embedding models in the model directory. @@ -265,7 +294,7 @@ async def list_embedding_models(request: Request) -> ModelList: Requires an admin key to see all embedding models. """ - if get_key_permission(request) == "admin": + if auth.provider.check_api_key(request).role == ROLE.ADMIN: embedding_model_dir = unwrap( config.embeddings.get("embedding_model_dir"), "models" ) @@ -280,7 +309,10 @@ async def list_embedding_models(request: Request) -> ModelList: @router.get( "/v1/model/embedding", - dependencies=[Depends(check_api_key), Depends(check_embeddings_container)], + dependencies=[ + Depends(auth.check_api_key(ROLE.USER | ROLE.ADMIN)), + Depends(check_embeddings_container), + ], ) async def get_embedding_model() -> ModelCard: """Returns the currently loaded embedding model.""" @@ -289,7 +321,9 @@ async def get_embedding_model() -> ModelCard: return models.data[0] -@router.post("/v1/model/embedding/load", dependencies=[Depends(check_admin_key)]) +@router.post( + "/v1/model/embedding/load", dependencies=[Depends(auth.check_api_key(ROLE.ADMIN))] +) async def load_embedding_model( request: Request, data: EmbeddingModelLoadRequest ) -> ModelLoadResponse: @@ -337,7 +371,10 @@ async def load_embedding_model( @router.post( "/v1/model/embedding/unload", - dependencies=[Depends(check_admin_key), Depends(check_embeddings_container)], + dependencies=[ + Depends(auth.check_api_key(ROLE.ADMIN)), + Depends(check_embeddings_container), + ], ) async def unload_embedding_model(): """Unloads the current embedding model.""" @@ -348,7 +385,10 @@ async def unload_embedding_model(): # Encode tokens endpoint @router.post( "/v1/token/encode", - dependencies=[Depends(check_api_key), Depends(check_model_container)], + dependencies=[ + Depends(auth.check_api_key(ROLE.USER | ROLE.ADMIN)), + Depends(check_model_container), + ], ) async def encode_tokens(data: TokenEncodeRequest) -> TokenEncodeResponse: """Encodes a string or chat completion messages into tokens.""" @@ -378,7 +418,10 @@ async def encode_tokens(data: TokenEncodeRequest) -> TokenEncodeResponse: # Decode tokens endpoint @router.post( "/v1/token/decode", - dependencies=[Depends(check_api_key), Depends(check_model_container)], + dependencies=[ + Depends(auth.check_api_key(ROLE.USER | ROLE.ADMIN)), + Depends(check_model_container), + ], ) async def decode_tokens(data: TokenDecodeRequest) -> TokenDecodeResponse: """Decodes tokens into a string.""" @@ -389,7 +432,10 @@ async def decode_tokens(data: TokenDecodeRequest) -> TokenDecodeResponse: return response -@router.get("/v1/auth/permission", dependencies=[Depends(check_api_key)]) +@router.get( + "/v1/auth/permission", + dependencies=[Depends(auth.check_api_key(ROLE.USER | ROLE.ADMIN))], +) async def key_permission(request: Request) -> AuthPermissionResponse: """ Gets the access level/permission of a provided key in headers. @@ -409,8 +455,13 @@ async def key_permission(request: Request) -> AuthPermissionResponse: raise HTTPException(400, error_message) from exc -@router.get("/v1/templates", dependencies=[Depends(check_api_key)]) -@router.get("/v1/template/list", dependencies=[Depends(check_api_key)]) +@router.get( + "/v1/templates", dependencies=[Depends(auth.check_api_key(ROLE.USER | ROLE.ADMIN))] +) +@router.get( + "/v1/template/list", + dependencies=[Depends(auth.check_api_key(ROLE.USER | ROLE.ADMIN))], +) async def list_templates(request: Request) -> TemplateList: """ Get a list of all templates. @@ -419,7 +470,7 @@ async def list_templates(request: Request) -> TemplateList: """ template_strings = [] - if get_key_permission(request) == "admin": + if auth.provider.check_api_key(request).role == ROLE.ADMIN: templates = get_all_templates() template_strings = [template.stem for template in templates] else: @@ -431,7 +482,10 @@ async def list_templates(request: Request) -> TemplateList: @router.post( "/v1/template/switch", - dependencies=[Depends(check_admin_key), Depends(check_model_container)], + dependencies=[ + Depends(auth.check_api_key(ROLE.ADMIN)), + Depends(check_model_container), + ], ) async def switch_template(data: TemplateSwitchRequest): """Switch the currently loaded template.""" @@ -457,7 +511,10 @@ async def switch_template(data: TemplateSwitchRequest): @router.post( "/v1/template/unload", - dependencies=[Depends(check_admin_key), Depends(check_model_container)], + dependencies=[ + Depends(auth.check_api_key(ROLE.ADMIN)), + Depends(check_model_container), + ], ) async def unload_template(): """Unloads the currently selected template""" @@ -466,8 +523,14 @@ async def unload_template(): # Sampler override endpoints -@router.get("/v1/sampling/overrides", dependencies=[Depends(check_api_key)]) -@router.get("/v1/sampling/override/list", dependencies=[Depends(check_api_key)]) +@router.get( + "/v1/sampling/overrides", + dependencies=[Depends(auth.check_api_key(ROLE.USER | ROLE.ADMIN))], +) +@router.get( + "/v1/sampling/override/list", + dependencies=[Depends(auth.check_api_key(ROLE.USER | ROLE.ADMIN))], +) async def list_sampler_overrides(request: Request) -> SamplerOverrideListResponse: """ List all currently applied sampler overrides. @@ -475,7 +538,7 @@ async def list_sampler_overrides(request: Request) -> SamplerOverrideListRespons Requires an admin key to see all override presets. """ - if get_key_permission(request) == "admin": + if auth.provider.check_api_key(request).role == ROLE.ADMIN: presets = sampling.get_all_presets() else: presets = [] @@ -487,7 +550,7 @@ async def list_sampler_overrides(request: Request) -> SamplerOverrideListRespons @router.post( "/v1/sampling/override/switch", - dependencies=[Depends(check_admin_key)], + dependencies=[Depends(auth.check_api_key(ROLE.ADMIN))], ) async def switch_sampler_override(data: SamplerOverrideSwitchRequest): """Switch the currently loaded override preset""" @@ -516,7 +579,7 @@ async def switch_sampler_override(data: SamplerOverrideSwitchRequest): @router.post( "/v1/sampling/override/unload", - dependencies=[Depends(check_admin_key)], + dependencies=[Depends(auth.check_api_key(ROLE.ADMIN))], ) async def unload_sampler_override(): """Unloads the currently selected override preset""" diff --git a/main.py b/main.py index 740e1d05..58337a00 100644 --- a/main.py +++ b/main.py @@ -11,7 +11,7 @@ from common import gen_logging, sampling, model from common.args import convert_args_to_dict, init_argparser -from common.auth import load_auth_keys +from common.auth import auth from common.logger import setup_logger from common.networking import is_port_in_use from common.signals import signal_handler @@ -50,7 +50,7 @@ async def entrypoint_async(): port = fallback_port # Initialize auth keys - load_auth_keys(unwrap(config.network.get("disable_auth"), False)) + auth.load(unwrap(config.network.get("disable_auth"), False)) # Override the generation log options if given if config.logging: From 8a1b82eae159d69325050ac4601d620f0284b9e2 Mon Sep 17 00:00:00 2001 From: TerminalMan <84923604+SecretiveShell@users.noreply.github.com> Date: Fri, 6 Sep 2024 23:03:55 +0100 Subject: [PATCH 2/6] remove print statement --- common/auth.py | 1 - 1 file changed, 1 deletion(-) diff --git a/common/auth.py b/common/auth.py index 194d2991..b73c47ad 100644 --- a/common/auth.py +++ b/common/auth.py @@ -123,7 +123,6 @@ def check_api_key(self, api_key: str) -> Union[API_KEY, None]: def authenticate_api_key(self, api_key: str, role: ROLE) -> bool: key = self.check_api_key(api_key) - print(f"#### {key=}") if not key: return False return key.role & role # if key.role in role From 2f60a64d588898fa1db643f58d6240c95ce48066 Mon Sep 17 00:00:00 2001 From: TerminalMan <84923604+SecretiveShell@users.noreply.github.com> Date: Fri, 6 Sep 2024 23:23:29 +0100 Subject: [PATCH 3/6] fix permissions check and ruff format --- common/auth.py | 10 +++++++--- endpoints/core/router.py | 4 ++-- 2 files changed, 9 insertions(+), 5 deletions(-) diff --git a/common/auth.py b/common/auth.py index b73c47ad..6b31ca02 100644 --- a/common/auth.py +++ b/common/auth.py @@ -5,14 +5,13 @@ import secrets import yaml -from fastapi import Header, HTTPException, Request +from fastapi import Header, HTTPException from pydantic import BaseModel, Field from loguru import logger -from typing import Optional, Union +from typing import Union from enum import Flag, auto from abc import ABC, abstractmethod -from common.utils import coalesce, unwrap __all__ = ["ROLE", "auth"] @@ -32,22 +31,27 @@ class API_KEY(BaseModel): class AUTH_PROVIDER(ABC): @staticmethod + @abstractmethod def add_api_key(role: ROLE) -> API_KEY: """add an API key""" @staticmethod + @abstractmethod def set_api_key(role: ROLE, api_key: str) -> API_KEY: """add an existing API key""" @staticmethod + @abstractmethod def remove_api_key(api_key: str) -> bool: """remove an API key""" @staticmethod + @abstractmethod def check_api_key(api_key: str) -> Union[API_KEY, None]: """check if an API key is valid""" @staticmethod + @abstractmethod def authenticate_api_key(api_key: str, role: ROLE) -> bool: """check if an api key has ROLE""" diff --git a/endpoints/core/router.py b/endpoints/core/router.py index 013efd66..5bb9613f 100644 --- a/endpoints/core/router.py +++ b/endpoints/core/router.py @@ -436,7 +436,7 @@ async def decode_tokens(data: TokenDecodeRequest) -> TokenDecodeResponse: "/v1/auth/permission", dependencies=[Depends(auth.check_api_key(ROLE.USER | ROLE.ADMIN))], ) -async def key_permission(request: Request) -> AuthPermissionResponse: +async def key_permission(request: Request, api_key: Annotated[str, Depends(auth.check_api_key(ROLE.USER | ROLE.ADMIN))],) -> AuthPermissionResponse: """ Gets the access level/permission of a provided key in headers. @@ -447,7 +447,7 @@ async def key_permission(request: Request) -> AuthPermissionResponse: """ try: - permission = get_key_permission(request) + permission = auth.provider.check_api_key(api_key).role.name return AuthPermissionResponse(permission=permission) except ValueError as exc: error_message = handle_request_error(str(exc)).error.message From 54418d0401be7c2962a3a6013444f6b37c023fe9 Mon Sep 17 00:00:00 2001 From: TerminalMan <84923604+SecretiveShell@users.noreply.github.com> Date: Sat, 7 Sep 2024 00:11:35 +0100 Subject: [PATCH 4/6] fix endpoints with internal ROLE checks --- common/auth.py | 22 +++++----- endpoints/OAI/router.py | 19 +++++---- endpoints/OAI/utils/completion.py | 4 +- endpoints/core/router.py | 69 ++++++++++++++++--------------- 4 files changed, 59 insertions(+), 55 deletions(-) diff --git a/common/auth.py b/common/auth.py index 6b31ca02..a46a505d 100644 --- a/common/auth.py +++ b/common/auth.py @@ -80,7 +80,7 @@ def __init__(self) -> None: self.set_api_key(ROLE.ADMIN, key) user_keys = keys_dict.get("user_keys") - if admin_keys: + if user_keys: for key in admin_keys: self.set_api_key(ROLE.ADMIN, key) @@ -101,7 +101,8 @@ def __init__(self) -> None: for key in self.api_keys: logger.info(f"{key.role.name} :\t {key.key}") logger.info( - "If these keys get compromised, make sure to delete api_tokens.yml and restart the server. Have fun!" + "If these keys get compromised, make sure to delete \ + api_tokens.yml and restart the server. Have fun!" ) def add_api_key(self, role: ROLE) -> API_KEY: @@ -177,7 +178,8 @@ def load(self, disable_from_config: bool): self.provider = provider_class() - # by returning a dynamic dependency we can have one function where we can specify what roles can access the endpoint + # by returning a dynamic dependency we can have one function + # where we can specify what roles can access the endpoint def check_api_key(self, role: ROLE): """Check if the API key is valid.""" @@ -185,22 +187,20 @@ async def check( x_api_key: str = Header(None), authorization: str = Header(None) ): if x_api_key: - if not self.provider.authenticate_api_key(x_api_key, role): + key = self.provider.authenticate_api_key(x_api_key, role) + if not key: raise HTTPException(401, "Invalid API key") - return x_api_key + return key if authorization: split_key = authorization.split(" ") if len(split_key) < 2: raise HTTPException(401, "Invalid API key") - if split_key[ - 0 - ].lower() != "bearer" or not self.provider.authenticate_api_key( - split_key[1], role - ): + key = self.provider.authenticate_api_key(split_key[1], role) + if split_key[0].lower() != "bearer" or not key: raise HTTPException(401, "Invalid API key") - return authorization + return key raise HTTPException(401, "Please provide an API key") diff --git a/endpoints/OAI/router.py b/endpoints/OAI/router.py index e63e5005..801e0cb9 100644 --- a/endpoints/OAI/router.py +++ b/endpoints/OAI/router.py @@ -2,6 +2,7 @@ from fastapi import APIRouter, Depends, HTTPException, Request from sse_starlette import EventSourceResponse from sys import maxsize +from typing import Annotated from common import model from common.auth import auth, ROLE @@ -41,12 +42,11 @@ def setup(): # Completions endpoint -@router.post( - "/v1/completions", - dependencies=[Depends(auth.check_api_key(ROLE.USER | ROLE.ADMIN))], -) +@router.post("/v1/completions") async def completion_request( - request: Request, data: CompletionRequest + request: Request, + data: CompletionRequest, + user_role: Annotated[ROLE, Depends(auth.check_api_key(ROLE.USER | ROLE.ADMIN))], ) -> CompletionResponse: """ Generates a completion from a prompt. @@ -55,7 +55,7 @@ async def completion_request( """ if data.model: - await load_inline_model(data.model, request) + await load_inline_model(data.model, user_role) else: await check_model_container() @@ -93,10 +93,11 @@ async def completion_request( # Chat completions endpoint @router.post( "/v1/chat/completions", - dependencies=[Depends(auth.check_api_key(ROLE.USER | ROLE.ADMIN))], ) async def chat_completion_request( - request: Request, data: ChatCompletionRequest + request: Request, + data: ChatCompletionRequest, + user_role: Annotated[ROLE, Depends(auth.check_api_key(ROLE.USER | ROLE.ADMIN))], ) -> ChatCompletionResponse: """ Generates a chat completion from a prompt. @@ -105,7 +106,7 @@ async def chat_completion_request( """ if data.model: - await load_inline_model(data.model, request) + await load_inline_model(data.model, user_role) else: await check_model_container() diff --git a/endpoints/OAI/utils/completion.py b/endpoints/OAI/utils/completion.py index 5a7900ec..93fdd5fa 100644 --- a/endpoints/OAI/utils/completion.py +++ b/endpoints/OAI/utils/completion.py @@ -109,7 +109,7 @@ async def _stream_collector( await gen_queue.put(e) -async def load_inline_model(model_name: str, request: Request): +async def load_inline_model(model_name: str, user_role: ROLE): """Load a model from the data.model parameter""" # Return if the model container already exists @@ -117,7 +117,7 @@ async def load_inline_model(model_name: str, request: Request): return # Inline model loading isn't enabled or the user isn't an admin - if not auth.provider.check_api_key(request).role == ROLE.ADMIN: + if user_role == ROLE.ADMIN: error_message = handle_request_error( f"Unable to switch model to {model_name} because " + "an admin key isn't provided", diff --git a/endpoints/core/router.py b/endpoints/core/router.py index 5bb9613f..6d2cd4cc 100644 --- a/endpoints/core/router.py +++ b/endpoints/core/router.py @@ -58,7 +58,7 @@ async def healthcheck(): @router.get("/v1/model/list") async def list_models( request: Request, - api_key: Annotated[str, Depends(auth.check_api_key(ROLE.USER | ROLE.ADMIN))], + user_role: Annotated[ROLE, Depends(auth.check_api_key(ROLE.USER | ROLE.ADMIN))], ) -> ModelList: """ Lists all models in the model directory. @@ -71,7 +71,7 @@ async def list_models( draft_model_dir = config.draft_model.get("draft_model_dir") - if auth.provider.check_api_key(request).role == ROLE.ADMIN: + if user_role == ROLE.ADMIN: models = get_model_list(model_path.resolve(), draft_model_dir) else: models = await get_current_model_list() @@ -98,16 +98,18 @@ async def current_model() -> ModelCard: @router.get( "/v1/model/draft/list", - dependencies=[Depends(auth.check_api_key(ROLE.USER | ROLE.ADMIN))], ) -async def list_draft_models(request: Request) -> ModelList: +async def list_draft_models( + request: Request, + user_role: Annotated[ROLE, Depends(auth.check_api_key(ROLE.USER | ROLE.ADMIN))], +) -> ModelList: """ Lists all draft models in the model directory. Requires an admin key to see all draft models. """ - if auth.provider.check_api_key(request).role == ROLE.ADMIN: + if user_role == ROLE.ADMIN: draft_model_dir = unwrap(config.draft_model.get("draft_model_dir"), "models") draft_model_path = pathlib.Path(draft_model_dir) @@ -195,20 +197,19 @@ async def download_model(request: Request, data: DownloadRequest) -> DownloadRes # Lora list endpoint -@router.get( - "/v1/loras", dependencies=[Depends(auth.check_api_key(ROLE.USER | ROLE.ADMIN))] -) -@router.get( - "/v1/lora/list", dependencies=[Depends(auth.check_api_key(ROLE.USER | ROLE.ADMIN))] -) -async def list_all_loras(request: Request) -> LoraList: +@router.get("/v1/loras") +@router.get("/v1/lora/list") +async def list_all_loras( + request: Request, + user_role: Annotated[ROLE, Depends(auth.check_api_key(ROLE.USER | ROLE.ADMIN))], +) -> LoraList: """ Lists all LoRAs in the lora directory. Requires an admin key to see all LoRAs. """ - if auth.provider.check_api_key(request).role == ROLE.ADMIN: + if user_role == ROLE.ADMIN: lora_path = pathlib.Path(unwrap(config.lora.get("lora_dir"), "loras")) loras = get_lora_list(lora_path.resolve()) else: @@ -283,18 +284,18 @@ async def unload_loras(): await model.unload_loras() -@router.get( - "/v1/model/embedding/list", - dependencies=[Depends(auth.check_api_key(ROLE.USER | ROLE.ADMIN))], -) -async def list_embedding_models(request: Request) -> ModelList: +@router.get("/v1/model/embedding/list") +async def list_embedding_models( + request: Request, + user_role: Annotated[ROLE, Depends(auth.check_api_key(ROLE.USER | ROLE.ADMIN))], +) -> ModelList: """ Lists all embedding models in the model directory. Requires an admin key to see all embedding models. """ - if auth.provider.check_api_key(request).role == ROLE.ADMIN: + if user_role == ROLE.ADMIN: embedding_model_dir = unwrap( config.embeddings.get("embedding_model_dir"), "models" ) @@ -436,7 +437,10 @@ async def decode_tokens(data: TokenDecodeRequest) -> TokenDecodeResponse: "/v1/auth/permission", dependencies=[Depends(auth.check_api_key(ROLE.USER | ROLE.ADMIN))], ) -async def key_permission(request: Request, api_key: Annotated[str, Depends(auth.check_api_key(ROLE.USER | ROLE.ADMIN))],) -> AuthPermissionResponse: +async def key_permission( + request: Request, + user_role: Annotated[ROLE, Depends(auth.check_api_key(ROLE.USER | ROLE.ADMIN))], +) -> AuthPermissionResponse: """ Gets the access level/permission of a provided key in headers. @@ -455,14 +459,12 @@ async def key_permission(request: Request, api_key: Annotated[str, Depends(auth. raise HTTPException(400, error_message) from exc -@router.get( - "/v1/templates", dependencies=[Depends(auth.check_api_key(ROLE.USER | ROLE.ADMIN))] -) -@router.get( - "/v1/template/list", - dependencies=[Depends(auth.check_api_key(ROLE.USER | ROLE.ADMIN))], -) -async def list_templates(request: Request) -> TemplateList: +@router.get("/v1/templates") +@router.get("/v1/template/list") +async def list_templates( + request: Request, + user_role: Annotated[ROLE, Depends(auth.check_api_key(ROLE.USER | ROLE.ADMIN))], +) -> TemplateList: """ Get a list of all templates. @@ -470,7 +472,7 @@ async def list_templates(request: Request) -> TemplateList: """ template_strings = [] - if auth.provider.check_api_key(request).role == ROLE.ADMIN: + if user_role == ROLE.ADMIN: templates = get_all_templates() template_strings = [template.stem for template in templates] else: @@ -525,20 +527,21 @@ async def unload_template(): # Sampler override endpoints @router.get( "/v1/sampling/overrides", - dependencies=[Depends(auth.check_api_key(ROLE.USER | ROLE.ADMIN))], ) @router.get( "/v1/sampling/override/list", - dependencies=[Depends(auth.check_api_key(ROLE.USER | ROLE.ADMIN))], ) -async def list_sampler_overrides(request: Request) -> SamplerOverrideListResponse: +async def list_sampler_overrides( + request: Request, + user_role: Annotated[ROLE, Depends(auth.check_api_key(ROLE.USER | ROLE.ADMIN))], +) -> SamplerOverrideListResponse: """ List all currently applied sampler overrides. Requires an admin key to see all override presets. """ - if auth.provider.check_api_key(request).role == ROLE.ADMIN: + if user_role == ROLE.ADMIN: presets = sampling.get_all_presets() else: presets = [] From 2889470e867d5021692d312d6241e8e8619c09fc Mon Sep 17 00:00:00 2001 From: TerminalMan <84923604+SecretiveShell@users.noreply.github.com> Date: Sat, 7 Sep 2024 00:15:50 +0100 Subject: [PATCH 5/6] minor fixes --- endpoints/OAI/utils/completion.py | 2 +- endpoints/core/router.py | 4 +--- 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/endpoints/OAI/utils/completion.py b/endpoints/OAI/utils/completion.py index 93fdd5fa..51125339 100644 --- a/endpoints/OAI/utils/completion.py +++ b/endpoints/OAI/utils/completion.py @@ -14,7 +14,7 @@ from loguru import logger from common import model -from common.auth import auth, ROLE +from common.auth import ROLE from common.networking import ( get_generator_error, handle_request_disconnect, diff --git a/endpoints/core/router.py b/endpoints/core/router.py index 6d2cd4cc..cecc7911 100644 --- a/endpoints/core/router.py +++ b/endpoints/core/router.py @@ -435,7 +435,6 @@ async def decode_tokens(data: TokenDecodeRequest) -> TokenDecodeResponse: @router.get( "/v1/auth/permission", - dependencies=[Depends(auth.check_api_key(ROLE.USER | ROLE.ADMIN))], ) async def key_permission( request: Request, @@ -451,8 +450,7 @@ async def key_permission( """ try: - permission = auth.provider.check_api_key(api_key).role.name - return AuthPermissionResponse(permission=permission) + return AuthPermissionResponse(permission=user_role) except ValueError as exc: error_message = handle_request_error(str(exc)).error.message From 708a9e020a72a92bd9c090e4b95817c3630bb0db Mon Sep 17 00:00:00 2001 From: TerminalMan <84923604+SecretiveShell@users.noreply.github.com> Date: Sat, 7 Sep 2024 01:06:04 +0100 Subject: [PATCH 6/6] make api_tokens.yml dynamic based on ROLES --- common/auth.py | 40 +++++++++++++++++----------------------- 1 file changed, 17 insertions(+), 23 deletions(-) diff --git a/common/auth.py b/common/auth.py index a46a505d..27d7c28f 100644 --- a/common/auth.py +++ b/common/auth.py @@ -12,6 +12,7 @@ from enum import Flag, auto from abc import ABC, abstractmethod +from common.utils import unwrap __all__ = ["ROLE", "auth"] @@ -62,38 +63,31 @@ class SIMPLE_AUTH_PROVIDER(AUTH_PROVIDER): def __init__(self) -> None: try: with open("api_tokens.yml", "r", encoding="utf8") as auth_file: - keys_dict: dict = yaml.safe_load(auth_file) + keys_dict: dict = unwrap(yaml.safe_load(auth_file), {}) # load legacy keys admin_key = keys_dict.get("admin_key") if admin_key: self.set_api_key(ROLE.ADMIN, admin_key) - admin_key = keys_dict.get("api_key") - if admin_key: + user_key = keys_dict.get("api_key") + if user_key: self.set_api_key(ROLE.USER, admin_key) # load new keys - admin_keys = keys_dict.get("admin_keys") - if admin_keys: - for key in admin_keys: - self.set_api_key(ROLE.ADMIN, key) - - user_keys = keys_dict.get("user_keys") - if user_keys: - for key in admin_keys: - self.set_api_key(ROLE.ADMIN, key) + for role in ROLE : + role_keys = keys_dict.get(f"{role.name.lower()}_keys") + if role_keys: + for key in role_keys: + self.set_api_key(role, key) except FileNotFoundError: - file = { - "admin_keys": [ - self.add_api_key(ROLE.ADMIN), - ], - "user_keys": [ - self.add_api_key(ROLE.USER), - ], - } + file = {} + + for role in ROLE : + file[f"{role.name.lower()}_keys"] = [self.add_api_key(role).key for i in range(3)] + print(file) with open("api_tokens.yml", "w", encoding="utf8") as auth_file: yaml.safe_dump(file, auth_file, default_flow_style=False) @@ -101,12 +95,12 @@ def __init__(self) -> None: for key in self.api_keys: logger.info(f"{key.role.name} :\t {key.key}") logger.info( - "If these keys get compromised, make sure to delete \ - api_tokens.yml and restart the server. Have fun!" + "If these keys get compromised, make sure to delete " + + "api_tokens.yml and restart the server. Have fun!" ) def add_api_key(self, role: ROLE) -> API_KEY: - return self.set_api_key(key=secrets.token_hex(16), role=role) + return self.set_api_key(api_key=secrets.token_hex(16), role=role) def set_api_key(self, role: ROLE, api_key: str) -> API_KEY: key = API_KEY(key=api_key, role=role)