Skip to content

Commit

Permalink
fix linting
Browse files Browse the repository at this point in the history
  • Loading branch information
panford committed Nov 3, 2023
1 parent a945f49 commit 3858bb3
Show file tree
Hide file tree
Showing 3 changed files with 84 additions and 69 deletions.
99 changes: 54 additions & 45 deletions src/skai/model/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,40 +171,44 @@ def __init__(self):
b'minor_damage':3,
b'no_damage' :4}
self.label_to_int_table = tf.lookup.StaticHashTable(
tf.lookup.KeyValueTensorInitializer(list(string_label_categories.keys()),
list(string_label_categories.values())),
default_value=-1
tf.lookup.KeyValueTensorInitializer(
list(string_label_categories.keys()),
list(string_label_categories.values())),
default_value=-1
)
self.int_to_label_table = tf.lookup.StaticHashTable(
tf.lookup.KeyValueTensorInitializer(list(string_label_categories.values()),
list(string_label_categories.keys())),
default_value='unknown'
tf.lookup.KeyValueTensorInitializer(
list(string_label_categories.values()),
list(string_label_categories.keys())),
default_value='unknown'
)
def encode_example_ids(self, dataloader: Dataloader)-> Dataloader:
"""
Encode example IDs from hexadecimal strings to integers in a
TensorFlow DataLoader.
Description:
example_id are hexadecimal strings, eg. b0b947f423a1c77ac948c76f63fa8209.
This is encode by taking the int to base 16. This gives a long integer
representation, ie 125613911306676688688906949689977127817181202292590253,
which cannot be stored by a tensorflow tensor. This long integer can be
broken into smaller segments like [2323, 9023, 3403] using a combination of
integer division and modulo operations which can be reversed. The segments
are (pre-)padded to same size for all examples in a batch and initial size
before padding appended to segments. ie [0, 0, 2323, 9023, 3403, 3]
Encode example IDs from hexadecimal strings to integers in a
TensorFlow DataLoader.
Description:
example_id are hexadecimal strings, eg. b0b947f423a1c77ac948c76f63fa8209.
This is encode by taking the int to base 16. This gives a long integer
representation, ie 125613911306676688688906949689977127817181202292590253,
which cannot be stored by a tensorflow tensor. This long integer can be
broken into smaller segments like [2323, 9023, 3403] using a combination
of integer division and modulo operations which can be reversed. The
segments are (pre-)padded to same size for all examples in a batch and
initial size before padding appended to segments. ie
[0, 0, 2323, 9023, 3403, 3]
Args:
- dataloader: The TensorFlow DataLoader containing example IDs to be encoded.
Args:
- dataloader: The TensorFlow DataLoader containing example IDs to be
encoded.
Returns:
- dataloader: The modified TensorFlow DataLoader with encoded example IDs.
"""
Returns:
- dataloader: The modified TensorFlow DataLoader with encoded example IDs.
"""
return self._apply_map_to_features(dataloader,
self._convert_hex_strings_to_int,
'example_id')

def encode_string_labels(self, dataloader: Dataloader)-> Dataloader:
"""
Encode string data components to numerical values.
Expand All @@ -218,7 +222,7 @@ def encode_string_labels(self, dataloader: Dataloader)-> Dataloader:
return self._apply_map_to_features(dataloader,
self._convert_label_to_int,
'string_label')

def decode_example_ids(self, inputs: tf.Tensor | Dataloader):
"""
Decode example IDs from integers to hexadecimal strings in a batch.
Expand Down Expand Up @@ -255,22 +259,22 @@ def decode_string_labels(self, inputs: tf.Tensor | Dataloader):
'string_label')
else:
return self._convert_int_to_label(inputs)

def _convert_hex_strings_to_int(self, hex_strings):
"""Converts hex strings to integer values, typically a very long one.
This long integer values do not fit into a tensorflow tensor int datatype.
So the long integer is broken into segments using modulo technique and padding
to same size
So the long integer is broken into segments using modulo technique and
padding to same size
"""
segment_size=4
def split_long_integer(number):
segments = []
while number > 0:
segment = number % (10 ** segment_size) # Extract the last `segment_size` digits
segment = number % (10 ** segment_size) #Get last `segment_size` digits
segments.append(segment)
number //= 10 ** segment_size # Removes the last `segment_size` digits
number //= 10 ** segment_size # Remove last `segment_size` digits
return segments

output = []
for hex_string in hex_strings:
integer = int(hex_string.numpy(), 16)
Expand All @@ -279,7 +283,7 @@ def split_long_integer(number):
output.append(short_integers)
padded_sequences = tf.keras.preprocessing.sequence.pad_sequences(
output, padding='pre')
return padded_sequences
return padded_sequences

def _convert_int_to_hex_strings(self, segments):
"""Converts integer segments to a long integer value
Expand Down Expand Up @@ -311,29 +315,31 @@ def long_integer_to_string(integer):
def _convert_label_to_int(self, string_labels):
"""Lookup integer values from string labels"""
return self.label_to_int_table.lookup(string_labels)

def _convert_int_to_label(self, int_labels):
"""Lookup string labels from integer keys"""
return self.int_to_label_table.lookup(int_labels)

def _process_per_batch(self, batch, map_fn, feature):
"""Apply a map function to a batch of data."""
for idx, examples in enumerate(batch):
processed = map_fn(examples[feature])
examples[feature] = processed

if idx==0:
if idx==0:
transformed_batch=tf.data.Dataset.from_tensor_slices(examples)
continue
transformed_batch.concatenate(
tf.data.Dataset.from_tensor_slices(examples))
return transformed_batch

def _apply_map_to_features(self, dataloader: Dataloader,
map_fn: collections.abc.Callable[[tf.Tensor], tf.Tensor],
feature: str):
def _apply_map_to_features(self,
dataloader: Dataloader,
map_fn: collections.abc.Callable[[tf.Tensor], tf.Tensor],
feature: str):
"""
Apply a map function to a TensorFlow DataLoader and return the modified DataLoader.
Applies a map function to a TensorFlow DataLoader and returns
the modified DataLoader.
Args:
- dataloader: The TensorFlow DataLoader to apply the map function to.
Expand All @@ -345,23 +351,26 @@ def _apply_map_to_features(self, dataloader: Dataloader,
batch_size = dataloader.train_splits[0]._batch_size.numpy()

dataloader.train_splits = [
self._process_per_batch(data, map_fn, feature) for data in dataloader.train_splits
self._process_per_batch(data, map_fn, feature)
for data in dataloader.train_splits
]
dataloader.val_splits = [
self._process_per_batch(data, map_fn, feature) for data in dataloader.val_splits
self._process_per_batch(data, map_fn, feature)
for data in dataloader.val_splits
]
num_splits = len(dataloader.train_splits)
train_ds = gather_data_splits(
list(range(num_splits)), dataloader.train_splits)
val_ds = gather_data_splits(list(range(num_splits)), dataloader.val_splits)
train_ds = gather_data_splits(list(range(num_splits)),
dataloader.train_splits)
val_ds = gather_data_splits(list(range(num_splits)),
dataloader.val_splits)
dataloader.train_ds = train_ds
dataloader.eval_ds['val'] = val_ds
for (k, v) in dataloader.eval_ds.items():
if k != 'val':
dataloader.eval_ds[k] = self._process_per_batch(v, map_fn, feature)
dataloader = apply_batch(dataloader, batch_size)
return dataloader


def gather_data_splits(
slice_idx: list[int],
Expand Down
14 changes: 8 additions & 6 deletions src/skai/model/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

import logging as native_logging
import os

import datetime
from absl import app
from absl import flags
from absl import logging
Expand Down Expand Up @@ -103,7 +103,7 @@ def main(_) -> None:
dataloader.train_ds.filter(
generate_bias_table_lib.filter_ids_fn(ids_tab)) for
ids_tab in sampling_policies.convert_ids_to_table(config.ids_dir)]
print("Ids dir: ", config.ids_dir)

model_params = models.ModelTrainingParameters(
model_name=config.model.name,
train_bias=config.train_bias,
Expand All @@ -128,10 +128,13 @@ def main(_) -> None:
if FLAGS.is_vertex:
job_id = os.path.basename(FLAGS.trial_name)
output_dir = os.path.join(config.output_dir, job_id)
tf.io.gfile.makedirs(output_dir)
else:
#TODO - Choose a diretory name in case vertex ai is not used in running experiments
output_dir = config.output_dir
#TODO - Maybe change diretory name in case
# vertex ai is not used in running experiments
start_time = datetime.datetime.now()
timestamp = start_time.strftime('%Y-%m-%d-%H%M%S')
output_dir = f'{config.output_dir}_{timestamp}'
tf.io.gfile.makedirs(output_dir)
example_id_to_bias_table = None

if config.train_bias or (config.reweighting.do_reweighting and
Expand Down Expand Up @@ -224,4 +227,3 @@ def main(_) -> None:

if __name__ == '__main__':
app.run(main)

40 changes: 22 additions & 18 deletions src/skai/model/train_strategy.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,15 @@
"""Train Strategy file.
This creates the strategy for specified accelerator, cpu, gpu or tpu.
"""
from typing import Union
import tensorflow as tf


_Strategy = Union[
tf.distribute.Strategy,
tf.distribute.MirroredStrategy,
tf.distribute.TPUStrategy
tf.distribute.Strategy,
tf.distribute.MirroredStrategy,
tf.distribute.TPUStrategy
]


Expand All @@ -19,18 +23,18 @@ def get_tpu_resolver():


def get_strategy(accelerator_type: str)->_Strategy:
"""Gets distributed training strategy for accelerator type
Args:
accelerator_type: The accelerator type which is one of cpu, gpu or tpu
Returns:
MirrorStrategy if accelerator_type is gpu,
TPUStrategy if accelerator_type is tpu,
else default Strategy
"""
if accelerator_type == 'gpu':
return tf.distribute.MirroredStrategy()
elif accelerator_type == 'tpu':
resolver = get_tpu_resolver()
return tf.distribute.TPUStrategy(resolver)
return tf.distribute.get_strategy()
"""Gets distributed training strategy for accelerator type
Args:
accelerator_type: The accelerator type which is one of cpu, gpu or tpu
Returns:
MirrorStrategy if accelerator_type is gpu,
TPUStrategy if accelerator_type is tpu,
else default Strategy
"""
if accelerator_type == 'gpu':
return tf.distribute.MirroredStrategy()
elif accelerator_type == 'tpu':
resolver = get_tpu_resolver()
return tf.distribute.TPUStrategy(resolver)
return tf.distribute.get_strategy()

0 comments on commit 3858bb3

Please sign in to comment.