Skip to content

Commit

Permalink
Implement get_model_path and the hash table generation, TODO add comp…
Browse files Browse the repository at this point in the history
…arision in run/create/equivalent model funcs
  • Loading branch information
justincdavis committed Aug 6, 2024
1 parent 17e1109 commit a8047f2
Show file tree
Hide file tree
Showing 6 changed files with 124 additions and 0 deletions.
4 changes: 4 additions & 0 deletions src/oakutils/blobs/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,8 @@
Get the input and output layer data for a blob.
get_model_name
Get the name of a compiled model file.
get_model_path
Get the path to the model blob.
"""

Expand All @@ -60,6 +62,7 @@
get_output_layer_data,
)
from ._benchmark import BenchmarkData, Metric, benchmark_blob
from ._find import get_model_path

_log = logging.getLogger(__name__)

Expand All @@ -71,6 +74,7 @@
"get_input_layer_data",
"get_layer_data",
"get_output_layer_data",
"get_model_path",
"models",
]

Expand Down
52 changes: 52 additions & 0 deletions src/oakutils/blobs/_find.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
# Copyright (c) 2024 Justin Davis (davisjustin302@gmail.com)
#
# MIT License
from __future__ import annotations

from pathlib import Path

from oakutils.nodes.models._parsing import get_candidates as _get_candidates


def get_model_path(
model_type: str,
model_attributes: list[str],
shaves: int,
) -> Path:
"""
Get the path to the model blob.
Parameters
----------
model_type : str
The model type to get the path for.
Examples include: ['gaussian', 'sobel']
model_attributes : list[str]
The model attributes to get the path for.
An example could be ['15'] for a gaussian model
using a 15x15 kernel size.
shaves : int
The number of shaves the model was compiled for.
Returns
-------
Path
The path to the model blob.
Raises
------
FileNotFoundError
If the returned model blob does not exists.
ValueError
If no model blob paths could be formed from the attributes and shaves.
"""
candidates = _get_candidates(model_type, model_attributes, shaves)
if len(candidates) == 0:
err_msg = f"No model blob paths could be formed from the attributes {model_attributes} and shaves {shaves}."
raise ValueError(err_msg)
blobpath = Path(candidates[0])
if not blobpath.exists():
err_msg = f"The model blob path {blobpath} does not exists."
raise FileNotFoundError(err_msg)
return blobpath
13 changes: 13 additions & 0 deletions tests/blobs/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,16 @@
#
# MIT License
from __future__ import annotations

from pathlib import Path

from .hashs import create_file_hash_table, create_bulk_hash_table

# handle the creation of the hash files if they do not exists
hash_table_path = Path(__file__).parent / "hash_table.pkl"
if not hash_table_path.exists():
create_file_hash_table()

bulk_hash_table_path = Path(__file__).parent / "bulk_hash_table.pkl"
if not bulk_hash_table_path.exists():
create_bulk_hash_table()
Binary file added tests/blobs/models/bulk_hash_table.pkl
Binary file not shown.
Binary file added tests/blobs/models/hash_table.pkl
Binary file not shown.
55 changes: 55 additions & 0 deletions tests/blobs/models/hashs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
# Copyright (c) 2024 Justin Davis (davisjustin302@gmail.com)
#
# MIT License
from __future__ import annotations

import hashlib
import pickle
import io
from pathlib import Path

from oakutils.blobs.models.bulk import ALL_MODELS


def hash_file(file_path: Path) -> str:
hasher = hashlib.md5()
with file_path.open("rb") as file:
for chunk in iter(lambda: file.read(io.DEFAULT_BUFFER_SIZE), b""):
hasher.update(chunk)
return hasher.hexdigest()


def create_file_hash_table() -> None:
hash_table: dict[str, str] = {}
for blob_tuple in ALL_MODELS:
for blob_path in blob_tuple:
hash_table[blob_path] = hash_file(blob_path)
table_file = Path(__file__).parent / "hash_table.pkl"
with Path.open(table_file, "wb") as file:
pickle.dump(hash_table, file, protocol=pickle.HIGHEST_PROTOCOL)


def compare_entry(entry: Path) -> bool:
with Path.open(Path(__file__).parent / "hash_table.pkl", "rb") as file:
table = pickle.load(file)
return table[entry] == hash_file(entry)


def create_bulk_hash_table() -> None:
hash_table: dict[str, str] = {}
for blob_tuple in ALL_MODELS:
# get the stem file path without the suffix
# then remove the _shavesN part at the end
key = blob_tuple[0].stem[:-8]
hashes = [hash_file(bp) for bp in blob_tuple]
hash_table[key] = hash(tuple(hashes))
table_file = Path(__file__).parent / "bulk_hash_table.pkl"
with Path.open(table_file, "wb") as file:
pickle.dump(hash_table, file, protocol=pickle.HIGHEST_PROTOCOL)


def compare_bulk_entry(entry: tuple[Path, ...]) -> bool:
key = entry[0].stem[:-8]
with Path.open(Path(__file__).parent / "bulk_hash_table.pkl", "rb") as file:
table = pickle.load(file)
return hash(tuple(hash_file(bp) for bp in entry)) == table[key]

0 comments on commit a8047f2

Please sign in to comment.