@@ -5,26 +5,72 @@ const paths = @import("paths.zig");
55
66pub const logger = std .log .scoped (.lib_proxy );
77
8- const dll_name = " winhttp" ;
8+ const dll_names : [] const [: 0 ] const u8 = &.{ "version" , " winhttp" } ;
99
10- const iter_proxy_funcs = std .mem .splitScalar (u8 , @embedFile ("symbols/" ++ dll_name ++ ".txt" ), '\n ' );
10+ const DllName = blk : {
11+ var fields : []const std.builtin.Type.EnumField = &.{};
1112
12- const ProxyFuncAddrs = blk : {
13- @setEvalBranchQuota (8000 );
13+ for (dll_names , 0.. ) | dll_name , i | {
14+ fields = fields ++ .{std.builtin.Type.EnumField {
15+ .name = dll_name ,
16+ .value = i ,
17+ }};
18+ }
1419
15- var fields : []const std.builtin.Type.StructField = &.{};
20+ break :blk @Type (.{ .@"enum" = .{
21+ .fields = fields ,
22+ .decls = &.{},
23+ } });
24+ };
1625
17- var funcs = iter_proxy_funcs ;
26+ fn proxyFunctions (comptime dll_name : []const u8 ) []const []const u8 {
27+ var buf : []const []const u8 = &.{};
28+
29+ var funcs = std .mem .splitScalar (u8 , @embedFile ("symbols/" ++ dll_name ++ ".txt" ), '\n ' );
1830 while (funcs .next ()) | name | {
31+ if (name .len == 0 ) {
32+ continue ;
33+ }
1934 if (std .mem .indexOfScalar (u8 , name , ' ' ) != null ) {
2035 @compileError ("proxy function name \" " ++ name ++ "\" contains whitespace" );
2136 }
37+ buf = buf ++ .{name };
38+ }
39+
40+ return buf ;
41+ }
42+
43+ const DllIncludes = blk : {
44+ var fields : []const std.builtin.Type.StructField = &.{};
45+
46+ for (std .meta .fieldNames (ProxyFuncAddrs )) | func_name | {
47+ fields = fields ++ .{std.builtin.Type.StructField {
48+ .name = func_name ,
49+ .type = bool ,
50+ .default_value_ptr = & false ,
51+ .is_comptime = false ,
52+ .alignment = 0 ,
53+ }};
54+ }
55+
56+ break :blk @Type (.{ .@"struct" = .{
57+ .layout = .@"packed" ,
58+ .fields = fields ,
59+ .decls = &.{},
60+ .is_tuple = false ,
61+ } });
62+ };
63+
64+ const EachDllIncludes = blk : {
65+ var fields : []const std.builtin.Type.StructField = &.{};
66+
67+ for (dll_names ) | dll_name | {
2268 fields = fields ++ .{std.builtin.Type.StructField {
23- .name = @ptrCast ( name ++ .{ 0 }) ,
24- .type = std . os . windows . FARPROC ,
25- .default_value_ptr = null ,
69+ .name = dll_name ,
70+ .type = DllIncludes ,
71+ .default_value_ptr = &.{} ,
2672 .is_comptime = false ,
27- .alignment = @alignOf ( std . os . windows . FARPROC ) ,
73+ .alignment = 0 ,
2874 }};
2975 }
3076
@@ -36,30 +82,91 @@ const ProxyFuncAddrs = blk: {
3682 } });
3783};
3884
39- var proxy_func_addrs : ProxyFuncAddrs = undefined ;
85+ const each_dll_includes = blk : {
86+ @setEvalBranchQuota (8000 );
4087
41- comptime {
88+ var includes : EachDllIncludes = &.{};
89+
90+ for (dll_names ) | dll_name | {
91+ for (proxyFunctions (dll_name )) | name | {
92+ @field (@field (includes , name ), dll_name ) = true ;
93+ }
94+ }
95+
96+ break :blk includes ;
97+ };
98+
99+ const ProxyFuncAddrs = blk : {
42100 @setEvalBranchQuota (8000 );
43- var funcs = iter_proxy_funcs ;
44- while (funcs .next ()) | name | {
101+
102+ var fields : []const std.builtin.Type.StructField = &.{};
103+
104+ for (dll_names ) | dll_name | {
105+ for (proxyFunctions (dll_name )) | name | {
106+ const FuncAddr = ? * fn () callconv (.c ) void ;
107+ fields = fields ++ .{std.builtin.Type.StructField {
108+ .name = @ptrCast (name ++ .{0 }),
109+ .type = FuncAddr ,
110+ .default_value_ptr = @ptrCast (&@as (FuncAddr , null )),
111+ .is_comptime = false ,
112+ .alignment = 0 ,
113+ }};
114+ }
115+ }
116+
117+ break :blk @Type (.{ .@"struct" = .{
118+ .layout = .auto ,
119+ .fields = fields ,
120+ .decls = &.{},
121+ .is_tuple = false ,
122+ } });
123+ };
124+
125+ var proxy_func_addrs : ProxyFuncAddrs = .{};
126+
127+ fn panicUnlinkedFunction (name : []const u8 ) noreturn {
128+ @branchHint (.cold );
129+ std .debug .panic ("Attempted to call unlinked function {s}" , .{name });
130+ }
131+
132+ comptime {
133+ for (std .meta .fieldNames (ProxyFuncAddrs )) | name | {
45134 @export (& struct {
46135 fn f () callconv (.c ) void {
47- return @as (* fn () callconv (.c ) void , @ptrCast (@field (proxy_func_addrs , name )))();
136+ if (@field (proxy_func_addrs , name )) | func | {
137+ return func ();
138+ } else {
139+ panicUnlinkedFunction (name );
140+ }
48141 }
49142 }.f , .{ .name = name });
50143 }
51144}
52145
53- fn panicLocateFunction (path : []const u16 , name : []const u8 ) noreturn {
146+ fn getDllIncludes (dll_name : DllName ) * const DllIncludes {
147+ inline for (dll_names ) | other_dll_name | {
148+ if (@field (DllName , other_dll_name ) == dll_name ) {
149+ return &@field (each_dll_includes , other_dll_name );
150+ }
151+ }
152+ unreachable ;
153+ }
154+
155+ fn logUnlinkableFunction (name : []const u8 , path : []const u16 ) void {
54156 @branchHint (.cold );
55- std . debug . panic ("Failed to locate function {s} in {s}" , .{ name , std .unicode .fmtUtf16Le (path ) });
157+ logger . warn ("Failed to locate function {s} in {s}" , .{ name , std .unicode .fmtUtf16Le (path ) });
56158}
57159
58- fn loadFunctions (dll : std.os.windows.HMODULE , path : []const u16 ) void {
160+ fn loadFunctions (dll : std.os.windows.HMODULE , path : []const u16 , dll_name : DllName ) void {
161+ const includes = getDllIncludes (dll_name );
59162 inline for (comptime std .meta .fieldNames (ProxyFuncAddrs )) | field | {
60- @field (proxy_func_addrs , field ) = std .os .windows .kernel32 .GetProcAddress (dll , field ) orelse {
61- panicLocateFunction (path , field );
62- };
163+ if (@field (includes , field )) {
164+ if (std .os .windows .kernel32 .GetProcAddress (dll , field )) | ptr | {
165+ @field (proxy_func_addrs , field ) = @ptrCast (ptr );
166+ } else {
167+ logUnlinkableFunction (field , path );
168+ }
169+ }
63170 }
64171}
65172
@@ -80,18 +187,29 @@ fn empty(comptime T: type) *[0:0]T {
80187 return @constCast (&[_ :0 ]T {});
81188}
82189
190+ fn findDllMatch (module_name : []const u16 ) ? DllName {
191+ if (module_name .len > 4 and eqlIgnoreCase (module_name [module_name .len - 4 .. ], ".dll" )) {
192+ const module_name_stripped = module_name [0 .. module_name .len - 4 ];
193+ inline for (dll_names ) | dll_name | {
194+ if (eqlIgnoreCase (module_name_stripped , dll_name )) {
195+ return @field (DllName , dll_name );
196+ }
197+ }
198+ }
199+ return null ;
200+ }
201+
83202pub fn loadProxy (module : std.os.windows.HMODULE ) ! void {
84203 var module_path_buf = paths.ModulePathBuf {};
85204 const module_path = (try module_path_buf .get (module )).? ;
86205
87206 const module_name = paths .getFileName (u16 , module_path );
88207
89- const proxy_name = dll_name ++ ".dll" ;
90- if (! eqlIgnoreCase (module_name , proxy_name )) {
91- logger .debug ("{s} is not supported for proxying" , .{std .unicode .fmtUtf16Le (module_name )});
208+ const dll_name = findDllMatch (module_name ) orelse {
209+ logger .debug ("{} is not supported for proxying" , .{std .unicode .fmtUtf16Le (module_name )});
92210 return error .UnsupportedName ;
93- }
94- logger .debug ("Detected injection as supported proxy. Loading actual." , .{});
211+ };
212+ logger .debug ("Detected injection as supported proxy {} . Loading actual." , .{std . unicode . fmtUtf16Le ( module_name ) });
95213
96214 // sys_len includes null-terminator
97215 const sys_len = std .os .windows .kernel32 .GetSystemDirectoryW (empty (u16 ), 0 );
@@ -106,11 +224,11 @@ pub fn loadProxy(module: std.os.windows.HMODULE) !void {
106224 sys_full_path_buf [sys_len + module_name .len ] = 0 ;
107225 const sys_full_path = sys_full_path_buf [0 .. sys_len + module_name .len :0 ];
108226
109- logger .debug ("Looking for actual DLL at {s }" , .{std .unicode .fmtUtf16Le (sys_full_path )});
227+ logger .debug ("Looking for actual DLL at {}" , .{std .unicode .fmtUtf16Le (sys_full_path )});
110228
111229 const handle = try std .os .windows .LoadLibraryW (sys_full_path );
112230
113- loadFunctions (handle , sys_full_path );
231+ loadFunctions (handle , sys_full_path , dll_name );
114232}
115233
116234test {
0 commit comments