-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
First draft model test rewrite, still missing equivalence tests
- Loading branch information
1 parent
345efb8
commit 2b8bde1
Showing
19 changed files
with
330 additions
and
510 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -81,6 +81,7 @@ ci = [ | |
] | ||
test = [ | ||
"pytest>=6.2.0", | ||
"stdlib_list>=0.10.0", | ||
] | ||
docs = [ | ||
"sphinx==6.2.1", | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,48 @@ | ||
# Copyright (c) 2024 Justin Davis (davisjustin302@gmail.com) | ||
# | ||
# MIT License | ||
from __future__ import annotations | ||
|
||
from collections.abc import Callable | ||
from functools import partial | ||
|
||
from oakutils.nodes import get_nn_frame | ||
|
||
from .load import create_model, run_model | ||
|
||
|
||
def create_model_ghhs(createmodelfunc: Callable): | ||
for use_blur in [True, False]: | ||
for ks in [3, 5, 7, 9, 11, 13, 15]: | ||
for shave in [1, 2, 3, 4, 5, 6]: | ||
for use_gs in [True, False]: | ||
modelfunc = partial( | ||
createmodelfunc, | ||
blur_kernel_size=ks, | ||
shaves=shave, | ||
use_blur=use_blur, | ||
grayscale_out=use_gs, | ||
) | ||
assert create_model(modelfunc) == 0, f"Failed for {ks}, {shave}, {use_blur}, {use_gs}" | ||
return 0 | ||
|
||
|
||
def run_model_ghhs(createmodelfunc: Callable): | ||
for use_blur in [True, False]: | ||
for ks in [3, 5, 7, 9, 11, 13, 15]: | ||
for shave in [1, 2, 3, 4, 5, 6]: | ||
for use_gs in [True, False]: | ||
modelfunc = partial( | ||
createmodelfunc, | ||
blur_kernel_size=ks, | ||
shaves=shave, | ||
use_blur=use_blur, | ||
grayscale_out=use_gs, | ||
) | ||
channels = 1 if use_gs else 3 | ||
decodefunc = partial( | ||
get_nn_frame, | ||
channels=channels, | ||
) | ||
assert run_model(modelfunc, decodefunc) == 0, f"Failed for {ks}, {shave}, {use_blur}, {use_gs}" | ||
return 0 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,59 @@ | ||
# Copyright (c) 2024 Justin Davis (davisjustin302@gmail.com) | ||
# | ||
# MIT License | ||
from __future__ import annotations | ||
|
||
from collections.abc import Callable | ||
|
||
import depthai as dai | ||
from oakutils.nodes import create_color_camera, create_xout | ||
|
||
from ...device import get_device_count | ||
|
||
|
||
def create_model(modelfunc: Callable) -> int: | ||
pipeline = dai.Pipeline() | ||
cam = create_color_camera(pipeline) | ||
model = modelfunc(pipeline, cam.preview) | ||
xout_model = create_xout(pipeline, model.out, "model_out") | ||
|
||
all_nodes = [ | ||
cam, | ||
model, | ||
xout_model, | ||
] | ||
assert len(all_nodes) == 3 | ||
for node in all_nodes: | ||
assert node is not None | ||
|
||
return 0 | ||
|
||
|
||
def run_model(modelfunc: Callable, decodefunc: Callable) -> int: | ||
pipeline = dai.Pipeline() | ||
cam = create_color_camera(pipeline) | ||
model = modelfunc(pipeline, cam.preview) | ||
xout_model = create_xout(pipeline, model.out, "model_out") | ||
|
||
all_nodes = [ | ||
cam, | ||
model, | ||
xout_model, | ||
] | ||
assert len(all_nodes) == 3 | ||
for node in all_nodes: | ||
assert node is not None | ||
|
||
if get_device_count() == 0: | ||
return 0 | ||
|
||
with dai.Device(pipeline) as device: | ||
queue: dai.DataOutputQueue = device.getOutputQueue("model_out") | ||
|
||
while True: | ||
data = queue.get() | ||
frame = decodefunc(data) | ||
assert frame is not None | ||
break | ||
|
||
return 0 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.