diff --git a/tftpy/CPFilelock.py b/tftpy/CPFilelock.py new file mode 100644 index 0000000..a4fafb0 --- /dev/null +++ b/tftpy/CPFilelock.py @@ -0,0 +1,69 @@ +import os +import platform +import time +import errno + +class FileLock: + """ A cross-platform file lock solution. """ + + def __init__(self, file_obj): + self.file_obj = file_obj + self.fd = file_obj.fileno() + self.is_windows = platform.system() == 'Windows' + + def acquire_shared(self): + if self.is_windows: + self._win32_lock(shared=True) + else: + self._posix_lock(shared=True) + + + def acquire_exclusive(self): + if self.is_windows: + self._win32_lock(shared=False) + else: + self._posix_lock(shared=False) + + + def release(self): + if self.is_windows: + self._win32_unlock() + else: + self._posix_unlock() + + def _win32_lock(self, shared=True): + lockfile = f"{self.file_obj.name}.lock" + while True: + try: + fd = os.open(lockfile, + os.O_CREAT | os.O_EXCL | os.O_RDWR) + os.close(fd) + break + except OSError as e: + if e.errno != errno.EEXIST: + raise + time.sleep(0.1) + + + def _win32_unlock(self): + try: + os.unlink(f"{self.file_obj.name}.lock") + except OSError: + pass + + + def _posix_lock(self, shared=True): + import fcntl + flags = fcntl.LOCK_SH if shared else fcntl.LOCK_EX + try: + fcntl.flock(self.fd, flags | fcntl.LOCK_NB) + except IOError as e: + raise OSError(f"Failed to acquire lock: {e}") + + + def _posix_unlock(self): + import fcntl + try: + fcntl.flock(self.fd, fcntl.LOCK_UN) + except IOError as e: + raise OSError(f"Failed to release lock: {e}") \ No newline at end of file diff --git a/tftpy/TftpContexts.py b/tftpy/TftpContexts.py index 831b117..014465e 100644 --- a/tftpy/TftpContexts.py +++ b/tftpy/TftpContexts.py @@ -15,6 +15,7 @@ import socket import sys import time +from .CPFilelock import FileLock from .TftpPacketFactory import TftpPacketFactory from .TftpPacketTypes import * @@ -134,7 +135,7 @@ def __del__(self): def __enter__(self): log.debug("__enter__ on TftpContext") return self - + def __exit__(self, type, value, traceback): log.debug("__exit__ on TftpContext") self.end() @@ -165,7 +166,8 @@ def end(self, close_fileobj=True): if close_fileobj and self.fileobj is not None and not self.fileobj.closed: log.debug("self.fileobj is open - closing") if self.flock: - lockfile(self.fileobj, unlock=True) + if self.flock and hasattr(self, "filelock"): + self.filelock.release() self.fileobj.close() def gethost(self): @@ -335,7 +337,12 @@ def __init__( if self.flock: log.debug("locking input file %s", input) try: - lockfile(self.fileobj, shared=True, blocking=False) + self.filelock = FileLock(self.fileobj) + try: + self.filelock.acquire_shared() + except OSError as err: + log.error("Failed to acquire read lock on file %s", input) + raise except OSError as err: log.error("Failed to acquire read lock on file %s", input) raise @@ -346,7 +353,7 @@ def __init__( def __del__(self): log.debug("TftpContextClientUpload.__del__") super().__del__() - + def __str__(self): return f"{self.host}:{self.port} {self.state}" @@ -431,7 +438,12 @@ def __init__( if self.flock: log.debug("locking file for writing: %s", output) try: - lockfile(self.fileobj, shared=False, blocking=False) + self.filelock = FileLock(self.fileobj) + try: + self.filelock.acquire_exclusive() + except OSError as err: + log.error("Failed to acquire write lock on output file %s: %s", output, err) + raise except OSError as err: log.error("Failed to acquire write lock on output file %s: %s", output, err) raise