diff --git a/src/blueapi/cli/cli.py b/src/blueapi/cli/cli.py index 05592ba2f..97636a1d3 100644 --- a/src/blueapi/cli/cli.py +++ b/src/blueapi/cli/cli.py @@ -1,6 +1,7 @@ import json import logging import sys +from collections.abc import Sequence from functools import wraps from pathlib import Path from pprint import pprint @@ -11,6 +12,7 @@ from bluesky_stomp.models import Broker from observability_utils.tracing import setup_tracing from pydantic import ValidationError +from pydantic_settings.sources import PathType from requests.exceptions import ConnectionError from blueapi import __version__ @@ -18,7 +20,7 @@ from blueapi.client.client import BlueapiClient from blueapi.client.event_bus import AnyEvent, BlueskyStreamingError, EventBusClient from blueapi.client.rest import BlueskyRemoteControlError -from blueapi.config import ApplicationConfig, ConfigLoader +from blueapi.config import ApplicationConfig from blueapi.core import OTLP_EXPORT_ENABLED, DataEvent from blueapi.worker import ProgressEvent, Task, WorkerEvent @@ -26,30 +28,44 @@ from .updates import CliEventRenderer +def parse_path_type(path_type: PathType) -> list[Path]: + """ + Parse a PathType parameter and return a list of Path objects. + + :param path_type: The input which can be a Path, str, or a sequence of Path/str. + :return: A list of Path objects. + """ + if isinstance(path_type, str | Path): + # Single Path or string: Convert to Path and return as a single-element list + return [Path(path_type)] + + if isinstance(path_type, Sequence): + # Sequence of Paths/strings: Convert each element to a Path + return [Path(item) for item in path_type if isinstance(item, str | Path)] + + # If it doesn't match the expected types, raise an error + raise TypeError(f"Unsupported PathType: {type(path_type)}") + + @click.group(invoke_without_command=True) @click.version_option(version=__version__, prog_name="blueapi") @click.option( "-c", "--config", type=Path, help="Path to configuration YAML file", multiple=True ) @click.pass_context -def main(ctx: click.Context, config: Path | None | tuple[Path, ...]) -> None: - # if no command is supplied, run with the options passed - - config_loader = ConfigLoader(ApplicationConfig) - if config is not None: - configs = (config,) if isinstance(config, Path) else config - for path in configs: - if path.exists(): - config_loader.use_values_from_yaml(path) - else: - raise FileNotFoundError(f"Cannot find file: {path}") - +def main(ctx: click.Context, config: PathType) -> None: + # Override default yaml_file path in the model_config if `config` is provided + ApplicationConfig.model_config["yaml_file"] = config + app_config = ApplicationConfig() # Instantiates with customized sources + print(f"Loaded configuration {app_config}") ctx.ensure_object(dict) - loaded_config: ApplicationConfig = config_loader.load() + ctx.obj["config"] = app_config - ctx.obj["config"] = loaded_config + # note: this is the key result of the 'main' function, it loaded the config + # and due to 'pass context' flag above + # it's left for the handler of words that are later in the stdin logging.basicConfig( - format="%(asctime)s - %(message)s", level=loaded_config.logging.level + format="%(asctime)s - %(message)s", level=app_config.logging.level ) if ctx.invoked_subcommand is None: @@ -173,6 +189,7 @@ def listen_to_events(obj: dict) -> None: ) ) ) + fmt = obj["fmt"] def on_event( diff --git a/src/blueapi/config.py b/src/blueapi/config.py index 468e00f98..2b3338ee6 100644 --- a/src/blueapi/config.py +++ b/src/blueapi/config.py @@ -1,16 +1,17 @@ -from collections.abc import Mapping from enum import Enum from pathlib import Path -from typing import Any, Generic, Literal, TypeVar +from typing import Literal -import yaml from bluesky_stomp.models import BasicAuthentication -from pydantic import BaseModel, Field, TypeAdapter, ValidationError +from pydantic import BaseModel, Field +from pydantic_settings import BaseSettings, SettingsConfigDict, YamlConfigSettingsSource -from blueapi.utils import BlueapiBaseModel, InvalidConfigError +from blueapi.utils import BlueapiBaseModel LogLevel = Literal["NOTSET", "DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"] +DEFAULT_PATH = Path("config.yaml") # Default YAML file path + class SourceKind(str, Enum): PLAN_FUNCTIONS = "planFunctions" @@ -77,7 +78,8 @@ class ScratchConfig(BlueapiBaseModel): repositories: list[ScratchRepository] = Field(default_factory=list) -class ApplicationConfig(BlueapiBaseModel): +# class ApplicationConfig(BaseSettings, cli_parse_args=True, cli_prog_name="blueapi"): +class ApplicationConfig(BaseSettings): """ Config for the worker application as a whole. Root of config tree. @@ -89,83 +91,18 @@ class ApplicationConfig(BlueapiBaseModel): api: RestConfig = Field(default_factory=RestConfig) scratch: ScratchConfig | None = None - def __eq__(self, other: object) -> bool: - if isinstance(other, ApplicationConfig): - return ( - (self.stomp == other.stomp) - & (self.env == other.env) - & (self.logging == other.logging) - & (self.api == other.api) - ) - return False - - -C = TypeVar("C", bound=BaseModel) - - -class ConfigLoader(Generic[C]): - """ - Small utility class for loading config from various sources. - You must define a config schema as a dataclass (or series of - nested dataclasses) that can then be loaded from some combination - of default values, dictionaries, YAML/JSON files etc. - """ - - def __init__(self, schema: type[C]) -> None: - self._adapter = TypeAdapter(schema) - self._values: dict[str, Any] = {} - - def use_values(self, values: Mapping[str, Any]) -> None: - """ - Use all values provided in the config, override any defaults - and values set by previous calls into this class. - - Args: - values (Mapping[str, Any]): Dictionary of override values, - does not need to be exhaustive - if defaults provided. - """ - - def recursively_update_map(old: dict[str, Any], new: Mapping[str, Any]) -> None: - for key in new: - if ( - key in old - and isinstance(old[key], dict) - and isinstance(new[key], dict) - ): - recursively_update_map(old[key], new[key]) - else: - old[key] = new[key] - - recursively_update_map(self._values, values) - - def use_values_from_yaml(self, path: Path) -> None: - """ - Use all values provided in a YAML/JSON file in the - config, override any defaults and values set by - previous calls into this class. - - Args: - path (Path): Path to YAML/JSON file - """ - - with path.open("r") as stream: - values = yaml.load(stream, yaml.Loader) - self.use_values(values) - - def load(self) -> C: - """ - Finalize and load the config as an instance of the `schema` - dataclass. - - Returns: - C: Dataclass instance holding config - """ - - try: - return self._adapter.validate_python(self._values) - except ValidationError as exc: - error_details = "\n".join(str(e) for e in exc.errors()) - raise InvalidConfigError( - f"Something is wrong with the configuration file: \n {error_details}" - ) from exc + model_config = SettingsConfigDict( + env_nested_delimiter="__", yaml_file=DEFAULT_PATH, yaml_file_encoding="utf-8" + ) + + @classmethod + def settings_customize_sources( + cls, init_settings, env_settings, file_secret_settings + ): + path = cls.model_config.get("yaml_file") + return ( + init_settings, + YamlConfigSettingsSource(settings_cls=cls, yaml_file=path), + env_settings, + file_secret_settings, + ) diff --git a/src/blueapi/service/config_manager.py b/src/blueapi/service/config_manager.py new file mode 100644 index 000000000..c99455976 --- /dev/null +++ b/src/blueapi/service/config_manager.py @@ -0,0 +1,26 @@ +# config_manager.py + +from blueapi.config import ApplicationConfig + + +class ConfigManager: + """Manages application configuration in a way that’s easy to test and mock.""" + + _config: ApplicationConfig + + def __init__(self, config: ApplicationConfig = None): + if config is None: + ApplicationConfig.model_config["yaml_file"] = None + config = ApplicationConfig() + self._config = config + + def get_config(self) -> ApplicationConfig: + """Retrieve the current configuration.""" + return self._config + + def set_config(self, new_config: ApplicationConfig): + """ + This is a setter function that the main process uses + to pass the config into the subprocess + """ + self._config = new_config diff --git a/src/blueapi/service/interface.py b/src/blueapi/service/interface.py index f4d819c26..9e638845b 100644 --- a/src/blueapi/service/interface.py +++ b/src/blueapi/service/interface.py @@ -9,20 +9,18 @@ from blueapi.config import ApplicationConfig, StompConfig from blueapi.core.context import BlueskyContext from blueapi.core.event import EventStream +from blueapi.service.config_manager import ConfigManager from blueapi.service.model import DeviceModel, PlanModel, WorkerTask from blueapi.worker.event import TaskStatusEnum, WorkerState from blueapi.worker.task import Task from blueapi.worker.task_worker import TaskWorker, TrackableTask """This module provides interface between web application and underlying Bluesky -context and worker""" +context and worker +""" -_CONFIG: ApplicationConfig = ApplicationConfig() - - -def config() -> ApplicationConfig: - return _CONFIG +config_manager = ConfigManager() def set_config(new_config: ApplicationConfig): @@ -34,23 +32,23 @@ def set_config(new_config: ApplicationConfig): @cache def context() -> BlueskyContext: ctx = BlueskyContext() - ctx.with_config(config().env) + env_config = config_manager.get_config().env + ctx.with_config(env_config) return ctx @cache def worker() -> TaskWorker: - worker = TaskWorker( - context(), - broadcast_statuses=config().env.events.broadcast_status_events, - ) + env_config = config_manager.get_config().env + should_broadcast_status_events: bool = env_config.events.broadcast_status_events + worker = TaskWorker(context(), broadcast_statuses=should_broadcast_status_events) worker.start() return worker @cache def stomp_client() -> StompClient | None: - stomp_config: StompConfig | None = config().stomp + stomp_config: StompConfig | None = config_manager.get_config().stomp if stomp_config is not None: client = StompClient.for_broker( broker=Broker( @@ -79,7 +77,7 @@ def stomp_client() -> StompClient | None: def setup(config: ApplicationConfig) -> None: """Creates and starts a worker with supplied config""" - set_config(config) + config_manager.set_config(config) # Eagerly initialize worker and messaging connection diff --git a/src/blueapi/utils/__init__.py b/src/blueapi/utils/__init__.py index c81762542..b74cf5e08 100644 --- a/src/blueapi/utils/__init__.py +++ b/src/blueapi/utils/__init__.py @@ -1,5 +1,4 @@ from .base_model import BlueapiBaseModel, BlueapiModelConfig, BlueapiPlanModelConfig -from .invalid_config_error import InvalidConfigError from .modules import load_module_all from .serialization import serialize from .thread_exception import handle_all_exceptions @@ -11,5 +10,4 @@ "BlueapiBaseModel", "BlueapiModelConfig", "BlueapiPlanModelConfig", - "InvalidConfigError", ] diff --git a/src/blueapi/utils/invalid_config_error.py b/src/blueapi/utils/invalid_config_error.py deleted file mode 100644 index be99d5a9e..000000000 --- a/src/blueapi/utils/invalid_config_error.py +++ /dev/null @@ -1,3 +0,0 @@ -class InvalidConfigError(Exception): - def __init__(self, message="Configuration is invalid"): - super().__init__(message) diff --git a/tests/unit_tests/test_config.py b/tests/unit_tests/test_config.py index 69351170b..304b52273 100644 --- a/tests/unit_tests/test_config.py +++ b/tests/unit_tests/test_config.py @@ -11,8 +11,7 @@ from bluesky_stomp.models import BasicAuthentication from pydantic import BaseModel, Field -from blueapi.config import ApplicationConfig, ConfigLoader -from blueapi.utils import InvalidConfigError +from blueapi.config import ApplicationConfig class Config(BaseModel): @@ -60,70 +59,6 @@ def default_yaml(package_root: Path) -> Path: return package_root.parent.parent / "config" / "defaults.yaml" -@pytest.mark.parametrize("schema", [ConfigWithDefaults, NestedConfigWithDefaults]) -def test_load_defaults(schema: type[Any]) -> None: - loader = ConfigLoader(schema) - assert loader.load() == schema() - - -def test_load_some_defaults() -> None: - loader = ConfigLoader(ConfigWithDefaults) - loader.use_values({"foo": 4}) - assert loader.load() == ConfigWithDefaults(foo=4) - - -def test_load_override_all() -> None: - loader = ConfigLoader(ConfigWithDefaults) - loader.use_values({"foo": 4, "bar": "hi"}) - assert loader.load() == ConfigWithDefaults(foo=4, bar="hi") - - -def test_load_override_all_nested() -> None: - loader = ConfigLoader(NestedConfig) - loader.use_values({"nested": {"foo": 4, "bar": "hi"}, "baz": True}) - assert loader.load() == NestedConfig(nested=Config(foo=4, bar="hi"), baz=True) - - -def test_load_defaultless_schema() -> None: - loader = ConfigLoader(Config) - with pytest.raises(InvalidConfigError): - loader.load() - - -def test_inject_values_into_defaultless_schema() -> None: - loader = ConfigLoader(Config) - loader.use_values({"foo": 4, "bar": "hi"}) - assert loader.load() == Config(foo=4, bar="hi") - - -def test_load_yaml(config_yaml: Path) -> None: - loader = ConfigLoader(Config) - loader.use_values_from_yaml(config_yaml) - assert loader.load() == Config(foo=5, bar="test string") - - -def test_load_yaml_nested(nested_config_yaml: Path) -> None: - loader = ConfigLoader(NestedConfig) - loader.use_values_from_yaml(nested_config_yaml) - assert loader.load() == NestedConfig( - nested=Config(foo=6, bar="other test string"), baz=True - ) - - -def test_load_yaml_override(override_config_yaml: Path) -> None: - loader = ConfigLoader(ConfigWithDefaults) - loader.use_values_from_yaml(override_config_yaml) - - assert loader.load() == ConfigWithDefaults(foo=7) - - -def test_error_thrown_if_schema_does_not_match_yaml(nested_config_yaml: Path) -> None: - loader = ConfigLoader(Config) - loader.use_values_from_yaml(nested_config_yaml) - with pytest.raises(InvalidConfigError): - loader.load() - - @mock.patch.dict(os.environ, {"FOO": "bar"}, clear=True) def test_auth_from_env(): auth = BasicAuthentication(username="${FOO}", password="baz") # type: ignore @@ -231,12 +166,11 @@ def test_config_yaml_parsed(temp_yaml_config_file): temp_yaml_file_path, config_data = temp_yaml_config_file # Initialize loader and load config from the YAML file - loader = ConfigLoader(ApplicationConfig) - loader.use_values_from_yaml(temp_yaml_file_path) - loaded_config = loader.load() + ApplicationConfig.model_config["yaml_file"] = temp_yaml_file_path + app_config = ApplicationConfig() # Instantiates with customized sources # Parse the loaded config JSON into a dictionary - target_dict_json = json.loads(loaded_config.model_dump_json()) + target_dict_json = json.loads(app_config.model_dump_json()) # Assert that config_data is a subset of target_dict_json assert is_subset(config_data, target_dict_json) @@ -311,17 +245,16 @@ def test_config_yaml_parsed_complete(temp_yaml_config_file: dict): temp_yaml_file_path, config_data = temp_yaml_config_file # Initialize loader and load config from the YAML file - loader = ConfigLoader(ApplicationConfig) - loader.use_values_from_yaml(temp_yaml_file_path) - loaded_config = loader.load() + ApplicationConfig.model_config["yaml_file"] = temp_yaml_file_path + app_config = ApplicationConfig() # Instantiates with customized sources # Parse the loaded config JSON into a dictionary - target_dict_json = json.loads(loaded_config.model_dump_json()) + target_dict_json = json.loads(app_config.model_dump_json()) - assert loaded_config.stomp is not None - assert loaded_config.stomp.auth is not None + assert app_config.stomp is not None + assert app_config.stomp.auth is not None assert ( - loaded_config.stomp.auth.password.get_secret_value() + app_config.stomp.auth.password.get_secret_value() == config_data["stomp"]["auth"]["password"] # noqa: E501 ) # Remove the password field to not compare it again in the full dict comparison