diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index c827551b..e1fe9d86 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -9,7 +9,42 @@ on: concurrency: group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }} cancel-in-progress: true +env: + SIMULACRON_VERSION: 0.10.0 + GO_VERSION: 1.24.2 jobs: + dependencies: + name: Fetch dependencies + runs-on: ubuntu-latest + steps: + - uses: actions/cache@v4 + id: restore-go + with: + path: ~/deps/godl + key: ${{ runner.os }}-deps-godl-${{ env.GO_VERSION }} + + - uses: actions/cache@v4 + id: restore-simulacron + with: + path: ~/deps/simulacron + key: ${{ runner.os }}-deps-simulacron-${{ env.SIMULACRON_VERSION }} + + - if: ${{ steps.restore-go.outputs.cache-hit != 'true' }} + name: Download Go + continue-on-error: true + run: | + mkdir -p ~/deps/godl + cd ~/deps/godl + wget -O go.tar.gz https://go.dev/dl/go${{ env.GO_VERSION }}.linux-amd64.tar.gz + + - if: ${{ steps.restore-simulacron.outputs.cache-hit != 'true' }} + name: Download simulacron + continue-on-error: true + run: | + mkdir -p ~/deps/simulacron + cd ~/deps/simulacron + wget -O simulacron.jar https://github.com/datastax/simulacron/releases/download/${{ env.SIMULACRON_VERSION }}/simulacron-standalone-${{ env.SIMULACRON_VERSION }}.jar + # Runs a NoSQLBench job in docker-compose with 3 proxy nodes # Verifies the written data matches in both ORIGIN and TARGET clusters nosqlbench-tests: @@ -20,7 +55,7 @@ jobs: - name: Start docker-compose id: compose run: | - docker compose -f docker-compose-tests.yml up --abort-on-container-exit --exit-code-from=nosqlbench + docker compose -f docker-compose-tests.yml up --abort-on-container-exit --exit-code-from=nosqlbench4 - name: Test Summary if: ${{ failure() }} run: | @@ -29,15 +64,20 @@ jobs: # Runs all the unit tests under the proxy module (all the *_test.go files) unit-tests: name: Unit Tests + needs: dependencies runs-on: ubuntu-latest steps: - uses: actions/checkout@v2 + - uses: actions/cache@v4 + id: restore-go + with: + path: ~/deps/godl + key: ${{ runner.os }}-deps-godl-${{ env.GO_VERSION }} - name: Run run: | sudo apt update sudo apt -y install default-jre gcc git wget - wget https://go.dev/dl/go1.24.2.linux-amd64.tar.gz - sudo tar -xzf go*.tar.gz -C /usr/local/ + sudo tar -xzf ~/deps/godl/go.tar.gz -C /usr/local/ export PATH=$PATH:/usr/local/go/bin export PATH=$PATH:`go env GOPATH`/bin go install github.com/jstemmer/go-junit-report/v2@latest @@ -52,21 +92,35 @@ jobs: # These tests use Simulacron and in-memory CQLServer integration-tests-mock: name: Mock Tests + needs: dependencies runs-on: ubuntu-latest steps: - uses: actions/checkout@v2 + + - uses: actions/cache@v4 + id: restore-go + with: + path: ~/deps/godl + key: ${{ runner.os }}-deps-godl-${{ env.GO_VERSION }} + + - uses: actions/cache@v4 + id: restore-simulacron + with: + path: ~/deps/simulacron + key: ${{ runner.os }}-deps-simulacron-${{ env.SIMULACRON_VERSION }} + - name: Run run: | sudo apt update sudo apt -y install openjdk-8-jdk gcc git wget - wget https://go.dev/dl/go1.24.2.linux-amd64.tar.gz - sudo tar -xzf go*.tar.gz -C /usr/local/ + sudo tar -xzf ~/deps/godl/go.tar.gz -C /usr/local/ export PATH=$PATH:/usr/local/go/bin export PATH=$PATH:`go env GOPATH`/bin go install github.com/jstemmer/go-junit-report/v2@latest - wget https://github.com/datastax/simulacron/releases/download/0.10.0/simulacron-standalone-0.10.0.jar - export SIMULACRON_PATH=`pwd`/simulacron-standalone-0.10.0.jar + cp ~/deps/simulacron/simulacron.jar . + export SIMULACRON_PATH=`pwd`/simulacron.jar go test -timeout 180m -v 2>&1 ./integration-tests | go-junit-report -set-exit-code -iocopy -out report-integration-mock.xml + - name: Test Summary uses: test-summary/action@v1 if: always() @@ -76,26 +130,63 @@ jobs: # Runs integration tests using CCM integration-tests-ccm: name: CCM Tests + needs: dependencies runs-on: ubuntu-latest + strategy: + fail-fast: false + matrix: + cassandra_version: [ '2.2.19', '3.11.19', '4.1.9', '5.0.6', 'dse-4.8.16', 'dse-5.1.48', 'dse-6.8.61' ] steps: - uses: actions/checkout@v2 + + - uses: actions/cache@v4 + id: restore-go + with: + path: ~/deps/godl + key: ${{ runner.os }}-deps-godl-${{ env.GO_VERSION }} + + - uses: actions/cache@v4 + id: restore-cache-ccm + with: + path: ~/.ccm/repository + key: ${{ runner.os }}-ccm-${{ matrix.cassandra_version }} + - name: Run run: | sudo apt update + sudo apt -y install openjdk-8-jdk gcc git wget pip - wget https://go.dev/dl/go1.24.2.linux-amd64.tar.gz - sudo tar -xzf go*.tar.gz -C /usr/local/ - export PATH=$PATH:/usr/local/go/bin - export PATH=$PATH:`go env GOPATH`/bin + sudo apt -y install openjdk-11-jdk gcc git wget pip + export JAVA_HOME=/usr/lib/jvm/java-8-openjdk-amd64 + export JAVA8_HOME=/usr/lib/jvm/java-8-openjdk-amd64 + export JAVA11_HOME=/usr/lib/jvm/java-11-openjdk-amd64 export PATH=$JAVA_HOME/bin:$PATH java -version + + sudo tar -xzf ~/deps/godl/go.tar.gz -C /usr/local/ + export PATH=$PATH:/usr/local/go/bin + export PATH=$PATH:`go env GOPATH`/bin + go install github.com/jstemmer/go-junit-report/v2@latest - pip install ccm + + CCM_VERSION="0e20102c1cad99104969239f1ac375b6fcaa7bbc" + export CCM_VERSION + echo "Install CCM ${CCM_VERSION}" + pip install "git+https://github.com/apache/cassandra-ccm.git@${CCM_VERSION}" + which ccm sudo ln -s /home/runner/.local/bin/ccm /usr/local/bin/ccm /usr/local/bin/ccm list - go test -timeout 180m -v 2>&1 ./integration-tests -RUN_MOCKTESTS=false -RUN_CCMTESTS=true | go-junit-report -set-exit-code -iocopy -out report-integration-ccm.xml + + go test -timeout 180m -v 2>&1 ./integration-tests -RUN_MOCKTESTS=false -RUN_CCMTESTS=true -CASSANDRA_VERSION=${{ matrix.cassandra_version }} | go-junit-report -set-exit-code -iocopy -out report-integration-ccm.xml + + - uses: actions/cache/save@v4 + if: always() + with: + path: ~/.ccm/repository + key: ${{ runner.os }}-ccm-${{ matrix.cassandra_version }} + - name: Test Summary uses: test-summary/action@v1 if: always() @@ -105,21 +196,31 @@ jobs: # Runs the mock tests with go's race checker to spot potential data races race-checker: name: Race Checker + needs: dependencies runs-on: ubuntu-latest if: ${{ false }} # temporarily disabled steps: - uses: actions/checkout@v2 + - uses: actions/cache@v4 + id: restore-simulacron + with: + path: ~/deps/simulacron + key: ${{ runner.os }}-deps-simulacron-${{ env.SIMULACRON_VERSION }} + - uses: actions/cache@v4 + id: restore-go + with: + path: ~/deps/godl + key: ${{ runner.os }}-deps-godl-${{ env.GO_VERSION }} - name: Run run: | sudo apt update sudo apt -y install openjdk-8-jdk gcc git pip wget - wget https://go.dev/dl/go1.24.2.linux-amd64.tar.gz - sudo tar -xzf go*.tar.gz -C /usr/local/ + sudo tar -xzf ~/deps/godl/go.tar.gz -C /usr/local/ export PATH=$PATH:/usr/local/go/bin export PATH=$PATH:`go env GOPATH`/bin go install github.com/jstemmer/go-junit-report/v2@latest - wget https://github.com/datastax/simulacron/releases/download/0.10.0/simulacron-standalone-0.10.0.jar - export SIMULACRON_PATH=`pwd`/simulacron-standalone-0.10.0.jar + cp ~/deps/simulacron/simulacron.jar . + export SIMULACRON_PATH=`pwd`/simulacron.jar go test -race -timeout 180m -v 2>&1 ./integration-tests | go-junit-report -set-exit-code -iocopy -out report-integration-race.xml - name: Test Summary uses: test-summary/action@v1 @@ -130,16 +231,21 @@ jobs: # Performs static analysis to check for things like context leaks go-vet: name: Go Vet + needs: dependencies runs-on: ubuntu-latest if: ${{ false }} # temporarily disabled steps: - uses: actions/checkout@v2 + - uses: actions/cache@v4 + id: restore-go + with: + path: ~/deps/godl + key: ${{ runner.os }}-deps-godl-${{ env.GO_VERSION }} - name: Run run: | sudo apt update sudo apt -y install openjdk-8-jdk gcc git pip wget - wget https://go.dev/dl/go1.24.2.linux-amd64.tar.gz - sudo tar -xzf go*.tar.gz -C /usr/local/ + sudo tar -xzf ~/deps/godl/go.tar.gz -C /usr/local/ export PATH=$PATH:/usr/local/go/bin export PATH=$PATH:`go env GOPATH`/bin go vet ./... \ No newline at end of file diff --git a/CHANGELOG/CHANGELOG-2.3.md b/CHANGELOG/CHANGELOG-2.3.md index 1071ef23..60074206 100644 --- a/CHANGELOG/CHANGELOG-2.3.md +++ b/CHANGELOG/CHANGELOG-2.3.md @@ -4,11 +4,6 @@ Changelog for the ZDM Proxy, new PRs should update the `unreleased` section. When cutting a new release, update the `unreleased` heading to the tag being generated and date, like `## vX.Y.Z - YYYY-MM-DD` and create a new placeholder section for `unreleased` entries. -## Unreleased - -* [#150](https://github.com/datastax/zdm-proxy/issues/150): CQL request tracing -* [#154](https://github.com/datastax/zdm-proxy/issues/154): Support CQL request compression - --- ## v2.3.4 - 2025-05-29 diff --git a/CHANGELOG/CHANGELOG-2.4.md b/CHANGELOG/CHANGELOG-2.4.md new file mode 100644 index 00000000..1c12b4eb --- /dev/null +++ b/CHANGELOG/CHANGELOG-2.4.md @@ -0,0 +1,20 @@ +# Changelog + +Changelog for the ZDM Proxy, new PRs should update the `unreleased` section. + +When cutting a new release, update the `unreleased` heading to the tag being generated and date, like `## vX.Y.Z - YYYY-MM-DD` and create a new placeholder section for `unreleased` entries. + +--- + +## v2.4.0 - 2026-01-16 + +### New Features + +* [#150](https://github.com/datastax/zdm-proxy/issues/150): CQL request tracing +* [#154](https://github.com/datastax/zdm-proxy/issues/154): Support CQL request compression +* [#157](https://github.com/datastax/zdm-proxy/pull/157): Support protocol v5 +* [#157](https://github.com/datastax/zdm-proxy/pull/157): New Configuration setting to block specific protocol versions + +### Improvements + +* [#157](https://github.com/datastax/zdm-proxy/pull/157): Improvements to CI so we can find regressions with multiple C* versions before merging a PR \ No newline at end of file diff --git a/RELEASE_NOTES.md b/RELEASE_NOTES.md index 16cb4348..dbc0971b 100644 --- a/RELEASE_NOTES.md +++ b/RELEASE_NOTES.md @@ -6,6 +6,18 @@ Build artifacts are available at [Docker Hub](https://hub.docker.com/repository/ For additional details on the changes included in a specific release, see the associated CHANGELOG-x.x.md file. +## v2.4.0 - 2026-01-16 + +Support LZ4 and snappy compression. + +Support protocol v5. + +New configuration setting `ZDM_BLOCKED_PROTOCOL_VERSIONS` to block specific protocol versions at the proxy level. + +Send request id in the request payload (currently supported by Astra only). + +[Changelog](CHANGELOG/CHANGELOG-2.4.md#v240---2026-01-16) + ## v2.3.4 - 2025-05-29 Fix CQL stream ID validation for internal heartbeat mechanism. diff --git a/compose/nosqlbench-entrypoint.sh b/compose/nosqlbench-entrypoint.sh deleted file mode 100755 index c49e66e9..00000000 --- a/compose/nosqlbench-entrypoint.sh +++ /dev/null @@ -1,78 +0,0 @@ -#!/bin/sh -apk add --no-cache netcat-openbsd -apk add py3-pip -pip install cqlsh - -function test_conn() { - nc -z -v $1 9042; - while [ $? -ne 0 ]; - do echo "CQL port not ready on $1"; - sleep 10; - nc -z -v $1 9042; - done -} - -# Wait for clusters and proxy to be responsive -test_conn zdm_tests_origin -test_conn zdm_tests_target -test_conn zdm_tests_proxy - -set -e - -echo "Creating schema" -cat /source/nb-tests/schema.cql | cqlsh zdm_tests_proxy - -echo "Running NoSQLBench RAMPUP job" -java -jar /nb.jar \ - --show-stacktraces \ - /source/nb-tests/cql-nb-activity.yaml \ - rampup \ - driver=cqld3 \ - hosts=zdm_tests_proxy \ - localdc=datacenter1 \ - errors=retry \ - -v - -echo "Running NoSQLBench WRITE job" -java -jar /nb.jar \ - --show-stacktraces \ - /source/nb-tests/cql-nb-activity.yaml \ - write \ - driver=cqld3 \ - hosts=zdm_tests_proxy \ - localdc=datacenter1 \ - errors=retry \ - -v - -echo "Running NoSQLBench READ job" -java -jar /nb.jar \ - --show-stacktraces \ - /source/nb-tests/cql-nb-activity.yaml \ - read \ - driver=cqld3 \ - hosts=zdm_tests_proxy \ - localdc=datacenter1 \ - errors=retry \ - -v - -echo "Running NoSQLBench VERIFY job on ORIGIN" -java -jar /nb.jar \ - --show-stacktraces \ - --report-csv-to /source/verify-origin \ - /source/nb-tests/cql-nb-activity.yaml \ - verify \ - driver=cqld3 \ - hosts=zdm_tests_origin \ - localdc=datacenter1 \ - -v - -echo "Running NoSQLBench VERIFY job on TARGET" -java -jar /nb.jar \ - --show-stacktraces \ - --report-csv-to /source/verify-target \ - /source/nb-tests/cql-nb-activity.yaml \ - verify \ - driver=cqld3 \ - hosts=zdm_tests_target \ - localdc=datacenter1 \ - -v \ No newline at end of file diff --git a/compose/nosqlbench4-entrypoint.sh b/compose/nosqlbench4-entrypoint.sh new file mode 100755 index 00000000..f8e7e503 --- /dev/null +++ b/compose/nosqlbench4-entrypoint.sh @@ -0,0 +1,45 @@ +#!/bin/sh + +# Block until the given file appears or the given timeout is reached. +# Exit status is 0 iff the file exists. +wait_file() { + local file="$1"; shift + local wait_seconds="${1:-10}"; shift # 10 seconds as default timeout + test $wait_seconds -lt 1 && echo 'At least 1 second is required' && return 1 + + until test $((wait_seconds--)) -eq 0 -o -e "$file" ; do sleep 1; done + + test $wait_seconds -ge 0 # equivalent: let ++wait_seconds +} + +donefile=/source/donefile + +wait_file "$donefile" 1200 || { + echo "donefile missing after waiting for 1200 seconds: '$donefile'" + exit 1 +} +echo "File found" + +set -e + +echo "Running NoSQLBench VERIFY job on ORIGIN" +java -jar /nb.jar \ + --show-stacktraces \ + --report-csv-to /source/verify-origin \ + /source/nb-tests/cql-nb-activity.yaml \ + verify \ + driver=cqld3 \ + hosts=zdm_tests_origin \ + localdc=datacenter1 \ + -vv + +echo "Running NoSQLBench VERIFY job on TARGET" +java -jar /nb.jar \ + --show-stacktraces \ + --report-csv-to /source/verify-target \ + /source/nb-tests/cql-nb-activity.yaml \ + verify \ + driver=cqld3 \ + hosts=zdm_tests_target \ + localdc=datacenter1 \ + -vv \ No newline at end of file diff --git a/compose/nosqlbench5-entrypoint.sh b/compose/nosqlbench5-entrypoint.sh new file mode 100755 index 00000000..f747eca8 --- /dev/null +++ b/compose/nosqlbench5-entrypoint.sh @@ -0,0 +1,73 @@ +#!/bin/bash + +apt-get update +apt-get install -y netcat-openbsd + +function test_conn() { + nc -z -v $1 9042; + while [ $? -ne 0 ]; + do echo "CQL port not ready on $1"; + sleep 10; + nc -z -v $1 9042; + done +} + +# Wait for clusters and proxy to be responsive +test_conn zdm_tests_origin +test_conn zdm_tests_target +test_conn zdm_tests_proxy + +set -e + +echo "Running NoSQLBench SCHEMA job" +java -jar /nb.jar \ + --show-stacktraces \ + /source/nb-tests/cql-nb-activity.yaml \ + schema \ + driver=cqld4 \ + hosts=zdm_tests_proxy \ + localdc=datacenter1 \ + errors=retry \ + --log-level-override com.datastax.oss.driver:INFO,com.datastax.oss.driver.internal.core.session.PoolManager:DEBUG,com.datastax.oss.driver.internal.core.pool.ChannelPool:DEBUG,com.datastax.oss.driver.internal.core.metadata.NodeStateManager:DEBUG,com.datastax.oss.driver.internal.core.metadata.MetadataManager:DEBUG,com.datastax.oss.driver.internal.core.util.concurrent.Reconnection:DEBUG \ + -vv + +echo "Running NoSQLBench RAMPUP job" +java -jar /nb.jar \ + --show-stacktraces \ + /source/nb-tests/cql-nb-activity.yaml \ + rampup \ + driver=cqld4 \ + hosts=zdm_tests_proxy \ + localdc=datacenter1 \ + errors=retry \ + --log-level-override com.datastax.oss.driver:INFO,com.datastax.oss.driver.internal.core.session.PoolManager:DEBUG,com.datastax.oss.driver.internal.core.pool.ChannelPool:DEBUG,com.datastax.oss.driver.internal.core.metadata.NodeStateManager:DEBUG,com.datastax.oss.driver.internal.core.metadata.MetadataManager:DEBUG,com.datastax.oss.driver.internal.core.util.concurrent.Reconnection:DEBUG \ + -vv + +echo "Running NoSQLBench WRITE job" +java -jar /nb.jar \ + --show-stacktraces \ + /source/nb-tests/cql-nb-activity.yaml \ + write \ + driver=cqld4 \ + hosts=zdm_tests_proxy \ + localdc=datacenter1 \ + errors=retry \ + --log-level-override com.datastax.oss.driver:INFO,com.datastax.oss.driver.internal.core.session.PoolManager:DEBUG,com.datastax.oss.driver.internal.core.pool.ChannelPool:DEBUG,com.datastax.oss.driver.internal.core.metadata.NodeStateManager:DEBUG,com.datastax.oss.driver.internal.core.metadata.MetadataManager:DEBUG,com.datastax.oss.driver.internal.core.util.concurrent.Reconnection:DEBUG \ + -vv + +echo "Running NoSQLBench READ job" +java -jar /nb.jar \ + --show-stacktraces \ + /source/nb-tests/cql-nb-activity.yaml \ + read \ + driver=cqld4 \ + hosts=zdm_tests_proxy \ + localdc=datacenter1 \ + errors=retry \ + --log-level-override com.datastax.oss.driver:INFO,com.datastax.oss.driver.internal.core.session.PoolManager:DEBUG,com.datastax.oss.driver.internal.core.pool.ChannelPool:DEBUG,com.datastax.oss.driver.internal.core.metadata.NodeStateManager:DEBUG,com.datastax.oss.driver.internal.core.metadata.MetadataManager:DEBUG,com.datastax.oss.driver.internal.core.util.concurrent.Reconnection:DEBUG \ + -vv + +touch /source/donefile + +# don't exit otherwise the verification step on the other container won't run +sleep 600 \ No newline at end of file diff --git a/docker-compose-tests.yml b/docker-compose-tests.yml index b5ab5b17..b6d7389b 100644 --- a/docker-compose-tests.yml +++ b/docker-compose-tests.yml @@ -11,17 +11,25 @@ networks: services: origin: - image: cassandra:3.11.13 + image: cassandra:4.1.10 container_name: zdm_tests_origin restart: unless-stopped + command: + - cassandra + - -f + - -Dcassandra.ring_delay_ms=1000 networks: proxy: ipv4_address: 192.168.100.101 target: - image: cassandra:3.11.13 + image: cassandra:5.0.6 container_name: zdm_tests_target restart: unless-stopped + command: + - cassandra + - -f + - -Dcassandra.ring_delay_ms=1000 networks: proxy: ipv4_address: 192.168.100.102 @@ -42,14 +50,26 @@ services: proxy: ipv4_address: 192.168.100.103 - nosqlbench: + nosqlbench5: + image: nosqlbench/nosqlbench:5.21.7 + container_name: zdm_tests_nb5 + tty: true + volumes: + - .:/source + entrypoint: + - /source/compose/nosqlbench5-entrypoint.sh + networks: + proxy: + ipv4_address: 192.168.100.104 + + nosqlbench4: image: nosqlbench/nosqlbench:4.15.101 - container_name: zdm_tests_nb + container_name: zdm_tests_nb4 tty: true volumes: - .:/source entrypoint: - - /source/compose/nosqlbench-entrypoint.sh + - /source/compose/nosqlbench4-entrypoint.sh networks: proxy: - ipv4_address: 192.168.100.104 \ No newline at end of file + ipv4_address: 192.168.100.105 \ No newline at end of file diff --git a/go.mod b/go.mod index ab585f92..884d7a6e 100644 --- a/go.mod +++ b/go.mod @@ -4,17 +4,17 @@ go 1.24 require ( github.com/antlr4-go/antlr/v4 v4.13.1 - github.com/datastax/go-cassandra-native-protocol v0.0.0-20240626123646-2abea740da8d - github.com/gocql/gocql v0.0.0-20200624222514-34081eda590e + github.com/apache/cassandra-gocql-driver/v2 v2.0.0 + github.com/datastax/go-cassandra-native-protocol v0.0.0-20240903140133-605a850e203b github.com/google/uuid v1.1.1 github.com/jpillora/backoff v1.0.0 github.com/kelseyhightower/envconfig v1.4.0 github.com/mcuadros/go-defaults v1.2.0 github.com/prometheus/client_golang v1.11.1 github.com/prometheus/client_model v0.2.0 - github.com/rs/zerolog v1.20.0 + github.com/rs/zerolog v1.34.0 github.com/sirupsen/logrus v1.6.0 - github.com/stretchr/testify v1.8.0 + github.com/stretchr/testify v1.9.0 golang.org/x/time v0.12.0 gopkg.in/yaml.v3 v3.0.1 ) @@ -25,16 +25,17 @@ require ( github.com/davecgh/go-spew v1.1.1 // indirect github.com/golang/protobuf v1.5.0 // indirect github.com/golang/snappy v0.0.3 // indirect - github.com/hailocab/go-hostpool v0.0.0-20160125115350-e80d13ce29ed // indirect github.com/konsorten/go-windows-terminal-sequences v1.0.3 // indirect - github.com/kr/pretty v0.2.1 // indirect + github.com/kr/text v0.2.0 // indirect + github.com/mattn/go-colorable v0.1.13 // indirect + github.com/mattn/go-isatty v0.0.19 // indirect github.com/matttproud/golang_protobuf_extensions v1.0.1 // indirect - github.com/pierrec/lz4/v4 v4.0.3 // indirect + github.com/pierrec/lz4/v4 v4.1.8 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect github.com/prometheus/common v0.26.0 // indirect github.com/prometheus/procfs v0.6.0 // indirect golang.org/x/exp v0.0.0-20240506185415-9bf2ced13842 // indirect - golang.org/x/sys v0.3.0 // indirect + golang.org/x/sys v0.12.0 // indirect google.golang.org/protobuf v1.33.0 // indirect gopkg.in/inf.v0 v0.9.1 // indirect ) diff --git a/go.sum b/go.sum index e273a478..3c7a7f80 100644 --- a/go.sum +++ b/go.sum @@ -6,19 +6,18 @@ github.com/alecthomas/units v0.0.0-20190717042225-c3de453c63f4/go.mod h1:ybxpYRF github.com/alecthomas/units v0.0.0-20190924025748-f65c72e2690d/go.mod h1:rBZYJk541a8SKzHPHnH3zbiI+7dagKZ0cgpgrD7Fyho= github.com/antlr4-go/antlr/v4 v4.13.1 h1:SqQKkuVZ+zWkMMNkjy5FZe5mr5WURWnlpmOuzYWrPrQ= github.com/antlr4-go/antlr/v4 v4.13.1/go.mod h1:GKmUxMtwp6ZgGwZSva4eWPC5mS6vUAmOABFgjdkM7Nw= +github.com/apache/cassandra-gocql-driver/v2 v2.0.0 h1:Omnzb1Z/P90Dr2TbVNu54ICQL7TKVIIsJO231w484HU= +github.com/apache/cassandra-gocql-driver/v2 v2.0.0/go.mod h1:QH/asJjB3mHvY6Dot6ZKMMpTcOrWJ8i9GhsvG1g0PK4= github.com/beorn7/perks v0.0.0-20180321164747-3a771d992973/go.mod h1:Dwedo/Wpr24TaqPxmxbtue+5NUziq4I4S80YR8gNf3Q= github.com/beorn7/perks v1.0.0/go.mod h1:KWe93zE9D1o94FZ5RNwFwVgaQK1VOXiVxmqh+CedLV8= github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM= github.com/beorn7/perks v1.0.1/go.mod h1:G2ZrVWU2WbWT9wwq4/hrbKbnv/1ERSJQ0ibhJ6rlkpw= -github.com/bitly/go-hostpool v0.0.0-20171023180738-a3a6125de932 h1:mXoPYz/Ul5HYEDvkta6I8/rnYM5gSdSV2tJ6XbZuEtY= -github.com/bitly/go-hostpool v0.0.0-20171023180738-a3a6125de932/go.mod h1:NOuUCSz6Q9T7+igc/hlvDOUdtWKryOrtFyIVABv/p7k= -github.com/bmizerany/assert v0.0.0-20160611221934-b7ed37b82869 h1:DDGfHa7BWjL4YnC6+E63dPcxHo2sUxDIu8g3QgEJdRY= -github.com/bmizerany/assert v0.0.0-20160611221934-b7ed37b82869/go.mod h1:Ekp36dRnpXw/yCqJaO+ZrUyxD+3VXMFFr56k5XYrpB4= github.com/cespare/xxhash/v2 v2.1.1 h1:6MnRN8NT7+YBpUIWxHtefFZOKTAPgGjpQSxqLNn0+qY= github.com/cespare/xxhash/v2 v2.1.1/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= -github.com/coreos/go-systemd v0.0.0-20190321100706-95778dfbb74e/go.mod h1:F5haX7vjVVG0kc13fIWeqUViNPyEJxv/OmvnBo0Yme4= -github.com/datastax/go-cassandra-native-protocol v0.0.0-20240626123646-2abea740da8d h1:UnPtAA8Ux3GvHLazSSUydERFuoQRyxHrB8puzXyjXIE= -github.com/datastax/go-cassandra-native-protocol v0.0.0-20240626123646-2abea740da8d/go.mod h1:6FzirJfdffakAVqmHjwVfFkpru/gNbIazUOK5rIhndc= +github.com/coreos/go-systemd/v22 v22.5.0/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSVTIJ3seZv2GcEnc= +github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= +github.com/datastax/go-cassandra-native-protocol v0.0.0-20240903140133-605a850e203b h1:o7DLYw053jrHE9ii7pO4t/5GT6d/s6Eko+Szzj4j894= +github.com/datastax/go-cassandra-native-protocol v0.0.0-20240903140133-605a850e203b/go.mod h1:6FzirJfdffakAVqmHjwVfFkpru/gNbIazUOK5rIhndc= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= @@ -29,8 +28,7 @@ github.com/go-logfmt/logfmt v0.3.0/go.mod h1:Qt1PoO58o5twSAckw1HlFXLmHsOX5/0LbT9 github.com/go-logfmt/logfmt v0.4.0/go.mod h1:3RMwSq7FuexP4Kalkev3ejPJsZTpXXBr9+V4qmtdjCk= github.com/go-logfmt/logfmt v0.5.0/go.mod h1:wCYkCAKZfumFQihp8CzCvQ3paCTfi41vtzG1KdI/P7A= github.com/go-stack/stack v1.8.0/go.mod h1:v0f6uXyyMGvRgIKkXu+yp6POWl0qKG85gN/melR3HDY= -github.com/gocql/gocql v0.0.0-20200624222514-34081eda590e h1:SroDcndcOU9BVAduPf/PXihXoR2ZYTQYLXbupbqxAyQ= -github.com/gocql/gocql v0.0.0-20200624222514-34081eda590e/go.mod h1:DL0ekTmBSTdlNF25Orwt/JMzqIq3EJ4MVa/J/uK64OY= +github.com/godbus/dbus/v5 v5.0.4/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA= github.com/gogo/protobuf v1.1.1/go.mod h1:r8qH/GZQm5c6nD/R0oafs1akxWv10x8SbQlK7atdtwQ= github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= github.com/golang/protobuf v1.3.1/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= @@ -44,7 +42,6 @@ github.com/golang/protobuf v1.4.2/go.mod h1:oDoupMAO8OvCJWAcko0GGGIgR6R6ocIYbsSw github.com/golang/protobuf v1.4.3/go.mod h1:oDoupMAO8OvCJWAcko0GGGIgR6R6ocIYbsSw735rRwI= github.com/golang/protobuf v1.5.0 h1:LUVKkCeviFUMKqHa4tXIIij/lbhnMbP7Fn5wKdKkRh4= github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk= -github.com/golang/snappy v0.0.0-20170215233205-553a64147049/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q= github.com/golang/snappy v0.0.3 h1:fHPg5GQYlCeLIPB9BZqMVR5nR9A+IM5zcgeTdjMYmLA= github.com/golang/snappy v0.0.3/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q= github.com/google/go-cmp v0.3.0/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU= @@ -57,8 +54,6 @@ github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeN github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= github.com/google/uuid v1.1.1 h1:Gkbcsh/GbpXz7lPftLA3P6TYMwjCLYm83jiFQZF/3gY= github.com/google/uuid v1.1.1/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= -github.com/hailocab/go-hostpool v0.0.0-20160125115350-e80d13ce29ed h1:5upAirOpQc1Q53c0bnx2ufif5kANL7bfZWcc6VJWJd8= -github.com/hailocab/go-hostpool v0.0.0-20160125115350-e80d13ce29ed/go.mod h1:tMWxXQ9wFIaZeTI9F+hmhFiGpFmhOHzyShyFUhRm0H4= github.com/jpillora/backoff v1.0.0 h1:uvFg412JmmHBHw7iwprIxkPMI+sGQ4kzOWsMeHnm2EA= github.com/jpillora/backoff v1.0.0/go.mod h1:J/6gKK9jxlEcS3zixgDgUAsiuZ7yrSoa/FX5e0EB2j4= github.com/json-iterator/go v1.1.6/go.mod h1:+SdeFBvtyEkXs7REEP0seUULqWtbJapLOCVDaaPEHmU= @@ -73,11 +68,15 @@ github.com/konsorten/go-windows-terminal-sequences v1.0.3 h1:CE8S1cTafDpPvMhIxNJ github.com/konsorten/go-windows-terminal-sequences v1.0.3/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ= github.com/kr/logfmt v0.0.0-20140226030751-b84e30acd515/go.mod h1:+0opPa2QZZtGFBFZlji/RkVcI2GknAs/DXo4wKdlNEc= github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= -github.com/kr/pretty v0.2.1 h1:Fmg33tUaq4/8ym9TJN1x7sLJnHVwhP33CNkpYV/7rwI= -github.com/kr/pretty v0.2.1/go.mod h1:ipq/a2n7PKx3OHsz4KJII5eveXtPO4qwEXGdVfWzfnI= github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= -github.com/kr/text v0.1.0 h1:45sCR5RtlFHMR4UwH9sdQ5TC8v0qDQCHnXt+kaKSTVE= github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= +github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= +github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= +github.com/mattn/go-colorable v0.1.13 h1:fFA4WZxdEF4tXPZVKMLwD8oUnCTTo08duU7wxecdEvA= +github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovkB8vQcUbaXHg= +github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM= +github.com/mattn/go-isatty v0.0.19 h1:JITubQf0MOLdlGRuRq+jtsDlekdYPia9ZFsB8h/APPA= +github.com/mattn/go-isatty v0.0.19/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= github.com/matttproud/golang_protobuf_extensions v1.0.1 h1:4hp9jkHxhMHkqkrB3Ix0jegS5sx/RkqARlsWZ6pIwiU= github.com/matttproud/golang_protobuf_extensions v1.0.1/go.mod h1:D8He9yQNgCq6Z5Ld7szi9bcBfOoFv/3dc6xSMkL2PC0= github.com/mcuadros/go-defaults v1.2.0 h1:FODb8WSf0uGaY8elWJAkoLL0Ri6AlZ1bFlenk56oZtc= @@ -90,8 +89,8 @@ github.com/mwitkow/go-conntrack v0.0.0-20161129095857-cc309e4a2223/go.mod h1:qRW github.com/mwitkow/go-conntrack v0.0.0-20190716064945-2f068394615f/go.mod h1:qRWi+5nqEBWmkhHvq77mSJWrCKwh8bxhgT7d/eI7P4U= github.com/niemeyer/pretty v0.0.0-20200227124842-a10e7caefd8e h1:fD57ERR4JtEqsWbfPhv4DMiApHyliiK5xCTNVSPiaAs= github.com/niemeyer/pretty v0.0.0-20200227124842-a10e7caefd8e/go.mod h1:zD1mROLANZcx1PVRCS0qkT7pwLkGfwJo4zjcN/Tysno= -github.com/pierrec/lz4/v4 v4.0.3 h1:vNQKSVZNYUEAvRY9FaUXAF1XPbSOHJtDTiP41kzDz2E= -github.com/pierrec/lz4/v4 v4.0.3/go.mod h1:gZWDp/Ze/IJXGXf23ltt2EXimqmTUXEy0GFuRQyBid4= +github.com/pierrec/lz4/v4 v4.1.8 h1:ieHkV+i2BRzngO4Wd/3HGowuZStgq6QkPsD1eolNAO4= +github.com/pierrec/lz4/v4 v4.1.8/go.mod h1:gZWDp/Ze/IJXGXf23ltt2EXimqmTUXEy0GFuRQyBid4= github.com/pkg/errors v0.8.0/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= @@ -115,23 +114,26 @@ github.com/prometheus/procfs v0.0.2/go.mod h1:TjEm7ze935MbeOT/UhFTIMYKhuLP4wbCsT github.com/prometheus/procfs v0.1.3/go.mod h1:lV6e/gmhEcM9IjHGsFOCxxuZ+z1YqCvr4OA4YeYWdaU= github.com/prometheus/procfs v0.6.0 h1:mxy4L2jP6qMonqmq+aTtOx1ifVWUgG/TAmntgbh3xv4= github.com/prometheus/procfs v0.6.0/go.mod h1:cz+aTbrPOrUb4q7XlbU9ygM+/jj0fzG6c1xBZuNvfVA= -github.com/rs/xid v1.2.1/go.mod h1:+uKXf+4Djp6Md1KODXJxgGQPKngRmWyn10oCKFzNHOQ= -github.com/rs/zerolog v1.20.0 h1:38k9hgtUBdxFwE34yS8rTHmHBa4eN16E4DJlv177LNs= -github.com/rs/zerolog v1.20.0/go.mod h1:IzD0RJ65iWH0w97OQQebJEvTZYvsCUm9WVLWBQrJRjo= +github.com/rs/xid v1.6.0/go.mod h1:7XoLgs4eV+QndskICGsho+ADou8ySMSjJKDIan90Nz0= +github.com/rs/zerolog v1.34.0 h1:k43nTLIwcTVQAncfCw4KZ2VY6ukYoZaBPNOE8txlOeY= +github.com/rs/zerolog v1.34.0/go.mod h1:bJsvje4Z08ROH4Nhs5iH600c3IkWhwp44iRc54W6wYQ= github.com/sirupsen/logrus v1.2.0/go.mod h1:LxeOpSwHxABJmUn/MG1IvRgCAasNZTLOkJPxbbu5VWo= github.com/sirupsen/logrus v1.4.2/go.mod h1:tLMulIdttU9McNUspp0xgXVQah82FyeX6MwdIuYE2rE= github.com/sirupsen/logrus v1.6.0 h1:UBcNElsrwanuuMsnGSlYmtmgbb23qDR5dG+6X6Oo89I= github.com/sirupsen/logrus v1.6.0/go.mod h1:7uNnSEd1DgxDLC74fIahvMZmmYsHGZGEOFrfsX/uA88= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/objx v0.1.1/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= -github.com/stretchr/objx v0.4.0 h1:M2gUjqZET1qApGOWNSnZ49BAIMX4F/1plDv3+l31EJ4= -github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= +github.com/stretchr/objx v0.5.2 h1:xuMeJ0Sdp5ZMRXx/aWO6RZxdr3beISkG5/G/aIRr3pY= +github.com/stretchr/objx v0.5.2/go.mod h1:FRsXN1f5AsAjCGJKqEizvkpNtU+EGNCLh3NxZ/8L+MA= github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs= github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4= -github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= -github.com/stretchr/testify v1.8.0 h1:pSgiaMZlXftHpm5L7V1+rVB+AZJydKsMxsQBIJw4PKk= -github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= +github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg= +github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= +go.uber.org/multierr v1.10.0 h1:S0h4aNzvfcFsC3dRF1jLoaov7oRaKqRGC/pUEJ2yvPQ= +go.uber.org/multierr v1.10.0/go.mod h1:20+QtiLqy0Nd6FdQB9TLXag12DsQkrbs3htMFfDN80Y= +go.uber.org/zap v1.27.0 h1:aJMhYGrd5QSmlpLMr2MftRKl7t8J8PTZPA732ud/XR8= +go.uber.org/zap v1.27.0/go.mod h1:GB2qFLM7cTU87MWRP2mPIjqfIDnGu+VIO4V/SdhGo2E= golang.org/x/crypto v0.0.0-20180904163835-0709b304e793/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= @@ -142,12 +144,10 @@ golang.org/x/net v0.0.0-20181114220301-adae6a3d119a/go.mod h1:mL1N/T3taQHkDXs73r golang.org/x/net v0.0.0-20190108225652-1e06a53dbb7e/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= golang.org/x/net v0.0.0-20190613194153-d28f0bde5980/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= -golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= golang.org/x/net v0.0.0-20200625001655-4c5254603344/go.mod h1:/O7V0waA8r7cgGh81Ro3o1hOxt32SMVPicZroKQ2sZA= golang.org/x/oauth2 v0.0.0-20190226205417-e64efc72b421/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw= golang.org/x/sync v0.0.0-20181108010431-42b317875d0f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20181221193216-37e7f081c4d4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20201207232520-09787c993a3a/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sys v0.0.0-20180905080454-ebe1bf3edb33/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= @@ -161,15 +161,15 @@ golang.org/x/sys v0.0.0-20200615200032-f1bc736245b1/go.mod h1:h1NjWce9XRLGQEsW7w golang.org/x/sys v0.0.0-20200625212154-ddb9806d33ae/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210124154548-22da62e12c0c/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210603081109-ebe580a85c40/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.3.0 h1:w8ZOecv6NaNa/zC8944JTU3vz4u6Lagfk4RPQxv92NQ= -golang.org/x/sys v0.3.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.12.0 h1:CM0HF96J0hcLAwsHPJZjfdNzs0gftsLfgKt57wWHJ0o= +golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk= golang.org/x/time v0.12.0 h1:ScB/8o8olJvc+CQPWrK3fPZNfh7qgwCrY0zJmoEQLSE= golang.org/x/time v0.12.0/go.mod h1:CDIdPxbZBQxdj6cxyCIdrNogrJKMJ7pr37NYpMcMDSg= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= -golang.org/x/tools v0.0.0-20190828213141-aed303cbaa74/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= -golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= google.golang.org/appengine v1.4.0/go.mod h1:xpcJRLb0r/rnEns0DIKYYv+WjYCduHsrkT7/EB5XEv4= google.golang.org/protobuf v0.0.0-20200109180630-ec00e32a8dfd/go.mod h1:DFci5gLYBciE7Vtevhsrf46CRTquxDuWsQurQQe4oz8= @@ -193,6 +193,5 @@ gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v2 v2.2.4/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v2 v2.2.5/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v2 v2.3.0/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= -gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/integration-tests/asyncreads_test.go b/integration-tests/asyncreads_test.go index d487f2aa..988148b7 100644 --- a/integration-tests/asyncreads_test.go +++ b/integration-tests/asyncreads_test.go @@ -3,22 +3,25 @@ package integration_tests import ( "context" "fmt" + "sync" + "testing" + "time" + + "github.com/apache/cassandra-gocql-driver/v2" "github.com/datastax/go-cassandra-native-protocol/client" "github.com/datastax/go-cassandra-native-protocol/frame" "github.com/datastax/go-cassandra-native-protocol/message" "github.com/datastax/go-cassandra-native-protocol/primitive" - "github.com/datastax/zdm-proxy/integration-tests/setup" - "github.com/datastax/zdm-proxy/integration-tests/simulacron" - "github.com/datastax/zdm-proxy/integration-tests/utils" - "github.com/datastax/zdm-proxy/proxy/pkg/config" - "github.com/gocql/gocql" "github.com/rs/zerolog" log "github.com/sirupsen/logrus" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - "sync" - "testing" - "time" + + "github.com/datastax/zdm-proxy/integration-tests/env" + "github.com/datastax/zdm-proxy/integration-tests/setup" + "github.com/datastax/zdm-proxy/integration-tests/simulacron" + "github.com/datastax/zdm-proxy/integration-tests/utils" + "github.com/datastax/zdm-proxy/proxy/pkg/config" ) func TestAsyncReadError(t *testing.T) { @@ -49,7 +52,7 @@ func TestAsyncReadError(t *testing.T) { require.Nil(t, err) client := client.NewCqlClient("127.0.0.1:14002", nil) - cqlClientConn, err := client.ConnectAndInit(context.Background(), primitive.ProtocolVersion4, 0) + cqlClientConn, err := client.ConnectAndInit(context.Background(), env.DefaultProtocolVersionSimulacron, 0) require.Nil(t, err) defer cqlClientConn.Close() @@ -58,7 +61,7 @@ func TestAsyncReadError(t *testing.T) { Options: nil, } - rsp, err := cqlClientConn.SendAndReceive(frame.NewFrame(primitive.ProtocolVersion4, 0, queryMsg)) + rsp, err := cqlClientConn.SendAndReceive(frame.NewFrame(env.DefaultProtocolVersionSimulacron, 0, queryMsg)) require.Nil(t, err) require.Equal(t, primitive.OpCodeResult, rsp.Header.OpCode) rowsMsg, ok := rsp.Body.Message.(*message.RowsResult) @@ -95,7 +98,7 @@ func TestAsyncReadHighLatency(t *testing.T) { require.Nil(t, err) client := client.NewCqlClient("127.0.0.1:14002", nil) - cqlClientConn, err := client.ConnectAndInit(context.Background(), primitive.ProtocolVersion4, 0) + cqlClientConn, err := client.ConnectAndInit(context.Background(), env.DefaultProtocolVersionSimulacron, 0) require.Nil(t, err) defer cqlClientConn.Close() @@ -105,7 +108,7 @@ func TestAsyncReadHighLatency(t *testing.T) { } now := time.Now() - rsp, err := cqlClientConn.SendAndReceive(frame.NewFrame(primitive.ProtocolVersion4, 0, queryMsg)) + rsp, err := cqlClientConn.SendAndReceive(frame.NewFrame(env.DefaultProtocolVersionSimulacron, 0, queryMsg)) require.Less(t, time.Now().Sub(now).Milliseconds(), int64(500)) require.Nil(t, err) require.Equal(t, primitive.OpCodeResult, rsp.Header.OpCode) @@ -143,7 +146,7 @@ func TestAsyncExhaustedStreamIds(t *testing.T) { require.Nil(t, err) client := client.NewCqlClient("127.0.0.1:14002", nil) - cqlClientConn, err := client.ConnectAndInit(context.Background(), primitive.ProtocolVersion4, 0) + cqlClientConn, err := client.ConnectAndInit(context.Background(), env.DefaultProtocolVersionSimulacron, 0) require.Nil(t, err) defer cqlClientConn.Close() @@ -169,7 +172,7 @@ func TestAsyncExhaustedStreamIds(t *testing.T) { go func() { defer wg.Done() for j := 0; j < totalRequests/workers; j++ { - rsp, err := cqlClientConn.SendAndReceive(frame.NewFrame(primitive.ProtocolVersion4, 0, queryMsg)) + rsp, err := cqlClientConn.SendAndReceive(frame.NewFrame(env.DefaultProtocolVersionSimulacron, 0, queryMsg)) assert.Nil(t, err) if err != nil { continue @@ -302,14 +305,14 @@ func TestAsyncReadsRequestTypes(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { client := client.NewCqlClient("127.0.0.1:14002", nil) - cqlClientConn, err := client.ConnectAndInit(context.Background(), primitive.ProtocolVersion4, 0) + cqlClientConn, err := client.ConnectAndInit(context.Background(), env.DefaultProtocolVersionSimulacron, 0) require.Nil(t, err) defer cqlClientConn.Close() err = testSetup.Origin.DeleteLogs() require.Nil(t, err) err = testSetup.Target.DeleteLogs() require.Nil(t, err) - f := frame.NewFrame(primitive.ProtocolVersion4, 0, tt.msg) + f := frame.NewFrame(env.DefaultProtocolVersionSimulacron, 0, tt.msg) rsp, err := cqlClientConn.SendAndReceive(f) require.Nil(t, err) require.NotNil(t, rsp) @@ -324,7 +327,7 @@ func TestAsyncReadsRequestTypes(t *testing.T) { ResultMetadataId: preparedResult.ResultMetadataId, Options: nil, } - f = frame.NewFrame(primitive.ProtocolVersion4, 0, execute) + f = frame.NewFrame(env.DefaultProtocolVersionSimulacron, 0, execute) rsp, err = cqlClientConn.SendAndReceive(f) require.Nil(t, err) require.NotNil(t, rsp) diff --git a/integration-tests/auth_test.go b/integration-tests/auth_test.go index c0b1f627..b1af9692 100644 --- a/integration-tests/auth_test.go +++ b/integration-tests/auth_test.go @@ -3,18 +3,21 @@ package integration_tests import ( "context" "fmt" + "strings" + "sync" + "testing" + "time" + "github.com/datastax/go-cassandra-native-protocol/client" "github.com/datastax/go-cassandra-native-protocol/frame" "github.com/datastax/go-cassandra-native-protocol/message" "github.com/datastax/go-cassandra-native-protocol/primitive" + "github.com/stretchr/testify/require" + + "github.com/datastax/zdm-proxy/integration-tests/env" "github.com/datastax/zdm-proxy/integration-tests/setup" "github.com/datastax/zdm-proxy/proxy/pkg/config" "github.com/datastax/zdm-proxy/proxy/pkg/health" - "github.com/stretchr/testify/require" - "strings" - "sync" - "testing" - "time" ) func TestAuth(t *testing.T) { @@ -499,7 +502,6 @@ func TestAuth(t *testing.T) { originAddress := "127.0.1.1" targetAddress := "127.0.1.2" - version := primitive.ProtocolVersion4 for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { @@ -536,7 +538,7 @@ func TestAuth(t *testing.T) { client.NewDriverConnectionInitializationHandler("target", "dc2", func(_ string) {}), } - err = testSetup.Start(nil, false, primitive.ProtocolVersion4) + err = testSetup.Start(nil, false, env.DefaultProtocolVersion) require.Nil(t, err) proxy, err := setup.NewProxyInstanceWithConfig(proxyConf) @@ -566,7 +568,7 @@ func TestAuth(t *testing.T) { require.Nil(t, err, "client connection failed: %v", err) defer cqlConn.Close() - err = cqlConn.InitiateHandshake(primitive.ProtocolVersion4, 0) + err = cqlConn.InitiateHandshake(env.DefaultProtocolVersion, 0) originRequestsByConn := originRequestHandler.GetRequests() targetRequestsByConn := targetRequestHandler.GetRequests() @@ -586,7 +588,7 @@ func TestAuth(t *testing.T) { Options: &message.QueryOptions{Consistency: primitive.ConsistencyLevelOne}, } - response, err := cqlConn.SendAndReceive(frame.NewFrame(version, 0, query)) + response, err := cqlConn.SendAndReceive(frame.NewFrame(env.DefaultProtocolVersion, 0, query)) require.Nil(t, err, "query request send failed: %s", err) require.Equal(t, primitive.OpCodeResult, response.Body.Message.GetOpCode(), response.Body.Message) diff --git a/integration-tests/basicbatch_test.go b/integration-tests/basicbatch_test.go index 1d265fbb..cbde0f5b 100644 --- a/integration-tests/basicbatch_test.go +++ b/integration-tests/basicbatch_test.go @@ -2,28 +2,25 @@ package integration_tests import ( "fmt" - "github.com/datastax/zdm-proxy/integration-tests/env" + "testing" + + "github.com/stretchr/testify/require" + "github.com/datastax/zdm-proxy/integration-tests/setup" "github.com/datastax/zdm-proxy/integration-tests/utils" - "github.com/stretchr/testify/require" - "testing" - "github.com/gocql/gocql" + "github.com/apache/cassandra-gocql-driver/v2" ) // BasicBatch tests basic batch statement functionality // The test runs a basic batch statement, which includes an insert and update, // and then runs an insert and update after to make sure it works func TestBasicBatch(t *testing.T) { - if !env.RunCcmTests { - t.Skip("Test requires CCM, set RUN_CCMTESTS env variable to TRUE") - } - - proxyInstance, err := NewProxyInstanceForGlobalCcmClusters() + proxyInstance, err := NewProxyInstanceForGlobalCcmClusters(t) require.Nil(t, err) defer proxyInstance.Shutdown() - originCluster, targetCluster, err := SetupOrGetGlobalCcmClusters() + originCluster, targetCluster, err := SetupOrGetGlobalCcmClusters(t) require.Nil(t, err) // Initialize test data diff --git a/integration-tests/basicselect_test.go b/integration-tests/basicselect_test.go index dba6f24f..5e9bf8cd 100644 --- a/integration-tests/basicselect_test.go +++ b/integration-tests/basicselect_test.go @@ -2,26 +2,25 @@ package integration_tests import ( "fmt" + "testing" + + "github.com/stretchr/testify/require" + "github.com/datastax/zdm-proxy/integration-tests/env" "github.com/datastax/zdm-proxy/integration-tests/setup" "github.com/datastax/zdm-proxy/integration-tests/utils" - "github.com/stretchr/testify/require" - "testing" ) func TestSaiSelect(t *testing.T) { - if !env.RunCcmTests { - t.Skip("Test requires CCM, set RUN_CCMTESTS env variable to TRUE") - } if !(env.IsDse && env.CompareServerVersion("6.9") >= 0) { t.Skip("Test requires DSE 6.9 cluster") } - proxyInstance, err := NewProxyInstanceForGlobalCcmClusters() + proxyInstance, err := NewProxyInstanceForGlobalCcmClusters(t) require.Nil(t, err) defer proxyInstance.Shutdown() - originCluster, targetCluster, err := SetupOrGetGlobalCcmClusters() + originCluster, targetCluster, err := SetupOrGetGlobalCcmClusters(t) require.Nil(t, err) // Initialize test data diff --git a/integration-tests/basicupdate_test.go b/integration-tests/basicupdate_test.go index d302035c..aa72c504 100644 --- a/integration-tests/basicupdate_test.go +++ b/integration-tests/basicupdate_test.go @@ -2,13 +2,17 @@ package integration_tests import ( "fmt" + "testing" + + gocql "github.com/apache/cassandra-gocql-driver/v2" + "github.com/apache/cassandra-gocql-driver/v2/lz4" + "github.com/apache/cassandra-gocql-driver/v2/snappy" + log "github.com/sirupsen/logrus" + "github.com/stretchr/testify/require" + "github.com/datastax/zdm-proxy/integration-tests/env" "github.com/datastax/zdm-proxy/integration-tests/setup" "github.com/datastax/zdm-proxy/integration-tests/utils" - "github.com/gocql/gocql" - log "github.com/sirupsen/logrus" - "github.com/stretchr/testify/require" - "testing" ) // BasicUpdate tests if update queries run correctly @@ -16,15 +20,11 @@ import ( // performs an update where through the proxy // then loads the unloaded data into the destination func TestBasicUpdate(t *testing.T) { - if !env.RunCcmTests { - t.Skip("Test requires CCM, set RUN_CCMTESTS env variable to TRUE") - } - - proxyInstance, err := NewProxyInstanceForGlobalCcmClusters() + proxyInstance, err := NewProxyInstanceForGlobalCcmClusters(t) require.Nil(t, err) defer proxyInstance.Shutdown() - originCluster, targetCluster, err := SetupOrGetGlobalCcmClusters() + originCluster, targetCluster, err := SetupOrGetGlobalCcmClusters(t) require.Nil(t, err) // Initialize test data @@ -64,18 +64,14 @@ func TestBasicUpdate(t *testing.T) { } func TestCompression(t *testing.T) { - if !env.RunCcmTests { - t.Skip("Test requires CCM, set RUN_CCMTESTS env variable to TRUE") - } - log.SetLevel(log.TraceLevel) defer log.SetLevel(log.InfoLevel) - proxyInstance, err := NewProxyInstanceForGlobalCcmClusters() + proxyInstance, err := NewProxyInstanceForGlobalCcmClusters(t) require.Nil(t, err) defer proxyInstance.Shutdown() - originCluster, targetCluster, err := SetupOrGetGlobalCcmClusters() + originCluster, targetCluster, err := SetupOrGetGlobalCcmClusters(t) require.Nil(t, err) // Initialize test data @@ -86,30 +82,40 @@ func TestCompression(t *testing.T) { // Seed originCluster and targetCluster w/ schema and data setup.SeedData(originCluster.GetSession(), targetCluster.GetSession(), setup.TasksModel, data) - // Connect to proxy as a "client" - cluster := utils.NewCluster("127.0.0.1", "", "", 14002) - cluster.Compressor = gocql.SnappyCompressor{} - proxy, err := cluster.CreateSession() - - if err != nil { - t.Log("Unable to connect to proxy session.") - t.Fatal(err) - } - defer proxy.Close() - - // Run query on proxied connection - err = proxy.Query(fmt.Sprintf("UPDATE %s.%s SET task = 'terrance' WHERE id = d1b05da0-8c20-11ea-9fc6-6d2c86545d91;", setup.TestKeyspace, setup.TasksModel)).Exec() - if err != nil { - t.Log("Mid-migration update failed.") - t.Fatal(err) + compressors := []gocql.Compressor{snappy.SnappyCompressor{}, lz4.LZ4Compressor{}} + + for _, compressor := range compressors { + t.Run(compressor.Name(), func(t *testing.T) { + // Connect to proxy as a "client" + cluster := utils.NewCluster("127.0.0.1", "", "", 14002) + if !env.IsDse && env.CompareServerVersion("4.0.0") >= 0 && compressor.Name() == "snappy" { + cluster.ProtoVersion = 4 // v5 doesn't support snappy + } + cluster.Compressor = compressor + cluster.Logger = gocql.NewLogger(gocql.LogLevelDebug) + proxy, err := cluster.CreateSession() + + if err != nil { + t.Log("Unable to connect to proxy session.") + t.Fatal(err) + } + defer proxy.Close() + + // Run query on proxied connection + err = proxy.Query(fmt.Sprintf("UPDATE %s.%s SET task = 'terrance' WHERE id = d1b05da0-8c20-11ea-9fc6-6d2c86545d91;", setup.TestKeyspace, setup.TasksModel)).Exec() + if err != nil { + t.Log("Mid-migration update failed.") + t.Fatal(err) + } + + // Assertions! + itr := targetCluster.GetSession().Query(fmt.Sprintf("SELECT * FROM %s.%s WHERE id = d1b05da0-8c20-11ea-9fc6-6d2c86545d91;", setup.TestKeyspace, setup.TasksModel)).Iter() + row := make(map[string]interface{}) + + require.True(t, itr.MapScan(row)) + task := setup.MapToTask(row) + + setup.AssertEqual(t, "terrance", task.Task) + }) } - - // Assertions! - itr := targetCluster.GetSession().Query(fmt.Sprintf("SELECT * FROM %s.%s WHERE id = d1b05da0-8c20-11ea-9fc6-6d2c86545d91;", setup.TestKeyspace, setup.TasksModel)).Iter() - row := make(map[string]interface{}) - - require.True(t, itr.MapScan(row)) - task := setup.MapToTask(row) - - setup.AssertEqual(t, "terrance", task.Task) } diff --git a/integration-tests/batch_test.go b/integration-tests/batch_test.go index 84054b60..16e3c2cc 100644 --- a/integration-tests/batch_test.go +++ b/integration-tests/batch_test.go @@ -1,10 +1,10 @@ package integration_tests import ( + "github.com/apache/cassandra-gocql-driver/v2" "github.com/datastax/zdm-proxy/integration-tests/setup" "github.com/datastax/zdm-proxy/integration-tests/simulacron" "github.com/datastax/zdm-proxy/integration-tests/utils" - "github.com/gocql/gocql" "github.com/stretchr/testify/require" "testing" ) diff --git a/integration-tests/ccm/ccm.go b/integration-tests/ccm/ccm.go index 35379b6e..cfee2f02 100644 --- a/integration-tests/ccm/ccm.go +++ b/integration-tests/ccm/ccm.go @@ -4,11 +4,12 @@ import ( "context" "errors" "fmt" - log "github.com/sirupsen/logrus" "os/exec" "runtime" "strings" "time" + + log "github.com/sirupsen/logrus" ) const cmdTimeout = 5 * time.Minute @@ -88,9 +89,9 @@ func UpdateConf(yamlChanges ...string) (string, error) { func Start(jvmArgs ...string) (string, error) { newJvmArgs := make([]string, len(jvmArgs)*2) - for i := 0; i < len(newJvmArgs); i += 2 { - newJvmArgs[i] = "--jvm_arg" - newJvmArgs[i+1] = jvmArgs[i] + for i := 0; i < len(jvmArgs); i++ { + newJvmArgs[i*2] = "--jvm_arg" + newJvmArgs[i*2+1] = jvmArgs[i] } if runtime.GOOS == "windows" { diff --git a/integration-tests/ccm/cluster.go b/integration-tests/ccm/cluster.go index d2b388ef..68c5c8b9 100644 --- a/integration-tests/ccm/cluster.go +++ b/integration-tests/ccm/cluster.go @@ -2,8 +2,10 @@ package ccm import ( "fmt" + + "github.com/apache/cassandra-gocql-driver/v2" + "github.com/datastax/zdm-proxy/integration-tests/env" - "github.com/gocql/gocql" ) type Cluster struct { @@ -15,6 +17,8 @@ type Cluster struct { startNodeIndex int session *gocql.Session + + singleNode bool } func newCluster(name string, version string, isDse bool, startNodeIndex int, numberOfSeedNodes int) *Cluster { @@ -26,6 +30,7 @@ func newCluster(name string, version string, isDse bool, startNodeIndex int, num numberOfSeedNodes: numberOfSeedNodes, startNodeIndex: startNodeIndex, session: nil, + singleNode: numberOfSeedNodes == 1, } } @@ -84,7 +89,7 @@ func (ccmCluster *Cluster) Create(numberOfNodes int, start bool) error { } if start { - _, err = Start() + _, err = Start(fmt.Sprintf("-Dcassandra.ring_delay_ms=%v", ccmCluster.GetDelayMs())) if err != nil { Remove(ccmCluster.name) @@ -118,7 +123,7 @@ func (ccmCluster *Cluster) Start(jvmArgs ...string) error { if err != nil { return err } - _, err = Start(jvmArgs...) + _, err = Start(append(jvmArgs, fmt.Sprintf("-Dcassandra.ring_delay_ms=%v", ccmCluster.GetDelayMs()))...) return err } @@ -147,6 +152,7 @@ func (ccmCluster *Cluster) Remove() error { func (ccmCluster *Cluster) AddNode(index int) error { ccmCluster.SwitchToThis() + ccmCluster.singleNode = false nodeIndex := ccmCluster.startNodeIndex + index _, err := Add( false, @@ -161,7 +167,8 @@ func (ccmCluster *Cluster) AddNode(index int) error { func (ccmCluster *Cluster) StartNode(index int, jvmArgs ...string) error { ccmCluster.SwitchToThis() nodeIndex := ccmCluster.startNodeIndex + index - _, err := StartNode(fmt.Sprintf("node%d", nodeIndex), jvmArgs...) + _, err := StartNode(fmt.Sprintf("node%d", nodeIndex), + append(jvmArgs, fmt.Sprintf("-Dcassandra.ring_delay_ms=%v", ccmCluster.GetDelayMs()))...) return err } @@ -178,3 +185,11 @@ func (ccmCluster *Cluster) RemoveNode(index int) error { _, err := RemoveNode(fmt.Sprintf("node%d", nodeIndex)) return err } + +func (ccmCluster *Cluster) GetDelayMs() int { + if ccmCluster.singleNode { + return 1000 + } else { + return 10000 + } +} diff --git a/integration-tests/connect_test.go b/integration-tests/connect_test.go index 94d3502f..a30531f0 100644 --- a/integration-tests/connect_test.go +++ b/integration-tests/connect_test.go @@ -4,21 +4,24 @@ import ( "bufio" "bytes" "context" + "sync/atomic" + "testing" + "time" + cqlClient "github.com/datastax/go-cassandra-native-protocol/client" "github.com/datastax/go-cassandra-native-protocol/frame" "github.com/datastax/go-cassandra-native-protocol/message" "github.com/datastax/go-cassandra-native-protocol/primitive" + "github.com/rs/zerolog" + log "github.com/sirupsen/logrus" + "github.com/stretchr/testify/require" + "github.com/datastax/zdm-proxy/integration-tests/client" + "github.com/datastax/zdm-proxy/integration-tests/env" "github.com/datastax/zdm-proxy/integration-tests/setup" "github.com/datastax/zdm-proxy/integration-tests/simulacron" "github.com/datastax/zdm-proxy/integration-tests/utils" "github.com/datastax/zdm-proxy/proxy/pkg/config" - "github.com/rs/zerolog" - log "github.com/sirupsen/logrus" - "github.com/stretchr/testify/require" - "sync/atomic" - "testing" - "time" ) func TestGoCqlConnect(t *testing.T) { @@ -67,7 +70,7 @@ func TestCannotConnectWithoutControlConnection(t *testing.T) { for i := 0; i < 1000; i++ { // connect to proxy as a "client" client := cqlClient.NewCqlClient("127.0.0.1:14002", nil) - conn, err := client.ConnectAndInit(context.Background(), primitive.ProtocolVersion4, 0) + conn, err := client.ConnectAndInit(context.Background(), env.DefaultProtocolVersionSimulacron, 0) require.Nil(t, err) _ = conn.Close() } @@ -139,7 +142,7 @@ func TestControlConnectionProtocolVersionNegotiation(t *testing.T) { Query: "SELECT * FROM test", Options: &message.QueryOptions{Consistency: primitive.ConsistencyLevelOne}, } - rsp, err := cqlClientConn.SendAndReceive(frame.NewFrame(primitive.ProtocolVersion3, 0, queryMsg)) + rsp, err := cqlClientConn.SendAndReceive(frame.NewFrame(negotiatedProto, 0, queryMsg)) if err != nil { t.Fatal("query failed:", err) } @@ -192,13 +195,6 @@ func TestRequestedProtocolVersionUnsupportedByProxy(t *testing.T) { expectedVersion primitive.ProtocolVersion errExpected string }{ - { - "request v5, response v4", - primitive.ProtocolVersion5, - "4", - primitive.ProtocolVersion4, - "Invalid or unsupported protocol version (5)", - }, { "request v1, response v4", primitive.ProtocolVersion(0x1), @@ -228,7 +224,7 @@ func TestRequestedProtocolVersionUnsupportedByProxy(t *testing.T) { testSetup.Origin.CqlServer.RequestHandlers = []cqlClient.RequestHandler{cqlClient.NewDriverConnectionInitializationHandler("origin", "dc1", func(_ string) {})} testSetup.Target.CqlServer.RequestHandlers = []cqlClient.RequestHandler{cqlClient.NewDriverConnectionInitializationHandler("target", "dc1", func(_ string) {})} - err = testSetup.Start(cfg, false, primitive.ProtocolVersion3) + err = testSetup.Start(cfg, false, env.DefaultProtocolVersion) require.Nil(t, err) testClient, err := client.NewTestClient(context.Background(), "127.0.0.1:14002") @@ -257,14 +253,6 @@ func TestReturnedProtocolVersionUnsupportedByProxy(t *testing.T) { errExpected string } tests := []*test{ - { - "DSE_V2 request, v5 returned, v4 expected", - primitive.ProtocolVersionDse2, - "4", - primitive.ProtocolVersion5, - primitive.ProtocolVersion4, - "Invalid or unsupported protocol version (5)", - }, { "DSE_V2 request, v1 returned, v4 expected", primitive.ProtocolVersionDse2, @@ -299,7 +287,7 @@ func TestReturnedProtocolVersionUnsupportedByProxy(t *testing.T) { testSetup.Origin.CqlServer.RequestRawHandlers = []cqlClient.RawRequestHandler{rawHandler} testSetup.Target.CqlServer.RequestRawHandlers = []cqlClient.RawRequestHandler{rawHandler} - err = testSetup.Start(cfg, false, primitive.ProtocolVersion4) + err = testSetup.Start(cfg, false, env.DefaultProtocolVersion) require.Nil(t, err) testClient, err := client.NewTestClient(context.Background(), "127.0.0.1:14002") @@ -335,9 +323,7 @@ func TestReturnedProtocolVersionUnsupportedByProxy(t *testing.T) { func createFrameWithUnsupportedVersion(version primitive.ProtocolVersion, streamId int16, isResponse bool) ([]byte, error) { mostSimilarVersion := version - if version > primitive.ProtocolVersionDse2 { - mostSimilarVersion = primitive.ProtocolVersionDse2 - } else if version < primitive.ProtocolVersion2 { + if version < primitive.ProtocolVersion2 { mostSimilarVersion = primitive.ProtocolVersion2 } @@ -394,7 +380,7 @@ func TestHandlingOfInternalHeartbeat(t *testing.T) { // Connect to proxy as a "client" proxyClient := cqlClient.NewCqlClient("127.0.0.1:14002", nil) - cqlClientConn, err := proxyClient.ConnectAndInit(context.Background(), primitive.ProtocolVersion4, 0) + cqlClientConn, err := proxyClient.ConnectAndInit(context.Background(), env.DefaultProtocolVersionSimulacron, 0) require.Nil(t, err) defer cqlClientConn.Close() @@ -403,7 +389,7 @@ func TestHandlingOfInternalHeartbeat(t *testing.T) { Options: nil, } - _, err = cqlClientConn.SendAndReceive(frame.NewFrame(primitive.ProtocolVersion4, 0, queryMsg)) + _, err = cqlClientConn.SendAndReceive(frame.NewFrame(env.DefaultProtocolVersionSimulacron, 0, queryMsg)) require.Nil(t, err) // sleep longer than heartbeat interval @@ -412,7 +398,7 @@ func TestHandlingOfInternalHeartbeat(t *testing.T) { err = testSetup.Target.DeleteLogs() require.Nil(t, err) - _, err = cqlClientConn.SendAndReceive(frame.NewFrame(primitive.ProtocolVersion4, 0, queryMsg)) + _, err = cqlClientConn.SendAndReceive(frame.NewFrame(env.DefaultProtocolVersionSimulacron, 0, queryMsg)) require.Nil(t, err) err = buffWriter.Flush() diff --git a/integration-tests/controlconn_test.go b/integration-tests/controlconn_test.go index 02b60add..aee2a9eb 100644 --- a/integration-tests/controlconn_test.go +++ b/integration-tests/controlconn_test.go @@ -3,24 +3,26 @@ package integration_tests import ( "context" "fmt" + "net" + "sort" + "sync" + "sync/atomic" + "testing" + "time" + "github.com/datastax/go-cassandra-native-protocol/client" "github.com/datastax/go-cassandra-native-protocol/frame" "github.com/datastax/go-cassandra-native-protocol/message" "github.com/datastax/go-cassandra-native-protocol/primitive" + "github.com/google/uuid" + log "github.com/sirupsen/logrus" + "github.com/stretchr/testify/require" + "github.com/datastax/zdm-proxy/integration-tests/env" "github.com/datastax/zdm-proxy/integration-tests/setup" "github.com/datastax/zdm-proxy/integration-tests/simulacron" "github.com/datastax/zdm-proxy/integration-tests/utils" "github.com/datastax/zdm-proxy/proxy/pkg/zdmproxy" - "github.com/google/uuid" - log "github.com/sirupsen/logrus" - "github.com/stretchr/testify/require" - "net" - "sort" - "sync" - "sync/atomic" - "testing" - "time" ) func TestGetHosts(t *testing.T) { @@ -465,7 +467,7 @@ func TestConnectionAssignment(t *testing.T) { queryString := fmt.Sprintf("INSERT INTO testconnections_%d (a) VALUES ('a')", i) openConnectionAndSendRequestFunc := func() { - cqlConn, err := testClient.ConnectAndInit(context.Background(), primitive.ProtocolVersion4, 1) + cqlConn, err := testClient.ConnectAndInit(context.Background(), env.DefaultProtocolVersionSimulacron, 1) require.Nil(t, err, "testClient setup failed: %v", err) defer cqlConn.Close() @@ -474,7 +476,7 @@ func TestConnectionAssignment(t *testing.T) { Options: nil, } - queryFrame := frame.NewFrame(primitive.ProtocolVersion4, 5, queryMsg) + queryFrame := frame.NewFrame(env.DefaultProtocolVersionSimulacron, 5, queryMsg) _, err = cqlConn.SendAndReceive(queryFrame) require.Nil(t, err) } @@ -597,7 +599,7 @@ func TestRefreshTopologyEventHandler(t *testing.T) { Port: 9042, }, } - topologyEventFrame := frame.NewFrame(primitive.ProtocolVersion4, -1, topologyEvent) + topologyEventFrame := frame.NewFrame(env.DefaultProtocolVersion, -1, topologyEvent) err = serverConn.Send(topologyEventFrame) require.Nil(t, err) @@ -759,7 +761,7 @@ func TestRefreshTopologyEventHandler(t *testing.T) { newRegisterHandler(&originRegisterMessages, originRegisterLock), createMutableHandler(originHandler)} testSetup.Target.CqlServer.RequestHandlers = []client.RequestHandler{ newRegisterHandler(&targetRegisterMessages, targetRegisterLock), createMutableHandler(targetHandler)} - err = testSetup.Start(conf, false, primitive.ProtocolVersion4) + err = testSetup.Start(conf, false, env.DefaultProtocolVersion) require.Nil(t, err) checkRegisterMessages(t, originRegisterMessages, originRegisterLock) checkRegisterMessages(t, targetRegisterMessages, targetRegisterLock) diff --git a/integration-tests/cqlserver/client.go b/integration-tests/cqlserver/client.go index 5fc8ba7a..d02850b5 100644 --- a/integration-tests/cqlserver/client.go +++ b/integration-tests/cqlserver/client.go @@ -3,6 +3,7 @@ package cqlserver import ( "context" "fmt" + "github.com/datastax/go-cassandra-native-protocol/client" "github.com/datastax/go-cassandra-native-protocol/primitive" ) diff --git a/integration-tests/cqlserver/cluster.go b/integration-tests/cqlserver/cluster.go index 0801c0bf..7a125933 100644 --- a/integration-tests/cqlserver/cluster.go +++ b/integration-tests/cqlserver/cluster.go @@ -3,9 +3,10 @@ package cqlserver import ( "context" "fmt" + "time" + "github.com/datastax/go-cassandra-native-protocol/client" log "github.com/sirupsen/logrus" - "time" ) type Cluster struct { @@ -43,7 +44,7 @@ func NewCqlServerCluster(listenAddr string, port int, username string, password } func (recv *Cluster) Start() error { - ctx, _ := context.WithTimeout(context.Background(), 10*time.Second) + ctx, _ := context.WithTimeout(context.Background(), 100*time.Second) return recv.CqlServer.Start(ctx) } diff --git a/integration-tests/env/vars.go b/integration-tests/env/vars.go index 17bd0e9b..cf7212c5 100644 --- a/integration-tests/env/vars.go +++ b/integration-tests/env/vars.go @@ -2,11 +2,15 @@ package env import ( "flag" + "fmt" "math/rand" "os" + "slices" "strconv" "strings" "time" + + "github.com/datastax/go-cassandra-native-protocol/primitive" ) const ( @@ -18,17 +22,26 @@ var Rand = rand.New(rand.NewSource(time.Now().UTC().UnixNano())) var ServerVersion string var CassandraVersion string var DseVersion string +var ServerVersionLogStr string var IsDse bool var RunCcmTests bool var RunMockTests bool var RunAllTlsTests bool var Debug bool +var SupportedProtocolVersions []primitive.ProtocolVersion +var AllProtocolVersions []primitive.ProtocolVersion = []primitive.ProtocolVersion{ + primitive.ProtocolVersion2, primitive.ProtocolVersion3, primitive.ProtocolVersion4, + primitive.ProtocolVersion5, primitive.ProtocolVersionDse1, primitive.ProtocolVersionDse2, +} +var DefaultProtocolVersion primitive.ProtocolVersion +var DefaultProtocolVersionSimulacron primitive.ProtocolVersion +var DefaultProtocolVersionTestClient primitive.ProtocolVersion func InitGlobalVars() { flags := map[string]interface{}{ "CASSANDRA_VERSION": flag.String( "CASSANDRA_VERSION", - getEnvironmentVariableOrDefault("CASSANDRA_VERSION", "3.11.7"), + getEnvironmentVariableOrDefault("CASSANDRA_VERSION", "5.0.6"), "CASSANDRA_VERSION"), "DSE_VERSION": flag.String( @@ -70,8 +83,36 @@ func InitGlobalVars() { IsDse = true ServerVersion = DseVersion } else { - ServerVersion = CassandraVersion - IsDse = false + split := strings.SplitAfter(CassandraVersion, "dse-") + if len(split) == 2 { + IsDse = true + ServerVersion = split[1] + DseVersion = ServerVersion + CassandraVersion = "" + } else { + ServerVersion = CassandraVersion + IsDse = false + } + } + + SupportedProtocolVersions = supportedProtocolVersions() + + ServerVersionLogStr = serverVersionLogString() + + DefaultProtocolVersion = ComputeDefaultProtocolVersion() + + if DefaultProtocolVersion <= primitive.ProtocolVersion2 { + DefaultProtocolVersionSimulacron = primitive.ProtocolVersion3 + } else if DefaultProtocolVersion >= primitive.ProtocolVersion5 { + DefaultProtocolVersionSimulacron = primitive.ProtocolVersion4 + } else { + DefaultProtocolVersionSimulacron = DefaultProtocolVersion + } + + if DefaultProtocolVersion.SupportsModernFramingLayout() { + DefaultProtocolVersionTestClient = primitive.ProtocolVersion4 + } else { + DefaultProtocolVersionTestClient = DefaultProtocolVersion } if strings.ToLower(runCcmTests) == "true" { @@ -143,3 +184,81 @@ func getEnvironmentVariableBoolOrDefault(key string, defaultValue bool) bool { return defaultValue } } + +func SupportsProtocolVersion(protoVersion primitive.ProtocolVersion) bool { + return slices.Contains(SupportedProtocolVersions, protoVersion) +} + +func supportedProtocolVersions() []primitive.ProtocolVersion { + v := parseVersion(ServerVersion) + if IsDse { + if v[0] >= 6 { + return []primitive.ProtocolVersion{ + primitive.ProtocolVersion3, primitive.ProtocolVersion4, + primitive.ProtocolVersionDse1, primitive.ProtocolVersionDse2} + } + if v[0] >= 5 { + return []primitive.ProtocolVersion{ + primitive.ProtocolVersion3, primitive.ProtocolVersion4, primitive.ProtocolVersionDse1} + } + + if v[0] >= 4 { + return []primitive.ProtocolVersion{primitive.ProtocolVersion2, primitive.ProtocolVersion3} + } + } else { + if v[0] >= 4 { + return []primitive.ProtocolVersion{ + primitive.ProtocolVersion3, primitive.ProtocolVersion4, primitive.ProtocolVersion5} + } + if v[0] >= 3 { + return []primitive.ProtocolVersion{ + primitive.ProtocolVersion3, primitive.ProtocolVersion4} + } + if v[0] >= 2 { + if v[1] >= 2 { + return []primitive.ProtocolVersion{ + primitive.ProtocolVersion2, primitive.ProtocolVersion3, primitive.ProtocolVersion4} + } + + if v[1] >= 1 { + return []primitive.ProtocolVersion{primitive.ProtocolVersion2, primitive.ProtocolVersion3} + } + + if v[1] >= 0 { + return []primitive.ProtocolVersion{primitive.ProtocolVersion2} + } + } + } + + panic(fmt.Sprintf("Unsupported server version IsDse=%v Version=%v", IsDse, ServerVersion)) +} + +func serverVersionLogString() string { + if IsDse { + return fmt.Sprintf("dse-%v", ServerVersion) + } else { + return ServerVersion + } +} + +func ProtocolVersionStr(v primitive.ProtocolVersion) string { + switch v { + case primitive.ProtocolVersionDse1: + return "DSEv1" + case primitive.ProtocolVersionDse2: + return "DSEv2" + } + return strconv.Itoa(int(v)) +} + +func ComputeDefaultProtocolVersion() primitive.ProtocolVersion { + orderedProtocolVersions := []primitive.ProtocolVersion{ + primitive.ProtocolVersionDse2, primitive.ProtocolVersionDse1, primitive.ProtocolVersion5, + primitive.ProtocolVersion4, primitive.ProtocolVersion3, primitive.ProtocolVersion2} + for _, v := range orderedProtocolVersions { + if SupportsProtocolVersion(v) { + return v + } + } + panic(fmt.Sprintf("Unable to compute protocol version for server version %v", ServerVersionLogStr)) +} diff --git a/integration-tests/events_test.go b/integration-tests/events_test.go index af16da10..4f80cce2 100644 --- a/integration-tests/events_test.go +++ b/integration-tests/events_test.go @@ -3,24 +3,22 @@ package integration_tests import ( "context" "fmt" + "testing" + "time" + "github.com/datastax/go-cassandra-native-protocol/message" "github.com/datastax/go-cassandra-native-protocol/primitive" + "github.com/stretchr/testify/require" + "github.com/datastax/zdm-proxy/integration-tests/ccm" "github.com/datastax/zdm-proxy/integration-tests/client" "github.com/datastax/zdm-proxy/integration-tests/env" "github.com/datastax/zdm-proxy/integration-tests/setup" - "github.com/stretchr/testify/require" - "testing" - "time" ) // TestSchemaEvents tests the schema event message handling func TestSchemaEvents(t *testing.T) { - if !env.RunCcmTests { - t.Skip("Test requires CCM, set RUN_CCMTESTS env variable to TRUE") - } - - originCluster, targetCluster, err := SetupOrGetGlobalCcmClusters() + originCluster, targetCluster, err := SetupOrGetGlobalCcmClusters(t) require.Nil(t, err) tests := []struct { @@ -42,7 +40,7 @@ func TestSchemaEvents(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - proxyInstance, err := NewProxyInstanceForGlobalCcmClusters() + proxyInstance, err := NewProxyInstanceForGlobalCcmClusters(t) require.Nil(t, err) defer proxyInstance.Shutdown() @@ -56,10 +54,10 @@ func TestSchemaEvents(t *testing.T) { require.True(t, err == nil, "unable to connect to test client: %v", err) defer testClientForSchemaChange.Shutdown() - err = testClientForEvents.PerformDefaultHandshake(context.Background(), primitive.ProtocolVersion4, false) + err = testClientForEvents.PerformDefaultHandshake(context.Background(), env.DefaultProtocolVersionTestClient, false) require.True(t, err == nil, "could not perform handshake: %v", err) - err = testClientForSchemaChange.PerformDefaultHandshake(context.Background(), primitive.ProtocolVersion4, false) + err = testClientForSchemaChange.PerformDefaultHandshake(context.Background(), env.DefaultProtocolVersionTestClient, false) require.True(t, err == nil, "could not perform handshake: %v", err) // send REGISTER to proxy @@ -70,7 +68,7 @@ func TestSchemaEvents(t *testing.T) { primitive.EventTypeTopologyChange}, } - response, _, err := testClientForEvents.SendMessage(context.Background(), primitive.ProtocolVersion4, registerMsg) + response, _, err := testClientForEvents.SendMessage(context.Background(), env.DefaultProtocolVersionTestClient, registerMsg) require.True(t, err == nil, "could not send register frame: %v", err) _, ok := response.Body.Message.(*message.Ready) @@ -82,7 +80,7 @@ func TestSchemaEvents(t *testing.T) { "WITH REPLICATION = {'class':'SimpleStrategy', 'replication_factor':1};", env.Rand.Uint64()), } - response, _, err = testClientForSchemaChange.SendMessage(context.Background(), primitive.ProtocolVersion4, createKeyspaceMessage) + response, _, err = testClientForSchemaChange.SendMessage(context.Background(), env.DefaultProtocolVersionTestClient, createKeyspaceMessage) require.True(t, err == nil, "could not send create keyspace request: %v", err) _, ok = response.Body.Message.(*message.SchemaChangeResult) @@ -107,11 +105,7 @@ func TestSchemaEvents(t *testing.T) { // TestTopologyStatusEvents tests the topology and status events handling func TestTopologyStatusEvents(t *testing.T) { - if !env.RunCcmTests { - t.Skip("Test requires CCM, set RUN_CCMTESTS env variable to TRUE") - } - - tempCcmSetup, err := setup.NewTemporaryCcmTestSetup(true, false) + tempCcmSetup, err := setup.NewTemporaryCcmTestSetup(t, true, false) require.Nil(t, err) defer tempCcmSetup.Cleanup() @@ -147,7 +141,7 @@ func TestTopologyStatusEvents(t *testing.T) { require.True(t, err == nil, "unable to connect to test client: %v", err) defer testClientForEvents.Shutdown() - err = testClientForEvents.PerformDefaultHandshake(context.Background(), primitive.ProtocolVersion4, false) + err = testClientForEvents.PerformDefaultHandshake(context.Background(), env.DefaultProtocolVersionTestClient, false) require.True(t, err == nil, "could not perform handshake: %v", err) registerMsg := &message.Register{ @@ -157,7 +151,7 @@ func TestTopologyStatusEvents(t *testing.T) { primitive.EventTypeTopologyChange}, } - response, _, err := testClientForEvents.SendMessage(context.Background(), primitive.ProtocolVersion4, registerMsg) + response, _, err := testClientForEvents.SendMessage(context.Background(), env.DefaultProtocolVersionTestClient, registerMsg) require.True(t, err == nil, "could not send register frame: %v", err) _, ok := response.Body.Message.(*message.Ready) diff --git a/integration-tests/functioncalls_test.go b/integration-tests/functioncalls_test.go index 96fbbfeb..20ebfc7b 100644 --- a/integration-tests/functioncalls_test.go +++ b/integration-tests/functioncalls_test.go @@ -4,18 +4,21 @@ import ( "context" "encoding/base64" "encoding/json" + "regexp" + "testing" + "github.com/datastax/go-cassandra-native-protocol/client" "github.com/datastax/go-cassandra-native-protocol/datacodec" "github.com/datastax/go-cassandra-native-protocol/datatype" "github.com/datastax/go-cassandra-native-protocol/frame" "github.com/datastax/go-cassandra-native-protocol/message" "github.com/datastax/go-cassandra-native-protocol/primitive" - "github.com/datastax/zdm-proxy/integration-tests/setup" - "github.com/datastax/zdm-proxy/integration-tests/simulacron" "github.com/google/uuid" "github.com/stretchr/testify/require" - "regexp" - "testing" + + "github.com/datastax/zdm-proxy/integration-tests/env" + "github.com/datastax/zdm-proxy/integration-tests/setup" + "github.com/datastax/zdm-proxy/integration-tests/simulacron" ) type param struct { @@ -153,7 +156,7 @@ func TestNowFunctionReplacementSimpleStatement(t *testing.T) { defer simulacronSetup.Cleanup() testClient := client.NewCqlClient("127.0.0.1:14002", nil) - cqlConn, err := testClient.ConnectAndInit(context.Background(), primitive.ProtocolVersion4, 1) + cqlConn, err := testClient.ConnectAndInit(context.Background(), env.DefaultProtocolVersionSimulacron, 1) require.Nil(t, err, "testClient setup failed: %v", err) defer cqlConn.Close() @@ -165,7 +168,7 @@ func TestNowFunctionReplacementSimpleStatement(t *testing.T) { Options: test.queryOpts, } - f := frame.NewFrame(primitive.ProtocolVersion4, 2, queryMsg) + f := frame.NewFrame(env.DefaultProtocolVersionSimulacron, 2, queryMsg) _, err := cqlConn.SendAndReceive(f) require.Nil(tt, err) @@ -1356,7 +1359,7 @@ func TestNowFunctionReplacementPreparedStatement(t *testing.T) { defer simulacronSetup.Cleanup() testClient := client.NewCqlClient("127.0.0.1:14002", nil) - cqlConn, err := testClient.ConnectAndInit(context.Background(), primitive.ProtocolVersion4, 1) + cqlConn, err := testClient.ConnectAndInit(context.Background(), env.DefaultProtocolVersionSimulacron, 1) require.Nil(t, err, "testClient setup failed: %v", err) defer cqlConn.Close() @@ -1418,7 +1421,7 @@ func TestNowFunctionReplacementPreparedStatement(t *testing.T) { Query: test.originalQuery, } - f := frame.NewFrame(primitive.ProtocolVersion4, 0, queryMsg) + f := frame.NewFrame(env.DefaultProtocolVersionSimulacron, 0, queryMsg) resp, err := cqlConn.SendAndReceive(f) require.Nil(t, err) @@ -1465,7 +1468,7 @@ func TestNowFunctionReplacementPreparedStatement(t *testing.T) { ResultMetadataId: prepared.ResultMetadataId, Options: queryOpts, } - f = frame.NewFrame(primitive.ProtocolVersion4, 0, executeMsg) + f = frame.NewFrame(env.DefaultProtocolVersionSimulacron, 0, executeMsg) resp, err = cqlConn.SendAndReceive(f) require.Nil(t, err) @@ -1577,7 +1580,7 @@ func TestNowFunctionReplacementPreparedStatement(t *testing.T) { ResultMetadataId: prepared.ResultMetadataId, Options: queryOptsNamed, } - f = frame.NewFrame(primitive.ProtocolVersion4, 0, executeMsg) + f = frame.NewFrame(env.DefaultProtocolVersionSimulacron, 0, executeMsg) _, err = cqlConn.SendAndReceive(f) require.Nil(t, err) @@ -2172,7 +2175,7 @@ func TestNowFunctionReplacementBatchStatement(t *testing.T) { defer simulacronSetup.Cleanup() testClient := client.NewCqlClient("127.0.0.1:14002", nil) - cqlConn, err := testClient.ConnectAndInit(context.Background(), primitive.ProtocolVersion4, 1) + cqlConn, err := testClient.ConnectAndInit(context.Background(), env.DefaultProtocolVersionSimulacron, 1) require.Nil(t, err, "testClient setup failed: %v", err) defer cqlConn.Close() @@ -2254,7 +2257,7 @@ func TestNowFunctionReplacementBatchStatement(t *testing.T) { if !p.isReplacedNow { codec, err := datacodec.NewCodec(p.dataType) require.Nil(t, err) - value, err := codec.Encode(p.value, primitive.ProtocolVersion4) + value, err := codec.Encode(p.value, env.DefaultProtocolVersionSimulacron) require.Nil(t, err) positionalValues = append(positionalValues, primitive.NewValue(value)) } @@ -2280,7 +2283,7 @@ func TestNowFunctionReplacementBatchStatement(t *testing.T) { prepareMsg := &message.Prepare{ Query: childStatement.originalQuery, } - f := frame.NewFrame(primitive.ProtocolVersion4, 0, prepareMsg) + f := frame.NewFrame(env.DefaultProtocolVersionSimulacron, 0, prepareMsg) resp, err := cqlConn.SendAndReceive(f) require.Nil(t, err) prepared, ok := resp.Body.Message.(*message.PreparedResult) @@ -2306,7 +2309,7 @@ func TestNowFunctionReplacementBatchStatement(t *testing.T) { Children: batchChildStatements, } - f := frame.NewFrame(primitive.ProtocolVersion4, 0, batchMsg) + f := frame.NewFrame(env.DefaultProtocolVersionSimulacron, 0, batchMsg) resp, err := cqlConn.SendAndReceive(f) require.Nil(t, err) diff --git a/integration-tests/main_test.go b/integration-tests/main_test.go index c7195f42..d067d6b9 100644 --- a/integration-tests/main_test.go +++ b/integration-tests/main_test.go @@ -1,20 +1,20 @@ package integration_tests import ( + "os" + "testing" + + log "github.com/sirupsen/logrus" + "github.com/datastax/zdm-proxy/integration-tests/ccm" "github.com/datastax/zdm-proxy/integration-tests/env" "github.com/datastax/zdm-proxy/integration-tests/setup" "github.com/datastax/zdm-proxy/proxy/pkg/zdmproxy" - "github.com/gocql/gocql" - log "github.com/sirupsen/logrus" - "os" - "testing" ) func TestMain(m *testing.M) { env.InitGlobalVars() - gocql.TimeoutLimit = 5 if env.Debug { log.SetLevel(log.DebugLevel) } else { @@ -24,13 +24,13 @@ func TestMain(m *testing.M) { os.Exit(RunTests(m)) } -func SetupOrGetGlobalCcmClusters() (*ccm.Cluster, *ccm.Cluster, error) { - originCluster, err := setup.GetGlobalTestClusterOrigin() +func SetupOrGetGlobalCcmClusters(t *testing.T) (*ccm.Cluster, *ccm.Cluster, error) { + originCluster, err := setup.GetGlobalTestClusterOrigin(t) if err != nil { return nil, nil, err } - targetCluster, err := setup.GetGlobalTestClusterTarget() + targetCluster, err := setup.GetGlobalTestClusterTarget(t) if err != nil { return nil, nil, err } @@ -43,8 +43,8 @@ func RunTests(m *testing.M) int { return m.Run() } -func NewProxyInstanceForGlobalCcmClusters() (*zdmproxy.ZdmProxy, error) { - originCluster, targetCluster, err := SetupOrGetGlobalCcmClusters() +func NewProxyInstanceForGlobalCcmClusters(t *testing.T) (*zdmproxy.ZdmProxy, error) { + originCluster, targetCluster, err := SetupOrGetGlobalCcmClusters(t) if err != nil { return nil, err } diff --git a/integration-tests/metrics_test.go b/integration-tests/metrics_test.go index 3da06989..01dc3af5 100644 --- a/integration-tests/metrics_test.go +++ b/integration-tests/metrics_test.go @@ -2,25 +2,28 @@ package integration_tests import ( "fmt" + "net/http" + "sort" + "strconv" + "strings" + "sync" + "testing" + "time" + "github.com/datastax/go-cassandra-native-protocol/client" "github.com/datastax/go-cassandra-native-protocol/frame" "github.com/datastax/go-cassandra-native-protocol/message" "github.com/datastax/go-cassandra-native-protocol/primitive" + "github.com/prometheus/client_golang/prometheus/promhttp" + log "github.com/sirupsen/logrus" + "github.com/stretchr/testify/require" + + "github.com/datastax/zdm-proxy/integration-tests/env" "github.com/datastax/zdm-proxy/integration-tests/setup" "github.com/datastax/zdm-proxy/integration-tests/utils" "github.com/datastax/zdm-proxy/proxy/pkg/config" "github.com/datastax/zdm-proxy/proxy/pkg/httpzdmproxy" "github.com/datastax/zdm-proxy/proxy/pkg/metrics" - "github.com/prometheus/client_golang/prometheus/promhttp" - log "github.com/sirupsen/logrus" - "github.com/stretchr/testify/require" - "net/http" - "sort" - "strconv" - "strings" - "sync" - "testing" - "time" ) var nodeMetrics = []metrics.Metric{ @@ -69,17 +72,21 @@ var proxyMetrics = []metrics.Metric{ var allMetrics = append(proxyMetrics, nodeMetrics...) -var insertQuery = frame.NewFrame( - primitive.ProtocolVersion4, - client.ManagedStreamId, - &message.Query{Query: "INSERT INTO ks1.t1"}, -) +func getInsertQuery() *frame.Frame { + return frame.NewFrame( + env.DefaultProtocolVersion, + client.ManagedStreamId, + &message.Query{Query: "INSERT INTO ks1.t1"}, + ) +} -var selectQuery = frame.NewFrame( - primitive.ProtocolVersion4, - client.ManagedStreamId, - &message.Query{Query: "SELECT * FROM ks1.t1"}, -) +func getSelectQuery() *frame.Frame { + return frame.NewFrame( + env.DefaultProtocolVersion, + client.ManagedStreamId, + &message.Query{Query: "SELECT * FROM ks1.t1"}, + ) +} func testMetrics(t *testing.T, metricsHandler *httpzdmproxy.HandlerWithFallback) { @@ -125,7 +132,7 @@ func testMetrics(t *testing.T, metricsHandler *httpzdmproxy.HandlerWithFallback) testSetup.Origin.CqlServer.RequestHandlers = []client.RequestHandler{client.RegisterHandler, client.HeartbeatHandler, client.HandshakeHandler, client.NewSystemTablesHandler("cluster1", "dc1"), handleReads, handleWrites} testSetup.Target.CqlServer.RequestHandlers = []client.RequestHandler{client.RegisterHandler, client.HeartbeatHandler, client.HandshakeHandler, client.NewSystemTablesHandler("cluster2", "dc2"), handleReads, handleWrites} - err = testSetup.Start(conf, false, primitive.ProtocolVersion4) + err = testSetup.Start(conf, false, env.DefaultProtocolVersion) require.Nil(t, err) wg := &sync.WaitGroup{} @@ -143,7 +150,7 @@ func testMetrics(t *testing.T, metricsHandler *httpzdmproxy.HandlerWithFallback) lines := GatherMetrics(t, conf, false) checkMetrics(t, false, lines, conf.ReadMode, 0, 0, 0, 0, 0, 0, 0, 0, true, true, originEndpoint, targetEndpoint, asyncEndpoint, 0, 0, 0) - err = testSetup.Client.Connect(primitive.ProtocolVersion4) + err = testSetup.Client.Connect(env.DefaultProtocolVersion) require.Nil(t, err) clientConn := testSetup.Client.CqlConnection @@ -155,7 +162,7 @@ func testMetrics(t *testing.T, metricsHandler *httpzdmproxy.HandlerWithFallback) // but all of these are "system" requests so not tracked checkMetrics(t, true, lines, conf.ReadMode, 1, 1, 1, expectedAsyncConnections, 0, 0, 0, 0, true, true, originEndpoint, targetEndpoint, asyncEndpoint, 0, 0, 0) - _, err = clientConn.SendAndReceive(insertQuery) + _, err = clientConn.SendAndReceive(getInsertQuery()) require.Nil(t, err) lines = GatherMetrics(t, conf, true) @@ -166,7 +173,7 @@ func testMetrics(t *testing.T, metricsHandler *httpzdmproxy.HandlerWithFallback) // only QUERY is tracked checkMetrics(t, true, lines, conf.ReadMode, 1, 1, 1, expectedAsyncConnections, 1, 0, 0, 0, true, true, originEndpoint, targetEndpoint, asyncEndpoint, 0, 0, 0) - _, err = clientConn.SendAndReceive(selectQuery) + _, err = clientConn.SendAndReceive(getSelectQuery()) require.Nil(t, err) lines = GatherMetrics(t, conf, true) diff --git a/integration-tests/noresponsefromcluster_test.go b/integration-tests/noresponsefromcluster_test.go index d5810b4d..233f0db0 100644 --- a/integration-tests/noresponsefromcluster_test.go +++ b/integration-tests/noresponsefromcluster_test.go @@ -2,14 +2,17 @@ package integration_tests import ( "context" + "strings" + "testing" + "github.com/datastax/go-cassandra-native-protocol/message" "github.com/datastax/go-cassandra-native-protocol/primitive" + "github.com/stretchr/testify/require" + "github.com/datastax/zdm-proxy/integration-tests/client" + "github.com/datastax/zdm-proxy/integration-tests/env" "github.com/datastax/zdm-proxy/integration-tests/setup" "github.com/datastax/zdm-proxy/integration-tests/simulacron" - "github.com/stretchr/testify/require" - "strings" - "testing" ) func TestAtLeastOneClusterReturnsNoResponse(t *testing.T) { @@ -23,7 +26,7 @@ func TestAtLeastOneClusterReturnsNoResponse(t *testing.T) { defer testClient.Shutdown() - err = testClient.PerformDefaultHandshake(context.Background(), primitive.ProtocolVersion4, false) + err = testClient.PerformDefaultHandshake(context.Background(), env.DefaultProtocolVersionSimulacron, false) require.True(t, err == nil, "No-auth handshake failed: %s", err) queryPrimeNoResponse := @@ -82,7 +85,7 @@ func TestAtLeastOneClusterReturnsNoResponse(t *testing.T) { PositionalValues: []*primitive.Value{primitive.NewValue([]byte("john"))}, }, } - response, _, err := testClient.SendMessage(context.Background(), primitive.ProtocolVersion4, query) + response, _, err := testClient.SendMessage(context.Background(), env.DefaultProtocolVersionSimulacron, query) require.True(t, response == nil, "a response has been received") require.True(t, err != nil, "no error has been received, but the request should have failed") diff --git a/integration-tests/options_test.go b/integration-tests/options_test.go index c7bf25a2..bb68b3d3 100644 --- a/integration-tests/options_test.go +++ b/integration-tests/options_test.go @@ -1,13 +1,15 @@ package integration_tests import ( + "testing" + "github.com/datastax/go-cassandra-native-protocol/client" "github.com/datastax/go-cassandra-native-protocol/frame" "github.com/datastax/go-cassandra-native-protocol/message" - "github.com/datastax/go-cassandra-native-protocol/primitive" - "github.com/datastax/zdm-proxy/integration-tests/setup" "github.com/stretchr/testify/require" - "testing" + + "github.com/datastax/zdm-proxy/integration-tests/env" + "github.com/datastax/zdm-proxy/integration-tests/setup" ) func TestOptionsShouldComeFromTarget(t *testing.T) { @@ -19,10 +21,10 @@ func TestOptionsShouldComeFromTarget(t *testing.T) { testSetup.Origin.CqlServer.RequestHandlers = []client.RequestHandler{client.RegisterHandler, newOptionsHandler(map[string][]string{"FROM": {"origin"}}), client.HandshakeHandler, client.NewSystemTablesHandler("cluster2", "dc2")} testSetup.Target.CqlServer.RequestHandlers = []client.RequestHandler{client.RegisterHandler, newOptionsHandler(map[string][]string{"FROM": {"target"}}), client.HandshakeHandler, client.NewSystemTablesHandler("cluster1", "dc1")} - err = testSetup.Start(conf, true, primitive.ProtocolVersion4) + err = testSetup.Start(conf, true, env.DefaultProtocolVersion) require.Nil(t, err) - request := frame.NewFrame(primitive.ProtocolVersion4, client.ManagedStreamId, &message.Options{}) + request := frame.NewFrame(env.DefaultProtocolVersion, client.ManagedStreamId, &message.Options{}) response, err := testSetup.Client.CqlConnection.SendAndReceive(request) require.Nil(t, err) require.IsType(t, &message.Supported{}, response.Body.Message) @@ -40,10 +42,10 @@ func TestCommonCompressionAlgorithms(t *testing.T) { testSetup.Origin.CqlServer.RequestHandlers = []client.RequestHandler{client.RegisterHandler, newOptionsHandler(map[string][]string{"COMPRESSION": {"snappy"}}), client.HandshakeHandler, client.NewSystemTablesHandler("cluster2", "dc2")} testSetup.Target.CqlServer.RequestHandlers = []client.RequestHandler{client.RegisterHandler, newOptionsHandler(map[string][]string{"COMPRESSION": {"snappy", "lz4"}}), client.HandshakeHandler, client.NewSystemTablesHandler("cluster1", "dc1")} - err = testSetup.Start(conf, true, primitive.ProtocolVersion4) + err = testSetup.Start(conf, true, env.DefaultProtocolVersion) require.Nil(t, err) - request := frame.NewFrame(primitive.ProtocolVersion4, client.ManagedStreamId, &message.Options{}) + request := frame.NewFrame(env.DefaultProtocolVersion, client.ManagedStreamId, &message.Options{}) response, err := testSetup.Client.CqlConnection.SendAndReceive(request) require.Nil(t, err) require.IsType(t, &message.Supported{}, response.Body.Message) diff --git a/integration-tests/prepared_statements_test.go b/integration-tests/prepared_statements_test.go index 53a50294..78774e9f 100644 --- a/integration-tests/prepared_statements_test.go +++ b/integration-tests/prepared_statements_test.go @@ -4,22 +4,25 @@ import ( "bytes" "context" "fmt" + "sync" + "testing" + "time" + client2 "github.com/datastax/go-cassandra-native-protocol/client" "github.com/datastax/go-cassandra-native-protocol/datatype" "github.com/datastax/go-cassandra-native-protocol/frame" "github.com/datastax/go-cassandra-native-protocol/message" "github.com/datastax/go-cassandra-native-protocol/primitive" + "github.com/rs/zerolog" + log "github.com/sirupsen/logrus" + "github.com/stretchr/testify/require" + "github.com/datastax/zdm-proxy/integration-tests/client" + "github.com/datastax/zdm-proxy/integration-tests/env" "github.com/datastax/zdm-proxy/integration-tests/setup" "github.com/datastax/zdm-proxy/integration-tests/simulacron" "github.com/datastax/zdm-proxy/integration-tests/utils" "github.com/datastax/zdm-proxy/proxy/pkg/config" - "github.com/rs/zerolog" - log "github.com/sirupsen/logrus" - "github.com/stretchr/testify/require" - "sync" - "testing" - "time" ) func TestPreparedIdProxyCacheMiss(t *testing.T) { @@ -33,7 +36,7 @@ func TestPreparedIdProxyCacheMiss(t *testing.T) { defer testClient.Shutdown() - err = testClient.PerformDefaultHandshake(context.Background(), primitive.ProtocolVersion4, false) + err = testClient.PerformDefaultHandshake(context.Background(), env.DefaultProtocolVersionSimulacron, false) require.True(t, err == nil, "No-auth handshake failed: %s", err) preparedId := []byte{143, 7, 36, 50, 225, 104, 157, 89, 199, 177, 239, 231, 82, 201, 142, 253} @@ -42,7 +45,7 @@ func TestPreparedIdProxyCacheMiss(t *testing.T) { QueryId: preparedId, ResultMetadataId: nil, } - response, requestStreamId, err := testClient.SendMessage(context.Background(), primitive.ProtocolVersion4, executeMsg) + response, requestStreamId, err := testClient.SendMessage(context.Background(), env.DefaultProtocolVersionSimulacron, executeMsg) require.True(t, err == nil, "execute request send failed: %s", err) require.True(t, response != nil, "response received was null") @@ -74,7 +77,7 @@ func TestPreparedIdPreparationMismatch(t *testing.T) { defer testClient.Shutdown() - err = testClient.PerformDefaultHandshake(context.Background(), primitive.ProtocolVersion4, false) + err = testClient.PerformDefaultHandshake(context.Background(), env.DefaultProtocolVersionSimulacron, false) require.True(t, err == nil, "No-auth handshake failed: %s", err) tests := map[string]struct { @@ -138,7 +141,7 @@ func TestPreparedIdPreparationMismatch(t *testing.T) { Keyspace: "", } - response, requestStreamId, err := testClient.SendMessage(context.Background(), primitive.ProtocolVersion4, prepareMsg) + response, requestStreamId, err := testClient.SendMessage(context.Background(), env.DefaultProtocolVersionSimulacron, prepareMsg) require.True(t, err == nil, "prepare request send failed: %s", err) preparedResponse, ok := response.Body.Message.(*message.PreparedResult) @@ -153,7 +156,7 @@ func TestPreparedIdPreparationMismatch(t *testing.T) { ResultMetadataId: preparedResponse.ResultMetadataId, } - response, requestStreamId, err = testClient.SendMessage(context.Background(), primitive.ProtocolVersion4, executeMsg) + response, requestStreamId, err = testClient.SendMessage(context.Background(), env.DefaultProtocolVersionSimulacron, executeMsg) require.True(t, err == nil, "execute request send failed: %s", err) if test.expectedUnprepared { @@ -329,7 +332,7 @@ func TestPreparedIdReplacement(t *testing.T) { test.expectedBatchQuery, targetPreparedId, targetBatchPreparedId, targetKey, targetValue, map[string]interface{}{}, false, test.expectedVariables, test.expectedBatchPreparedStmtVariables, dualReadsEnabled && test.read)} - err = testSetup.Start(conf, true, primitive.ProtocolVersion4) + err = testSetup.Start(conf, true, env.DefaultProtocolVersion) require.Nil(t, err) prepareMsg := &message.Prepare{ @@ -342,7 +345,7 @@ func TestPreparedIdReplacement(t *testing.T) { } prepareResp, err := testSetup.Client.CqlConnection.SendAndReceive( - frame.NewFrame(primitive.ProtocolVersion4, 10, prepareMsg)) + frame.NewFrame(env.DefaultProtocolVersion, 10, prepareMsg)) require.Nil(t, err) preparedResult, ok := prepareResp.Body.Message.(*message.PreparedResult) @@ -350,6 +353,8 @@ func TestPreparedIdReplacement(t *testing.T) { require.Equal(t, originPreparedId, preparedResult.PreparedQueryId) + metadataId := preparedResult.ResultMetadataId + var batchPrepareMsg *message.Prepare var expectedBatchPrepareMsg *message.Prepare if test.batchQuery != "" { @@ -358,7 +363,7 @@ func TestPreparedIdReplacement(t *testing.T) { expectedBatchPrepareMsg = batchPrepareMsg.DeepCopy() expectedBatchPrepareMsg.Query = test.expectedBatchQuery prepareResp, err = testSetup.Client.CqlConnection.SendAndReceive( - frame.NewFrame(primitive.ProtocolVersion4, 10, batchPrepareMsg)) + frame.NewFrame(env.DefaultProtocolVersion, 10, batchPrepareMsg)) require.Nil(t, err) preparedResult, ok = prepareResp.Body.Message.(*message.PreparedResult) @@ -369,12 +374,12 @@ func TestPreparedIdReplacement(t *testing.T) { executeMsg := &message.Execute{ QueryId: originPreparedId, - ResultMetadataId: nil, + ResultMetadataId: metadataId, Options: &message.QueryOptions{}, } executeResp, err := testSetup.Client.CqlConnection.SendAndReceive( - frame.NewFrame(primitive.ProtocolVersion4, 20, executeMsg)) + frame.NewFrame(env.DefaultProtocolVersion, 20, executeMsg)) require.Nil(t, err) rowsResult, ok := executeResp.Body.Message.(*message.RowsResult) @@ -410,7 +415,7 @@ func TestPreparedIdReplacement(t *testing.T) { } batchResp, err := testSetup.Client.CqlConnection.SendAndReceive( - frame.NewFrame(primitive.ProtocolVersion4, 30, batchMsg)) + frame.NewFrame(env.DefaultProtocolVersion, 30, batchMsg)) require.Nil(t, err) batchResult, ok := batchResp.Body.Message.(*message.VoidResult) @@ -695,7 +700,7 @@ func TestUnpreparedIdReplacement(t *testing.T) { test.batchQuery, targetPreparedId, targetBatchPreparedId, targetKey, targetValue, targetCtx, test.targetUnprepared, nil, nil, dualReadsEnabled && test.read)} - err = testSetup.Start(conf, true, primitive.ProtocolVersion4) + err = testSetup.Start(conf, true, env.DefaultProtocolVersion) require.Nil(t, err) prepareMsg := &message.Prepare{ @@ -704,7 +709,7 @@ func TestUnpreparedIdReplacement(t *testing.T) { } prepareResp, err := testSetup.Client.CqlConnection.SendAndReceive( - frame.NewFrame(primitive.ProtocolVersion4, 10, prepareMsg)) + frame.NewFrame(env.DefaultProtocolVersion, 10, prepareMsg)) require.Nil(t, err) preparedResult, ok := prepareResp.Body.Message.(*message.PreparedResult) @@ -714,12 +719,12 @@ func TestUnpreparedIdReplacement(t *testing.T) { executeMsg := &message.Execute{ QueryId: originPreparedId, - ResultMetadataId: nil, + ResultMetadataId: preparedResult.ResultMetadataId, Options: &message.QueryOptions{}, } executeResp, err := testSetup.Client.CqlConnection.SendAndReceive( - frame.NewFrame(primitive.ProtocolVersion4, 20, executeMsg)) + frame.NewFrame(env.DefaultProtocolVersion, 20, executeMsg)) require.Nil(t, err) unPreparedResult, ok := executeResp.Body.Message.(*message.Unprepared) @@ -728,7 +733,7 @@ func TestUnpreparedIdReplacement(t *testing.T) { require.Equal(t, originPreparedId, unPreparedResult.Id) prepareResp, err = testSetup.Client.CqlConnection.SendAndReceive( - frame.NewFrame(primitive.ProtocolVersion4, 10, prepareMsg)) + frame.NewFrame(env.DefaultProtocolVersion, 10, prepareMsg)) require.Nil(t, err) preparedResult, ok = prepareResp.Body.Message.(*message.PreparedResult) @@ -737,7 +742,7 @@ func TestUnpreparedIdReplacement(t *testing.T) { require.Equal(t, originPreparedId, preparedResult.PreparedQueryId) executeResp, err = testSetup.Client.CqlConnection.SendAndReceive( - frame.NewFrame(primitive.ProtocolVersion4, 20, executeMsg)) + frame.NewFrame(env.DefaultProtocolVersion, 20, executeMsg)) require.Nil(t, err) rowsResult, ok := executeResp.Body.Message.(*message.RowsResult) @@ -749,7 +754,7 @@ func TestUnpreparedIdReplacement(t *testing.T) { batchPrepareMsg = prepareMsg.DeepCopy() batchPrepareMsg.Query = test.batchQuery prepareResp, err = testSetup.Client.CqlConnection.SendAndReceive( - frame.NewFrame(primitive.ProtocolVersion4, 10, batchPrepareMsg)) + frame.NewFrame(env.DefaultProtocolVersion, 10, batchPrepareMsg)) require.Nil(t, err) preparedResult, ok = prepareResp.Body.Message.(*message.PreparedResult) @@ -779,7 +784,7 @@ func TestUnpreparedIdReplacement(t *testing.T) { } batchResp, err := testSetup.Client.CqlConnection.SendAndReceive( - frame.NewFrame(primitive.ProtocolVersion4, 30, batchMsg)) + frame.NewFrame(env.DefaultProtocolVersion, 30, batchMsg)) require.Nil(t, err) unPreparedResult, ok := batchResp.Body.Message.(*message.Unprepared) @@ -788,7 +793,7 @@ func TestUnpreparedIdReplacement(t *testing.T) { require.Equal(t, originBatchPreparedId, unPreparedResult.Id) prepareResp, err = testSetup.Client.CqlConnection.SendAndReceive( - frame.NewFrame(primitive.ProtocolVersion4, 10, batchPrepareMsg)) + frame.NewFrame(env.DefaultProtocolVersion, 10, batchPrepareMsg)) require.Nil(t, err) preparedResult, ok = prepareResp.Body.Message.(*message.PreparedResult) @@ -797,7 +802,7 @@ func TestUnpreparedIdReplacement(t *testing.T) { require.Equal(t, originBatchPreparedId, preparedResult.PreparedQueryId) batchResp, err = testSetup.Client.CqlConnection.SendAndReceive( - frame.NewFrame(primitive.ProtocolVersion4, 30, batchMsg)) + frame.NewFrame(env.DefaultProtocolVersion, 30, batchMsg)) require.Nil(t, err) batchResult, ok := batchResp.Body.Message.(*message.VoidResult) @@ -1034,7 +1039,7 @@ func NewPreparedTestHandler( lock.Unlock() return frame.NewFrame(request.Header.Version, request.Header.StreamId, &message.PreparedResult{ PreparedQueryId: prepId, - ResultMetadataId: nil, + ResultMetadataId: prepId, VariablesMetadata: variablesMetadata, ResultMetadata: rowsMetadata, }) diff --git a/integration-tests/protocolversions_test.go b/integration-tests/protocolversions_test.go index 4da2980d..c87e6eda 100644 --- a/integration-tests/protocolversions_test.go +++ b/integration-tests/protocolversions_test.go @@ -2,17 +2,22 @@ package integration_tests import ( "context" + "errors" "fmt" + "net" + "slices" + "strings" + "testing" + "github.com/datastax/go-cassandra-native-protocol/client" "github.com/datastax/go-cassandra-native-protocol/datatype" "github.com/datastax/go-cassandra-native-protocol/frame" "github.com/datastax/go-cassandra-native-protocol/message" "github.com/datastax/go-cassandra-native-protocol/primitive" - "github.com/datastax/zdm-proxy/integration-tests/setup" + log "github.com/sirupsen/logrus" "github.com/stretchr/testify/require" - "net" - "slices" - "testing" + + "github.com/datastax/zdm-proxy/integration-tests/setup" ) // Test that proxy can establish connectivity with ORIGIN and TARGET @@ -32,16 +37,25 @@ func TestProtocolNegotiationDifferentClusters(t *testing.T) { }{ { name: "OriginV2_TargetV2_ClientV2", - proxyMaxProtoVer: "2", + proxyMaxProtoVer: "", proxyOriginContConnVer: primitive.ProtocolVersion2, proxyTargetContConnVer: primitive.ProtocolVersion2, originProtoVer: []primitive.ProtocolVersion{primitive.ProtocolVersion2}, targetProtoVer: []primitive.ProtocolVersion{primitive.ProtocolVersion2}, clientProtoVer: primitive.ProtocolVersion2, }, + { + name: "OriginV23_TargetV345_ClientV3", + proxyMaxProtoVer: "", + proxyOriginContConnVer: primitive.ProtocolVersion3, + proxyTargetContConnVer: primitive.ProtocolVersion5, + originProtoVer: []primitive.ProtocolVersion{primitive.ProtocolVersion2, primitive.ProtocolVersion3}, + targetProtoVer: []primitive.ProtocolVersion{primitive.ProtocolVersion3, primitive.ProtocolVersion4, primitive.ProtocolVersion5}, + clientProtoVer: primitive.ProtocolVersion3, + }, { name: "OriginV2_TargetV2_ClientV2_ProxyControlConnNegotiation", - proxyMaxProtoVer: "4", + proxyMaxProtoVer: "", proxyOriginContConnVer: primitive.ProtocolVersion2, proxyTargetContConnVer: primitive.ProtocolVersion2, originProtoVer: []primitive.ProtocolVersion{primitive.ProtocolVersion2}, @@ -50,7 +64,7 @@ func TestProtocolNegotiationDifferentClusters(t *testing.T) { }, { name: "OriginV2_TargetV23_ClientV2", - proxyMaxProtoVer: "3", + proxyMaxProtoVer: "", proxyOriginContConnVer: primitive.ProtocolVersion2, proxyTargetContConnVer: primitive.ProtocolVersion3, originProtoVer: []primitive.ProtocolVersion{primitive.ProtocolVersion2}, @@ -59,7 +73,7 @@ func TestProtocolNegotiationDifferentClusters(t *testing.T) { }, { name: "OriginV23_TargetV2_ClientV2", - proxyMaxProtoVer: "3", + proxyMaxProtoVer: "", proxyOriginContConnVer: primitive.ProtocolVersion3, proxyTargetContConnVer: primitive.ProtocolVersion2, originProtoVer: []primitive.ProtocolVersion{primitive.ProtocolVersion2, primitive.ProtocolVersion3}, @@ -69,42 +83,60 @@ func TestProtocolNegotiationDifferentClusters(t *testing.T) { { // most common setup with OSS Cassandra name: "OriginV345_TargetV345_ClientV4", - proxyMaxProtoVer: "DseV2", - proxyOriginContConnVer: primitive.ProtocolVersion4, - proxyTargetContConnVer: primitive.ProtocolVersion4, + proxyMaxProtoVer: "", + proxyOriginContConnVer: primitive.ProtocolVersion5, + proxyTargetContConnVer: primitive.ProtocolVersion5, originProtoVer: []primitive.ProtocolVersion{primitive.ProtocolVersion3, primitive.ProtocolVersion4, primitive.ProtocolVersion5}, targetProtoVer: []primitive.ProtocolVersion{primitive.ProtocolVersion3, primitive.ProtocolVersion4, primitive.ProtocolVersion5}, clientProtoVer: primitive.ProtocolVersion4, }, + { + name: "OriginV345_TargetV345_ClientV5", + proxyMaxProtoVer: "", + proxyOriginContConnVer: primitive.ProtocolVersion5, + proxyTargetContConnVer: primitive.ProtocolVersion5, + originProtoVer: []primitive.ProtocolVersion{primitive.ProtocolVersion3, primitive.ProtocolVersion4, primitive.ProtocolVersion5}, + targetProtoVer: []primitive.ProtocolVersion{primitive.ProtocolVersion3, primitive.ProtocolVersion4, primitive.ProtocolVersion5}, + clientProtoVer: primitive.ProtocolVersion5, + }, { // most common setup with DSE name: "OriginV345_TargetV34Dse1Dse2_ClientV4", - proxyMaxProtoVer: "DseV2", - proxyOriginContConnVer: primitive.ProtocolVersion4, + proxyMaxProtoVer: "", + proxyOriginContConnVer: primitive.ProtocolVersion5, proxyTargetContConnVer: primitive.ProtocolVersionDse2, originProtoVer: []primitive.ProtocolVersion{primitive.ProtocolVersion3, primitive.ProtocolVersion4, primitive.ProtocolVersion5}, targetProtoVer: []primitive.ProtocolVersion{primitive.ProtocolVersion3, primitive.ProtocolVersion4, primitive.ProtocolVersionDse1, primitive.ProtocolVersionDse2}, clientProtoVer: primitive.ProtocolVersion4, }, { - name: "OriginV2_TargetV3_ClientV2", - proxyMaxProtoVer: "3", + name: "OriginV234Dse1Dse2_TargetV345_ClientV4", + proxyMaxProtoVer: "", + proxyOriginContConnVer: primitive.ProtocolVersionDse2, + proxyTargetContConnVer: primitive.ProtocolVersion5, + originProtoVer: []primitive.ProtocolVersion{primitive.ProtocolVersion2, primitive.ProtocolVersion3, primitive.ProtocolVersion4, primitive.ProtocolVersionDse1, primitive.ProtocolVersionDse2}, + targetProtoVer: []primitive.ProtocolVersion{primitive.ProtocolVersion3, primitive.ProtocolVersion4, primitive.ProtocolVersion5}, + clientProtoVer: primitive.ProtocolVersion4, + }, + { + name: "OriginV2_TargetV345_FailClient", + proxyMaxProtoVer: "", proxyOriginContConnVer: primitive.ProtocolVersion2, - proxyTargetContConnVer: primitive.ProtocolVersion3, + proxyTargetContConnVer: primitive.ProtocolVersion5, originProtoVer: []primitive.ProtocolVersion{primitive.ProtocolVersion2}, - targetProtoVer: []primitive.ProtocolVersion{primitive.ProtocolVersion3}, + targetProtoVer: []primitive.ProtocolVersion{primitive.ProtocolVersion3, primitive.ProtocolVersion4, primitive.ProtocolVersion5}, clientProtoVer: primitive.ProtocolVersion2, // client connection should fail as there is no common protocol version between origin and target failClientConnect: true, }, { - name: "OriginV3_TargetV3_ClientV3_Too_Low_Proto_Configured", + name: "OriginV3_TargetV3_Too_Low_Proto_Configured", proxyMaxProtoVer: "2", proxyOriginContConnVer: primitive.ProtocolVersion3, proxyTargetContConnVer: primitive.ProtocolVersion3, originProtoVer: []primitive.ProtocolVersion{primitive.ProtocolVersion3}, targetProtoVer: []primitive.ProtocolVersion{primitive.ProtocolVersion3}, clientProtoVer: primitive.ProtocolVersion2, - // client proxy startup, because configured protocol version is too low + // fail proxy control connection, because configured protocol version is too low failProxyStartup: true, }, } @@ -113,6 +145,7 @@ func TestProtocolNegotiationDifferentClusters(t *testing.T) { targetAddress := "127.0.1.2" serverConf := setup.NewTestConfig(originAddress, targetAddress) proxyConf := setup.NewTestConfig(originAddress, targetAddress) + log.SetLevel(log.TraceLevel) queryInsert := &message.Query{ Query: "INSERT INTO test_ks.test(key, value) VALUES(1, '1')", // use INSERT to route request to both clusters @@ -123,7 +156,9 @@ func TestProtocolNegotiationDifferentClusters(t *testing.T) { for _, test := range tests { t.Run(test.name, func(t *testing.T) { - proxyConf.ControlConnMaxProtocolVersion = test.proxyMaxProtoVer + if test.proxyMaxProtoVer != "" { + proxyConf.ControlConnMaxProtocolVersion = test.proxyMaxProtoVer + } testSetup, err := setup.NewCqlServerTestSetup(t, serverConf, false, false, false) require.Nil(t, err) @@ -180,6 +215,152 @@ func TestProtocolNegotiationDifferentClusters(t *testing.T) { } } +// Test that proxy blocks protocol versions when configured to do so +func TestProtocolNegotiationBlockedVersions(t *testing.T) { + tests := []struct { + name string + clusterProtoVers []primitive.ProtocolVersion + blockedProtoVers string + clientProtoVer primitive.ProtocolVersion + failClientConnect bool + }{ + { + name: "ClusterV2_BlockedV2_ClientFail", + clusterProtoVers: []primitive.ProtocolVersion{primitive.ProtocolVersion2}, + blockedProtoVers: "v2", + failClientConnect: true, + }, + { + name: "ClusterV2V3V4_BlockedV2_ClientV4", + clusterProtoVers: []primitive.ProtocolVersion{0x2, 0x3, 0x4}, + blockedProtoVers: "v2", + clientProtoVer: 0x4, + }, + { + name: "ClusterV2V3V4V5_BlockedV5_ClientV4", + clusterProtoVers: []primitive.ProtocolVersion{0x2, 0x3, 0x4, 0x5}, + blockedProtoVers: "v5", + clientProtoVer: 0x4, + }, + { + name: "ClusterV2V3V4V5_BlockedV4V5_ClientV3", + clusterProtoVers: []primitive.ProtocolVersion{0x2, 0x3, 0x4, 0x5}, + blockedProtoVers: "v4,v5", + clientProtoVer: 0x3, + }, + { + name: "ClusterV2V3V4V5_BlockedV2V3V4V5_ClientFail", + clusterProtoVers: []primitive.ProtocolVersion{0x2, 0x3, 0x4, 0x5}, + blockedProtoVers: "2,3,4,5", + failClientConnect: true, + }, + { + name: "ClusterV2V3V4DseV1DseV2_BlockedV4V5DseV1_ClientDseV2", + clusterProtoVers: []primitive.ProtocolVersion{0x2, 0x3, 0x4, primitive.ProtocolVersionDse1, primitive.ProtocolVersionDse2}, + blockedProtoVers: "V4,V5,DseV1", + clientProtoVer: primitive.ProtocolVersionDse2, + }, + { + name: "ClusterV2V3V4DseV1_BlockedV5_ClientDseV1", + clusterProtoVers: []primitive.ProtocolVersion{0x2, 0x3, 0x4, primitive.ProtocolVersionDse1}, + blockedProtoVers: "V5", + clientProtoVer: primitive.ProtocolVersionDse1, + }, + { + name: "ClusterV2V3V4DseV1_BlockedDseV1_ClientV4", + clusterProtoVers: []primitive.ProtocolVersion{0x2, 0x3, 0x4, primitive.ProtocolVersionDse1}, + blockedProtoVers: "dsev1", + clientProtoVer: 0x4, + }, + } + + originAddress := "127.0.1.1" + targetAddress := "127.0.1.2" + serverConf := setup.NewTestConfig(originAddress, targetAddress) + proxyConf := setup.NewTestConfig(originAddress, targetAddress) + log.SetLevel(log.TraceLevel) + + queryInsert := &message.Query{ + Query: "INSERT INTO test_ks.test(key, value) VALUES(1, '1')", // use INSERT to route request to both clusters + } + querySelect := &message.Query{ + Query: "SELECT * FROM test_ks.test", + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + proxyConf.BlockedProtocolVersions = test.blockedProtoVers + + testSetup, err := setup.NewCqlServerTestSetup(t, serverConf, false, false, false) + require.Nil(t, err) + defer testSetup.Cleanup() + + originRequestHandler := NewProtocolNegotiationRequestHandler("origin", "dc1", originAddress, test.clusterProtoVers) + targetRequestHandler := NewProtocolNegotiationRequestHandler("target", "dc1", targetAddress, test.clusterProtoVers) + + testSetup.Origin.CqlServer.RequestHandlers = []client.RequestHandler{ + originRequestHandler.HandleRequest, + client.NewDriverConnectionInitializationHandler("origin", "dc1", func(_ string) {}), + } + testSetup.Target.CqlServer.RequestHandlers = []client.RequestHandler{ + targetRequestHandler.HandleRequest, + client.NewDriverConnectionInitializationHandler("target", "dc1", func(_ string) {}), + } + + err = testSetup.Start(nil, false, 0) + require.Nil(t, err) + + proxy, err := setup.NewProxyInstanceWithConfig(proxyConf) // starts the proxy + if proxy != nil { + defer proxy.Shutdown() + } + require.Nil(t, err) + + cqlConn, clientProtoVer, err := connectWithNegotiation(testSetup.Client.CqlClient, context.Background()) + if cqlConn != nil { + defer cqlConn.Close() + } + if test.failClientConnect { + require.NotNil(t, err) + return + } + require.Nil(t, err) + + require.Equal(t, test.clientProtoVer, clientProtoVer) + + response, err := cqlConn.SendAndReceive(frame.NewFrame(test.clientProtoVer, 0, queryInsert)) + require.Nil(t, err) + require.IsType(t, &message.VoidResult{}, response.Body.Message) + + response, err = cqlConn.SendAndReceive(frame.NewFrame(test.clientProtoVer, 0, querySelect)) + require.Nil(t, err) + resultSet := response.Body.Message.(*message.RowsResult).Data + require.Equal(t, 1, len(resultSet)) + }) + } +} + +func connectWithNegotiation(cqlClient *client.CqlClient, ctx context.Context) (*client.CqlClientConnection, primitive.ProtocolVersion, error) { + orderedProtoVersions := []primitive.ProtocolVersion{ + primitive.ProtocolVersionDse2, primitive.ProtocolVersionDse1, primitive.ProtocolVersion5, + primitive.ProtocolVersion4, primitive.ProtocolVersion3, primitive.ProtocolVersion2} + + for _, protoVersion := range orderedProtoVersions { + conn, err := cqlClient.ConnectAndInit(ctx, protoVersion, 0) + if err != nil { + if conn != nil { + conn.Close() + } + if strings.Contains(strings.ToLower(err.Error()), "handler closed") { + continue + } + return nil, 0, fmt.Errorf("negotiate error: %w", err) + } + return conn, protoVersion, nil + } + return nil, 0, errors.New("all protocol versions failed") +} + type ProtocolNegotiationRequestHandler struct { cluster string datacenter string diff --git a/integration-tests/read_test.go b/integration-tests/read_test.go index 909874ac..50c548a6 100644 --- a/integration-tests/read_test.go +++ b/integration-tests/read_test.go @@ -2,18 +2,20 @@ package integration_tests import ( "fmt" + "net" + "strings" + "testing" + + "github.com/stretchr/testify/require" + "github.com/datastax/zdm-proxy/integration-tests/setup" "github.com/datastax/zdm-proxy/integration-tests/simulacron" "github.com/datastax/zdm-proxy/integration-tests/utils" "github.com/datastax/zdm-proxy/proxy/pkg/config" - "github.com/stretchr/testify/require" - "net" - "strings" - "testing" ) -var rpcAddressExpectedPrimed = net.IPv4(192, 168, 1, 1) -var rpcAddressExpectedProxy = net.IPv4(127, 0, 0, 1) +var rpcAddressExpectedPrimed = net.IP{192, 168, 1, 1} +var rpcAddressExpectedProxy = net.IP{127, 0, 0, 1} var rows = simulacron.NewRowsResult( map[string]simulacron.DataType{ @@ -65,13 +67,13 @@ func testForwardDecisionsForReads(t *testing.T, primaryCluster string, systemQue } expectedProxyRow := map[string]interface{}{ - "rpc_address": rpcAddressExpectedProxy.String(), + "rpc_address": rpcAddressExpectedProxy, } expectedAliasedProxyRow := map[string]interface{}{ - "addr": rpcAddressExpectedProxy.String(), + "addr": rpcAddressExpectedProxy, } expectedPrimedRow := map[string]interface{}{ - "rpc_address": rpcAddressExpectedPrimed.String(), + "rpc_address": rpcAddressExpectedPrimed, } tests := []struct { diff --git a/integration-tests/runner_test.go b/integration-tests/runner_test.go index 0cdbb5d6..ea08c1fe 100644 --- a/integration-tests/runner_test.go +++ b/integration-tests/runner_test.go @@ -3,9 +3,20 @@ package integration_tests import ( "context" "fmt" + "net/http" + "strings" + "sync" + "sync/atomic" + "testing" + "time" + "github.com/datastax/go-cassandra-native-protocol/message" - "github.com/datastax/go-cassandra-native-protocol/primitive" + "github.com/jpillora/backoff" + log "github.com/sirupsen/logrus" + "github.com/stretchr/testify/require" + "github.com/datastax/zdm-proxy/integration-tests/client" + "github.com/datastax/zdm-proxy/integration-tests/env" "github.com/datastax/zdm-proxy/integration-tests/setup" "github.com/datastax/zdm-proxy/integration-tests/utils" "github.com/datastax/zdm-proxy/proxy/pkg/config" @@ -14,15 +25,6 @@ import ( "github.com/datastax/zdm-proxy/proxy/pkg/metrics" "github.com/datastax/zdm-proxy/proxy/pkg/runner" "github.com/datastax/zdm-proxy/proxy/pkg/zdmproxy" - "github.com/jpillora/backoff" - log "github.com/sirupsen/logrus" - "github.com/stretchr/testify/require" - "net/http" - "strings" - "sync" - "sync/atomic" - "testing" - "time" ) /* @@ -213,7 +215,7 @@ func testMetricsWithUnavailableNode( queryMsg := &message.Query{ Query: "SELECT * FROM table1", } - _, _, _ = testClient.SendMessage(context.Background(), primitive.ProtocolVersion4, queryMsg) + _, _, _ = testClient.SendMessage(context.Background(), env.DefaultProtocolVersionSimulacron, queryMsg) utils.RequireWithRetries(t, func() (err error, fatal bool) { // expect connection failure to origin cluster diff --git a/integration-tests/setup/data.go b/integration-tests/setup/data.go index f9e74bbd..b57c157c 100644 --- a/integration-tests/setup/data.go +++ b/integration-tests/setup/data.go @@ -3,7 +3,7 @@ package setup import ( "fmt" - "github.com/gocql/gocql" + "github.com/apache/cassandra-gocql-driver/v2" log "github.com/sirupsen/logrus" ) diff --git a/integration-tests/setup/testcluster.go b/integration-tests/setup/testcluster.go index 3ea21c88..b7336ecd 100644 --- a/integration-tests/setup/testcluster.go +++ b/integration-tests/setup/testcluster.go @@ -2,17 +2,19 @@ package setup import ( "context" + "math" + "sync" + "testing" + "github.com/datastax/go-cassandra-native-protocol/primitive" + log "github.com/sirupsen/logrus" + "github.com/datastax/zdm-proxy/integration-tests/ccm" "github.com/datastax/zdm-proxy/integration-tests/cqlserver" "github.com/datastax/zdm-proxy/integration-tests/env" "github.com/datastax/zdm-proxy/integration-tests/simulacron" "github.com/datastax/zdm-proxy/proxy/pkg/config" "github.com/datastax/zdm-proxy/proxy/pkg/zdmproxy" - log "github.com/sirupsen/logrus" - "math" - "sync" - "testing" ) type TestCluster interface { @@ -27,7 +29,10 @@ var createdGlobalClusters = false var globalCcmClusterOrigin *ccm.Cluster var globalCcmClusterTarget *ccm.Cluster -func GetGlobalTestClusterOrigin() (*ccm.Cluster, error) { +func GetGlobalTestClusterOrigin(t *testing.T) (*ccm.Cluster, error) { + if !env.RunCcmTests { + t.Skip("Skipping CCM tests, RUN_CCMTESTS is set false") + } if createdGlobalClusters { return globalCcmClusterOrigin, nil } @@ -47,7 +52,10 @@ func GetGlobalTestClusterOrigin() (*ccm.Cluster, error) { return globalCcmClusterOrigin, nil } -func GetGlobalTestClusterTarget() (*ccm.Cluster, error) { +func GetGlobalTestClusterTarget(t *testing.T) (*ccm.Cluster, error) { + if !env.RunCcmTests { + t.Skip("Skipping CCM tests, RUN_CCMTESTS is set false") + } if createdGlobalClusters { return globalCcmClusterTarget, nil } @@ -198,7 +206,10 @@ type CcmTestSetup struct { Proxy *zdmproxy.ZdmProxy } -func NewTemporaryCcmTestSetup(start bool, createProxy bool) (*CcmTestSetup, error) { +func NewTemporaryCcmTestSetup(t *testing.T, start bool, createProxy bool) (*CcmTestSetup, error) { + if !env.RunCcmTests { + t.Skip("Skipping CCM tests, RUN_CCMTESTS is set false") + } firstClusterId := env.Rand.Uint64() % (math.MaxUint64 - 1) origin, err := ccm.GetNewCluster(firstClusterId, 20, env.OriginNodes, start) if err != nil { diff --git a/integration-tests/shutdown_test.go b/integration-tests/shutdown_test.go index b6b44cda..c205c545 100644 --- a/integration-tests/shutdown_test.go +++ b/integration-tests/shutdown_test.go @@ -4,23 +4,26 @@ import ( "context" "errors" "fmt" + "math/rand" + "runtime" + "sync" + "sync/atomic" + "testing" + "time" + client2 "github.com/datastax/go-cassandra-native-protocol/client" "github.com/datastax/go-cassandra-native-protocol/frame" "github.com/datastax/go-cassandra-native-protocol/message" "github.com/datastax/go-cassandra-native-protocol/primitive" + "github.com/rs/zerolog" + log "github.com/sirupsen/logrus" + "github.com/stretchr/testify/require" + "github.com/datastax/zdm-proxy/integration-tests/client" + "github.com/datastax/zdm-proxy/integration-tests/env" "github.com/datastax/zdm-proxy/integration-tests/setup" "github.com/datastax/zdm-proxy/integration-tests/simulacron" "github.com/datastax/zdm-proxy/proxy/pkg/config" - "github.com/rs/zerolog" - log "github.com/sirupsen/logrus" - "github.com/stretchr/testify/require" - "math/rand" - "runtime" - "sync" - "sync/atomic" - "testing" - "time" ) func TestShutdownInFlightRequests(t *testing.T) { @@ -55,7 +58,7 @@ func TestShutdownInFlightRequests(t *testing.T) { }() cqlClient := client2.NewCqlClient("127.0.0.1:14002", nil) - cqlConn, err := cqlClient.ConnectAndInit(context.Background(), primitive.ProtocolVersion4, 0) + cqlConn, err := cqlClient.ConnectAndInit(context.Background(), env.DefaultProtocolVersionSimulacron, 0) if err != nil { t.Fatalf("could not connect: %v", err) } @@ -88,15 +91,15 @@ func TestShutdownInFlightRequests(t *testing.T) { beginTimestamp := time.Now() - reqFrame := frame.NewFrame(primitive.ProtocolVersion4, 2, queryMsg1) + reqFrame := frame.NewFrame(env.DefaultProtocolVersionSimulacron, 2, queryMsg1) inflightRequest, err := cqlConn.Send(reqFrame) require.Nil(t, err) - reqFrame2 := frame.NewFrame(primitive.ProtocolVersion4, 3, queryMsg2) + reqFrame2 := frame.NewFrame(env.DefaultProtocolVersionSimulacron, 3, queryMsg2) inflightRequest2, err := cqlConn.Send(reqFrame2) require.Nil(t, err) - reqFrame3 := frame.NewFrame(primitive.ProtocolVersion4, 4, queryMsg3) + reqFrame3 := frame.NewFrame(env.DefaultProtocolVersionSimulacron, 4, queryMsg3) inflightRequest3, err := cqlConn.Send(reqFrame3) require.Nil(t, err) @@ -125,7 +128,7 @@ func TestShutdownInFlightRequests(t *testing.T) { default: } - reqFrame4 := frame.NewFrame(primitive.ProtocolVersion4, 5, queryMsg1) + reqFrame4 := frame.NewFrame(env.DefaultProtocolVersionSimulacron, 5, queryMsg1) inflightRequest4, err := cqlConn.Send(reqFrame4) require.Nil(t, err) @@ -236,7 +239,7 @@ func TestStressShutdown(t *testing.T) { require.Nil(t, err) defer cqlConn.Shutdown() - err = cqlConn.PerformDefaultHandshake(context.Background(), primitive.ProtocolVersion4, false) + err = cqlConn.PerformDefaultHandshake(context.Background(), env.DefaultProtocolVersionSimulacron, false) require.Nil(t, err) // create a channel that will receive errors from goroutines that are sending requests, @@ -284,7 +287,7 @@ func TestStressShutdown(t *testing.T) { case <-defaultHandshakeDoneCh: return default: - rspFrame, _, err := tempCqlConn.SendMessage(context.Background(), primitive.ProtocolVersion4, &message.Options{}) + rspFrame, _, err := tempCqlConn.SendMessage(context.Background(), env.DefaultProtocolVersionSimulacron, &message.Options{}) if err != nil { if !shutdownProxyTriggered.Load().(bool) { errChan <- fmt.Errorf("[%v] unexpected error in heartbeat: %w", id, err) @@ -311,7 +314,7 @@ func TestStressShutdown(t *testing.T) { case <-time.After(time.Duration(r) * time.Millisecond): case <-globalCtx.Done(): } - err = tempCqlConn.PerformDefaultHandshake(context.Background(), primitive.ProtocolVersion4, false) + err = tempCqlConn.PerformDefaultHandshake(context.Background(), env.DefaultProtocolVersionSimulacron, false) defaultHandshakeDoneCh <- true optionsWg.Wait() _ = tempCqlConn.Shutdown() @@ -336,7 +339,7 @@ func TestStressShutdown(t *testing.T) { Query: "SELECT * FROM system.local", Options: &message.QueryOptions{Consistency: primitive.ConsistencyLevelLocalOne}, } - rsp, _, err := cqlConn.SendMessage(context.Background(), primitive.ProtocolVersion4, queryMsg) + rsp, _, err := cqlConn.SendMessage(context.Background(), env.DefaultProtocolVersionSimulacron, queryMsg) if err != nil { if !shutdownProxyTriggered.Load().(bool) { diff --git a/integration-tests/simulacron/api.go b/integration-tests/simulacron/api.go index d0478ab9..6b4e28d1 100644 --- a/integration-tests/simulacron/api.go +++ b/integration-tests/simulacron/api.go @@ -4,8 +4,10 @@ import ( "bytes" "encoding/json" "fmt" - "github.com/gocql/gocql" "time" + + "github.com/apache/cassandra-gocql-driver/v2" + "github.com/datastax/go-cassandra-native-protocol/primitive" ) type When interface { @@ -384,3 +386,11 @@ func when(out map[string]interface{}) When { when.out = out return when } + +func SupportsProtocolVersion(version primitive.ProtocolVersion) bool { + if version == primitive.ProtocolVersion3 || version == primitive.ProtocolVersion4 { + return true + } + + return false +} diff --git a/integration-tests/simulacron/cluster.go b/integration-tests/simulacron/cluster.go index 6423c833..2d394337 100644 --- a/integration-tests/simulacron/cluster.go +++ b/integration-tests/simulacron/cluster.go @@ -3,8 +3,8 @@ package simulacron import ( "encoding/json" "fmt" + "github.com/apache/cassandra-gocql-driver/v2" "github.com/datastax/zdm-proxy/integration-tests/env" - "github.com/gocql/gocql" "net" "strings" ) diff --git a/integration-tests/streamid_test.go b/integration-tests/streamid_test.go index 17f01271..a0c6151b 100644 --- a/integration-tests/streamid_test.go +++ b/integration-tests/streamid_test.go @@ -3,19 +3,22 @@ package integration_tests import ( "context" "fmt" + "strings" + "sync" + "testing" + "time" + "github.com/datastax/go-cassandra-native-protocol/frame" "github.com/datastax/go-cassandra-native-protocol/message" "github.com/datastax/go-cassandra-native-protocol/primitive" + "github.com/stretchr/testify/require" + "github.com/datastax/zdm-proxy/integration-tests/client" + "github.com/datastax/zdm-proxy/integration-tests/env" "github.com/datastax/zdm-proxy/integration-tests/setup" "github.com/datastax/zdm-proxy/integration-tests/simulacron" "github.com/datastax/zdm-proxy/proxy/pkg/metrics" "github.com/datastax/zdm-proxy/proxy/pkg/runner" - "github.com/stretchr/testify/require" - "strings" - "sync" - "testing" - "time" ) type resources struct { @@ -101,7 +104,7 @@ func TestStreamIdsMetrics(t *testing.T) { defer resources.close() assertUsedStreamIds := initAsserts(resources.setup, metricsPrefix) - asyncQuery := asyncContextWrap(resources.testClient) + asyncQuery := asyncContextWrap(env.DefaultProtocolVersionSimulacron, resources.testClient) for idx, query := range testCase.queries { replacedQuery := fmt.Sprintf(query, formatName(t)) testCase.queries[idx] = replacedQuery @@ -125,7 +128,7 @@ func TestStreamIdsMetrics(t *testing.T) { // asyncContextWrap is a higher-order function that holds a reference to the test client and returns a function that // actually executes the query in an asynchronous fashion and returns an WaitGroup for synchronization -func asyncContextWrap(testClient *client.TestClient) func(t *testing.T, query string, repeat int) *sync.WaitGroup { +func asyncContextWrap(version primitive.ProtocolVersion, testClient *client.TestClient) func(t *testing.T, query string, repeat int) *sync.WaitGroup { run := func(t *testing.T, query string, repeat int) *sync.WaitGroup { // WaitGroup for controlling the dispatched/sent queries dispatchedWg := &sync.WaitGroup{} @@ -137,7 +140,7 @@ func asyncContextWrap(testClient *client.TestClient) func(t *testing.T, query st go func(testClient *client.TestClient, dispatched *sync.WaitGroup, returned *sync.WaitGroup) { defer returnedWg.Done() dispatchedWg.Done() - executeQuery(t, testClient, query) + executeQuery(t, version, testClient, query) }(testClient, dispatchedWg, returnedWg) } dispatchedWg.Wait() @@ -160,7 +163,7 @@ func setupResources(t *testing.T, testSetup *setup.SimulacronTestSetup, metricsP testClient, err := client.NewTestClientWithRequestTimeout(context.Background(), fmt.Sprintf("127.0.0.1:%v", proxyPort), 10*time.Second) require.Nil(t, err) - testClient.PerformDefaultHandshake(context.Background(), primitive.ProtocolVersion3, false) + testClient.PerformDefaultHandshake(context.Background(), env.DefaultProtocolVersionSimulacron, false) return &resources{ setup: &setup.SimulacronTestSetup{ @@ -190,11 +193,11 @@ func primeClustersWithDelay(setup *setup.SimulacronTestSetup, query string) { // executeQuery sends the query string in a Frame message through the test client and handles any failures internally // by failing the tests, otherwise, returns the response to the caller -func executeQuery(t *testing.T, client *client.TestClient, query string) *frame.Frame { +func executeQuery(t *testing.T, version primitive.ProtocolVersion, client *client.TestClient, query string) *frame.Frame { q := &message.Query{ Query: query, } - response, _, err := client.SendMessage(context.Background(), primitive.ProtocolVersion4, q) + response, _, err := client.SendMessage(context.Background(), version, q) if err != nil { t.Fatal("query failed:", err) } diff --git a/integration-tests/stress_test.go b/integration-tests/stress_test.go index 772f402f..8a53b2f2 100644 --- a/integration-tests/stress_test.go +++ b/integration-tests/stress_test.go @@ -4,24 +4,23 @@ import ( "context" "errors" "fmt" - "github.com/datastax/zdm-proxy/integration-tests/env" - "github.com/datastax/zdm-proxy/integration-tests/setup" - "github.com/gocql/gocql" - "github.com/rs/zerolog" - log "github.com/sirupsen/logrus" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" "sync" "sync/atomic" "testing" "time" + + "github.com/apache/cassandra-gocql-driver/v2" + "github.com/rs/zerolog" + log "github.com/sirupsen/logrus" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/datastax/zdm-proxy/integration-tests/env" + "github.com/datastax/zdm-proxy/integration-tests/setup" ) func TestSimultaneousConnections(t *testing.T) { - if !env.RunCcmTests { - t.Skip("Test requires CCM, set RUN_CCMTESTS env variable to TRUE") - } - ccmSetup, err := setup.NewTemporaryCcmTestSetup(false, false) + ccmSetup, err := setup.NewTemporaryCcmTestSetup(t, false, false) require.Nil(t, err) defer ccmSetup.Cleanup() err = ccmSetup.Origin.UpdateConf("authenticator: PasswordAuthenticator") @@ -72,6 +71,8 @@ func TestSimultaneousConnections(t *testing.T) { fatalErr := errors.New("fatal err") spawnGoroutinesWg := &sync.WaitGroup{} + var sessions []*gocql.Session + sessionsLock := &sync.Mutex{} for i := 0; i < parallelSessionGoroutines; i++ { spawnGoroutinesWg.Add(1) go func() { @@ -79,7 +80,6 @@ func TestSimultaneousConnections(t *testing.T) { for i := 0; i < numberOfSessionsPerGoroutine; i++ { goCqlCluster := gocql.NewCluster("localhost") goCqlCluster.Port = 14002 - goCqlCluster.ProtoVersion = 4 goCqlCluster.Authenticator = gocql.PasswordAuthenticator{ Username: "cassandra", Password: "cassandra", @@ -93,13 +93,19 @@ func TestSimultaneousConnections(t *testing.T) { errChan <- fmt.Errorf("%w: %v", fatalErr, err.Error()) return } - defer goCqlSession.Close() + sessionsLock.Lock() + sessions = append(sessions, goCqlSession) + sessionsLock.Unlock() requestWg.Add(1) go func() { defer requestWg.Done() for testCtx.Err() == nil { qCtx, fn := context.WithTimeout(testCtx, 10*time.Second) - q := goCqlSession.Query("SELECT * FROM system_schema.keyspaces").WithContext(qCtx) + qry := "SELECT * FROM system_schema.keyspaces" + if (!env.IsDse && env.CompareServerVersion("3.0.0") < 0) || (env.IsDse && env.CompareServerVersion("5.0.0") < 0) { + qry = "SELECT * FROM system.schema_keyspaces" + } + q := goCqlSession.Query(qry).WithContext(qCtx) err := q.Exec() fn() if errors.Is(err, gocql.ErrSessionClosed) { @@ -119,6 +125,13 @@ func TestSimultaneousConnections(t *testing.T) { go func() { defer wg.Done() defer close(errChan) + defer func() { + sessionsLock.Lock() + for _, session := range sessions { + session.Close() + } + sessionsLock.Unlock() + }() spawnGoroutinesWg.Wait() select { case <-time.After(13 * time.Second): @@ -151,6 +164,7 @@ func TestSimultaneousConnections(t *testing.T) { requestWg.Wait() }() + errCounter := 0 for { err, ok := <-errChan if !ok { @@ -166,7 +180,14 @@ func TestSimultaneousConnections(t *testing.T) { assert.Failf(t, "error before shutdown, deadlock?", "%v", err.Error()) testCancelFn() } else { - t.Log(err) + if errors.Is(err, gocql.ErrNoConnections) { + if errCounter%20 == 0 { + t.Log(err) + } + errCounter++ + } else { + t.Log(err) + } } } } diff --git a/integration-tests/tls_test.go b/integration-tests/tls_test.go index 94508de7..73871c94 100644 --- a/integration-tests/tls_test.go +++ b/integration-tests/tls_test.go @@ -5,22 +5,24 @@ import ( "crypto/tls" "crypto/x509" "fmt" + "io/ioutil" + "path/filepath" + "strings" + "testing" + "github.com/datastax/go-cassandra-native-protocol/client" "github.com/datastax/go-cassandra-native-protocol/frame" "github.com/datastax/go-cassandra-native-protocol/message" "github.com/datastax/go-cassandra-native-protocol/primitive" - "github.com/datastax/zdm-proxy/integration-tests/env" - "github.com/datastax/zdm-proxy/integration-tests/setup" - "github.com/datastax/zdm-proxy/integration-tests/utils" - "github.com/datastax/zdm-proxy/proxy/pkg/config" "github.com/rs/zerolog" zerologger "github.com/rs/zerolog/log" log "github.com/sirupsen/logrus" "github.com/stretchr/testify/require" - "io/ioutil" - "path/filepath" - "strings" - "testing" + + "github.com/datastax/zdm-proxy/integration-tests/env" + "github.com/datastax/zdm-proxy/integration-tests/setup" + "github.com/datastax/zdm-proxy/integration-tests/utils" + "github.com/datastax/zdm-proxy/proxy/pkg/config" ) type clusterTlsConfiguration struct { @@ -107,7 +109,9 @@ const ( // Runs only when the full test suite is executed func TestTls_OneWayOrigin_OneWayTarget(t *testing.T) { - + if env.CompareServerVersion("6.0.0") < 0 && env.CompareServerVersion("5.0.0") >= 0 && env.IsDse { + t.Skipf("skip tls tests for dse 5.1, for some unknown reason TLS errors are happening, revisit this if there is a user report about this") + } essentialTest := false skipNonEssentialTests(essentialTest, t) @@ -184,7 +188,9 @@ func TestTls_OneWayOrigin_OneWayTarget(t *testing.T) { // Runs only when the full test suite is executed func TestTls_MutualOrigin_MutualTarget(t *testing.T) { - + if env.CompareServerVersion("6.0.0") < 0 && env.CompareServerVersion("5.0.0") >= 0 && env.IsDse { + t.Skipf("skip tls tests for dse 5.1, for some unknown reason TLS errors are happening, revisit this if there is a user report about this") + } essentialTest := false skipNonEssentialTests(essentialTest, t) @@ -261,7 +267,9 @@ func TestTls_MutualOrigin_MutualTarget(t *testing.T) { // Always runs func TestTls_OneWayOrigin_MutualTarget(t *testing.T) { - + if env.CompareServerVersion("6.0.0") < 0 && env.CompareServerVersion("5.0.0") >= 0 && env.IsDse { + t.Skipf("skip tls tests for dse 5.1, for some unknown reason TLS errors are happening, revisit this if there is a user report about this") + } essentialTest := true skipNonEssentialTests(essentialTest, t) @@ -468,6 +476,9 @@ func TestTls_OneWayOrigin_MutualTarget(t *testing.T) { // Runs only when the full test suite is executed func TestTls_ExpiredCA(t *testing.T) { + if env.CompareServerVersion("6.0.0") < 0 && env.CompareServerVersion("5.0.0") >= 0 && env.IsDse { + t.Skipf("skip tls tests for dse 5.1, for some unknown reason TLS errors are happening, revisit this if there is a user report about this") + } essentialTest := false skipNonEssentialTests(essentialTest, t) @@ -513,7 +524,9 @@ func TestTls_ExpiredCA(t *testing.T) { // Runs only when the full test suite is executed func TestTls_MutualOrigin_OneWayTarget(t *testing.T) { - + if env.CompareServerVersion("6.0.0") < 0 && env.CompareServerVersion("5.0.0") >= 0 && env.IsDse { + t.Skipf("skip tls tests for dse 5.1, for some unknown reason TLS errors are happening, revisit this if there is a user report about this") + } essentialTest := false skipNonEssentialTests(essentialTest, t) @@ -590,7 +603,9 @@ func TestTls_MutualOrigin_OneWayTarget(t *testing.T) { // Runs only when the full test suite is executed func TestTls_NoOrigin_OneWayTarget(t *testing.T) { - + if env.CompareServerVersion("6.0.0") < 0 && env.CompareServerVersion("5.0.0") >= 0 && env.IsDse { + t.Skipf("skip tls tests for dse 5.1, for some unknown reason TLS errors are happening, revisit this if there is a user report about this") + } essentialTest := false skipNonEssentialTests(essentialTest, t) @@ -667,7 +682,9 @@ func TestTls_NoOrigin_OneWayTarget(t *testing.T) { // Always runs func TestTls_NoOrigin_MutualTarget(t *testing.T) { - + if env.CompareServerVersion("6.0.0") < 0 && env.CompareServerVersion("5.0.0") >= 0 && env.IsDse { + t.Skipf("skip tls tests for dse 5.1, for some unknown reason TLS errors are happening, revisit this if there is a user report about this") + } essentialTest := true skipNonEssentialTests(essentialTest, t) @@ -744,7 +761,9 @@ func TestTls_NoOrigin_MutualTarget(t *testing.T) { // Runs only when the full test suite is executed func TestTls_OneWayOrigin_NoTarget(t *testing.T) { - + if env.CompareServerVersion("6.0.0") < 0 && env.CompareServerVersion("5.0.0") >= 0 && env.IsDse { + t.Skipf("skip tls tests for dse 5.1, for some unknown reason TLS errors are happening, revisit this if there is a user report about this") + } essentialTest := false skipNonEssentialTests(essentialTest, t) @@ -821,7 +840,9 @@ func TestTls_OneWayOrigin_NoTarget(t *testing.T) { // Runs only when the full test suite is executed func TestTls_MutualOrigin_NoTarget(t *testing.T) { - + if env.CompareServerVersion("6.0.0") < 0 && env.CompareServerVersion("5.0.0") >= 0 && env.IsDse { + t.Skipf("skip tls tests for dse 5.1, for some unknown reason TLS errors are happening, revisit this if there is a user report about this") + } essentialTest := false skipNonEssentialTests(essentialTest, t) @@ -903,11 +924,7 @@ func skipNonEssentialTests(essentialTest bool, t *testing.T) { } func setupOriginAndTargetClusters(clusterConf clusterTlsConfiguration, t *testing.T) (*setup.CcmTestSetup, error) { - if !env.RunCcmTests { - t.Skip("Test requires CCM, set RUN_CCMTESTS env variable to TRUE") - } - - ccmSetup, err := setup.NewTemporaryCcmTestSetup(false, false) + ccmSetup, err := setup.NewTemporaryCcmTestSetup(t, false, false) if ccmSetup == nil { return nil, fmt.Errorf("ccm setup could not be created and is nil") } @@ -1239,7 +1256,7 @@ func applyProxyClientTlsConfiguration(expiredCa bool, incorrectCa bool, isMutual func createTestClientConnection(endpoint string, tlsCfg *tls.Config) (*client.CqlClientConnection, error) { testClient := client.NewCqlClient(endpoint, nil) testClient.TLSConfig = tlsCfg - return testClient.ConnectAndInit(context.Background(), primitive.ProtocolVersion4, 1) + return testClient.ConnectAndInit(context.Background(), env.DefaultProtocolVersion, 1) } func sendRequest(cqlConn *client.CqlClientConnection, cqlRequest string, isSchemaChange bool, t *testing.T) { @@ -1250,7 +1267,7 @@ func sendRequest(cqlConn *client.CqlClientConnection, cqlRequest string, isSchem }, } - queryFrame := frame.NewFrame(primitive.ProtocolVersion4, 0, requestMsg) + queryFrame := frame.NewFrame(env.DefaultProtocolVersion, 0, requestMsg) response, err := cqlConn.SendAndReceive(queryFrame) require.Nil(t, err) diff --git a/integration-tests/unavailablenode_test.go b/integration-tests/unavailablenode_test.go index 7d3f0417..3245c470 100644 --- a/integration-tests/unavailablenode_test.go +++ b/integration-tests/unavailablenode_test.go @@ -3,16 +3,19 @@ package integration_tests import ( "context" "fmt" + "strings" + "testing" + "time" + "github.com/datastax/go-cassandra-native-protocol/frame" "github.com/datastax/go-cassandra-native-protocol/message" "github.com/datastax/go-cassandra-native-protocol/primitive" + "github.com/stretchr/testify/require" + "github.com/datastax/zdm-proxy/integration-tests/client" + "github.com/datastax/zdm-proxy/integration-tests/env" "github.com/datastax/zdm-proxy/integration-tests/setup" "github.com/datastax/zdm-proxy/integration-tests/utils" - "github.com/stretchr/testify/require" - "strings" - "testing" - "time" ) // TestUnavailableNode tests if the proxy closes the client connection correctly when either cluster node connection is closed @@ -30,7 +33,7 @@ func TestUnavailableNode(t *testing.T) { require.True(t, err == nil, "testClient setup failed: %s", err) defer testClient.Shutdown() - err = testClient.PerformDefaultHandshake(context.Background(), primitive.ProtocolVersion4, false) + err = testClient.PerformDefaultHandshake(context.Background(), env.DefaultProtocolVersionSimulacron, false) require.True(t, err == nil, "No-auth handshake failed: %s", err) switch clusterNotResponding { @@ -55,7 +58,7 @@ func TestUnavailableNode(t *testing.T) { responsePtr := new(*frame.Frame) errPtr := new(error) utils.RequireWithRetries(t, func() (err error, fatal bool) { - *responsePtr, _, *errPtr = testClient.SendMessage(context.Background(), primitive.ProtocolVersion4, query) + *responsePtr, _, *errPtr = testClient.SendMessage(context.Background(), env.DefaultProtocolVersionSimulacron, query) if *responsePtr != nil { _, ok := (*responsePtr).Body.Message.(*message.Overloaded) if !ok { @@ -83,11 +86,11 @@ func TestUnavailableNode(t *testing.T) { require.True(t, err == nil, "newTestClient setup failed: %s", err) defer newTestClient.Shutdown() - err = newTestClient.PerformDefaultHandshake(context.Background(), primitive.ProtocolVersion4, false) + err = newTestClient.PerformDefaultHandshake(context.Background(), env.DefaultProtocolVersionSimulacron, false) require.True(t, err == nil, "No-auth handshake failed: %s", err) // send same query on the new connection and this time it should succeed - response, _, err = newTestClient.SendMessage(context.Background(), primitive.ProtocolVersion4, query) + response, _, err = newTestClient.SendMessage(context.Background(), env.DefaultProtocolVersionSimulacron, query) require.True(t, err == nil, "Query failed: %v", err) require.Equal( diff --git a/integration-tests/utils/testutils.go b/integration-tests/utils/testutils.go index 2c050ecd..3e8a1517 100644 --- a/integration-tests/utils/testutils.go +++ b/integration-tests/utils/testutils.go @@ -4,8 +4,8 @@ import ( "bytes" "encoding/json" "fmt" + "github.com/apache/cassandra-gocql-driver/v2" "github.com/datastax/zdm-proxy/proxy/pkg/health" - "github.com/gocql/gocql" "github.com/rs/zerolog" log "github.com/sirupsen/logrus" "github.com/stretchr/testify/require" diff --git a/integration-tests/virtualization_test.go b/integration-tests/virtualization_test.go index 243080e4..69423999 100644 --- a/integration-tests/virtualization_test.go +++ b/integration-tests/virtualization_test.go @@ -3,28 +3,32 @@ package integration_tests import ( "context" "fmt" + "math/big" + "math/rand" + "net" + "sort" + "strings" + "sync" + "testing" + "time" + + "github.com/apache/cassandra-gocql-driver/v2" "github.com/datastax/go-cassandra-native-protocol/client" "github.com/datastax/go-cassandra-native-protocol/datacodec" "github.com/datastax/go-cassandra-native-protocol/frame" "github.com/datastax/go-cassandra-native-protocol/message" "github.com/datastax/go-cassandra-native-protocol/primitive" + "github.com/google/uuid" + log "github.com/sirupsen/logrus" + "github.com/stretchr/testify/require" + + "github.com/datastax/zdm-proxy/integration-tests/ccm" "github.com/datastax/zdm-proxy/integration-tests/env" "github.com/datastax/zdm-proxy/integration-tests/setup" + "github.com/datastax/zdm-proxy/integration-tests/simulacron" "github.com/datastax/zdm-proxy/integration-tests/utils" "github.com/datastax/zdm-proxy/proxy/pkg/config" "github.com/datastax/zdm-proxy/proxy/pkg/zdmproxy" - "github.com/gocql/gocql" - "github.com/google/uuid" - log "github.com/sirupsen/logrus" - "github.com/stretchr/testify/require" - "math/big" - "math/rand" - "net" - "sort" - "strings" - "sync" - "testing" - "time" ) type connectObserver struct { @@ -136,11 +140,10 @@ func TestVirtualizationNumberOfConnections(t *testing.T) { if !exists { counter = 0 } + counter++ + hostsMap[hostAddr.String()] = counter if observedConnect.Err != nil { errors = append(errors, observedConnect.Err) - } else { - counter++ - hostsMap[hostAddr.String()] = counter } hostsMapLock.Unlock() } @@ -184,10 +187,6 @@ func TestVirtualizationNumberOfConnections(t *testing.T) { } func TestVirtualizationTokenAwareness(t *testing.T) { - if !env.RunCcmTests { - t.Skip("Test requires CCM, set RUN_CCMTESTS env variable to TRUE") - } - type test struct { name string proxyIndexes []int @@ -251,9 +250,9 @@ func TestVirtualizationTokenAwareness(t *testing.T) { }, } - origin, err := setup.GetGlobalTestClusterOrigin() + origin, err := setup.GetGlobalTestClusterOrigin(t) require.Nil(t, err) - target, err := setup.GetGlobalTestClusterTarget() + target, err := setup.GetGlobalTestClusterTarget(t) require.Nil(t, err) err = origin.GetSession().Query( @@ -377,377 +376,483 @@ CREATE TABLE system.local ( ) */ func TestInterceptedQueries(t *testing.T) { - testSetup, err := setup.NewSimulacronTestSetupWithSessionAndNodes(t, false, false, 3) - require.Nil(t, err) - defer testSetup.Cleanup() - - expectedLocalCols := []string{ - "key", "bootstrapped", "broadcast_address", "cluster_name", "cql_version", "data_center", "dse_version", "graph", - "host_id", "listen_address", "partitioner", "rack", "release_version", "rpc_address", "schema_version", "tokens", - "truncated_at", - } + for _, v := range env.AllProtocolVersions { + t.Run(v.String(), func(t *testing.T) { + var cleanupFn func() + originName := "" + var originSetup, targetSetup setup.TestCluster + var expectedLocalCols, expectedPeersCols []string + var expectedLocalVals [][]interface{} + var isCcm bool + + hostId1 := uuid.NewSHA1(uuid.Nil, net.ParseIP("127.0.0.1")) + primitiveHostId1 := primitive.UUID(hostId1) + hostId2 := uuid.NewSHA1(uuid.Nil, net.ParseIP("127.0.0.2")) + primitiveHostId2 := primitive.UUID(hostId2) + hostId3 := uuid.NewSHA1(uuid.Nil, net.ParseIP("127.0.0.3")) + primitiveHostId3 := primitive.UUID(hostId3) + + if !simulacron.SupportsProtocolVersion(v) { + if !env.SupportsProtocolVersion(v) { + t.Skipf("proto version %v not supported in current ccm cluster version %v", v.String(), env.ServerVersionLogStr) + } + cleanupFn = func() {} + var err error + originSetup, err = setup.GetGlobalTestClusterOrigin(t) + require.Nil(t, err) + originName = originSetup.(*ccm.Cluster).GetId() + targetSetup, err = setup.GetGlobalTestClusterTarget(t) + require.Nil(t, err) + if env.CompareServerVersion("4.0.0") < 0 { + // add thrift_version column + expectedLocalCols = []string{ + "key", "bootstrapped", "broadcast_address", "cluster_name", "cql_version", "data_center", + "gossip_generation", "host_id", "listen_address", "native_protocol_version", "partitioner", + "rack", "release_version", "rpc_address", "schema_version", "thrift_version", "tokens", "truncated_at", + } + expectedLocalVals = [][]interface{}{ + { + "local", "COMPLETED", net.ParseIP("127.0.0.1").To4(), originName, "3.4.7", "datacenter1", 1764262829, primitiveHostId1, + net.ParseIP("127.0.0.1").To4(), env.ProtocolVersionStr(env.ComputeDefaultProtocolVersion()), + "org.apache.cassandra.dht.Murmur3Partitioner", "rack0", env.CassandraVersion, net.ParseIP("127.0.0.1").To4(), nil, + "20", []string{"1241"}, nil, + }, + } + } else { + expectedLocalCols = []string{ + "key", "bootstrapped", "broadcast_address", "cluster_name", "cql_version", "data_center", + "gossip_generation", "host_id", "listen_address", "native_protocol_version", "partitioner", + "rack", "release_version", "rpc_address", "schema_version", "tokens", "truncated_at", + } + expectedLocalVals = [][]interface{}{ + { + "local", "COMPLETED", net.ParseIP("127.0.0.1").To4(), originName, "3.4.7", "datacenter1", 1764262829, primitiveHostId1, + net.ParseIP("127.0.0.1").To4(), env.ProtocolVersionStr(v), "org.apache.cassandra.dht.Murmur3Partitioner", "rack0", env.CassandraVersion, net.ParseIP("127.0.0.1").To4(), nil, + []string{"1241"}, nil, + }, + } + } - expectedPeersCols := []string{ - "peer", "data_center", "dse_version", "graph", "host_id", "preferred_ip", "rack", "release_version", "rpc_address", - "schema_version", "tokens", - } + expectedPeersCols = []string{ + "peer", "data_center", "host_id", "preferred_ip", "rack", "release_version", "rpc_address", + "schema_version", "tokens", + } + isCcm = true + } else { + testSetup, err := setup.NewSimulacronTestSetupWithSessionAndNodes(t, false, false, 3) + require.Nil(t, err) + cleanupFn = testSetup.Cleanup + originName = testSetup.Origin.Name + originSetup = testSetup.Origin + targetSetup = testSetup.Target + expectedLocalCols = []string{ + "key", "bootstrapped", "broadcast_address", "cluster_name", "cql_version", "data_center", "dse_version", "graph", + "host_id", "listen_address", "partitioner", "rack", "release_version", "rpc_address", "schema_version", "tokens", + "truncated_at", + } - hostId1 := uuid.NewSHA1(uuid.Nil, net.ParseIP("127.0.0.1")) - primitiveHostId1 := primitive.UUID(hostId1) - hostId2 := uuid.NewSHA1(uuid.Nil, net.ParseIP("127.0.0.2")) - primitiveHostId2 := primitive.UUID(hostId2) - hostId3 := uuid.NewSHA1(uuid.Nil, net.ParseIP("127.0.0.3")) - primitiveHostId3 := primitive.UUID(hostId3) - - numTokens := 8 - - type testDefinition struct { - query string - expectedCols []string - expectedValues [][]interface{} - errExpected message.Message - proxyInstanceCount int - connectProxyIndex int - } + expectedPeersCols = []string{ + "peer", "data_center", "dse_version", "graph", "host_id", "preferred_ip", "rack", "release_version", "rpc_address", + "schema_version", "tokens", + } + isCcm = false + } + defer cleanupFn() + + numTokens := 8 + + type testDefinition struct { + query string + expectedCols []string + expectedValuesSimulacron [][]interface{} + expectedValuesCcm [][]interface{} + errExpected message.Message + proxyInstanceCount int + connectProxyIndex int + } - tests := []testDefinition{ - { - query: "SELECT * FROM system.local", - expectedCols: expectedLocalCols, - expectedValues: [][]interface{}{ + tests := []testDefinition{ { - "local", "COMPLETED", net.ParseIP("127.0.0.1").To4(), testSetup.Origin.Name, "3.2.0", "dc1", env.DseVersion, false, primitiveHostId1, - net.ParseIP("127.0.0.1").To4(), "org.apache.cassandra.dht.Murmur3Partitioner", "rack0", env.CassandraVersion, net.ParseIP("127.0.0.1").To4(), nil, - []string{"1241"}, nil, + query: "SELECT * FROM system.local", + expectedCols: expectedLocalCols, + expectedValuesSimulacron: [][]interface{}{ + { + "local", "COMPLETED", net.ParseIP("127.0.0.1").To4(), originName, "3.2.0", "dc1", env.DseVersion, false, primitiveHostId1, + net.ParseIP("127.0.0.1").To4(), "org.apache.cassandra.dht.Murmur3Partitioner", "rack0", env.CassandraVersion, net.ParseIP("127.0.0.1").To4(), nil, + []string{"1241"}, nil, + }, + }, + expectedValuesCcm: expectedLocalVals, + errExpected: nil, + proxyInstanceCount: 3, + connectProxyIndex: 0, }, - }, - errExpected: nil, - proxyInstanceCount: 3, - connectProxyIndex: 0, - }, - { - query: "SELECT rack FROM system.local", - expectedCols: []string{"rack"}, - expectedValues: [][]interface{}{ { - "rack0", + query: "SELECT rack FROM system.local", + expectedCols: []string{"rack"}, + expectedValuesSimulacron: [][]interface{}{ + { + "rack0", + }, + }, + errExpected: nil, + proxyInstanceCount: 3, + connectProxyIndex: 0, }, - }, - errExpected: nil, - proxyInstanceCount: 3, - connectProxyIndex: 0, - }, - { - query: "SELECT rack as r FROM system.local", - expectedCols: []string{"r"}, - expectedValues: [][]interface{}{ { - "rack0", + query: "SELECT rack as r FROM system.local", + expectedCols: []string{"r"}, + expectedValuesSimulacron: [][]interface{}{ + { + "rack0", + }, + }, + errExpected: nil, + proxyInstanceCount: 3, + connectProxyIndex: 0, }, - }, - errExpected: nil, - proxyInstanceCount: 3, - connectProxyIndex: 0, - }, - { - query: "SELECT count(*) FROM system.local", - expectedCols: []string{"count"}, - expectedValues: [][]interface{}{ { - int32(1), + query: "SELECT count(*) FROM system.local", + expectedCols: []string{"count"}, + expectedValuesSimulacron: [][]interface{}{ + { + int32(1), + }, + }, + errExpected: nil, + proxyInstanceCount: 3, + connectProxyIndex: 0, }, - }, - errExpected: nil, - proxyInstanceCount: 3, - connectProxyIndex: 0, - }, - { - query: "SELECT dsa, key, asd FROM system.local", - expectedCols: nil, - expectedValues: nil, - errExpected: &message.Invalid{ErrorMessage: "Undefined column name dsa"}, - proxyInstanceCount: 3, - connectProxyIndex: 0, - }, - { - query: "SELECT dsa FROM system.local", - expectedCols: nil, - expectedValues: nil, - errExpected: &message.Invalid{ErrorMessage: "Undefined column name dsa"}, - proxyInstanceCount: 3, - connectProxyIndex: 0, - }, - { - query: "SELECT key, asd FROM system.local", - expectedCols: nil, - expectedValues: nil, - errExpected: &message.Invalid{ErrorMessage: "Undefined column name asd"}, - proxyInstanceCount: 3, - connectProxyIndex: 0, - }, - { - query: "SELECT rack as r, count(*) as c, rack FROM system.peers", - expectedCols: []string{"r", "c", "rack"}, - expectedValues: [][]interface{}{ { - "rack0", int32(2), "rack0", + query: "SELECT dsa, key, asd FROM system.local", + expectedCols: nil, + expectedValuesSimulacron: nil, + errExpected: &message.Invalid{ErrorMessage: "Undefined column name dsa"}, + proxyInstanceCount: 3, + connectProxyIndex: 0, }, - }, - errExpected: nil, - proxyInstanceCount: 3, - connectProxyIndex: 0, - }, - { - query: "SELECT * FROM system.peers", - expectedCols: expectedPeersCols, - expectedValues: [][]interface{}{ { - net.ParseIP("127.0.0.2").To4(), "dc1", env.DseVersion, false, primitiveHostId2, net.ParseIP("127.0.0.2").To4(), "rack0", env.CassandraVersion, net.ParseIP("127.0.0.2").To4(), nil, []string{"1234"}, + query: "SELECT dsa FROM system.local", + expectedCols: nil, + expectedValuesSimulacron: nil, + errExpected: &message.Invalid{ErrorMessage: "Undefined column name dsa"}, + proxyInstanceCount: 3, + connectProxyIndex: 0, }, { - net.ParseIP("127.0.0.3").To4(), "dc1", env.DseVersion, false, primitiveHostId3, net.ParseIP("127.0.0.3").To4(), "rack0", env.CassandraVersion, net.ParseIP("127.0.0.3").To4(), nil, []string{"1234"}, + query: "SELECT key, asd FROM system.local", + expectedCols: nil, + expectedValuesSimulacron: nil, + errExpected: &message.Invalid{ErrorMessage: "Undefined column name asd"}, + proxyInstanceCount: 3, + connectProxyIndex: 0, }, - }, - errExpected: nil, - proxyInstanceCount: 3, - connectProxyIndex: 0, - }, - { - query: "SELECT * FROM system.peers", - expectedCols: expectedPeersCols, - expectedValues: [][]interface{}{ { - net.ParseIP("127.0.0.1").To4(), "dc1", env.DseVersion, false, primitiveHostId1, net.ParseIP("127.0.0.1").To4(), "rack0", env.CassandraVersion, net.ParseIP("127.0.0.1").To4(), nil, []string{"1234"}, + query: "SELECT rack as r, count(*) as c, rack FROM system.peers", + expectedCols: []string{"r", "c", "rack"}, + expectedValuesSimulacron: [][]interface{}{ + { + "rack0", int32(2), "rack0", + }, + }, + errExpected: nil, + proxyInstanceCount: 3, + connectProxyIndex: 0, }, { - net.ParseIP("127.0.0.3").To4(), "dc1", env.DseVersion, false, primitiveHostId3, net.ParseIP("127.0.0.3").To4(), "rack0", env.CassandraVersion, net.ParseIP("127.0.0.3").To4(), nil, []string{"1234"}, + query: "SELECT * FROM system.peers", + expectedCols: expectedPeersCols, + expectedValuesSimulacron: [][]interface{}{ + { + net.ParseIP("127.0.0.2").To4(), "dc1", env.DseVersion, false, primitiveHostId2, net.ParseIP("127.0.0.2").To4(), "rack0", env.CassandraVersion, net.ParseIP("127.0.0.2").To4(), nil, []string{"1234"}, + }, + { + net.ParseIP("127.0.0.3").To4(), "dc1", env.DseVersion, false, primitiveHostId3, net.ParseIP("127.0.0.3").To4(), "rack0", env.CassandraVersion, net.ParseIP("127.0.0.3").To4(), nil, []string{"1234"}, + }, + }, + expectedValuesCcm: [][]interface{}{ + { + net.ParseIP("127.0.0.2").To4(), "datacenter1", primitiveHostId2, net.ParseIP("127.0.0.2").To4(), "rack0", env.CassandraVersion, net.ParseIP("127.0.0.2").To4(), nil, []string{"1234"}, + }, + { + net.ParseIP("127.0.0.3").To4(), "datacenter1", primitiveHostId3, net.ParseIP("127.0.0.3").To4(), "rack0", env.CassandraVersion, net.ParseIP("127.0.0.3").To4(), nil, []string{"1234"}, + }, + }, + errExpected: nil, + proxyInstanceCount: 3, + connectProxyIndex: 0, }, - }, - errExpected: nil, - proxyInstanceCount: 3, - connectProxyIndex: 1, - }, - { - query: "SELECT * FROM system.peers", - expectedCols: expectedPeersCols, - expectedValues: [][]interface{}{}, - errExpected: nil, - proxyInstanceCount: 1, - connectProxyIndex: 0, - }, - { - query: "SELECT rack FROM system.peers", - expectedCols: []string{"rack"}, - expectedValues: [][]interface{}{ { - "rack0", + query: "SELECT * FROM system.peers", + expectedCols: expectedPeersCols, + expectedValuesSimulacron: [][]interface{}{ + { + net.ParseIP("127.0.0.1").To4(), "dc1", env.DseVersion, false, primitiveHostId1, net.ParseIP("127.0.0.1").To4(), "rack0", env.CassandraVersion, net.ParseIP("127.0.0.1").To4(), nil, []string{"1234"}, + }, + { + net.ParseIP("127.0.0.3").To4(), "dc1", env.DseVersion, false, primitiveHostId3, net.ParseIP("127.0.0.3").To4(), "rack0", env.CassandraVersion, net.ParseIP("127.0.0.3").To4(), nil, []string{"1234"}, + }, + }, + expectedValuesCcm: [][]interface{}{ + { + net.ParseIP("127.0.0.1").To4(), "datacenter1", primitiveHostId1, net.ParseIP("127.0.0.1").To4(), "rack0", env.CassandraVersion, net.ParseIP("127.0.0.1").To4(), nil, []string{"1234"}, + }, + { + net.ParseIP("127.0.0.3").To4(), "datacenter1", primitiveHostId3, net.ParseIP("127.0.0.3").To4(), "rack0", env.CassandraVersion, net.ParseIP("127.0.0.3").To4(), nil, []string{"1234"}, + }, + }, + errExpected: nil, + proxyInstanceCount: 3, + connectProxyIndex: 1, }, { - "rack0", + query: "SELECT * FROM system.peers", + expectedCols: expectedPeersCols, + expectedValuesSimulacron: [][]interface{}{}, + errExpected: nil, + proxyInstanceCount: 1, + connectProxyIndex: 0, }, - }, - errExpected: nil, - proxyInstanceCount: 3, - connectProxyIndex: 0, - }, - { - query: "SELECT rack as r FROM system.peers", - expectedCols: []string{"r"}, - expectedValues: [][]interface{}{ { - "rack0", + query: "SELECT rack FROM system.peers", + expectedCols: []string{"rack"}, + expectedValuesSimulacron: [][]interface{}{ + { + "rack0", + }, + { + "rack0", + }, + }, + errExpected: nil, + proxyInstanceCount: 3, + connectProxyIndex: 0, }, { - "rack0", + query: "SELECT rack as r FROM system.peers", + expectedCols: []string{"r"}, + expectedValuesSimulacron: [][]interface{}{ + { + "rack0", + }, + { + "rack0", + }, + }, + errExpected: nil, + proxyInstanceCount: 3, + connectProxyIndex: 0, }, - }, - errExpected: nil, - proxyInstanceCount: 3, - connectProxyIndex: 0, - }, - { - query: "SELECT peer, count(*) FROM system.peers", - expectedCols: []string{"peer", "count"}, - expectedValues: [][]interface{}{ { - net.ParseIP("127.0.0.2").To4(), int32(2), + query: "SELECT peer, count(*) FROM system.peers", + expectedCols: []string{"peer", "count"}, + expectedValuesSimulacron: [][]interface{}{ + { + net.ParseIP("127.0.0.2").To4(), int32(2), + }, + }, + errExpected: nil, + proxyInstanceCount: 3, + connectProxyIndex: 0, }, - }, - errExpected: nil, - proxyInstanceCount: 3, - connectProxyIndex: 0, - }, - { - query: "SELECT peer, count(*), count(*) as c, peer as p FROM system.peers", - expectedCols: []string{"peer", "count", "c", "p"}, - expectedValues: [][]interface{}{ { - nil, int32(0), int32(0), nil, + query: "SELECT peer, count(*), count(*) as c, peer as p FROM system.peers", + expectedCols: []string{"peer", "count", "c", "p"}, + expectedValuesSimulacron: [][]interface{}{ + { + nil, int32(0), int32(0), nil, + }, + }, + errExpected: nil, + proxyInstanceCount: 1, + connectProxyIndex: 0, }, - }, - errExpected: nil, - proxyInstanceCount: 1, - connectProxyIndex: 0, - }, - { - query: "SELECT count(*) FROM system.peers", - expectedCols: []string{"count"}, - expectedValues: [][]interface{}{ { - int32(2), + query: "SELECT count(*) FROM system.peers", + expectedCols: []string{"count"}, + expectedValuesSimulacron: [][]interface{}{ + { + int32(2), + }, + }, + errExpected: nil, + proxyInstanceCount: 3, + connectProxyIndex: 0, }, - }, - errExpected: nil, - proxyInstanceCount: 3, - connectProxyIndex: 0, - }, - { - query: "SELECT count(*) FROM system.peers", - expectedCols: []string{"count"}, - expectedValues: [][]interface{}{ { - int32(0), + query: "SELECT count(*) FROM system.peers", + expectedCols: []string{"count"}, + expectedValuesSimulacron: [][]interface{}{ + { + int32(0), + }, + }, + errExpected: nil, + proxyInstanceCount: 1, + connectProxyIndex: 0, }, - }, - errExpected: nil, - proxyInstanceCount: 1, - connectProxyIndex: 0, - }, - { - query: "SELECT asd, peer, dsa FROM system.peers", - expectedCols: nil, - expectedValues: nil, - errExpected: &message.Invalid{ErrorMessage: "Undefined column name asd"}, - proxyInstanceCount: 3, - connectProxyIndex: 0, - }, - { - query: "SELECT asd FROM system.peers", - expectedCols: nil, - expectedValues: nil, - errExpected: &message.Invalid{ErrorMessage: "Undefined column name asd"}, - proxyInstanceCount: 3, - connectProxyIndex: 0, - }, - { - query: "SELECT peer, dsa FROM system.peers", - expectedCols: nil, - expectedValues: nil, - errExpected: &message.Invalid{ErrorMessage: "Undefined column name dsa"}, - proxyInstanceCount: 3, - connectProxyIndex: 0, - }, - { - query: "SELECT peer as p, count(*) as c, peer FROM system.peers", - expectedCols: []string{"p", "c", "peer"}, - expectedValues: [][]interface{}{ { - net.ParseIP("127.0.0.2").To4(), int32(2), net.ParseIP("127.0.0.2").To4(), + query: "SELECT asd, peer, dsa FROM system.peers", + expectedCols: nil, + expectedValuesSimulacron: nil, + errExpected: &message.Invalid{ErrorMessage: "Undefined column name asd"}, + proxyInstanceCount: 3, + connectProxyIndex: 0, }, - }, - errExpected: nil, - proxyInstanceCount: 3, - connectProxyIndex: 0, - }, - } + { + query: "SELECT asd FROM system.peers", + expectedCols: nil, + expectedValuesSimulacron: nil, + errExpected: &message.Invalid{ErrorMessage: "Undefined column name asd"}, + proxyInstanceCount: 3, + connectProxyIndex: 0, + }, + { + query: "SELECT peer, dsa FROM system.peers", + expectedCols: nil, + expectedValuesSimulacron: nil, + errExpected: &message.Invalid{ErrorMessage: "Undefined column name dsa"}, + proxyInstanceCount: 3, + connectProxyIndex: 0, + }, + { + query: "SELECT peer as p, count(*) as c, peer FROM system.peers", + expectedCols: []string{"p", "c", "peer"}, + expectedValuesSimulacron: [][]interface{}{ + { + net.ParseIP("127.0.0.2").To4(), int32(2), net.ParseIP("127.0.0.2").To4(), + }, + }, + errExpected: nil, + proxyInstanceCount: 3, + connectProxyIndex: 0, + }, + } - checkRowsResultFunc := func(t *testing.T, testVars testDefinition, queryResponseFrame *frame.Frame) { - queryRowsResult, ok := queryResponseFrame.Body.Message.(*message.RowsResult) - require.True(t, ok, queryResponseFrame.Body.Message) - require.Equal(t, len(testVars.expectedValues), len(queryRowsResult.Data)) - var resultCols []string - for _, colMetadata := range queryRowsResult.Metadata.Columns { - resultCols = append(resultCols, colMetadata.Name) - } - require.Equal(t, testVars.expectedCols, resultCols) - for i, row := range queryRowsResult.Data { - require.Equal(t, len(testVars.expectedValues[i]), len(row)) - for j, value := range row { - dcodec, err := datacodec.NewCodec(queryRowsResult.Metadata.Columns[j].Type) - require.Nil(t, err) - var dest interface{} - wasNull, err := dcodec.Decode(value, &dest, primitive.ProtocolVersion4) - require.Nil(t, err) - switch queryRowsResult.Metadata.Columns[j].Name { - case "schema_version": - require.IsType(t, primitive.UUID{}, dest) - require.NotNil(t, dest) - require.NotEqual(t, primitive.UUID{}, dest) - case "tokens": - tokens, ok := dest.([]*string) - require.True(t, ok) - require.Equal(t, numTokens, len(tokens)) - for _, token := range tokens { - require.NotNil(t, token) - require.NotEqual(t, "", *token) - } - default: - if wasNull { - require.Nil(t, testVars.expectedValues[i][j], queryRowsResult.Metadata.Columns[j].Name) - } else { - require.Equal(t, testVars.expectedValues[i][j], dest, queryRowsResult.Metadata.Columns[j].Name) + checkRowsResultFunc := func(t *testing.T, testVars testDefinition, queryResponseFrame *frame.Frame) { + queryRowsResult, ok := queryResponseFrame.Body.Message.(*message.RowsResult) + require.True(t, ok, queryResponseFrame.Body.Message) + if env.IsDse && isCcm { + // skip validation of columns when DSE is used with CCM, maybe we can add DSE columns here in the future + return + } + expectedVals := testVars.expectedValuesSimulacron + if isCcm && testVars.expectedValuesCcm != nil { + expectedVals = testVars.expectedValuesCcm + } + require.Equal(t, len(expectedVals), len(queryRowsResult.Data)) + var resultCols []string + for _, colMetadata := range queryRowsResult.Metadata.Columns { + resultCols = append(resultCols, colMetadata.Name) + } + require.Equal(t, testVars.expectedCols, resultCols) + for i, row := range queryRowsResult.Data { + require.Equal(t, len(expectedVals[i]), len(row)) + for j, value := range row { + dcodec, err := datacodec.NewCodec(queryRowsResult.Metadata.Columns[j].Type) + require.Nil(t, err) + var dest interface{} + wasNull, err := dcodec.Decode(value, &dest, queryResponseFrame.Header.Version) + require.Nil(t, err) + switch queryRowsResult.Metadata.Columns[j].Name { + case "schema_version": + require.IsType(t, primitive.UUID{}, dest) + require.NotNil(t, dest) + require.NotEqual(t, primitive.UUID{}, dest) + case "tokens": + tokens, ok := dest.([]*string) + require.True(t, ok) + require.Equal(t, numTokens, len(tokens)) + for _, token := range tokens { + require.NotNil(t, token) + require.NotEqual(t, "", *token) + } + case "gossip_generation": + gossip, ok := dest.(int32) + require.True(t, ok) + require.NotNil(t, gossip) + require.Greater(t, gossip, int32(0)) + case "cql_version": + cqlV, ok := dest.(string) + require.True(t, ok) + require.NotNil(t, cqlV) + require.NotEqual(t, "", cqlV) + case "thrift_version": + thriftV, ok := dest.(string) + require.True(t, ok) + require.NotNil(t, thriftV) + require.NotEqual(t, "", thriftV) + default: + if wasNull { + require.Nil(t, expectedVals[i][j], queryRowsResult.Metadata.Columns[j].Name) + } else { + require.Equal(t, expectedVals[i][j], dest, queryRowsResult.Metadata.Columns[j].Name) + } + } } } } - } - } - for _, testVars := range tests { - t.Run(fmt.Sprintf("%s_proxy%d_%dtotalproxies", testVars.query, testVars.connectProxyIndex, testVars.proxyInstanceCount), func(t *testing.T) { - proxyAddresses := []string{"127.0.0.1", "127.0.0.2", "127.0.0.3"} - if testVars.proxyInstanceCount == 1 { - proxyAddresses = []string{"127.0.0.1"} - } else if testVars.proxyInstanceCount != 3 { - require.Fail(t, "unsupported proxy instance count %v", testVars.proxyInstanceCount) - } - proxyAddressToConnect := fmt.Sprintf("127.0.0.%v", testVars.connectProxyIndex+1) - proxy, err := LaunchProxyWithTopologyConfig( - strings.Join(proxyAddresses, ","), testVars.connectProxyIndex, - proxyAddressToConnect, numTokens, testSetup.Origin, testSetup.Target) - require.Nil(t, err) - defer proxy.Shutdown() - - testClient := client.NewCqlClient(fmt.Sprintf("%v:14002", proxyAddressToConnect), nil) - cqlConnection, err := testClient.ConnectAndInit(context.Background(), primitive.ProtocolVersion4, 0) - require.Nil(t, err) - defer cqlConnection.Close() - - queryMsg := &message.Query{ - Query: testVars.query, - Options: nil, - } - queryFrame := frame.NewFrame(primitive.ProtocolVersion4, 0, queryMsg) - queryResponseFrame, err := cqlConnection.SendAndReceive(queryFrame) - require.Nil(t, err) - if testVars.errExpected != nil { - require.Equal(t, testVars.errExpected, queryResponseFrame.Body.Message) - } else { - checkRowsResultFunc(t, testVars, queryResponseFrame) - } + for _, testVars := range tests { + t.Run(fmt.Sprintf("%s_proxy%d_%dtotalproxies", testVars.query, testVars.connectProxyIndex, testVars.proxyInstanceCount), func(t *testing.T) { + proxyAddresses := []string{"127.0.0.1", "127.0.0.2", "127.0.0.3"} + if testVars.proxyInstanceCount == 1 { + proxyAddresses = []string{"127.0.0.1"} + } else if testVars.proxyInstanceCount != 3 { + require.Fail(t, "unsupported proxy instance count %v", testVars.proxyInstanceCount) + } + proxyAddressToConnect := fmt.Sprintf("127.0.0.%v", testVars.connectProxyIndex+1) + proxy, err := LaunchProxyWithTopologyConfig( + strings.Join(proxyAddresses, ","), testVars.connectProxyIndex, + proxyAddressToConnect, numTokens, originSetup, targetSetup) + require.Nil(t, err) + defer proxy.Shutdown() + + testClient := client.NewCqlClient(fmt.Sprintf("%v:14002", proxyAddressToConnect), nil) + testClient.ReadTimeout = 1 * time.Second + cqlConnection, err := testClient.ConnectAndInit(context.Background(), v, 0) + require.Nil(t, err) + defer cqlConnection.Close() + + queryMsg := &message.Query{ + Query: testVars.query, + Options: nil, + } + queryFrame := frame.NewFrame(v, 0, queryMsg) + queryResponseFrame, err := cqlConnection.SendAndReceive(queryFrame) + require.Nil(t, err) + if testVars.errExpected != nil { + require.Equal(t, testVars.errExpected, queryResponseFrame.Body.Message) + } else { + checkRowsResultFunc(t, testVars, queryResponseFrame) + } - prepareMsg := &message.Prepare{ - Query: testVars.query, - Keyspace: "", - } - prepareFrame := frame.NewFrame(primitive.ProtocolVersion4, 0, prepareMsg) - prepareResponseFrame, err := cqlConnection.SendAndReceive(prepareFrame) - require.Nil(t, err) - if testVars.errExpected != nil { - require.Equal(t, testVars.errExpected, prepareResponseFrame.Body.Message) - } else { - preparedMsg, ok := prepareResponseFrame.Body.Message.(*message.PreparedResult) - require.True(t, ok, prepareResponseFrame.Body.Message) - executeMsg := &message.Execute{ - QueryId: preparedMsg.PreparedQueryId, - ResultMetadataId: preparedMsg.ResultMetadataId, - Options: nil, - } - executeFrame := frame.NewFrame(primitive.ProtocolVersion4, 0, executeMsg) - executeResponseFrame, err := cqlConnection.SendAndReceive(executeFrame) - require.Nil(t, err) - checkRowsResultFunc(t, testVars, executeResponseFrame) + prepareMsg := &message.Prepare{ + Query: testVars.query, + Keyspace: "", + } + prepareFrame := frame.NewFrame(v, 0, prepareMsg) + prepareResponseFrame, err := cqlConnection.SendAndReceive(prepareFrame) + require.Nil(t, err) + if testVars.errExpected != nil { + require.Equal(t, testVars.errExpected, prepareResponseFrame.Body.Message) + } else { + preparedMsg, ok := prepareResponseFrame.Body.Message.(*message.PreparedResult) + require.True(t, ok, prepareResponseFrame.Body.Message) + executeMsg := &message.Execute{ + QueryId: preparedMsg.PreparedQueryId, + ResultMetadataId: preparedMsg.ResultMetadataId, + Options: nil, + } + executeFrame := frame.NewFrame(v, 0, executeMsg) + executeResponseFrame, err := cqlConnection.SendAndReceive(executeFrame) + require.Nil(t, err) + checkRowsResultFunc(t, testVars, executeResponseFrame) + } + }) } }) } + } func TestVirtualizationPartitioner(t *testing.T) { @@ -846,7 +951,7 @@ func TestVirtualizationPartitioner(t *testing.T) { client.NewDriverConnectionInitializationHandler("target", "dc2", func(_ string) {}), } - err = testSetup.Start(nil, false, primitive.ProtocolVersion4) + err = testSetup.Start(nil, false, env.DefaultProtocolVersion) require.Nil(t, err) validatePartitionerFromSystemLocal(t, originAddress+":9042", credentials, originPartitioner) @@ -968,7 +1073,7 @@ func computeReplicas(n int, numTokens int) []*replica { func validatePartitionerFromSystemLocal(t *testing.T, remoteEndpoint string, credentials *client.AuthCredentials, expectedPartitioner string) { testClient := client.NewCqlClient(remoteEndpoint, credentials) - cqlConn, err := testClient.ConnectAndInit(context.Background(), primitive.ProtocolVersion4, 1) + cqlConn, err := testClient.ConnectAndInit(context.Background(), env.DefaultProtocolVersion, 1) require.Nil(t, err, "testClient setup failed", err) require.NotNil(t, cqlConn, "cql connection could not be opened") defer func() { @@ -984,7 +1089,7 @@ func validatePartitionerFromSystemLocal(t *testing.T, remoteEndpoint string, cre }, } - queryFrame := frame.NewFrame(primitive.ProtocolVersion4, 0, requestMsg) + queryFrame := frame.NewFrame(env.DefaultProtocolVersion, 0, requestMsg) response, err := cqlConn.SendAndReceive(queryFrame) require.Nil(t, err) diff --git a/integration-tests/write_test.go b/integration-tests/write_test.go index cc04fddf..dd76c377 100644 --- a/integration-tests/write_test.go +++ b/integration-tests/write_test.go @@ -6,11 +6,11 @@ import ( "encoding/base64" "encoding/hex" "fmt" + "github.com/apache/cassandra-gocql-driver/v2" "github.com/datastax/go-cassandra-native-protocol/primitive" "github.com/datastax/zdm-proxy/integration-tests/setup" "github.com/datastax/zdm-proxy/integration-tests/simulacron" "github.com/datastax/zdm-proxy/integration-tests/utils" - "github.com/gocql/gocql" log "github.com/sirupsen/logrus" "github.com/stretchr/testify/require" "regexp" diff --git a/integration-tests/writecoalescer_test.go b/integration-tests/writecoalescer_test.go new file mode 100644 index 00000000..d2c9c8e5 --- /dev/null +++ b/integration-tests/writecoalescer_test.go @@ -0,0 +1,311 @@ +package integration_tests + +import ( + "context" + "fmt" + "strings" + "sync" + "testing" + "time" + + "github.com/datastax/go-cassandra-native-protocol/client" + "github.com/datastax/go-cassandra-native-protocol/frame" + "github.com/datastax/go-cassandra-native-protocol/message" + "github.com/datastax/go-cassandra-native-protocol/primitive" + "github.com/stretchr/testify/require" + + "github.com/datastax/zdm-proxy/integration-tests/setup" +) + +// TestWriteCoalescerHandlesWrittenFalse tests that the write coalescer correctly handles +// the case when the segment writer returns written=false, which happens when a frame +// cannot fit in the current segment payload buffer and needs to be written later. +func TestWriteCoalescerHandlesWrittenFalse(t *testing.T) { + // Create a config with very small write buffer sizes to force the written=false condition + conf := setup.NewTestConfig("127.0.1.1", "127.0.1.2") + + // Set extremely small buffer sizes and reduce workers to trigger written=false more frequently + conf.RequestWriteBufferSizeBytes = 256 // Extremely small buffer to force frequent flushes + conf.ResponseWriteBufferSizeBytes = 256 + conf.RequestResponseMaxWorkers = 2 // Very few workers to increase contention + conf.WriteMaxWorkers = 2 + conf.ReadMaxWorkers = 2 + + testSetup, err := setup.NewCqlServerTestSetup(t, conf, false, false, false) + require.Nil(t, err) + defer testSetup.Cleanup() + + // Create request handlers that capture all requests and return successful responses + originRequestHandler := NewRequestCapturingHandler() + targetRequestHandler := NewRequestCapturingHandler() + + testSetup.Origin.CqlServer.RequestHandlers = []client.RequestHandler{ + originRequestHandler.HandleRequest, + client.NewDriverConnectionInitializationHandler("origin", "dc1", func(_ string) {}), + } + testSetup.Target.CqlServer.RequestHandlers = []client.RequestHandler{ + targetRequestHandler.HandleRequest, + client.NewDriverConnectionInitializationHandler("target", "dc2", func(_ string) {}), + } + + err = testSetup.Start(nil, false, primitive.ProtocolVersion5) + require.Nil(t, err) + + proxy, err := setup.NewProxyInstanceWithConfig(conf) + require.Nil(t, err) + require.NotNil(t, proxy) + defer proxy.Shutdown() + + testSetup.Client.CqlClient.ReadTimeout = 5 * time.Second // Short timeout to fail fast and expose bugs quickly + cqlConn, err := testSetup.Client.CqlClient.ConnectAndInit(context.Background(), primitive.ProtocolVersion5, client.ManagedStreamId) + require.Nil(t, err, "client connection failed: %v", err) + defer cqlConn.Close() + + // Spawn multiple goroutines that concurrently send INSERT queries + // This should trigger the written=false condition and expose any race conditions + numGoroutines := 5 + queriesPerGoroutine := 10 + var wg sync.WaitGroup + errorsChan := make(chan error, numGoroutines*queriesPerGoroutine) + + for g := 0; g < numGoroutines; g++ { + wg.Add(1) + goroutineId := g + go func() { + defer wg.Done() + + for i := 0; i < queriesPerGoroutine; i++ { + // Create queries with large payloads to exceed the small buffer + largeValue := make([]byte, 400) // 400 bytes of data + for j := range largeValue { + largeValue[j] = byte('A' + (j % 26)) + } + + queryMsg := &message.Query{ + Query: fmt.Sprintf("INSERT INTO test.table (id, data) VALUES (%d, '%s')", + goroutineId*queriesPerGoroutine+i, string(largeValue)), + Options: &message.QueryOptions{ + Consistency: primitive.ConsistencyLevelOne, + }, + } + + queryFrame := frame.NewFrame(primitive.ProtocolVersion5, int16((goroutineId*queriesPerGoroutine+i)%100+1), queryMsg) + responseFrame, err := cqlConn.SendAndReceive(queryFrame) + if err != nil { + errorsChan <- fmt.Errorf("goroutine %d query %d failed: %v", goroutineId, i, err) + return + } + if responseFrame == nil { + errorsChan <- fmt.Errorf("goroutine %d query %d returned nil response", goroutineId, i) + return + } + + // Verify we got a successful response + if _, ok := responseFrame.Body.Message.(*message.VoidResult); !ok { + errorsChan <- fmt.Errorf("goroutine %d query %d did not return VoidResult", goroutineId, i) + return + } + } + }() + } + + wg.Wait() + close(errorsChan) + + // Check for errors from goroutines + var errors []error + for err := range errorsChan { + errors = append(errors, err) + } + require.Empty(t, errors, "Encountered errors during concurrent writes: %v", errors) + + totalQueries := numGoroutines * queriesPerGoroutine + + // Verify that all queries were received by both origin and target + originRequests := originRequestHandler.GetQueryRequests() + targetRequests := targetRequestHandler.GetQueryRequests() + + require.GreaterOrEqual(t, len(originRequests), totalQueries, + "origin should have received at least %d queries, got %d", totalQueries, len(originRequests)) + require.GreaterOrEqual(t, len(targetRequests), totalQueries, + "target should have received at least %d queries, got %d", totalQueries, len(targetRequests)) +} + +// TestWriteCoalescerMultipleFramesInSegment tests that multiple frames can be written +// to a segment payload when they fit, and that leftover frames are properly handled. +func TestWriteCoalescerMultipleFramesInSegment(t *testing.T) { + conf := setup.NewTestConfig("127.0.1.1", "127.0.1.2") + + // Set very small buffer sizes and reduce workers to maximize contention + conf.RequestWriteBufferSizeBytes = 512 // Small buffer to force frequent flushes + conf.ResponseWriteBufferSizeBytes = 512 + conf.RequestResponseMaxWorkers = 2 + conf.WriteMaxWorkers = 2 + conf.ReadMaxWorkers = 2 + + testSetup, err := setup.NewCqlServerTestSetup(t, conf, false, false, false) + require.Nil(t, err) + defer testSetup.Cleanup() + + originRequestHandler := NewRequestCapturingHandler() + targetRequestHandler := NewRequestCapturingHandler() + + testSetup.Origin.CqlServer.RequestHandlers = []client.RequestHandler{ + originRequestHandler.HandleRequest, + client.NewDriverConnectionInitializationHandler("origin", "dc1", func(_ string) {}), + } + testSetup.Target.CqlServer.RequestHandlers = []client.RequestHandler{ + targetRequestHandler.HandleRequest, + client.NewDriverConnectionInitializationHandler("target", "dc2", func(_ string) {}), + } + + err = testSetup.Start(nil, false, primitive.ProtocolVersion5) + require.Nil(t, err) + + proxy, err := setup.NewProxyInstanceWithConfig(conf) + require.Nil(t, err) + require.NotNil(t, proxy) + defer proxy.Shutdown() + + testSetup.Client.CqlClient.ReadTimeout = 5 * time.Second // Short timeout to fail fast and expose bugs quickly + cqlConn, err := testSetup.Client.CqlClient.ConnectAndInit(context.Background(), primitive.ProtocolVersion5, client.ManagedStreamId) + require.Nil(t, err, "client connection failed: %v", err) + defer cqlConn.Close() + + // Spawn multiple goroutines sending bursts of queries concurrently + numGoroutines := 8 + queriesPerGoroutine := 15 + var wg sync.WaitGroup + errorsChan := make(chan error, numGoroutines*queriesPerGoroutine) + + for g := 0; g < numGoroutines; g++ { + wg.Add(1) + goroutineId := g + go func() { + defer wg.Done() + + for i := 0; i < queriesPerGoroutine; i++ { + // Create moderately-sized INSERT queries with variable length data + dataSize := 200 + (i * 10) // Variable size from 200 to 340 bytes + largeValue := make([]byte, dataSize) + for j := range largeValue { + largeValue[j] = byte('A' + (j % 26)) + } + + queryMsg := &message.Query{ + Query: fmt.Sprintf("INSERT INTO test.table (id, data) VALUES (%d, '%s')", + goroutineId*queriesPerGoroutine+i, string(largeValue)), + Options: &message.QueryOptions{ + Consistency: primitive.ConsistencyLevelOne, + }, + } + + queryFrame := frame.NewFrame(primitive.ProtocolVersion5, int16((goroutineId*queriesPerGoroutine+i)%100+1), queryMsg) + responseFrame, err := cqlConn.SendAndReceive(queryFrame) + if err != nil { + errorsChan <- fmt.Errorf("goroutine %d query %d failed: %v", goroutineId, i, err) + return + } + if responseFrame == nil { + errorsChan <- fmt.Errorf("goroutine %d query %d returned nil response", goroutineId, i) + return + } + + // Verify we got a successful response + if _, ok := responseFrame.Body.Message.(*message.VoidResult); !ok { + errorsChan <- fmt.Errorf("goroutine %d query %d did not return VoidResult", goroutineId, i) + return + } + } + }() + } + + wg.Wait() + close(errorsChan) + + // Check for errors from goroutines + var errors []error + for err := range errorsChan { + errors = append(errors, err) + } + require.Empty(t, errors, "Encountered errors during concurrent writes: %v", errors) + + totalQueries := numGoroutines * queriesPerGoroutine + originRequests := originRequestHandler.GetQueryRequests() + targetRequests := targetRequestHandler.GetQueryRequests() + + require.GreaterOrEqual(t, len(originRequests), totalQueries, + "origin should have received at least %d queries, got %d", totalQueries, len(originRequests)) + require.GreaterOrEqual(t, len(targetRequests), totalQueries, + "target should have received at least %d queries, got %d", totalQueries, len(targetRequests)) +} + +// RequestCapturingHandler captures all incoming requests for verification +type RequestCapturingHandler struct { + lock *sync.Mutex + requests []*frame.Frame +} + +func NewRequestCapturingHandler() *RequestCapturingHandler { + return &RequestCapturingHandler{ + lock: &sync.Mutex{}, + requests: make([]*frame.Frame, 0), + } +} + +func (recv *RequestCapturingHandler) HandleRequest( + request *frame.Frame, + _ *client.CqlServerConnection, + _ client.RequestHandlerContext) (response *frame.Frame) { + + recv.lock.Lock() + recv.requests = append(recv.requests, request) + recv.lock.Unlock() + + // Return appropriate response based on request type + switch msg := request.Body.Message.(type) { + case *message.Query: + // Let system table queries pass through to the next handler + q := strings.ToLower(strings.TrimSpace(msg.Query)) + if strings.Contains(q, "system.local") || strings.Contains(q, "system.peers") { + return nil // Let the system tables handler deal with it + } + // Return a void result for non-system queries + return frame.NewFrame( + request.Header.Version, + request.Header.StreamId, + &message.VoidResult{}, + ) + default: + // For other request types, return nil (let other handlers deal with it) + return nil + } +} + +func (recv *RequestCapturingHandler) GetQueryRequests() []*frame.Frame { + recv.lock.Lock() + defer recv.lock.Unlock() + + queries := make([]*frame.Frame, 0) + for _, req := range recv.requests { + if _, ok := req.Body.Message.(*message.Query); ok { + queries = append(queries, req) + } + } + return queries +} + +func (recv *RequestCapturingHandler) GetAllRequests() []*frame.Frame { + recv.lock.Lock() + defer recv.lock.Unlock() + + result := make([]*frame.Frame, len(recv.requests)) + copy(result, recv.requests) + return result +} + +func (recv *RequestCapturingHandler) Clear() { + recv.lock.Lock() + defer recv.lock.Unlock() + recv.requests = make([]*frame.Frame, 0) +} diff --git a/nb-tests/cql-nb-activity.yaml b/nb-tests/cql-nb-activity.yaml index 4c1d19f8..c17b83dc 100644 --- a/nb-tests/cql-nb-activity.yaml +++ b/nb-tests/cql-nb-activity.yaml @@ -5,12 +5,37 @@ bindings: rw_value: Hash(); <int>>; ToString() -> String scenarios: + schema: run driver=cqld4 tags=phase:schema threads==1 cycles=UNDEF rampup: run driver=cqld4 tags=phase:rampup cycles=20000 write: run driver=cqld4 tags=phase:write cycles=20000 read: run driver=cqld4 tags=phase:read cycles=20000 - verify: run driver=cqld4 tags=phase:verify errors=warn,unverified->count compare=all cycles=20000 + verify: run driver=cqld3 tags=phase:verify errors=warn,unverified->count compare=all cycles=20000 blocks: + - name: schema + tags: + phase: schema + params: + prepared: false + statements: + - drop-keyspace: | + drop keyspace if exists <>; + tags: + name: drop-keyspace + - create-keyspace: | + create keyspace if not exists <> + WITH replication = {'class': 'SimpleStrategy', 'replication_factor': '<>'} + AND durable_writes = true; + tags: + name: create-keyspace + - create-table: | + create table if not exists <>.<> ( + key int, + value text, + PRIMARY KEY (key) + ); + tags: + name: create-table - name: rampup tags: phase: rampup diff --git a/nb-tests/cql-starter.yaml b/nb-tests/cql-starter.yaml new file mode 100644 index 00000000..a81ac3ea --- /dev/null +++ b/nb-tests/cql-starter.yaml @@ -0,0 +1,96 @@ +description: | + A cql-starter workload. + * Cassandra: 3.x, 4.x. + * DataStax Enterprise: 6.8.x. + * DataStax Astra. + +scenarios: + default: + schema: run driver=cql tags==block:schema threads==1 cycles==UNDEF + rampup: run driver=cql tags==block:rampup cycles===TEMPLATE(rampup-cycles,1) threads=auto + main: run driver=cql tags==block:"main.*" cycles===TEMPLATE(main-cycles,10) threads=auto + # rampdown: run driver=cql tags==block:rampdown threads==1 cycles==UNDEF + astra: + schema: run driver=cql tags==block:schema_astra threads==1 cycles==UNDEF + rampup: run driver=cql tags==block:rampup cycles===TEMPLATE(rampup-cycles,10) threads=auto + main: run driver=cql tags==block:"main.*" cycles===TEMPLATE(main-cycles,10) threads=auto + basic_check: + schema: run driver=cql tags==block:schema threads==1 cycles==UNDEF + rampup: run driver=cql tags==block:rampup cycles===TEMPLATE(rampup-cycles,10) threads=auto + main: run driver=cql tags==block:"main.*" cycles===TEMPLATE(main-cycles,10) threads=auto + +params: + a_param: "value" + +bindings: + machine_id: ElapsedNanoTime(); ToHashedUUID() -> java.util.UUID + message: Discard(); FirstLines('data/cql-starter-message.txt'); + rampup_message: ToString(); + time: ElapsedNanoTime(); Mul(1000); ToJavaInstant(); + ts: ElapsedNanoTime(); Mul(1000); + + +blocks: + schema: + params: + prepared: false + ops: + create_keyspace: | + create keyspace if not exists TEMPLATE(keyspace,starter) + WITH replication = {'class': 'SimpleStrategy', 'replication_factor': 'TEMPLATE(rf,1)'} + AND durable_writes = true; + create_table: | + create table if not exists TEMPLATE(keyspace,starter).TEMPLATE(table,cqlstarter) ( + machine_id UUID, + message text, + time timestamp, + PRIMARY KEY ((machine_id), time) + ) WITH CLUSTERING ORDER BY (time DESC); + + schema_astra: + params: + prepared: false + ops: + create_table_astra: | + create table if not exists TEMPLATE(keyspace,starter).TEMPLATE(table,cqlstarter) ( + machine_id UUID, + message text, + time timestamp, + PRIMARY KEY ((machine_id), time) + ) WITH CLUSTERING ORDER BY (time DESC); + + rampup: + params: + cl: TEMPLATE(write_cl,LOCAL_QUORUM) + idempotent: true + instrument: true + ops: + insert_rampup: | + insert into TEMPLATE(keyspace,starter).TEMPLATE(table,cqlstarter) (machine_id, message, time) + values ({machine_id}, {rampup_message}, {time}) using timestamp {ts}; + + rampdown: + ops: + truncate_table: | + truncate table TEMPLATE(keyspace,starter).TEMPLATE(table,cqlstarter); + + main_read: + params: + ratio: TEMPLATE(read_ratio,1) + cl: TEMPLATE(read_cl,LOCAL_QUORUM) + idempotent: true + instrument: true + ops: + select_read: | + select * from TEMPLATE(keyspace,starter).TEMPLATE(table,cqlstarter) + where machine_id={machine_id}; + main_write: + params: + ratio: TEMPLATE(write_ratio,9) + cl: TEMPLATE(write_cl,LOCAL_QUORUM) + idempotent: true + instrument: true + ops: + insert_main: | + insert into TEMPLATE(keyspace,starter).TEMPLATE(table,cqlstarter) + (machine_id, message, time) values ({machine_id}, {message}, {time}) using timestamp {ts}; diff --git a/proxy/launch.go b/proxy/launch.go index 353b7a7f..a811589c 100644 --- a/proxy/launch.go +++ b/proxy/launch.go @@ -4,16 +4,18 @@ import ( "context" "flag" "fmt" - "github.com/datastax/zdm-proxy/proxy/pkg/config" - "github.com/datastax/zdm-proxy/proxy/pkg/runner" - log "github.com/sirupsen/logrus" "os" "os/signal" "syscall" + + log "github.com/sirupsen/logrus" + + "github.com/datastax/zdm-proxy/proxy/pkg/config" + "github.com/datastax/zdm-proxy/proxy/pkg/runner" ) // TODO: to be managed externally -const ZdmVersionString = "2.3.4" +const ZdmVersionString = "2.4.0" var displayVersion = flag.Bool("version", false, "display the ZDM proxy version and exit") var configFile = flag.String("config", "", "specify path to ZDM configuration file") diff --git a/proxy/pkg/config/config.go b/proxy/pkg/config/config.go index c38ead63..fe3cebdb 100644 --- a/proxy/pkg/config/config.go +++ b/proxy/pkg/config/config.go @@ -3,16 +3,20 @@ package config import ( "encoding/json" "fmt" + "net" + "os" + "slices" + "strconv" + "strings" + "sync" + "github.com/datastax/go-cassandra-native-protocol/primitive" - "github.com/datastax/zdm-proxy/proxy/pkg/common" "github.com/kelseyhightower/envconfig" def "github.com/mcuadros/go-defaults" log "github.com/sirupsen/logrus" "gopkg.in/yaml.v3" - "net" - "os" - "strconv" - "strings" + + "github.com/datastax/zdm-proxy/proxy/pkg/common" ) // Config holds the values of environment variables necessary for proper Proxy function. @@ -26,6 +30,7 @@ type Config struct { AsyncHandshakeTimeoutMs int `default:"4000" split_words:"true" yaml:"async_handshake_timeout_ms"` LogLevel string `default:"INFO" split_words:"true" yaml:"log_level"` ControlConnMaxProtocolVersion string `default:"DseV2" split_words:"true" yaml:"control_conn_max_protocol_version"` // Numeric Cassandra OSS protocol version or DseV1 / DseV2 + BlockedProtocolVersions string `default:"" split_words:"true" yaml:"blocked_protocol_versions"` // Tracing (also known as distributed tracing - request id generation and logging) @@ -328,6 +333,11 @@ func (c *Config) Validate() error { return err } + _, err = c.ParseBlockedProtocolVersions() + if err != nil { + return err + } + return nil } @@ -392,10 +402,10 @@ func (c *Config) ParseControlConnMaxProtocolVersion() (primitive.ProtocolVersion ver, err := strconv.ParseUint(c.ControlConnMaxProtocolVersion, 10, 32) if err != nil { return 0, fmt.Errorf("could not parse control connection max protocol version, valid values are "+ - "2, 3, 4, DseV1, DseV2; original err: %w", err) + "2, 3, 4, 5, DseV1, DseV2; original err: %w", err) } - if ver < 2 || ver > 4 { - return 0, fmt.Errorf("invalid control connection max protocol version, valid values are 2, 3, 4, DseV1, DseV2") + if ver < 2 || ver > 5 { + return 0, fmt.Errorf("invalid control connection max protocol version, valid values are 2, 3, 4, 5, DseV1, DseV2") } return primitive.ProtocolVersion(ver), nil } @@ -667,3 +677,68 @@ func isDefined(propertyValue string) bool { func isNotDefined(propertyValue string) bool { return !isDefined(propertyValue) } + +func (c *Config) ParseBlockedProtocolVersions() ([]primitive.ProtocolVersion, error) { + if isNotDefined(c.BlockedProtocolVersions) { + return []primitive.ProtocolVersion{}, nil + } + + versionsStr := strings.Split(c.BlockedProtocolVersions, ",") + versions := make([]primitive.ProtocolVersion, 0, len(versionsStr)) + for _, v := range versionsStr { + trimmed := strings.TrimSpace(v) + if trimmed == "" { + continue + } + parsedVersion, err := parseProtocolVersion(trimmed) + if err != nil { + return nil, fmt.Errorf("invalid value for ZDM_BLOCKED_PROTOCOL_VERSIONS (%v); possible values are: %v (case insensitive)", + trimmed, supportedProtocolVersionsStr()) + } + versions = append(versions, parsedVersion) + } + return versions, nil +} + +var protocolVersionStrMap = map[primitive.ProtocolVersion][]string{ + primitive.ProtocolVersion2: {"2", "v2"}, + primitive.ProtocolVersion3: {"3", "v3"}, + primitive.ProtocolVersion4: {"4", "v4"}, + primitive.ProtocolVersion5: {"5", "v5"}, + primitive.ProtocolVersionDse1: {"DseV1", "Dse_V1"}, + primitive.ProtocolVersionDse2: {"DseV2", "Dse_V2"}, +} + +var supportedProtocolVersionsStr = sync.OnceValue[[]string]( + func() []string { + versionsStr := make([]string, 0) + for _, strSlice := range protocolVersionStrMap { + for _, str := range strSlice { + versionsStr = append(versionsStr, str) + } + } + slices.Sort(versionsStr) + return versionsStr + }) + +var lowerCaseProtocolVersionsMap = sync.OnceValue[map[string]primitive.ProtocolVersion]( + func() map[string]primitive.ProtocolVersion { + m := make(map[string]primitive.ProtocolVersion) + for v, strSlice := range protocolVersionStrMap { + for _, str := range strSlice { + m[strings.ToLower(str)] = v + } + } + return m + }) + +func parseProtocolVersion(versionStr string) (primitive.ProtocolVersion, error) { + blockableProtocolVersions := lowerCaseProtocolVersionsMap() + lowerCaseVersionStr := strings.ToLower(versionStr) + matchedVersion, ok := blockableProtocolVersions[lowerCaseVersionStr] + if !ok { + return 0, fmt.Errorf("unrecognized protocol version (%s), allowed versions are %v (case insensitive)", + versionStr, supportedProtocolVersionsStr()) + } + return matchedVersion, nil +} diff --git a/proxy/pkg/config/config_blockedversions_test.go b/proxy/pkg/config/config_blockedversions_test.go new file mode 100644 index 00000000..83b09c2d --- /dev/null +++ b/proxy/pkg/config/config_blockedversions_test.go @@ -0,0 +1,122 @@ +package config + +import ( + "testing" + + "github.com/datastax/go-cassandra-native-protocol/primitive" + "github.com/stretchr/testify/require" +) + +func TestConfig_ParseBlockedProtocolVersions(t *testing.T) { + + type test struct { + name string + envVars []envVar + expectedBlockedVersions []primitive.ProtocolVersion + errExpected bool + errMsg string + } + + tests := []test{ + { + name: "Valid: no versions blocked", + envVars: []envVar{{"ZDM_BLOCKED_PROTOCOL_VERSIONS", ""}}, + expectedBlockedVersions: []primitive.ProtocolVersion{}, + errExpected: false, + errMsg: "", + }, + { + name: "Valid: no versions blocked (with spaces)", + envVars: []envVar{{"ZDM_BLOCKED_PROTOCOL_VERSIONS", ", , ,"}}, + expectedBlockedVersions: []primitive.ProtocolVersion{}, + errExpected: false, + errMsg: "", + }, + { + name: "Valid: v5 blocked", + envVars: []envVar{{"ZDM_BLOCKED_PROTOCOL_VERSIONS", "v5"}}, + expectedBlockedVersions: []primitive.ProtocolVersion{primitive.ProtocolVersion5}, + errExpected: false, + errMsg: "", + }, + { + name: "Valid: v2, v3, v5 blocked", + envVars: []envVar{{"ZDM_BLOCKED_PROTOCOL_VERSIONS", "v2,v3,v5"}}, + expectedBlockedVersions: []primitive.ProtocolVersion{0x2, 0x3, 0x5}, + errExpected: false, + errMsg: "", + }, + { + name: "Valid: 2, 3, 5 blocked", + envVars: []envVar{{"ZDM_BLOCKED_PROTOCOL_VERSIONS", "2,3,5"}}, + expectedBlockedVersions: []primitive.ProtocolVersion{0x2, 0x3, 0x5}, + errExpected: false, + errMsg: "", + }, + { + name: "Valid: 2, V3, 5 blocked", + envVars: []envVar{{"ZDM_BLOCKED_PROTOCOL_VERSIONS", "2,V3,5"}}, + expectedBlockedVersions: []primitive.ProtocolVersion{0x2, 0x3, 0x5}, + errExpected: false, + errMsg: "", + }, + { + name: "Valid: 2,v2,3,v3,4,v4,5,v5,dsev1,dse_v1,dsev2,dse_v2 blocked", + envVars: []envVar{{"ZDM_BLOCKED_PROTOCOL_VERSIONS", "2,v2,3,v3,4,v4,5,v5,DSEv1,DSE_V1,DSEV2,DSE_V2"}}, + expectedBlockedVersions: []primitive.ProtocolVersion{0x2, 0x2, 0x3, 0x3, 0x4, 0x4, 0x5, 0x5, + primitive.ProtocolVersionDse1, primitive.ProtocolVersionDse1, primitive.ProtocolVersionDse2, primitive.ProtocolVersionDse2}, + errExpected: false, + errMsg: "", + }, + { + name: "Invalid: unrecognized v1", + envVars: []envVar{{"ZDM_BLOCKED_PROTOCOL_VERSIONS", "v1"}}, + expectedBlockedVersions: nil, + errExpected: true, + errMsg: "invalid value for ZDM_BLOCKED_PROTOCOL_VERSIONS (v1); possible values are: [2 3 4 5 DseV1 DseV2 Dse_V1 Dse_V2 v2 v3 v4 v5] (case insensitive)", + }, + { + name: "Invalid: unrecognized sdasd", + envVars: []envVar{{"ZDM_BLOCKED_PROTOCOL_VERSIONS", "v2, sdasd"}}, + expectedBlockedVersions: nil, + errExpected: true, + errMsg: "invalid value for ZDM_BLOCKED_PROTOCOL_VERSIONS (sdasd); possible values are: [2 3 4 5 DseV1 DseV2 Dse_V1 Dse_V2 v2 v3 v4 v5] (case insensitive)", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + clearAllEnvVars() + + // set test-specific env vars + for _, envVar := range tt.envVars { + setEnvVar(envVar.vName, envVar.vValue) + } + + // set other general env vars + setOriginCredentialsEnvVars() + setTargetCredentialsEnvVars() + setOriginContactPointsAndPortEnvVars() + setTargetContactPointsAndPortEnvVars() + + conf, err := New().LoadConfig("") + if err != nil { + if tt.errExpected { + require.Equal(t, tt.errMsg, err.Error()) + return + } else { + t.Fatal("Unexpected configuration validation error, stopping test here") + } + } + + if conf == nil { + t.Fatal("No configuration validation error was thrown but the parsed configuration is null, stopping test here") + } else { + blockedVersions, err := conf.ParseBlockedProtocolVersions() + require.Nil(t, err) // validate should have failed before if err is expected + require.Equal(t, tt.expectedBlockedVersions, blockedVersions) + } + }) + } + +} diff --git a/proxy/pkg/config/config_test.go b/proxy/pkg/config/config_test.go index 74eaa557..35322da9 100644 --- a/proxy/pkg/config/config_test.go +++ b/proxy/pkg/config/config_test.go @@ -1,9 +1,10 @@ package config import ( + "testing" + "github.com/datastax/go-cassandra-native-protocol/primitive" "github.com/stretchr/testify/require" - "testing" ) func TestTargetConfig_WithBundleOnly(t *testing.T) { @@ -135,6 +136,12 @@ func TestTargetConfig_ParsingControlConnMaxProtocolVersion(t *testing.T) { parsedProtocolVersion: primitive.ProtocolVersion4, errorMessage: "", }, + { + name: "ParsedV5", + controlConnMaxProtocolVersion: "5", + parsedProtocolVersion: primitive.ProtocolVersion5, + errorMessage: "", + }, { name: "ParsedDse1", controlConnMaxProtocolVersion: "DseV1", @@ -153,23 +160,17 @@ func TestTargetConfig_ParsingControlConnMaxProtocolVersion(t *testing.T) { parsedProtocolVersion: primitive.ProtocolVersionDse2, errorMessage: "", }, - { - name: "UnsupportedCassandraV5", - controlConnMaxProtocolVersion: "5", - parsedProtocolVersion: 0, - errorMessage: "invalid control connection max protocol version, valid values are 2, 3, 4, DseV1, DseV2", - }, { name: "UnsupportedCassandraV1", controlConnMaxProtocolVersion: "1", parsedProtocolVersion: 0, - errorMessage: "invalid control connection max protocol version, valid values are 2, 3, 4, DseV1, DseV2", + errorMessage: "invalid control connection max protocol version, valid values are 2, 3, 4, 5, DseV1, DseV2", }, { name: "InvalidValue", controlConnMaxProtocolVersion: "Dsev123", parsedProtocolVersion: 0, - errorMessage: "could not parse control connection max protocol version, valid values are 2, 3, 4, DseV1, DseV2", + errorMessage: "could not parse control connection max protocol version, valid values are 2, 3, 4, 5, DseV1, DseV2", }, } diff --git a/proxy/pkg/zdmproxy/clientconn.go b/proxy/pkg/zdmproxy/clientconn.go index 0c6c9403..e93fcea2 100644 --- a/proxy/pkg/zdmproxy/clientconn.go +++ b/proxy/pkg/zdmproxy/clientconn.go @@ -1,17 +1,18 @@ package zdmproxy import ( - "bufio" "context" "fmt" + "net" + "sync" + "sync/atomic" + "github.com/datastax/go-cassandra-native-protocol/frame" "github.com/datastax/go-cassandra-native-protocol/message" "github.com/datastax/go-cassandra-native-protocol/primitive" - "github.com/datastax/zdm-proxy/proxy/pkg/config" log "github.com/sirupsen/logrus" - "net" - "sync" - "sync/atomic" + + "github.com/datastax/zdm-proxy/proxy/pkg/config" ) const ClientConnectorLogPrefix = "CLIENT-CONNECTOR" @@ -35,6 +36,9 @@ type ClientConnector struct { // configuration object of the proxy conf *config.Config + // protocol versions blocked through configuration + blockedProtoVersions []primitive.ProtocolVersion + // channel on which the ClientConnector sends requests as it receives them from the client requestChannel chan<- *frame.RawFrame @@ -58,12 +62,14 @@ type ClientConnector struct { shutdownRequestCtx context.Context minProtoVer primitive.ProtocolVersion - compression *atomic.Value + + codecHelper *connCodecHelper } func NewClientConnector( connection net.Conn, conf *config.Config, + blockedProtoVersions []primitive.ProtocolVersion, localClientHandlerWg *sync.WaitGroup, requestsChan chan<- *frame.RawFrame, clientHandlerContext context.Context, @@ -78,9 +84,11 @@ func NewClientConnector( minProtoVer primitive.ProtocolVersion, compression *atomic.Value) *ClientConnector { + codecHelper := newConnCodecHelper(connection, connection.RemoteAddr().String(), conf.RequestReadBufferSizeBytes, conf.RequestWriteBufferSizeBytes, compression, clientHandlerContext) return &ClientConnector{ connection: connection, conf: conf, + blockedProtoVersions: blockedProtoVersions, requestChannel: requestsChan, clientHandlerWg: localClientHandlerWg, clientHandlerContext: clientHandlerContext, @@ -94,7 +102,8 @@ func NewClientConnector( ClientConnectorLogPrefix, false, false, - writeScheduler), + writeScheduler, + codecHelper), responsesDoneChan: responsesDoneChan, requestsDoneCtx: requestsDoneCtx, eventsDoneChan: eventsDoneChan, @@ -103,7 +112,7 @@ func NewClientConnector( shutdownRequestCtx: shutdownRequestCtx, clientHandlerShutdownRequestCancelFn: clientHandlerShutdownRequestCancelFn, minProtoVer: minProtoVer, - compression: compression, + codecHelper: codecHelper, } } @@ -176,26 +185,26 @@ func (cc *ClientConnector) listenForRequests() { setDrainModeNowFunc() }() - bufferedReader := bufio.NewReaderSize(cc.connection, cc.conf.RequestWriteBufferSizeBytes) connectionAddr := cc.connection.RemoteAddr().String() protocolErrOccurred := false var alreadySentProtocolErr *frame.RawFrame for cc.clientHandlerContext.Err() == nil { - f, err := readRawFrame(bufferedReader, connectionAddr, cc.clientHandlerContext) - - protocolErrResponseFrame, err, _ := checkProtocolError(f, cc.minProtoVer, cc.getCompression(), err, protocolErrOccurred, ClientConnectorLogPrefix) + f, _, err := cc.codecHelper.ReadRawFrame() + protocolErrResponseFrame, err, _ := checkProtocolError( + f, cc.minProtoVer, cc.blockedProtoVersions, cc.codecHelper.GetCompression(), err, protocolErrOccurred, ClientConnectorLogPrefix) if err != nil { handleConnectionError( err, cc.clientHandlerContext, cc.clientHandlerCancelFunc, ClientConnectorLogPrefix, "reading", connectionAddr) break } else if protocolErrResponseFrame != nil { + protocolErrResponseFrame.Header.StreamId = 0 alreadySentProtocolErr = protocolErrResponseFrame protocolErrOccurred = true cc.sendResponseToClient(protocolErrResponseFrame) continue } else if alreadySentProtocolErr != nil { clonedProtocolErr := alreadySentProtocolErr.DeepCopy() - clonedProtocolErr.Header.StreamId = f.Header.StreamId + clonedProtocolErr.Header.StreamId = 0 cc.sendResponseToClient(clonedProtocolErr) continue } @@ -223,7 +232,7 @@ func (cc *ClientConnector) sendOverloadedToClient(request *frame.RawFrame) { ErrorMessage: "Shutting down, please retry on next host.", } response := frame.NewFrame(request.Header.Version, request.Header.StreamId, msg) - rawResponse, err := codecs[cc.getCompression()].ConvertToRawFrame(response) + rawResponse, err := frameCodecs[cc.codecHelper.GetCompression()].ConvertToRawFrame(response) if err != nil { log.Errorf("[%s] Could not convert frame (%v) to raw frame: %v", ClientConnectorLogPrefix, response, err) } else { @@ -231,8 +240,10 @@ func (cc *ClientConnector) sendOverloadedToClient(request *frame.RawFrame) { } } -func checkProtocolError(f *frame.RawFrame, protoVer primitive.ProtocolVersion, compression primitive.Compression, - connErr error, protocolErrorOccurred bool, prefix string) (protocolErrResponse *frame.RawFrame, fatalErr error, errorCode int8) { +func checkProtocolError( + f *frame.RawFrame, protoVer primitive.ProtocolVersion, blockedVersions []primitive.ProtocolVersion, + compression primitive.Compression, connErr error, protocolErrorOccurred bool, prefix string) ( + protocolErrResponse *frame.RawFrame, fatalErr error, errorCode int8) { var protocolErrMsg *message.ProtocolError var streamId int16 var logMsg string @@ -242,8 +253,8 @@ func checkProtocolError(f *frame.RawFrame, protoVer primitive.ProtocolVersion, c streamId = 0 errorCode = ProtocolErrorDecodeError } else { - protocolErrMsg = checkProtocolVersion(f.Header.Version) - logMsg = "Protocol v5 detected while decoding a frame." + protocolErrMsg = checkProtocolVersion(f.Header.Version, blockedVersions) + logMsg = fmt.Sprintf("Protocol %v detected while decoding a frame.", f.Header.Version) streamId = f.Header.StreamId errorCode = ProtocolErrorUnsupportedVersion } @@ -266,7 +277,7 @@ func checkProtocolError(f *frame.RawFrame, protoVer primitive.ProtocolVersion, c func generateProtocolErrorResponseFrame(streamId int16, protoVer primitive.ProtocolVersion, compression primitive.Compression, protocolErrMsg *message.ProtocolError) (*frame.RawFrame, error) { response := frame.NewFrame(protoVer, streamId, protocolErrMsg) - rawResponse, err := codecs[compression].ConvertToRawFrame(response) + rawResponse, err := frameCodecs[compression].ConvertToRawFrame(response) if err != nil { return nil, err } @@ -277,7 +288,3 @@ func generateProtocolErrorResponseFrame(streamId int16, protoVer primitive.Proto func (cc *ClientConnector) sendResponseToClient(frame *frame.RawFrame) { cc.writeCoalescer.Enqueue(frame) } - -func (cc *ClientConnector) getCompression() primitive.Compression { - return cc.compression.Load().(primitive.Compression) -} diff --git a/proxy/pkg/zdmproxy/clienthandler.go b/proxy/pkg/zdmproxy/clienthandler.go index d9ac22f8..26c4a654 100644 --- a/proxy/pkg/zdmproxy/clienthandler.go +++ b/proxy/pkg/zdmproxy/clienthandler.go @@ -6,20 +6,23 @@ import ( "encoding/hex" "errors" "fmt" - "github.com/datastax/go-cassandra-native-protocol/frame" - "github.com/datastax/go-cassandra-native-protocol/message" - "github.com/datastax/go-cassandra-native-protocol/primitive" - "github.com/datastax/zdm-proxy/proxy/pkg/common" - "github.com/datastax/zdm-proxy/proxy/pkg/config" - "github.com/datastax/zdm-proxy/proxy/pkg/metrics" - "github.com/google/uuid" - log "github.com/sirupsen/logrus" "net" + "slices" "sort" "strings" "sync" "sync/atomic" "time" + + "github.com/datastax/go-cassandra-native-protocol/frame" + "github.com/datastax/go-cassandra-native-protocol/message" + "github.com/datastax/go-cassandra-native-protocol/primitive" + "github.com/google/uuid" + log "github.com/sirupsen/logrus" + + "github.com/datastax/zdm-proxy/proxy/pkg/common" + "github.com/datastax/zdm-proxy/proxy/pkg/config" + "github.com/datastax/zdm-proxy/proxy/pkg/metrics" ) /* @@ -130,6 +133,7 @@ func NewClientHandler( originControlConn *ControlConn, targetControlConn *ControlConn, conf *config.Config, + blockedProtoVersions []primitive.ProtocolVersion, topologyConfig *common.TopologyConfig, targetUsername string, targetPassword string, @@ -278,6 +282,7 @@ func NewClientHandler( clientConnector: NewClientConnector( clientTcpConn, conf, + blockedProtoVersions, localClientHandlerWg, requestsChannel, clientHandlerContext, @@ -439,7 +444,7 @@ func (ch *ClientHandler) requestLoop() { if ready { ch.handshakeDone.Store(true) log.Infof( - "Handshake successful with client %s", connectionAddr) + "Handshake successful with client %s (%v, Compression: %v)", connectionAddr, f.Header.Version.String(), ch.getCompression()) } log.Tracef("ready? %t", ready) } else { @@ -688,6 +693,12 @@ func (ch *ClientHandler) tryProcessProtocolError(response *Response, protocolErr log.Debugf("[ClientHandler] Protocol version downgrade detected (%v) on %v, forwarding it to the client.", errMsg, response.connectorType) } + + // some clients might require stream id 0 on protocol errors (it's what C* does, or at least some C* versions) + // gocql for example has a bug where protocol version negotiation will fail if stream id of the protocol error isn't 0 + // https://issues.apache.org/jira/browse/CASSGO-97 + response.responseFrame.Header.StreamId = 0 + ch.clientConnector.sendResponseToClient(response.responseFrame) } return true @@ -699,7 +710,7 @@ func (ch *ClientHandler) tryProcessProtocolError(response *Response, protocolErr func decodeError(responseFrame *frame.RawFrame, compression primitive.Compression) (message.Error, error) { if responseFrame != nil && responseFrame.Header.OpCode == primitive.OpCodeError { - body, err := codecs[compression].DecodeBody( + body, err := frameCodecs[compression].DecodeBody( responseFrame.Header, bytes.NewReader(responseFrame.Body)) if err != nil { @@ -1114,6 +1125,17 @@ func (ch *ClientHandler) handleHandshakeRequest(request *frame.RawFrame, wg *syn if newAuthFrame != nil { request = newAuthFrame } + } else if request.Header.OpCode == primitive.OpCodeStartup { + clientStartup, err := defaultFrameCodec.DecodeBody(request.Header, bytes.NewReader(request.Body)) + if err != nil { + scheduledTaskChannel <- &handshakeRequestResult{ + authSuccess: false, + err: fmt.Errorf("failed to decode startup message: %w", err), + } + } + compression := clientStartup.Message.(*message.Startup).GetCompression() + ch.setCompression(compression) + ch.startupRequest.Store(request) } responseChan := make(chan *customResponse, 1) @@ -1180,15 +1202,7 @@ func (ch *ClientHandler) handleHandshakeRequest(request *frame.RawFrame, wg *syn ch.secondaryStartupResponse = secondaryResponse - clientStartup, err := defaultCodec.DecodeBody(request.Header, bytes.NewReader(request.Body)) - if err != nil { - return false, fmt.Errorf("failed to decode startup message: %w", err) - } - ch.setCompression(clientStartup.Message.(*message.Startup).GetCompression()) - - ch.startupRequest.Store(request) - - err = validateSecondaryStartupResponse(secondaryResponse, secondaryCluster) + err := validateSecondaryStartupResponse(secondaryResponse, secondaryCluster) if err != nil { return false, fmt.Errorf("unsuccessful startup on %v: %w", secondaryCluster, err) } @@ -1205,7 +1219,7 @@ func (ch *ClientHandler) handleHandshakeRequest(request *frame.RawFrame, wg *syn err: nil, } if aggregatedResponse.Header.OpCode == primitive.OpCodeReady || aggregatedResponse.Header.OpCode == primitive.OpCodeAuthSuccess { - // target handshake must happen within a single client request lifetime + // secondary handshake must happen within a single client request lifetime // to guarantee that no other request with the same // stream id goes to target in the meantime @@ -1349,7 +1363,9 @@ func (ch *ClientHandler) startSecondaryHandshake(asyncConnector bool) (chan erro } startupFrame := startupFrameInterface.(*frame.RawFrame) startupResponse := ch.secondaryStartupResponse - if startupResponse == nil { + if asyncConnector { + startupResponse = nil + } else if startupResponse == nil { return nil, errors.New("can not start secondary handshake before a Startup response was received") } @@ -1996,7 +2012,7 @@ func (ch *ClientHandler) aggregateAndTrackResponses( }, } buf := &bytes.Buffer{} - err := defaultCodec.EncodeBody(newHeader, newBody, buf) + err := ch.getCodec().EncodeBody(newHeader, newBody, buf) if err != nil { log.Errorf("Failed to encode OPTIONS body: %v", err) return responseFromTargetCassandra, common.ClusterTypeTarget @@ -2168,11 +2184,11 @@ func (ch *ClientHandler) setCompression(compression primitive.Compression) { } func (ch *ClientHandler) getCodec() frame.RawCodec { - return codecs[ch.getCompression()] + return frameCodecs[ch.getCompression()] } func decodeErrorResult(frame *frame.RawFrame, compression primitive.Compression) (message.Error, error) { - body, err := codecs[compression].DecodeBody(frame.Header, bytes.NewReader(frame.Body)) + body, err := frameCodecs[compression].DecodeBody(frame.Header, bytes.NewReader(frame.Body)) if err != nil { return nil, fmt.Errorf("could not decode error body: %w", err) } @@ -2199,7 +2215,7 @@ func createUnpreparedFrame(errVal *UnpreparedExecuteError, compression primitive f := frame.NewFrame(errVal.Header.Version, errVal.Header.StreamId, unpreparedMsg) f.Body.TracingId = errVal.Body.TracingId - rawFrame, err := codecs[compression].ConvertToRawFrame(f) + rawFrame, err := frameCodecs[compression].ConvertToRawFrame(f) if err != nil { return nil, fmt.Errorf("could not convert unprepared response frame to rawframe: %w", err) } @@ -2335,9 +2351,21 @@ func checkUnsupportedProtocolError(err error) *message.ProtocolError { return nil } -// checkProtocolVersion handles the case where the protocol library does not return an error but the proxy does not support a specific version -func checkProtocolVersion(version primitive.ProtocolVersion) *message.ProtocolError { - if version < primitive.ProtocolVersion5 || version.IsDse() { +func createStandardUnsupportedVersionString(version primitive.ProtocolVersion) string { + return fmt.Sprintf("Invalid or unsupported protocol version (%d)", version) +} + +// checkProtocolVersion handles the case where the protocol library does not return an error but the proxy does not support (or blocks) a specific version +func checkProtocolVersion(version primitive.ProtocolVersion, blockedVersions []primitive.ProtocolVersion) *message.ProtocolError { + if slices.Contains(blockedVersions, version) { + return &message.ProtocolError{ErrorMessage: createStandardUnsupportedVersionString(version)} + } + + if version.IsDse() { + return nil + } + + if version >= primitive.ProtocolVersion2 && version <= primitive.ProtocolVersion5 { return nil } diff --git a/proxy/pkg/zdmproxy/clusterconn.go b/proxy/pkg/zdmproxy/clusterconn.go index 26d3e2ad..39cf6d0f 100644 --- a/proxy/pkg/zdmproxy/clusterconn.go +++ b/proxy/pkg/zdmproxy/clusterconn.go @@ -1,23 +1,24 @@ package zdmproxy import ( - "bufio" "context" "encoding/hex" "errors" "fmt" + "io" + "net" + "sync" + "sync/atomic" + "time" + "github.com/datastax/go-cassandra-native-protocol/frame" "github.com/datastax/go-cassandra-native-protocol/message" "github.com/datastax/go-cassandra-native-protocol/primitive" + log "github.com/sirupsen/logrus" + "github.com/datastax/zdm-proxy/proxy/pkg/common" "github.com/datastax/zdm-proxy/proxy/pkg/config" "github.com/datastax/zdm-proxy/proxy/pkg/metrics" - log "github.com/sirupsen/logrus" - "io" - "net" - "sync" - "sync/atomic" - "time" ) type ClusterConnectionInfo struct { @@ -60,10 +61,9 @@ type ClusterConnector struct { cancelFunc context.CancelFunc responseChan chan<- *Response - responseReadBufferSizeBytes int - writeCoalescer *writeCoalescer - doneChan chan bool - frameProcessor FrameProcessor + writeCoalescer *writeCoalescer + doneChan chan bool + frameProcessor FrameProcessor handshakeDone *atomic.Value @@ -76,8 +76,9 @@ type ClusterConnector struct { lastHeartbeatTime *atomic.Value lastHeartbeatLock sync.Mutex - ccProtoVer primitive.ProtocolVersion - compression *atomic.Value + ccProtoVer primitive.ProtocolVersion + + codecHelper *connCodecHelper } func NewClusterConnectionInfo(connConfig ConnectionConfig, endpointConfig Endpoint, isOriginCassandra bool) *ClusterConnectionInfo { @@ -153,6 +154,7 @@ func NewClusterConnector( // Initialize heartbeat time lastHeartbeatTime := &atomic.Value{} lastHeartbeatTime.Store(time.Now()) + codecHelper := newConnCodecHelper(conn, conn.RemoteAddr().String(), conf.ResponseReadBufferSizeBytes, conf.ResponseWriteBufferSizeBytes, compression, clusterConnCtx) return &ClusterConnector{ conf: conf, @@ -175,19 +177,19 @@ func NewClusterConnector( string(connectorType), true, asyncConnector, - writeScheduler), - responseChan: responseChan, - frameProcessor: frameProcessor, - responseReadBufferSizeBytes: conf.ResponseReadBufferSizeBytes, - doneChan: make(chan bool), - readScheduler: readScheduler, - asyncConnector: asyncConnector, - asyncConnectorState: ConnectorStateHandshake, - asyncPendingRequests: asyncPendingRequests, - handshakeDone: handshakeDone, - lastHeartbeatTime: lastHeartbeatTime, - ccProtoVer: ccProtoVer, - compression: compression, + writeScheduler, + codecHelper), + responseChan: responseChan, + frameProcessor: frameProcessor, + doneChan: make(chan bool), + readScheduler: readScheduler, + asyncConnector: asyncConnector, + asyncConnectorState: ConnectorStateHandshake, + asyncPendingRequests: asyncPendingRequests, + handshakeDone: handshakeDone, + lastHeartbeatTime: lastHeartbeatTime, + ccProtoVer: ccProtoVer, + codecHelper: codecHelper, }, nil } @@ -252,14 +254,15 @@ func (cc *ClusterConnector) runResponseListeningLoop() { defer close(cc.doneChan) defer atomic.StoreInt32(&cc.asyncConnectorState, ConnectorStateShutdown) - bufferedReader := bufio.NewReaderSize(cc.connection, cc.responseReadBufferSizeBytes) connectionAddr := cc.connection.RemoteAddr().String() wg := &sync.WaitGroup{} defer wg.Wait() protocolErrOccurred := false for { - response, err := readRawFrame(bufferedReader, connectionAddr, cc.clusterConnContext) - protocolErrResponseFrame, err, errCode := checkProtocolError(response, cc.ccProtoVer, cc.getCompression(), err, protocolErrOccurred, string(cc.connectorType)) + response, state, err := cc.codecHelper.ReadRawFrame() + protocolErrResponseFrame, err, errCode := checkProtocolError( + response, cc.ccProtoVer, []primitive.ProtocolVersion{}, cc.codecHelper.GetCompression(), err, + protocolErrOccurred, string(cc.connectorType)) if err != nil { handleConnectionError( err, cc.clusterConnContext, cc.cancelFunc, string(cc.connectorType), "reading", connectionAddr) @@ -274,6 +277,15 @@ func (cc *ClusterConnector) runResponseListeningLoop() { } } + if !state.useSegments && response.Header.Version.SupportsModernFramingLayout() && + (response.Header.OpCode == primitive.OpCodeReady || response.Header.OpCode == primitive.OpCodeAuthenticate) { + err = cc.codecHelper.SetState(true) + if err != nil { + handleConnectionError(err, cc.clusterConnContext, cc.cancelFunc, string(cc.connectorType), "switching to segments", connectionAddr) + break + } + } + // when there's a protocol error, we cannot rely on the returned stream id, the only exception is // when it's a UnsupportedVersion error, which means the Frame was properly parsed by the native protocol library // but the proxy doesn't support the protocol version and in that case we can proceed with releasing the stream id in the mapper @@ -284,7 +296,7 @@ func (cc *ClusterConnector) runResponseListeningLoop() { // if releasing the stream id failed, check if it's a protocol error response // if it is then ignore the release error and forward the response to the client handler so that // it can be handled correctly - parsedResponse, parseErr := cc.getCodec().ConvertFromRawFrame(response) + parsedResponse, parseErr := state.frameCodec.ConvertFromRawFrame(response) if parseErr != nil { log.Errorf("[%v] Error converting frame when releasing stream id: %v. Original error: %v.", string(cc.connectorType), parseErr, releaseErr) continue @@ -330,7 +342,7 @@ func (cc *ClusterConnector) runResponseListeningLoop() { } func (cc *ClusterConnector) handleAsyncResponse(response *frame.RawFrame) *frame.RawFrame { - errMsg, err := decodeError(response, cc.getCompression()) + errMsg, err := decodeError(response, cc.codecHelper.GetCompression()) if err != nil { log.Errorf("[%s] Error occured while checking if error is a protocol error: %v.", cc.connectorType, err) cc.Shutdown() @@ -550,13 +562,13 @@ func (cc *ClusterConnector) sendHeartbeat(version primitive.ProtocolVersion, hea cc.lastHeartbeatTime.Store(time.Now()) optionsMsg := &message.Options{} heartBeatFrame := frame.NewFrame(version, -1, optionsMsg) - rawFrame, err := defaultCodec.ConvertToRawFrame(heartBeatFrame) + rawFrame, err := defaultFrameCodec.ConvertToRawFrame(heartBeatFrame) if err != nil { log.Errorf("Cannot convert heartbeat frame to raw frame: %v", err) return } log.Debugf("Sending heartbeat to cluster %v", cc.clusterType) - cc.sendRequestToCluster(rawFrame, true) + _ = cc.sendRequestToCluster(rawFrame, true) } // shouldSendHeartbeat looks up the value of the last heartbeat time in the atomic value @@ -567,9 +579,5 @@ func (cc *ClusterConnector) shouldSendHeartbeat(heartbeatIntervalMs int) bool { } func (cc *ClusterConnector) getCodec() frame.RawCodec { - return codecs[cc.getCompression()] -} - -func (cc *ClusterConnector) getCompression() primitive.Compression { - return cc.compression.Load().(primitive.Compression) + return cc.codecHelper.GetState().frameCodec } diff --git a/proxy/pkg/zdmproxy/coalescer.go b/proxy/pkg/zdmproxy/coalescer.go index 30c7c397..67e915c8 100644 --- a/proxy/pkg/zdmproxy/coalescer.go +++ b/proxy/pkg/zdmproxy/coalescer.go @@ -1,13 +1,15 @@ package zdmproxy import ( - "bytes" "context" - "github.com/datastax/go-cassandra-native-protocol/frame" - "github.com/datastax/zdm-proxy/proxy/pkg/config" - log "github.com/sirupsen/logrus" "net" "sync" + + "github.com/datastax/go-cassandra-native-protocol/frame" + "github.com/datastax/go-cassandra-native-protocol/primitive" + log "github.com/sirupsen/logrus" + + "github.com/datastax/zdm-proxy/proxy/pkg/config" ) const ( @@ -32,6 +34,11 @@ type writeCoalescer struct { writeBufferSizeBytes int scheduler *Scheduler + + codecHelper *connCodecHelper + + isClusterConnector bool + isAsyncConnector bool } func NewWriteCoalescer( @@ -41,23 +48,24 @@ func NewWriteCoalescer( shutdownContext context.Context, clientHandlerCancelFunc context.CancelFunc, logPrefix string, - isRequest bool, - isAsync bool, - scheduler *Scheduler) *writeCoalescer { + isClusterConnector bool, + isAsyncConnector bool, + scheduler *Scheduler, + codecHelper *connCodecHelper) *writeCoalescer { writeQueueSizeFrames := conf.RequestWriteQueueSizeFrames - if !isRequest { + if !isClusterConnector { writeQueueSizeFrames = conf.ResponseWriteQueueSizeFrames } - if isAsync { + if isAsyncConnector { writeQueueSizeFrames = conf.AsyncConnectorWriteQueueSizeFrames } writeBufferSizeBytes := conf.RequestWriteBufferSizeBytes - if !isRequest { + if !isClusterConnector { writeBufferSizeBytes = conf.ResponseWriteBufferSizeBytes } - if isAsync { + if isAsyncConnector { writeBufferSizeBytes = conf.AsyncConnectorWriteBufferSizeBytes } return &writeCoalescer{ @@ -71,6 +79,9 @@ func NewWriteCoalescer( waitGroup: &sync.WaitGroup{}, writeBufferSizeBytes: writeBufferSizeBytes, scheduler: scheduler, + isClusterConnector: isClusterConnector, + isAsyncConnector: isAsyncConnector, + codecHelper: codecHelper, } } @@ -85,30 +96,49 @@ func (recv *writeCoalescer) RunWriteQueueLoop() { defer recv.waitGroup.Done() draining := false - bufferedWriter := bytes.NewBuffer(make([]byte, 0, initialBufferSize)) wg := &sync.WaitGroup{} defer wg.Wait() - for { - var resultOk bool - var result *coalescerIterationResult + state := recv.codecHelper.GetState() - firstFrame, firstFrameOk := <-recv.writeQueue + var resultOk bool + var result coalescerIterationResult + for { + var firstFrame *frame.RawFrame + var firstFrameOk bool + if result.leftoverFrame != nil { + firstFrame = result.leftoverFrame + firstFrameOk = true + } else { + firstFrame, firstFrameOk = <-recv.writeQueue + } if !firstFrameOk { break } - resultChannel := make(chan *coalescerIterationResult, 1) - tempDraining := draining - tempBuffer := bufferedWriter + result = coalescerIterationResult{} + resultOk = false + + writeBuffer := recv.codecHelper.segWriter.GetWriteBuffer() + resultChannel := make(chan coalescerIterationResult, 1) wg.Add(1) recv.scheduler.Schedule(func() { defer wg.Done() firstFrameRead := false + state = recv.codecHelper.GetState() for { var f *frame.RawFrame var ok bool if firstFrameRead { + newState := recv.codecHelper.GetState() + if newState != state { + // state updated (compression or segments) + resultChannel <- coalescerIterationResult{} + close(resultChannel) + return + } + state = newState + select { case f, ok = <-recv.writeQueue: default: @@ -116,42 +146,61 @@ func (recv *writeCoalescer) RunWriteQueueLoop() { } if !ok { - t := &coalescerIterationResult{ - buffer: tempBuffer, - draining: tempDraining, - } - resultChannel <- t + resultChannel <- coalescerIterationResult{} close(resultChannel) return } - if tempDraining { + if draining { // continue draining the write queue without writing on connection until it is closed log.Tracef("[%v] Discarding frame from write queue because shutdown was requested: %v", recv.logPrefix, f.Header) continue } } else { + writeBuffer.Reset() firstFrameRead = true f = firstFrame ok = true } - log.Tracef("[%v] Writing %v on %v", recv.logPrefix, f.Header, connectionAddr) - err := writeRawFrame(tempBuffer, connectionAddr, recv.shutdownContext, f) - if err != nil { - tempDraining = true - handleConnectionError(err, recv.shutdownContext, recv.cancelFunc, recv.logPrefix, "writing", connectionAddr) - } else { - if tempBuffer.Len() >= recv.writeBufferSizeBytes { - t := &coalescerIterationResult{ - buffer: tempBuffer, - draining: tempDraining, + if !state.useSegments { + log.Tracef("[%v] Writing %v on %v", recv.logPrefix, f.Header, connectionAddr) + err := writeRawFrame(writeBuffer, connectionAddr, recv.shutdownContext, f) + if err != nil { + draining = true + handleConnectionError(err, recv.shutdownContext, recv.cancelFunc, recv.logPrefix, "writing", connectionAddr) + } else { + if !recv.isClusterConnector { + // this is the write loop of a client connector so this loop is writing responses + // we need to switch to segments once READY/AUTHENTICATE response is sent (if v5+) + + if (f.Header.OpCode == primitive.OpCodeReady || f.Header.OpCode == primitive.OpCodeAuthenticate) && + f.Header.Version.SupportsModernFramingLayout() { + resultChannel <- coalescerIterationResult{switchToSegments: true} + close(resultChannel) + return + } } - resultChannel <- t + } + } else { + log.Tracef("[%v] Writing %v to segment on %v", recv.logPrefix, f.Header, connectionAddr) + written, err := recv.codecHelper.segWriter.AppendFrameToSegmentPayload(f) + if err != nil { + draining = true + handleConnectionError(err, recv.shutdownContext, recv.cancelFunc, recv.logPrefix, "writing", connectionAddr) + } else if !written { + // need to write current payload before moving forward + resultChannel <- coalescerIterationResult{leftoverFrame: f} close(resultChannel) return } } + + if writeBuffer.Len() >= recv.writeBufferSizeBytes { + resultChannel <- coalescerIterationResult{} + close(resultChannel) + return + } } }) @@ -159,17 +208,35 @@ func (recv *writeCoalescer) RunWriteQueueLoop() { if !resultOk { break } + if draining { + continue + } - draining = result.draining - bufferedWriter = result.buffer - if bufferedWriter.Len() > 0 && !draining { - _, err := recv.connection.Write(bufferedWriter.Bytes()) - bufferedWriter.Reset() + if result.switchToSegments { + err := recv.codecHelper.SetState(true) // don't update local state variable yet, so old state is used to write this buffer if err != nil { - handleConnectionError(err, recv.shutdownContext, recv.cancelFunc, recv.logPrefix, "writing", connectionAddr) + handleConnectionError(err, recv.shutdownContext, recv.cancelFunc, recv.logPrefix, "switching to segments", connectionAddr) draining = true } } + + if writeBuffer.Len() > 0 { + if draining { + writeBuffer.Reset() + } else if !state.useSegments { + _, err := recv.connection.Write(writeBuffer.Bytes()) + if err != nil { + handleConnectionError(err, recv.shutdownContext, recv.cancelFunc, recv.logPrefix, "writing", connectionAddr) + draining = true + } + } else { + err := recv.codecHelper.segWriter.WriteSegments(recv.connection, state) + if err != nil { + handleConnectionError(err, recv.shutdownContext, recv.cancelFunc, recv.logPrefix, "writing", connectionAddr) + draining = true + } + } + } } }() } @@ -198,6 +265,6 @@ func (recv *writeCoalescer) Close() { } type coalescerIterationResult struct { - buffer *bytes.Buffer - draining bool + switchToSegments bool + leftoverFrame *frame.RawFrame } diff --git a/proxy/pkg/zdmproxy/codechelper.go b/proxy/pkg/zdmproxy/codechelper.go new file mode 100644 index 00000000..91851c5c --- /dev/null +++ b/proxy/pkg/zdmproxy/codechelper.go @@ -0,0 +1,211 @@ +package zdmproxy + +import ( + "bufio" + "bytes" + "context" + "errors" + "fmt" + "io" + "sync/atomic" + + "github.com/datastax/go-cassandra-native-protocol/frame" + "github.com/datastax/go-cassandra-native-protocol/primitive" + "github.com/datastax/go-cassandra-native-protocol/segment" +) + +type connState struct { + useSegments bool // Protocol v5+ outer frame (segment) handling. See: https://github.com/apache/cassandra/blob/c713132aa6c20305a4a0157e9246057925ccbf78/doc/native_protocol_v5.spec + frameCodec frame.RawCodec + segmentCodec segment.Codec +} + +var emptyConnState = &connState{ + useSegments: false, + frameCodec: defaultFrameCodec, + segmentCodec: nil, +} + +type connCodecHelper struct { + state atomic.Pointer[connState] + compression *atomic.Value + + src *bufio.Reader + waitReadDataBuf []byte // buf to block waiting for data (1 byte) + waitReadDataReader *bytes.Reader + dualReader *DualReader + + segAccum SegmentAccumulator + + segWriter *SegmentWriter + + connectionAddr string + shutdownContext context.Context +} + +func newConnCodecHelper(src io.Reader, connectionAddr string, readBufferSizeBytes int, writeBufferSizeBytes int, compression *atomic.Value, + shutdownContext context.Context) *connCodecHelper { + writeBuffer := bytes.NewBuffer(make([]byte, 0, initialBufferSize)) + + bufferedReader := bufio.NewReaderSize(src, readBufferSizeBytes) + waitBuf := make([]byte, 1) // buf to block waiting for data (1 byte) + waitBufReader := bytes.NewReader(waitBuf) + return &connCodecHelper{ + state: atomic.Pointer[connState]{}, + compression: compression, + src: bufferedReader, + segAccum: NewSegmentAccumulator(defaultFrameCodec), + waitReadDataBuf: waitBuf, + waitReadDataReader: waitBufReader, + segWriter: NewSegmentWriter(writeBuffer, writeBufferSizeBytes, connectionAddr, shutdownContext), + connectionAddr: connectionAddr, + shutdownContext: shutdownContext, + dualReader: NewDualReader(waitBufReader, bufferedReader), + } +} + +func (recv *connCodecHelper) ReadRawFrame() (*frame.RawFrame, *connState, error) { + // Check if we already have a frame ready in the accumulator + if recv.segAccum.FrameReady() { + state := recv.GetState() + if !state.useSegments { + return nil, state, errors.New("unexpected state after checking that frame is ready to be read") + } + f, err := recv.segAccum.ReadFrame() + return f, state, err + } + + // block until data is available outside of codecHelper so that we can check the state (segments/compression) + // before reading the frame/segment otherwise it will check the state then enter a blocking state inside a codec + // but the state can be modified in the meantime + _, err := io.ReadFull(recv.src, recv.waitReadDataBuf) + if err != nil { + return nil, nil, adaptConnErr(recv.connectionAddr, recv.shutdownContext, err) + } + _ = recv.waitReadDataReader.UnreadByte() // reset reader1 to initial position + recv.dualReader.Reset() + state := recv.GetState() + if !state.useSegments { + rawFrame, err := defaultFrameCodec.DecodeRawFrame(recv.dualReader) // body is not being decompressed, so we can use default codec + if err != nil { + return nil, state, adaptConnErr(recv.connectionAddr, recv.shutdownContext, err) + } + return rawFrame, state, nil + } else { + for !recv.segAccum.FrameReady() { + sgmt, err := state.segmentCodec.DecodeSegment(recv.dualReader) + if err != nil { + return nil, state, adaptConnErr(recv.connectionAddr, recv.shutdownContext, err) + } + err = recv.segAccum.AppendSegmentPayload(sgmt.Payload.UncompressedData) + if err != nil { + return nil, state, err + } + } + f, err := recv.segAccum.ReadFrame() + return f, state, err + } +} + +// SetStartupCompression should be called as soon as the STARTUP request is received and the atomic.Value +// holding the primitive.Compression value is set. This method will update the state of this codec helper +// according to the value of Compression. +// +// This method should only be called once STARTUP is received and before the handshake proceeds because it +// will forcefully set a state where segments are disabled. +func (recv *connCodecHelper) SetStartupCompression() error { + return recv.SetState(false) +} + +// MaybeEnableSegments is a helper method to conditionally switch to segments if the provided protocol version supports them. +func (recv *connCodecHelper) MaybeEnableSegments(version primitive.ProtocolVersion) error { + if version.SupportsModernFramingLayout() { + return recv.SetState(true) + } + return nil +} + +// SetState updates the state of this codec helper loading the compression type from the atomic.Value provided +// during initialization and sets the underlying codecs to use segments or not according to the parameter. +func (recv *connCodecHelper) SetState(useSegments bool) error { + compression := recv.GetCompression() + if useSegments { + sCodec, ok := segmentCodecs[compression] + if !ok { + return fmt.Errorf("unknown segment compression %v", compression) + } + recv.state.Store(&connState{ + useSegments: true, + frameCodec: defaultFrameCodec, + segmentCodec: sCodec, + }) + return nil + } + + fCodec, ok := frameCodecs[compression] + if !ok { + return fmt.Errorf("unknown frame compression %v", compression) + } + recv.state.Store(&connState{ + useSegments: false, + frameCodec: fCodec, + segmentCodec: nil, + }) + return nil +} + +func (recv *connCodecHelper) GetState() *connState { + state := recv.state.Load() + if state == nil { + return emptyConnState + } + return state +} + +func (recv *connCodecHelper) GetCompression() primitive.Compression { + return recv.compression.Load().(primitive.Compression) +} + +// DualReader returns a Reader that's the logical concatenation of +// the provided input readers. They're read sequentially. Once all +// inputs have returned EOF, Read will return EOF. If any of the readers +// return a non-nil, non-EOF error, Read will return that error. +// It is identical to io.MultiReader but fixed to 2 readers so it avoids allocating a slice +type DualReader struct { + reader1 io.Reader + reader2 io.Reader + skipReader1 bool +} + +func (mr *DualReader) Read(p []byte) (n int, err error) { + currentReader := mr.reader1 + if mr.skipReader1 { + currentReader = mr.reader2 + } + for currentReader != nil { + n, err = currentReader.Read(p) + if err == io.EOF { + if mr.skipReader1 { + currentReader = nil + } else { + mr.skipReader1 = true + currentReader = mr.reader2 + } + } + if n > 0 || err != io.EOF { + if err == io.EOF && currentReader != nil { + err = nil + } + return + } + } + return 0, io.EOF +} + +func (mr *DualReader) Reset() { + mr.skipReader1 = false +} + +func NewDualReader(reader1 io.Reader, reader2 io.Reader) *DualReader { + return &DualReader{reader1: reader1, reader2: reader2, skipReader1: false} +} diff --git a/proxy/pkg/zdmproxy/codechelper_test.go b/proxy/pkg/zdmproxy/codechelper_test.go new file mode 100644 index 00000000..b8bb76aa --- /dev/null +++ b/proxy/pkg/zdmproxy/codechelper_test.go @@ -0,0 +1,737 @@ +package zdmproxy + +import ( + "bytes" + "context" + "fmt" + "io" + "sync/atomic" + "testing" + + "github.com/datastax/go-cassandra-native-protocol/frame" + "github.com/datastax/go-cassandra-native-protocol/primitive" + "github.com/datastax/go-cassandra-native-protocol/segment" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// Helper to create a connCodecHelper for testing with a buffer as source +func createTestConnCodecHelper(src *bytes.Buffer) *connCodecHelper { + compression := &atomic.Value{} + compression.Store(primitive.CompressionNone) + ctx := context.Background() + return newConnCodecHelper(src, "test-addr:9042", 4096, 1024, compression, ctx) +} + +// Helper to write a frame as a segment to a buffer +func writeFrameAsSegment(t *testing.T, buf *bytes.Buffer, frm *frame.RawFrame, useSegments bool) { + if useSegments { + // Encode frame to get envelope + envelopeBytes := encodeRawFrameToBytes(t, frm) + + // Wrap in segment + seg := &segment.Segment{ + Payload: &segment.Payload{UncompressedData: envelopeBytes}, + Header: &segment.Header{IsSelfContained: true}, + } + + err := defaultSegmentCodec.EncodeSegment(seg, buf) + require.NoError(t, err) + } else { + // Write frame directly (no segmentation) + err := defaultFrameCodec.EncodeRawFrame(frm, buf) + require.NoError(t, err) + } +} + +// TestConnCodecHelper_ReadSingleFrame_NoSegments tests reading a single frame without segmentation (v4) +func TestConnCodecHelper_ReadSingleFrame_NoSegments(t *testing.T) { + // Create a test frame + bodyContent := []byte("test query body") + testFrame := createTestRawFrame(primitive.ProtocolVersion4, 1, bodyContent) + + // Write frame to buffer (no segments for v4) + buf := &bytes.Buffer{} + writeFrameAsSegment(t, buf, testFrame, false) + + // Create codec helper + helper := createTestConnCodecHelper(buf) + + // Read the frame + readFrame, state, err := helper.ReadRawFrame() + require.NoError(t, err) + require.NotNil(t, readFrame) + require.NotNil(t, state) + + // Verify state shows no segments + assert.False(t, state.useSegments) + + // Verify the frame + assert.Equal(t, testFrame.Header.Version, readFrame.Header.Version) + assert.Equal(t, testFrame.Header.StreamId, readFrame.Header.StreamId) + assert.Equal(t, testFrame.Header.OpCode, readFrame.Header.OpCode) + assert.Equal(t, testFrame.Body, readFrame.Body) +} + +// TestConnCodecHelper_ReadSingleFrame_WithSegments tests reading a single frame with v5 segmentation +func TestConnCodecHelper_ReadSingleFrame_WithSegments(t *testing.T) { + // Create a test frame + bodyContent := []byte("test query body for v5") + testFrame := createTestRawFrame(primitive.ProtocolVersion5, 1, bodyContent) + + // Write frame as segment to buffer + buf := &bytes.Buffer{} + writeFrameAsSegment(t, buf, testFrame, true) + + // Create codec helper and enable segments + helper := createTestConnCodecHelper(buf) + err := helper.MaybeEnableSegments(primitive.ProtocolVersion5) + require.NoError(t, err) + + // Read the frame + readFrame, state, err := helper.ReadRawFrame() + require.NoError(t, err) + require.NotNil(t, readFrame) + require.NotNil(t, state) + + // Verify state shows segments enabled + assert.True(t, state.useSegments) + + // Verify the frame + assert.Equal(t, testFrame.Header.Version, readFrame.Header.Version) + assert.Equal(t, testFrame.Header.StreamId, readFrame.Header.StreamId) + assert.Equal(t, testFrame.Header.OpCode, readFrame.Header.OpCode) + assert.Equal(t, testFrame.Body, readFrame.Body) +} + +// TestConnCodecHelper_ReadMultipleFrames_NoSegments tests reading multiple frames without segmentation +func TestConnCodecHelper_ReadMultipleFrames_NoSegments(t *testing.T) { + // Create multiple test frames + frame1 := createTestRawFrame(primitive.ProtocolVersion4, 1, []byte("first frame")) + frame2 := createTestRawFrame(primitive.ProtocolVersion4, 2, []byte("second frame")) + frame3 := createTestRawFrame(primitive.ProtocolVersion4, 3, []byte("third frame")) + + // Write frames to buffer + buf := &bytes.Buffer{} + writeFrameAsSegment(t, buf, frame1, false) + writeFrameAsSegment(t, buf, frame2, false) + writeFrameAsSegment(t, buf, frame3, false) + + // Create codec helper + helper := createTestConnCodecHelper(buf) + + // Read and verify each frame + frames := []*frame.RawFrame{frame1, frame2, frame3} + for i, expectedFrame := range frames { + readFrame, _, err := helper.ReadRawFrame() + require.NoError(t, err, "Failed to read frame %d", i+1) + require.NotNil(t, readFrame) + + assert.Equal(t, expectedFrame.Header.StreamId, readFrame.Header.StreamId, + "Frame %d stream ID mismatch", i+1) + assert.Equal(t, expectedFrame.Body, readFrame.Body, + "Frame %d body mismatch", i+1) + } +} + +// TestConnCodecHelper_ReadMultipleFrames_WithSegments tests reading multiple frames with v5 segmentation +func TestConnCodecHelper_ReadMultipleFrames_WithSegments(t *testing.T) { + // Create multiple test frames + frame1 := createTestRawFrame(primitive.ProtocolVersion5, 1, []byte("first v5 frame")) + frame2 := createTestRawFrame(primitive.ProtocolVersion5, 2, []byte("second v5 frame")) + frame3 := createTestRawFrame(primitive.ProtocolVersion5, 3, []byte("third v5 frame")) + + // Write frames as segments to buffer + buf := &bytes.Buffer{} + writeFrameAsSegment(t, buf, frame1, true) + writeFrameAsSegment(t, buf, frame2, true) + writeFrameAsSegment(t, buf, frame3, true) + + // Create codec helper and enable segments + helper := createTestConnCodecHelper(buf) + err := helper.MaybeEnableSegments(primitive.ProtocolVersion5) + require.NoError(t, err) + + // Read and verify each frame + frames := []*frame.RawFrame{frame1, frame2, frame3} + for i, expectedFrame := range frames { + readFrame, state, err := helper.ReadRawFrame() + require.NoError(t, err, "Failed to read frame %d", i+1) + require.NotNil(t, readFrame) + assert.True(t, state.useSegments, "Segments should be enabled") + + assert.Equal(t, expectedFrame.Header.StreamId, readFrame.Header.StreamId, + "Frame %d stream ID mismatch", i+1) + assert.Equal(t, expectedFrame.Body, readFrame.Body, + "Frame %d body mismatch", i+1) + } +} + +// TestConnCodecHelper_SingleSegmentFrame tests reading a frame from a single self-contained segment +func TestConnCodecHelper_SingleSegmentFrame(t *testing.T) { + // Create a test frame + bodyContent := []byte("test query body") + testFrame := createTestRawFrame(primitive.ProtocolVersion5, 1, bodyContent) + + // Write frame as a self-contained segment to buffer + buf := &bytes.Buffer{} + writeFrameAsSegment(t, buf, testFrame, true) + + // Create codec helper and enable segments + helper := createTestConnCodecHelper(buf) + err := helper.MaybeEnableSegments(primitive.ProtocolVersion5) + require.NoError(t, err) + + // Verify frame is ready state is correct (internal check through reading) + readFrame, state, err := helper.ReadRawFrame() + require.NoError(t, err) + require.NotNil(t, readFrame) + require.True(t, state.useSegments) + + // Verify the frame + assert.Equal(t, testFrame.Header.Version, readFrame.Header.Version) + assert.Equal(t, testFrame.Header.StreamId, readFrame.Header.StreamId) + assert.Equal(t, testFrame.Header.OpCode, readFrame.Header.OpCode) + assert.Equal(t, testFrame.Body, readFrame.Body) +} + +// TestConnCodecHelper_MultipleSegmentPayloads tests accumulating a frame from multiple non-self-contained segments +func TestConnCodecHelper_MultipleSegmentPayloads(t *testing.T) { + // Create a frame with larger body + bodyContent := make([]byte, 100) + for i := range bodyContent { + bodyContent[i] = byte(i % 256) + } + testFrame := createTestRawFrame(primitive.ProtocolVersion5, 2, bodyContent) + + // Encode the frame + fullPayload := encodeRawFrameToBytes(t, testFrame) + + // Split the payload into multiple non-self-contained segments + buf := &bytes.Buffer{} + part1 := fullPayload[:40] // First part + part2 := fullPayload[40:] // Rest + + // Write first non-self-contained segment + seg1 := &segment.Segment{ + Payload: &segment.Payload{UncompressedData: part1}, + Header: &segment.Header{IsSelfContained: false}, + } + err := defaultSegmentCodec.EncodeSegment(seg1, buf) + require.NoError(t, err) + + // Write second non-self-contained segment + seg2 := &segment.Segment{ + Payload: &segment.Payload{UncompressedData: part2}, + Header: &segment.Header{IsSelfContained: false}, + } + err = defaultSegmentCodec.EncodeSegment(seg2, buf) + require.NoError(t, err) + + // Create codec helper and enable segments + helper := createTestConnCodecHelper(buf) + err = helper.MaybeEnableSegments(primitive.ProtocolVersion5) + require.NoError(t, err) + + // Read the frame (should accumulate from both segments automatically) + readFrame, state, err := helper.ReadRawFrame() + require.NoError(t, err) + require.NotNil(t, readFrame) + require.True(t, state.useSegments) + + // Verify the frame + assert.Equal(t, testFrame.Header.Version, readFrame.Header.Version) + assert.Equal(t, testFrame.Header.StreamId, readFrame.Header.StreamId) + assert.Equal(t, testFrame.Body, readFrame.Body) +} + +// TestConnCodecHelper_SequentialFramesInSeparateSegments tests reading multiple frames, +// each in its own self-contained segment +func TestConnCodecHelper_SequentialFramesInSeparateSegments(t *testing.T) { + // Create multiple test frames + frame1 := createTestRawFrame(primitive.ProtocolVersion5, 1, []byte("first frame")) + frame2 := createTestRawFrame(primitive.ProtocolVersion5, 2, []byte("second frame")) + frame3 := createTestRawFrame(primitive.ProtocolVersion5, 3, []byte("third frame")) + + // Write each frame as a separate self-contained segment to buffer + buf := &bytes.Buffer{} + writeFrameAsSegment(t, buf, frame1, true) + writeFrameAsSegment(t, buf, frame2, true) + writeFrameAsSegment(t, buf, frame3, true) + + // Create codec helper and enable segments + helper := createTestConnCodecHelper(buf) + err := helper.MaybeEnableSegments(primitive.ProtocolVersion5) + require.NoError(t, err) + + // Read and verify each frame + frames := []*frame.RawFrame{frame1, frame2, frame3} + for i, expectedFrame := range frames { + readFrame, state, err := helper.ReadRawFrame() + require.NoError(t, err, "Failed to read frame %d", i+1) + require.NotNil(t, readFrame) + require.True(t, state.useSegments) + + assert.Equal(t, expectedFrame.Header.StreamId, readFrame.Header.StreamId, + "Frame %d stream ID mismatch", i+1) + assert.Equal(t, expectedFrame.Body, readFrame.Body, + "Frame %d body mismatch", i+1) + } +} + +// TestConnCodecHelper_EmptyBufferEOF tests that reading from empty buffer returns EOF +func TestConnCodecHelper_EmptyBufferEOF(t *testing.T) { + // Create empty buffer + buf := &bytes.Buffer{} + helper := createTestConnCodecHelper(buf) + err := helper.MaybeEnableSegments(primitive.ProtocolVersion5) + require.NoError(t, err) + + // Try to read - should get EOF + readFrame, _, err := helper.ReadRawFrame() + require.Error(t, err) + require.Nil(t, readFrame) + assert.Contains(t, err.Error(), "EOF") +} + +// TestConnCodecHelper_MultipleEnvelopesInOneSegment tests that connCodecHelper can handle +// multiple envelopes packed into a single self-contained segment (per Protocol v5 spec Section 1). +// This is a CRITICAL test - if it fails, it indicates a bug in connCodecHelper.ReadRawFrame() +// where it doesn't check the internal accumulator before reading from the network. +func TestConnCodecHelper_MultipleEnvelopesInOneSegment(t *testing.T) { + testCases := []struct { + name string + envelopeCount int + }{ + {name: "Two envelopes in one segment", envelopeCount: 2}, + {name: "Three envelopes in one segment", envelopeCount: 3}, + {name: "Four envelopes in one segment", envelopeCount: 4}, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + // Create multiple envelopes + var envelopes []*frame.RawFrame + var combinedEnvelopePayload []byte + + for i := 0; i < tc.envelopeCount; i++ { + bodyContent := []byte(fmt.Sprintf("envelope_%d_data", i+1)) + envelope := createTestRawFrame(primitive.ProtocolVersion5, int16(i+1), bodyContent) + envelopes = append(envelopes, envelope) + + // Encode envelope and append to combined payload + encodedEnvelope := encodeRawFrameToBytes(t, envelope) + combinedEnvelopePayload = append(combinedEnvelopePayload, encodedEnvelope...) + } + + // Create ONE segment containing all envelopes + buf := &bytes.Buffer{} + seg := &segment.Segment{ + Payload: &segment.Payload{UncompressedData: combinedEnvelopePayload}, + Header: &segment.Header{IsSelfContained: true}, + } + err := defaultSegmentCodec.EncodeSegment(seg, buf) + require.NoError(t, err) + + // Create codec helper and enable segments + helper := createTestConnCodecHelper(buf) + err = helper.MaybeEnableSegments(primitive.ProtocolVersion5) + require.NoError(t, err) + + // Read all envelopes back - THIS IS THE BUG TEST + // If ReadRawFrame() doesn't check the accumulator first, it will fail with EOF + // on the second call instead of returning the cached envelope + for i := 0; i < tc.envelopeCount; i++ { + readEnvelope, state, err := helper.ReadRawFrame() + + // If this fails with EOF on i > 0, it's the bug! + require.NoError(t, err, + "BUG: Failed to read envelope %d of %d - ReadRawFrame() should check accumulator before reading from source", + i+1, tc.envelopeCount) + require.NotNil(t, readEnvelope) + assert.True(t, state.useSegments) + + // Verify envelope content + assert.Equal(t, envelopes[i].Header.StreamId, readEnvelope.Header.StreamId, + "Envelope %d stream ID mismatch", i+1) + assert.Equal(t, envelopes[i].Body, readEnvelope.Body, + "Envelope %d body mismatch", i+1) + } + }) + } +} + +// TestConnCodecHelper_LargeFrameMultipleSegments tests reading a large frame split across multiple segments +func TestConnCodecHelper_LargeFrameMultipleSegments(t *testing.T) { + // Create a large frame that will require multiple segments + largeBody := make([]byte, segment.MaxPayloadLength*2+1000) + for i := range largeBody { + largeBody[i] = byte(i % 256) + } + testFrame := createTestRawFrame(primitive.ProtocolVersion5, 1, largeBody) + + // Encode the frame + envelopeBytes := encodeRawFrameToBytes(t, testFrame) + + // Split into multiple non-self-contained segments + buf := &bytes.Buffer{} + payloadLength := len(envelopeBytes) + + for offset := 0; offset < payloadLength; offset += segment.MaxPayloadLength { + end := offset + segment.MaxPayloadLength + if end > payloadLength { + end = payloadLength + } + + seg := &segment.Segment{ + Payload: &segment.Payload{UncompressedData: envelopeBytes[offset:end]}, + Header: &segment.Header{IsSelfContained: false}, // Not self-contained + } + err := defaultSegmentCodec.EncodeSegment(seg, buf) + require.NoError(t, err) + } + + // Create codec helper and enable segments + helper := createTestConnCodecHelper(buf) + err := helper.MaybeEnableSegments(primitive.ProtocolVersion5) + require.NoError(t, err) + + // Read the frame (should accumulate from multiple segments) + readFrame, state, err := helper.ReadRawFrame() + require.NoError(t, err) + require.NotNil(t, readFrame) + assert.True(t, state.useSegments) + + // Verify the frame + assert.Equal(t, testFrame.Header.Version, readFrame.Header.Version) + assert.Equal(t, testFrame.Header.StreamId, readFrame.Header.StreamId) + assert.Equal(t, testFrame.Body, readFrame.Body) +} + +// TestConnCodecHelper_StateTransitions tests state transitions for enabling/disabling segments +func TestConnCodecHelper_StateTransitions(t *testing.T) { + buf := &bytes.Buffer{} + helper := createTestConnCodecHelper(buf) + + // Initially, state should be empty (no segments) + state := helper.GetState() + assert.False(t, state.useSegments) + + // Enable segments for v5 + err := helper.MaybeEnableSegments(primitive.ProtocolVersion5) + require.NoError(t, err) + + state = helper.GetState() + assert.True(t, state.useSegments) + assert.NotNil(t, state.segmentCodec) + + // Disable segments (e.g., for startup) + err = helper.SetStartupCompression() + require.NoError(t, err) + + state = helper.GetState() + assert.False(t, state.useSegments) + + // Enable again for v5 + err = helper.MaybeEnableSegments(primitive.ProtocolVersion5) + require.NoError(t, err) + + state = helper.GetState() + assert.True(t, state.useSegments) +} + +// TestConnCodecHelper_MixedProtocolVersions tests handling different protocol versions +func TestConnCodecHelper_MixedProtocolVersions(t *testing.T) { + testCases := []struct { + name string + version primitive.ProtocolVersion + shouldUseSegments bool + }{ + {name: "v3 - no segments", version: primitive.ProtocolVersion3, shouldUseSegments: false}, + {name: "v4 - no segments", version: primitive.ProtocolVersion4, shouldUseSegments: false}, + {name: "v5 - with segments", version: primitive.ProtocolVersion5, shouldUseSegments: true}, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + // Create a test frame + bodyContent := []byte(fmt.Sprintf("test for %s", tc.name)) + testFrame := createTestRawFrame(tc.version, 1, bodyContent) + + // Write frame to buffer + buf := &bytes.Buffer{} + writeFrameAsSegment(t, buf, testFrame, tc.shouldUseSegments) + + // Create codec helper + helper := createTestConnCodecHelper(buf) + + // Enable segments if protocol supports it + err := helper.MaybeEnableSegments(tc.version) + require.NoError(t, err) + + // Verify state + state := helper.GetState() + assert.Equal(t, tc.shouldUseSegments, state.useSegments, + "Segment usage mismatch for %s", tc.name) + + // Read and verify frame + readFrame, _, err := helper.ReadRawFrame() + require.NoError(t, err) + require.NotNil(t, readFrame) + + assert.Equal(t, testFrame.Header.Version, readFrame.Header.Version) + assert.Equal(t, testFrame.Body, readFrame.Body) + }) + } +} + +// TestConnCodecHelper_PartialEnvelopeAcrossSegments tests the edge case where a single envelope +// (frame) is split across multiple segments with partial header bytes. +// This ensures that connCodecHelper correctly accumulates partial envelope data across segments. +func TestConnCodecHelper_PartialEnvelopeAcrossSegments(t *testing.T) { + // Create a test frame + bodyContent := []byte("test body content for edge case") + testFrame := createTestRawFrame(primitive.ProtocolVersion5, 1, bodyContent) + fullEnvelope := encodeRawFrameToBytes(t, testFrame) + + // Protocol v5 header is 9 bytes + // Split envelope across 3 segments: + // Segment 1: First 3 bytes of envelope header (incomplete) + // Segment 2: Next 4 bytes of header (bytes 3-6, still incomplete - total 7 < 9) + // Segment 3: Remaining header bytes (bytes 7-8) + body + + buf := &bytes.Buffer{} + + // Write segment 1 with partial header (3 bytes) + seg1 := &segment.Segment{ + Payload: &segment.Payload{UncompressedData: fullEnvelope[:3]}, + Header: &segment.Header{IsSelfContained: false}, + } + err := defaultSegmentCodec.EncodeSegment(seg1, buf) + require.NoError(t, err) + + // Write segment 2 with more partial header (4 bytes) + seg2 := &segment.Segment{ + Payload: &segment.Payload{UncompressedData: fullEnvelope[3:7]}, + Header: &segment.Header{IsSelfContained: false}, + } + err = defaultSegmentCodec.EncodeSegment(seg2, buf) + require.NoError(t, err) + + // Write segment 3 with remaining header + body + seg3 := &segment.Segment{ + Payload: &segment.Payload{UncompressedData: fullEnvelope[7:]}, + Header: &segment.Header{IsSelfContained: false}, + } + err = defaultSegmentCodec.EncodeSegment(seg3, buf) + require.NoError(t, err) + + // Create codec helper and enable segments + helper := createTestConnCodecHelper(buf) + err = helper.MaybeEnableSegments(primitive.ProtocolVersion5) + require.NoError(t, err) + + // Read the frame - should succeed despite header being split across 3 segments + readFrame, state, err := helper.ReadRawFrame() + require.NoError(t, err) + require.NotNil(t, readFrame) + assert.True(t, state.useSegments) + + // Verify frame content + assert.Equal(t, testFrame.Header.StreamId, readFrame.Header.StreamId) + assert.Equal(t, testFrame.Body, readFrame.Body) +} + +// TestConnCodecHelper_HeaderCompletionWithBodyInSegment tests the edge case where one segment +// completes the envelope header AND contains body bytes. +// This ensures the accumulator correctly transitions from header parsing to body accumulation. +func TestConnCodecHelper_HeaderCompletionWithBodyInSegment(t *testing.T) { + // Create a test frame with larger body + bodyContent := make([]byte, 50) + for i := range bodyContent { + bodyContent[i] = byte(i) + } + testFrame := createTestRawFrame(primitive.ProtocolVersion5, 1, bodyContent) + fullEnvelope := encodeRawFrameToBytes(t, testFrame) + + // v5 header is 9 bytes + // Segment 1: First 7 bytes of header (incomplete) + // Segment 2: Remaining 2 header bytes (7-8) + first 11 body bytes (9-19) + // This segment completes header AND has body data + // Segment 3: Remaining body bytes (20+) + + buf := &bytes.Buffer{} + + // Write segment 1 with partial header (7 bytes) + seg1 := &segment.Segment{ + Payload: &segment.Payload{UncompressedData: fullEnvelope[:7]}, + Header: &segment.Header{IsSelfContained: false}, + } + err := defaultSegmentCodec.EncodeSegment(seg1, buf) + require.NoError(t, err) + + // Write segment 2 with header completion + some body bytes + seg2 := &segment.Segment{ + Payload: &segment.Payload{UncompressedData: fullEnvelope[7:20]}, + Header: &segment.Header{IsSelfContained: false}, + } + err = defaultSegmentCodec.EncodeSegment(seg2, buf) + require.NoError(t, err) + + // Write segment 3 with remaining body bytes + seg3 := &segment.Segment{ + Payload: &segment.Payload{UncompressedData: fullEnvelope[20:]}, + Header: &segment.Header{IsSelfContained: false}, + } + err = defaultSegmentCodec.EncodeSegment(seg3, buf) + require.NoError(t, err) + + // Create codec helper and enable segments + helper := createTestConnCodecHelper(buf) + err = helper.MaybeEnableSegments(primitive.ProtocolVersion5) + require.NoError(t, err) + + // Read the frame + readFrame, state, err := helper.ReadRawFrame() + require.NoError(t, err) + require.NotNil(t, readFrame) + assert.True(t, state.useSegments) + + // Verify frame content + assert.Equal(t, testFrame.Header.StreamId, readFrame.Header.StreamId) + assert.Equal(t, bodyContent, readFrame.Body) +} + +// === DualReader Tests === +// DualReader is an internal component of connCodecHelper + +// TestDualReader_NewDualReader tests the constructor +func TestDualReader_NewDualReader(t *testing.T) { + reader1 := bytes.NewReader([]byte("data1")) + reader2 := bytes.NewReader([]byte("data2")) + + dualReader := NewDualReader(reader1, reader2) + + require.NotNil(t, dualReader) + assert.Equal(t, reader1, dualReader.reader1) + assert.Equal(t, reader2, dualReader.reader2) + assert.False(t, dualReader.skipReader1) +} + +// TestDualReader_Read tests reading from both readers +func TestDualReader_Read(t *testing.T) { + data1 := []byte("first reader data") + data2 := []byte("second reader data") + reader1 := bytes.NewReader(data1) + reader2 := bytes.NewReader(data2) + + dualReader := NewDualReader(reader1, reader2) + + // Read all data + result := make([]byte, len(data1)+len(data2)) + n, err := io.ReadFull(dualReader, result) + require.NoError(t, err) + assert.Equal(t, len(data1)+len(data2), n) + + // Verify data + expectedData := append(data1, data2...) + assert.Equal(t, expectedData, result) + + // Further reads should return EOF + buf := make([]byte, 10) + n, err = dualReader.Read(buf) + assert.Equal(t, 0, n) + assert.Equal(t, io.EOF, err) +} + +// TestDualReader_Read_FirstReaderOnly tests reading when second reader is empty +func TestDualReader_Read_FirstReaderOnly(t *testing.T) { + data1 := []byte("only first reader") + reader1 := bytes.NewReader(data1) + reader2 := bytes.NewReader([]byte{}) + + dualReader := NewDualReader(reader1, reader2) + + result := make([]byte, len(data1)) + n, err := io.ReadFull(dualReader, result) + require.NoError(t, err) + assert.Equal(t, len(data1), n) + assert.Equal(t, data1, result) +} + +// TestDualReader_Read_SecondReaderOnly tests reading when first reader is empty +func TestDualReader_Read_SecondReaderOnly(t *testing.T) { + data2 := []byte("only second reader") + reader1 := bytes.NewReader([]byte{}) + reader2 := bytes.NewReader(data2) + + dualReader := NewDualReader(reader1, reader2) + + result := make([]byte, len(data2)) + n, err := io.ReadFull(dualReader, result) + require.NoError(t, err) + assert.Equal(t, len(data2), n) + assert.Equal(t, data2, result) +} + +// TestDualReader_Reset tests resetting the reader +func TestDualReader_Reset(t *testing.T) { + data1 := []byte("first") + data2 := []byte("second") + reader1 := bytes.NewReader(data1) + reader2 := bytes.NewReader(data2) + + dualReader := NewDualReader(reader1, reader2) + + // Read some data to move past first reader + // Use io.ReadFull to ensure we read from both readers + buf := make([]byte, len(data1)+2) + n, err := io.ReadFull(dualReader, buf) + require.NoError(t, err) + assert.Equal(t, len(data1)+2, n) // Should have read from both readers + + // Reset + dualReader.Reset() + assert.False(t, dualReader.skipReader1) + + // Reset the underlying readers too + reader1.Seek(0, io.SeekStart) + reader2.Seek(0, io.SeekStart) + + // Read again + result := make([]byte, len(data1)+len(data2)) + n, err = io.ReadFull(dualReader, result) + require.NoError(t, err) + assert.Equal(t, len(data1)+len(data2), n) + + expectedData := append(data1, data2...) + assert.Equal(t, expectedData, result) +} + +// TestDualReader_Read_InChunks tests reading in multiple small chunks +func TestDualReader_Read_InChunks(t *testing.T) { + data1 := []byte("12345") + data2 := []byte("67890") + reader1 := bytes.NewReader(data1) + reader2 := bytes.NewReader(data2) + + dualReader := NewDualReader(reader1, reader2) + + // Read in small chunks + var result []byte + chunkSize := 2 + for { + buf := make([]byte, chunkSize) + n, err := dualReader.Read(buf) + if n > 0 { + result = append(result, buf[:n]...) + } + if err == io.EOF { + break + } + require.NoError(t, err) + } + + expectedData := append(data1, data2...) + assert.Equal(t, expectedData, result) +} diff --git a/proxy/pkg/zdmproxy/controlconn.go b/proxy/pkg/zdmproxy/controlconn.go index 796e4ae7..37edb54b 100644 --- a/proxy/pkg/zdmproxy/controlconn.go +++ b/proxy/pkg/zdmproxy/controlconn.go @@ -4,15 +4,6 @@ import ( "context" "errors" "fmt" - "github.com/datastax/go-cassandra-native-protocol/frame" - "github.com/datastax/go-cassandra-native-protocol/message" - "github.com/datastax/go-cassandra-native-protocol/primitive" - "github.com/datastax/zdm-proxy/proxy/pkg/common" - "github.com/datastax/zdm-proxy/proxy/pkg/config" - "github.com/datastax/zdm-proxy/proxy/pkg/metrics" - "github.com/google/uuid" - "github.com/jpillora/backoff" - log "github.com/sirupsen/logrus" "math" "math/big" "math/rand" @@ -22,6 +13,17 @@ import ( "sync" "sync/atomic" "time" + + "github.com/datastax/go-cassandra-native-protocol/frame" + "github.com/datastax/go-cassandra-native-protocol/message" + "github.com/datastax/go-cassandra-native-protocol/primitive" + "github.com/google/uuid" + "github.com/jpillora/backoff" + log "github.com/sirupsen/logrus" + + "github.com/datastax/zdm-proxy/proxy/pkg/common" + "github.com/datastax/zdm-proxy/proxy/pkg/config" + "github.com/datastax/zdm-proxy/proxy/pkg/metrics" ) type ControlConn struct { @@ -385,7 +387,7 @@ func (cc *ControlConn) connAndNegotiateProtoVer(endpoint Endpoint, initialProtoV newConn := NewCqlConnection(cc, endpoint, tcpConn, cc.username, cc.password, ccReadTimeout, ccWriteTimeout, cc.conf, protoVer) err = newConn.InitializeContext(protoVer, ctx) var respErr *ResponseError - if err != nil && errors.As(err, &respErr) && respErr.IsProtocolError() && strings.Contains(err.Error(), "Invalid or unsupported protocol version") { + if err != nil && errors.As(err, &respErr) && respErr.IsProtocolError() { // unsupported protocol version // protocol renegotiation requires opening a new TCP connection err2 := newConn.Close() @@ -410,6 +412,8 @@ func downgradeProtocol(version primitive.ProtocolVersion) primitive.ProtocolVers case primitive.ProtocolVersionDse2: return primitive.ProtocolVersionDse1 case primitive.ProtocolVersionDse1: + return primitive.ProtocolVersion5 + case primitive.ProtocolVersion5: return primitive.ProtocolVersion4 case primitive.ProtocolVersion4: return primitive.ProtocolVersion3 diff --git a/proxy/pkg/zdmproxy/cqlconn.go b/proxy/pkg/zdmproxy/cqlconn.go index 7dd30fc8..d9018251 100644 --- a/proxy/pkg/zdmproxy/cqlconn.go +++ b/proxy/pkg/zdmproxy/cqlconn.go @@ -4,11 +4,6 @@ import ( "context" "errors" "fmt" - "github.com/datastax/go-cassandra-native-protocol/frame" - "github.com/datastax/go-cassandra-native-protocol/message" - "github.com/datastax/go-cassandra-native-protocol/primitive" - "github.com/datastax/zdm-proxy/proxy/pkg/config" - log "github.com/sirupsen/logrus" "io" "net" "runtime" @@ -16,6 +11,13 @@ import ( "sync" "sync/atomic" "time" + + "github.com/datastax/go-cassandra-native-protocol/frame" + "github.com/datastax/go-cassandra-native-protocol/message" + "github.com/datastax/go-cassandra-native-protocol/primitive" + log "github.com/sirupsen/logrus" + + "github.com/datastax/zdm-proxy/proxy/pkg/config" ) const ( @@ -68,12 +70,16 @@ type cqlConn struct { authEnabled bool frameProcessor FrameProcessor protocolVersion *atomic.Value + codecHelper *connCodecHelper } var ( StreamIdMismatchErr = errors.New("stream id of the response is different from the stream id of the request") ) +const CqlConnReadBufferSizeBytes = 1024 +const CqlConnWriteBufferSizeBytes = 1024 + func (c *cqlConn) GetEndpoint() Endpoint { return c.endpoint } @@ -88,6 +94,8 @@ func NewCqlConnection( readTimeout time.Duration, writeTimeout time.Duration, conf *config.Config, protoVer primitive.ProtocolVersion) CqlConnection { ctx, cFn := context.WithCancel(context.Background()) + compressionValue := &atomic.Value{} + compressionValue.Store(primitive.CompressionNone) cqlConn := &cqlConn{ controlConn: controlConn, readTimeout: readTimeout, @@ -113,6 +121,7 @@ func NewCqlConnection( // protoVer is the proposed protocol version using which we will try to establish connectivity frameProcessor: NewStreamIdProcessor(NewInternalStreamIdMapper(protoVer, conf, nil)), protocolVersion: &atomic.Value{}, + codecHelper: newConnCodecHelper(conn, conn.RemoteAddr().String(), CqlConnReadBufferSizeBytes, CqlConnWriteBufferSizeBytes, compressionValue, ctx), } cqlConn.StartRequestLoop() cqlConn.StartResponseLoop() @@ -147,7 +156,8 @@ func (c *cqlConn) StartResponseLoop() { defer close(c.eventsQueue) defer log.Debugf("Shutting down response loop on %v.", c) for c.ctx.Err() == nil { - f, err := defaultCodec.DecodeFrame(c.conn) + var f *frame.Frame + rawFrame, state, err := c.codecHelper.ReadRawFrame() if err != nil { if isDisconnectErr(err) { log.Infof("[%v] Control connection to %v disconnected", c.controlConn.connConfig.GetClusterType(), c.conn.RemoteAddr().String()) @@ -157,7 +167,21 @@ func (c *cqlConn) StartResponseLoop() { c.cancelFn() break } - + f, err = state.frameCodec.ConvertFromRawFrame(rawFrame) + if err != nil { + log.Errorf("Failed to decode frame messge on cql connection %v: %v", c, err) + c.cancelFn() + break + } + if !state.useSegments && f.Header.Version.SupportsModernFramingLayout() && + (f.Header.OpCode == primitive.OpCodeReady || f.Header.OpCode == primitive.OpCodeAuthenticate) { + err = c.codecHelper.SetState(true) + if err != nil { + log.Errorf("Failed to switch to segments on cql connection %v: %v", c, err) + c.cancelFn() + break + } + } if f.Body.Message.GetOpCode() == primitive.OpCodeEvent { select { case c.eventsQueue <- f: @@ -206,15 +230,74 @@ func (c *cqlConn) StartRequestLoop() { for c.ctx.Err() == nil { select { case f := <-c.outgoingCh: - err := defaultCodec.EncodeFrame(f, c.conn) - if err != nil { - if isDisconnectErr(err) { - log.Infof("[%v] Control connection to %v disconnected", c.controlConn.connConfig.GetClusterType(), c.conn.RemoteAddr().String()) - } else { - log.Errorf("Failed to write/encode frame on cql connection %v: %v", c, err) + state := c.codecHelper.GetState() + if state.useSegments { + first := true + for { + if !first { + ok := false + select { + case f, ok = <-c.outgoingCh: + default: + } + if !ok { + state = c.codecHelper.GetState() + err := c.codecHelper.segWriter.WriteSegments(c.conn, state) + if err != nil { + log.Errorf("Failed to write segment to control connection %v: %v", c, err) + c.cancelFn() + return + } + break + } + } else { + first = false + } + + rawFrame, err := defaultFrameCodec.ConvertToRawFrame(f) + if err != nil { + log.Errorf("Failed to convert frame to raw frame while writing segment payload on control connection %v: %v", c, err) + c.cancelFn() + return + } + written, err := c.codecHelper.segWriter.AppendFrameToSegmentPayload(rawFrame) + if err != nil { + log.Errorf("Failed to write/encode frame to segment payload on control connection %v: %v", c, err) + c.cancelFn() + return + } + if !written { + state = c.codecHelper.GetState() + err = c.codecHelper.segWriter.WriteSegments(c.conn, state) + if err != nil { + log.Errorf("Failed to write segment to control connection %v: %v", c, err) + c.cancelFn() + return + } + written, err = c.codecHelper.segWriter.AppendFrameToSegmentPayload(rawFrame) + if err != nil { + log.Errorf("Failed to write/encode frame to segment payload on control connection %v: %v", c, err) + c.cancelFn() + return + } + if !written { + log.Errorf("SegWriter returned false even after flushing the payload on control connection %v: %v", c, err) + c.cancelFn() + return + } + } + } + } else { + err := defaultFrameCodec.EncodeFrame(f, c.conn) + if err != nil { + if isDisconnectErr(err) { + log.Infof("[%v] Control connection to %v disconnected", c.controlConn.connConfig.GetClusterType(), c.conn.RemoteAddr().String()) + } else { + log.Errorf("Failed to write/encode frame on cql connection %v: %v", c, err) + } + c.cancelFn() + return } - c.cancelFn() - return } case <-c.ctx.Done(): return @@ -348,6 +431,8 @@ func (c *cqlConn) SendAndReceive(request *frame.Frame, ctx context.Context) (*fr c.Close() } return nil, fmt.Errorf("context finished before completing receiving frame on %v: %w", c, readTimeoutCtx.Err()) + case <-c.ctx.Done(): + return nil, fmt.Errorf("cql connection was closed: %w", ctx.Err()) } } @@ -379,24 +464,29 @@ func (c *cqlConn) PerformHandshake(version primitive.ProtocolVersion, ctx contex if response, err = c.SendAndReceive(startup, ctx); err == nil { switch response.Body.Message.(type) { case *message.Ready: - log.Warnf("%v: expected AUTHENTICATE, got READY – is authentication required?", c) + log.Warnf("%v ControlConn: authentication is NOT enabled.", c.controlConn.connConfig.GetClusterType()) break case *message.Authenticate: authEnabled = true var authResponse *frame.Frame authResponse, err = performHandshakeStep(authenticator, version, -1, response) - if err == nil { + if err != nil { + return authEnabled, fmt.Errorf("authentication response processing failed: %w", err) + } + response, err = c.SendAndReceive(authResponse, ctx) + if err != nil { + return authEnabled, fmt.Errorf("could not send AUTH RESPONSE: %w", err) + } + _, authSuccess := response.Body.Message.(*message.AuthSuccess) + if !authSuccess { + authResponse, err = performHandshakeStep(authenticator, version, -1, response) + if err != nil { + return authEnabled, fmt.Errorf("second authentication response processing failed: %w", err) + } if response, err = c.SendAndReceive(authResponse, ctx); err != nil { - err = fmt.Errorf("could not send AUTH RESPONSE: %w", err) + return authEnabled, fmt.Errorf("could not send AUTH RESPONSE: %w", err) } else if _, authSuccess := response.Body.Message.(*message.AuthSuccess); !authSuccess { - authResponse, err = performHandshakeStep(authenticator, version, -1, response) - if err == nil { - if response, err = c.SendAndReceive(authResponse, ctx); err != nil { - err = fmt.Errorf("could not send AUTH RESPONSE: %w", err) - } else if _, authSuccess := response.Body.Message.(*message.AuthSuccess); !authSuccess { - err = fmt.Errorf("expected AUTH_SUCCESS, got %v", response.Body.Message) - } - } + return authEnabled, fmt.Errorf("expected AUTH_SUCCESS, got %v", response.Body.Message) } } case *message.ProtocolError: @@ -408,8 +498,6 @@ func (c *cqlConn) PerformHandshake(version primitive.ProtocolVersion, ctx contex if err == nil { log.Debugf("%v: handshake successful", c) c.initialized = true - } else { - log.Errorf("%v: handshake failed: %v", c, err) } return authEnabled, err } diff --git a/proxy/pkg/zdmproxy/cqlparser.go b/proxy/pkg/zdmproxy/cqlparser.go index 2b6f3e32..d1d5d43c 100644 --- a/proxy/pkg/zdmproxy/cqlparser.go +++ b/proxy/pkg/zdmproxy/cqlparser.go @@ -4,6 +4,8 @@ import ( "encoding/hex" "errors" "fmt" + "strings" + "github.com/datastax/go-cassandra-native-protocol/frame" "github.com/datastax/go-cassandra-native-protocol/message" "github.com/datastax/go-cassandra-native-protocol/primitive" @@ -11,7 +13,6 @@ import ( "github.com/datastax/zdm-proxy/proxy/pkg/config" "github.com/datastax/zdm-proxy/proxy/pkg/metrics" log "github.com/sirupsen/logrus" - "strings" ) type forwardDecision string @@ -300,7 +301,7 @@ func (recv *frameDecodeContext) GetOrDecodeFrame() (*frame.Frame, error) { return recv.decodedFrame, nil } - if codec, ok := codecs[recv.compression]; ok { + if codec, ok := frameCodecs[recv.compression]; ok { decodedFrame, err := codec.ConvertFromRawFrame(recv.frame) if err != nil { return nil, fmt.Errorf("could not decode raw frame: %w", err) diff --git a/proxy/pkg/zdmproxy/cqlparser_test.go b/proxy/pkg/zdmproxy/cqlparser_test.go index d41d140a..1637e7d7 100644 --- a/proxy/pkg/zdmproxy/cqlparser_test.go +++ b/proxy/pkg/zdmproxy/cqlparser_test.go @@ -214,7 +214,7 @@ func mockAuthResponse(t *testing.T) *frame.RawFrame { func mockFrame(t *testing.T, message message.Message, version primitive.ProtocolVersion) *frame.RawFrame { f := frame.NewFrame(version, 1, message) - rawFrame, err := defaultCodec.ConvertToRawFrame(f) + rawFrame, err := defaultFrameCodec.ConvertToRawFrame(f) require.Nil(t, err) return rawFrame } diff --git a/proxy/pkg/zdmproxy/frame.go b/proxy/pkg/zdmproxy/frame.go index c24900ef..b15cb060 100644 --- a/proxy/pkg/zdmproxy/frame.go +++ b/proxy/pkg/zdmproxy/frame.go @@ -3,11 +3,13 @@ package zdmproxy import ( "context" "fmt" + "io" + "github.com/datastax/go-cassandra-native-protocol/compression/lz4" "github.com/datastax/go-cassandra-native-protocol/compression/snappy" "github.com/datastax/go-cassandra-native-protocol/frame" "github.com/datastax/go-cassandra-native-protocol/primitive" - "io" + "github.com/datastax/go-cassandra-native-protocol/segment" ) type shutdownError struct { @@ -18,17 +20,33 @@ func (e *shutdownError) Error() string { return e.err } -var defaultCodec = frame.NewRawCodec() +var defaultFrameCodec = frame.NewRawCodec() +var defaultSegmentCodec = segment.NewCodec() -var codecs = map[primitive.Compression]frame.RawCodec{ - primitive.CompressionNone: defaultCodec, +var frameCodecs = map[primitive.Compression]frame.RawCodec{ + primitive.CompressionNone: defaultFrameCodec, primitive.CompressionLz4: frame.NewRawCodecWithCompression(lz4.Compressor{}), primitive.CompressionSnappy: frame.NewRawCodecWithCompression(snappy.Compressor{}), - primitive.Compression("none"): defaultCodec, + primitive.Compression("none"): defaultFrameCodec, primitive.Compression("lz4"): frame.NewRawCodecWithCompression(lz4.Compressor{}), primitive.Compression("snappy"): frame.NewRawCodecWithCompression(snappy.Compressor{}), } +var segmentCodecs = map[primitive.Compression]segment.Codec{ + primitive.CompressionNone: defaultSegmentCodec, + primitive.CompressionLz4: segment.NewCodecWithCompression(lz4.Compressor{}), + primitive.Compression("none"): defaultSegmentCodec, + primitive.Compression("lz4"): segment.NewCodecWithCompression(lz4.Compressor{}), +} + +func getFrameCodec(compression primitive.Compression) (frame.RawCodec, error) { + codec, ok := frameCodecs[compression] + if !ok { + return nil, fmt.Errorf("no codec for compression: %v", compression) + } + return codec, nil +} + var ShutdownErr = &shutdownError{err: "aborted due to shutdown request"} func adaptConnErr(connectionAddr string, clientHandlerContext context.Context, err error) error { @@ -45,16 +63,6 @@ func adaptConnErr(connectionAddr string, clientHandlerContext context.Context, e // Simple function that writes a rawframe with a single call to writeToConnection func writeRawFrame(writer io.Writer, connectionAddr string, clientHandlerContext context.Context, frame *frame.RawFrame) error { - err := defaultCodec.EncodeRawFrame(frame, writer) // body is already compressed if needed, so we can use default codec + err := defaultFrameCodec.EncodeRawFrame(frame, writer) // body is already compressed if needed, so we can use default codec return adaptConnErr(connectionAddr, clientHandlerContext, err) } - -// Simple function that reads data from a connection and builds a frame -func readRawFrame(reader io.Reader, connectionAddr string, clientHandlerContext context.Context) (*frame.RawFrame, error) { - rawFrame, err := defaultCodec.DecodeRawFrame(reader) // body is not being decompressed, so we can use default codec - if err != nil { - return nil, adaptConnErr(connectionAddr, clientHandlerContext, err) - } - - return rawFrame, nil -} diff --git a/proxy/pkg/zdmproxy/nativeprotocol.go b/proxy/pkg/zdmproxy/nativeprotocol.go index 98b0dfe1..8e1fe8fa 100644 --- a/proxy/pkg/zdmproxy/nativeprotocol.go +++ b/proxy/pkg/zdmproxy/nativeprotocol.go @@ -4,10 +4,11 @@ import ( "crypto/md5" "errors" "fmt" + "strings" + "github.com/datastax/go-cassandra-native-protocol/datatype" "github.com/datastax/go-cassandra-native-protocol/message" "github.com/datastax/go-cassandra-native-protocol/primitive" - "strings" ) type ParsedRow struct { @@ -169,7 +170,8 @@ func EncodePreparedResult( } id := md5.Sum([]byte(query + keyspace)) return &message.PreparedResult{ - PreparedQueryId: id[:], + PreparedQueryId: id[:], + ResultMetadataId: id[:], ResultMetadata: &message.RowsMetadata{ ColumnCount: int32(len(columns)), Columns: columns, diff --git a/proxy/pkg/zdmproxy/proxy.go b/proxy/pkg/zdmproxy/proxy.go index 9a99a95a..ffd4f1c2 100644 --- a/proxy/pkg/zdmproxy/proxy.go +++ b/proxy/pkg/zdmproxy/proxy.go @@ -5,20 +5,23 @@ import ( "crypto/tls" "errors" "fmt" - "github.com/datastax/zdm-proxy/proxy/pkg/common" - "github.com/datastax/zdm-proxy/proxy/pkg/config" - "github.com/datastax/zdm-proxy/proxy/pkg/metrics" - "github.com/datastax/zdm-proxy/proxy/pkg/metrics/noopmetrics" - "github.com/datastax/zdm-proxy/proxy/pkg/metrics/prommetrics" - "github.com/jpillora/backoff" - "github.com/prometheus/client_golang/prometheus" - log "github.com/sirupsen/logrus" "math/rand" "net" "runtime" "sync" "sync/atomic" "time" + + "github.com/datastax/go-cassandra-native-protocol/primitive" + "github.com/jpillora/backoff" + "github.com/prometheus/client_golang/prometheus" + log "github.com/sirupsen/logrus" + + "github.com/datastax/zdm-proxy/proxy/pkg/common" + "github.com/datastax/zdm-proxy/proxy/pkg/config" + "github.com/datastax/zdm-proxy/proxy/pkg/metrics" + "github.com/datastax/zdm-proxy/proxy/pkg/metrics/noopmetrics" + "github.com/datastax/zdm-proxy/proxy/pkg/metrics/prommetrics" ) type ZdmProxy struct { @@ -37,6 +40,8 @@ type ZdmProxy struct { readMode common.ReadMode systemQueriesMode common.SystemQueriesMode + blockedProtoVersions []primitive.ProtocolVersion + proxyRand *rand.Rand lock *sync.RWMutex @@ -414,6 +419,13 @@ func (p *ZdmProxy) initializeGlobalStructures() error { log.Infof("Parsed Async latency buckets: %v", p.asyncBuckets) } + p.blockedProtoVersions, err = p.Conf.ParseBlockedProtocolVersions() + if err != nil { + return fmt.Errorf("failed to parse blocked protocol versions: %w", err) + } else { + log.Infof("Parsed Blocked Protocol Versions: %v", p.blockedProtoVersions) + } + p.activeClients = 0 return nil } @@ -554,6 +566,7 @@ func (p *ZdmProxy) handleNewConnection(clientConn net.Conn) { p.originControlConn, p.targetControlConn, p.Conf, + p.blockedProtoVersions, p.TopologyConfig, p.Conf.TargetUsername, p.Conf.TargetPassword, diff --git a/proxy/pkg/zdmproxy/querymodifier_test.go b/proxy/pkg/zdmproxy/querymodifier_test.go index c787f3a5..42275706 100644 --- a/proxy/pkg/zdmproxy/querymodifier_test.go +++ b/proxy/pkg/zdmproxy/querymodifier_test.go @@ -165,7 +165,7 @@ func TestReplaceQueryString(t *testing.T) { decodedFrame, statementQuery, err := context.GetOrDecodeAndInspect("", timeUuidGenerator) require.Nil(t, err) _, decodedFrame, statementQuery, statementsReplacedTerms, err := queryModifier.replaceQueryString(decodedFrame, statementQuery) - newRawFrame, err := defaultCodec.ConvertToRawFrame(decodedFrame) + newRawFrame, err := defaultFrameCodec.ConvertToRawFrame(decodedFrame) newContext := NewInitializedFrameDecodeContext(newRawFrame, primitive.CompressionNone, decodedFrame, statementQuery) require.Nil(t, err) require.Equal(t, len(test.positionsReplaced), len(statementsReplacedTerms)) diff --git a/proxy/pkg/zdmproxy/segment.go b/proxy/pkg/zdmproxy/segment.go new file mode 100644 index 00000000..5eac175e --- /dev/null +++ b/proxy/pkg/zdmproxy/segment.go @@ -0,0 +1,275 @@ +package zdmproxy + +import ( + "bytes" + "context" + "errors" + "fmt" + "io" + + "github.com/datastax/go-cassandra-native-protocol/frame" + "github.com/datastax/go-cassandra-native-protocol/primitive" + "github.com/datastax/go-cassandra-native-protocol/segment" +) + +// SegmentAccumulator provides a way for the caller to build frames from segments. +// +// The caller appends segment payloads to this accumulator by calling AppendSegmentPayload +// and then retrieves frames by calling ReadFrame. +// +// The caller can check whether a frame is ready to be read by calling FrameReady(). +// +// There can be multiple frames in a segment so the caller should check FrameReady() again after calling ReadFrame(). +// +// This type is not "thread-safe". +type SegmentAccumulator interface { + ReadFrame() (*frame.RawFrame, error) + AppendSegmentPayload(payload []byte) error + FrameReady() bool +} + +type segmentAcc struct { + buf *bytes.Buffer + accumLength int + targetLength int + hdr *frame.Header + codec frame.RawDecoder + payloadReader *bytes.Reader + version primitive.ProtocolVersion + hdrBuf *bytes.Buffer +} + +func NewSegmentAccumulator(codec frame.RawDecoder) SegmentAccumulator { + return &segmentAcc{ + buf: nil, + accumLength: 0, + targetLength: 0, + hdr: nil, + codec: codec, + payloadReader: nil, + version: 0, + hdrBuf: bytes.NewBuffer(make([]byte, 0, primitive.FrameHeaderLengthV3AndHigher)), + } +} + +func (a *segmentAcc) FrameReady() bool { + return a.accumLength >= a.targetLength && a.hdr != nil +} + +func (a *segmentAcc) ReadFrame() (*frame.RawFrame, error) { + if !a.FrameReady() { + return nil, errors.New("frame is not ready") + } + payload := a.buf.Bytes() + actualPayload := payload[:a.targetLength] + var extraBytes []byte + if a.accumLength > a.targetLength { + extraBytes = payload[a.targetLength:] + } + hdr := a.hdr + a.reset() + err := a.AppendSegmentPayload(extraBytes) + if err != nil { + return nil, fmt.Errorf("could not carry over extra payload bytes to new payload: %w", err) + } + if hdr.Version.SupportsModernFramingLayout() && hdr.Flags.Contains(primitive.HeaderFlagCompressed) { + hdr.Flags = hdr.Flags.Remove(primitive.HeaderFlagCompressed) // gocql workaround (https://issues.apache.org/jira/browse/CASSGO-98) + } + return &frame.RawFrame{ + Header: hdr, + Body: actualPayload, + }, nil +} + +func (a *segmentAcc) reset() { + a.buf = nil // do not zero/reset current buffer, just allocate a new one + a.accumLength = 0 + a.targetLength = 0 + a.version = 0 + a.hdr = nil + a.hdrBuf.Reset() +} + +func (a *segmentAcc) AppendSegmentPayload(payload []byte) error { + if len(payload) == 0 { + return nil + } + + if a.payloadReader == nil { + a.payloadReader = bytes.NewReader(payload) + } else { + a.payloadReader.Reset(payload) + } + + if a.version == 0 { + v, err := a.readVersion(a.payloadReader) + if err != nil { + return fmt.Errorf("cannot read frame version in multipart segment: %w", err) + } + a.version = v + } + + if a.hdr == nil { + remainingBytes := a.version.FrameHeaderLengthInBytes() - a.hdrBuf.Len() + bytesToCopy := remainingBytes + done := true + if a.payloadReader.Len() < remainingBytes { + bytesToCopy = a.payloadReader.Len() + done = false + } + _, err := io.CopyN(a.hdrBuf, a.payloadReader, int64(bytesToCopy)) + if err != nil { + return fmt.Errorf("cannot read frame header bytes: %w", err) + } + if done { + a.hdr, err = a.codec.DecodeHeader(a.hdrBuf) + if err != nil { + return fmt.Errorf("cannot read frame header in multipart segment: %w", err) + } + a.targetLength = int(a.hdr.BodyLength) + a.buf = bytes.NewBuffer(make([]byte, 0, a.targetLength)) + } + } + + if a.payloadReader.Len() > 0 { + n, err := a.buf.ReadFrom(a.payloadReader) + if err != nil { + return fmt.Errorf("cannot copy payload to buffer: %w", err) + } + a.accumLength += int(n) + } + return nil +} + +func (a *segmentAcc) readVersion(reader *bytes.Reader) (primitive.ProtocolVersion, error) { + versionAndDirection, err := reader.ReadByte() + if err != nil { + return 0, fmt.Errorf("cannot decode header version and direction: %w", err) + } + _ = reader.UnreadByte() + + version := primitive.ProtocolVersion(versionAndDirection & 0b0111_1111) + err = primitive.CheckSupportedProtocolVersion(version) + if err != nil { + return 0, err + } + return version, nil +} + +type SegmentWriter struct { + payload *bytes.Buffer + connectionAddr string + clientHandlerContext context.Context + maxBufferSize int +} + +func NewSegmentWriter(writeBuffer *bytes.Buffer, maxBufferSize int, connectionAddr string, clientHandlerContext context.Context) *SegmentWriter { + return &SegmentWriter{ + payload: writeBuffer, + connectionAddr: connectionAddr, + clientHandlerContext: clientHandlerContext, + maxBufferSize: maxBufferSize, + } +} + +func FrameUncompressedLength(f *frame.RawFrame) (int, error) { + if f.Header.Flags.Contains(primitive.HeaderFlagCompressed) { + return -1, fmt.Errorf("cannot obtain uncompressed length of compressed frame: %v", f.String()) + } + return f.Header.Version.FrameHeaderLengthInBytes() + len(f.Body), nil +} + +func (w *SegmentWriter) GetWriteBuffer() *bytes.Buffer { + return w.payload +} + +func (w *SegmentWriter) canWriteFrameInternal(frameLength int) bool { + if frameLength > segment.MaxPayloadLength { // frame needs multiple segments + if w.payload.Len() > 0 { + // if frame needs multiple segments and there is already a frame in the payload then need to flush first + return false + } else { + return true + } + } else { // frame can be self contained + if w.payload.Len()+frameLength > segment.MaxPayloadLength { + // if frame can be self contained but adding it to the current payload exceeds the max length then need to flush first + return false + } else if w.payload.Len() >= 0 && w.payload.Len() >= w.maxBufferSize { + // if there is already data in the current payload and it exceeds the configured max buffer size then need to flush first + // max buffer size can be exceeded if payload is currently empty (otherwise the frame couldn't be written) + return false + } else { + return true + } + } +} + +func (w *SegmentWriter) WriteSegments(dst io.Writer, state *connState) error { + payload := w.payload.Bytes() + payloadLength := len(payload) + + if payloadLength <= 0 { + return errors.New("cannot write segment with empty payload") + } + + if payloadLength > segment.MaxPayloadLength { + segmentCount := payloadLength / segment.MaxPayloadLength + isExactMultiple := payloadLength%segment.MaxPayloadLength == 0 + if !isExactMultiple { + segmentCount++ + } + + // Split the payload buffer into segments + for i := range segmentCount { + segmentLength := segment.MaxPayloadLength + if i == segmentCount-1 && !isExactMultiple { + segmentLength = payloadLength % segment.MaxPayloadLength + } + start := i * segment.MaxPayloadLength + seg := &segment.Segment{ + Payload: &segment.Payload{UncompressedData: payload[start : start+segmentLength]}, + Header: &segment.Header{IsSelfContained: false}, + } + err := state.segmentCodec.EncodeSegment(seg, dst) + if err != nil { + return adaptConnErr( + w.connectionAddr, + w.clientHandlerContext, + fmt.Errorf("cannot write segment %d of %d: %w", i+1, segmentCount, err)) + } + } + } else { + seg := &segment.Segment{ + Payload: &segment.Payload{UncompressedData: w.payload.Bytes()}, + Header: &segment.Header{IsSelfContained: true}, + } + err := state.segmentCodec.EncodeSegment(seg, dst) + if err != nil { + return adaptConnErr(w.connectionAddr, w.clientHandlerContext, fmt.Errorf("cannot write segment: %w", err)) + } + } + w.payload.Reset() + return nil +} + +func (w *SegmentWriter) AppendFrameToSegmentPayload(frm *frame.RawFrame) (bool, error) { + frameLength, err := FrameUncompressedLength(frm) + if err != nil { + return false, err + } + if !w.canWriteFrameInternal(frameLength) { + return false, nil + } + + err = w.writeToPayload(frm) + if err != nil { + return false, fmt.Errorf("cannot write frame to segment payload: %w", err) + } + return true, nil +} + +func (w *SegmentWriter) writeToPayload(f *frame.RawFrame) error { + // frames are always uncompressed in v5 (segments can be compressed) + return adaptConnErr(w.connectionAddr, w.clientHandlerContext, defaultFrameCodec.EncodeRawFrame(f, w.payload)) +} diff --git a/proxy/pkg/zdmproxy/segment_test.go b/proxy/pkg/zdmproxy/segment_test.go new file mode 100644 index 00000000..9e35dc68 --- /dev/null +++ b/proxy/pkg/zdmproxy/segment_test.go @@ -0,0 +1,292 @@ +package zdmproxy + +import ( + "bytes" + "context" + "fmt" + "testing" + + "github.com/datastax/go-cassandra-native-protocol/frame" + "github.com/datastax/go-cassandra-native-protocol/primitive" + "github.com/datastax/go-cassandra-native-protocol/segment" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// Helper function to create a simple raw frame for testing +func createTestRawFrame(version primitive.ProtocolVersion, streamId int16, bodyContent []byte) *frame.RawFrame { + return &frame.RawFrame{ + Header: &frame.Header{ + Version: version, + Flags: primitive.HeaderFlag(0), + StreamId: streamId, + OpCode: primitive.OpCodeQuery, + BodyLength: int32(len(bodyContent)), + }, + Body: bodyContent, + } +} + +// Helper function to encode a raw frame to bytes +func encodeRawFrameToBytes(t *testing.T, frm *frame.RawFrame) []byte { + buf := &bytes.Buffer{} + err := defaultFrameCodec.EncodeRawFrame(frm, buf) + require.NoError(t, err) + return buf.Bytes() +} + +// TestFrameUncompressedLength tests the FrameUncompressedLength function +func TestFrameUncompressedLength(t *testing.T) { + // Test with uncompressed frame + bodyContent := []byte("test body") + testFrame := createTestRawFrame(primitive.ProtocolVersion4, 1, bodyContent) + + length, err := FrameUncompressedLength(testFrame) + require.NoError(t, err) + + expectedLength := primitive.ProtocolVersion4.FrameHeaderLengthInBytes() + len(bodyContent) + assert.Equal(t, expectedLength, length) +} + +// TestFrameUncompressedLength_Compressed tests that compressed frames return error +func TestFrameUncompressedLength_Compressed(t *testing.T) { + bodyContent := []byte("test body") + testFrame := createTestRawFrame(primitive.ProtocolVersion4, 1, bodyContent) + testFrame.Header.Flags = primitive.HeaderFlagCompressed + + length, err := FrameUncompressedLength(testFrame) + require.Error(t, err) + assert.Equal(t, -1, length) + assert.Contains(t, err.Error(), "cannot obtain uncompressed length of compressed frame") +} + +// TestSegmentWriter_NewSegmentWriter tests the constructor +func TestSegmentWriter_NewSegmentWriter(t *testing.T) { + buf := &bytes.Buffer{} + ctx := context.Background() + addr := "127.0.0.1:9042" + + writer := NewSegmentWriter(buf, 128, addr, ctx) + + require.NotNil(t, writer) + assert.Equal(t, buf, writer.payload) + assert.Equal(t, addr, writer.connectionAddr) + assert.Equal(t, ctx, writer.clientHandlerContext) +} + +// TestSegmentWriter_GetWriteBuffer tests getting the write buffer +func TestSegmentWriter_GetWriteBuffer(t *testing.T) { + buf := &bytes.Buffer{} + ctx := context.Background() + writer := NewSegmentWriter(buf, 128, "127.0.0.1:9042", ctx) + + returnedBuf := writer.GetWriteBuffer() + assert.Equal(t, buf, returnedBuf) +} + +// TestSegmentWriter_CanWriteFrameInternal tests the internal frame capacity check +func TestSegmentWriter_CanWriteFrameInternal(t *testing.T) { + buf := &bytes.Buffer{} + ctx := context.Background() + writer := NewSegmentWriter(buf, 10000, "127.0.0.1:9042", ctx) + + // Test 1: Empty payload, frame fits in one segment + assert.True(t, writer.canWriteFrameInternal(1000)) + + // Test 2: Empty payload, frame needs multiple segments + assert.True(t, writer.canWriteFrameInternal(segment.MaxPayloadLength+1)) + + // Test 3: Empty payload, frame has exact length of max segment payload length + assert.True(t, writer.canWriteFrameInternal(segment.MaxPayloadLength)) + + // Test 4: Empty payload, frame with no body (e.g. OPTIONS message) + assert.True(t, writer.canWriteFrameInternal(primitive.FrameHeaderLengthV3AndHigher)) + + // Test 5: Empty payload, 0 length (just an edge case but it should be impossible for this to happen) + assert.True(t, writer.canWriteFrameInternal(0)) + + // Test 6: Write some data first + writer.payload.Write(make([]byte, 1000)) + + // Small frame that fits + assert.True(t, writer.canWriteFrameInternal(1000)) + + // Test 7: Frame that would exceed segment max payload after merging and there's already data in the payload + assert.False(t, writer.canWriteFrameInternal(segment.MaxPayloadLength-500)) + + // Test 8: Payload has data, adding frame would need multiple segments + writer.payload.Reset() + writer.payload.Write(make([]byte, 100)) + assert.False(t, writer.canWriteFrameInternal(segment.MaxPayloadLength+1)) +} + +// TestSegmentWriter_AppendFrameToSegmentPayload tests appending frames +func TestSegmentWriter_AppendFrameToSegmentPayload(t *testing.T) { + buf := &bytes.Buffer{} + ctx := context.Background() + writer := NewSegmentWriter(buf, 100000, "127.0.0.1:9042", ctx) + + bodyContent := []byte("test") + testFrame := createTestRawFrame(primitive.ProtocolVersion4, 1, bodyContent) + + // Append frame + written, err := writer.AppendFrameToSegmentPayload(testFrame) + require.NoError(t, err) + require.True(t, written) + + // Check that buffer has content + assert.Greater(t, buf.Len(), 0) +} + +// TestSegmentWriter_AppendFrameToSegmentPayload_CannotWrite tests when frame cannot be written +func TestSegmentWriter_AppendFrameToSegmentPayload_CannotWrite(t *testing.T) { + buf := &bytes.Buffer{} + ctx := context.Background() + writer := NewSegmentWriter(buf, 100, "127.0.0.1:9042", ctx) + + // Fill the buffer + writer.payload.Write(make([]byte, 1000)) + + // Try to append a frame that cannot fit + bodyContent := make([]byte, 5000) + testFrame := createTestRawFrame(primitive.ProtocolVersion4, 1, bodyContent) + + written, err := writer.AppendFrameToSegmentPayload(testFrame) + require.NoError(t, err) + require.False(t, written) // Should not be written +} + +// TestSegmentWriter_WriteSegments_SelfContained tests writing a self-contained segment +func TestSegmentWriter_WriteSegments_SelfContained(t *testing.T) { + testCases := []struct { + name string + frameCount int + }{ + {name: "Single frame", frameCount: 1}, + {name: "Two frames", frameCount: 2}, + {name: "Three frames", frameCount: 3}, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + buf := &bytes.Buffer{} + ctx := context.Background() + writer := NewSegmentWriter(buf, 100000, "127.0.0.1:9042", ctx) + + // Create a conn state with segment codec + state := &connState{ + useSegments: true, + frameCodec: defaultFrameCodec, + segmentCodec: defaultSegmentCodec, + } + + // Append multiple frames to the payload + var expectedEnvelopes [][]byte + for i := 0; i < tc.frameCount; i++ { + bodyContent := []byte(fmt.Sprintf("frame_%d_body", i+1)) + testFrame := createTestRawFrame(primitive.ProtocolVersion5, int16(i+1), bodyContent) + + // Append frame to segment payload + written, err := writer.AppendFrameToSegmentPayload(testFrame) + require.NoError(t, err, "Failed to append frame %d", i+1) + require.True(t, written, "Frame %d was not written", i+1) + + // Store expected envelope bytes + expectedEnvelopes = append(expectedEnvelopes, encodeRawFrameToBytes(t, testFrame)) + } + + // Write segments + dst := &bytes.Buffer{} + err := writer.WriteSegments(dst, state) + require.NoError(t, err) + + // Verify the payload was reset + assert.Equal(t, 0, writer.payload.Len()) + + // Verify something was written to dst + assert.Greater(t, dst.Len(), 0) + + // Decode the segment to verify + decodedSegment, err := state.segmentCodec.DecodeSegment(dst) + require.NoError(t, err) + assert.True(t, decodedSegment.Header.IsSelfContained) + + // Verify all frames are in the segment payload + var expectedPayload []byte + for _, envelope := range expectedEnvelopes { + expectedPayload = append(expectedPayload, envelope...) + } + assert.Equal(t, expectedPayload, decodedSegment.Payload.UncompressedData) + + // Verify we can decode all frames from the segment payload + payloadReader := bytes.NewReader(decodedSegment.Payload.UncompressedData) + for i := 0; i < tc.frameCount; i++ { + decodedFrame, err := defaultFrameCodec.DecodeRawFrame(payloadReader) + require.NoError(t, err, "Failed to decode frame %d from segment payload", i+1) + assert.Equal(t, int16(i+1), decodedFrame.Header.StreamId, "Frame %d has wrong stream ID", i+1) + assert.Equal(t, []byte(fmt.Sprintf("frame_%d_body", i+1)), decodedFrame.Body, "Frame %d has wrong body", i+1) + } + }) + } +} + +// TestSegmentWriter_WriteSegments_MultipleSegments tests writing multiple segments +func TestSegmentWriter_WriteSegments_MultipleSegments(t *testing.T) { + buf := &bytes.Buffer{} + ctx := context.Background() + writer := NewSegmentWriter(buf, 128, "127.0.0.1:9042", ctx) + + // Add data larger than MaxPayloadLength + largeData := make([]byte, segment.MaxPayloadLength*2+1000) + for i := range largeData { + largeData[i] = byte(i % 256) + } + writer.payload.Write(largeData) + + // Create a conn state with segment codec + state := &connState{ + useSegments: true, + frameCodec: defaultFrameCodec, + segmentCodec: defaultSegmentCodec, + } + + // Write segments + dst := &bytes.Buffer{} + err := writer.WriteSegments(dst, state) + require.NoError(t, err) + + // Verify the payload was reset + assert.Equal(t, 0, writer.payload.Len()) + + // Decode and verify segments + var reconstructedData []byte + for i := 0; i < 3; i++ { // Should have 3 segments + decodedSegment, err := state.segmentCodec.DecodeSegment(dst) + require.NoError(t, err, "Failed to decode segment %d", i) + assert.False(t, decodedSegment.Header.IsSelfContained, "Segment %d should not be self-contained", i) + reconstructedData = append(reconstructedData, decodedSegment.Payload.UncompressedData...) + } + + assert.Equal(t, 0, dst.Len()) + + // Verify reconstructed data matches original + assert.Equal(t, largeData, reconstructedData) +} + +// TestSegmentWriter_WriteSegments_EmptyPayload tests that writing empty payload returns error +func TestSegmentWriter_WriteSegments_EmptyPayload(t *testing.T) { + buf := &bytes.Buffer{} + ctx := context.Background() + writer := NewSegmentWriter(buf, 128, "127.0.0.1:9042", ctx) + + state := &connState{ + useSegments: true, + frameCodec: defaultFrameCodec, + segmentCodec: defaultSegmentCodec, + } + + dst := &bytes.Buffer{} + err := writer.WriteSegments(dst, state) + require.Error(t, err) + assert.Contains(t, err.Error(), "cannot write segment with empty payload") +} diff --git a/proxy/pkg/zdmproxy/startup.go b/proxy/pkg/zdmproxy/startup.go index 44702984..78f1161d 100644 --- a/proxy/pkg/zdmproxy/startup.go +++ b/proxy/pkg/zdmproxy/startup.go @@ -2,13 +2,15 @@ package zdmproxy import ( "fmt" + "net" + "time" + "github.com/datastax/go-cassandra-native-protocol/frame" "github.com/datastax/go-cassandra-native-protocol/message" "github.com/datastax/go-cassandra-native-protocol/primitive" - "github.com/datastax/zdm-proxy/proxy/pkg/common" log "github.com/sirupsen/logrus" - "net" - "time" + + "github.com/datastax/zdm-proxy/proxy/pkg/common" ) const ( @@ -23,6 +25,26 @@ func (recv *AuthError) Error() string { return fmt.Sprintf("authentication error: %v", recv.errMsg) } +func (ch *ClientHandler) getAuthSecondaryClusterConnector() *ClusterConnector { + if ch.forwardAuthToTarget { + // secondary is ORIGIN + return ch.originCassandraConnector + } else { + // secondary is TARGET + return ch.targetCassandraConnector + } +} + +func (ch *ClientHandler) getAuthPrimaryClusterConnector() *ClusterConnector { + if ch.forwardAuthToTarget { + // primary is TARGET + return ch.targetCassandraConnector + } else { + // primary is ORIGIN + return ch.originCassandraConnector + } +} + func (ch *ClientHandler) handleSecondaryHandshakeStartup( startupRequest *frame.RawFrame, startupResponse *frame.RawFrame, asyncConnector bool) error { @@ -138,8 +160,13 @@ func (ch *ClientHandler) handleSecondaryHandshakeStartup( } } + connector := ch.getAuthSecondaryClusterConnector() + if asyncConnector { + connector = ch.asyncConnector + } + newPhase, parsedFrame, done, err := handleSecondaryHandshakeResponse( - phase, response, clientIPAddress, clusterAddress, ch.getCompression(), logIdentifier) + connector, phase, response, clientIPAddress, clusterAddress, ch.getCompression(), logIdentifier) if err != nil { return err } @@ -161,9 +188,10 @@ func (ch *ClientHandler) handleSecondaryHandshakeStartup( } func handleSecondaryHandshakeResponse( + clusterConnector *ClusterConnector, phase int, f *frame.RawFrame, clientIPAddress net.Addr, clusterAddress net.Addr, compression primitive.Compression, logIdentifier string) (int, *frame.Frame, bool, error) { - parsedFrame, err := codecs[compression].ConvertFromRawFrame(f) + parsedFrame, err := frameCodecs[compression].ConvertFromRawFrame(f) if err != nil { return phase, nil, false, fmt.Errorf("could not decode frame from %v: %w", clusterAddress, err) }