diff --git a/acquire/outputs/tar.py b/acquire/outputs/tar.py index 69b72906..f6fa1cd6 100644 --- a/acquire/outputs/tar.py +++ b/acquire/outputs/tar.py @@ -1,5 +1,6 @@ from __future__ import annotations +import copy import io import tarfile from typing import TYPE_CHECKING, BinaryIO @@ -100,7 +101,54 @@ def write( if stat: info.mtime = stat.st_mtime - self.tar.addfile(info, fh) + # Inline version of Python stdlib's tarfile.addfile & tarfile.copyfileobj, + # to allow for padding and more control over the tar file writing. + self.tar._check("awx") + + if fh is None and info.isreg() and info.size != 0: + return + + tarinfo = copy.copy(info) + saved_offset = self.tar.offset + saved_filepos = self.tar.fileobj.tell() + + try: + buf = tarinfo.tobuf(self.tar.format, self.tar.encoding, self.tar.errors) + self.tar.fileobj.write(buf) + self.tar.offset += len(buf) + bufsize = self.tar.copybufsize or 16 * 1024 + + if fh is not None: + if tarinfo.size is None or tarinfo.size == 0: + return + + blocks, remainder = divmod(tarinfo.size, bufsize) + for _ in range(blocks): + buf = fh.read(size) + if len(buf) < size: + # PATCH; instead of raising an exception, pad the data to the desired length + buf += tarfile.NUL * (size - len(buf)) + self.tar.fileobj.write(buf) + + if remainder > 0: + buf = fh.read(remainder) + if len(buf) < remainder: + # PATCH; instead of raising an exception, pad the data to the desired length + buf += tarfile.NUL * (remainder - len(buf)) + self.tar.fileobj.write(buf) + + blocks, remainder = divmod(tarinfo.size, tarfile.BLOCKSIZE) + if remainder > 0: + self.tar.fileobj.write(tarfile.NUL * (tarfile.BLOCKSIZE - remainder)) + blocks += 1 + self.tar.offset += blocks * tarfile.BLOCKSIZE + + self.tar.members.append(tarinfo) + except Exception: + self.tar.fileobj.seek(saved_filepos) + self.tar.fileobj.truncate() + self.tar.offset = saved_offset + raise def close(self) -> None: """Closes the tar file.""" diff --git a/tests/test_outputs_tar.py b/tests/test_outputs_tar.py index 81059bb2..6a0965a1 100644 --- a/tests/test_outputs_tar.py +++ b/tests/test_outputs_tar.py @@ -1,5 +1,6 @@ from __future__ import annotations +import io import tarfile from pathlib import Path from typing import TYPE_CHECKING @@ -63,3 +64,110 @@ def test_tar_output_encrypt(mock_fs: VirtualFilesystem, public_key: bytes, tmp_p with tarfile.open(name=decrypted_path, mode="r") as tar_file: assert entry.open().read() == tar_file.extractfile(entry_name).read() + + +def test_tar_output_race_condition_with_shrinking_file(tmp_path: Path, public_key: bytes) -> None: + class ShrinkingFile(io.BytesIO): + def __init__(self, data: bytes): + super().__init__(data) + + def read(self, size: int) -> bytes: + return super().read(size - 5) + + content = b"some text" + + content_padded = content[:-5] + tarfile.NUL * 5 + file = ShrinkingFile(content) + + tar_output = TarOutput(tmp_path / "race.tar", encrypt=True, public_key=public_key) + tar_output.write("file.log", file) + tar_output.close() + file.close() + + encrypted_stream = EncryptedFile(tar_output.path.open("rb"), Path("tests/_data/private_key.pem")) + decrypted_path = tmp_path / "decrypted.tar" + + # Direct streaming is not an option because tarfile needs seek when reading from encrypted files directly + Path(decrypted_path).write_bytes(encrypted_stream.read()) + + with tarfile.open(name=decrypted_path, mode="r") as tar_file: + member = tar_file.getmember("file.log") + extracted = tar_file.extractfile(member).read() + # The content should be padded with zeros to match the original size, despite the fact that the file shrunk + assert extracted == content_padded + + +def test_tar_output_race_condition_with_growing_file(tmp_path: Path, public_key: bytes) -> None: + class GrowingFile(io.BytesIO): + def __init__(self, data: bytes): + super().__init__(data) + + def read(self, size: int) -> bytes: + return super().read(size) + b"FOX" + + content = b"some text" + + file = GrowingFile(content) + + tar_output = TarOutput(tmp_path / "race.tar", encrypt=True, public_key=public_key) + tar_output.write("file.log", file) + tar_output.close() + file.close() + + encrypted_stream = EncryptedFile(tar_output.path.open("rb"), Path("tests/_data/private_key.pem")) + decrypted_path = tmp_path / "decrypted.tar" + + # Direct streaming is not an option because tarfile needs seek when reading from encrypted files directly + Path(decrypted_path).write_bytes(encrypted_stream.read()) + + with tarfile.open(name=decrypted_path, mode="r") as tar_file: + member = tar_file.getmember("file.log") + extracted = tar_file.extractfile(member).read() + # The content should match the original content, without the extra bytes + # because the file was read with the original size + assert extracted == content + + +def test_tar_output_exception_rollback(tmp_path: Path) -> None: + """Test that tar file is properly truncated when an exception occurs during writing.""" + + class FailingFile(io.BytesIO): + def __init__(self, data: bytes, fail_after_bytes: int = 5): + super().__init__(data) + self.fail_after_bytes = fail_after_bytes + self.bytes_read = 0 + + def read(self, size: int) -> bytes: + data = super().read(size) + self.bytes_read += len(data) + if self.bytes_read > self.fail_after_bytes: + raise IOError("Simulated I/O error during file read") + return data + + content = b"This is some test content that will fail during reading" + failing_file = FailingFile(content, fail_after_bytes=5) + + tar_output = TarOutput(tmp_path / "test.tar") + + successful_file = io.BytesIO(b"dissectftw") + tar_output.write("successful_file.txt", successful_file) + + file_size_before_failure = tar_output.tar.fileobj.tell() + members_count_before = len(tar_output.tar.members) + + # Attempt to write the failing file + with pytest.raises(IOError, match="Simulated I/O error during file read"): + tar_output.write("failing_file.txt", failing_file, size=len(content)) + + # Verify that the tar file was truncated back to its state before the failed write + assert tar_output.tar.fileobj.tell() == file_size_before_failure + assert len(tar_output.tar.members) == members_count_before + + tar_output.close() + + # Verify the tar file can still be opened and contains only the successful entry + with tarfile.open(tar_output.path) as tar_file: + members = tar_file.getmembers() + assert len(members) == 1 + assert members[0].name == "successful_file.txt" + assert tar_file.extractfile("successful_file.txt").read() == b"dissectftw"