From 340b49dd3293c822adb2988b1619129225dfdafe Mon Sep 17 00:00:00 2001 From: Tom Read Cutting Date: Fri, 9 Sep 2022 14:43:39 +0200 Subject: [PATCH] Ensure all IO errors are exhaustively handled Make sure all io errors returned by any library call go through zar's internal file io error handler. This is ensured by adding a build option to ensure a compile error if that isn't the case and adding it to the test suite. --- .github/workflows/ci.yml | 18 +++++ build.zig | 3 + src/archive/Archive.zig | 159 ++++++++++++++++++++++++--------------- src/main.zig | 3 +- 4 files changed, 119 insertions(+), 64 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 03eabf0..25f67fd 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -56,3 +56,21 @@ jobs: - name: Build Redis macOS run: make CC="zig cc -target x86_64-macos" CXX="zig c++ -target x86_64-macos" AR="${GITHUB_WORKSPACE}/zig-out/bin/zar" RANLIB="zig ranlib" uname_S="Darwin" uname_M="x86_64" USE_JEMALLOC=no USE_SYSTEMD=no working-directory: redis + test_io_errors_handled: + name: Test handled errors are actually handled + timeout-minutes: 15 + runs-on: ${{ matrix.os }} + strategy: + matrix: + os: [ubuntu-latest, macos-latest, windows-latest] + steps: + - name: Checkout + uses: actions/checkout@v1 + with: + submodules: recursive + - name: Setup Zig + uses: goto-bus-stop/setup-zig@v1 + with: + version: master + - name: Build + run: zig build -Dtest-errors-handled=true diff --git a/build.zig b/build.zig index 6ca3a49..0640dec 100644 --- a/build.zig +++ b/build.zig @@ -30,6 +30,9 @@ pub fn build(b: *std.build.Builder) void { tests.addOptions("build_options", exe_options); tests_exe.addOptions("build_options", exe_options); + const test_errors_handled = b.option(bool, "test-errors-handled", "Compile with this to confirm zar sends all io errors through the io error handler") orelse false; + exe_options.addOption(bool, "test_errors_handled", test_errors_handled); + { const tracy = b.option([]const u8, "tracy", "Enable Tracy integration. Supply path to Tracy source"); const tracy_callstack = b.option(bool, "tracy-callstack", "Include callstack information with Tracy data. Does nothing if -Dtracy is not provided") orelse false; diff --git a/src/archive/Archive.zig b/src/archive/Archive.zig index 1b23aab..ec32446 100644 --- a/src/archive/Archive.zig +++ b/src/archive/Archive.zig @@ -2,6 +2,7 @@ const Archive = @This(); const builtin = @import("builtin"); const std = @import("std"); +const build_options = @import("build_options"); const trace = @import("../tracy.zig").trace; const traceNamed = @import("../tracy.zig").traceNamed; const fmt = std.fmt; @@ -107,7 +108,15 @@ pub const Operation = enum { // not so the caller will need to print appropriate error messages // themselves (if needed at all). pub const UnhandledError = ParseError || CriticalError; -pub const HandledError = IoError; +pub const HandledError = HandledIoError; + +// We can set this to true just to make Handled errors are actually handled at +// comptime! +pub const test_errors_handled = build_options.test_errors_handled; + +pub const HandledIoError = if (test_errors_handled) error{Handled} else IoError; + +pub const CreateError = error{}; pub const ParseError = error{ NotArchive, @@ -117,6 +126,8 @@ pub const ParseError = error{ }; pub const InsertError = error{}; +pub const DeleteError = error{}; +pub const FinalizeError = error{}; pub const CriticalError = error{ OutOfMemory, @@ -128,6 +139,7 @@ pub const IoError = error{ BrokenPipe, ConnectionResetByPeer, ConnectionTimedOut, + DiskQuota, InputOutput, IsDir, NotOpenForReading, @@ -145,10 +157,12 @@ pub const IoError = error{ FileNotFound, FileTooBig, InvalidUtf8, + LockViolation, NameTooLong, NoDevice, NoSpaceLeft, NotDir, + NotOpenForWriting, PathAlreadyExists, PipeBusy, ProcessFdQuotaExceeded, @@ -261,9 +275,11 @@ const ErrorContext = enum { opening, reading, seeking, + stat, + writing, }; -pub fn printFileIoError(comptime context: ErrorContext, file_name: []const u8, err: IoError) void { +pub fn printFileIoError(comptime context: ErrorContext, file_name: []const u8, err: IoError) HandledIoError { const context_str = @tagName(context); switch (err) { @@ -271,12 +287,17 @@ pub fn printFileIoError(comptime context: ErrorContext, file_name: []const u8, e error.FileNotFound => logger.err("Error " ++ context_str ++ " '{s}', file not found.", .{file_name}), else => logger.err("Error " ++ context_str ++ " '{s}'.", .{file_name}), } - return; + if (test_errors_handled) return error.Handled; + return err; } -pub fn handleFileIoError(comptime context: ErrorContext, file_name: []const u8, err_result: anytype) @TypeOf(err_result) { - _ = err_result catch |err| printFileIoError(context, file_name, err); - return err_result; +// The weird return type is so that we can distinguish between handled and unhandled IO errors, +// i.e. if test_errors_handled is set to true, and raw calls to io operations will return in a compile failure +pub fn handleFileIoError(comptime context: ErrorContext, file_name: []const u8, err_result: anytype) HandledIoError!@typeInfo(@TypeOf(err_result)).ErrorUnion.payload { + const unwrapped_result = err_result catch |err| { + return printFileIoError(context, file_name, err); + }; + return unwrapped_result; } // These are the defaults llvm ar uses (excepting windows) @@ -301,7 +322,7 @@ pub fn create( output_archive_type: ArchiveType, modifiers: Modifiers, created: bool, -) !Archive { +) (CreateError || HandledIoError)!Archive { return Archive{ .dir = dir, .file = file, @@ -312,7 +333,7 @@ pub fn create( .symbols = .{}, .file_name_to_index = .{}, .modifiers = modifiers, - .stat = try file.stat(), + .stat = try handleFileIoError(.stat, name, file.stat()), .created = created, }; } @@ -382,7 +403,7 @@ const TrackingBufferedWriter = tracking_buffered_writer.TrackingBufferedWriter(s // TODO: This needs to be integrated into the workflow // used for parsing. (use same error handling workflow etc.) /// Use same naming scheme for objects (as found elsewhere in the file). -pub fn finalize(self: *Archive, allocator: Allocator) !void { +pub fn finalize(self: *Archive, allocator: Allocator) (FinalizeError || HandledIoError || CriticalError)!void { const tracy = trace(@src()); defer tracy.end(); if (self.output_archive_type == .ambiguous) { @@ -392,13 +413,13 @@ pub fn finalize(self: *Archive, allocator: Allocator) !void { } // Overwrite all contents - try self.file.seekTo(0); + try handleFileIoError(.seeking, self.name, self.file.seekTo(0)); // We wrap the buffered writer so that can we can track file position more easily var buffered_writer = TrackingBufferedWriter{ .buffered_writer = std.io.bufferedWriter(self.file.writer()) }; const writer = buffered_writer.writer(); - try writer.writeAll(if (self.output_archive_type == .gnuthin) magic_thin else magic_string); + try handleFileIoError(.writing, self.name, writer.writeAll(if (self.output_archive_type == .gnuthin) magic_thin else magic_string)); const header_names = try allocator.alloc([16]u8, self.files.items.len); @@ -490,7 +511,10 @@ pub fn finalize(self: *Archive, allocator: Allocator) !void { defer allocator.free(name); // Edit the header - _ = try std.fmt.bufPrint(&(header_names[index]), "{s: <16}", .{name}); + _ = std.fmt.bufPrint(&(header_names[index]), "{s: <16}", .{name}) catch |e| switch (e) { + // Should be unreachable as the buffer should already definetely be large enough... + error.NoSpaceLeft => unreachable, + }; } // Write the symbol table itself @@ -512,15 +536,15 @@ pub fn finalize(self: *Archive, allocator: Allocator) !void { const magic: []const u8 = if (format == .gnu64) "/SYM64/" else "/"; - try writer.print(Header.format_string, .{ magic, symtab_time, 0, 0, 0, symbol_table_size }); + try handleFileIoError(.writing, self.name, writer.print(Header.format_string, .{ magic, symtab_time, 0, 0, 0, symbol_table_size })); { const tracy_scope_inner = traceNamed(@src(), "Write Symbol Count"); defer tracy_scope_inner.end(); if (format == .gnu64) { - try writer.writeIntBig(u64, @intCast(u64, self.symbols.items.len)); + try handleFileIoError(.writing, self.name, writer.writeIntBig(u64, @intCast(u64, self.symbols.items.len))); } else { - try writer.writeIntBig(u32, @intCast(u32, self.symbols.items.len)); + try handleFileIoError(.writing, self.name, writer.writeIntBig(u32, @intCast(u32, self.symbols.items.len))); } } @@ -550,14 +574,14 @@ pub fn finalize(self: *Archive, allocator: Allocator) !void { for (self.symbols.items) |symbol| { if (format == .gnu64) { - try writer.writeIntBig(i64, relative_file_offsets[symbol.file_index] + @intCast(i64, offset_to_files)); + try handleFileIoError(.writing, self.name, writer.writeIntBig(i64, relative_file_offsets[symbol.file_index] + @intCast(i64, offset_to_files))); } else { - try writer.writeIntBig(i32, relative_file_offsets[symbol.file_index] + @intCast(i32, offset_to_files)); + try handleFileIoError(.writing, self.name, writer.writeIntBig(i32, relative_file_offsets[symbol.file_index] + @intCast(i32, offset_to_files))); } } } - try writer.writeAll(symbol_table); + try handleFileIoError(.writing, self.name, writer.writeAll(symbol_table)); } // Write the string table itself @@ -565,9 +589,10 @@ pub fn finalize(self: *Archive, allocator: Allocator) !void { const tracy_scope = traceNamed(@src(), "Write String Table"); defer tracy_scope.end(); if (string_table.items.len != 0) { - while (string_table.items.len % self.output_archive_type.getAlignment() != 0) + while (string_table.items.len % self.output_archive_type.getAlignment() != 0) { try string_table.append('\n'); - try writer.print("//{s}{: <10}`\n{s}", .{ " " ** 46, string_table.items.len, string_table.items }); + } + try handleFileIoError(.writing, self.name, writer.print("//{s}{: <10}`\n{s}", .{ " " ** 46, string_table.items.len, string_table.items })); } } }, @@ -597,16 +622,16 @@ pub fn finalize(self: *Archive, allocator: Allocator) !void { int_size + // Int describing size of symbol table's strings symbol_table.len; // The lengths of strings themselves - try writer.print(Header.format_string, .{ "#1/12", symtab_time, 0, 0, 0, symbol_table_size }); + try handleFileIoError(.writing, self.name, writer.print(Header.format_string, .{ "#1/12", symtab_time, 0, 0, 0, symbol_table_size })); const endian = builtin.cpu.arch.endian(); if (format == .darwin64) { - try writer.writeAll(bsd_symdef_64_magic); - try writer.writeInt(u64, @intCast(u64, num_ranlib_bytes), endian); + try handleFileIoError(.writing, self.name, writer.writeAll(bsd_symdef_64_magic)); + try handleFileIoError(.writing, self.name, writer.writeInt(u64, @intCast(u64, num_ranlib_bytes), endian)); } else { - try writer.writeAll(bsd_symdef_magic ++ "\x00\x00\x00"); - try writer.writeInt(u32, @intCast(u32, num_ranlib_bytes), endian); + try handleFileIoError(.writing, self.name, writer.writeAll(bsd_symdef_magic ++ "\x00\x00\x00")); + try handleFileIoError(.writing, self.name, writer.writeInt(u32, @intCast(u32, num_ranlib_bytes), endian)); } const ranlibs = try allocator.alloc(Ranlib(IntType), self.symbols.items.len); @@ -624,15 +649,15 @@ pub fn finalize(self: *Archive, allocator: Allocator) !void { ranlibs[idx].ran_off = relative_file_offsets[symbol.file_index] + @intCast(i32, offset_to_files); } - try writer.writeAll(mem.sliceAsBytes(ranlibs)); + try handleFileIoError(.writing, self.name, writer.writeAll(mem.sliceAsBytes(ranlibs))); if (format == .darwin64) { - try writer.writeInt(u64, @intCast(u64, symbol_string_table_and_offsets.unpadded_symbol_table_length), endian); + try handleFileIoError(.writing, self.name, writer.writeInt(u64, @intCast(u64, symbol_string_table_and_offsets.unpadded_symbol_table_length), endian)); } else { - try writer.writeInt(u32, @intCast(u32, symbol_string_table_and_offsets.unpadded_symbol_table_length), endian); + try handleFileIoError(.writing, self.name, writer.writeInt(u32, @intCast(u32, symbol_string_table_and_offsets.unpadded_symbol_table_length), endian)); } - try writer.writeAll(symbol_table); + try handleFileIoError(.writing, self.name, writer.writeAll(symbol_table)); } }, // This needs to be able to tell whatsupp. @@ -653,7 +678,10 @@ pub fn finalize(self: *Archive, allocator: Allocator) !void { const padding = self.calculatePadding(buffered_writer.file_pos + header_buffer.len + file.name.len); // BSD format: Just write the length of the name in header - _ = try std.fmt.bufPrint(&(header_names[index]), "#1/{: <13}", .{file.name.len + padding}); + _ = std.fmt.bufPrint(&(header_names[index]), "#1/{: <13}", .{file.name.len + padding}) catch |e| switch (e) { + // Should be unreachable as the buffer should already definetely be large enough... + error.NoSpaceLeft => unreachable, + }; if (self.output_archive_type.isDarwin()) { var file_padding = file.contents.length % self.output_archive_type.getFileAlignment(); file_padding = (self.output_archive_type.getFileAlignment() - file_padding) % self.output_archive_type.getFileAlignment(); @@ -664,34 +692,37 @@ pub fn finalize(self: *Archive, allocator: Allocator) !void { } }; - _ = try std.fmt.bufPrint( + _ = std.fmt.bufPrint( &header_buffer, Header.format_string, .{ &header_names[index], file.contents.timestamp, file.contents.uid, file.contents.gid, file.contents.mode, file_length }, - ); + ) catch |e| switch (e) { + // Should be unreachable as the buffer should already definetely be large enough... + error.NoSpaceLeft => unreachable, + }; // TODO: handle errors - _ = try writer.write(&header_buffer); + _ = try handleFileIoError(.writing, self.name, writer.write(&header_buffer)); // Write the name of the file in the data section if (self.output_archive_type.isBsdLike()) { - try writer.writeAll(file.name); - try writer.writeByteNTimes(0, self.calculatePadding(buffered_writer.file_pos)); + try handleFileIoError(.writing, self.name, writer.writeAll(file.name)); + try handleFileIoError(.writing, self.name, writer.writeByteNTimes(0, self.calculatePadding(buffered_writer.file_pos))); } if (self.output_archive_type != .gnuthin) { - try file.contents.write(writer, null); - try writer.writeByteNTimes('\n', self.calculatePadding(buffered_writer.file_pos)); + try handleFileIoError(.writing, self.name, file.contents.write(writer, null)); + try handleFileIoError(.writing, self.name, writer.writeByteNTimes('\n', self.calculatePadding(buffered_writer.file_pos))); } } - try buffered_writer.flush(); + try handleFileIoError(.writing, self.name, buffered_writer.flush()); // Truncate the file size - try self.file.setEndPos(buffered_writer.file_pos); + try handleFileIoError(.writing, self.name, self.file.setEndPos(buffered_writer.file_pos)); } -pub fn deleteFiles(self: *Archive, file_names: []const []const u8) !void { +pub fn deleteFiles(self: *Archive, file_names: []const []const u8) (DeleteError || HandledIoError || CriticalError)!void { const tracy = trace(@src()); defer tracy.end(); // For the list of given file names, find the entry in self.files @@ -760,16 +791,16 @@ pub fn extract(self: *Archive, file_names: []const []const u8) !void { } } -pub fn addToSymbolTable(self: *Archive, allocator: Allocator, archived_file: *const ArchivedFile, file_index: usize, file: fs.File, file_offset: u32) (CriticalError || IoError)!void { +pub fn addToSymbolTable(self: *Archive, allocator: Allocator, archived_file: *const ArchivedFile, file_index: usize, file: fs.File, file_offset: u32) (CriticalError || HandledIoError)!void { // TODO: make this read directly from the file contents buffer! // Get the file magic - try file.seekTo(file_offset); + try handleFileIoError(.seeking, archived_file.name, file.seekTo(file_offset)); var magic: [4]u8 = undefined; - _ = try file.reader().read(&magic); + _ = try handleFileIoError(.reading, archived_file.name, file.reader().read(&magic)); - try file.seekTo(file_offset); + try handleFileIoError(.seeking, archived_file.name, file.seekTo(file_offset)); blk: { // TODO: Load object from memory (upstream zld) @@ -879,7 +910,7 @@ pub fn addToSymbolTable(self: *Archive, allocator: Allocator, archived_file: *co } } -pub fn insertFiles(self: *Archive, allocator: Allocator, file_names: []const []const u8) (InsertError || IoError || CriticalError)!void { +pub fn insertFiles(self: *Archive, allocator: Allocator, file_names: []const []const u8) (InsertError || HandledIoError || CriticalError)!void { const tracy = trace(@src()); defer tracy.end(); @@ -900,13 +931,13 @@ pub fn insertFiles(self: *Archive, allocator: Allocator, file_names: []const []c // FIXME: Currently windows doesnt support the Stat struct if (builtin.os.tag == .windows) { - const file_stats = try file.stat(); + const file_stats = try handleFileIoError(.stat, file_name, file.stat()); // Convert timestamp from ns to s mtime = file_stats.mtime; size = file_stats.size; mode = file_stats.mode; } else { - const file_stats = try std.os.fstat(file.handle); + const file_stats = try handleFileIoError(.stat, file_name, std.os.fstat(file.handle)); gid = file_stats.gid; uid = file_stats.uid; @@ -936,10 +967,15 @@ pub fn insertFiles(self: *Archive, allocator: Allocator, file_names: []const []c const timestamp = @intCast(u128, @divFloor(mtime, std.time.ns_per_s)); + // Extract critical error from error set - so IO errors can be handled seperately + const bytes_or_io_error = file.readToEndAllocOptions(allocator, std.math.maxInt(usize), size, @alignOf(u64), null) catch |e| switch (e) { + error.OutOfMemory => return error.OutOfMemory, + else => @errSetCast(IoError, e), + }; var archived_file = ArchivedFile{ .name = try allocator.dupe(u8, fs.path.basename(file_name)), .contents = Contents{ - .bytes = try file.readToEndAllocOptions(allocator, std.math.maxInt(usize), size, @alignOf(u64), null), + .bytes = try handleFileIoError(.reading, file_name, bytes_or_io_error), .length = size, .mode = mode, .timestamp = timestamp, @@ -967,7 +1003,7 @@ pub fn insertFiles(self: *Archive, allocator: Allocator, file_names: []const []c } } -pub fn parse(self: *Archive, allocator: Allocator) (ParseError || IoError || CriticalError)!void { +pub fn parse(self: *Archive, allocator: Allocator) (ParseError || HandledIoError || CriticalError)!void { const tracy = trace(@src()); defer tracy.end(); const reader = self.file.reader(); @@ -1135,8 +1171,7 @@ pub fn parse(self: *Archive, allocator: Allocator) (ParseError || IoError || Cri const archive_header = reader.readStruct(Header) catch |err| switch (err) { error.EndOfStream => break, else => { - printFileIoError(.reading, self.name, err); - return err; + return printFileIoError(.reading, self.name, err); }, }; @@ -1241,7 +1276,7 @@ pub fn parse(self: *Archive, allocator: Allocator) (ParseError || IoError || Cri var symbol_magic_check_buffer: [bsd_symdef_longest_magic]u8 = undefined; // TODO: handle not reading enough characters! - const chars_read = try reader.read(&symbol_magic_check_buffer); + const chars_read = try handleFileIoError(.reading, self.name, reader.read(&symbol_magic_check_buffer)); var sorted = false; @@ -1257,7 +1292,7 @@ pub fn parse(self: *Archive, allocator: Allocator) (ParseError || IoError || Cri } if (chars_read - magic_len > 0) { - try reader.context.seekBy(@intCast(i64, magic_len) - @intCast(i64, chars_read)); + try handleFileIoError(.seeking, self.name, reader.context.seekBy(@intCast(i64, magic_len) - @intCast(i64, chars_read))); } seek_forward_amount = seek_forward_amount - @intCast(u32, magic_len); @@ -1280,7 +1315,7 @@ pub fn parse(self: *Archive, allocator: Allocator) (ParseError || IoError || Cri } // TODO: error if negative (because spec defines this as a long, so should never be that large?) - const num_ranlib_bytes = try reader.readInt(IntType, endianess); + const num_ranlib_bytes = try handleFileIoError(.reading, self.name, reader.readInt(IntType, endianess)); seek_forward_amount = seek_forward_amount - @as(u32, @sizeOf(IntType)); // TODO: error if this doesn't divide properly? @@ -1289,12 +1324,12 @@ pub fn parse(self: *Archive, allocator: Allocator) (ParseError || IoError || Cri var ranlib_bytes = try allocator.alloc(u8, @intCast(u32, num_ranlib_bytes)); // TODO: error handling - _ = try reader.read(ranlib_bytes); + _ = try handleFileIoError(.reading, self.name, reader.read(ranlib_bytes)); seek_forward_amount = seek_forward_amount - @intCast(u32, num_ranlib_bytes); var ranlibs = mem.bytesAsSlice(Ranlib(IntType), ranlib_bytes); - const symbol_strings_length = try reader.readInt(u32, endianess); + const symbol_strings_length = try handleFileIoError(.reading, self.name, reader.readInt(u32, endianess)); // TODO: We don't really need this information, but maybe it could come in handy // later? _ = symbol_strings_length; @@ -1303,7 +1338,7 @@ pub fn parse(self: *Archive, allocator: Allocator) (ParseError || IoError || Cri const symbol_string_bytes = try allocator.alloc(u8, seek_forward_amount); seek_forward_amount = 0; - _ = try reader.read(symbol_string_bytes); + _ = try handleFileIoError(.reading, self.name, reader.read(symbol_string_bytes)); if (!self.modifiers.build_symbol_table) { for (ranlibs) |ranlib| { @@ -1322,11 +1357,11 @@ pub fn parse(self: *Archive, allocator: Allocator) (ParseError || IoError || Cri // We have a symbol table! } - try reader.context.seekBy(seek_forward_amount); + try handleFileIoError(.seeking, self.name, reader.context.seekBy(seek_forward_amount)); continue; } - try reader.context.seekTo(current_seek_pos); + try handleFileIoError(.seeking, self.name, reader.context.seekTo(current_seek_pos)); } const archive_name_buffer = try allocator.alloc(u8, archive_name_length); @@ -1357,7 +1392,7 @@ pub fn parse(self: *Archive, allocator: Allocator) (ParseError || IoError || Cri }, }; - const offset_hack = try reader.context.getPos(); + const offset_hack = try handleFileIoError(.seeking, self.name, reader.context.getPos()); if (self.inferred_archive_type == .gnuthin) { var thin_file = try handleFileIoError(.opening, trimmed_archive_name, self.dir.openFile(trimmed_archive_name, .{})); @@ -1369,13 +1404,13 @@ pub fn parse(self: *Archive, allocator: Allocator) (ParseError || IoError || Cri } if (self.modifiers.build_symbol_table) { - const post_offset_hack = try reader.context.getPos(); + const post_offset_hack = try handleFileIoError(.seeking, self.name, reader.context.getPos()); // TODO: Actually handle these errors! self.addToSymbolTable(allocator, &parsed_file, self.files.items.len, reader.context, @intCast(u32, offset_hack)) catch { return error.TODO; }; - try reader.context.seekTo(post_offset_hack); + try handleFileIoError(.seeking, self.name, reader.context.seekTo(post_offset_hack)); } try self.file_name_to_index.put(allocator, trimmed_archive_name, self.files.items.len); diff --git a/src/main.zig b/src/main.zig index 5938521..4f0f01c 100644 --- a/src/main.zig +++ b/src/main.zig @@ -160,8 +160,7 @@ fn openOrCreateFile(cwd: fs.Dir, archive_path: []const u8, print_creation_warnin return create_file_handle; }, else => { - Archive.printFileIoError(.opening, archive_path, err); - return err; + return Archive.printFileIoError(.opening, archive_path, err); }, }; return open_file_handle;