-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Implement get_model_path and the hash table generation, TODO add comp…
…arision in run/create/equivalent model funcs
- Loading branch information
1 parent
17e1109
commit a8047f2
Showing
6 changed files
with
124 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,52 @@ | ||
# Copyright (c) 2024 Justin Davis (davisjustin302@gmail.com) | ||
# | ||
# MIT License | ||
from __future__ import annotations | ||
|
||
from pathlib import Path | ||
|
||
from oakutils.nodes.models._parsing import get_candidates as _get_candidates | ||
|
||
|
||
def get_model_path( | ||
model_type: str, | ||
model_attributes: list[str], | ||
shaves: int, | ||
) -> Path: | ||
""" | ||
Get the path to the model blob. | ||
Parameters | ||
---------- | ||
model_type : str | ||
The model type to get the path for. | ||
Examples include: ['gaussian', 'sobel'] | ||
model_attributes : list[str] | ||
The model attributes to get the path for. | ||
An example could be ['15'] for a gaussian model | ||
using a 15x15 kernel size. | ||
shaves : int | ||
The number of shaves the model was compiled for. | ||
Returns | ||
------- | ||
Path | ||
The path to the model blob. | ||
Raises | ||
------ | ||
FileNotFoundError | ||
If the returned model blob does not exists. | ||
ValueError | ||
If no model blob paths could be formed from the attributes and shaves. | ||
""" | ||
candidates = _get_candidates(model_type, model_attributes, shaves) | ||
if len(candidates) == 0: | ||
err_msg = f"No model blob paths could be formed from the attributes {model_attributes} and shaves {shaves}." | ||
raise ValueError(err_msg) | ||
blobpath = Path(candidates[0]) | ||
if not blobpath.exists(): | ||
err_msg = f"The model blob path {blobpath} does not exists." | ||
raise FileNotFoundError(err_msg) | ||
return blobpath |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Binary file not shown.
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,55 @@ | ||
# Copyright (c) 2024 Justin Davis (davisjustin302@gmail.com) | ||
# | ||
# MIT License | ||
from __future__ import annotations | ||
|
||
import hashlib | ||
import pickle | ||
import io | ||
from pathlib import Path | ||
|
||
from oakutils.blobs.models.bulk import ALL_MODELS | ||
|
||
|
||
def hash_file(file_path: Path) -> str: | ||
hasher = hashlib.md5() | ||
with file_path.open("rb") as file: | ||
for chunk in iter(lambda: file.read(io.DEFAULT_BUFFER_SIZE), b""): | ||
hasher.update(chunk) | ||
return hasher.hexdigest() | ||
|
||
|
||
def create_file_hash_table() -> None: | ||
hash_table: dict[str, str] = {} | ||
for blob_tuple in ALL_MODELS: | ||
for blob_path in blob_tuple: | ||
hash_table[blob_path] = hash_file(blob_path) | ||
table_file = Path(__file__).parent / "hash_table.pkl" | ||
with Path.open(table_file, "wb") as file: | ||
pickle.dump(hash_table, file, protocol=pickle.HIGHEST_PROTOCOL) | ||
|
||
|
||
def compare_entry(entry: Path) -> bool: | ||
with Path.open(Path(__file__).parent / "hash_table.pkl", "rb") as file: | ||
table = pickle.load(file) | ||
return table[entry] == hash_file(entry) | ||
|
||
|
||
def create_bulk_hash_table() -> None: | ||
hash_table: dict[str, str] = {} | ||
for blob_tuple in ALL_MODELS: | ||
# get the stem file path without the suffix | ||
# then remove the _shavesN part at the end | ||
key = blob_tuple[0].stem[:-8] | ||
hashes = [hash_file(bp) for bp in blob_tuple] | ||
hash_table[key] = hash(tuple(hashes)) | ||
table_file = Path(__file__).parent / "bulk_hash_table.pkl" | ||
with Path.open(table_file, "wb") as file: | ||
pickle.dump(hash_table, file, protocol=pickle.HIGHEST_PROTOCOL) | ||
|
||
|
||
def compare_bulk_entry(entry: tuple[Path, ...]) -> bool: | ||
key = entry[0].stem[:-8] | ||
with Path.open(Path(__file__).parent / "bulk_hash_table.pkl", "rb") as file: | ||
table = pickle.load(file) | ||
return hash(tuple(hash_file(bp) for bp in entry)) == table[key] |