From f4ae8e9a34a456b7c26755393e25e612012925a8 Mon Sep 17 00:00:00 2001 From: A Googler Date: Tue, 13 Feb 2024 21:16:26 -0800 Subject: [PATCH] Add support for testing rules_cc's new toolchains with rules_testing. PiperOrigin-RevId: 606847690 --- lib/private/analysis_test.bzl | 18 +++++++++++++++--- lib/private/target_subject.bzl | 32 +++++++++++++++++++++++--------- 2 files changed, 38 insertions(+), 12 deletions(-) diff --git a/lib/private/analysis_test.bzl b/lib/private/analysis_test.bzl index c491ebb..52bd051 100644 --- a/lib/private/analysis_test.bzl +++ b/lib/private/analysis_test.bzl @@ -21,6 +21,7 @@ load("@bazel_skylib//lib:dicts.bzl", "dicts") load("@bazel_skylib//lib:types.bzl", "types") load("//lib:truth.bzl", "truth") load("//lib:util.bzl", "recursive_testing_aspect", "testing_aspect") +load("//lib/private:target_subject.bzl", "PROVIDER_SUBJECT_FACTORIES") load("//lib/private:util.bzl", "get_test_name_from_function") def _fail(env, msg): @@ -37,7 +38,7 @@ def _fail(env, msg): print(full_msg) env.failures.append(full_msg) -def _begin_analysis_test(ctx): +def _begin_analysis_test(ctx, provider_factories): """Begins a unit test. This should be the first function called in a unit test implementation @@ -48,6 +49,10 @@ def _begin_analysis_test(ctx): Args: ctx: The Starlark context. Pass the implementation function's `ctx` argument in verbatim. + provider_factories: List[struct(type, name, factory)] Additional provider + factories that should be available to `target.provider(...)`. + Eg. `[struct(type=FooInfo, name="FooInfo", factory=FooInfoFactory)]`. + Returns: An analysis_test "environment" struct. The following fields are public: @@ -86,6 +91,7 @@ def _begin_analysis_test(ctx): truth_env = struct( ctx = ctx, fail = lambda msg: _fail(failures_env, msg), + providers = PROVIDER_SUBJECT_FACTORIES + provider_factories, ) analysis_test_env = struct( ctx = ctx, @@ -126,7 +132,8 @@ def analysis_test( fragments = [], config_settings = {}, extra_target_under_test_aspects = [], - collect_actions_recursively = False): + collect_actions_recursively = False, + provider_factories = []): """Creates an analysis test from its implementation function. An analysis test verifies the behavior of a "real" rule target by examining @@ -189,6 +196,7 @@ def analysis_test( analysis test target itself (e.g. common attributes like `tags`, `target_compatible_with`, or attributes from `attrs`). Note that these are for the analysis test target itself, not the target under test. + fragments: An optional list of fragment names that can be used to give rules access to language-specific parts of configuration. config_settings: A dictionary of configuration settings to change for the target under @@ -202,6 +210,10 @@ def analysis_test( in addition to those set up by default for the test harness itself. collect_actions_recursively: If true, runs testing_aspect over all attributes, otherwise it is only applied to the target under test. + provider_factories: Optional[List[struct(type, name, factory)]] + Additional provider factories that should be available to + `target.provider(...)`. + Eg. `[struct(type=FooInfo, name="FooInfo", factory=FooInfoFactory)]`. Returns: (None) @@ -290,7 +302,7 @@ def analysis_test( ) def wrapped_impl(ctx): - env, target = _begin_analysis_test(ctx) + env, target = _begin_analysis_test(ctx, provider_factories) impl(env, target) return _end_analysis_test(env) diff --git a/lib/private/target_subject.bzl b/lib/private/target_subject.bzl index 47d8b94..bbb99f9 100644 --- a/lib/private/target_subject.bzl +++ b/lib/private/target_subject.bzl @@ -237,10 +237,12 @@ def _target_subject_provider(self, provider_key, factory = None): Returns: A subject wrapper of the provider value. """ + provider_name = str(factory) if not factory: - for key, value in _PROVIDER_SUBJECT_FACTORIES: - if key == provider_key: - factory = value + for provider in self.meta.env.providers: + if provider.type == provider_key: + factory = provider.factory + provider_name = provider.name break if not factory: @@ -249,7 +251,7 @@ def _target_subject_provider(self, provider_key, factory = None): return factory( info, - meta = self.meta.derive("provider({})".format(provider_key)), + meta = self.meta.derive("provider({})".format(provider_name)), ) def _target_subject_action_generating(self, short_path): @@ -385,11 +387,23 @@ def _target_subject_attr(self, name, *, factory = None): meta = self.meta.derive("attr({})".format(name)), ) -# Providers aren't hashable, so we have to use a list of (key, value) -_PROVIDER_SUBJECT_FACTORIES = [ - (InstrumentedFilesInfo, InstrumentedFilesInfoSubject.new), - (RunEnvironmentInfo, RunEnvironmentInfoSubject.new), - (testing.ExecutionInfo, ExecutionInfoSubject.new), +# Providers aren't hashable, so we have to use a list of structs. +PROVIDER_SUBJECT_FACTORIES = [ + struct( + type = InstrumentedFilesInfo, + name = str(InstrumentedFilesInfo), + factory = InstrumentedFilesInfoSubject.new, + ), + struct( + type = RunEnvironmentInfo, + name = str(RunEnvironmentInfo), + factory = RunEnvironmentInfoSubject, + ), + struct( + type = testing.ExecutionInfo, + name = str(testing.ExecutionInfo), + factory = ExecutionInfoSubject.new, + ), ] def _provider_name(provider):