Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 69 additions & 0 deletions tftpy/CPFilelock.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
import os
import platform
import time
import errno

class FileLock:
""" A cross-platform file lock solution. """

def __init__(self, file_obj):
self.file_obj = file_obj
self.fd = file_obj.fileno()
self.is_windows = platform.system() == 'Windows'

def acquire_shared(self):
if self.is_windows:
self._win32_lock(shared=True)
else:
self._posix_lock(shared=True)


def acquire_exclusive(self):
if self.is_windows:
self._win32_lock(shared=False)
else:
self._posix_lock(shared=False)


def release(self):
if self.is_windows:
self._win32_unlock()
else:
self._posix_unlock()

def _win32_lock(self, shared=True):
lockfile = f"{self.file_obj.name}.lock"
while True:
try:
fd = os.open(lockfile,
os.O_CREAT | os.O_EXCL | os.O_RDWR)
os.close(fd)
break
except OSError as e:
if e.errno != errno.EEXIST:
raise
time.sleep(0.1)


def _win32_unlock(self):
try:
os.unlink(f"{self.file_obj.name}.lock")
except OSError:
pass


def _posix_lock(self, shared=True):
import fcntl
flags = fcntl.LOCK_SH if shared else fcntl.LOCK_EX
try:
fcntl.flock(self.fd, flags | fcntl.LOCK_NB)
except IOError as e:
raise OSError(f"Failed to acquire lock: {e}")


def _posix_unlock(self):
import fcntl
try:
fcntl.flock(self.fd, fcntl.LOCK_UN)
except IOError as e:
raise OSError(f"Failed to release lock: {e}")
22 changes: 17 additions & 5 deletions tftpy/TftpContexts.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import socket
import sys
import time
from .CPFilelock import FileLock

from .TftpPacketFactory import TftpPacketFactory
from .TftpPacketTypes import *
Expand Down Expand Up @@ -134,7 +135,7 @@ def __del__(self):
def __enter__(self):
log.debug("__enter__ on TftpContext")
return self

def __exit__(self, type, value, traceback):
log.debug("__exit__ on TftpContext")
self.end()
Expand Down Expand Up @@ -165,7 +166,8 @@ def end(self, close_fileobj=True):
if close_fileobj and self.fileobj is not None and not self.fileobj.closed:
log.debug("self.fileobj is open - closing")
if self.flock:
lockfile(self.fileobj, unlock=True)
if self.flock and hasattr(self, "filelock"):
self.filelock.release()
self.fileobj.close()

def gethost(self):
Expand Down Expand Up @@ -335,7 +337,12 @@ def __init__(
if self.flock:
log.debug("locking input file %s", input)
try:
lockfile(self.fileobj, shared=True, blocking=False)
self.filelock = FileLock(self.fileobj)
try:
self.filelock.acquire_shared()
except OSError as err:
log.error("Failed to acquire read lock on file %s", input)
raise
except OSError as err:
log.error("Failed to acquire read lock on file %s", input)
raise
Expand All @@ -346,7 +353,7 @@ def __init__(
def __del__(self):
log.debug("TftpContextClientUpload.__del__")
super().__del__()

def __str__(self):
return f"{self.host}:{self.port} {self.state}"

Expand Down Expand Up @@ -431,7 +438,12 @@ def __init__(
if self.flock:
log.debug("locking file for writing: %s", output)
try:
lockfile(self.fileobj, shared=False, blocking=False)
self.filelock = FileLock(self.fileobj)
try:
self.filelock.acquire_exclusive()
except OSError as err:
log.error("Failed to acquire write lock on output file %s: %s", output, err)
raise
except OSError as err:
log.error("Failed to acquire write lock on output file %s: %s", output, err)
raise
Expand Down