1
0
mirror of https://github.com/craigerl/aprsd.git synced 2026-03-30 19:55:44 -04:00

refactor(threads): add daemon, period, Event-based shutdown to APRSDThread

- Add daemon=True class attribute (subclasses override to False)
- Add period=1 class attribute for wait interval
- Replace thread_stop boolean with _shutdown_event (threading.Event)
- Add wait() method for interruptible sleeps
- Update tests for new Event-based API

BREAKING: thread_stop boolean replaced with _shutdown_event.
Code checking thread.thread_stop directly must use thread._shutdown_event.is_set()
This commit is contained in:
Walter Boring 2026-03-24 11:54:29 -04:00
parent bc9b15d47a
commit b7a37322e1
2 changed files with 110 additions and 23 deletions

View File

@ -2,7 +2,6 @@ import abc
import datetime
import logging
import threading
import time
from typing import List
import wrapt
@ -13,21 +12,26 @@ LOG = logging.getLogger('APRSD')
class APRSDThread(threading.Thread, metaclass=abc.ABCMeta):
"""Base class for all threads in APRSD."""
# Class attributes - subclasses override as needed
daemon = True # Most threads are daemon threads
period = 1 # Default wait period in seconds
loop_count = 1
_pause = False
thread_stop = False
def __init__(self, name):
super().__init__(name=name)
self.thread_stop = False
# Set daemon from class attribute
self.daemon = self.__class__.daemon
# Set period from class attribute (can be overridden in __init__)
self.period = self.__class__.period
self._shutdown_event = threading.Event()
self.loop_count = 0
APRSDThreadList().add(self)
self._last_loop = datetime.datetime.now()
def _should_quit(self):
"""see if we have a quit message from the global queue."""
if self.thread_stop:
return True
return False
"""Check if thread should exit."""
return self._shutdown_event.is_set()
def pause(self):
"""Logically pause the processing of the main loop."""
@ -40,8 +44,21 @@ class APRSDThread(threading.Thread, metaclass=abc.ABCMeta):
self._pause = False
def stop(self):
"""Signal thread to stop. Returns immediately."""
LOG.debug(f"Stopping thread '{self.name}'")
self.thread_stop = True
self._shutdown_event.set()
def wait(self, timeout: float | None = None) -> bool:
"""Wait for shutdown signal or timeout.
Args:
timeout: Seconds to wait. Defaults to self.period.
Returns:
True if shutdown was signaled, False if timeout expired.
"""
wait_time = timeout if timeout is not None else self.period
return self._shutdown_event.wait(timeout=wait_time)
@abc.abstractmethod
def loop(self):
@ -64,7 +81,7 @@ class APRSDThread(threading.Thread, metaclass=abc.ABCMeta):
LOG.debug('Starting')
while not self._should_quit():
if self._pause:
time.sleep(1)
self.wait(timeout=1)
else:
self.loop_count += 1
can_loop = self.loop()

View File

@ -1,3 +1,4 @@
import datetime
import threading
import time
import unittest
@ -42,11 +43,9 @@ class TestAPRSDThread(unittest.TestCase):
"""Test thread initialization."""
thread = TestThread('TestThread1')
self.assertEqual(thread.name, 'TestThread1')
self.assertFalse(thread.thread_stop)
self.assertFalse(thread._shutdown_event.is_set())
self.assertFalse(thread._pause)
self.assertEqual(thread.loop_count, 1)
# Should be registered in thread list
self.assertEqual(thread.loop_count, 0) # Was 1, now starts at 0
thread_list = APRSDThreadList()
self.assertIn(thread, thread_list.threads_list)
@ -54,8 +53,7 @@ class TestAPRSDThread(unittest.TestCase):
"""Test _should_quit() method."""
thread = TestThread('TestThread2')
self.assertFalse(thread._should_quit())
thread.thread_stop = True
thread._shutdown_event.set()
self.assertTrue(thread._should_quit())
def test_pause_unpause(self):
@ -72,20 +70,93 @@ class TestAPRSDThread(unittest.TestCase):
def test_stop(self):
"""Test stop() method."""
thread = TestThread('TestThread4')
self.assertFalse(thread.thread_stop)
self.assertFalse(thread._shutdown_event.is_set())
thread.stop()
self.assertTrue(thread.thread_stop)
self.assertTrue(thread._shutdown_event.is_set())
def test_loop_age(self):
"""Test loop_age() method."""
import datetime
thread = TestThread('TestThread5')
age = thread.loop_age()
self.assertIsInstance(age, datetime.timedelta)
self.assertGreaterEqual(age.total_seconds(), 0)
def test_daemon_attribute_default(self):
"""Test that daemon attribute defaults to True."""
thread = TestThread('DaemonTest')
self.assertTrue(thread.daemon)
def test_daemon_attribute_override(self):
"""Test that daemon attribute can be overridden via class attribute."""
class NonDaemonThread(APRSDThread):
daemon = False
def loop(self):
return False
thread = NonDaemonThread('NonDaemonTest')
self.assertFalse(thread.daemon)
def test_period_attribute_default(self):
"""Test that period attribute defaults to 1."""
thread = TestThread('PeriodTest')
self.assertEqual(thread.period, 1)
def test_period_attribute_override(self):
"""Test that period attribute can be overridden via class attribute."""
class LongPeriodThread(APRSDThread):
period = 60
def loop(self):
return False
thread = LongPeriodThread('LongPeriodTest')
self.assertEqual(thread.period, 60)
def test_shutdown_event_exists(self):
"""Test that _shutdown_event is created."""
thread = TestThread('EventTest')
self.assertIsInstance(thread._shutdown_event, threading.Event)
self.assertFalse(thread._shutdown_event.is_set())
def test_wait_returns_false_on_timeout(self):
"""Test that wait() returns False when timeout expires."""
thread = TestThread('WaitTimeoutTest')
start = time.time()
result = thread.wait(timeout=0.1)
elapsed = time.time() - start
self.assertFalse(result)
self.assertGreaterEqual(elapsed, 0.1)
def test_wait_returns_true_when_stopped(self):
"""Test that wait() returns True immediately when stop() was called."""
thread = TestThread('WaitStopTest')
thread.stop()
start = time.time()
result = thread.wait(timeout=10)
elapsed = time.time() - start
self.assertTrue(result)
self.assertLess(elapsed, 1)
def test_wait_uses_period_by_default(self):
"""Test that wait() uses self.period when no timeout specified."""
class ShortPeriodThread(APRSDThread):
period = 0.1
def loop(self):
return False
thread = ShortPeriodThread('ShortPeriodTest')
start = time.time()
result = thread.wait()
elapsed = time.time() - start
self.assertFalse(result)
self.assertGreaterEqual(elapsed, 0.1)
self.assertLess(elapsed, 0.5)
def test_str(self):
"""Test __str__() method."""
thread = TestThread('TestThread6')
@ -253,10 +324,9 @@ class TestAPRSDThreadList(unittest.TestCase):
thread2 = TestThread('TestThread9')
thread_list.add(thread1)
thread_list.add(thread2)
thread_list.stop_all()
self.assertTrue(thread1.thread_stop)
self.assertTrue(thread2.thread_stop)
self.assertTrue(thread1._shutdown_event.is_set())
self.assertTrue(thread2._shutdown_event.is_set())
def test_pause_all(self):
"""Test pause_all() method."""