diff --git a/extensions/BUILD b/extensions/BUILD index d84881716..f2ab4c750 100644 --- a/extensions/BUILD +++ b/extensions/BUILD @@ -173,6 +173,79 @@ cc_test( ], ) +cc_library( + name = "network_ext_functions", + srcs = ["network_ext_functions.cc"], + hdrs = ["network_ext_functions.h"], + deps = [ + "//common:native_type", + "//common:typeinfo", + "//common:value", + "//net/base:ipaddress", + "//net/util:ipaddress_util", + "//runtime:function", + "@com_google_absl//absl/status", + "@com_google_absl//absl/strings", + "@com_google_protobuf//:protobuf", + ], +) + +cc_library( + name = "network_ext", + srcs = ["network_ext.cc"], + hdrs = ["network_ext.h"], + deps = [ + ":network_ext_functions", + "//base:builtins", + "//checker:type_checker_builder", + "//common:decl", + "//common:native_type", + "//common:type", + "//common:value", + "//compiler", + "//internal:status_macros", + "//net/base:ipaddress", + "//runtime:function", + "//runtime:function_adapter", + "//runtime:function_registry", + "//runtime:runtime_options", + "//runtime:type_registry", + "@com_google_absl//absl/base:no_destructor", + "@com_google_absl//absl/base:nullability", + "@com_google_absl//absl/status", + "@com_google_absl//absl/strings", + "@com_google_protobuf//:protobuf", + ], +) + +cc_test( + name = "network_ext_test", + srcs = ["network_ext_test.cc"], + deps = [ + ":network_ext", + "//checker:validation_result", + "//common:ast", + "//common:minimal_descriptor_pool", + "//common:value", + "//compiler", + "//compiler:compiler_factory", + "//internal:status_macros", + "//internal:testing", + "//runtime", + "//runtime:activation", + "//runtime:runtime_builder", + "//runtime:runtime_builder_factory", + "//runtime:runtime_options", + "//runtime:standard_functions", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:status_matchers", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_cel_spec//proto/cel/expr:syntax_cc_proto", + "@com_google_protobuf//:protobuf", + ], +) + # New users should use ":regex_ext" instead. cc_library( name = "regex_functions", diff --git a/extensions/network_ext.cc b/extensions/network_ext.cc new file mode 100644 index 000000000..a032dec76 --- /dev/null +++ b/extensions/network_ext.cc @@ -0,0 +1,548 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "extensions/network_ext.h" + +#include +#include + +#include "net/base/ipaddress.h" +#include "absl/base/no_destructor.h" +#include "absl/base/nullability.h" +#include "absl/status/status.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" +#include "base/builtins.h" +#include "checker/type_checker_builder.h" +#include "common/decl.h" +#include "common/native_type.h" +#include "common/type.h" +#include "common/value.h" +#include "compiler/compiler.h" +#include "extensions/network_ext_functions.h" +#include "internal/status_macros.h" +#include "runtime/function.h" +#include "runtime/function_adapter.h" +#include "runtime/function_registry.h" +#include "runtime/runtime_options.h" +#include "runtime/type_registry.h" +#include "google/protobuf/arena.h" +#include "google/protobuf/descriptor.h" +#include "google/protobuf/message.h" + +namespace cel::extensions { +namespace { + +using ::cel::BinaryFunctionAdapter; +using ::cel::BoolValue; +using ::cel::MakeFunctionDecl; +using ::cel::MakeMemberOverloadDecl; +using ::cel::MakeOverloadDecl; +using ::cel::NativeTypeId; +using ::cel::OpaqueType; +using ::cel::OpaqueValue; +using ::cel::OpaqueValueContent; +using ::cel::OpaqueValueDispatcher; +using ::cel::StringValue; +using ::cel::Type; +using ::cel::TypeType; +using ::cel::UnaryFunctionAdapter; +using ::cel::UnsafeOpaqueValue; +using ::cel::Value; +using ::cel::builtin::kEqual; +using ::cel::builtin::kString; +using ::net_base::IPAddress; +using ::net_base::IsInitializedAddress; +using ::google::protobuf::Arena; +using ::google::protobuf::DescriptorPool; +using ::google::protobuf::MessageFactory; + +// Arena for static type instances +Arena* absl_nonnull BuiltinsArena() { + static absl::NoDestructor arena; + return arena.get(); +} + +// CEL Type Declarations +OpaqueType IpType() { + static const absl::NoDestructor kInstance( + BuiltinsArena(), "net.IP", std::vector{}); + return *kInstance; +} + +Type TypeOfIpType() { + static const absl::NoDestructor kInstance( + TypeType(BuiltinsArena(), IpType())); + return *kInstance; +} + +OpaqueType CidrType() { + static const absl::NoDestructor kInstance( + BuiltinsArena(), "net.CIDR", std::vector{}); + return *kInstance; +} + +Type TypeOfCidrType() { + static const absl::NoDestructor kInstance( + TypeType(BuiltinsArena(), CidrType())); + return *kInstance; +} + +// ----------------------------------------------------------------------------- +// Dispatcher for IpAddrRep (net.IP) +// ----------------------------------------------------------------------------- + +NativeTypeId IpAddrRep_GetTypeId(const OpaqueValueDispatcher*, + OpaqueValueContent content) { + return IpAddrRep::GetTypeId(); +} + +absl::string_view IpAddrRep_GetTypeName(const OpaqueValueDispatcher*, + OpaqueValueContent content) { + return "net.IP"; +} + +std::string IpAddrRep_DebugString(const OpaqueValueDispatcher*, + OpaqueValueContent content) { + return content.To()->DebugString(); +} + +absl::Status IpAddrRep_Equal(const OpaqueValueDispatcher*, + OpaqueValueContent content, + const OpaqueValue& other, const DescriptorPool*, + MessageFactory*, Arena*, Value* result) { + const IpAddrRep* self = content.To(); + const IpAddrRep* other_rep = IpAddrRep::Unwrap(other); + if (!other_rep) { + *result = BoolValue(false); + return absl::OkStatus(); + } + *result = BoolValue(self->Equals(*other_rep)); + return absl::OkStatus(); +} + +OpaqueValue IpAddrRep_Clone(const OpaqueValueDispatcher*, + OpaqueValueContent content, Arena* arena) { + const IpAddrRep* self = content.To(); + return IpAddrRep::Create(arena, self->addr()).GetOpaque(); +} + +OpaqueType IpAddrRep_GetRuntimeType(const OpaqueValueDispatcher*, + OpaqueValueContent) { + return IpType(); +} + +static const OpaqueValueDispatcher kIpAddrRepDispatcher = { + .get_type_id = IpAddrRep_GetTypeId, + .get_arena = nullptr, + .get_type_name = IpAddrRep_GetTypeName, + .debug_string = IpAddrRep_DebugString, + .get_runtime_type = IpAddrRep_GetRuntimeType, + .equal = IpAddrRep_Equal, + .clone = IpAddrRep_Clone, +}; + +// ----------------------------------------------------------------------------- +// Dispatcher for CidrRangeRep (net.CIDR) +// ----------------------------------------------------------------------------- + +NativeTypeId CidrRangeRep_GetTypeId(const OpaqueValueDispatcher*, + OpaqueValueContent content) { + return CidrRangeRep::GetTypeId(); +} + +absl::string_view CidrRangeRep_GetTypeName(const OpaqueValueDispatcher*, + OpaqueValueContent content) { + return "net.CIDR"; +} + +std::string CidrRangeRep_DebugString(const OpaqueValueDispatcher*, + OpaqueValueContent content) { + return content.To()->DebugString(); +} + +absl::Status CidrRangeRep_Equal(const OpaqueValueDispatcher*, + OpaqueValueContent content, + const OpaqueValue& other, const DescriptorPool*, + MessageFactory*, Arena*, Value* result) { + const CidrRangeRep* self = content.To(); + const CidrRangeRep* other_rep = CidrRangeRep::Unwrap(other); + if (!other_rep) { + *result = BoolValue(false); + return absl::OkStatus(); + } + *result = BoolValue(self->Equals(*other_rep)); + return absl::OkStatus(); +} + +OpaqueValue CidrRangeRep_Clone(const OpaqueValueDispatcher*, + OpaqueValueContent content, Arena* arena) { + const CidrRangeRep* self = content.To(); + return CidrRangeRep::Create(arena, self->host(), self->length()).GetOpaque(); +} + +OpaqueType CidrRangeRep_GetRuntimeType(const OpaqueValueDispatcher*, + OpaqueValueContent) { + return CidrType(); +} + +static const OpaqueValueDispatcher kCidrRangeRepDispatcher = { + .get_type_id = CidrRangeRep_GetTypeId, + .get_arena = nullptr, + .get_type_name = CidrRangeRep_GetTypeName, + .debug_string = CidrRangeRep_DebugString, + .get_runtime_type = CidrRangeRep_GetRuntimeType, + .equal = CidrRangeRep_Equal, + .clone = CidrRangeRep_Clone, +}; + +} // namespace + +// ----------------------------------------------------------------------------- +// IpAddrRep Method Implementations +// ----------------------------------------------------------------------------- +Value IpAddrRep::Create(Arena* arena, const IPAddress& addr) { + IpAddrRep* rep = Arena::Create(arena, addr); + return UnsafeOpaqueValue(&kIpAddrRepDispatcher, + OpaqueValueContent::From(rep)); +} + +const IpAddrRep* IpAddrRep::Unwrap(const Value& value) { + auto opaque = value.AsOpaque(); + if (!opaque.has_value() || opaque->GetTypeId() != IpAddrRep::GetTypeId()) { + return nullptr; + } + return opaque->content().To(); +} + +std::string IpAddrRep::DebugString() const { + if (!IsInitializedAddress(addr_)) { + return "ip()"; + } + return absl::StrCat("ip('", addr_.ToString(), "')"); +} + +// ----------------------------------------------------------------------------- +// CidrRangeRep Method Implementations +// ----------------------------------------------------------------------------- +Value CidrRangeRep::Create(Arena* arena, const IPAddress& host, + int length) { // Changed signature + CidrRangeRep* rep = Arena::Create( + arena, host, length); // Changed constructor call + return UnsafeOpaqueValue(&kCidrRangeRepDispatcher, + OpaqueValueContent::From(rep)); +} + +const CidrRangeRep* CidrRangeRep::Unwrap(const Value& value) { + auto opaque = value.AsOpaque(); + if (!opaque.has_value() || opaque->GetTypeId() != CidrRangeRep::GetTypeId()) { + return nullptr; + } + return opaque->content().To(); +} + +std::string CidrRangeRep::DebugString() const { + if (!IsInitializedAddress(host_) || + length_ < 0) { // Changed to use host_ and length_ + return "cidr()"; + } + return absl::StrCat("cidr('", host_.ToString(), "/", length_, + "')"); // Changed to use host_ and length_ +} + +// ----------------------------------------------------------------------------- +// CEL Extension Registration +// ----------------------------------------------------------------------------- + +absl::Status ConfigureNetworkFunctions(cel::TypeCheckerBuilder& builder) { + // Register Type Identifiers + CEL_RETURN_IF_ERROR( + builder.AddVariable(cel::MakeVariableDecl("net.IP", TypeOfIpType()))); + CEL_RETURN_IF_ERROR( + builder.AddVariable(cel::MakeVariableDecl("net.CIDR", TypeOfCidrType()))); + + CEL_ASSIGN_OR_RETURN( + auto decl_is_ip, + MakeFunctionDecl("isIP", MakeOverloadDecl("is_ip_string", cel::BoolType(), + cel::StringType()))); + CEL_RETURN_IF_ERROR(builder.AddFunction(decl_is_ip)); + + CEL_ASSIGN_OR_RETURN( + auto decl_ip, + MakeFunctionDecl( + "ip", MakeOverloadDecl("string_to_ip", IpType(), cel::StringType()), + MakeMemberOverloadDecl("cidr_ip", IpType(), CidrType()))); + CEL_RETURN_IF_ERROR(builder.AddFunction(decl_ip)); + + CEL_ASSIGN_OR_RETURN( + auto decl_is_cidr, + MakeFunctionDecl("isCIDR", + MakeOverloadDecl("is_cidr_string", cel::BoolType(), + cel::StringType()))); + CEL_RETURN_IF_ERROR(builder.AddFunction(decl_is_cidr)); + + CEL_ASSIGN_OR_RETURN( + auto decl_cidr, + MakeFunctionDecl("cidr", MakeOverloadDecl("string_to_cidr", CidrType(), + cel::StringType()))); + CEL_RETURN_IF_ERROR(builder.AddFunction(decl_cidr)); + + CEL_ASSIGN_OR_RETURN( + auto decl_ip_is_canonical, + MakeFunctionDecl("ip.isCanonical", + MakeOverloadDecl("ip_is_canonical_string", + cel::BoolType(), cel::StringType()))); + CEL_RETURN_IF_ERROR(builder.AddFunction(decl_ip_is_canonical)); + + CEL_ASSIGN_OR_RETURN( + auto decl_ip_family, + MakeFunctionDecl("family", MakeMemberOverloadDecl( + "ip_family", cel::IntType(), IpType()))); + CEL_RETURN_IF_ERROR(builder.AddFunction(decl_ip_family)); + + CEL_ASSIGN_OR_RETURN( + auto decl_ip_is_loopback, + MakeFunctionDecl( + "isLoopback", + MakeMemberOverloadDecl("ip_is_loopback", cel::BoolType(), IpType()))); + CEL_RETURN_IF_ERROR(builder.AddFunction(decl_ip_is_loopback)); + + CEL_ASSIGN_OR_RETURN( + auto decl_ip_is_global_unicast, + MakeFunctionDecl("isGlobalUnicast", + MakeMemberOverloadDecl("ip_is_global_unicast", + cel::BoolType(), IpType()))); + CEL_RETURN_IF_ERROR(builder.AddFunction(decl_ip_is_global_unicast)); + + CEL_ASSIGN_OR_RETURN( + auto decl_ip_is_link_local_multicast, + MakeFunctionDecl("isLinkLocalMulticast", + MakeMemberOverloadDecl("ip_is_link_local_multicast", + cel::BoolType(), IpType()))); + CEL_RETURN_IF_ERROR(builder.AddFunction(decl_ip_is_link_local_multicast)); + + CEL_ASSIGN_OR_RETURN( + auto decl_ip_is_link_local_unicast, + MakeFunctionDecl("isLinkLocalUnicast", + MakeMemberOverloadDecl("ip_is_link_local_unicast", + cel::BoolType(), IpType()))); + CEL_RETURN_IF_ERROR(builder.AddFunction(decl_ip_is_link_local_unicast)); + + CEL_ASSIGN_OR_RETURN( + auto decl_ip_is_unspecified, + MakeFunctionDecl("isUnspecified", + MakeMemberOverloadDecl("ip_is_unspecified", + cel::BoolType(), IpType()))); + CEL_RETURN_IF_ERROR(builder.AddFunction(decl_ip_is_unspecified)); + + CEL_ASSIGN_OR_RETURN( + auto decl_cidr_contains_ip, + MakeFunctionDecl( + "containsIP", + MakeMemberOverloadDecl("cidr_contains_ip_ip", cel::BoolType(), + CidrType(), IpType()), + MakeMemberOverloadDecl("cidr_contains_ip_string", cel::BoolType(), + CidrType(), cel::StringType()))); + CEL_RETURN_IF_ERROR(builder.AddFunction(decl_cidr_contains_ip)); + + CEL_ASSIGN_OR_RETURN( + auto decl_cidr_contains_cidr, + MakeFunctionDecl( + "containsCIDR", + MakeMemberOverloadDecl("cidr_contains_cidr_cidr", cel::BoolType(), + CidrType(), CidrType()), + MakeMemberOverloadDecl("cidr_contains_cidr_string", cel::BoolType(), + CidrType(), cel::StringType()))); + CEL_RETURN_IF_ERROR(builder.AddFunction(decl_cidr_contains_cidr)); + + CEL_ASSIGN_OR_RETURN( + auto decl_cidr_masked, + MakeFunctionDecl("masked", MakeMemberOverloadDecl( + "cidr_masked", CidrType(), CidrType()))); + CEL_RETURN_IF_ERROR(builder.AddFunction(decl_cidr_masked)); + + CEL_ASSIGN_OR_RETURN( + auto decl_cidr_prefix_length, + MakeFunctionDecl("prefixLength", + MakeMemberOverloadDecl("cidr_prefix_length", + cel::IntType(), CidrType()))); + CEL_RETURN_IF_ERROR(builder.AddFunction(decl_cidr_prefix_length)); + + CEL_ASSIGN_OR_RETURN( + auto decl_string, + MakeFunctionDecl( + kString, + MakeMemberOverloadDecl("ip_to_string", cel::StringType(), IpType()), + MakeMemberOverloadDecl("cidr_to_string", cel::StringType(), + CidrType()))); + CEL_RETURN_IF_ERROR(builder.AddFunction(decl_string)); + + // Add Equality Operator overloads for net.IP and net.CIDR + CEL_ASSIGN_OR_RETURN( + auto decl_equals, + MakeFunctionDecl( + kEqual, + MakeOverloadDecl("ip_equal_ip", cel::BoolType(), IpType(), IpType()), + MakeOverloadDecl("cidr_equal_cidr", cel::BoolType(), CidrType(), + CidrType()))); + CEL_RETURN_IF_ERROR(builder.AddFunction(decl_equals)); + + return absl::OkStatus(); +} + +cel::CompilerLibrary NetworkCompilerLibrary() { + return cel::CompilerLibrary("cel.extensions.network", + ConfigureNetworkFunctions); +} + +absl::Status RegisterNetworkTypes(cel::TypeRegistry& registry, + const cel::RuntimeOptions& options) { + CEL_RETURN_IF_ERROR(registry.RegisterType(IpType())); + CEL_RETURN_IF_ERROR(registry.RegisterType(CidrType())); + return absl::OkStatus(); +} +// Implementation for Opaque type equality +Value OpaqueEq(const Value& v1, const Value& v2, + const Function::InvokeContext& context) { + Value result; + absl::Status status = + v1.Equal(v2, context.descriptor_pool(), context.message_factory(), + context.arena(), &result); + if (!status.ok()) { + // This shouldn't happen if the types are supported by the dispatcher + return ErrorValue(status); + } + return result; +} + +absl::Status RegisterNetworkFunctions(cel::FunctionRegistry& registry, + const cel::RuntimeOptions& options) { + // ... other function registrations ... + CEL_RETURN_IF_ERROR(registry.Register( + UnaryFunctionAdapter::CreateDescriptor("isIP", + false), + UnaryFunctionAdapter::WrapFunction(&NetIsIP))); + CEL_RETURN_IF_ERROR(registry.Register( + UnaryFunctionAdapter::CreateDescriptor("ip", + false), + UnaryFunctionAdapter::WrapFunction( + &NetIPString))); + CEL_RETURN_IF_ERROR(registry.Register( + UnaryFunctionAdapter::CreateDescriptor( + "isCIDR", false), + UnaryFunctionAdapter::WrapFunction( + &NetIsCIDR))); + CEL_RETURN_IF_ERROR(registry.Register( + UnaryFunctionAdapter::CreateDescriptor("cidr", + false), + UnaryFunctionAdapter::WrapFunction( + &NetCIDRString))); + CEL_RETURN_IF_ERROR(registry.Register( + UnaryFunctionAdapter::CreateDescriptor( + "ip.isCanonical", false), + UnaryFunctionAdapter::WrapFunction( + &NetIPIsCanonical))); + + // Register Member Functions for net.IP + CEL_RETURN_IF_ERROR(registry.Register( + UnaryFunctionAdapter::CreateDescriptor( + "family", true), + UnaryFunctionAdapter::WrapFunction( + &NetIPFamily))); + CEL_RETURN_IF_ERROR(registry.Register( + UnaryFunctionAdapter::CreateDescriptor( + "isLoopback", true), + UnaryFunctionAdapter::WrapFunction( + &NetIPIsLoopback))); + CEL_RETURN_IF_ERROR(registry.Register( + UnaryFunctionAdapter::CreateDescriptor( + "isGlobalUnicast", true), + UnaryFunctionAdapter::WrapFunction( + &NetIPIsGlobalUnicast))); + CEL_RETURN_IF_ERROR(registry.Register( + UnaryFunctionAdapter::CreateDescriptor( + "isLinkLocalMulticast", true), + UnaryFunctionAdapter::WrapFunction( + &NetIPIsLinkLocalMulticast))); + CEL_RETURN_IF_ERROR(registry.Register( + UnaryFunctionAdapter::CreateDescriptor( + "isLinkLocalUnicast", true), + UnaryFunctionAdapter::WrapFunction( + &NetIPIsLinkLocalUnicast))); + CEL_RETURN_IF_ERROR(registry.Register( + UnaryFunctionAdapter::CreateDescriptor( + "isUnspecified", true), + UnaryFunctionAdapter::WrapFunction( + &NetIPIsUnspecified))); + + // Register Member Functions for net.CIDR + CEL_RETURN_IF_ERROR(registry.Register( + BinaryFunctionAdapter::CreateDescriptor("containsIP", + true), + BinaryFunctionAdapter:: + WrapFunction(&NetCIDRContainsIP))); + CEL_RETURN_IF_ERROR(registry.Register( + BinaryFunctionAdapter::CreateDescriptor("containsIP", + true), + BinaryFunctionAdapter:: + WrapFunction(&NetCIDRContainsIPString))); + + CEL_RETURN_IF_ERROR(registry.Register( + BinaryFunctionAdapter:: + CreateDescriptor("containsCIDR", true), + BinaryFunctionAdapter:: + WrapFunction(&NetCIDRContainsCIDR))); + CEL_RETURN_IF_ERROR(registry.Register( + BinaryFunctionAdapter:: + CreateDescriptor("containsCIDR", true), + BinaryFunctionAdapter:: + WrapFunction(&NetCIDRContainsCIDRString))); + + CEL_RETURN_IF_ERROR(registry.Register( + UnaryFunctionAdapter::CreateDescriptor("ip", + true), + UnaryFunctionAdapter::WrapFunction( + &NetCIDRIP))); + CEL_RETURN_IF_ERROR(registry.Register( + UnaryFunctionAdapter::CreateDescriptor( + "masked", true), + UnaryFunctionAdapter::WrapFunction( + &NetCIDRMasked))); + CEL_RETURN_IF_ERROR(registry.Register( + UnaryFunctionAdapter::CreateDescriptor( + "prefixLength", true), + UnaryFunctionAdapter::WrapFunction( + &NetCIDRPrefixLength))); + + // Register the combined string function + CEL_RETURN_IF_ERROR(registry.Register( + UnaryFunctionAdapter::CreateDescriptor(kString, + true), + UnaryFunctionAdapter::WrapFunction( + &NetToString))); + + // Register equality for IP and CIDR + CEL_RETURN_IF_ERROR(registry.Register( + BinaryFunctionAdapter::CreateDescriptor(kEqual, + false), + BinaryFunctionAdapter::WrapFunction(&OpaqueEq))); + + return absl::OkStatus(); +} + +} // namespace cel::extensions diff --git a/extensions/network_ext.h b/extensions/network_ext.h new file mode 100644 index 000000000..1da10eb16 --- /dev/null +++ b/extensions/network_ext.h @@ -0,0 +1,38 @@ +// Copyright 2025 Google LLC +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_EXTENSIONS_NETWORK_EXT_H_ +#define THIRD_PARTY_CEL_CPP_EXTENSIONS_NETWORK_EXT_H_ + +#include "absl/status/status.h" +#include "compiler/compiler.h" +#include "runtime/function_registry.h" +#include "runtime/runtime_options.h" +#include "runtime/type_registry.h" + +namespace cel::extensions { + +// Provides a CEL compiler library for network functions. +cel::CompilerLibrary NetworkCompilerLibrary(); + +// Registers network function overloads with the function registry. +absl::Status RegisterNetworkFunctions(cel::FunctionRegistry& registry, + const cel::RuntimeOptions& options); + +// Registers network types with the type registry. +absl::Status RegisterNetworkTypes(cel::TypeRegistry& registry, + const cel::RuntimeOptions& options); + +} // namespace cel::extensions + +#endif // THIRD_PARTY_CEL_CPP_EXTENSIONS_NETWORK_EXT_H_ diff --git a/extensions/network_ext_functions.cc b/extensions/network_ext_functions.cc new file mode 100644 index 000000000..bccce74f0 --- /dev/null +++ b/extensions/network_ext_functions.cc @@ -0,0 +1,369 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "extensions/network_ext_functions.h" + +#include +#include + +#include "net/base/ipaddress.h" +#include "net/util/ipaddress_util.h" +#include "absl/status/status.h" +#include "absl/strings/numbers.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_split.h" +#include "absl/strings/string_view.h" +#include "common/value.h" +#include "common/values/error_value.h" +#include "runtime/function.h" + +namespace cel::extensions { + +namespace { + +using ::cel::BoolValue; +using ::cel::ErrorValue; +using ::cel::Function; +using ::cel::IntValue; +using ::cel::StringValue; +using ::cel::Value; +using ::net_base::GetMappedIPv4Address; +using ::net_base::IPAddress; +using ::net_base::IPRange; +using ::net_base::IsAnyIPAddress; +using ::net_base::IsInitializedAddress; +using ::net_base::IsLoopbackIPAddress; +using ::net_base::IsProperSubRange; +using ::net_base::IsWithinSubnet; +using ::net_base::StringToIPAddress; +using ::net_util::IsLinkLocalIP; +using ::net_util::IsNonRoutableIP; + +// ----------------------------------------------------------------------------- +// Strict Parsing Helpers +// ----------------------------------------------------------------------------- + +bool IsStrictIP(const IPAddress& addr) { + if (!IsInitializedAddress(addr)) return false; + IPAddress unused; + // Check for IPv4-mapped IPv6 addresses. + if (GetMappedIPv4Address(addr, &unused)) { + return false; + } + // zone() is not a member of net_base::IPAddress + return true; +} + +// Helper to parse CIDR string into host and length without truncation +bool ParseCIDRWithoutTruncation(absl::string_view s, IPAddress* host, + int* length) { + std::vector parts = absl::StrSplit(s, '/'); + if (parts.size() != 2) { + return false; + } + if (!StringToIPAddress(parts[0], host)) { + return false; + } + if (!absl::SimpleAtoi(parts[1], length)) { + return false; + } + if ((*length < 0 || (*host).is_ipv4()) && + (*length > 32 || (*host).is_ipv6()) && *length > 128) { + return false; + } + return true; +} + +bool IsStrictCIDR(const IPAddress& host, int length) { + if (!IsInitializedAddress(host)) return false; + return IsStrictIP(host); +} + +} // namespace + +// ----------------------------------------------------------------------------- +// CEL Function Implementations +// ----------------------------------------------------------------------------- + +// isIP(string) -> bool +Value NetIsIP(const StringValue& str_val, + const Function::InvokeContext& context) { + IPAddress addr; + if (!StringToIPAddress(std::string(str_val.ToString()), &addr)) { + return BoolValue(false); + } + return BoolValue(IsStrictIP(addr)); +} + +// ip(string) -> net.IP +Value NetIPString(const StringValue& str_val, + const Function::InvokeContext& context) { + std::string str = std::string(str_val.ToString()); + IPAddress addr; + if (!StringToIPAddress(str, &addr)) { + return ErrorValue(absl::InvalidArgumentError( + absl::StrCat("IP Address '", str, "' parse error"))); + } + if (!IsStrictIP(addr)) { + return ErrorValue(absl::InvalidArgumentError(absl::StrCat( + "IP Address '", str, "' is not a strict IP (e.g., mapped IPv4)"))); + } + return IpAddrRep::Create(context.arena(), addr); +} + +// isCIDR(string) -> bool +Value NetIsCIDR(const StringValue& str_val, + const Function::InvokeContext& context) { + std::string str = std::string(str_val.ToString()); + IPAddress host; + int length; + if (!ParseCIDRWithoutTruncation(str, &host, &length)) { + return BoolValue(false); + } + return BoolValue(IsStrictCIDR(host, length)); +} + +// cidr(string) -> net.CIDR +Value NetCIDRString(const StringValue& str_val, + const Function::InvokeContext& context) { + std::string str = std::string(str_val.ToString()); + IPAddress host; + int length; + if (!ParseCIDRWithoutTruncation(str, &host, &length)) { + return ErrorValue(absl::InvalidArgumentError( + absl::StrCat("CIDR '", str, "' parse error"))); + } + + if (!IsStrictCIDR(host, length)) { + return ErrorValue(absl::InvalidArgumentError(absl::StrCat( + "CIDR '", str, "' is not a strict CIDR (e.g., mapped IPv4)"))); + } + return CidrRangeRep::Create(context.arena(), host, length); +} + +// .family() -> int +Value NetIPFamily(const OpaqueValue& self, + const Function::InvokeContext& context) { + const IpAddrRep* rep = IpAddrRep::Unwrap(self); + if (!rep || !IsInitializedAddress(rep->addr())) { + return ErrorValue(absl::InvalidArgumentError("Uninitialized IPAddress")); + } + switch (rep->addr().address_family()) { + case AF_INET: + return IntValue(4); + case AF_INET6: + return IntValue(6); + default: + return ErrorValue(absl::InvalidArgumentError("Unknown address family")); + } +} + +// .isLoopback() -> bool +Value NetIPIsLoopback(const OpaqueValue& self, + const Function::InvokeContext& context) { + const IpAddrRep* rep = IpAddrRep::Unwrap(self); + if (!rep || !IsInitializedAddress(rep->addr())) { + return ErrorValue(absl::InvalidArgumentError("Uninitialized IPAddress")); + } + return BoolValue(IsLoopbackIPAddress(rep->addr())); +} + +// .isGlobalUnicast() -> bool +Value NetIPIsGlobalUnicast(const OpaqueValue& self, + const Function::InvokeContext& context) { + const IpAddrRep* rep = IpAddrRep::Unwrap(self); + if (!rep || !IsInitializedAddress(rep->addr())) { + return ErrorValue(absl::InvalidArgumentError("Uninitialized IPAddress")); + } + const IPAddress& addr = rep->addr(); + + if (IsAnyIPAddress(addr) || addr == IPAddress::Loopback4() || + addr == IPAddress::Loopback6() || IsLinkLocalIP(addr) || + IsNonRoutableIP(addr)) { + return BoolValue(false); + } + + if (addr.is_ipv4()) { + return BoolValue(!net_base::IsV4MulticastIPAddress(addr)); + } + + if (addr.is_ipv6()) { + in6_addr addr6 = addr.ipv6_address(); + if (IN6_IS_ADDR_MULTICAST(&addr6)) { + return BoolValue(false); + } + return BoolValue(true); + } + return BoolValue(false); +} + +Value NetIPIsLinkLocalMulticast(const OpaqueValue& self, + const Function::InvokeContext& context) { + const IpAddrRep* rep = IpAddrRep::Unwrap(self); + if (!rep || !rep->addr().is_ipv6()) { + return BoolValue(false); + } + in6_addr addr6 = rep->addr().ipv6_address(); + return BoolValue(IN6_IS_ADDR_MC_LINKLOCAL(&addr6)); +} + +Value NetIPIsLinkLocalUnicast(const OpaqueValue& self, + const Function::InvokeContext& context) { + const IpAddrRep* rep = IpAddrRep::Unwrap(self); + if (!rep || !IsInitializedAddress(rep->addr())) { + return ErrorValue(absl::InvalidArgumentError("Uninitialized IPAddress")); + } + return BoolValue(IsLinkLocalIP(rep->addr())); +} + +Value NetIPIsUnspecified(const OpaqueValue& self, + const Function::InvokeContext& context) { + const IpAddrRep* rep = IpAddrRep::Unwrap(self); + if (!rep) { + return ErrorValue(absl::InvalidArgumentError("Invalid IP object")); + } + return BoolValue(IsAnyIPAddress(rep->addr())); +} + +Value NetIPIsCanonical(const StringValue& str_val, + const Function::InvokeContext& context) { + std::string str = std::string(str_val.ToString()); + IPAddress addr; + if (!StringToIPAddress(str, &addr)) { + return BoolValue(false); + } + if (!IsStrictIP(addr)) { + return BoolValue(false); + } + return BoolValue(addr.ToString() == str); +} + +Value NetCIDRContainsIP(const OpaqueValue& self, const OpaqueValue& other, + const Function::InvokeContext& context) { + const CidrRangeRep* self_rep = CidrRangeRep::Unwrap(self); + const IpAddrRep* other_rep = IpAddrRep::Unwrap(other); + if (!self_rep || !IsInitializedAddress(self_rep->host()) || + self_rep->length() < 0 || !other_rep || + !IsInitializedAddress(other_rep->addr())) { + return ErrorValue(absl::InvalidArgumentError("Uninitialized CIDR or IP")); + } + return BoolValue(IsWithinSubnet(self_rep->ToIPRange(), other_rep->addr())); +} + +Value NetCIDRContainsIPString(const OpaqueValue& self, + const StringValue& other_str, + const Function::InvokeContext& context) { + const CidrRangeRep* self_rep = CidrRangeRep::Unwrap(self); + if (!self_rep || !IsInitializedAddress(self_rep->host()) || + self_rep->length() < 0) { + return ErrorValue(absl::InvalidArgumentError("Uninitialized CIDR")); + } + + std::string str = std::string(other_str.ToString()); + IPAddress other_addr; + if (!StringToIPAddress(str, &other_addr) || !IsStrictIP(other_addr)) { + return ErrorValue(absl::InvalidArgumentError( + absl::StrCat("Invalid or non-strict IP string: ", str))); + } + return BoolValue(IsWithinSubnet(self_rep->ToIPRange(), other_addr)); +} + +Value NetCIDRContainsCIDR(const OpaqueValue& self, const OpaqueValue& other, + const Function::InvokeContext& context) { + const CidrRangeRep* self_rep = CidrRangeRep::Unwrap(self); + const CidrRangeRep* other_rep = CidrRangeRep::Unwrap(other); + if (!self_rep || !IsInitializedAddress(self_rep->host()) || + self_rep->length() < 0 || !other_rep || + !IsInitializedAddress(other_rep->host()) || other_rep->length() < 0) { + return ErrorValue(absl::InvalidArgumentError("Uninitialized CIDR")); + } + IPRange self_range = self_rep->ToIPRange(); + IPRange other_range = other_rep->ToIPRange(); + return BoolValue(self_range == other_range || + IsProperSubRange(self_range, other_range)); +} + +Value NetCIDRContainsCIDRString(const OpaqueValue& self, + const StringValue& other_str, + const Function::InvokeContext& context) { + const CidrRangeRep* self_rep = CidrRangeRep::Unwrap(self); + if (!self_rep || !IsInitializedAddress(self_rep->host()) || + self_rep->length() < 0) { + return ErrorValue(absl::InvalidArgumentError("Uninitialized CIDR")); + } + + std::string str = std::string(other_str.ToString()); + IPAddress other_host; + int other_length; + if (!ParseCIDRWithoutTruncation(str, &other_host, &other_length) || + !IsStrictCIDR(other_host, other_length)) { + return ErrorValue(absl::InvalidArgumentError( + absl::StrCat("Invalid or non-strict CIDR string: ", str))); + } + IPRange self_range = self_rep->ToIPRange(); + IPRange other_range(other_host, other_length); + return BoolValue(self_range == other_range || + IsProperSubRange(self_range, other_range)); +} + +Value NetCIDRIP(const OpaqueValue& self, + const Function::InvokeContext& context) { + const CidrRangeRep* rep = CidrRangeRep::Unwrap(self); + if (!rep || !IsInitializedAddress(rep->host()) || rep->length() < 0) { + return ErrorValue(absl::InvalidArgumentError("Uninitialized CIDR")); + } + return IpAddrRep::Create(context.arena(), rep->host()); +} + +Value NetCIDRMasked(const OpaqueValue& self, + const Function::InvokeContext& context) { + const CidrRangeRep* rep = CidrRangeRep::Unwrap(self); + if (!rep || !IsInitializedAddress(rep->host()) || rep->length() < 0) { + return ErrorValue(absl::InvalidArgumentError("Uninitialized CIDR")); + } + IPRange masked_range = rep->ToIPRange(); + return CidrRangeRep::Create(context.arena(), masked_range.host(), + masked_range.length()); +} + +Value NetCIDRPrefixLength(const OpaqueValue& self, + const Function::InvokeContext& context) { + const CidrRangeRep* rep = CidrRangeRep::Unwrap(self); + if (!rep || !IsInitializedAddress(rep->host()) || rep->length() < 0) { + return ErrorValue(absl::InvalidArgumentError("Uninitialized CIDR")); + } + return IntValue(rep->length()); +} + +Value NetToString(const OpaqueValue& self, + const Function::InvokeContext& context) { + if (const IpAddrRep* rep = IpAddrRep::Unwrap(self)) { + if (!IsInitializedAddress(rep->addr())) { + return ErrorValue(absl::InvalidArgumentError("Uninitialized IPAddress")); + } + return StringValue::From(rep->addr().ToString(), context.arena()); + } + if (const CidrRangeRep* rep = CidrRangeRep::Unwrap(self)) { + if (!IsInitializedAddress(rep->host()) || rep->length() < 0) { + return ErrorValue(absl::InvalidArgumentError("Uninitialized CIDR")); + } + return StringValue::From( + absl::StrCat(rep->host().ToString(), "/", rep->length()), + context.arena()); + } + return ErrorValue( + absl::InvalidArgumentError("Unsupported type for string()")); +} + +} // namespace cel::extensions diff --git a/extensions/network_ext_functions.h b/extensions/network_ext_functions.h new file mode 100644 index 000000000..81c51e50c --- /dev/null +++ b/extensions/network_ext_functions.h @@ -0,0 +1,128 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// extensions/network_ext_functions.h + +#ifndef THIRD_PARTY_CEL_CPP_EXTENSIONS_NETWORK_EXT_FUNCTIONS_H_ +#define THIRD_PARTY_CEL_CPP_EXTENSIONS_NETWORK_EXT_FUNCTIONS_H_ + +#include + +#include "net/base/ipaddress.h" +#include "common/native_type.h" +#include "common/typeinfo.h" +#include "common/value.h" +#include "runtime/function.h" +#include "google/protobuf/arena.h" + +namespace cel::extensions { + +// ... IpAddrRep and CidrRangeRep classes ... +class IpAddrRep { + public: + static cel::Value Create(google::protobuf::Arena* arena, + const net_base::IPAddress& addr); + static const IpAddrRep* Unwrap(const cel::Value& value); + IpAddrRep() = default; + explicit IpAddrRep(const net_base::IPAddress& addr) : addr_(addr) {} + const net_base::IPAddress& addr() const { return addr_; } + bool Equals(const IpAddrRep& other) const { return addr_ == other.addr_; } + std::string DebugString() const; + static cel::NativeTypeId GetTypeId() { return cel::TypeId(); } + + private: + net_base::IPAddress addr_; +}; + +class CidrRangeRep { + public: + static cel::Value Create(google::protobuf::Arena* arena, + const net_base::IPAddress& host, int length); + static const CidrRangeRep* Unwrap(const cel::Value& value); + + CidrRangeRep() = default; + explicit CidrRangeRep(const net_base::IPAddress& host, int length) + : host_(host), length_(length) {} + + const net_base::IPAddress& host() const { return host_; } + int length() const { return length_; } + + // Utility to get the net_base::IPRange (which will be truncated) + net_base::IPRange ToIPRange() const { + return net_base::IPRange(host_, length_); + } + + bool Equals(const CidrRangeRep& other) const { + return length_ == other.length_ && host_ == other.host_; + } + std::string DebugString() const; + + static cel::NativeTypeId GetTypeId() { return cel::TypeId(); } + + template + friend H AbslHashValue(H h, const CidrRangeRep& c) { + return H::combine(std::move(h), c.host_, c.length_); + } + + private: + net_base::IPAddress host_; + int length_ = -1; +}; + +// Declarations +cel::Value NetIsIP(const cel::StringValue& str_val, + const cel::Function::InvokeContext& context); +cel::Value NetIPString(const cel::StringValue& str_val, + const cel::Function::InvokeContext& context); +cel::Value NetIsCIDR(const cel::StringValue& str_val, + const cel::Function::InvokeContext& context); +cel::Value NetCIDRString(const cel::StringValue& str_val, + const cel::Function::InvokeContext& context); +cel::Value NetIPFamily(const cel::OpaqueValue& self, + const cel::Function::InvokeContext& context); +cel::Value NetIPIsLoopback(const cel::OpaqueValue& self, + const cel::Function::InvokeContext& context); +cel::Value NetIPIsGlobalUnicast(const cel::OpaqueValue& self, + const cel::Function::InvokeContext& context); +cel::Value NetIPIsLinkLocalMulticast( + const cel::OpaqueValue& self, const cel::Function::InvokeContext& context); +cel::Value NetIPIsLinkLocalUnicast(const cel::OpaqueValue& self, + const cel::Function::InvokeContext& context); +cel::Value NetIPIsUnspecified(const cel::OpaqueValue& self, + const cel::Function::InvokeContext& context); +cel::Value NetIPIsCanonical(const cel::StringValue& str_val, + const cel::Function::InvokeContext& context); +cel::Value NetCIDRContainsIP(const cel::OpaqueValue& self, + const cel::OpaqueValue& other, + const cel::Function::InvokeContext& context); +cel::Value NetCIDRContainsIPString(const cel::OpaqueValue& self, + const cel::StringValue& other_str, + const cel::Function::InvokeContext& context); +cel::Value NetCIDRContainsCIDR(const cel::OpaqueValue& self, + const cel::OpaqueValue& other, + const cel::Function::InvokeContext& context); +cel::Value NetCIDRContainsCIDRString( + const cel::OpaqueValue& self, const cel::StringValue& other_str, + const cel::Function::InvokeContext& context); +cel::Value NetCIDRIP(const cel::OpaqueValue& self, + const cel::Function::InvokeContext& context); +cel::Value NetCIDRMasked(const cel::OpaqueValue& self, + const cel::Function::InvokeContext& context); +cel::Value NetCIDRPrefixLength(const cel::OpaqueValue& self, + const cel::Function::InvokeContext& context); +cel::Value NetToString(const cel::OpaqueValue& self, + const cel::Function::InvokeContext& context); + +} // namespace cel::extensions + +#endif // THIRD_PARTY_CEL_CPP_EXTENSIONS_NETWORK_EXT_FUNCTIONS_H_ diff --git a/extensions/network_ext_test.cc b/extensions/network_ext_test.cc new file mode 100644 index 000000000..73cf0fb2f --- /dev/null +++ b/extensions/network_ext_test.cc @@ -0,0 +1,394 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "extensions/network_ext.h" + +#include +#include + +#include "cel/expr/syntax.pb.h" +#include "absl/status/status.h" +#include "absl/status/status_matchers.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "common/ast.h" +#include "common/minimal_descriptor_pool.h" +#include "common/value.h" +#include "common/values/bool_value.h" +#include "common/values/error_value.h" +#include "common/values/int_value.h" +#include "common/values/string_value.h" +#include "internal/status_macros.h" +#include "internal/testing.h" +#include "runtime/activation.h" +#include "runtime/runtime.h" +#include "runtime/runtime_builder.h" +#include "runtime/runtime_builder_factory.h" +#include "runtime/standard_functions.h" +#include "google/protobuf/arena.h" +#include "google/protobuf/descriptor.h" + +// Includes for Compiler +#include "checker/validation_result.h" +#include "compiler/compiler.h" +#include "compiler/compiler_factory.h" +#include "runtime/runtime_options.h" + +namespace cel::extensions { +namespace { + +using ::absl_testing::IsOk; +using ::absl_testing::StatusIs; +using ::cel::Activation; +using ::testing::Eq; +using ::testing::HasSubstr; + +class NetworkExtTest : public ::testing::Test { + protected: + NetworkExtTest() = default; + + void SetUp() override { + // 1. Configure the Compiler + auto compiler_builder = + cel::NewCompilerBuilder(cel::GetMinimalDescriptorPool()); + ASSERT_THAT(compiler_builder.status(), IsOk()); + ASSERT_THAT((*compiler_builder)->AddLibrary(NetworkCompilerLibrary()), + IsOk()); + ASSERT_OK_AND_ASSIGN(compiler_, std::move(*compiler_builder)->Build()); + + // 2. Configure the Modern Runtime + cel::RuntimeOptions runtime_options; + // Wrap the raw pointer in a std::shared_ptr with a NO-OP DELETER + std::shared_ptr descriptor_pool( + cel::GetMinimalDescriptorPool(), [](const google::protobuf::DescriptorPool*) { + // Do nothing, as the pool is static. + }); + + auto runtime_builder = + cel::CreateRuntimeBuilder(descriptor_pool, runtime_options); + ASSERT_THAT(runtime_builder.status(), + IsOk()); // Check if CreateRuntimeBuilder succeeded + + ASSERT_THAT( + RegisterNetworkTypes(runtime_builder->type_registry(), runtime_options), + IsOk()); + ASSERT_THAT(RegisterNetworkFunctions(runtime_builder->function_registry(), + runtime_options), + IsOk()); + + ASSERT_THAT(cel::RegisterStandardFunctions( + runtime_builder->function_registry(), runtime_options), + IsOk()); + + // Build the runtime + ASSERT_OK_AND_ASSIGN(runtime_, std::move(*runtime_builder).Build()); + } + + // ... Evaluate() function and member variables ... + absl::StatusOr Evaluate(absl::string_view expr) { + auto validation_result = compiler_->Compile(expr); + CEL_RETURN_IF_ERROR(validation_result.status()); + + if (!validation_result->GetIssues().empty()) { + return absl::InvalidArgumentError( + validation_result->GetIssues()[0].message()); + } + + if (!validation_result->IsValid()) { + return absl::InternalError( + "Compilation produced an invalid AST without issues."); + } + + CEL_ASSIGN_OR_RETURN(std::unique_ptr ast, + validation_result->ReleaseAst()); + + if (ast == nullptr) { + return absl::InternalError("ValidationResult returned a null AST."); + } + + CEL_ASSIGN_OR_RETURN(auto program, runtime_->CreateProgram(std::move(ast))); + + Activation activation; + return program->Evaluate(&arena_, activation); + } + + std::unique_ptr compiler_; + std::unique_ptr runtime_; + google::protobuf::Arena arena_; +}; + +// --- Global Checks (isIP, isCIDR) --- +TEST_F(NetworkExtTest, IsIPValidIPv4) { + ASSERT_OK_AND_ASSIGN(auto value, Evaluate("isIP('1.2.3.4')")); + ASSERT_TRUE(value.IsBool()); + EXPECT_THAT(value.As()->NativeValue(), Eq(true)); +} + +TEST_F(NetworkExtTest, IsIPValidIPv6) { + ASSERT_OK_AND_ASSIGN(auto value, Evaluate("isIP('2001:db8::1')")); + ASSERT_TRUE(value.IsBool()); + EXPECT_THAT(value.As()->NativeValue(), Eq(true)); +} + +TEST_F(NetworkExtTest, IsIPInvalid) { + ASSERT_OK_AND_ASSIGN(auto value, Evaluate("isIP('not.an.ip')")); + ASSERT_TRUE(value.IsBool()); + EXPECT_THAT(value.As()->NativeValue(), Eq(false)); +} + +TEST_F(NetworkExtTest, IsIPWithPort) { + ASSERT_OK_AND_ASSIGN(auto value, Evaluate("isIP('127.0.0.1:80')")); + ASSERT_TRUE(value.IsBool()); + EXPECT_THAT(value.As()->NativeValue(), Eq(false)); +} + +TEST_F(NetworkExtTest, IsCIDRValid) { + ASSERT_OK_AND_ASSIGN(auto value, Evaluate("isCIDR('10.0.0.0/8')")); + ASSERT_TRUE(value.IsBool()); + EXPECT_THAT(value.As()->NativeValue(), Eq(true)); +} + +TEST_F(NetworkExtTest, IsCIDRInvalidMask) { + ASSERT_OK_AND_ASSIGN(auto value, Evaluate("isCIDR('10.0.0.0/999')")); + ASSERT_TRUE(value.IsBool()); + EXPECT_THAT(value.As()->NativeValue(), Eq(false)); +} + +// --- IP Constructors & Equality --- +TEST_F(NetworkExtTest, IPEqualityIPv4) { + ASSERT_OK_AND_ASSIGN(auto value, + Evaluate("ip('127.0.0.1') == ip('127.0.0.1')")); + ASSERT_TRUE(value.IsBool()); + EXPECT_THAT(value.As()->NativeValue(), Eq(true)); +} + +TEST_F(NetworkExtTest, IPInequality) { + ASSERT_OK_AND_ASSIGN(auto value, + Evaluate("ip('127.0.0.1') == ip('1.2.3.4')")); + ASSERT_TRUE(value.IsBool()); + EXPECT_THAT(value.As()->NativeValue(), Eq(false)); +} + +TEST_F(NetworkExtTest, IPEqualityIPv6MixedCase) { + ASSERT_OK_AND_ASSIGN(auto value, + Evaluate("ip('2001:db8::1') == ip('2001:DB8::1')")); + ASSERT_TRUE(value.IsBool()); + EXPECT_THAT(value.As()->NativeValue(), Eq(true)); +} + +// --- String Conversion --- +TEST_F(NetworkExtTest, IPToStringIPv4) { + ASSERT_OK_AND_ASSIGN(auto value, Evaluate("ip('1.2.3.4').string()")); + ASSERT_TRUE(value.IsString()); + EXPECT_THAT(value.As()->ToString(), Eq("1.2.3.4")); +} + +TEST_F(NetworkExtTest, IPToStringIPv6) { + ASSERT_OK_AND_ASSIGN(auto value, + Evaluate("ip('2001:db8::1').string()")); // .string() + ASSERT_TRUE(value.IsString()); + EXPECT_THAT(value.As()->ToString(), Eq("2001:db8::1")); +} + +TEST_F(NetworkExtTest, CIDRToStringIPv4) { + ASSERT_OK_AND_ASSIGN(auto value, + Evaluate("cidr('10.0.0.0/8').string()")); // .string() + ASSERT_TRUE(value.IsString()); + EXPECT_THAT(value.As()->ToString(), Eq("10.0.0.0/8")); +} + +TEST_F(NetworkExtTest, CIDRToStringIPv6) { + ASSERT_OK_AND_ASSIGN(auto value, + Evaluate("cidr('::1/128').string()")); // .string() + ASSERT_TRUE(value.IsString()); + EXPECT_THAT(value.As()->ToString(), Eq("::1/128")); +} + +// --- Family --- +TEST_F(NetworkExtTest, FamilyIPv4) { + ASSERT_OK_AND_ASSIGN(auto value, Evaluate("ip('127.0.0.1').family()")); + ASSERT_TRUE(value.IsInt()); + EXPECT_THAT(value.As()->NativeValue(), Eq(4)); +} + +TEST_F(NetworkExtTest, FamilyIPv6) { + ASSERT_OK_AND_ASSIGN(auto value, Evaluate("ip('::1').family()")); + ASSERT_TRUE(value.IsInt()); + EXPECT_THAT(value.As()->NativeValue(), Eq(6)); +} + +// --- Canonicalization --- +TEST_F(NetworkExtTest, IsCanonicalIPv4) { + ASSERT_OK_AND_ASSIGN(auto value, Evaluate("ip.isCanonical('127.0.0.1')")); + ASSERT_TRUE(value.IsBool()); + EXPECT_THAT(value.As()->NativeValue(), Eq(true)); +} + +TEST_F(NetworkExtTest, IsCanonicalIPv6) { + ASSERT_OK_AND_ASSIGN(auto value, Evaluate("ip.isCanonical('2001:db8::1')")); + ASSERT_TRUE(value.IsBool()); + EXPECT_THAT(value.As()->NativeValue(), Eq(true)); +} + +TEST_F(NetworkExtTest, IsCanonicalIPv6Uppercase) { + ASSERT_OK_AND_ASSIGN(auto value, Evaluate("ip.isCanonical('2001:DB8::1')")); + ASSERT_TRUE(value.IsBool()); + EXPECT_THAT(value.As()->NativeValue(), Eq(false)); +} + +TEST_F(NetworkExtTest, IsCanonicalIPv6Expanded) { + ASSERT_OK_AND_ASSIGN(auto value, + Evaluate("ip.isCanonical('2001:db8:0:0:0:0:0:1')")); + ASSERT_TRUE(value.IsBool()); + EXPECT_THAT(value.As()->NativeValue(), Eq(false)); +} + +// --- IP Types (Loopback, Unspecified, etc) --- +TEST_F(NetworkExtTest, IsLoopbackIPv4) { + ASSERT_OK_AND_ASSIGN(auto value, Evaluate("ip('127.0.0.1').isLoopback()")); + ASSERT_TRUE(value.IsBool()); + EXPECT_THAT(value.As()->NativeValue(), Eq(true)); +} + +TEST_F(NetworkExtTest, IsLoopbackIPv6) { + ASSERT_OK_AND_ASSIGN(auto value, Evaluate("ip('::1').isLoopback()")); + ASSERT_TRUE(value.IsBool()); + EXPECT_THAT(value.As()->NativeValue(), Eq(true)); +} + +TEST_F(NetworkExtTest, IsUnspecifiedIPv4) { + ASSERT_OK_AND_ASSIGN(auto value, Evaluate("ip('0.0.0.0').isUnspecified()")); + ASSERT_TRUE(value.IsBool()); + EXPECT_THAT(value.As()->NativeValue(), Eq(true)); +} + +TEST_F(NetworkExtTest, IsUnspecifiedIPv6) { + ASSERT_OK_AND_ASSIGN(auto value, Evaluate("ip('::').isUnspecified()")); + ASSERT_TRUE(value.IsBool()); + EXPECT_THAT(value.As()->NativeValue(), Eq(true)); +} + +TEST_F(NetworkExtTest, IsGlobalUnicast) { + ASSERT_OK_AND_ASSIGN(auto value, Evaluate("ip('8.8.8.8').isGlobalUnicast()")); + ASSERT_TRUE(value.IsBool()); + EXPECT_THAT(value.As()->NativeValue(), Eq(true)); +} + +TEST_F(NetworkExtTest, IsLinkLocalMulticast) { + ASSERT_OK_AND_ASSIGN(auto value, + Evaluate("ip('ff02::1').isLinkLocalMulticast()")); + ASSERT_TRUE(value.IsBool()); + EXPECT_THAT(value.As()->NativeValue(), Eq(true)); +} + +// --- CIDR Accessors --- +TEST_F(NetworkExtTest, CIDRPrefixLength) { + ASSERT_OK_AND_ASSIGN(auto value, + Evaluate("cidr('192.168.0.0/24').prefixLength()")); + ASSERT_TRUE(value.IsInt()); + EXPECT_THAT(value.As()->NativeValue(), Eq(24)); +} + +TEST_F(NetworkExtTest, CIDRIPExtraction) { + ASSERT_OK_AND_ASSIGN( + auto value, Evaluate("cidr('192.168.0.0/24').ip() == ip('192.168.0.0')")); + ASSERT_TRUE(value.IsBool()); + EXPECT_THAT(value.As()->NativeValue(), Eq(true)); +} + +TEST_F(NetworkExtTest, CIDRIPExtractionHostBitsSet) { + ASSERT_OK_AND_ASSIGN( + auto value, Evaluate("cidr('192.168.1.5/24').ip() == ip('192.168.1.5')")); + ASSERT_TRUE(value.IsBool()); + EXPECT_THAT(value.As()->NativeValue(), Eq(true)); +} + +TEST_F(NetworkExtTest, CIDRMasked) { + ASSERT_OK_AND_ASSIGN( + auto value, + Evaluate("cidr('192.168.1.5/24').masked() == cidr('192.168.1.0/24')")); + ASSERT_TRUE(value.IsBool()); + EXPECT_THAT(value.As()->NativeValue(), Eq(true)); +} + +TEST_F(NetworkExtTest, CIDRMaskedIdentity) { + ASSERT_OK_AND_ASSIGN( + auto value, + Evaluate("cidr('192.168.1.0/24').masked() == cidr('192.168.1.0/24')")); + ASSERT_TRUE(value.IsBool()); + EXPECT_THAT(value.As()->NativeValue(), Eq(true)); +} + +// --- Containment (IP in CIDR) --- +TEST_F(NetworkExtTest, ContainsIPSimple) { + ASSERT_OK_AND_ASSIGN( + auto value, Evaluate("cidr('10.0.0.0/8').containsIP(ip('10.1.2.3'))")); + ASSERT_TRUE(value.IsBool()); + EXPECT_THAT(value.As()->NativeValue(), Eq(true)); +} + +TEST_F(NetworkExtTest, ContainsIPStringOverload) { + ASSERT_OK_AND_ASSIGN(auto value, + Evaluate("cidr('10.0.0.0/8').containsIP('10.1.2.3')")); + ASSERT_TRUE(value.IsBool()); + EXPECT_THAT(value.As()->NativeValue(), Eq(true)); +} + +// ... other Contains tests ... + +// --- Runtime Errors --- +TEST_F(NetworkExtTest, ErrIPConstructorInvalid) { + ASSERT_OK_AND_ASSIGN(auto value, Evaluate("ip('999.999.999.999')")); + ASSERT_TRUE(value.IsError()); + EXPECT_THAT( + value.As()->ToStatus(), + StatusIs(absl::StatusCode::kInvalidArgument, HasSubstr("parse error"))); +} + +TEST_F(NetworkExtTest, ErrCIDRConstructorInvalid) { + ASSERT_OK_AND_ASSIGN(auto value, Evaluate("cidr('1.2.3.4')")); + ASSERT_TRUE(value.IsError()); + EXPECT_THAT( + value.As()->ToStatus(), + StatusIs(absl::StatusCode::kInvalidArgument, HasSubstr("parse error"))); +} + +TEST_F(NetworkExtTest, ErrCIDRConstructorInvalidMask) { + ASSERT_OK_AND_ASSIGN(auto value, Evaluate("cidr('10.0.0.0/999')")); + ASSERT_TRUE(value.IsError()); + EXPECT_THAT( + value.As()->ToStatus(), + StatusIs(absl::StatusCode::kInvalidArgument, HasSubstr("parse error"))); +} + +TEST_F(NetworkExtTest, ErrContainsIPStringInvalid) { + ASSERT_OK_AND_ASSIGN(auto value, + Evaluate("cidr('10.0.0.0/8').containsIP('not-an-ip')")); + ASSERT_TRUE(value.IsError()); + EXPECT_THAT(value.As()->ToStatus(), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("Invalid or non-strict IP string"))); +} + +TEST_F(NetworkExtTest, ErrContainsCIDRStringInvalid) { + ASSERT_OK_AND_ASSIGN( + auto value, Evaluate("cidr('10.0.0.0/8').containsCIDR('not-a-cidr')")); + ASSERT_TRUE(value.IsError()); + EXPECT_THAT(value.As()->ToStatus(), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("Invalid or non-strict CIDR string"))); +} + +} // namespace +} // namespace cel::extensions