diff --git a/sdks/go/container/tools/pipeline_options.go b/sdks/go/container/tools/pipeline_options.go index 026fb31b0991..9d5bd5894eec 100644 --- a/sdks/go/container/tools/pipeline_options.go +++ b/sdks/go/container/tools/pipeline_options.go @@ -42,3 +42,53 @@ func MakePipelineOptionsFileAndEnvVar(options string) error { os.Setenv("PIPELINE_OPTIONS_FILE", f.Name()) return nil } + +type PipelineOptionsData struct { + Options LegacyOptionsData `json:"options"` + Experiments []string `json:"beam:option:experiments:v1"` +} + +type LegacyOptionsData struct { + Experiments []string `json:"experiments"` +} + +// GetExperiments extracts a string array of experiments from the pipeline +// options string (in JSON format) +// +// The JSON string can be in two styles: +// +// Legacy style: +// +// { +// "display_data": [ +// {...}, +// ], +// "options": { +// ... +// "experiments": [ +// ... +// ], +// } +// } +// +// URN style: +// +// { +// "beam:option:experiments:v1": [ +// ... +// ] +// } +func GetExperiments(options string) []string { + var opts PipelineOptionsData + err := json.Unmarshal([]byte(options), &opts) + if err != nil { + return nil + } + + // Check the legacy style experiments first + if opts.Options.Experiments != nil { + return opts.Options.Experiments + } + + return opts.Experiments +} diff --git a/sdks/go/container/tools/pipeline_options_test.go b/sdks/go/container/tools/pipeline_options_test.go index 7a0d7ebd5f09..7920eb0f35a2 100644 --- a/sdks/go/container/tools/pipeline_options_test.go +++ b/sdks/go/container/tools/pipeline_options_test.go @@ -56,3 +56,71 @@ func TestMakePipelineOptionsFileAndEnvVar(t *testing.T) { } os.Remove("pipeline_options.json") } + +func TestGetExperiments(t *testing.T) { + tests := []struct { + name string + inputOptions string + expectedExps []string + }{ + { + "no experiments", + `{"options": {"a": "b"}}`, + nil, + }, + { + "valid legacy experiments", + `{"options": {"experiments": ["a", "b"]}}`, + []string{"a", "b"}, + }, + { + "valid urn experiments", + `{"beam:option:experiments:v1": ["a", "b"]}`, + []string{"a", "b"}, + }, + { + "valid legacy and urn experiments; legacy first", + `{"options": {"experiments": ["c", "d"]}, "beam:option:experiments:v1": ["a", "b"]}`, + []string{"c", "d"}, + }, + { + "valid legacy and urn experiments; legacy first, even if empty", + `{"options": {"experiments": []}, "beam:option:experiments:v1": ["a", "b"]}`, + []string{}, + }, + { + "empty legacy experiments", + `{"options": {"experiments": []}}`, + []string{}, + }, + { + "empty urn experiments", + `{"beam:option:experiments:v1": []}`, + []string{}, + }, + { + "invalid json", + `{options: {"experiments": []}}`, + nil, + }, + { + "empty string", + "", + nil, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + exps := GetExperiments(test.inputOptions) + if len(exps) != len(test.expectedExps) { + t.Errorf("got: %v, want: %v", exps, test.expectedExps) + } + for i, v := range exps { + if v != test.expectedExps[i] { + t.Errorf("got: %v, want: %v", exps, test.expectedExps) + } + } + }) + } +} diff --git a/sdks/python/container/boot.go b/sdks/python/container/boot.go index 572dbf011134..d6a098dc01b2 100644 --- a/sdks/python/container/boot.go +++ b/sdks/python/container/boot.go @@ -117,37 +117,6 @@ func main() { } } -// The json string of pipeline options is in the following format. -// We only focus on experiments here. -// -// { -// "display_data": [ -// {...}, -// ], -// "options": { -// ... -// "experiments": [ -// ... -// ], -// } -// } -type PipelineOptionsData struct { - Options OptionsData `json:"options"` -} - -type OptionsData struct { - Experiments []string `json:"experiments"` -} - -func getExperiments(options string) []string { - var opts PipelineOptionsData - err := json.Unmarshal([]byte(options), &opts) - if err != nil { - return nil - } - return opts.Options.Experiments -} - func launchSDKProcess() error { ctx := grpcx.WriteWorkerID(context.Background(), *id) @@ -187,7 +156,7 @@ func launchSDKProcess() error { logger.Fatalf(ctx, "Failed to convert pipeline options: %v", err) } - experiments := getExperiments(options) + experiments := tools.GetExperiments(options) pipNoBuildIsolation = false if slices.Contains(experiments, "pip_no_build_isolation") { pipNoBuildIsolation = true