changeset 24:20c8ec56e447

logfile: Pull logfile thread out of Logger. This enables automatic garbage collection of Logger instances, since a running thread no longer has a reference to a Logger's self. It separates exclusive management of logfile state into the _writer_thread function, which now opens the file and writes it until it is told to stop by receiving the poison pill.
author Paul Fisher <paul@pfish.zone>
date Sun, 10 Nov 2019 23:07:11 -0500
parents 88249e451566
children a4147ecb18b3
files weather_server/locations.py weather_server/logfile.py weather_server/logfile_test.py
diffstat 3 files changed, 158 insertions(+), 102 deletions(-) [+]
line wrap: on
line diff
--- a/weather_server/locations.py	Sun Nov 10 19:42:04 2019 -0500
+++ b/weather_server/locations.py	Sun Nov 10 23:07:11 2019 -0500
@@ -31,7 +31,7 @@
                 'location', 'name', fallback='Weather station')
             self.tz_name = parser.get('location', 'timezone', fallback='UTC')
             self.password = parser.get('location', 'password')
-            self.logger = logfile.Logger.create(
+            self.logger = logfile.Logger(
                 str(root / LOG), sample_field='sample_time')
         except (IOError, KeyError, configparser.Error):
             raise ConfigError("Couldn't load location info.")
--- a/weather_server/logfile.py	Sun Nov 10 19:42:04 2019 -0500
+++ b/weather_server/logfile.py	Sun Nov 10 23:07:11 2019 -0500
@@ -31,50 +31,33 @@
         self.future = futures.Future()
 
 
-# probably handle file-writing with a queue that reports back its progress
+# Poison pill to tell a logger thread to stop.
+_POISON = object()
+
 
 class Logger:
-    """Logger which handles reading/writing a temperature log for one process.
-    """
-
-    instance_lock = threading.Lock()
-    instances: t.Dict[str, 'Logger'] = {}
-
-    @classmethod
-    def create(
-        cls,
-        filename: str,
-        *,
-        sample_field: str,
-    ) -> 'Logger':
-        """Creates a single shared instance of a logger for the given file."""
-        try:
-            instance = cls.instances[filename]
-        except KeyError:
-            with cls.instance_lock:
-                try:
-                    instance = cls.instances[filename]
-                except KeyError:
-                    cls.instances[filename] = Logger(
-                        filename,
-                        sample_field=sample_field)
-                    instance = cls.instances[filename]
-        if instance._sample_field != sample_field:
-            raise ValueError(
-                'Existing instance has different sample field: '
-                '{!r} != {!r}'.format(instance._sample_field, sample_field))
-        return instance
+    """Logger which handles reading/writing one temperature log file."""
 
     def __init__(self, filename: str, *, sample_field: str):
-        """You should probably call .create() instead."""
+        """Creates a new Logger for the given file.
+
+        Args:
+            filename: The filename to open, or create if not already there.
+            sample_field: The field name to use as the strictly-increasing
+                value to ensure that no duplicate writes occur.
+        """
         self._sample_field = sample_field
-        self._file = _open_or_create(filename)
-        self._data: t.List[t.Dict[str, t.Any], ...] = []
         self._queue = queue.SimpleQueue()
-        self._last_size = 0
-        self._lock_status: t.Optional[int] = None
-        self._writer_thread = threading.Thread(target=self._writer)
+        # Create a Future that will be resolved once the file is opened
+        # (or fails to be opened).
+        writer_started = futures.Future()
+        self._writer_thread = threading.Thread(
+            name=f'{filename!r} writer thread',
+            target=lambda: _writer_thread(
+                filename, self._queue, sample_field, writer_started),
+            daemon=True)
         self._writer_thread.start()
+        writer_started.result()
 
     @property
     def data(self) -> t.Tuple[t.Dict[str, t.Any], ...]:
@@ -87,49 +70,65 @@
         self._queue.put(req)
         return req.future.result()
 
-    _POISON = object()
+    def __del__(self):
+        self.close()
 
     def close(self):
-        self._queue.put(self._POISON)
+        self._queue.put(_POISON)
         self._writer_thread.join()
 
-    def _writer(self) -> None:
+
+def _writer_thread(
+    filename: str,
+    q: queue.Queue,
+    sample_field: str,
+    started: futures.Future,
+) -> None:
+    if not started.set_running_or_notify_cancel():
+        return
+    try:
+        file = _open_or_create(filename)
+        started.set_result(None)
+    except BaseException as e:
+        started.set_exception(e)
+        return
+    with file:
         running = True
+        data: t.List[t.Dict[str, object]] = []
         while running:
-            item = self._queue.get()
-            if item is self._POISON:
+            item = q.get()
+            if item is _POISON:
                 # None is the poison pill that makes us stop.
                 running = False
             elif isinstance(item, _ReadRequest):
                 if not item.future.set_running_or_notify_cancel():
                     continue
                 try:
-                    with self._file_lock(fcntl.LOCK_SH):
-                        self._catch_up()
+                    with _file_lock(file, fcntl.LOCK_SH):
+                        data.extend(_catch_up(file))
                 except BaseException as x:
                     item.future.set_exception(x)
                 else:
-                    item.future.set_result(tuple(self._data))
+                    item.future.set_result(tuple(data))
             elif isinstance(item, _WriteRequest):
                 if not item.future.set_running_or_notify_cancel():
                     continue
                 try:
-                    with self._file_lock(fcntl.LOCK_EX):
-                        self._catch_up()
+                    with _file_lock(file, fcntl.LOCK_EX):
+                        data.extend(_catch_up(file))
                         # Since we're at the last good point, truncate after.
-                        self._file.truncate(self._file.tell())
-                        if not self._data:
+                        file.truncate(file.tell())
+                        if not data:
                             last = None
                         else:
-                            last = self._data[-1][self._sample_field]
+                            last = data[-1][sample_field]
                         for entry in item.entries:
-                            entry_key = entry[self._sample_field]
+                            entry_key = entry[sample_field]
                             if last is None or last < entry_key:
-                                self._file.write(common.bson_encode(entry))
-                                self._data.append(entry)
+                                file.write(common.bson_encode(entry))
+                                data.append(entry)
                                 last = entry_key
-                        self._file.flush()
-                        self._last_size = self._file.tell()
+                        file.flush()
                 except BaseException as x:
                     item.future.set_exception(x)
                 else:
@@ -137,43 +136,40 @@
             else:
                 raise AssertionError(
                     'Unexpected item {!r} in the queue'.format(item))
-        self._file.close()
+
+
+@contextlib.contextmanager
+def _file_lock(file: t.BinaryIO, operation: int) -> t.Iterator[None]:
+    assert operation in (fcntl.LOCK_SH, fcntl.LOCK_EX), 'Invalid operation.'
+    fcntl.flock(file, operation)
+    try:
+        yield
+    finally:
+        fcntl.flock(file, fcntl.LOCK_UN)
+
+
+def _size(file: t.BinaryIO) -> int:
+    return os.stat(file.fileno()).st_size
+
 
-    def _catch_up(self) -> None:
-        """Reads data and advances the file pointer to the end of the file."""
-        assert self._lock_status is not None, 'The lock must be held.'
-        size = self._size()
-        if size == self._last_size:
-            return
-        last_good = self._file.tell()
-        try:
-            items = bson.decode_file_iter(
-                self._file, codec_options=common.BSON_OPTIONS)
-            for item in items:
-                last_good = self._file.tell()
-                self._data.append(item)
-        except bson.InvalidBSON:
-            pass  # We have reached the last valid document.  Bail.
-        # Seek back to immediately after the end of the last valid doc.
-        self._last_size = last_good
-        self._file.seek(last_good, os.SEEK_SET)
-
-    def fileno(self) -> int:
-        return self._file.fileno()
-
-    def _size(self) -> int:
-        return os.stat(self.fileno()).st_size
-
-    @contextlib.contextmanager
-    def _file_lock(self, operation: int):
-        assert operation in (fcntl.LOCK_SH, fcntl.LOCK_EX), 'Invalid operation.'
-        fcntl.flock(self, operation)
-        self._lock_status = operation
-        try:
-            yield
-        finally:
-            self._lock_status = None
-            fcntl.flock(self, fcntl.LOCK_UN)
+def _catch_up(file: t.BinaryIO) -> t.Iterable[t.Dict[str, object]]:
+    """Reads data and advances the file pointer to the end of the file."""
+    size = _size(file)
+    pointer = file.tell()
+    if size == pointer:
+        return ()
+    output: t.List[t.Dict[str, object]] = []
+    try:
+        items = bson.decode_file_iter(
+            file, codec_options=common.BSON_OPTIONS)
+        for item in items:
+            pointer = file.tell()
+            output.append(item)
+    except bson.InvalidBSON:
+        pass  # We have reached the last valid document.  Bail.
+    # Seek back to immediately after the end of the last valid doc.
+    file.seek(pointer, os.SEEK_SET)
+    return output
 
 
 def _open_or_create(path: str) -> t.BinaryIO:
--- a/weather_server/logfile_test.py	Sun Nov 10 19:42:04 2019 -0500
+++ b/weather_server/logfile_test.py	Sun Nov 10 23:07:11 2019 -0500
@@ -1,7 +1,9 @@
 import contextlib
 import datetime
+import os.path
 import pathlib
 import tempfile
+import threading
 import unittest
 
 import bson
@@ -9,7 +11,6 @@
 
 from . import common
 from . import logfile
-from . import types
 
 
 def ts(n):
@@ -35,6 +36,20 @@
         with contextlib.closing(lg) as logger:
             self.assertEqual(logger.data, ())
 
+    def test_fails_to_open(self):
+        with self.assertRaises(OSError):
+            logfile.Logger(
+                os.path.join(
+                    self.temp_dir.name,
+                    'nonexistent-directory',
+                    'bogus-filename'),
+                sample_field='unimportant')
+
+    def test_del(self):
+        lg = logfile.Logger(
+            str(self.log_path), sample_field='x')
+        del lg
+
     def test_loading(self):
         with self.log_path.open('wb') as outfile:
             outfile.write(common.bson_encode(dict(
@@ -72,14 +87,34 @@
         )) as logger:
             logger.write_rows([
                 # Ignored, since it's older than the newest entry.
-                types.Reading(ts(100), 999, 666, ts(101)).as_dict(),
-                types.Reading(ts(125), 333, 777, ts(200)).as_dict(),
+                dict(
+                    sample_time=ts(100),
+                    temp_c=999,
+                    rh_pct=666,
+                    ingest_time=ts(101),
+                ),
+                dict(
+                    sample_time=ts(125),
+                    temp_c=333,
+                    rh_pct=777,
+                    ingest_time=ts(200),
+                ),
             ])
             self.assertEqual(
                 logger.data,
                 (
-                    types.Reading(ts(123), 420, 69, ts(125)).as_dict(),
-                    types.Reading(ts(125), 333, 777, ts(200)).as_dict(),
+                    dict(
+                        sample_time=ts(123),
+                        temp_c=420,
+                        rh_pct=69,
+                        ingest_time=ts(125),
+                    ),
+                    dict(
+                        sample_time=ts(125),
+                        temp_c=333,
+                        rh_pct=777,
+                        ingest_time=ts(200),
+                    ),
                 )
             )
 
@@ -104,8 +139,18 @@
             sample_field='sample_time',
         )) as logger:
             logger.write_rows([
-                types.Reading(ts(100), 999, 666, ts(101)).as_dict(),
-                types.Reading(ts(125), 333, 777, ts(200)).as_dict(),
+                dict(
+                    sample_time=ts(100),
+                    temp_c=999,
+                    rh_pct=666,
+                    ingest_time=ts(101),
+                ),
+                dict(
+                    sample_time=ts(125),
+                    temp_c=333,
+                    rh_pct=777,
+                    ingest_time=ts(200),
+                ),
             ])
             with self.log_path.open('ab') as outfile:
                 outfile.write(common.bson_encode(dict(
@@ -116,9 +161,24 @@
                 )))
                 outfile.flush()
             self.assertEqual(logger.data, (
-                types.Reading(ts(100), 999, 666, ts(101)).as_dict(),
-                types.Reading(ts(125), 333, 777, ts(200)).as_dict(),
-                types.Reading(ts(1024), 256, 128, ts(4096)).as_dict(),
+                dict(
+                    sample_time=ts(100),
+                    temp_c=999,
+                    rh_pct=666,
+                    ingest_time=ts(101),
+                ),
+                dict(
+                    sample_time=ts(125),
+                    temp_c=333,
+                    rh_pct=777,
+                    ingest_time=ts(200),
+                ),
+                dict(
+                    sample_time=ts(1024),
+                    temp_c=256,
+                    rh_pct=128,
+                    ingest_time=ts(4096),
+                ),
             ))
 
     def read_bsons(self):