Skip to content

Commit

Permalink
Split main into functions
Browse files Browse the repository at this point in the history
  • Loading branch information
jfrost-mo committed Nov 8, 2024
1 parent 710f737 commit 47ada4e
Showing 1 changed file with 57 additions and 49 deletions.
106 changes: 57 additions & 49 deletions src/CSET/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,39 @@ def main():
Handles argument parsing, setting up logging, top level error capturing,
and execution of the desired subcommand.
"""
parser = setup_argument_parser()
cli_args = sys.argv[1:] + shlex.split(os.getenv("CSET_ADDOPTS", ""))
args, unparsed_args = parser.parse_known_args(cli_args)
setup_logging(args.verbose)

# Down here so runs after logging is setup.
logging.debug("CLI Arguments: %s", cli_args)

if args.subparser is None:
print("Please choose a command.", file=sys.stderr)
parser.print_usage()
sys.exit(127)

try:
# Execute the specified subcommand.
args.func(args, unparsed_args)
except ArgumentError as err:
# Error message for when needed template variables are missing.
print(err, file=sys.stderr)
parser.print_usage()
sys.exit(127)
except Exception as err:
# Provide slightly nicer error messages for unhandled exceptions.
print(err, file=sys.stderr)
# Display the time and full traceback when debug logging.
logging.debug("An unhandled exception occurred.")
if logging.root.isEnabledFor(logging.DEBUG):
raise
sys.exit(1)


def setup_argument_parser() -> argparse.ArgumentParser:
"""Create argument parser for CSET CLI."""
parser = argparse.ArgumentParser(
prog="cset", description="Convective Scale Evaluation Tool"
)
Expand Down Expand Up @@ -138,68 +171,43 @@ def main():
)
parser_recipe_id.set_defaults(func=_recipe_id_command)

cli_args = sys.argv[1:] + shlex.split(os.getenv("CSET_ADDOPTS", ""))
args, unparsed_args = parser.parse_known_args(cli_args)

# Setup logging.
logging.captureWarnings(True)
loglevel = calculate_loglevel(args)
logger = logging.getLogger()
logger.setLevel(min(loglevel, logging.INFO))
stderr_log = logging.StreamHandler()
stderr_log.addFilter(lambda record: record.levelno >= loglevel)
stderr_log.setFormatter(logging.Formatter("%(asctime)s %(levelname)s %(message)s"))
logger.addHandler(stderr_log)

# Down here so runs after logging is setup.
logging.debug("CLI Arguments: %s", cli_args)
return parser

if args.subparser is None:
print("Please choose a command.", file=sys.stderr)
parser.print_usage()
sys.exit(127)

try:
# Execute the specified subcommand.
args.func(args, unparsed_args)
except ArgumentError as err:
# Error message for when needed template variables are missing.
print(err, file=sys.stderr)
parser.print_usage()
sys.exit(127)
except Exception as err:
# Provide slightly nicer error messages for unhandled exceptions.
print(err, file=sys.stderr)
# Display the time and full traceback when debug logging.
logging.debug("An unhandled exception occurred.")
if logging.root.isEnabledFor(logging.DEBUG):
raise
sys.exit(1)


def calculate_loglevel(args) -> int:
"""Calculate the logging level to apply.
def setup_logging(verbosity: int):
"""Configure logging level, format and output stream.
Level is based on verbose argument and the LOGLEVEL environment variable.
"""
logging.captureWarnings(True)

# Calculate logging level.
try:
name_to_level = logging.getLevelNamesMapping()
except AttributeError:
# logging.getLevelNamesMapping() is python 3.11 or newer. Using
# implementation detail for older versions.
name_to_level = logging._nameToLevel
# Level from CLI flags.
if args.verbose >= 2:
loglevel = logging.DEBUG
elif args.verbose == 1:
loglevel = logging.INFO
if verbosity >= 2:
cli_loglevel = logging.DEBUG
elif verbosity == 1:
cli_loglevel = logging.INFO
else:
loglevel = logging.WARNING
return min(
loglevel,
# Level from environment variable.
name_to_level.get(os.getenv("LOGLEVEL"), logging.ERROR),
)
cli_loglevel = logging.WARNING
# Level from environment variable.
env_loglevel = name_to_level.get(os.getenv("LOGLEVEL"), logging.ERROR)
loglevel = min(cli_loglevel, env_loglevel)

# Configure the root logger.
logger = logging.getLogger()
# Record everything at least INFO for the log file.
logger.setLevel(min(loglevel, logging.INFO))
stderr_log = logging.StreamHandler()
# Filter stderr log to just what is requested.
stderr_log.addFilter(lambda record: record.levelno >= loglevel)
stderr_log.setFormatter(logging.Formatter("%(asctime)s %(levelname)s %(message)s"))
logger.addHandler(stderr_log)


def _bake_command(args, unparsed_args):
Expand Down

0 comments on commit 47ada4e

Please sign in to comment.