From 8dea56ff0090b1e345a371d5e705c50c1bdd741f Mon Sep 17 00:00:00 2001 From: arjunchainani Date: Wed, 18 Sep 2024 23:02:17 -0500 Subject: [PATCH] changed config file reading to work with expected labels from oracle --- resspect/scripts/run_loop.py | 17 ++++------------- 1 file changed, 4 insertions(+), 13 deletions(-) diff --git a/resspect/scripts/run_loop.py b/resspect/scripts/run_loop.py index 97950e5f..492fd31d 100644 --- a/resspect/scripts/run_loop.py +++ b/resspect/scripts/run_loop.py @@ -66,10 +66,6 @@ def run_loop(args): """ # set training sample variable - - # TODO: add code for reading from config file - # -> grab n_classes parameter for learn_loop() - # -> create dict to map label names to corresponding numerical labels # gets class names from config file and generates dictionary with one-hot encoded numerical labels n_classes = 0 @@ -78,16 +74,11 @@ def run_loop(args): if args.config is not None: with open(args.config, "r") as config: info = [line for line in config] - for label in info[0].split(','): - class_info[label] = n_classes + for label in info[0].split(';'): + class_info[label.split(":")[0]] = label.split(":")[1] n_classes += 1 - - for num, label in enumerate(class_info.values()): - encoded = [0 for _ in range(n_classes)] - encoded[label] = 1 - class_info[list(class_info)[num]] = encoded - - print(f'class_info: {class_info}') + + print(f'class_info: {class_info}') # TODO: remove these print(f'n_classes: {n_classes}') if args.training == 'original':