Skip to content

Commit

Permalink
Fix CI failures (#1898)
Browse files Browse the repository at this point in the history
* implement join_threads helper method

* simplify network_test.py

* update tox.ini
  • Loading branch information
zariiii9003 authored Nov 23, 2024
1 parent 805f3fb commit 33a1ec7
Show file tree
Hide file tree
Showing 3 changed files with 71 additions and 76 deletions.
103 changes: 39 additions & 64 deletions test/network_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,16 +6,14 @@
import unittest

import can
from test.config import IS_PYPY

logging.getLogger(__file__).setLevel(logging.WARNING)


# make a random bool:
def rbool():
return bool(round(random.random()))


channel = "vcan0"
return random.choice([False, True])


class ControllerAreaNetworkTestCase(unittest.TestCase):
Expand Down Expand Up @@ -51,74 +49,51 @@ def tearDown(self):
# Restore the defaults
can.rc = self._can_rc

def producer(self, ready_event, msg_read):
self.client_bus = can.interface.Bus(channel=channel)
ready_event.wait()
for i in range(self.num_messages):
m = can.Message(
arbitration_id=self.ids[i],
is_remote_frame=self.remote_flags[i],
is_error_frame=self.error_flags[i],
is_extended_id=self.extended_flags[i],
data=self.data[i],
)
# logging.debug("writing message: {}".format(m))
if msg_read is not None:
# Don't send until the other thread is ready
msg_read.wait()
msg_read.clear()

self.client_bus.send(m)
def producer(self, channel: str):
with can.interface.Bus(channel=channel) as client_bus:
for i in range(self.num_messages):
m = can.Message(
arbitration_id=self.ids[i],
is_remote_frame=self.remote_flags[i],
is_error_frame=self.error_flags[i],
is_extended_id=self.extended_flags[i],
data=self.data[i],
)
client_bus.send(m)

def testProducer(self):
"""Verify that we can send arbitrary messages on the bus"""
logging.debug("testing producer alone")
ready = threading.Event()
ready.set()
self.producer(ready, None)

self.producer(channel="testProducer")
logging.debug("producer test complete")

def testProducerConsumer(self):
logging.debug("testing producer/consumer")
ready = threading.Event()
msg_read = threading.Event()

self.server_bus = can.interface.Bus(channel=channel, interface="virtual")

t = threading.Thread(target=self.producer, args=(ready, msg_read))
t.start()

# Ensure there are no messages on the bus
while True:
m = self.server_bus.recv(timeout=0.5)
if m is None:
print("No messages... lets go")
break
else:
self.fail("received messages before the test has started ...")
ready.set()
i = 0
while i < self.num_messages:
msg_read.set()
msg = self.server_bus.recv(timeout=0.5)
self.assertIsNotNone(msg, "Didn't receive a message")
# logging.debug("Received message {} with data: {}".format(i, msg.data))

self.assertEqual(msg.is_extended_id, self.extended_flags[i])
if not msg.is_remote_frame:
self.assertEqual(msg.data, self.data[i])
self.assertEqual(msg.arbitration_id, self.ids[i])

self.assertEqual(msg.is_error_frame, self.error_flags[i])
self.assertEqual(msg.is_remote_frame, self.remote_flags[i])

i += 1
t.join()

with contextlib.suppress(NotImplementedError):
self.server_bus.flush_tx_buffer()
self.server_bus.shutdown()
read_timeout = 2.0 if IS_PYPY else 0.5
channel = "testProducerConsumer"

with can.interface.Bus(channel=channel, interface="virtual") as server_bus:
t = threading.Thread(target=self.producer, args=(channel,))
t.start()

i = 0
while i < self.num_messages:
msg = server_bus.recv(timeout=read_timeout)
self.assertIsNotNone(msg, "Didn't receive a message")

self.assertEqual(msg.is_extended_id, self.extended_flags[i])
if not msg.is_remote_frame:
self.assertEqual(msg.data, self.data[i])
self.assertEqual(msg.arbitration_id, self.ids[i])

self.assertEqual(msg.is_error_frame, self.error_flags[i])
self.assertEqual(msg.is_remote_frame, self.remote_flags[i])

i += 1
t.join()

with contextlib.suppress(NotImplementedError):
server_bus.flush_tx_buffer()


if __name__ == "__main__":
Expand Down
40 changes: 31 additions & 9 deletions test/simplecyclic_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,11 @@
"""

import gc
import sys
import time
import traceback
import unittest
from threading import Thread
from time import sleep
from typing import List
from unittest.mock import MagicMock
Expand Down Expand Up @@ -87,6 +90,8 @@ def test_removing_bus_tasks(self):
# Note calling task.stop will remove the task from the Bus's internal task management list
task.stop()

self.join_threads([task.thread for task in tasks], 5.0)

assert len(bus._periodic_tasks) == 0
bus.shutdown()

Expand Down Expand Up @@ -115,8 +120,7 @@ def test_managed_tasks(self):
for task in tasks:
task.stop()

for task in tasks:
assert task.thread.join(5.0) is None, "Task didn't stop before timeout"
self.join_threads([task.thread for task in tasks], 5.0)

bus.shutdown()

Expand All @@ -142,9 +146,7 @@ def test_stopping_perodic_tasks(self):

# stop the other half using the bus api
bus.stop_all_periodic_tasks(remove_tasks=False)

for task in tasks:
assert task.thread.join(5.0) is None, "Task didn't stop before timeout"
self.join_threads([task.thread for task in tasks], 5.0)

# Tasks stopped via `stop_all_periodic_tasks` with remove_tasks=False should
# still be associated with the bus (e.g. for restarting)
Expand All @@ -161,7 +163,7 @@ def test_restart_perodic_tasks(self):
is_extended_id=False, arbitration_id=0x123, data=[0, 1, 2, 3, 4, 5, 6, 7]
)

def _read_all_messages(_bus: can.interfaces.virtual.VirtualBus) -> None:
def _read_all_messages(_bus: "can.interfaces.virtual.VirtualBus") -> None:
sleep(safe_timeout)
while not _bus.queue.empty():
_bus.recv(timeout=period)
Expand Down Expand Up @@ -207,9 +209,8 @@ def _read_all_messages(_bus: can.interfaces.virtual.VirtualBus) -> None:

# Stop all tasks and wait for the thread to exit
bus.stop_all_periodic_tasks()
if isinstance(task, can.broadcastmanager.ThreadBasedCyclicSendTask):
# Avoids issues where the thread is still running when the bus is shutdown
task.thread.join(safe_timeout)
# Avoids issues where the thread is still running when the bus is shutdown
self.join_threads([task.thread], 5.0)

@unittest.skipIf(IS_CI, "fails randomly when run on CI server")
def test_thread_based_cyclic_send_task(self):
Expand Down Expand Up @@ -288,6 +289,27 @@ def increment_first_byte(msg: can.Message) -> None:
self.assertEqual(b"\x06\x00\x00\x00\x00\x00\x00\x00", bytes(msg_list[5].data))
self.assertEqual(b"\x07\x00\x00\x00\x00\x00\x00\x00", bytes(msg_list[6].data))

@staticmethod
def join_threads(threads: List[Thread], timeout: float) -> None:
stuck_threads: List[Thread] = []
t0 = time.perf_counter()
for thread in threads:
time_left = timeout - (time.perf_counter() - t0)
if time_left > 0.0:
thread.join(time_left)
if thread.is_alive():
if platform.python_implementation() == "CPython":
# print thread frame to help with debugging
frame = sys._current_frames()[thread.ident]
traceback.print_stack(frame, file=sys.stderr)
stuck_threads.append(thread)
if stuck_threads:
err_message = (
f"Threads did not stop within {timeout:.1f} seconds: "
f"[{', '.join([str(t) for t in stuck_threads])}]"
)
raise RuntimeError(err_message)


if __name__ == "__main__":
unittest.main()
4 changes: 1 addition & 3 deletions tox.ini
Original file line number Diff line number Diff line change
Expand Up @@ -11,16 +11,14 @@ deps =
pyserial~=3.5
parameterized~=0.8
asammdf>=6.0; platform_python_implementation=="CPython" and python_version<"3.13"
pywin32>=305; platform_system=="Windows" and platform_python_implementation=="CPython" and python_version<"3.13"
pywin32>=305; platform_system=="Windows" and platform_python_implementation=="CPython" and python_version<"3.14"

commands =
pytest {posargs}

extras =
canalystii

recreate = True

[testenv:gh]
passenv =
CI
Expand Down

0 comments on commit 33a1ec7

Please sign in to comment.