From 1a5dea99602cff12d913782a9488e6686c1cf306 Mon Sep 17 00:00:00 2001 From: Chad Date: Wed, 7 May 2025 13:45:25 -0400 Subject: [PATCH] Allow decodePostflightRequest to decode json with strings Json decoder would fail to decode post flight json with quoted integers example: {"rules_received":"1","rules_processed":"1","machine_id":"serial_number"} --- go.mod | 1 + go.sum | 2 + moroz/svc_postflight.go | 5 ++- moroz/svc_postflight_test.go | 77 ++++++++++++++++++++++++++++++++++++ 4 files changed, 83 insertions(+), 2 deletions(-) create mode 100644 moroz/svc_postflight_test.go diff --git a/go.mod b/go.mod index 9977b8f..fe0a79f 100644 --- a/go.mod +++ b/go.mod @@ -5,6 +5,7 @@ go 1.23.1 require ( github.com/BurntSushi/toml v0.2.0 github.com/go-kit/kit v0.4.0 + github.com/goccy/go-yaml v1.17.1 github.com/gorilla/mux v1.6.1 github.com/kolide/kit v0.0.0-20180912215818-0c28f72eb2b0 github.com/oklog/run v1.0.0 diff --git a/go.sum b/go.sum index bfacc4e..665dc9b 100644 --- a/go.sum +++ b/go.sum @@ -6,6 +6,8 @@ github.com/go-logfmt/logfmt v0.3.0 h1:8HUsc87TaSWLKwrnumgC8/YconD2fJQsRJAsWaPg2i github.com/go-logfmt/logfmt v0.3.0/go.mod h1:Qt1PoO58o5twSAckw1HlFXLmHsOX5/0LbT9GBnD5lWE= github.com/go-stack/stack v1.7.0 h1:S04+lLfST9FvL8dl4R31wVUC/paZp/WQZbLmUgWboGw= github.com/go-stack/stack v1.7.0/go.mod h1:v0f6uXyyMGvRgIKkXu+yp6POWl0qKG85gN/melR3HDY= +github.com/goccy/go-yaml v1.17.1 h1:LI34wktB2xEE3ONG/2Ar54+/HJVBriAGJ55PHls4YuY= +github.com/goccy/go-yaml v1.17.1/go.mod h1:XBurs7gK8ATbW4ZPGKgcbrY1Br56PdM69F7LkFRi1kA= github.com/gorilla/context v0.0.0-20160226214623-1ea25387ff6f h1:9oNbS1z4rVpbnkHBdPZU4jo9bSmrLpII768arSyMFgk= github.com/gorilla/context v0.0.0-20160226214623-1ea25387ff6f/go.mod h1:kBGZzfjB9CEq2AlWe17Uuf7NDRt0dE0s8S51q0aT7Yg= github.com/gorilla/mux v1.6.1 h1:KOwqsTYZdeuMacU7CxjMNYEKeBvLbxW+psodrbcEa3A= diff --git a/moroz/svc_postflight.go b/moroz/svc_postflight.go index 5a45f00..75b4ffe 100644 --- a/moroz/svc_postflight.go +++ b/moroz/svc_postflight.go @@ -3,10 +3,11 @@ package moroz import ( "compress/zlib" "context" - "encoding/json" "net/http" "time" + "github.com/goccy/go-yaml" + "github.com/go-kit/kit/endpoint" "github.com/groob/moroz/santa" ) @@ -50,7 +51,7 @@ func decodePostflightRequest(ctx context.Context, r *http.Request) (interface{}, return nil, err } req := postflightRequest{MachineID: id} - if err := json.NewDecoder(zr).Decode(&req.payload); err != nil { + if err := yaml.NewDecoder(zr).Decode(&req.payload); err != nil { return nil, err } return req, nil diff --git a/moroz/svc_postflight_test.go b/moroz/svc_postflight_test.go new file mode 100644 index 0000000..936a9bd --- /dev/null +++ b/moroz/svc_postflight_test.go @@ -0,0 +1,77 @@ +package moroz + +import ( + "bytes" + "compress/zlib" + "context" + "net/http" + "testing" + + "github.com/gorilla/mux" +) + +func TestDecodePostflightRequest(t *testing.T) { + tests := []struct { + name string + inputJSON string + expectedID string + expectedError bool + }{ + { + name: "Valid JSON", + inputJSON: `{"rules_received":1,"rules_processed":1,"machine_id":"serial_number"}`, + expectedID: "serial_number", + expectedError: false, + }, + { + name: "Valid JSON with Strings", + inputJSON: `{"rules_received":"1","rules_processed":"1","machine_id":"serial_number"}`, + expectedID: "serial_number", + expectedError: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var buf bytes.Buffer + zw := zlib.NewWriter(&buf) + _, err := zw.Write([]byte(tt.inputJSON)) + if err != nil { + t.Fatalf("failed to write compressed data: %v", err) + } + zw.Close() + + req, err := http.NewRequest("POST", "/v1/santa/postflight/serial_number", &buf) + if err != nil { + t.Fatalf("failed to create request: %v", err) + } + + vars := map[string]string{"id": "serial_number"} + req = mux.SetURLVars(req, vars) + + req.Header.Set("Content-Encoding", "deflate") + + result, err := decodePostflightRequest(context.Background(), req) + if tt.expectedError { + if err == nil { + t.Errorf("expected an error but got none") + } + return + } + + if err != nil { + t.Errorf("unexpected error: %v", err) + return + } + + reqResult, ok := result.(postflightRequest) + if !ok { + t.Fatalf("expected postflightRequest, got %T", result) + } + + if reqResult.MachineID != tt.expectedID { + t.Errorf("expected MachineID %q, got %q", tt.expectedID, reqResult.MachineID) + } + }) + } +}