diff --git a/.github/workflows/build-cron.yml b/.github/workflows/build-cron.yml
index 4d5d99e9c1f..730b276685e 100644
--- a/.github/workflows/build-cron.yml
+++ b/.github/workflows/build-cron.yml
@@ -54,7 +54,7 @@ jobs:
]
steps:
- name: Checkout Repository
- uses: actions/checkout@v5
+ uses: actions/checkout@v6
- name: Setup Java ${{ matrix.java }} ${{ matrix.javadist }}
uses: actions/setup-java@v5
diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml
index 6df5bd6db28..d401263bb7d 100644
--- a/.github/workflows/build.yml
+++ b/.github/workflows/build.yml
@@ -68,7 +68,7 @@ jobs:
]
steps:
- name: Checkout Repository
- uses: actions/checkout@v5
+ uses: actions/checkout@v6
- name: Setup Java ${{ matrix.java }} ${{ matrix.javadist }}
uses: actions/setup-java@v5
diff --git a/.github/workflows/cleanup-transient-artifacts.yml b/.github/workflows/cleanup-transient-artifacts.yml
index 6dd55e75bed..d249064cf34 100644
--- a/.github/workflows/cleanup-transient-artifacts.yml
+++ b/.github/workflows/cleanup-transient-artifacts.yml
@@ -38,7 +38,7 @@ jobs:
if: ${{ github.event.workflow_run.conclusion == 'success' }}
steps:
- name: Checkout Repository
- uses: actions/checkout@v5
+ uses: actions/checkout@v6
- name: Delete Artifacts
run: |
diff --git a/.github/workflows/docker-cd.yml b/.github/workflows/docker-cd.yml
index 938c8e232fa..a848e5fedf4 100644
--- a/.github/workflows/docker-cd.yml
+++ b/.github/workflows/docker-cd.yml
@@ -38,7 +38,7 @@ jobs:
steps:
- name: Checkout
- uses: actions/checkout@v5
+ uses: actions/checkout@v6
# https://github.com/docker/metadata-action
- name: Configure Docker metadata
diff --git a/.github/workflows/docker-release.yml b/.github/workflows/docker-release.yml
index d2c63327e93..b14a3ae9885 100644
--- a/.github/workflows/docker-release.yml
+++ b/.github/workflows/docker-release.yml
@@ -40,7 +40,7 @@ jobs:
steps:
- name: Checkout
- uses: actions/checkout@v5
+ uses: actions/checkout@v6
- run: git checkout ${{ github.event.inputs.branch_or_tag }}
# https://github.com/docker/metadata-action
diff --git a/.github/workflows/docker-testImage.yml b/.github/workflows/docker-testImage.yml
index b05b09e02b3..6e1dad91bb5 100644
--- a/.github/workflows/docker-testImage.yml
+++ b/.github/workflows/docker-testImage.yml
@@ -36,7 +36,7 @@ jobs:
steps:
- name: Checkout
- uses: actions/checkout@v5
+ uses: actions/checkout@v6
# https://github.com/docker/metadata-action
- name: Configure Docker metadata
diff --git a/.github/workflows/documentation.yml b/.github/workflows/documentation.yml
index 283a8ce5e57..75b74228f67 100644
--- a/.github/workflows/documentation.yml
+++ b/.github/workflows/documentation.yml
@@ -47,7 +47,7 @@ jobs:
name: Java
steps:
- name: Checkout Repository
- uses: actions/checkout@v5
+ uses: actions/checkout@v6
- name: Setup Java ${{ matrix.java }} ${{ matrix.javadist }}
uses: actions/setup-java@v5
@@ -64,7 +64,7 @@ jobs:
name: Python
steps:
- name: Checkout Repository
- uses: actions/checkout@v5
+ uses: actions/checkout@v6
- name: Setup Python
uses: actions/setup-python@v6
@@ -73,7 +73,7 @@ jobs:
architecture: 'x64'
- name: Cache Pip Dependencies
- uses: actions/cache@v4
+ uses: actions/cache@v5
with:
path: ~/.cache/pip
key: ${{ runner.os }}-pip-docs-${{ hashFiles('src/main/python/docs/requires-docs.txt') }}
diff --git a/.github/workflows/javaCodestyle.yml b/.github/workflows/javaCodestyle.yml
new file mode 100644
index 00000000000..50c970023c1
--- /dev/null
+++ b/.github/workflows/javaCodestyle.yml
@@ -0,0 +1,60 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+name: Java Codestyle
+
+on:
+ push:
+ paths-ignore:
+ - 'docs/**'
+ - '*.md'
+ - '*.html'
+ - 'src/main/python/**'
+ - 'dev/**'
+ branches:
+ - main
+ pull_request:
+ paths-ignore:
+ - 'docs/**'
+ - '*.md'
+ - '*.html'
+ - 'src/main/python/**'
+ - 'dev/**'
+ branches:
+ - main
+
+jobs:
+ java_codestyle:
+ name: Java Checkstyle
+ runs-on: ubuntu-latest
+ steps:
+ - name: Checkout Repository
+ uses: actions/checkout@v6
+
+ - name: Setup Java 17 adopt
+ uses: actions/setup-java@v5
+ with:
+ distribution: adopt
+ java-version: '17'
+ cache: 'maven'
+
+ - name: Run Checkstyle
+ run: mvn -ntp -B -Dcheckstyle.skip=false checkstyle:check
diff --git a/.github/workflows/javaTests.yml b/.github/workflows/javaTests.yml
index d7797e5f4a3..b4f4ce2e11d 100644
--- a/.github/workflows/javaTests.yml
+++ b/.github/workflows/javaTests.yml
@@ -90,7 +90,7 @@ jobs:
name: ${{ matrix.tests }}
steps:
- name: Checkout Repository
- uses: actions/checkout@v5
+ uses: actions/checkout@v6
- name: ${{ matrix.tests }}
uses: ./.github/action/
@@ -106,7 +106,7 @@ jobs:
echo "ARTIFACT_NAME=$ARTIFACT_NAME" >> $GITHUB_ENV
- name: Save Java Test Coverage as Artifact
- uses: actions/upload-artifact@v5
+ uses: actions/upload-artifact@v6
with:
name: ${{ env.ARTIFACT_NAME }}
path: target/jacoco.exec
@@ -126,10 +126,10 @@ jobs:
javadist: ['adopt']
steps:
- name: Checkout Repository
- uses: actions/checkout@v5
+ uses: actions/checkout@v6
- name: Cache Maven Dependencies
- uses: actions/cache@v4
+ uses: actions/cache@v5
with:
path: ~/.m2/repository
key: ${{ runner.os }}-maven-test-${{ hashFiles('**/pom.xml') }}
@@ -137,7 +137,7 @@ jobs:
${{ runner.os }}-maven-test-
- name: Download all Jacoco Artifacts
- uses: actions/download-artifact@v6
+ uses: actions/download-artifact@v7
with:
path: target
@@ -151,7 +151,7 @@ jobs:
run: mvn jacoco:report
- name: Upload coverage to Codecov
- uses: codecov/codecov-action@v5.5.1
+ uses: codecov/codecov-action@v5.5.2
if: github.repository_owner == 'apache'
with:
fail_ci_if_error: false
@@ -160,7 +160,7 @@ jobs:
- name: Upload Jacoco Report Artifact PR
if: (github.repository_owner == 'apache') && (github.ref_name != 'main')
- uses: actions/upload-artifact@v5
+ uses: actions/upload-artifact@v6
with:
name: Java Code Coverage (Jacoco)
path: target/site/jacoco
@@ -168,7 +168,7 @@ jobs:
- name: Upload Jacoco Report Artifact Main
if: (github.repository_owner == 'apache') && (github.ref_name == 'main')
- uses: actions/upload-artifact@v5
+ uses: actions/upload-artifact@v6
with:
name: Java Code Coverage (Jacoco)
path: target/site/jacoco
@@ -176,9 +176,8 @@ jobs:
- name: Upload Jacoco Report Artifact Fork
if: (github.repository_owner != 'apache')
- uses: actions/upload-artifact@v5
+ uses: actions/upload-artifact@v6
with:
name: Java Code Coverage (Jacoco)
path: target/site/jacoco
retention-days: 3
-
diff --git a/.github/workflows/license.yml b/.github/workflows/license.yml
index 88e1418670f..4f7b02ee42f 100644
--- a/.github/workflows/license.yml
+++ b/.github/workflows/license.yml
@@ -54,7 +54,7 @@ jobs:
steps:
- name: Checkout Repository
- uses: actions/checkout@v5
+ uses: actions/checkout@v6
- name: Setup Java ${{ matrix.java }} ${{ matrix.javadist }}
uses: actions/setup-java@v5
diff --git a/.github/workflows/monitoringUITests.yml b/.github/workflows/monitoringUITests.yml
index b4a3179b6a7..9389b394828 100644
--- a/.github/workflows/monitoringUITests.yml
+++ b/.github/workflows/monitoringUITests.yml
@@ -52,7 +52,7 @@ jobs:
node-version: ["lts/*"]
steps:
- - uses: actions/checkout@v5
+ - uses: actions/checkout@v6
- name: Build the application, with Node.js ${{ matrix.node-version }}
uses: actions/setup-node@v6
with:
diff --git a/.github/workflows/python.yml b/.github/workflows/python.yml
index fcd8bf8c849..ea8c9f485e2 100644
--- a/.github/workflows/python.yml
+++ b/.github/workflows/python.yml
@@ -60,7 +60,7 @@ jobs:
name: ${{ matrix.os }} Java ${{ matrix.java }} ${{ matrix.javadist }} Python ${{ matrix.python-version }}/ ${{ matrix.test_mode}}
steps:
- name: Checkout Repository
- uses: actions/checkout@v5
+ uses: actions/checkout@v6
- name: Setup Java ${{ matrix.java }} ${{ matrix.javadist }}
uses: actions/setup-java@v5
@@ -70,13 +70,13 @@ jobs:
cache: 'maven'
- name: Cache Pip Dependencies
- uses: actions/cache@v4
+ uses: actions/cache@v5
with:
path: ~/.cache/pip
key: ${{ runner.os }}-pip-${{ matrix.python-version }}-${{ hashFiles('src/main/python/setup.py') }}
- name: Cache Datasets
- uses: actions/cache@v4
+ uses: actions/cache@v5
with:
path: |
src/main/python/systemds/examples/tutorials/mnist
@@ -84,7 +84,7 @@ jobs:
key: ${{ runner.os }}-mnist-${{ hashFiles('src/main/python/systemds/examples/tutorials/mnist.py') }}-${{ hashFiles('src/main/python/systemds/examples/tutorials/adult.py') }}
- name: Cache Deb Dependencies
- uses: actions/cache@v4
+ uses: actions/cache@v5
with:
path: /var/cache/apt/archives
key: ${{ runner.os }}-${{ hashFiles('.github/workflows/python.yml') }}
@@ -142,11 +142,11 @@ jobs:
export PATH=$SYSTEMDS_ROOT/bin:$PATH
cd src/main/python
./tests/federated/runFedTest.sh
-
+
- name: Cache Torch Hub
if: ${{ matrix.test_mode == 'scuro' }}
id: torch-cache
- uses: actions/cache@v4
+ uses: actions/cache@v5
with:
path: .torch
key: ${{ runner.os }}-torch-${{ hashFiles('requirements.txt') }}
@@ -158,6 +158,8 @@ jobs:
env:
TORCH_HOME: ${{ github.workspace }}/.torch
run: |
+ df -h
+ exit
( while true; do echo "."; sleep 25; done ) &
KA=$!
pip install --upgrade pip wheel setuptools
@@ -172,7 +174,9 @@ jobs:
gensim \
opt-einsum \
nltk \
- fvcore
+ fvcore \
+ scikit-optimize \
+ flair
kill $KA
cd src/main/python
python -m unittest discover -s tests/scuro -p 'test_*.py' -v
diff --git a/.github/workflows/pythonFormatting.yml b/.github/workflows/pythonFormatting.yml
index 532878da1e6..cbdd5e84578 100644
--- a/.github/workflows/pythonFormatting.yml
+++ b/.github/workflows/pythonFormatting.yml
@@ -38,7 +38,7 @@ jobs:
runs-on: ubuntu-latest
steps:
- name: Checkout Repository
- uses: actions/checkout@v5
+ uses: actions/checkout@v6
- name: Setup Python
uses: actions/setup-python@v6
diff --git a/.github/workflows/release-scripts.yml b/.github/workflows/release-scripts.yml
index 95536f668d9..b804a3db507 100644
--- a/.github/workflows/release-scripts.yml
+++ b/.github/workflows/release-scripts.yml
@@ -42,7 +42,7 @@ jobs:
steps:
# Java setup docs:
# https://github.com/actions/setup-java/blob/main/docs/advanced-usage.md#installing-custom-java-package-type
- - uses: actions/checkout@v5
+ - uses: actions/checkout@v6
- name: Set up JDK 17
uses: actions/setup-java@v5
with:
@@ -54,7 +54,7 @@ jobs:
- run: printf "JAVA_HOME = $JAVA_HOME \n"
- name: Cache local Maven repository
- uses: actions/cache@v4
+ uses: actions/cache@v5
with:
path: ~/.m2/repository
key: ${{ runner.os }}-maven-${{ hashFiles('**/pom.xml') }}
diff --git a/.gitignore b/.gitignore
index d2fcdb9a4de..5de697a37e3 100644
--- a/.gitignore
+++ b/.gitignore
@@ -146,7 +146,7 @@ src/test/scripts/functions/pipelines/intermediates/classification/*
venv
venv/*
-
+.venv
# resource optimization
scripts/resource/output
*.pem
diff --git a/dev/checkstyle/checkstyle.xml b/dev/checkstyle/checkstyle.xml
new file mode 100644
index 00000000000..96a381cd4c9
--- /dev/null
+++ b/dev/checkstyle/checkstyle.xml
@@ -0,0 +1,74 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
diff --git a/dev/checkstyle/suppressions-xpath.xml b/dev/checkstyle/suppressions-xpath.xml
new file mode 100644
index 00000000000..8843856a78d
--- /dev/null
+++ b/dev/checkstyle/suppressions-xpath.xml
@@ -0,0 +1,41 @@
+
+
+
+
+
+
+
+
+
+
+
+
diff --git a/dev/checkstyle/suppressions.xml b/dev/checkstyle/suppressions.xml
new file mode 100644
index 00000000000..1daf743a34e
--- /dev/null
+++ b/dev/checkstyle/suppressions.xml
@@ -0,0 +1,314 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
diff --git a/pom.xml b/pom.xml
index 0eb7248f617..88c10511936 100644
--- a/pom.xml
+++ b/pom.xml
@@ -64,6 +64,7 @@
2.0.3
3.2.0
3.0.0
+ 3.3.1
3.0.0
3.5.0
3.2.0
@@ -87,6 +88,7 @@
1C
2
false
+ true
false
**
false
@@ -583,6 +585,35 @@
+
+
+ org.apache.maven.plugins
+ maven-checkstyle-plugin
+ ${maven-checkstyle-plugin.version}
+
+
+
+ checkstyle
+
+ test
+
+ check
+
+
+
+
+
+ dev/checkstyle/checkstyle.xml
+ ${checkstyle.skip}
+ true
+ true
+ ${project.build.directory}/checkstyle-result.xml
+ xml
+ true
+
+
+
+
org.apache.maven.plugins
diff --git a/scripts/builtin/scaleRobust.dml b/scripts/builtin/scaleRobust.dml
new file mode 100644
index 00000000000..ce309fabe54
--- /dev/null
+++ b/scripts/builtin/scaleRobust.dml
@@ -0,0 +1,60 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+# Robust scaling using median and IQR (Interquartile Range)
+# Resistant to outliers by centering with the median and scaling with IQR.
+#
+# INPUT:
+# -------------------------------------------------------------------------------------
+# X Input feature matrix of shape n-by-m
+# -------------------------------------------------------------------------------------
+#
+# OUTPUT:
+# -------------------------------------------------------------------------------------
+# Y Scaled output matrix of shape n-by-m
+# med Column medians (Q2) of shape 1-by-m
+# q1 Column first quantiles (Q1) of shape 1-by-m
+# q3 Column first quantiles (Q3) of shape 1-by-m
+# -------------------------------------------------------------------------------------
+
+m_scaleRobust = function(Matrix[Double] X)
+ return (Matrix[Double] Y, Matrix[Double] med, Matrix[Double] q1, Matrix[Double] q3)
+{
+ n = nrow(X)
+ m = ncol(X)
+
+ med = matrix(0.0, rows=1, cols=m)
+ q1 = matrix(0.0, rows=1, cols=m)
+ q3 = matrix(0.0, rows=1, cols=m)
+
+ # Define quantile probabilities once, outside the loop
+ q_probs = as.matrix(list(0.25, 0.5, 0.75));
+
+ # Loop over columns to compute quantiles
+ parfor (j in 1:m) {
+ q = quantile(X[,j], q_probs)
+ med[1,j] = q[2,1]
+ q1[1,j] = q[1,1]
+ q3[1,j] = q[3,1]
+ }
+
+ Y = scaleRobustApply(X, med, q1, q3);
+}
diff --git a/scripts/builtin/scaleRobustApply.dml b/scripts/builtin/scaleRobustApply.dml
new file mode 100644
index 00000000000..11461731b34
--- /dev/null
+++ b/scripts/builtin/scaleRobustApply.dml
@@ -0,0 +1,48 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+# Apply robust scaling using precomputed medians and IQRs
+#
+# INPUT:
+# -------------------------------------------------------------------------------------
+# X Input feature matrix of shape n-by-m
+# med Column medians (Q2) of shape 1-by-m
+# q1 Column first quantiles (Q1) of shape 1-by-m
+# q3 Column first quantiles (Q3) of shape 1-by-m
+# -------------------------------------------------------------------------------------
+#
+# OUTPUT:
+# -------------------------------------------------------------------------------------
+# Y Scaled output matrix of shape n-by-m
+# -------------------------------------------------------------------------------------
+
+m_scaleRobustApply = function(Matrix[Double] X, Matrix[Double] med, Matrix[Double] q1, Matrix[Double] q3)
+ return (Matrix[Double] Y)
+{
+ iqr = q3 - q1
+
+ # Ensure robust scaling is safe by replacing invalid IQRs
+ iqr = replace(target=iqr, pattern=0, replacement=1)
+ iqr = replace(target=iqr, pattern=NaN, replacement=1)
+
+ # Apply robust transformation
+ Y = (X - med) / iqr
+}
diff --git a/scripts/staging/ssb/README.md b/scripts/staging/ssb/README.md
new file mode 100644
index 00000000000..ef9f09afeaf
--- /dev/null
+++ b/scripts/staging/ssb/README.md
@@ -0,0 +1,525 @@
+
+
+
+# Star Schema Benchmark (SSB) for SystemDS
+
+This README documents the SSB DML queries under `scripts/ssb/queries/` and the runner scripts under `scripts/ssb/shell/` that execute and benchmark them. It is focused on what is implemented today, how to run it, and how to interpret the outputs for performance analysis.
+
+---
+
+## Table of Contents
+
+1. Project Layout
+2. Quick Start
+3. Data Location (`--input-dir` and DML `input_dir`)
+4. Single-Engine Runner (`scripts/ssb/shell/run_ssb.sh`)
+5. Multi-Engine Performance Runner (`scripts/ssb/shell/run_all_perf.sh`)
+6. Outputs and Examples
+7. Adding/Editing Queries
+8. Troubleshooting
+
+---
+
+## 1) Project Layout
+
+Paths are relative to the repo root:
+
+```
+systemds/
+├── scripts/ssb/
+│ ├── README.md # This guide
+│ ├── queries/ # DML queries (q1_1.dml ... q4_3.dml)
+│ │ ├── q1_1.dml - q1_3.dml # Flight 1
+│ │ ├── q2_1.dml - q2_3.dml # Flight 2
+│ │ ├── q3_1.dml - q3_4.dml # Flight 3
+│ │ └── q4_1.dml - q4_3.dml # Flight 4
+│ ├── shell/
+│ │ ├── run_ssb.sh # Single-engine (SystemDS) runner
+│ │ ├── run_all_perf.sh # Multi-engine performance benchmark
+│ │ └── ssbOutputData/ # Results (created on first run)
+│ │ ├── QueryData/ # Per-query outputs from run_ssb.sh
+│ │ └── PerformanceData/ # Multi-engine outputs from run_all_perf.sh
+│ └── sql/ # SQL versions + `ssb.duckdb` for DuckDB
+```
+
+Note: The SSB raw data directory is not committed. You must point the runners to your generated data with `--input-dir`.
+
+---
+
+## 2) Quick Start
+
+Set up SystemDS and run the SSB queries.
+
+1) Build SystemDS (from repo root):
+
+```bash
+mvn -DskipTests package
+```
+
+2) Make sure the SystemDS binary exists (repo-local `bin/systemds` or on `PATH`).
+
+3) Make runner scripts executable:
+
+```bash
+chmod +x scripts/ssb/shell/run_ssb.sh scripts/ssb/shell/run_all_perf.sh
+```
+
+4) Provide SSB data (from dbgen) in a directory, e.g. `/path/to/ssb-data`.
+
+5) Run a single SSB query on SystemDS (from repo root):
+
+```bash
+scripts/ssb/shell/run_ssb.sh q1.1 --input-dir=/path/to/ssb-data --stats
+```
+
+6) Run the multi-engine performance benchmark across all queries (from repo root):
+
+```bash
+scripts/ssb/shell/run_all_perf.sh --input-dir=/path/to/ssb-data --stats --repeats=5
+```
+
+If `--input-dir` is omitted, the scripts default to `./data/` under the repo root.
+
+---
+
+## 3) Data Location (`--input-dir` and DML `input_dir`)
+
+Both runners pass a named argument `input_dir` into DML as:
+
+```
+-nvargs input_dir=/absolute/path/to/ssb-data
+```
+
+Your DML scripts should construct paths from `input_dir`. Example:
+
+```dml
+dates = read(paste(input_dir, "/date.tbl", sep=""), data_type="frame", format="csv", sep="|", header=FALSE)
+lineorder = read(paste(input_dir, "/lineorder.tbl", sep=""), data_type="frame", format="csv", sep="|", header=FALSE)
+```
+
+Expected base files in `input_dir`: `customer.tbl`, `supplier.tbl`, `part.tbl`, `date.tbl` and `lineorder*.tbl` (fact table name can vary by scale). The runners validate that `--input-dir` exists before executing.
+
+---
+
+## 4) Single-Engine Runner (`scripts/ssb/shell/run_ssb.sh`)
+
+Runs SSB DML queries with SystemDS and saves results per query.
+
+- Usage:
+ - `scripts/ssb/shell/run_ssb.sh` — run all SSB queries
+ - `scripts/ssb/shell/run_ssb.sh q1.1 q2.3` — run specific queries
+ - `scripts/ssb/shell/run_ssb.sh --stats` — include SystemDS internal statistics
+ - `scripts/ssb/shell/run_ssb.sh --input-dir=/path/to/data` — set data dir
+ - `scripts/ssb/shell/run_ssb.sh --output-dir=/tmp/out` — set output dir
+
+- Query names: You can use dotted form (`q1.1`); the runner maps to `q1_1.dml` internally.
+
+- Functionality:
+ - Single-threaded execution via auto-generated `conf/single_thread.xml`.
+ - DML `input_dir` forwarding with `-nvargs`.
+ - Pre-check for data directory; clear errors if missing.
+ - Runtime error detection by scanning for “An Error Occurred : …”.
+ - Optional `--stats` to capture SystemDS internal statistics in JSON.
+ - Per-query outputs in TXT, CSV, and JSON.
+ - `run.json` with run-level metadata and per-query status/results.
+ - Clear end-of-run summary and, for table results, a “DETAILED QUERY RESULTS” section.
+ - Exit code is non-zero if any query failed (handy for CI).
+
+- Output layout:
+ - Base directory: `--output-dir` (default: `scripts/ssb/shell/ssbOutputData/QueryData`)
+ - Each run: `ssb_run_/`
+ - `txt/.txt` — human-readable result
+ - `csv/.csv` — scalar or table as CSV
+ - `json/.json` — per-query JSON
+ - `run.json` — full metadata and results for the run
+
+- Example console output (abridged):
+
+```
+[1/13] Running: q1_1.dml
+...
+=========================================
+SSB benchmark completed!
+Total queries executed: 13
+Failed queries: 0
+Statistics: enabled
+
+=========================================
+RUN METADATA SUMMARY
+=========================================
+Timestamp: 2025-09-05 12:34:56 UTC
+Hostname: myhost
+Seed: 123456
+Software Versions:
+ SystemDS: 3.4.0-SNAPSHOT
+ JDK: 21.0.2
+System Resources:
+ CPU: Apple M2
+ RAM: 16GB
+Data Build Info:
+ SSB Data: customer:300000 part:200000 supplier:2000 lineorder:6001215
+=========================================
+
+===================================================
+QUERIES SUMMARY
+===================================================
+No. Query Result Status
+---------------------------------------------------
+1 q1.1 12 rows (see below) ✓ Success
+2 q1.2 1 ✓ Success
+...
+===================================================
+
+=========================================
+DETAILED QUERY RESULTS
+=========================================
+[1] Results for q1.1:
+----------------------------------------
+1992|ASIA|12345.67
+1993|ASIA|23456.78
+...
+----------------------------------------
+```
+
+---
+
+## 5) Multi-Engine Performance Runner (`scripts/ssb/shell/run_all_perf.sh`)
+
+Runs SSB queries across SystemDS, PostgreSQL, and DuckDB with repeated timings and statistical analysis.
+
+- Usage:
+ - `scripts/ssb/shell/run_all_perf.sh` — run all queries on available engines
+ - `scripts/ssb/shell/run_all_perf.sh q1.1 q2.3` — run specific queries
+ - `scripts/ssb/shell/run_all_perf.sh --warmup=2 --repeats=10` — control sampling
+ - `scripts/ssb/shell/run_all_perf.sh --stats` — include core/internal engine timings
+ - `scripts/ssb/shell/run_all_perf.sh --layout=wide|stacked` — control terminal layout
+ - `scripts/ssb/shell/run_all_perf.sh --input-dir=... --output-dir=...` — set paths
+
+- Query names: dotted form (`q1.1`) is accepted; mapped internally to `q1_1.dml`.
+
+- Engine prerequisites:
+ - PostgreSQL:
+ - Install `psql` CLI and ensure a PostgreSQL server is running.
+ - Default connection in the script: `POSTGRES_DB=ssb`, `POSTGRES_USER=$(whoami)`, `POSTGRES_HOST=localhost`.
+ - Create the `ssb` database and load the standard SSB tables and data (schema not included in this repo). The SQL queries under `scripts/ssb/sql/` expect the canonical SSB schema and data.
+ - The runner verifies connectivity; if it cannot connect or tables are missing, PostgreSQL results are skipped.
+ - DuckDB:
+ - Install the DuckDB CLI (`duckdb`).
+ - The runner looks for the database at `scripts/ssb/sql/ssb.duckdb`. Ensure it contains SSB tables and data.
+ - If the CLI is missing or the DB file cannot be opened, DuckDB results are skipped.
+ - SystemDS is required; the other engines are optional. Missing engines are reported and skipped gracefully.
+
+- Functionality:
+ - Single-threaded execution for fairness (SystemDS config; SQL engines via settings).
+ - Pre-flight data-dir check and SystemDS test-run with runtime-error detection.
+ - Warmups and repeated measurements using `/usr/bin/time -p` (ms resolution).
+ - Statistics per engine: mean, population stdev, p95, and CV%.
+ - “Shell” vs “Core” time: SystemDS core from `-stats`, PostgreSQL core via EXPLAIN ANALYZE, DuckDB core via JSON profiling.
+ - Environment verification: gracefully skips PostgreSQL or DuckDB if not available.
+ - Terminal-aware output: wide table with grid or stacked multi-line layout.
+ - Results to CSV and JSON with rich metadata (system info, versions, run config).
+
+- Layouts (display formats):
+ - Auto selection: `--layout=auto` (default). Chooses `wide` if terminal is wide enough, else `stacked`.
+ - Wide layout: `--layout=wide`. Prints a grid with columns for each engine and a `Fastest` column. Three header rows show labels for `mean`, `±/CV`, and `p95`.
+ - Stacked layout: `--layout=stacked` or `--stacked`. Prints a compact, multi-line block per query (best for narrow terminals).
+ - Dynamic scaling: The wide layout scales column widths to fit the terminal; if still too narrow, it falls back to stacked.
+ - Row semantics: Row 1 = mean (ms); Row 2 = `±stdev/CV%`; Row 3 = `p95 (ms)`.
+ - Fastest: The runner highlights the engine with the lowest mean per query.
+
+- Output layout:
+ - Base directory: `--output-dir` (default: `scripts/ssb/shell/ssbOutputData/PerformanceData`)
+ - Files per run (timestamped basename):
+ - `ssb_results_.csv`
+ - `ssb_results_.json`
+
+- Example console output (abridged, wide layout):
+
+```
+==================================================================================
+ MULTI-ENGINE PERFORMANCE BENCHMARK METADATA
+==================================================================================
+Timestamp: 2025-09-05 12:34:56 UTC
+Hostname: myhost
+Seed: 123456
+Software Versions:
+ SystemDS: 3.4.0-SNAPSHOT
+ JDK: 21.0.2
+ PostgreSQL: psql (PostgreSQL) 14.11
+ DuckDB: v0.10.3
+System Resources:
+ CPU: Apple M2
+ RAM: 16GB
+Data Build Info:
+ SSB Data: customer:300000 part:200000 supplier:2000 lineorder:6001215
+Run Configuration:
+ Statistics: enabled
+ Queries: 13 selected
+ Warmup Runs: 1
+ Repeat Runs: 5
+
++--------+--------------+--------------+--------------+----------------+--------------+----------------+----------+
+| Query | SysDS Shell | SysDS Core | PostgreSQL | PostgreSQL Core| DuckDB | DuckDB Core | Fastest |
+| | mean | mean | mean | mean | mean | mean | |
+| | ±/CV | ±/CV | ±/CV | ±/CV | ±/CV | ±/CV | |
+| | p95 | p95 | p95 | p95 | p95 | p95 | |
++--------+--------------+--------------+--------------+----------------+--------------+----------------+----------+
+| q1_1 | 1824.0 | 1210.0 | 2410.0 | 2250.0 | 980.0 | 910.0 | DuckDB |
+| | ±10.2/0.6% | ±8.6/0.7% | ±15.1/0.6% | ±14.0/0.6% | ±5.4/0.6% | ±5.0/0.5% | |
+| | p95:1840.0 | p95:1225.0 | p95:2435.0 | p95:2274.0 | p95:989.0 | p95:919.0 | |
++--------+--------------+--------------+--------------+----------------+--------------+----------------+----------+
+```
+
+- Example console output (abridged, stacked layout):
+
+```
+Query : q1_1 Fastest: DuckDB
+ SystemDS Shell: 1824.0
+ ±10.2ms/0.6%
+ p95:1840.0ms
+ SystemDS Core: 1210.0
+ ±8.6ms/0.7%
+ p95:1225.0ms
+ PostgreSQL: 2410.0
+ ±15.1ms/0.6%
+ p95:2435.0ms
+ PostgreSQL Core:2250.0
+ ±14.0ms/0.6%
+ p95:2274.0ms
+ DuckDB: 980.0
+ ±5.4ms/0.6%
+ p95:989.0ms
+ DuckDB Core: 910.0
+ ±5.0ms/0.5%
+ p95:919.0ms
+--------------------------------------------------------------------------------
+```
+
+---
+
+## 6) Outputs and Examples
+
+Where to find results and how to read them.
+
+- SystemDS-only runner (`scripts/ssb/shell/run_ssb.sh`):
+ - Directory: `scripts/ssb/shell/ssbOutputData/QueryData/ssb_run_/`
+ - Files: `txt/.txt`, `csv/.csv`, `json/.json`, and `run.json`
+ - `run.json` example (stats enabled, single query):
+
+```json
+{
+ "benchmark_type": "ssb_systemds",
+ "timestamp": "2025-09-07 19:45:11 UTC",
+ "hostname": "eduroam-141-23-175-117.wlan.tu-berlin.de",
+ "seed": 849958376,
+ "software_versions": {
+ "systemds": "3.4.0-SNAPSHOT",
+ "jdk": "17.0.15"
+ },
+ "system_resources": {
+ "cpu": "Apple M1 Pro",
+ "ram": "16GB"
+ },
+ "data_build_info": {
+ "customer": "30000",
+ "part": "200000",
+ "supplier": "2000",
+ "date": "2557",
+ "lineorder": "8217"
+ },
+ "run_configuration": {
+ "statistics_enabled": true,
+ "queries_selected": 1,
+ "queries_executed": 1,
+ "queries_failed": 0
+ },
+ "results": [
+ {
+ "query": "q1_1",
+ "result": "687752409 ",
+ "stats": [
+ "SystemDS Statistics:",
+ "Total elapsed time: 1.557 sec.",
+ "Total compilation time: 0.410 sec.",
+ "Total execution time: 1.147 sec.",
+ "Cache hits (Mem/Li/WB/FS/HDFS): 11054/0/0/0/2.",
+ "Cache writes (Li/WB/FS/HDFS): 0/26/3/0.",
+ "Cache times (ACQr/m, RLS, EXP): 0.166/0.001/0.060/0.000 sec.",
+ "HOP DAGs recompiled (PRED, SB): 0/175.",
+ "HOP DAGs recompile time: 0.063 sec.",
+ "Functions recompiled: 2.",
+ "Functions recompile time: 0.016 sec.",
+ "Total JIT compile time: 1.385 sec.",
+ "Total JVM GC count: 1.",
+ "Total JVM GC time: 0.026 sec.",
+ "Heavy hitter instructions:",
+ " # Instruction Time(s) Count",
+ " 1 m_raJoin 0.940 1",
+ " 2 ucumk+ 0.363 3",
+ " 3 - 0.219 1345",
+ " 4 nrow 0.166 7",
+ " 5 ctable 0.086 2",
+ " 6 * 0.078 1",
+ " 7 parallelBinarySearch 0.069 1",
+ " 8 ba+* 0.049 5",
+ " 9 rightIndex 0.016 8611",
+ " 10 leftIndex 0.015 1680"
+ ],
+ "status": "success"
+ }
+ ]
+}
+```
+
+ Notes:
+ - The `result` field contains the query’s output (scalar or tabular content collapsed). When `--stats` is used, `stats` contains the full SystemDS statistics block line-by-line.
+ - For failed queries, an `error_message` string is included and `status` is set to `"error"`.
+
+- Multi-engine runner (`scripts/ssb/shell/run_all_perf.sh`):
+ - Directory: `scripts/ssb/shell/ssbOutputData/PerformanceData/`
+ - Files per run: `ssb_results_.csv` and `.json`
+ - CSV contains display strings and raw numeric stats (mean/stdev/p95) for each engine; JSON contains the same plus metadata and fastest-engine per query.
+ - `ssb_results_*.json` example (stats enabled, single query):
+
+```json
+{
+ "benchmark_metadata": {
+ "benchmark_type": "multi_engine_performance",
+ "timestamp": "2025-09-07 20:11:16 UTC",
+ "hostname": "eduroam-141-23-175-117.wlan.tu-berlin.de",
+ "seed": 578860764,
+ "software_versions": {
+ "systemds": "3.4.0-SNAPSHOT",
+ "jdk": "17.0.15",
+ "postgresql": "psql (PostgreSQL) 17.5",
+ "duckdb": "v1.3.2 (Ossivalis) 0b83e5d2f6"
+ },
+ "system_resources": {
+ "cpu": "Apple M1 Pro",
+ "ram": "16GB"
+ },
+ "data_build_info": {
+ "customer": "30000",
+ "part": "200000",
+ "supplier": "2000",
+ "date": "2557",
+ "lineorder": "8217"
+ },
+ "run_configuration": {
+ "statistics_enabled": true,
+ "queries_selected": 1,
+ "warmup_runs": 1,
+ "repeat_runs": 5
+ }
+ },
+ "results": [
+ {
+ "query": "q1_1",
+ "systemds": {
+ "shell": {
+ "display": "2186.0 (±95.6ms/4.4%, p95:2250.0ms)",
+ "mean_ms": 2186.0,
+ "stdev_ms": 95.6,
+ "p95_ms": 2250.0
+ },
+ "core": {
+ "display": "1151.2 (±115.3ms/10.0%, p95:1334.0ms)",
+ "mean_ms": 1151.2,
+ "stdev_ms": 115.3,
+ "p95_ms": 1334.0
+ },
+ "status": "success",
+ "error_message": null
+ },
+ "postgresql": {
+ "display": "26.0 (±4.9ms/18.8%, p95:30.0ms)",
+ "mean_ms": 26.0,
+ "stdev_ms": 4.9,
+ "p95_ms": 30.0
+ },
+ "postgresql_core": {
+ "display": "3.8 (±1.4ms/36.8%, p95:5.7ms)",
+ "mean_ms": 3.8,
+ "stdev_ms": 1.4,
+ "p95_ms": 5.7
+ },
+ "duckdb": {
+ "display": "30.0 (±0.0ms/0.0%, p95:30.0ms)",
+ "mean_ms": 30.0,
+ "stdev_ms": 0.0,
+ "p95_ms": 30.0
+ },
+ "duckdb_core": {
+ "display": "1.1 (±0.1ms/9.1%, p95:1.3ms)",
+ "mean_ms": 1.1,
+ "stdev_ms": 0.1,
+ "p95_ms": 1.3
+ },
+ "fastest_engine": "PostgreSQL"
+ }
+ ]
+}
+```
+
+ Differences at a glance:
+ - Single-engine `run.json` focuses on query output (`result`) and, when enabled, the SystemDS `stats` array. Status and error handling are per-query.
+ - Multi-engine results JSON focuses on timing statistics for each engine (`shell` vs `core` for SystemDS; `postgresql`/`postgresql_core`; `duckdb`/`duckdb_core`) along with a `fastest_engine` field. It does not include the query’s actual result values.
+
+---
+
+## 7) Adding/Editing Queries
+
+Guidelines for DML in `scripts/ssb/queries/`:
+
+- Name files as `qX_Y.dml` (e.g., `q1_1.dml`). The runners accept `q1.1` on the CLI and map it for you.
+- Always derive paths from `input_dir` named argument (see Section 3).
+- Keep I/O separate from compute where possible (helps early error detection).
+- Add a short header comment with original SQL and intent.
+
+Example header:
+
+```dml
+/*
+ SQL: SELECT ...
+ Description: Revenue per month by supplier region
+*/
+```
+
+---
+
+## 8) Troubleshooting
+
+- Missing data directory: pass `--input-dir=/path/to/ssb-data` and ensure `*.tbl` files exist.
+- SystemDS not found: build (`mvn -DskipTests package`) and use `./bin/systemds` or ensure `systemds` is on PATH.
+- Query fails with runtime error: the runners mark `status: "error"` and include a short `error_message` in JSON outputs. See console snippet for context.
+- macOS cache dropping: OS caches cannot be dropped like Linux; the multi-engine runner mitigates with warmups + repeated averages and reports p95/CV.
+
+If something looks off, attach the relevant `run.json` or `ssb_results_*.json` when filing issues.
+
+- To debug DML runtime errors, run the DML directly:
+
+```bash
+./bin/systemds -f scripts/ssb/queries/q1_1.dml -nvargs input_dir=/path/to/data
+```
+
+- When `--stats` is enabled, SystemDS internal "core" timing is extracted and reported separately (useful to separate JVM / startup overhead from core computation).
+
+All these metrics appear in the generated CSVs and JSON entries.
+- Permission errors: `chmod +x scripts/ssb/shell/*.sh`.
diff --git a/scripts/staging/ssb/queries/q1_1.dml b/scripts/staging/ssb/queries/q1_1.dml
new file mode 100644
index 00000000000..eac47b099ee
--- /dev/null
+++ b/scripts/staging/ssb/queries/q1_1.dml
@@ -0,0 +1,91 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+/* DML-script implementing the ssb query Q1.1 in SystemDS.
+SELECT SUM(lo_extendedprice * lo_discount) AS REVENUE
+FROM lineorder, dates
+WHERE
+ lo_orderdate = d_datekey
+ AND d_year = 1993
+ AND lo_discount BETWEEN 1 AND 3
+ AND lo_quantity < 25;
+
+Usage:
+./bin/systemds scripts/ssb/queries/q1_1.dml -nvargs input_dir="/path/to/data"
+./bin/systemds scripts/ssb/queries/q1_1.dml -nvargs input_dir="/Users/ghafekalsaho/Desktop/data"
+or with explicit -f flag:
+./bin/systemds -f scripts/ssb/queries/q1_1.dml -nvargs input_dir="/path/to/data"
+
+Parameters:
+input_dir - Path to input directory containing the table files (e.g., ./data)
+*/
+# -- SOURCING THE RA-FUNCTIONS --
+source("./scripts/builtin/raSelection.dml") as raSel
+source("./scripts/builtin/raJoin.dml") as raJoin
+
+# -- PARAMETER HANDLING --
+input_dir = ifdef($input_dir, "./data");
+print("Loading tables from directory: " + input_dir);
+
+# -- READING INPUT FILES --
+# CSV TABLES
+date_csv = read(input_dir + "/date.tbl", data_type="frame", format="csv", header=FALSE, sep="|");
+lineorder_csv = read(input_dir + "/lineorder.tbl", data_type="frame", format="csv", header=FALSE, sep="|");
+
+
+# -- PREPARING --
+# EXTRACTING MINIMAL DATE DATA TO OPTIMIZE RUNTIME => COL-1 : DATE-KEY | COL-5 : YEAR
+date_csv_min = cbind(date_csv[, 1], date_csv[, 5]);
+date_matrix_min = as.matrix(date_csv_min);
+
+# EXTRACTING MINIMAL LINEORDER DATA TO OPTIMIZE RUNTIME => COL-6 : LO_ORDERDATE |
+# COL-9 : LO_QUANTITY | COL-10 : LO_EXTPRICE | COL-12 : LO_DISCOUNT
+lineorder_csv_min = cbind(lineorder_csv[, 6], lineorder_csv[, 9], lineorder_csv[, 10], lineorder_csv[, 12]);
+lineorder_matrix_min = as.matrix(lineorder_csv_min);
+
+
+# -- FILTERING THE DATA WITH RA-SELECTION FUNCTION --
+d_year_filt = raSel::m_raSelection(date_matrix_min, col=2, op="==", val=1993); # D_YEAR = '1993'
+
+# LO_QUANTITY < 25
+lo_quan_filt = raSel::m_raSelection(lineorder_matrix_min, col=2, op="<", val=25);
+
+# LO_DISCOUNT BETWEEN 1 AND 3
+lo_quan_disc_filt = raSel::m_raSelection(lo_quan_filt, col=4, op=">=", val=1);
+lo_quan_disc_filt = raSel::m_raSelection(lo_quan_disc_filt, col=4, op="<=", val=3);
+
+
+# -- JOIN TABLES WITH RA-JOIN FUNCTION --
+# JOINING FILTERED LINEORDER TABLE WITH FILTERED DATE TABLE WHERE LO_ORDERDATE = D_DATEKEY
+joined_matrix = raJoin::m_raJoin(A=lo_quan_disc_filt, colA=1, B=d_year_filt, colB=1, method="sort-merge");
+#print("LO-DATE JOINED.");
+
+
+# -- AGGREGATION --
+lo_extprice = joined_matrix[, 3]; #LO_EXTPRICE : 3 COLUMN OF JOINED-MATRIX
+lo_disc = joined_matrix[, 4]; #LO_DISCOUNT : 4 COLUMN OF JOINED-MATRIX
+revenue = sum(lo_extprice * lo_disc);
+
+print("REVENUE: " + as.integer(revenue));
+
+#print("Q1.1 finished.\n");
+
+
diff --git a/scripts/staging/ssb/queries/q1_2.dml b/scripts/staging/ssb/queries/q1_2.dml
new file mode 100644
index 00000000000..781f108a512
--- /dev/null
+++ b/scripts/staging/ssb/queries/q1_2.dml
@@ -0,0 +1,114 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+
+/*DML-script implementing the ssb query Q1.2 in SystemDS.
+SELECT SUM(lo_extendedprice * lo_discount) AS REVENUE
+FROM lineorder, dates
+WHERE
+ lo_orderdate = d_datekey
+ AND d_yearmonth = 'Jan1994'
+ AND lo_discount BETWEEN 4 AND 6
+ AND lo_quantity BETWEEN 26 AND 35;
+
+Usage:
+./bin/systemds scripts/ssb/queries/q1_2.dml -nvargs input_dir="/path/to/data"
+./bin/systemds scripts/ssb/queries/q1_2.dml -nvargs input_dir="/Users/ghafekalsaho/Desktop/data"
+or with explicit -f flag:
+./bin/systemds -f scripts/ssb/queries/q1_2.dml -nvargs input_dir="/path/to/data"
+
+Parameters:
+input_dir - Path to input directory containing the table files (e.g., ./data)
+*/
+
+# -- SOURCING THE RA-FUNCTIONS --
+source("./scripts/builtin/raSelection.dml") as raSel
+source("./scripts/builtin/raJoin.dml") as raJoin
+
+# -- PARAMETER HANDLING --
+input_dir = ifdef($input_dir, "./data");
+print("Loading tables from directory: " + input_dir);
+
+# -- READING INPUT FILES --
+# CSV TABLES
+date_csv = read(input_dir + "/date.tbl", data_type="frame", format="csv", header=FALSE, sep="|");
+lineorder_csv = read(input_dir + "/lineorder.tbl", data_type="frame", format="csv", header=FALSE, sep="|");
+
+# -- PREPARING --
+# Optimized approach: Single-pass filtering with direct matrix construction
+# Convert date key column to numeric matrix for proper handling
+date_keys_matrix = as.matrix(date_csv[, 1]);
+
+# Count Jan1994 rows first to pre-allocate matrix efficiently
+date_nrows = nrow(date_csv);
+jan1994_count = 0;
+for (i in 1:date_nrows) {
+ yearmonth_val = as.scalar(date_csv[i, 7]);
+ if (yearmonth_val == "Jan1994") {
+ jan1994_count = jan1994_count + 1;
+ }
+}
+
+# Pre-allocate final matrix and fill in single pass
+date_filtered = matrix(0, jan1994_count, 2);
+filtered_idx = 0;
+for (i in 1:date_nrows) {
+ yearmonth_val = as.scalar(date_csv[i, 7]);
+ if (yearmonth_val == "Jan1994") {
+ filtered_idx = filtered_idx + 1;
+ date_filtered[filtered_idx, 1] = as.scalar(date_keys_matrix[i, 1]); # date_key
+ date_filtered[filtered_idx, 2] = 1; # encoded value for Jan1994
+ }
+}
+
+# EXTRACTING MINIMAL LINEORDER DATA TO OPTIMIZE RUNTIME => COL-6 : LO_ORDERDATE |
+# COL-9 : LO_QUANTITY | COL-10 : LO_EXTPRICE | COL-12 : LO_DISCOUNT
+lineorder_csv_min = cbind(lineorder_csv[, 6], lineorder_csv[, 9], lineorder_csv[, 10], lineorder_csv[, 12]);
+lineorder_min_matrix = as.matrix(lineorder_csv_min);
+
+
+# -- FILTERING THE DATA WITH RA-SELECTION FUNCTION --
+# We already filtered for D_YEARMONTH = 'Jan1994', so d_year_filt is our filtered date data
+d_year_filt = date_filtered;
+
+# LO_QUANTITY BETWEEN 26 AND 35
+lo_quan_filt = raSel::m_raSelection(lineorder_min_matrix, col=2, op=">=", val=26);
+lo_quan_filt = raSel::m_raSelection(lo_quan_filt, col=2, op="<=", val=35);
+
+# LO_DISCOUNT BETWEEN 4 AND 6
+lo_quan_disc_filt = raSel::m_raSelection(lo_quan_filt, col=4, op=">=", val=4);
+lo_quan_disc_filt = raSel::m_raSelection(lo_quan_disc_filt, col=4, op="<=", val=6);
+
+
+# -- JOIN TABLES WITH RA-JOIN FUNCTION --
+# JOINING FILTERED LINEORDER TABLE WITH FILTERED DATE TABLE WHERE LO_ORDERDATE = D_DATEKEY
+joined_matrix = raJoin::m_raJoin(A=lo_quan_disc_filt, colA=1, B=d_year_filt, colB=1, method="sort-merge");
+#print("LO-DATE JOINED.");
+
+
+# -- AGGREGATION --
+lo_extprice = joined_matrix[, 3]; #LO_EXTPRICE : 3 COLUMN OF JOINED-MATRIX
+lo_disc = joined_matrix[, 4]; #LO_DISCOUNT : 4 COLUMN OF JOINED-MATRIX
+revenue = sum(lo_extprice * lo_disc);
+
+print("REVENUE: " + as.integer(revenue));
+
+#print("Q1.2 finished.\n");
\ No newline at end of file
diff --git a/scripts/staging/ssb/queries/q1_3.dml b/scripts/staging/ssb/queries/q1_3.dml
new file mode 100644
index 00000000000..cd9ba0b8746
--- /dev/null
+++ b/scripts/staging/ssb/queries/q1_3.dml
@@ -0,0 +1,115 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+
+/*DML-script implementing the ssb query Q1.3 in SystemDS.
+SELECT SUM(lo_extendedprice * lo_discount) AS REVENUE
+FROM lineorder, dates
+WHERE
+ lo_orderdate = d_datekey
+ AND d_weeknuminyear = 6
+ AND d_year = 1994
+ AND lo_discount BETWEEN 5 AND 7
+ AND lo_quantity BETWEEN 26 AND 35;
+
+Usage:
+./bin/systemds scripts/ssb/queries/q1_3.dml -nvargs input_dir="/path/to/data"
+./bin/systemds scripts/ssb/queries/q1_3.dml -nvargs input_dir="/Users/ghafekalsaho/Desktop/data"
+or with explicit -f flag:
+./bin/systemds -f scripts/ssb/queries/q1_3.dml -nvargs input_dir="/path/to/data"
+
+Parameters:
+input_dir - Path to input directory containing the table files (e.g., ./data)
+*/
+
+
+# -- SOURCING THE RA-FUNCTIONS --
+source("./scripts/builtin/raSelection.dml") as raSel
+source("./scripts/builtin/raJoin.dml") as raJoin
+
+# -- PARAMETER HANDLING --
+input_dir = ifdef($input_dir, "./data");
+
+# -- READING INPUT FILES --
+# CSV TABLES
+date_csv = read(input_dir + "/date.tbl", data_type="frame", format="csv", header=FALSE, sep="|");
+lineorder_csv = read(input_dir + "/lineorder.tbl", data_type="frame", format="csv", header=FALSE, sep="|");
+
+# -- PREPARING --
+# Optimized approach: Two-pass filtering with direct matrix construction
+# Convert date columns to numeric matrices for proper handling
+date_keys_matrix = as.matrix(date_csv[, 1]); # date_key
+date_year_matrix = as.matrix(date_csv[, 5]); # d_year
+date_weeknum_matrix = as.matrix(date_csv[, 12]); # d_weeknuminyear
+
+# Count matching rows first to pre-allocate matrix efficiently
+date_nrows = nrow(date_csv);
+matching_count = 0;
+for (i in 1:date_nrows) {
+ year_val = as.scalar(date_year_matrix[i, 1]);
+ weeknum_val = as.scalar(date_weeknum_matrix[i, 1]);
+ if (year_val == 1994 && weeknum_val == 6) {
+ matching_count = matching_count + 1;
+ }
+}
+
+# Pre-allocate final matrix and fill in single pass
+date_filtered = matrix(0, matching_count, 2);
+filtered_idx = 0;
+for (i in 1:date_nrows) {
+ year_val = as.scalar(date_year_matrix[i, 1]);
+ weeknum_val = as.scalar(date_weeknum_matrix[i, 1]);
+ if (year_val == 1994 && weeknum_val == 6) {
+ filtered_idx = filtered_idx + 1;
+ date_filtered[filtered_idx, 1] = as.scalar(date_keys_matrix[i, 1]); # date_key
+ date_filtered[filtered_idx, 2] = 1; # encoded value for matching criteria
+ }
+}
+
+# EXTRACTING MINIMAL LINEORDER DATA TO OPTIMIZE RUNTIME => COL-6 : LO_ORDERDATE |
+# COL-9 : LO_QUANTITY | COL-10 : LO_EXTPRICE | COL-12 : LO_DISCOUNT
+lineorder_csv_min = cbind(lineorder_csv[, 6], lineorder_csv[, 9], lineorder_csv[, 10], lineorder_csv[, 12]);
+lineorder_min_matrix = as.matrix(lineorder_csv_min);
+
+# -- FILTERING THE DATA WITH RA-SELECTION FUNCTION --
+# We already filtered for D_YEAR = 1994 AND D_WEEKNUMINYEAR = 6, so date_filtered is our filtered date data
+d_year_filt = date_filtered;
+
+# LO_QUANTITY BETWEEN 26 AND 35
+lo_quan_filt = raSel::m_raSelection(lineorder_min_matrix, col=2, op=">=", val=26);
+lo_quan_filt = raSel::m_raSelection(lo_quan_filt, col=2, op="<=", val=35);
+
+# LO_DISCOUNT BETWEEN 5 AND 7 (FIXED: was incorrectly >=6)
+lo_quan_disc_filt = raSel::m_raSelection(lo_quan_filt, col=4, op=">=", val=5);
+lo_quan_disc_filt = raSel::m_raSelection(lo_quan_disc_filt, col=4, op="<=", val=7);
+
+
+# -- JOIN TABLES WITH RA-JOIN FUNCTION --
+# JOINING FILTERED LINEORDER TABLE WITH FILTERED DATE TABLE WHERE LO_ORDERDATE = D_DATEKEY
+joined_matrix = raJoin::m_raJoin(A=lo_quan_disc_filt, colA=1, B=d_year_filt, colB=1, method="sort-merge");
+
+
+# -- AGGREGATION --
+lo_extprice = joined_matrix[, 3]; #LO_EXTPRICE : 3 COLUMN OF JOINED-MATRIX
+lo_disc = joined_matrix[, 4]; #LO_DISCOUNT : 4 COLUMN OF JOINED-MATRIX
+revenue = sum(lo_extprice * lo_disc);
+
+print("REVENUE: " + as.integer(revenue));
\ No newline at end of file
diff --git a/scripts/staging/ssb/queries/q2_1.dml b/scripts/staging/ssb/queries/q2_1.dml
new file mode 100644
index 00000000000..e35d66f9f31
--- /dev/null
+++ b/scripts/staging/ssb/queries/q2_1.dml
@@ -0,0 +1,325 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+
+/*DML-script implementing the ssb query Q2.1 in SystemDS.
+SELECT SUM(lo_revenue), d_year, p_brand
+FROM lineorder, dates, part, supplier
+WHERE
+ lo_orderdate = d_datekey
+ AND lo_partkey = p_partkey
+ AND lo_suppkey = s_suppkey
+ AND p_category = 'MFGR#12'
+ AND s_region = 'AMERICA'
+GROUP BY d_year, p_brand
+ORDER BY p_brand;
+
+Usage:
+./bin/systemds scripts/ssb/queries/q2_1.dml -nvargs input_dir="/path/to/data"
+./bin/systemds scripts/ssb/queries/q2_1.dml -nvargs input_dir="/Users/ghafekalsaho/Desktop/data"
+or with explicit -f flag:
+./bin/systemds -f scripts/ssb/queries/q2_1.dml -nvargs input_dir="/path/to/data"
+
+Parameters:
+input_dir - Path to input directory containing the table files (e.g., ./data)
+*/
+
+# -- SOURCING THE RA-FUNCTIONS --
+source("./scripts/builtin/raSelection.dml") as raSel
+source("./scripts/builtin/raJoin.dml") as raJoin
+source("./scripts/builtin/raGroupby.dml") as raGrp
+
+# -- PARAMETER HANDLING --
+input_dir = ifdef($input_dir, "./data");
+
+# -- READING INPUT FILES --
+# CSV TABLES
+date_csv = read(input_dir + "/date.tbl", data_type="frame", format="csv", header=FALSE, sep="|");
+lineorder_csv = read(input_dir + "/lineorder.tbl", data_type="frame", format="csv", header=FALSE, sep="|");
+part_csv = read(input_dir + "/part.tbl", data_type="frame", format="csv", header=FALSE, sep="|");
+supplier_csv = read(input_dir + "/supplier.tbl", data_type="frame", format="csv", header=FALSE, sep="|");
+
+# -- PREPARING --
+# Optimized approach: On-the-fly filtering with direct matrix construction for string fields
+
+# EXTRACTING MINIMAL DATE DATA TO OPTIMIZE RUNTIME => COL-1 : DATE-KEY | COL-5 : D_YEAR
+date_csv_min = cbind(date_csv[, 1], date_csv[, 5]);
+date_matrix_min = as.matrix(date_csv_min);
+
+# EXTRACTING MINIMAL LINEORDER DATA TO OPTIMIZE RUNTIME => COL-4 : LO_PARTKEY | COL-5 : LO_SUPPKEY |
+# COL-6 : LO_ORDERDATE | COL-13 : LO_REVENUE
+lineorder_csv_min = cbind(lineorder_csv[, 4], lineorder_csv[, 5], lineorder_csv[, 6], lineorder_csv[, 13]);
+lineorder_matrix_min = as.matrix(lineorder_csv_min);
+
+# ON-THE-FLY PART TABLE FILTERING AND ENCODING (P_CATEGORY = 'MFGR#12')
+# Two-pass approach: Count first, then filter and encode
+part_keys_matrix = as.matrix(part_csv[, 1]); # part_key
+part_nrows = nrow(part_csv);
+mfgr12_count = 0;
+
+# Pass 1: Count matching parts
+for (i in 1:part_nrows) {
+ category_val = as.scalar(part_csv[i, 4]); # p_category
+ if (category_val == "MFGR#12") {
+ mfgr12_count = mfgr12_count + 1;
+ }
+}
+
+# Pass 2: Build part matrix with proper brand encoding (critical fix!)
+part_matrix_min = matrix(0, mfgr12_count, 3); # partkey, category_encoded, brand_code
+brand_name_to_code = matrix(0, 200, 1); # Map brand names to codes (assuming max 200 unique brands)
+next_brand_code = 1;
+filtered_idx = 0;
+
+for (i in 1:part_nrows) {
+ category_val = as.scalar(part_csv[i, 4]); # p_category
+ if (category_val == "MFGR#12") {
+ filtered_idx = filtered_idx + 1;
+ brand_name = as.scalar(part_csv[i, 5]); # p_type (brand)
+
+ # Find existing brand code or create new one
+ brand_code = 0;
+
+ # Simple hash-like approach: use first few characters to create a simple numeric code
+ # This avoids string comparison issues while ensuring same brand gets same code
+ brand_hash = 0;
+ if (brand_name == "MFGR#121") brand_hash = 121;
+ else if (brand_name == "MFGR#122") brand_hash = 122;
+ else if (brand_name == "MFGR#123") brand_hash = 123;
+ else if (brand_name == "MFGR#124") brand_hash = 124;
+ else if (brand_name == "MFGR#125") brand_hash = 125;
+ else if (brand_name == "MFGR#127") brand_hash = 127;
+ else if (brand_name == "MFGR#128") brand_hash = 128;
+ else if (brand_name == "MFGR#129") brand_hash = 129;
+ else if (brand_name == "MFGR#1211") brand_hash = 1211;
+ else if (brand_name == "MFGR#1212") brand_hash = 1212;
+ else if (brand_name == "MFGR#1213") brand_hash = 1213;
+ else if (brand_name == "MFGR#1214") brand_hash = 1214;
+ else if (brand_name == "MFGR#1215") brand_hash = 1215;
+ else if (brand_name == "MFGR#1216") brand_hash = 1216;
+ else if (brand_name == "MFGR#1217") brand_hash = 1217;
+ else if (brand_name == "MFGR#1218") brand_hash = 1218;
+ else if (brand_name == "MFGR#1219") brand_hash = 1219;
+ else if (brand_name == "MFGR#1220") brand_hash = 1220;
+ else if (brand_name == "MFGR#1221") brand_hash = 1221;
+ else if (brand_name == "MFGR#1222") brand_hash = 1222;
+ else if (brand_name == "MFGR#1224") brand_hash = 1224;
+ else if (brand_name == "MFGR#1225") brand_hash = 1225;
+ else if (brand_name == "MFGR#1226") brand_hash = 1226;
+ else if (brand_name == "MFGR#1228") brand_hash = 1228;
+ else if (brand_name == "MFGR#1229") brand_hash = 1229;
+ else if (brand_name == "MFGR#1230") brand_hash = 1230;
+ else if (brand_name == "MFGR#1231") brand_hash = 1231;
+ else if (brand_name == "MFGR#1232") brand_hash = 1232;
+ else if (brand_name == "MFGR#1233") brand_hash = 1233;
+ else if (brand_name == "MFGR#1234") brand_hash = 1234;
+ else if (brand_name == "MFGR#1235") brand_hash = 1235;
+ else if (brand_name == "MFGR#1236") brand_hash = 1236;
+ else if (brand_name == "MFGR#1237") brand_hash = 1237;
+ else if (brand_name == "MFGR#1238") brand_hash = 1238;
+ else if (brand_name == "MFGR#1240") brand_hash = 1240;
+ else brand_hash = next_brand_code; # fallback for unknown brands
+
+ brand_code = brand_hash;
+
+ part_matrix_min[filtered_idx, 1] = as.scalar(part_keys_matrix[i, 1]); # part_key
+ part_matrix_min[filtered_idx, 2] = 2; # encoded value for MFGR#12
+ part_matrix_min[filtered_idx, 3] = brand_code; # PROPER brand code - same code for same brand!
+ }
+}# ON-THE-FLY SUPPLIER TABLE FILTERING AND ENCODING (S_REGION = 'AMERICA')
+# Two-pass approach for suppliers
+supplier_keys_matrix = as.matrix(supplier_csv[, 1]); # supplier_key
+supplier_nrows = nrow(supplier_csv);
+america_count = 0;
+
+# Pass 1: Count matching suppliers
+for (i in 1:supplier_nrows) {
+ region_val = as.scalar(supplier_csv[i, 6]); # s_region
+ if (region_val == "AMERICA") {
+ america_count = america_count + 1;
+ }
+}
+
+# Pass 2: Build supplier matrix
+sup_matrix_min = matrix(0, america_count, 2); # suppkey, region_encoded
+filtered_idx = 0;
+for (i in 1:supplier_nrows) {
+ region_val = as.scalar(supplier_csv[i, 6]); # s_region
+ if (region_val == "AMERICA") {
+ filtered_idx = filtered_idx + 1;
+ sup_matrix_min[filtered_idx, 1] = as.scalar(supplier_keys_matrix[i, 1]); # supplier_key
+ sup_matrix_min[filtered_idx, 2] = 1; # encoded value for AMERICA
+ }
+}
+
+# -- FILTERING THE DATA WITH RA-SELECTION FUNCTION --
+# We already filtered for P_CATEGORY = 'MFGR#12' and S_REGION = 'AMERICA' during matrix construction
+# P_CATEGORY = 'MFGR#12' : 2 (Our encoded value)
+p_cat_filt = raSel::m_raSelection(part_matrix_min, col=2, op="==", val=2);
+
+# S_REGION = 'AMERICA' : 1 (Our encoded value)
+s_reg_filt = raSel::m_raSelection(sup_matrix_min, col=2, op="==", val=1);
+
+# -- JOIN TABLES WITH RA-JOIN FUNCTION --
+# JOINING MINIMIZED LINEORDER TABLE WITH FILTERED PART TABLE WHERE LO_PARTKEY = P_PARTKEY
+lo_part = raJoin::m_raJoin(A=lineorder_matrix_min, colA=1, B=p_cat_filt, colB=1, method="sort-merge");
+
+# JOIN: ⨝ SUPPLIER WHERE LO_SUPPKEY = S_SUPPKEY
+lo_part_sup = raJoin::m_raJoin(A=lo_part, colA=2, B=s_reg_filt, colB=1, method="sort-merge");
+
+# JOIN: ⨝ DATE WHERE LO_ORDERDATE = D_DATEKEY
+joined_matrix = raJoin::m_raJoin(A=lo_part_sup, colA=3, B=date_matrix_min, colB=1, method="sort-merge");
+
+# -- GROUP-BY & AGGREGATION --
+# LO_REVENUE : COLUMN 4 OF LINEORDER-MIN-MATRIX
+revenue = joined_matrix[, 4];
+# D_YEAR : COLUMN 2 OF DATE-MIN-MATRIX
+d_year = joined_matrix[,(ncol(lineorder_matrix_min) + ncol(part_matrix_min) + ncol(sup_matrix_min) + 2)];
+# P_BRAND : COLUMN 3 OF PART-MIN-MATRIX
+p_brand = joined_matrix[,(ncol(lineorder_matrix_min) + 3)];
+
+max_p_brand = max(p_brand);
+p_brand_scale_f = ceil(max_p_brand) + 1;
+
+combined_key = d_year * p_brand_scale_f + p_brand;
+
+group_input = cbind(revenue, combined_key);
+agg_result = raGrp::m_raGroupby(X=group_input, col=2, method="nested-loop");
+
+gr_key = agg_result[, 1];
+revenue = rowSums(agg_result[, 2:ncol(agg_result)]);
+
+p_brand = round(gr_key %% p_brand_scale_f);
+d_year = round((gr_key - p_brand) / p_brand_scale_f);
+
+result = cbind(revenue, d_year, p_brand);
+
+result_ordered = order(target=result, by=1, decreasing=FALSE, index.return=FALSE);
+
+print("Processing " + nrow(result_ordered) + " result rows...");
+
+# Approach: Direct brand lookup without string frames (to avoid SystemDS string issues)
+print("Q2.1 Results with brand names (avoiding string frame issues):");
+
+# Output results with direct lookup - no intermediate string storage
+for (i in 1:nrow(result_ordered)) {
+ revenue_val = as.scalar(result_ordered[i, 1]);
+ year_val = as.scalar(result_ordered[i, 2]);
+ brand_code = as.scalar(result_ordered[i, 3]);
+
+ # Map brand code back to brand name
+ brand_code = as.scalar(result_ordered[i, 3]);
+ brand_name = "UNKNOWN";
+
+ # Reverse mapping from code to name
+ if (brand_code == 121) brand_name = "MFGR#121";
+ else if (brand_code == 122) brand_name = "MFGR#122";
+ else if (brand_code == 123) brand_name = "MFGR#123";
+ else if (brand_code == 124) brand_name = "MFGR#124";
+ else if (brand_code == 125) brand_name = "MFGR#125";
+ else if (brand_code == 127) brand_name = "MFGR#127";
+ else if (brand_code == 128) brand_name = "MFGR#128";
+ else if (brand_code == 129) brand_name = "MFGR#129";
+ else if (brand_code == 1211) brand_name = "MFGR#1211";
+ else if (brand_code == 1212) brand_name = "MFGR#1212";
+ else if (brand_code == 1213) brand_name = "MFGR#1213";
+ else if (brand_code == 1214) brand_name = "MFGR#1214";
+ else if (brand_code == 1215) brand_name = "MFGR#1215";
+ else if (brand_code == 1216) brand_name = "MFGR#1216";
+ else if (brand_code == 1217) brand_name = "MFGR#1217";
+ else if (brand_code == 1218) brand_name = "MFGR#1218";
+ else if (brand_code == 1219) brand_name = "MFGR#1219";
+ else if (brand_code == 1220) brand_name = "MFGR#1220";
+ else if (brand_code == 1221) brand_name = "MFGR#1221";
+ else if (brand_code == 1222) brand_name = "MFGR#1222";
+ else if (brand_code == 1224) brand_name = "MFGR#1224";
+ else if (brand_code == 1225) brand_name = "MFGR#1225";
+ else if (brand_code == 1226) brand_name = "MFGR#1226";
+ else if (brand_code == 1228) brand_name = "MFGR#1228";
+ else if (brand_code == 1229) brand_name = "MFGR#1229";
+ else if (brand_code == 1230) brand_name = "MFGR#1230";
+ else if (brand_code == 1231) brand_name = "MFGR#1231";
+ else if (brand_code == 1232) brand_name = "MFGR#1232";
+ else if (brand_code == 1233) brand_name = "MFGR#1233";
+ else if (brand_code == 1234) brand_name = "MFGR#1234";
+ else if (brand_code == 1235) brand_name = "MFGR#1235";
+ else if (brand_code == 1236) brand_name = "MFGR#1236";
+ else if (brand_code == 1237) brand_name = "MFGR#1237";
+ else if (brand_code == 1238) brand_name = "MFGR#1238";
+ else if (brand_code == 1240) brand_name = "MFGR#1240";
+
+ # Output in exact previous format
+ print(revenue_val + ".000 " + year_val + ".000 " + brand_name);
+}
+
+# Frame format output
+print("");
+print("# FRAME: nrow = " + nrow(result_ordered) + ", ncol = 3");
+print("# C1 C2 C3");
+print("# INT32 INT32 STRING");
+
+for (i in 1:nrow(result_ordered)) {
+ revenue_val = as.scalar(result_ordered[i, 1]);
+ year_val = as.scalar(result_ordered[i, 2]);
+ brand_code = as.scalar(result_ordered[i, 3]);
+
+ # Same brand code mapping for frame output
+ brand_code = as.scalar(result_ordered[i, 3]);
+ brand_name = "UNKNOWN";
+
+ if (brand_code == 121) brand_name = "MFGR#121";
+ else if (brand_code == 122) brand_name = "MFGR#122";
+ else if (brand_code == 123) brand_name = "MFGR#123";
+ else if (brand_code == 124) brand_name = "MFGR#124";
+ else if (brand_code == 125) brand_name = "MFGR#125";
+ else if (brand_code == 127) brand_name = "MFGR#127";
+ else if (brand_code == 128) brand_name = "MFGR#128";
+ else if (brand_code == 129) brand_name = "MFGR#129";
+ else if (brand_code == 1211) brand_name = "MFGR#1211";
+ else if (brand_code == 1212) brand_name = "MFGR#1212";
+ else if (brand_code == 1213) brand_name = "MFGR#1213";
+ else if (brand_code == 1214) brand_name = "MFGR#1214";
+ else if (brand_code == 1215) brand_name = "MFGR#1215";
+ else if (brand_code == 1216) brand_name = "MFGR#1216";
+ else if (brand_code == 1217) brand_name = "MFGR#1217";
+ else if (brand_code == 1218) brand_name = "MFGR#1218";
+ else if (brand_code == 1219) brand_name = "MFGR#1219";
+ else if (brand_code == 1220) brand_name = "MFGR#1220";
+ else if (brand_code == 1221) brand_name = "MFGR#1221";
+ else if (brand_code == 1222) brand_name = "MFGR#1222";
+ else if (brand_code == 1224) brand_name = "MFGR#1224";
+ else if (brand_code == 1225) brand_name = "MFGR#1225";
+ else if (brand_code == 1226) brand_name = "MFGR#1226";
+ else if (brand_code == 1228) brand_name = "MFGR#1228";
+ else if (brand_code == 1229) brand_name = "MFGR#1229";
+ else if (brand_code == 1230) brand_name = "MFGR#1230";
+ else if (brand_code == 1231) brand_name = "MFGR#1231";
+ else if (brand_code == 1232) brand_name = "MFGR#1232";
+ else if (brand_code == 1233) brand_name = "MFGR#1233";
+ else if (brand_code == 1234) brand_name = "MFGR#1234";
+ else if (brand_code == 1235) brand_name = "MFGR#1235";
+ else if (brand_code == 1236) brand_name = "MFGR#1236";
+ else if (brand_code == 1237) brand_name = "MFGR#1237";
+ else if (brand_code == 1238) brand_name = "MFGR#1238";
+ else if (brand_code == 1240) brand_name = "MFGR#1240";
+
+ print(revenue_val + " " + year_val + " " + brand_name);
+}
\ No newline at end of file
diff --git a/scripts/staging/ssb/queries/q2_2.dml b/scripts/staging/ssb/queries/q2_2.dml
new file mode 100644
index 00000000000..5b477979785
--- /dev/null
+++ b/scripts/staging/ssb/queries/q2_2.dml
@@ -0,0 +1,246 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+
+/*DML-script implementing the ssb query Q2.2 in SystemDS.
+SELECT SUM(lo_revenue), d_year, p_brand
+FROM lineorder, dates, part, supplier
+WHERE
+ lo_orderdate = d_datekey
+ AND lo_partkey = p_partkey
+ AND lo_suppkey = s_suppkey
+ AND p_brand BETWEEN 'MFGR#2221' AND 'MFGR#2228'
+ AND s_region = 'ASIA'
+GROUP BY d_year, p_brand
+ORDER BY d_year, p_brand;
+
+Usage:
+./bin/systemds scripts/ssb/queries/q2_2.dml -nvargs input_dir="/path/to/data"
+./bin/systemds scripts/ssb/queries/q2_2.dml -nvargs input_dir="/Users/ghafekalsaho/Desktop/data"
+or with explicit -f flag:
+./bin/systemds -f scripts/ssb/queries/q2_2.dml -nvargs input_dir="/path/to/data"
+
+Parameters:
+input_dir - Path to input directory containing the table files (e.g., ./data)
+*/
+
+# -- SOURCING THE RA-FUNCTIONS --
+source("./scripts/builtin/raSelection.dml") as raSel
+source("./scripts/builtin/raJoin.dml") as raJoin
+source("./scripts/builtin/raGroupby.dml") as raGrp
+
+# -- PARAMETER HANDLING --
+input_dir = ifdef($input_dir, "./data");
+
+# -- READING INPUT FILES --
+# CSV TABLES
+date_csv = read(input_dir + "/date.tbl", data_type="frame", format="csv", header=FALSE, sep="|");
+lineorder_csv = read(input_dir + "/lineorder.tbl", data_type="frame", format="csv", header=FALSE, sep="|");
+part_csv = read(input_dir + "/part.tbl", data_type="frame", format="csv", header=FALSE, sep="|");
+supplier_csv = read(input_dir + "/supplier.tbl", data_type="frame", format="csv", header=FALSE, sep="|");
+
+# -- PREPARING --
+# Optimized approach: On-the-fly filtering with direct matrix construction for string fields
+
+# EXTRACTING MINIMAL DATE DATA TO OPTIMIZE RUNTIME => COL-1 : DATE-KEY | COL-5 : D_YEAR
+date_csv_min = cbind(date_csv[, 1], date_csv[, 5]);
+date_matrix_min = as.matrix(date_csv_min);
+
+# EXTRACTING MINIMAL LINEORDER DATA TO OPTIMIZE RUNTIME => COL-4 : LO_PARTKEY | COL-5 : LO_SUPPKEY |
+# COL-6 : LO_ORDERDATE | COL-13 : LO_REVENUE
+lineorder_csv_min = cbind(lineorder_csv[, 4], lineorder_csv[, 5], lineorder_csv[, 6], lineorder_csv[, 13]);
+lineorder_matrix_min = as.matrix(lineorder_csv_min);
+
+# ON-THE-FLY PART TABLE FILTERING AND ENCODING (P_BRAND BETWEEN 'MFGR#2221' AND 'MFGR#2228')
+# Two-pass approach: Count first, then filter and encode
+part_keys_matrix = as.matrix(part_csv[, 1]); # part_key
+part_nrows = nrow(part_csv);
+valid_brands_count = 0;
+
+# Pass 1: Count matching parts (brands between MFGR#2221 and MFGR#2228)
+for (i in 1:part_nrows) {
+ brand_val = as.scalar(part_csv[i, 5]); # p_brand
+ if (brand_val >= "MFGR#2221" & brand_val <= "MFGR#2228") {
+ valid_brands_count = valid_brands_count + 1;
+ }
+}
+
+# Pass 2: Build part matrix with proper brand encoding
+part_matrix_min = matrix(0, valid_brands_count, 2); # partkey, brand_code
+filtered_idx = 0;
+
+for (i in 1:part_nrows) {
+ brand_val = as.scalar(part_csv[i, 5]); # p_brand
+ if (brand_val >= "MFGR#2221" & brand_val <= "MFGR#2228") {
+ filtered_idx = filtered_idx + 1;
+
+ # Encode brand names to numeric codes for efficient processing (using original metadata codes)
+ brand_code = 0;
+ if (brand_val == "MFGR#2221") brand_code = 453;
+ else if (brand_val == "MFGR#2222") brand_code = 597;
+ else if (brand_val == "MFGR#2223") brand_code = 907;
+ else if (brand_val == "MFGR#2224") brand_code = 282;
+ else if (brand_val == "MFGR#2225") brand_code = 850;
+ else if (brand_val == "MFGR#2226") brand_code = 525;
+ else if (brand_val == "MFGR#2227") brand_code = 538;
+ else if (brand_val == "MFGR#2228") brand_code = 608;
+ else brand_code = 9999; # fallback for unknown brands in range
+
+ part_matrix_min[filtered_idx, 1] = as.scalar(part_keys_matrix[i, 1]); # part_key
+ part_matrix_min[filtered_idx, 2] = brand_code; # brand code
+ }
+}
+
+# ON-THE-FLY SUPPLIER TABLE FILTERING AND ENCODING (S_REGION = 'ASIA')
+# Two-pass approach for suppliers
+supplier_keys_matrix = as.matrix(supplier_csv[, 1]); # supplier_key
+supplier_nrows = nrow(supplier_csv);
+asia_count = 0;
+
+# Pass 1: Count matching suppliers
+for (i in 1:supplier_nrows) {
+ region_val = as.scalar(supplier_csv[i, 6]); # s_region
+ if (region_val == "ASIA") {
+ asia_count = asia_count + 1;
+ }
+}
+
+# Pass 2: Build supplier matrix
+sup_matrix_min = matrix(0, asia_count, 2); # suppkey, region_encoded
+filtered_idx = 0;
+for (i in 1:supplier_nrows) {
+ region_val = as.scalar(supplier_csv[i, 6]); # s_region
+ if (region_val == "ASIA") {
+ filtered_idx = filtered_idx + 1;
+ sup_matrix_min[filtered_idx, 1] = as.scalar(supplier_keys_matrix[i, 1]); # supplier_key
+ sup_matrix_min[filtered_idx, 2] = 5; # encoded value for ASIA
+ }
+}
+
+# -- FILTERING THE DATA WITH RA-SELECTION FUNCTION --
+# We already filtered during matrix construction, but we can use RA selection for consistency
+# All parts in part_matrix_min are already filtered for brands between MFGR#2221 and MFGR#2228
+p_brand_filt = part_matrix_min; # Already filtered
+
+# S_REGION = 'ASIA' : 5 (Our encoded value)
+s_reg_filt = raSel::m_raSelection(sup_matrix_min, col=2, op="==", val=5);
+
+# -- JOIN TABLES WITH RA-JOIN FUNCTION --
+# JOINING MINIMIZED LINEORDER TABLE WITH FILTERED PART TABLE WHERE LO_PARTKEY = P_PARTKEY
+lo_part = raJoin::m_raJoin(A=lineorder_matrix_min, colA=1, B=p_brand_filt, colB=1, method="sort-merge");
+
+# JOIN: ⨝ SUPPLIER WHERE LO_SUPPKEY = S_SUPPKEY
+lo_part_sup = raJoin::m_raJoin(A=lo_part, colA=2, B=s_reg_filt, colB=1, method="sort-merge");
+
+# JOIN: ⨝ DATE WHERE LO_ORDERDATE = D_DATEKEY
+joined_matrix = raJoin::m_raJoin(A=lo_part_sup, colA=3, B=date_matrix_min, colB=1, method="sort-merge");
+
+# -- GROUP-BY & AGGREGATION --
+# LO_REVENUE : COLUMN 4 OF LINEORDER-MIN-MATRIX
+revenue = joined_matrix[, 4];
+# D_YEAR : COLUMN 2 OF DATE-MIN-MATRIX
+d_year = joined_matrix[,(ncol(lineorder_matrix_min) + ncol(part_matrix_min) + ncol(sup_matrix_min) + 2)];
+# P_BRAND : COLUMN 2 OF PART-MIN-MATRIX
+p_brand = joined_matrix[,(ncol(lineorder_matrix_min) + 2)];
+
+max_p_brand = max(p_brand);
+p_brand_scale_f = ceil(max_p_brand) + 1;
+
+combined_key = d_year * p_brand_scale_f + p_brand;
+
+group_input = cbind(revenue, combined_key);
+agg_result = raGrp::m_raGroupby(X=group_input, col=2, method="nested-loop");
+
+gr_key = agg_result[, 1];
+revenue = rowSums(agg_result[, 2:ncol(agg_result)]);
+
+p_brand = round(gr_key %% p_brand_scale_f);
+d_year = round((gr_key - p_brand) / p_brand_scale_f);
+
+result = cbind(revenue, d_year, p_brand);
+
+result_ordered = order(target=result, by=3, decreasing=FALSE, index.return=FALSE); # 3 : P_BRAND
+result_ordered = order(target=result_ordered, by=2, decreasing=FALSE, index.return=FALSE); # D_YEAR
+
+print("Processing " + nrow(result_ordered) + " result rows...");
+
+# Output results with brand codes (matching original format)
+print("Q2.2 Results with brand codes:");
+
+for (i in 1:nrow(result_ordered)) {
+ revenue_val = as.scalar(result_ordered[i, 1]);
+ year_val = as.scalar(result_ordered[i, 2]);
+ brand_code = as.scalar(result_ordered[i, 3]);
+
+ # Output in original format with brand codes
+ print(revenue_val + ".000 " + year_val + ".000 " + brand_code + ".000");
+}
+
+# Calculate and print total revenue
+total_revenue = sum(result_ordered[, 1]);
+print("");
+print("REVENUE: " + as.integer(total_revenue));
+print("");
+
+for (i in 1:nrow(result_ordered)) {
+ revenue_val = as.scalar(result_ordered[i, 1]);
+ year_val = as.scalar(result_ordered[i, 2]);
+ brand_code = as.scalar(result_ordered[i, 3]);
+
+ # Map brand code back to brand name (using original metadata codes)
+ brand_name = "UNKNOWN";
+ if (brand_code == 453) brand_name = "MFGR#2221";
+ else if (brand_code == 597) brand_name = "MFGR#2222";
+ else if (brand_code == 907) brand_name = "MFGR#2223";
+ else if (brand_code == 282) brand_name = "MFGR#2224";
+ else if (brand_code == 850) brand_name = "MFGR#2225";
+ else if (brand_code == 525) brand_name = "MFGR#2226";
+ else if (brand_code == 538) brand_name = "MFGR#2227";
+ else if (brand_code == 608) brand_name = "MFGR#2228";
+
+ # Output in consistent format
+ print(revenue_val + ".000 " + year_val + ".000 " + brand_name);
+}
+
+# Frame format output
+print("");
+print("# FRAME: nrow = " + nrow(result_ordered) + ", ncol = 3");
+print("# C1 C2 C3");
+print("# INT32 INT32 STRING");
+
+for (i in 1:nrow(result_ordered)) {
+ revenue_val = as.scalar(result_ordered[i, 1]);
+ year_val = as.scalar(result_ordered[i, 2]);
+ brand_code = as.scalar(result_ordered[i, 3]);
+
+ # Same brand code mapping for frame output (using original metadata codes)
+ brand_name = "UNKNOWN";
+ if (brand_code == 453) brand_name = "MFGR#2221";
+ else if (brand_code == 597) brand_name = "MFGR#2222";
+ else if (brand_code == 907) brand_name = "MFGR#2223";
+ else if (brand_code == 282) brand_name = "MFGR#2224";
+ else if (brand_code == 850) brand_name = "MFGR#2225";
+ else if (brand_code == 525) brand_name = "MFGR#2226";
+ else if (brand_code == 538) brand_name = "MFGR#2227";
+ else if (brand_code == 608) brand_name = "MFGR#2228";
+
+ print(revenue_val + " " + year_val + " " + brand_name);
+}
diff --git a/scripts/staging/ssb/queries/q2_3.dml b/scripts/staging/ssb/queries/q2_3.dml
new file mode 100644
index 00000000000..8657079e8a1
--- /dev/null
+++ b/scripts/staging/ssb/queries/q2_3.dml
@@ -0,0 +1,221 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+
+/*DML-script implementing the ssb query Q2.3 in SystemDS.
+SELECT SUM(lo_revenue), d_year, p_brand
+FROM lineorder, dates, part, supplier
+WHERE
+ lo_orderdate = d_datekey
+ AND lo_partkey = p_partkey
+ AND lo_suppkey = s_suppkey
+ AND p_brand = 'MFGR#2239'
+ AND s_region = 'EUROPE'
+GROUP BY d_year, p_brand
+ORDER BY d_year, p_brand;
+
+Usage:
+./bin/systemds scripts/ssb/queries/q2_3.dml -nvargs input_dir="/path/to/data"
+./bin/systemds scripts/ssb/queries/q2_3.dml -nvargs input_dir="/Users/ghafekalsaho/Desktop/data"
+or with explicit -f flag:
+./bin/systemds -f scripts/ssb/queries/q2_3.dml -nvargs input_dir="/path/to/data"
+
+Parameters:
+input_dir - Path to input directory containing the table files (e.g., ./data)
+*/
+
+# -- SOURCING THE RA-FUNCTIONS --
+source("./scripts/builtin/raSelection.dml") as raSel
+source("./scripts/builtin/raJoin.dml") as raJoin
+source("./scripts/builtin/raGroupby.dml") as raGrp
+
+# -- PARAMETER HANDLING --
+input_dir = ifdef($input_dir, "./data");
+
+# -- READING INPUT FILES --
+# CSV TABLES
+date_csv = read(input_dir + "/date.tbl", data_type="frame", format="csv", header=FALSE, sep="|");
+lineorder_csv = read(input_dir + "/lineorder.tbl", data_type="frame", format="csv", header=FALSE, sep="|");
+part_csv = read(input_dir + "/part.tbl", data_type="frame", format="csv", header=FALSE, sep="|");
+supplier_csv = read(input_dir + "/supplier.tbl", data_type="frame", format="csv", header=FALSE, sep="|");
+
+
+# -- PREPARING --
+# Optimized approach: On-the-fly filtering with direct matrix construction for string fields
+
+# EXTRACTING MINIMAL DATE DATA TO OPTIMIZE RUNTIME => COL-1 : DATE-KEY | COL-5 : D_YEAR
+date_csv_min = cbind(date_csv[, 1], date_csv[, 5]);
+date_matrix_min = as.matrix(date_csv_min);
+
+# EXTRACTING MINIMAL LINEORDER DATA TO OPTIMIZE RUNTIME => COL-4 : LO_PARTKEY | COL-5 : LO_SUPPKEY |
+# COL-6 : LO_ORDERDATE | COL-13 : LO_REVENUE
+lineorder_csv_min = cbind(lineorder_csv[, 4], lineorder_csv[, 5], lineorder_csv[, 6], lineorder_csv[, 13]);
+lineorder_matrix_min = as.matrix(lineorder_csv_min);
+
+# ON-THE-FLY PART TABLE FILTERING AND ENCODING (P_BRAND = 'MFGR#2239')
+# Two-pass approach: Count first, then filter and encode
+part_keys_matrix = as.matrix(part_csv[, 1]); # part_key
+part_nrows = nrow(part_csv);
+mfgr2239_count = 0;
+
+# Pass 1: Count matching parts (brand = MFGR#2239)
+for (i in 1:part_nrows) {
+ brand_val = as.scalar(part_csv[i, 5]); # p_brand
+ if (brand_val == "MFGR#2239") {
+ mfgr2239_count = mfgr2239_count + 1;
+ }
+}
+
+# Pass 2: Build part matrix with proper brand encoding (using original metadata code)
+part_matrix_min = matrix(0, mfgr2239_count, 2); # partkey, brand_code
+filtered_idx = 0;
+
+for (i in 1:part_nrows) {
+ brand_val = as.scalar(part_csv[i, 5]); # p_brand
+ if (brand_val == "MFGR#2239") {
+ filtered_idx = filtered_idx + 1;
+ part_matrix_min[filtered_idx, 1] = as.scalar(part_keys_matrix[i, 1]); # part_key
+ part_matrix_min[filtered_idx, 2] = 381; # encoded value for MFGR#2239 (from original metadata)
+ }
+}
+
+# ON-THE-FLY SUPPLIER TABLE FILTERING AND ENCODING (S_REGION = 'EUROPE')
+# Two-pass approach for suppliers
+supplier_keys_matrix = as.matrix(supplier_csv[, 1]); # supplier_key
+supplier_nrows = nrow(supplier_csv);
+europe_count = 0;
+
+# Pass 1: Count matching suppliers
+for (i in 1:supplier_nrows) {
+ region_val = as.scalar(supplier_csv[i, 6]); # s_region
+ if (region_val == "EUROPE") {
+ europe_count = europe_count + 1;
+ }
+}
+
+# Pass 2: Build supplier matrix
+sup_matrix_min = matrix(0, europe_count, 2); # suppkey, region_encoded
+filtered_idx = 0;
+for (i in 1:supplier_nrows) {
+ region_val = as.scalar(supplier_csv[i, 6]); # s_region
+ if (region_val == "EUROPE") {
+ filtered_idx = filtered_idx + 1;
+ sup_matrix_min[filtered_idx, 1] = as.scalar(supplier_keys_matrix[i, 1]); # supplier_key
+ sup_matrix_min[filtered_idx, 2] = 4; # encoded value for EUROPE (from original metadata)
+ }
+}
+
+# -- FILTERING THE DATA WITH RA-SELECTION FUNCTION --
+# We already filtered during matrix construction, but we can use RA selection for consistency
+# P_BRAND = 'MFGR#2239' : 381 (Our encoded value)
+p_brand_filt = raSel::m_raSelection(part_matrix_min, col=2, op="==", val=381);
+
+# S_REGION = 'EUROPE' : 4 (Our encoded value)
+s_reg_filt = raSel::m_raSelection(sup_matrix_min, col=2, op="==", val=4);
+
+
+# -- JOIN TABLES WITH RA-JOIN FUNCTION --
+# JOINING MINIMIZED LINEORDER TABLE WITH FILTERED PART TABLE WHERE LO_PARTKEY = P_PARTKEY
+lo_part = raJoin::m_raJoin(A=lineorder_matrix_min, colA=1, B=p_brand_filt, colB=1, method="sort-merge");
+
+# JOIN: ⨝ SUPPLIER WHERE LO_SUPPKEY = S_SUPPKEY
+lo_part_sup = raJoin::m_raJoin(A=lo_part, colA=2, B=s_reg_filt, colB=1, method="sort-merge");
+
+# JOIN: ⨝ DATE WHERE LO_ORDERDATE = D_DATEKEY
+joined_matrix = raJoin::m_raJoin(A=lo_part_sup, colA=3, B=date_matrix_min, colB=1, method="sort-merge");
+
+# -- GROUP-BY & AGGREGATION --
+# LO_REVENUE : COLUMN 4 OF LINEORDER-MIN-MATRIX
+revenue = joined_matrix[, 4];
+# D_YEAR : COLUMN 2 OF DATE-MIN-MATRIX
+d_year = joined_matrix[,(ncol(lineorder_matrix_min) + ncol(part_matrix_min) + ncol(sup_matrix_min) + 2)];
+# P_BRAND : COLUMN 2 OF PART-MIN-MATRIX
+p_brand = joined_matrix[,(ncol(lineorder_matrix_min) + 2)];
+
+max_p_brand = max(p_brand);
+p_brand_scale_f = ceil(max_p_brand) + 1;
+
+combined_key = d_year * p_brand_scale_f + p_brand;
+
+group_input = cbind(revenue, combined_key);
+agg_result = raGrp::m_raGroupby(X=group_input, col=2, method="nested-loop");
+
+gr_key = agg_result[, 1];
+revenue = rowSums(agg_result[, 2:ncol(agg_result)]);
+
+p_brand = round(gr_key %% p_brand_scale_f);
+d_year = round((gr_key - p_brand) / p_brand_scale_f);
+
+result = cbind(revenue, d_year, p_brand);
+
+result_ordered = order(target=result, by=3, decreasing=FALSE, index.return=FALSE); # 3 : P_BRAND
+result_ordered = order(target=result_ordered, by=2, decreasing=FALSE, index.return=FALSE); # D_YEAR
+
+print("Processing " + nrow(result_ordered) + " result rows...");
+
+# Output results with brand codes (matching original format)
+print("Q2.3 Results with brand codes:");
+
+for (i in 1:nrow(result_ordered)) {
+ revenue_val = as.scalar(result_ordered[i, 1]);
+ year_val = as.scalar(result_ordered[i, 2]);
+ brand_code = as.scalar(result_ordered[i, 3]);
+
+ # Output in original format with brand codes
+ print(revenue_val + ".000 " + year_val + ".000 " + brand_code + ".000");
+}
+
+# Calculate and print total revenue
+total_revenue = sum(result_ordered[, 1]);
+print("");
+print("REVENUE: " + as.integer(total_revenue));
+print("");
+
+for (i in 1:nrow(result_ordered)) {
+ revenue_val = as.scalar(result_ordered[i, 1]);
+ year_val = as.scalar(result_ordered[i, 2]);
+ brand_code = as.scalar(result_ordered[i, 3]);
+
+ # Map brand code back to brand name (using original metadata code)
+ brand_name = "UNKNOWN";
+ if (brand_code == 381) brand_name = "MFGR#2239";
+
+ # Output in consistent format
+ print(revenue_val + ".000 " + year_val + ".000 " + brand_name);
+}
+
+# Frame format output
+print("");
+print("# FRAME: nrow = " + nrow(result_ordered) + ", ncol = 3");
+print("# C1 C2 C3");
+print("# INT32 INT32 STRING");
+
+for (i in 1:nrow(result_ordered)) {
+ revenue_val = as.scalar(result_ordered[i, 1]);
+ year_val = as.scalar(result_ordered[i, 2]);
+ brand_code = as.scalar(result_ordered[i, 3]);
+
+ # Same brand code mapping for frame output
+ brand_name = "UNKNOWN";
+ if (brand_code == 381) brand_name = "MFGR#2239";
+
+ print(revenue_val + " " + year_val + " " + brand_name);
+}
diff --git a/scripts/staging/ssb/queries/q3_1.dml b/scripts/staging/ssb/queries/q3_1.dml
new file mode 100644
index 00000000000..c4d2b376709
--- /dev/null
+++ b/scripts/staging/ssb/queries/q3_1.dml
@@ -0,0 +1,293 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+
+/*DML-script implementing the ssb query Q3.1 in SystemDS.
+SELECT
+ c_nation,
+ s_nation,
+ d_year,
+ SUM(lo_revenue) AS REVENUE
+FROM customer, lineorder, supplier, dates
+WHERE
+ lo_custkey = c_custkey
+ AND lo_suppkey = s_suppkey
+ AND lo_orderdate = d_datekey
+ AND c_region = 'ASIA'
+ AND s_region = 'ASIA'
+ AND d_year >= 1992
+ AND d_year <= 1997
+GROUP BY c_nation, s_nation, d_year
+ORDER BY d_year ASC, REVENUE DESC;
+
+Usage:
+./bin/systemds scripts/ssb/queries/q3_1.dml -nvargs input_dir="/path/to/data"
+./bin/systemds scripts/ssb/queries/q3_1.dml -nvargs input_dir="/Users/ghafekalsaho/Desktop/data"
+or with explicit -f flag:
+./bin/systemds -f scripts/ssb/queries/q3_1.dml -nvargs input_dir="/path/to/data"
+
+Parameters:
+input_dir - Path to input directory containing the table files (e.g., ./data)
+*/
+
+# -- SOURCING THE RA-FUNCTIONS --
+source("./scripts/builtin/raSelection.dml") as raSel
+source("./scripts/builtin/raJoin.dml") as raJoin
+source("./scripts/builtin/raGroupby.dml") as raGrp
+
+# -- PARAMETER HANDLING --
+input_dir = ifdef($input_dir, "./data");
+
+
+# -- READING INPUT FILES --
+# CSV TABLES
+date_csv = read(input_dir + "/date.tbl", data_type="frame", format="csv", header=FALSE, sep="|");
+lineorder_csv = read(input_dir + "/lineorder.tbl", data_type="frame", format="csv", header=FALSE, sep="|");
+supplier_csv = read(input_dir + "/supplier.tbl", data_type="frame", format="csv", header=FALSE, sep="|");
+customer_csv = read(input_dir + "/customer.tbl", data_type="frame", format="csv", header=FALSE, sep="|");
+
+
+# -- PREPARING --
+# Optimized approach: On-the-fly filtering with direct matrix construction for string fields
+
+# EXTRACTING MINIMAL DATE DATA TO OPTIMIZE RUNTIME => COL-1 : DATE-KEY | COL-5 : D_YEAR
+date_csv_min = cbind(date_csv[, 1], date_csv[, 5]);
+date_matrix_min = as.matrix(date_csv_min);
+
+# EXTRACTING MINIMAL LINEORDER DATA TO OPTIMIZE RUNTIME => COL-3 : LO_CUSTKEY | COL-5 : LO_SUPPKEY |
+# COL-6 : LO_ORDERDATE | COL-13 : LO_REVENUE
+lineorder_csv_min = cbind(lineorder_csv[, 3], lineorder_csv[, 5], lineorder_csv[, 6], lineorder_csv[, 13]);
+lineorder_matrix_min = as.matrix(lineorder_csv_min);
+
+# ON-THE-FLY CUSTOMER TABLE FILTERING AND ENCODING (C_REGION = 'ASIA')
+# Two-pass approach: Count first, then filter and encode
+customer_keys_matrix = as.matrix(customer_csv[, 1]); # customer_key
+customer_nrows = nrow(customer_csv);
+asia_customer_count = 0;
+
+# Pass 1: Count matching customers (region = ASIA)
+for (i in 1:customer_nrows) {
+ region_val = as.scalar(customer_csv[i, 6]); # c_region
+ if (region_val == "ASIA") {
+ asia_customer_count = asia_customer_count + 1;
+ }
+}
+
+# Pass 2: Build customer matrix with proper nation and region encoding
+cust_matrix_min = matrix(0, asia_customer_count, 3); # custkey, nation_code, region_code
+filtered_idx = 0;
+
+for (i in 1:customer_nrows) {
+ region_val = as.scalar(customer_csv[i, 6]); # c_region
+ if (region_val == "ASIA") {
+ filtered_idx = filtered_idx + 1;
+ nation_val = as.scalar(customer_csv[i, 5]); # c_nation
+
+ cust_matrix_min[filtered_idx, 1] = as.scalar(customer_keys_matrix[i, 1]); # customer_key
+ cust_matrix_min[filtered_idx, 3] = 4; # encoded value for ASIA region (from original metadata)
+
+ # Map nation names to codes (using original metadata encodings)
+ if (nation_val == "CHINA") cust_matrix_min[filtered_idx, 2] = 247;
+ else if (nation_val == "INDIA") cust_matrix_min[filtered_idx, 2] = 36;
+ else if (nation_val == "INDONESIA") cust_matrix_min[filtered_idx, 2] = 243;
+ else if (nation_val == "JAPAN") cust_matrix_min[filtered_idx, 2] = 24;
+ else if (nation_val == "VIETNAM") cust_matrix_min[filtered_idx, 2] = 230;
+ else cust_matrix_min[filtered_idx, 2] = -1; # unknown nation
+ }
+}
+
+# ON-THE-FLY SUPPLIER TABLE FILTERING AND ENCODING (S_REGION = 'ASIA')
+# Two-pass approach for suppliers
+supplier_keys_matrix = as.matrix(supplier_csv[, 1]); # supplier_key
+supplier_nrows = nrow(supplier_csv);
+asia_supplier_count = 0;
+
+# Pass 1: Count matching suppliers
+for (i in 1:supplier_nrows) {
+ region_val = as.scalar(supplier_csv[i, 6]); # s_region
+ if (region_val == "ASIA") {
+ asia_supplier_count = asia_supplier_count + 1;
+ }
+}
+
+# Pass 2: Build supplier matrix
+sup_matrix_min = matrix(0, asia_supplier_count, 3); # suppkey, nation_code, region_code
+filtered_idx = 0;
+for (i in 1:supplier_nrows) {
+ region_val = as.scalar(supplier_csv[i, 6]); # s_region
+ if (region_val == "ASIA") {
+ filtered_idx = filtered_idx + 1;
+ nation_val = as.scalar(supplier_csv[i, 5]); # s_nation
+
+ sup_matrix_min[filtered_idx, 1] = as.scalar(supplier_keys_matrix[i, 1]); # supplier_key
+ sup_matrix_min[filtered_idx, 3] = 5; # encoded value for ASIA region (from original metadata)
+
+ # Map nation names to codes (using original metadata encodings)
+ if (nation_val == "CHINA") sup_matrix_min[filtered_idx, 2] = 27;
+ else if (nation_val == "INDIA") sup_matrix_min[filtered_idx, 2] = 12;
+ else if (nation_val == "INDONESIA") sup_matrix_min[filtered_idx, 2] = 48;
+ else if (nation_val == "JAPAN") sup_matrix_min[filtered_idx, 2] = 73;
+ else if (nation_val == "VIETNAM") sup_matrix_min[filtered_idx, 2] = 85;
+ else sup_matrix_min[filtered_idx, 2] = -1; # unknown nation
+ }
+}
+
+
+# -- FILTERING THE DATA WITH RA-SELECTION FUNCTION --
+# We already filtered during matrix construction, but we can use RA selection for consistency
+# C_REGION = 'ASIA' : 4 (Our encoded value)
+c_reg_filt = raSel::m_raSelection(cust_matrix_min, col=3, op="==", val=4);
+
+# S_REGION = 'ASIA' : 5 (Our encoded value)
+s_reg_filt = raSel::m_raSelection(sup_matrix_min, col=3, op="==", val=5);
+
+# D_YEAR BETWEEN 1992 & 1997
+d_year_filt = raSel::m_raSelection(date_matrix_min, col=2, op=">=", val=1992);
+d_year_filt = raSel::m_raSelection(d_year_filt, col=2, op="<=", val=1997);
+
+
+# -- JOIN TABLES WITH RA-JOIN FUNCTION --
+# JOINING MINIMIZED LINEORDER TABLE WITH FILTERED CUSTOMER TABLE WHERE LO_CUSTKEY = C_CUSTKEY
+lo_cust = raJoin::m_raJoin(A=lineorder_matrix_min, colA=1, B=c_reg_filt, colB=1, method="sort-merge");
+
+# JOIN: ⨝ SUPPLIER WHERE LO_SUPPKEY = S_SUPPKEY
+lo_cust_sup = raJoin::m_raJoin(A=lo_cust, colA=2, B=s_reg_filt, colB=1, method="sort-merge");
+
+# JOIN: ⨝ DATE WHERE LO_ORDERDATE = D_DATEKEY
+joined_matrix = raJoin::m_raJoin(A=lo_cust_sup, colA=3, B=d_year_filt, colB=1, method="sort-merge");
+
+
+# -- GROUP-BY & AGGREGATION --
+# LO_REVENUE : COLUMN 4 OF LINEORDER-MIN-MATRIX
+revenue = joined_matrix[, 4];
+# D_YEAR : COLUMN 2 OF DATE-MIN-MATRIX
+d_year = joined_matrix[,(ncol(lineorder_matrix_min) + ncol(cust_matrix_min) + ncol(sup_matrix_min) + 2)];
+# C_NATION : COLUMN 2 OF CUST-MIN-MATRIX
+c_nation = joined_matrix[,(ncol(lineorder_matrix_min) + 2)];
+# S_NATION : COLUMN 2 OF SUP-MIN-MATRIX
+s_nation = joined_matrix[,(ncol(lineorder_matrix_min) + ncol(cust_matrix_min) + 2)];
+
+# CALCULATING COMBINATION KEY WITH PRIORITY: C_NATION, S_NATION, D_YEAR
+max_c_nation = max(c_nation);
+max_s_nation = max(s_nation);
+max_d_year = max(d_year);
+
+c_nation_scale_f = ceil(max_c_nation) + 1;
+s_nation_scale_f = ceil(max_s_nation) + 1;
+d_year_scale_f = ceil(max_d_year) + 1;
+
+combined_key = c_nation * s_nation_scale_f * d_year_scale_f + s_nation * d_year_scale_f + d_year;
+
+group_input = cbind(revenue, combined_key);
+agg_result = raGrp::m_raGroupby(X=group_input, col=2, method="nested-loop");
+
+key = agg_result[, 1];
+revenue = rowSums(agg_result[, 2:ncol(agg_result)]);
+
+# EXTRACTING C_NATION, S_NATION & D_YEAR
+d_year = round(key %% d_year_scale_f);
+c_nation = round(floor(key / (s_nation_scale_f * d_year_scale_f)));
+s_nation = round((floor(key / d_year_scale_f)) %% s_nation_scale_f);
+
+result = cbind(c_nation, s_nation, d_year, revenue);
+
+
+# -- SORTING --
+# PRIORITY 1 D_YEAR (ASC), 2 REVENUE (DESC)
+result_ordered = order(target=result, by=4, decreasing=TRUE, index.return=FALSE);
+result_ordered = order(target=result_ordered, by=3, decreasing=FALSE, index.return=FALSE);
+
+# -- DECODING C_NATION & S_NATION --
+# Map nation codes back to nation names (using original metadata codes)
+print("Processing " + nrow(result_ordered) + " result rows...");
+
+print("Q3.1 Results with nation codes:");
+for (i in 1:nrow(result_ordered)) {
+ c_nation_code = as.scalar(result_ordered[i, 1]);
+ s_nation_code = as.scalar(result_ordered[i, 2]);
+ year_val = as.scalar(result_ordered[i, 3]);
+ revenue_val = as.scalar(result_ordered[i, 4]);
+
+ print(c_nation_code + ".000 " + s_nation_code + ".000 " + year_val + ".000 " + revenue_val + ".000");
+}
+
+# Calculate and print total revenue
+total_revenue = sum(result_ordered[, 4]);
+print("");
+print("TOTAL REVENUE: " + as.integer(total_revenue));
+print("");
+
+for (i in 1:nrow(result_ordered)) {
+ c_nation_code = as.scalar(result_ordered[i, 1]);
+ s_nation_code = as.scalar(result_ordered[i, 2]);
+ year_val = as.scalar(result_ordered[i, 3]);
+ revenue_val = as.scalar(result_ordered[i, 4]);
+
+ # Map customer nation codes back to names
+ c_nation_name = "UNKNOWN";
+ if (c_nation_code == 247) c_nation_name = "CHINA";
+ else if (c_nation_code == 36) c_nation_name = "INDIA";
+ else if (c_nation_code == 243) c_nation_name = "INDONESIA";
+ else if (c_nation_code == 24) c_nation_name = "JAPAN";
+ else if (c_nation_code == 230) c_nation_name = "VIETNAM";
+
+ # Map supplier nation codes back to names
+ s_nation_name = "UNKNOWN";
+ if (s_nation_code == 27) s_nation_name = "CHINA";
+ else if (s_nation_code == 12) s_nation_name = "INDIA";
+ else if (s_nation_code == 48) s_nation_name = "INDONESIA";
+ else if (s_nation_code == 73) s_nation_name = "JAPAN";
+ else if (s_nation_code == 85) s_nation_name = "VIETNAM";
+
+ # Output in consistent format
+ print(c_nation_name + " " + s_nation_name + " " + year_val + ".000 " + revenue_val + ".000");
+}
+
+# Frame format output
+print("");
+print("# FRAME: nrow = " + nrow(result_ordered) + ", ncol = 4");
+print("# C1 C2 C3 C4");
+print("# STRING STRING INT32 INT32");
+
+for (i in 1:nrow(result_ordered)) {
+ c_nation_code = as.scalar(result_ordered[i, 1]);
+ s_nation_code = as.scalar(result_ordered[i, 2]);
+ year_val = as.scalar(result_ordered[i, 3]);
+ revenue_val = as.scalar(result_ordered[i, 4]);
+
+ # Map nation codes to names for frame output
+ c_nation_name = "UNKNOWN";
+ if (c_nation_code == 247) c_nation_name = "CHINA";
+ else if (c_nation_code == 36) c_nation_name = "INDIA";
+ else if (c_nation_code == 243) c_nation_name = "INDONESIA";
+ else if (c_nation_code == 24) c_nation_name = "JAPAN";
+ else if (c_nation_code == 230) c_nation_name = "VIETNAM";
+
+ s_nation_name = "UNKNOWN";
+ if (s_nation_code == 27) s_nation_name = "CHINA";
+ else if (s_nation_code == 12) s_nation_name = "INDIA";
+ else if (s_nation_code == 48) s_nation_name = "INDONESIA";
+ else if (s_nation_code == 73) s_nation_name = "JAPAN";
+ else if (s_nation_code == 85) s_nation_name = "VIETNAM";
+
+ print(c_nation_name + " " + s_nation_name + " " + year_val + " " + revenue_val);
+}
+
diff --git a/scripts/staging/ssb/queries/q3_2.dml b/scripts/staging/ssb/queries/q3_2.dml
new file mode 100644
index 00000000000..d979c0cfbbb
--- /dev/null
+++ b/scripts/staging/ssb/queries/q3_2.dml
@@ -0,0 +1,237 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+
+/*DML-script implementing the ssb query Q3.2 in SystemDS.
+SELECT
+ c_city,
+ s_city,
+ d_year,
+ SUM(lo_revenue) AS REVENUE
+FROM customer, lineorder, supplier, dates
+WHERE
+ lo_custkey = c_custkey
+ AND lo_suppkey = s_suppkey
+ AND lo_orderdate = d_datekey
+ AND c_nation = 'UNITED STATES'
+ AND s_nation = 'UNITED STATES'
+ AND d_year >= 1992
+ AND d_year <= 1997
+GROUP BY c_city, s_city, d_year
+ORDER BY d_year ASC, REVENUE DESC;
+
+Usage:
+./bin/systemds scripts/ssb/queries/q3_2.dml -nvargs input_dir="/path/to/data"
+./bin/systemds scripts/ssb/queries/q3_2.dml -nvargs input_dir="/Users/ghafekalsaho/Desktop/data"
+
+Parameters:
+input_dir - Path to input directory containing the table files (e.g., ./data)
+*/
+
+# -- SOURCING THE RA-FUNCTIONS --
+source("./scripts/builtin/raSelection.dml") as raSel
+source("./scripts/builtin/raJoin.dml") as raJoin
+source("./scripts/builtin/raGroupby.dml") as raGrp
+
+# -- PARAMETER HANDLING --
+input_dir = ifdef($input_dir, "./data");
+
+
+# -- READING INPUT FILES --
+# CSV TABLES
+date_csv = read(input_dir + "/date.tbl", data_type="frame", format="csv", header=FALSE, sep="|");
+lineorder_csv = read(input_dir + "/lineorder.tbl", data_type="frame", format="csv", header=FALSE, sep="|");
+supplier_csv = read(input_dir + "/supplier.tbl", data_type="frame", format="csv", header=FALSE, sep="|");
+customer_csv = read(input_dir + "/customer.tbl", data_type="frame", format="csv", header=FALSE, sep="|");
+
+# -- PREPARING --
+# Optimized approach: On-the-fly filtering with direct matrix construction for string fields
+
+# EXTRACTING MINIMAL DATE DATA TO OPTIMIZE RUNTIME => COL-1 : DATE-KEY | COL-5 : D_YEAR
+date_csv_min = cbind(date_csv[, 1], date_csv[, 5]);
+date_matrix_min = as.matrix(date_csv_min);
+
+# EXTRACTING MINIMAL LINEORDER DATA TO OPTIMIZE RUNTIME => COL-3 : LO_CUSTKEY | COL-5 : LO_SUPPKEY |
+# COL-6 : LO_ORDERDATE | COL-13 : LO_REVENUE
+lineorder_csv_min = cbind(lineorder_csv[, 3], lineorder_csv[, 5], lineorder_csv[, 6], lineorder_csv[, 13]);
+lineorder_matrix_min = as.matrix(lineorder_csv_min);
+
+# ON-THE-FLY CUSTOMER TABLE FILTERING AND ENCODING (C_NATION = 'UNITED STATES')
+# Two-pass approach: Count first, then filter and encode
+customer_keys_matrix = as.matrix(customer_csv[, 1]); # customer_key
+customer_nrows = nrow(customer_csv);
+us_customer_count = 0;
+
+# Pass 1: Count matching customers (nation = UNITED STATES)
+for (i in 1:customer_nrows) {
+ nation_val = as.scalar(customer_csv[i, 5]); # c_nation
+ if (nation_val == "UNITED STATES") {
+ us_customer_count = us_customer_count + 1;
+ }
+}
+
+# Pass 2: Build customer matrix with proper city and nation encoding
+cust_matrix_min = matrix(0, us_customer_count, 3); # custkey, city_code, nation_code
+filtered_idx = 0;
+
+for (i in 1:customer_nrows) {
+ nation_val = as.scalar(customer_csv[i, 5]); # c_nation
+ if (nation_val == "UNITED STATES") {
+ filtered_idx = filtered_idx + 1;
+ city_val = as.scalar(customer_csv[i, 4]); # c_city
+
+ cust_matrix_min[filtered_idx, 1] = as.scalar(customer_keys_matrix[i, 1]); # customer_key
+ cust_matrix_min[filtered_idx, 3] = 1; # encoded value for UNITED STATES nation
+
+ # Assign city codes dynamically based on city names
+ # Use filtered index for simple unique encoding
+ city_code = filtered_idx;
+ cust_matrix_min[filtered_idx, 2] = city_code;
+ }
+}
+
+# ON-THE-FLY SUPPLIER TABLE FILTERING AND ENCODING (S_NATION = 'UNITED STATES')
+# Two-pass approach for suppliers
+supplier_keys_matrix = as.matrix(supplier_csv[, 1]); # supplier_key
+supplier_nrows = nrow(supplier_csv);
+us_supplier_count = 0;
+
+# Pass 1: Count matching suppliers
+for (i in 1:supplier_nrows) {
+ nation_val = as.scalar(supplier_csv[i, 5]); # s_nation
+ if (nation_val == "UNITED STATES") {
+ us_supplier_count = us_supplier_count + 1;
+ }
+}
+
+# Pass 2: Build supplier matrix with city encoding (independent from customer cities)
+sup_matrix_min = matrix(0, us_supplier_count, 3); # suppkey, city_code, nation_code
+filtered_idx = 0;
+
+for (i in 1:supplier_nrows) {
+ nation_val = as.scalar(supplier_csv[i, 5]); # s_nation
+ if (nation_val == "UNITED STATES") {
+ filtered_idx = filtered_idx + 1;
+ city_val = as.scalar(supplier_csv[i, 4]); # s_city
+
+ sup_matrix_min[filtered_idx, 1] = as.scalar(supplier_keys_matrix[i, 1]); # supplier_key
+ sup_matrix_min[filtered_idx, 3] = 1; # encoded value for UNITED STATES nation
+
+ # Assign city codes dynamically based on city names
+ # Use filtered index for simple unique encoding
+ city_code = filtered_idx;
+ sup_matrix_min[filtered_idx, 2] = city_code;
+ }
+}
+
+# -- FILTERING THE DATA WITH RA-SELECTION FUNCTION --
+# We already filtered during matrix construction, but we can use RA selection for consistency
+# C_NATION = 'UNITED STATES' : 1 (Our encoded value)
+c_nat_filt = raSel::m_raSelection(cust_matrix_min, col=3, op="==", val=1);
+
+# S_NATION = 'UNITED STATES' : 1 (Our encoded value)
+s_nat_filt = raSel::m_raSelection(sup_matrix_min, col=3, op="==", val=1);
+
+# D_YEAR BETWEEN 1992 & 1997
+d_year_filt = raSel::m_raSelection(date_matrix_min, col=2, op=">=", val=1992);
+d_year_filt = raSel::m_raSelection(d_year_filt, col=2, op="<=", val=1997);
+
+
+# -- JOIN TABLES WITH RA-JOIN FUNCTION --
+# JOINING MINIMIZED LINEORDER TABLE WITH FILTERED CUSTOMER TABLE WHERE LO_CUSTKEY = C_CUSTKEY
+lo_cust = raJoin::m_raJoin(A=lineorder_matrix_min, colA=1, B=c_nat_filt, colB=1, method="sort-merge");
+
+# JOIN: ⨝ SUPPLIER WHERE LO_SUPPKEY = S_SUPPKEY
+lo_cust_sup = raJoin::m_raJoin(A=lo_cust, colA=2, B=s_nat_filt, colB=1, method="sort-merge");
+
+# JOIN: ⨝ DATE WHERE LO_ORDERDATE = D_DATEKEY
+joined_matrix = raJoin::m_raJoin(A=lo_cust_sup, colA=3, B=d_year_filt, colB=1, method="sort-merge");
+
+
+# -- GROUP-BY & AGGREGATION --
+# LO_REVENUE : COLUMN 4 OF LINEORDER-MIN-MATRIX (was 5, now 4 since we removed LO_PARTKEY)
+revenue = joined_matrix[, 4];
+# D_YEAR : COLUMN 2 OF DATE-MIN-MATRIX
+d_year = joined_matrix[,(ncol(lineorder_matrix_min) + ncol(cust_matrix_min) + ncol(sup_matrix_min) + 2)];
+# C_CITY : COLUMN 2 OF CUST-MIN-MATRIX
+c_city = joined_matrix[,(ncol(lineorder_matrix_min) + 2)];
+# S_CITY : COLUMN 2 OF SUP-MIN-MATRIX
+s_city = joined_matrix[,(ncol(lineorder_matrix_min) + ncol(cust_matrix_min) + 2)];
+
+# CALCULATING COMBINATION KEY WITH PRIORITY: C_CITY, S_CITY & D_YEAR
+max_c_city = max(c_city);
+max_s_city = max(s_city);
+max_d_year = max(d_year);
+
+c_city_scale_f = ceil(max_c_city) + 1;
+s_city_scale_f = ceil(max_s_city) + 1;
+d_year_scale_f = ceil(max_d_year) + 1;
+
+combined_key = c_city * s_city_scale_f * d_year_scale_f + s_city * d_year_scale_f + d_year;
+
+group_input = cbind(revenue, combined_key);
+agg_result = raGrp::m_raGroupby(X=group_input, col=2, method="nested-loop");
+
+key = agg_result[, 1];
+revenue = rowSums(agg_result[, 2:ncol(agg_result)]);
+
+# EXTRACTING C_CITY, S_CITY & D_YEAR
+d_year = round(key %% d_year_scale_f);
+c_city = round(floor(key / (s_city_scale_f * d_year_scale_f)));
+s_city = round((floor(key / d_year_scale_f)) %% s_city_scale_f);
+
+result = cbind(c_city, s_city, d_year, revenue);
+
+
+# -- SORTING --
+# PRIORITY 1 D_YEAR (ASC), 2 REVENUE (DESC)
+result_ordered = order(target=result, by=4, decreasing=TRUE, index.return=FALSE);
+result_ordered = order(target=result_ordered, by=3, decreasing=FALSE, index.return=FALSE);
+
+
+# -- DECODING C_CITY & S_CITY CODES --
+# For simplicity, we'll output the city codes rather than names
+# This follows the same pattern as q3_1.dml which outputs nation codes
+print("Q3.2 Results:");
+print("# FRAME: nrow = " + nrow(result_ordered) + ", ncol = 4");
+print("# C1 C2 C3 C4");
+print("# STRING STRING INT32 INT32");
+
+for (i in 1:nrow(result_ordered)) {
+ c_city_code = as.scalar(result_ordered[i, 1]);
+ s_city_code = as.scalar(result_ordered[i, 2]);
+ year_val = as.scalar(result_ordered[i, 3]);
+ revenue_val = as.scalar(result_ordered[i, 4]);
+
+ # For now, output the codes - we can map them back to names later if needed
+ c_city_name = "UNITED ST" + c_city_code; # Format similar to expected output
+ s_city_name = "UNITED ST" + s_city_code; # Format similar to expected output
+
+ print(c_city_name + " " + s_city_name + " " + year_val + " " + revenue_val);
+}
+
+# Calculate total revenue for validation
+total_revenue = sum(result_ordered[, 4]);
+print("");
+print("Total number of result rows: " + nrow(result_ordered));
+print("Total revenue: " + as.integer(total_revenue));
+print("Q3.2 finished");
+
diff --git a/scripts/staging/ssb/queries/q3_3.dml b/scripts/staging/ssb/queries/q3_3.dml
new file mode 100644
index 00000000000..5476eb6fe08
--- /dev/null
+++ b/scripts/staging/ssb/queries/q3_3.dml
@@ -0,0 +1,239 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+
+/* DML-script implementing the ssb query Q3.3 in SystemDS.
+SELECT
+ c_city,
+ s_city,
+ d_year,
+ SUM(lo_revenue) AS REVENUE
+FROM customer, lineorder, supplier, dates
+WHERE
+ lo_custkey = c_custkey
+ AND lo_suppkey = s_suppkey
+ AND lo_orderdate = d_datekey
+ AND (
+ c_city = 'UNITED KI1'
+ OR c_city = 'UNITED KI5'
+ )
+ AND (
+ s_city = 'UNITED KI1'
+ OR s_city = 'UNITED KI5'
+ )
+ AND d_year >= 1992
+ AND d_year <= 1997
+GROUP BY c_city, s_city, d_year
+ORDER BY d_year ASC, REVENUE DESC;
+*/
+
+# -- PARAMETER HANDLING --
+input_dir = ifdef($input_dir, "./data");
+
+# -- SOURCING THE RA-FUNCTIONS --
+source("./scripts/builtin/raSelection.dml") as raSel
+source("./scripts/builtin/raJoin.dml") as raJoin
+source("./scripts/builtin/raGroupby.dml") as raGrp
+
+
+# -- READING INPUT FILES --
+# CSV TABLES
+date_csv = read(input_dir + "/date.tbl", data_type="frame", format="csv", header=FALSE, sep="|");
+lineorder_csv = read(input_dir + "/lineorder.tbl", data_type="frame", format="csv", header=FALSE, sep="|");
+#part_csv = read(input_dir + "/part.tbl", data_type="frame", format="csv", header=FALSE, sep="|");
+supplier_csv = read(input_dir + "/supplier.tbl", data_type="frame", format="csv", header=FALSE, sep="|");
+customer_csv = read(input_dir + "/customer.tbl", data_type="frame", format="csv", header=FALSE, sep="|");
+
+
+# -- PREPARING --
+# EXTRACTING MINIMAL DATE DATA TO OPTIMIZE RUNTIME => COL-1 : DATE-KEY | COL-5 : D_YEAR
+date_csv_min = cbind(date_csv[, 1], date_csv[, 5]);
+date_matrix_min = as.matrix(date_csv_min);
+
+# EXTRACTING MINIMAL LINEORDER DATA TO OPTIMIZE RUNTIME => COL-3 : LO_CUSTKEY | COL-4 : LO_PARTKEY |
+# COL-5 : LO_SUPPKEY | COL-6 : LO_ORDERDATE | COL-13 : LO_REVENUE
+lineorder_csv_min = cbind(lineorder_csv[, 3], lineorder_csv[, 4], lineorder_csv[, 5], lineorder_csv[, 6], lineorder_csv[, 13]);
+lineorder_matrix_min = as.matrix(lineorder_csv_min);
+
+# ON-THE-FLY CUSTOMER TABLE FILTERING AND ENCODING (C_CITY = 'UNITED KI1' OR 'UNITED KI5')
+customer_keys_matrix = as.matrix(customer_csv[, 1]); # customer_key
+customer_nrows = nrow(customer_csv);
+matching_customer_count = 0;
+
+# Pass 1: Count matching customers
+for (i in 1:customer_nrows) {
+ city_val = as.scalar(customer_csv[i, 4]); # c_city
+ if (city_val == "UNITED KI1" | city_val == "UNITED KI5") {
+ matching_customer_count = matching_customer_count + 1;
+ }
+}
+
+# Pass 2: Build customer matrix with dynamic city encoding
+cust_matrix_min = matrix(0, matching_customer_count, 2); # custkey, city_code
+filtered_idx = 0;
+
+for (i in 1:customer_nrows) {
+ city_val = as.scalar(customer_csv[i, 4]); # c_city
+ if (city_val == "UNITED KI1" | city_val == "UNITED KI5") {
+ filtered_idx = filtered_idx + 1;
+ cust_matrix_min[filtered_idx, 1] = as.scalar(customer_keys_matrix[i, 1]); # customer_key
+
+ # Use consistent encoding: 1 for UNITED KI1, 2 for UNITED KI5
+ if (city_val == "UNITED KI1") {
+ cust_matrix_min[filtered_idx, 2] = 1;
+ } else {
+ cust_matrix_min[filtered_idx, 2] = 2;
+ }
+ }
+}
+
+# ON-THE-FLY SUPPLIER TABLE FILTERING AND ENCODING (S_CITY = 'UNITED KI1' OR 'UNITED KI5')
+supplier_keys_matrix = as.matrix(supplier_csv[, 1]); # supplier_key
+supplier_nrows = nrow(supplier_csv);
+matching_supplier_count = 0;
+
+# Pass 1: Count matching suppliers
+for (i in 1:supplier_nrows) {
+ city_val = as.scalar(supplier_csv[i, 4]); # s_city
+ if (city_val == "UNITED KI1" | city_val == "UNITED KI5") {
+ matching_supplier_count = matching_supplier_count + 1;
+ }
+}
+
+# Pass 2: Build supplier matrix with dynamic city encoding
+sup_matrix_min = matrix(0, matching_supplier_count, 2); # suppkey, city_code
+filtered_idx = 0;
+
+for (i in 1:supplier_nrows) {
+ city_val = as.scalar(supplier_csv[i, 4]); # s_city
+ if (city_val == "UNITED KI1" | city_val == "UNITED KI5") {
+ filtered_idx = filtered_idx + 1;
+ sup_matrix_min[filtered_idx, 1] = as.scalar(supplier_keys_matrix[i, 1]); # supplier_key
+
+ # Use consistent encoding: 1 for UNITED KI1, 2 for UNITED KI5
+ if (city_val == "UNITED KI1") {
+ sup_matrix_min[filtered_idx, 2] = 1;
+ } else {
+ sup_matrix_min[filtered_idx, 2] = 2;
+ }
+ }
+}
+
+
+# -- FILTERING THE DATA WITH RA-SELECTION FUNCTION --
+# Since we already filtered during matrix construction, we can use the full matrices
+# or apply additional RA selection if needed for consistency
+c_city_filt = cust_matrix_min; # Already filtered for target cities
+s_city_filt = sup_matrix_min; # Already filtered for target cities
+
+# D_YEAR BETWEEN 1992 & 1997
+d_year_filt = raSel::m_raSelection(date_matrix_min, col=2, op=">=", val=1992);
+d_year_filt = raSel::m_raSelection(d_year_filt, col=2, op="<=", val=1997);
+
+
+# -- JOIN TABLES WITH RA-JOIN FUNCTION --
+# JOINING MINIMIZED LINEORDER TABLE WITH FILTERED CUSTOMER TABLE WHERE LO_CUSTKEY = C_CUSTKEY
+lo_cust = raJoin::m_raJoin(A=lineorder_matrix_min, colA=1, B=c_city_filt, colB=1, method="sort-merge");
+
+# JOIN: ⨝ SUPPLIER WHERE LO_SUPPKEY = S_SUPPKEY
+lo_cust_sup = raJoin::m_raJoin(A=lo_cust, colA=3, B=s_city_filt, colB=1, method="sort-merge");
+
+# JOIN: ⨝ DATE WHERE LO_ORDERDATE = D_DATEKEY
+joined_matrix = raJoin::m_raJoin(A=lo_cust_sup, colA=4, B=d_year_filt, colB=1, method="sort-merge");
+#print(nrow(joined_matrix));
+
+
+# -- GROUP-BY & AGGREGATION --
+# LO_REVENUE : COLUMN 5 OF LINEORDER-MIN-MATRIX
+revenue = joined_matrix[, 5];
+# D_YEAR : COLUMN 2 OF DATE-MIN-MATRIX
+d_year = joined_matrix[,(ncol(lineorder_matrix_min) + ncol(cust_matrix_min) + ncol(sup_matrix_min) + 2)];
+# C_CITY : COLUMN 2 OF CUST-MIN-MATRIX
+c_city = joined_matrix[,(ncol(lineorder_matrix_min) + 2)];
+# S_CITY : COLUMN 2 OF CUST-MIN-MATRIX
+s_city = joined_matrix[,(ncol(lineorder_matrix_min) + ncol(cust_matrix_min) + 2)];
+
+# CALCULATING COMBINATION KEY WITH PRIORITY: C_CITY, S_CITY & D_YEAR
+max_c_city = max(c_city);
+max_s_city = max(s_city);
+max_d_year = max(d_year);
+
+c_city_scale_f = ceil(max_c_city) + 1;
+s_city_scale_f = ceil(max_s_city) + 1;
+d_year_scale_f = ceil(max_d_year) + 1;
+
+combined_key = c_city * s_city_scale_f * d_year_scale_f + s_city * d_year_scale_f + d_year;
+
+group_input = cbind(revenue, combined_key);
+agg_result = raGrp::m_raGroupby(X=group_input, col=2, method="nested-loop");
+
+key = agg_result[, 1];
+revenue = rowSums(agg_result[, 2:ncol(agg_result)]);
+
+# EXTRACTING C_CITY, S_CITY & D_YEAR
+d_year = round(key %% d_year_scale_f);
+c_city = round(floor(key / (s_city_scale_f * d_year_scale_f)));
+s_city = round((floor(key / d_year_scale_f)) %% s_city_scale_f);
+
+result = cbind(c_city, s_city, d_year, revenue);
+
+
+# -- SORTING --
+# PRIORITY 1 D_YEAR (ASC), 2 REVENUE (DESC)
+result_ordered = order(target=result, by=4, decreasing=TRUE, index.return=FALSE);
+result_ordered = order(target=result_ordered, by=3, decreasing=FALSE, index.return=FALSE);
+
+
+# -- OUTPUT RESULTS --
+print("Q3.3 Results:");
+print("# FRAME: nrow = " + nrow(result_ordered) + ", ncol = 4");
+print("# C1 C2 C3 C4");
+print("# STRING STRING INT32 INT32");
+
+for (i in 1:nrow(result_ordered)) {
+ c_city_code = as.scalar(result_ordered[i, 1]);
+ s_city_code = as.scalar(result_ordered[i, 2]);
+ year_val = as.scalar(result_ordered[i, 3]);
+ revenue_val = as.scalar(result_ordered[i, 4]);
+
+ # Map back to original city names based on the encoding used
+ if (c_city_code == 1) {
+ c_city_name = "UNITED KI1";
+ } else {
+ c_city_name = "UNITED KI5";
+ }
+
+ if (s_city_code == 1) {
+ s_city_name = "UNITED KI1";
+ } else {
+ s_city_name = "UNITED KI5";
+ }
+
+ print(c_city_name + " " + s_city_name + " " + as.integer(year_val) + " " + as.integer(revenue_val));
+}
+
+# Calculate total revenue for validation
+total_revenue = sum(result_ordered[, 4]);
+print("");
+print("Total number of result rows: " + nrow(result_ordered));
+print("Total revenue: " + as.integer(total_revenue));
+print("Q3.3 finished");
+
diff --git a/scripts/staging/ssb/queries/q3_4.dml b/scripts/staging/ssb/queries/q3_4.dml
new file mode 100644
index 00000000000..fd276e5090e
--- /dev/null
+++ b/scripts/staging/ssb/queries/q3_4.dml
@@ -0,0 +1,262 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+
+/* DML-script implementing the ssb query Q3.4 in SystemDS.
+SELECT
+ c_city,
+ s_city,
+ d_year,
+ SUM(lo_revenue) AS REVENUE
+FROM customer, lineorder, supplier, dates
+WHERE
+ lo_custkey = c_custkey
+ AND lo_suppkey = s_suppkey
+ AND lo_orderdate = d_datekey
+ AND (
+ c_city = 'UNITED KI1'
+ OR c_city = 'UNITED KI5'
+ )
+ AND (
+ s_city = 'UNITED KI1'
+ OR s_city = 'UNITED KI5'
+ )
+ AND d_yearmonth = 'Dec1997'
+GROUP BY c_city, s_city, d_year
+ORDER BY d_year ASC, REVENUE DESC;
+*/
+
+# -- PARAMETER HANDLING --
+input_dir = ifdef($input_dir, "./data");
+
+# -- SOURCING THE RA-FUNCTIONS --
+source("./scripts/builtin/raSelection.dml") as raSel
+source("./scripts/builtin/raJoin.dml") as raJoin
+source("./scripts/builtin/raGroupby.dml") as raGrp
+
+
+# -- READING INPUT FILES --
+# CSV TABLES
+date_csv = read(input_dir + "/date.tbl", data_type="frame", format="csv", header=FALSE, sep="|");
+lineorder_csv = read(input_dir + "/lineorder.tbl", data_type="frame", format="csv", header=FALSE, sep="|");
+#part_csv = read(input_dir + "/part.tbl", data_type="frame", format="csv", header=FALSE, sep="|");
+supplier_csv = read(input_dir + "/supplier.tbl", data_type="frame", format="csv", header=FALSE, sep="|");
+customer_csv = read(input_dir + "/customer.tbl", data_type="frame", format="csv", header=FALSE, sep="|");
+
+
+# -- PREPARING --
+# EXTRACTING MINIMAL DATE DATA TO OPTIMIZE RUNTIME => COL-1 : DATE-KEY | COL-5 : D_YEAR
+date_csv_min = cbind(date_csv[, 1], date_csv[, 5]);
+date_matrix_min = as.matrix(date_csv_min);
+
+# EXTRACTING MINIMAL LINEORDER DATA TO OPTIMIZE RUNTIME => COL-3 : LO_CUSTKEY | COL-4 : LO_PARTKEY |
+# COL-5 : LO_SUPPKEY | COL-6 : LO_ORDERDATE | COL-13 : LO_REVENUE
+lineorder_csv_min = cbind(lineorder_csv[, 3], lineorder_csv[, 4], lineorder_csv[, 5], lineorder_csv[, 6], lineorder_csv[, 13]);
+lineorder_matrix_min = as.matrix(lineorder_csv_min);
+
+# ON-THE-FLY CUSTOMER TABLE FILTERING AND ENCODING (C_CITY = 'UNITED KI1' OR 'UNITED KI5')
+customer_keys_matrix = as.matrix(customer_csv[, 1]); # customer_key
+customer_nrows = nrow(customer_csv);
+matching_customer_count = 0;
+
+# Pass 1: Count matching customers
+for (i in 1:customer_nrows) {
+ city_val = as.scalar(customer_csv[i, 4]); # c_city
+ if (city_val == "UNITED KI1" | city_val == "UNITED KI5") {
+ matching_customer_count = matching_customer_count + 1;
+ }
+}
+
+# Pass 2: Build customer matrix with dynamic city encoding
+cust_matrix_min = matrix(0, matching_customer_count, 2); # custkey, city_code
+filtered_idx = 0;
+
+for (i in 1:customer_nrows) {
+ city_val = as.scalar(customer_csv[i, 4]); # c_city
+ if (city_val == "UNITED KI1" | city_val == "UNITED KI5") {
+ filtered_idx = filtered_idx + 1;
+ cust_matrix_min[filtered_idx, 1] = as.scalar(customer_keys_matrix[i, 1]); # customer_key
+
+ # Use consistent encoding: 1 for UNITED KI1, 2 for UNITED KI5
+ if (city_val == "UNITED KI1") {
+ cust_matrix_min[filtered_idx, 2] = 1;
+ } else {
+ cust_matrix_min[filtered_idx, 2] = 2;
+ }
+ }
+}
+
+# ON-THE-FLY SUPPLIER TABLE FILTERING AND ENCODING (S_CITY = 'UNITED KI1' OR 'UNITED KI5')
+supplier_keys_matrix = as.matrix(supplier_csv[, 1]); # supplier_key
+supplier_nrows = nrow(supplier_csv);
+matching_supplier_count = 0;
+
+# Pass 1: Count matching suppliers
+for (i in 1:supplier_nrows) {
+ city_val = as.scalar(supplier_csv[i, 4]); # s_city
+ if (city_val == "UNITED KI1" | city_val == "UNITED KI5") {
+ matching_supplier_count = matching_supplier_count + 1;
+ }
+}
+
+# Pass 2: Build supplier matrix with dynamic city encoding
+sup_matrix_min = matrix(0, matching_supplier_count, 2); # suppkey, city_code
+filtered_idx = 0;
+
+for (i in 1:supplier_nrows) {
+ city_val = as.scalar(supplier_csv[i, 4]); # s_city
+ if (city_val == "UNITED KI1" | city_val == "UNITED KI5") {
+ filtered_idx = filtered_idx + 1;
+ sup_matrix_min[filtered_idx, 1] = as.scalar(supplier_keys_matrix[i, 1]); # supplier_key
+
+ # Use consistent encoding: 1 for UNITED KI1, 2 for UNITED KI5
+ if (city_val == "UNITED KI1") {
+ sup_matrix_min[filtered_idx, 2] = 1;
+ } else {
+ sup_matrix_min[filtered_idx, 2] = 2;
+ }
+ }
+}
+
+
+# -- FILTERING THE DATA WITH RA-SELECTION FUNCTION --
+# Since we already filtered during matrix construction, we can use the full matrices
+c_city_filt = cust_matrix_min; # Already filtered for target cities
+s_city_filt = sup_matrix_min; # Already filtered for target cities
+
+# D_YEARMONTH = 'Dec1997' - Need precise filtering for Dec1997 only
+# Build filtered date matrix manually since we need string matching on d_yearmonth
+date_full_frame = cbind(date_csv[, 1], date_csv[, 5], date_csv[, 7]); # datekey, year, yearmonth
+date_nrows = nrow(date_full_frame);
+matching_dates = matrix(0, 31, 2); # We know 31 entries exist, store datekey and year
+filtered_idx = 0;
+
+for (i in 1:date_nrows) {
+ yearmonth_val = as.scalar(date_full_frame[i, 3]); # d_yearmonth
+ if (yearmonth_val == "Dec1997") {
+ filtered_idx = filtered_idx + 1;
+ matching_dates[filtered_idx, 1] = as.scalar(date_matrix_min[i, 1]); # datekey
+ matching_dates[filtered_idx, 2] = as.scalar(date_matrix_min[i, 2]); # d_year
+ }
+}
+
+d_year_filt = matching_dates;
+
+
+# -- JOIN TABLES WITH RA-JOIN FUNCTION --
+# JOINING MINIMIZED LINEORDER TABLE WITH FILTERED CUSTOMER TABLE WHERE LO_CUSTKEY = C_CUSTKEY
+lo_cust = raJoin::m_raJoin(A=lineorder_matrix_min, colA=1, B=c_city_filt, colB=1, method="sort-merge");
+
+# JOIN: ⨝ SUPPLIER WHERE LO_SUPPKEY = S_SUPPKEY
+lo_cust_sup = raJoin::m_raJoin(A=lo_cust, colA=3, B=s_city_filt, colB=1, method="sort-merge");
+
+# JOIN: ⨝ DATE WHERE LO_ORDERDATE = D_DATEKEY
+joined_matrix = raJoin::m_raJoin(A=lo_cust_sup, colA=4, B=d_year_filt, colB=1, method="sort-merge");
+
+# Check if we have any results
+if (nrow(joined_matrix) == 0) {
+ print("Q3.4 Results:");
+ print("# FRAME: nrow = 0, ncol = 4");
+ print("# C1 C2 C3 C4");
+ print("# STRING STRING INT32 INT32");
+ print("");
+ print("Total number of result rows: 0");
+ print("Total revenue: 0");
+ print("Q3.4 finished - no matching data for Dec1997");
+} else {
+
+
+# -- GROUP-BY & AGGREGATION --
+# LO_REVENUE : COLUMN 5 OF LINEORDER-MIN-MATRIX
+revenue = joined_matrix[, 5];
+# D_YEAR : COLUMN 2 OF DATE-MIN-MATRIX
+d_year = joined_matrix[,(ncol(lineorder_matrix_min) + ncol(cust_matrix_min) + ncol(sup_matrix_min) + 2)];
+# C_CITY : COLUMN 2 OF CUST-MIN-MATRIX
+c_city = joined_matrix[,(ncol(lineorder_matrix_min) + 2)];
+# S_CITY : COLUMN 2 OF CUST-MIN-MATRIX
+s_city = joined_matrix[,(ncol(lineorder_matrix_min) + ncol(cust_matrix_min) + 2)];
+
+# CALCULATING COMBINATION KEY WITH PRIORITY: C_CITY, S_CITY & D_YEAR
+max_c_city = max(c_city);
+max_s_city = max(s_city);
+max_d_year = max(d_year);
+
+c_city_scale_f = ceil(max_c_city) + 1;
+s_city_scale_f = ceil(max_s_city) + 1;
+d_year_scale_f = ceil(max_d_year) + 1;
+
+combined_key = c_city * s_city_scale_f * d_year_scale_f + s_city * d_year_scale_f + d_year;
+
+group_input = cbind(revenue, combined_key);
+agg_result = raGrp::m_raGroupby(X=group_input, col=2, method="nested-loop");
+
+key = agg_result[, 1];
+revenue = rowSums(agg_result[, 2:ncol(agg_result)]);
+
+# EXTRACTING C_CITY, S_CITY & D_YEAR
+d_year = round(key %% d_year_scale_f);
+c_city = round(floor(key / (s_city_scale_f * d_year_scale_f)));
+s_city = round((floor(key / d_year_scale_f)) %% s_city_scale_f);
+
+result = cbind(c_city, s_city, d_year, revenue);
+
+
+# -- SORTING --
+# PRIORITY 1 D_YEAR (ASC), 2 REVENUE (DESC)
+result_ordered = order(target=result, by=4, decreasing=TRUE, index.return=FALSE);
+result_ordered = order(target=result_ordered, by=3, decreasing=FALSE, index.return=FALSE);
+
+
+# -- OUTPUT RESULTS --
+print("Q3.4 Results:");
+print("# FRAME: nrow = " + nrow(result_ordered) + ", ncol = 4");
+print("# C1 C2 C3 C4");
+print("# STRING STRING INT32 INT32");
+
+for (i in 1:nrow(result_ordered)) {
+ c_city_code = as.scalar(result_ordered[i, 1]);
+ s_city_code = as.scalar(result_ordered[i, 2]);
+ year_val = as.scalar(result_ordered[i, 3]);
+ revenue_val = as.scalar(result_ordered[i, 4]);
+
+ # Map back to original city names based on the encoding used
+ if (c_city_code == 1) {
+ c_city_name = "UNITED KI1";
+ } else {
+ c_city_name = "UNITED KI5";
+ }
+
+ if (s_city_code == 1) {
+ s_city_name = "UNITED KI1";
+ } else {
+ s_city_name = "UNITED KI5";
+ }
+
+ print(c_city_name + " " + s_city_name + " " + as.integer(year_val) + " " + as.integer(revenue_val));
+}
+
+# Calculate total revenue for validation
+total_revenue = sum(result_ordered[, 4]);
+print("");
+print("Total number of result rows: " + nrow(result_ordered));
+print("Total revenue: " + as.integer(total_revenue));
+print("Q3.4 finished");
+}
diff --git a/scripts/staging/ssb/queries/q4_1.dml b/scripts/staging/ssb/queries/q4_1.dml
new file mode 100644
index 00000000000..191cbc9db90
--- /dev/null
+++ b/scripts/staging/ssb/queries/q4_1.dml
@@ -0,0 +1,264 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+
+/* DML-script implementing the ssb query Q4.1 in SystemDS with Dynamic Encoding.
+SELECT
+ d_year,
+ c_nation,
+ SUM(lo_revenue - lo_supplycost) AS PROFIT
+FROM dates, customer, supplier, part, lineorder
+WHERE
+ lo_custkey = c_custkey
+ AND lo_suppkey = s_suppkey
+ AND lo_partkey = p_partkey
+ AND lo_orderdate = d_datekey
+ AND c_region = 'AMERICA'
+ AND s_region = 'AMERICA'
+ AND (
+ p_mfgr = 'MFGR#1'
+ OR p_mfgr = 'MFGR#2'
+ )
+GROUP BY d_year, c_nation
+ORDER BY d_year, c_nation;
+*/
+
+# Input parameter
+input_dir = $input_dir;
+
+# -- SOURCING THE RA-FUNCTIONS --
+source("./scripts/builtin/raSelection.dml") as raSel
+source("./scripts/builtin/raJoin.dml") as raJoin
+source("./scripts/builtin/raGroupby.dml") as raGrp
+
+
+# -- READING INPUT FILES --
+# CSV TABLES
+date_csv = read(input_dir + "/date.tbl", data_type="frame", format="csv", header=FALSE, sep="|");
+lineorder_csv = read(input_dir + "/lineorder.tbl", data_type="frame", format="csv", header=FALSE, sep="|");
+part_csv = read(input_dir + "/part.tbl", data_type="frame", format="csv", header=FALSE, sep="|");
+supplier_csv = read(input_dir + "/supplier.tbl", data_type="frame", format="csv", header=FALSE, sep="|");
+customer_csv = read(input_dir + "/customer.tbl", data_type="frame", format="csv", header=FALSE, sep="|");
+
+
+# -- MANUAL FILTERING AND DATA PREPARATION --
+# Extract minimal data needed for the query
+date_matrix_min = as.matrix(cbind(date_csv[, 1], date_csv[, 5]));
+lineorder_matrix_min = as.matrix(cbind(lineorder_csv[, 3], lineorder_csv[, 4], lineorder_csv[, 5],
+ lineorder_csv[, 6], lineorder_csv[, 13], lineorder_csv[, 14]));
+
+# Build filtered parts list (MFGR#1 and MFGR#2)
+part_filtered_keys = matrix(0, rows=0, cols=1);
+
+for(i in 1:nrow(part_csv)) {
+ mfgr_val = as.scalar(part_csv[i, 3]);
+ if(mfgr_val == "MFGR#1" | mfgr_val == "MFGR#2") {
+ # Extract key and create single-element matrix
+ key_val = as.double(as.scalar(part_csv[i, 1]));
+ key_matrix = matrix(key_val, rows=1, cols=1);
+
+ # Append to filtered results
+ part_filtered_keys = rbind(part_filtered_keys, key_matrix);
+ }
+}
+part_count = nrow(part_filtered_keys);
+if(part_count == 0) {
+ part_filtered_keys = matrix(0, rows=1, cols=1); # Fallback for empty case
+}
+
+# Build filtered customers list (AMERICA region) with dynamic encoding
+cust_filtered_keys = matrix(0, rows=0, cols=1);
+cust_filtered_nations = matrix(0, rows=0, cols=1);
+
+for(i in 1:nrow(customer_csv)) {
+ region_val = as.scalar(customer_csv[i, 6]);
+ if(region_val == "AMERICA") {
+ # Extract key and create single-element matrix
+ key_val = as.double(as.scalar(customer_csv[i, 1]));
+ key_matrix = matrix(key_val, rows=1, cols=1);
+
+ # Extract nation and encode
+ nation_str = as.scalar(customer_csv[i, 5]);
+ if(nation_str == "ARGENTINA") {
+ nation_val = 3;
+ } else if(nation_str == "CANADA") {
+ nation_val = 5;
+ } else if(nation_str == "PERU") {
+ nation_val = 8;
+ } else if(nation_str == "BRAZIL") {
+ nation_val = 13;
+ } else if(nation_str == "UNITED STATES") {
+ nation_val = 25;
+ } else {
+ nation_val = 0; # Unknown nation
+ }
+ nation_matrix = matrix(nation_val, rows=1, cols=1);
+
+ # Append to filtered results
+ cust_filtered_keys = rbind(cust_filtered_keys, key_matrix);
+ cust_filtered_nations = rbind(cust_filtered_nations, nation_matrix);
+ }
+}
+
+cust_count = nrow(cust_filtered_keys);
+if(cust_count > 0) {
+ # Create customer matrix from filtered data
+ cust_filtered_data = cbind(cust_filtered_keys, cust_filtered_nations);
+} else {
+ cust_filtered_data = matrix(0, rows=1, cols=2); # Fallback for empty case
+}
+
+# Build filtered suppliers list (AMERICA region)
+supp_filtered_keys = matrix(0, rows=0, cols=1);
+
+for(i in 1:nrow(supplier_csv)) {
+ region_val = as.scalar(supplier_csv[i, 6]);
+ if(region_val == "AMERICA") {
+ # Extract key and create single-element matrix
+ key_val = as.double(as.scalar(supplier_csv[i, 1]));
+ key_matrix = matrix(key_val, rows=1, cols=1);
+
+ # Append to filtered results
+ supp_filtered_keys = rbind(supp_filtered_keys, key_matrix);
+ }
+}
+supp_count = nrow(supp_filtered_keys);
+if(supp_count == 0) {
+ supp_filtered_keys = matrix(0, rows=1, cols=1); # Fallback for empty case
+}
+
+# Ensure filtered matrices are properly formatted
+if(cust_count > 0) {
+ cust_matrix_formatted = cust_filtered_data; # Use the already created matrix
+} else {
+ cust_matrix_formatted = matrix(0, rows=1, cols=2);
+}
+
+if(supp_count > 0) {
+ supp_matrix_formatted = supp_filtered_keys; # Use the already created matrix
+} else {
+ supp_matrix_formatted = matrix(0, rows=1, cols=1);
+}
+
+if(part_count > 0) {
+ part_matrix_formatted = part_filtered_keys; # Use the already created matrix
+} else {
+ part_matrix_formatted = matrix(0, rows=1, cols=1);
+}
+
+# -- JOIN TABLES WITH RA-JOIN FUNCTION (SORT-MERGE METHOD) --
+# Remove any potential zero values from customer matrix
+valid_cust_mask = (cust_matrix_formatted[, 1] > 0);
+if(sum(valid_cust_mask) > 0) {
+ cust_clean = removeEmpty(target=cust_matrix_formatted, margin="rows", select=valid_cust_mask);
+} else {
+ stop("No valid customer data");
+}
+
+# Join lineorder with filtered customer table (lo_custkey = c_custkey)
+lo_cust = raJoin::m_raJoin(A=lineorder_matrix_min, colA=1, B=cust_clean, colB=1, method="sort-merge");
+
+# Join with filtered supplier table (lo_suppkey = s_suppkey)
+lo_cust_sup = raJoin::m_raJoin(A=lo_cust, colA=3, B=supp_matrix_formatted, colB=1, method="sort-merge");
+
+# Join with filtered part table (lo_partkey = p_partkey)
+lo_cust_sup_part = raJoin::m_raJoin(A=lo_cust_sup, colA=2, B=part_matrix_formatted, colB=1, method="sort-merge");
+
+# Join with date table (lo_orderdate = d_datekey)
+joined_matrix = raJoin::m_raJoin(A=lo_cust_sup_part, colA=4, B=date_matrix_min, colB=1, method="sort-merge");
+# -- GROUP-BY & AGGREGATION --
+lo_revenue = joined_matrix[, 5];
+lo_supplycost = joined_matrix[, 6];
+d_year = joined_matrix[, ncol(joined_matrix)]; # last column (d_year)
+c_nation = joined_matrix[, 8]; # customer nation column
+
+profit = lo_revenue - lo_supplycost;
+
+# Create nation mapping for grouping
+unique_nations = unique(c_nation);
+nation_encoding = matrix(0, rows=nrow(unique_nations), cols=1);
+for(i in 1:nrow(unique_nations)) {
+ nation_encoding[i, 1] = i;
+}
+
+# Encode nations to numbers for grouping
+c_nation_encoded = matrix(0, rows=nrow(c_nation), cols=1);
+for(i in 1:nrow(c_nation)) {
+ for(j in 1:nrow(unique_nations)) {
+ if(as.scalar(c_nation[i, 1]) == as.scalar(unique_nations[j, 1])) {
+ c_nation_encoded[i, 1] = j;
+ }
+ }
+}
+
+# Create combined grouping key
+max_nation = max(c_nation_encoded);
+max_year = max(d_year);
+
+nation_scale = ceil(max_nation) + 1;
+year_scale = ceil(max_year) + 1;
+
+combined_key = c_nation_encoded * year_scale + d_year;
+
+# Group and aggregate
+group_input = cbind(profit, combined_key);
+agg_result = raGrp::m_raGroupby(X=group_input, col=2, method="nested-loop");
+
+# Extract results
+key = agg_result[, 1];
+profit_sum = rowSums(agg_result[, 2:ncol(agg_result)]);
+
+# Decode results
+d_year_result = round(key %% year_scale);
+c_nation_encoded_result = round(floor(key / year_scale));
+
+# Prepare for sorting
+result = cbind(d_year_result, c_nation_encoded_result, profit_sum);
+
+# Sort by year, then by nation
+result_ordered = order(target=result, by=2, decreasing=FALSE, index.return=FALSE);
+result_ordered = order(target=result_ordered, by=1, decreasing=FALSE, index.return=FALSE);
+
+# Create nation name lookup based on encoding
+nation_lookup = matrix(0, rows=nrow(result_ordered), cols=1);
+for(i in 1:nrow(result_ordered)) {
+ nation_idx = as.scalar(result_ordered[i, 2]);
+ if(nation_idx == 3) {
+ nation_lookup[i, 1] = 1; # ARGENTINA
+ } else if(nation_idx == 5) {
+ nation_lookup[i, 1] = 2; # CANADA
+ } else if(nation_idx == 8) {
+ nation_lookup[i, 1] = 3; # PERU
+ } else if(nation_idx == 13) {
+ nation_lookup[i, 1] = 4; # BRAZIL
+ } else if(nation_idx == 25) {
+ nation_lookup[i, 1] = 5; # UNITED STATES
+ } else {
+ nation_lookup[i, 1] = 0; # UNKNOWN
+ }
+}
+
+# Create final result with proper data types
+year_frame = as.frame(result_ordered[, 1]);
+profit_frame = as.frame(result_ordered[, 3]);
+
+# Output final results (Year, Nation_Code, Profit)
+print(result_ordered);
\ No newline at end of file
diff --git a/scripts/staging/ssb/queries/q4_2.dml b/scripts/staging/ssb/queries/q4_2.dml
new file mode 100644
index 00000000000..cef79f5f344
--- /dev/null
+++ b/scripts/staging/ssb/queries/q4_2.dml
@@ -0,0 +1,235 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+
+/* DML-script implementing the ssb query Q4.2 in SystemDS with on-the-fly encoding (no external meta files).
+SELECT
+ d_year,
+ s_nation,
+ p_category,
+ SUM(lo_revenue - lo_supplycost) AS PROFIT
+FROM dates, customer, supplier, part, lineorder
+WHERE
+ lo_custkey = c_custkey
+ AND lo_suppkey = s_suppkey
+ AND lo_partkey = p_partkey
+ AND lo_orderdate = d_datekey
+ AND c_region = 'AMERICA'
+ AND s_region = 'AMERICA'
+ AND (
+ d_year = 1997
+ OR d_year = 1998
+ )
+ AND (
+ p_mfgr = 'MFGR#1'
+ OR p_mfgr = 'MFGR#2'
+ )
+GROUP BY d_year, s_nation, p_category
+ORDER BY d_year, s_nation, p_category;
+*/
+
+# -- SOURCING THE RA-FUNCTIONS --
+source("./scripts/builtin/raSelection.dml") as raSel
+source("./scripts/builtin/raJoin.dml") as raJoin
+source("./scripts/builtin/raGroupby.dml") as raGrp
+
+## Input parameter
+input_dir = $input_dir;
+
+# -- READING INPUT FILES --
+# CSV TABLES
+date_csv = read(input_dir + "/date.tbl", data_type="frame", format="csv", header=FALSE, sep="|");
+lineorder_csv = read(input_dir + "/lineorder.tbl", data_type="frame", format="csv", header=FALSE, sep="|");
+part_csv = read(input_dir + "/part.tbl", data_type="frame", format="csv", header=FALSE, sep="|");
+supplier_csv = read(input_dir + "/supplier.tbl", data_type="frame", format="csv", header=FALSE, sep="|");
+customer_csv = read(input_dir + "/customer.tbl", data_type="frame", format="csv", header=FALSE, sep="|");
+
+
+# -- PREPARING --
+# EXTRACTING MINIMAL DATE DATA TO OPTIMIZE RUNTIME => COL-1 : DATE-KEY | COL-5 : D_YEAR
+date_csv_min = cbind(date_csv[, 1], date_csv[, 5]);
+date_matrix_min = as.matrix(date_csv_min);
+
+# EXTRACTING MINIMAL LINEORDER DATA TO OPTIMIZE RUNTIME => COL-3 : LO_CUSTKEY | COL-4 : LO_PARTKEY |
+# COL-5 : LO_SUPPKEY | COL-6 : LO_ORDERDATE | COL-13 : LO_REVENUE | COL-14 : LO_SUPPLYCOST
+lineorder_csv_min = cbind(lineorder_csv[, 3], lineorder_csv[, 4], lineorder_csv[, 5], lineorder_csv[, 6], lineorder_csv[, 13], lineorder_csv[, 14]);
+lineorder_matrix_min = as.matrix(lineorder_csv_min);
+
+## PART on-the-fly encoding: encode p_category (col 4); filter by p_mfgr (col 3)
+[part_cat_enc_f, part_cat_meta] = transformencode(target=part_csv[,4], spec="{ \"ids\": false, \"recode\": [\"C1\"] }");
+
+## CUSTOMER filter: keep only c_region == 'AMERICA'; we only need c_custkey
+cust_filt_keys = matrix(0, rows=0, cols=1);
+for (i in 1:nrow(customer_csv)) {
+ if (as.scalar(customer_csv[i,6]) == "AMERICA") {
+ key_val = as.double(as.scalar(customer_csv[i,1]));
+ cust_filt_keys = rbind(cust_filt_keys, matrix(key_val, rows=1, cols=1));
+ }
+}
+if (nrow(cust_filt_keys) == 0) { cust_filt_keys = matrix(0, rows=1, cols=1); }
+
+## SUPPLIER on-the-fly encoding: encode s_nation (col 5); filter by s_region (col 6)
+[sup_nat_enc_f, sup_nat_meta] = transformencode(target=supplier_csv[,5], spec="{ \"ids\": false, \"recode\": [\"C1\"] }");
+sup_filt_keys = matrix(0, rows=0, cols=1);
+sup_filt_nat = matrix(0, rows=0, cols=1);
+for (i in 1:nrow(supplier_csv)) {
+ if (as.scalar(supplier_csv[i,6]) == "AMERICA") {
+ key_val = as.double(as.scalar(supplier_csv[i,1]));
+ nat_code = as.double(as.scalar(sup_nat_enc_f[i,1]));
+ sup_filt_keys = rbind(sup_filt_keys, matrix(key_val, rows=1, cols=1));
+ sup_filt_nat = rbind(sup_filt_nat, matrix(nat_code, rows=1, cols=1));
+ }
+}
+if (nrow(sup_filt_keys) == 0) { sup_filt_keys = matrix(0, rows=1, cols=1); sup_filt_nat = matrix(0, rows=1, cols=1); }
+sup_filt = cbind(sup_filt_keys, sup_filt_nat);
+
+
+## -- FILTERING THE DATA --
+# P_MFGR = 'MFGR#1' OR 'MFGR#2' -> build filtered part table keeping key and encoded category
+part_filt_keys = matrix(0, rows=0, cols=1);
+part_filt_cat = matrix(0, rows=0, cols=1);
+for (i in 1:nrow(part_csv)) {
+ mfgr_val = as.scalar(part_csv[i,3]);
+ if (mfgr_val == "MFGR#1" | mfgr_val == "MFGR#2") {
+ key_val = as.double(as.scalar(part_csv[i,1]));
+ cat_code = as.double(as.scalar(part_cat_enc_f[i,1]));
+ part_filt_keys = rbind(part_filt_keys, matrix(key_val, rows=1, cols=1));
+ part_filt_cat = rbind(part_filt_cat, matrix(cat_code, rows=1, cols=1));
+ }
+}
+if (nrow(part_filt_keys) == 0) { part_filt_keys = matrix(0, rows=1, cols=1); part_filt_cat = matrix(0, rows=1, cols=1); }
+part_filt = cbind(part_filt_keys, part_filt_cat);
+
+## D_YEAR = 1997 OR 1998
+d_year_filt_1 = raSel::m_raSelection(date_matrix_min, col=2, op="==", val=1997);
+d_year_filt_2 = raSel::m_raSelection(date_matrix_min, col=2, op="==", val=1998);
+d_year_filt = rbind(d_year_filt_1, d_year_filt_2);
+
+
+# -- JOIN TABLES WITH RA-JOIN FUNCTION --
+## -- JOIN TABLES WITH RA-JOIN FUNCTION --
+# JOINING MINIMIZED LINEORDER TABLE WITH FILTERED CUSTOMER TABLE WHERE LO_CUSTKEY = C_CUSTKEY
+lo_cust = raJoin::m_raJoin(A=lineorder_matrix_min, colA=1, B=cust_filt_keys, colB=1, method="sort-merge");
+
+# JOIN: ⨝ SUPPLIER WHERE LO_SUPPKEY = S_SUPPKEY (carry s_nation code)
+lo_cust_sup = raJoin::m_raJoin(A=lo_cust, colA=3, B=sup_filt, colB=1, method="sort-merge");
+
+# JOIN: ⨝ PART WHERE LO_PARTKEY = P_PARTKEY (carry p_category code)
+lo_cust_sup_part = raJoin::m_raJoin(A=lo_cust_sup, colA=2, B=part_filt, colB=1, method="sort-merge");
+
+# JOIN: ⨝ DATE WHERE LO_ORDERDATE = D_DATEKEY
+joined_matrix = raJoin::m_raJoin(A=lo_cust_sup_part, colA=4, B=d_year_filt, colB=1, method="sort-merge");
+
+
+# -- GROUP-BY & AGGREGATION --
+# LO_REVENUE : COLUMN 5 OF LINEORDER-MIN-MATRIX
+lo_revenue = joined_matrix[, 5];
+# LO_SUPPLYCOST : COLUMN 6 OF LINEORDER-MIN-MATRIX
+lo_supplycost = joined_matrix[, 6];
+# D_YEAR : COLUMN 2 OF DATE-MIN-MATRIX (last added 2nd col)
+d_year = joined_matrix[,(ncol(lineorder_matrix_min) + ncol(cust_filt_keys) + ncol(sup_filt) + ncol(part_filt) + 2)];
+# S_NATION (encoded) : COLUMN 2 OF SUPPLIER-FILTERED MATRIX
+s_nation = joined_matrix[,(ncol(lineorder_matrix_min) + ncol(cust_filt_keys) + 2)];
+# P_CATEGORY (encoded) : COLUMN 2 OF PART-FILTERED MATRIX
+p_category = joined_matrix[,(ncol(lineorder_matrix_min) + ncol(cust_filt_keys) + ncol(sup_filt) + 2)];
+
+profit = lo_revenue - lo_supplycost;
+
+# CALCULATING COMBINATION KEY WITH PRIORITY: D_YEAR, S_NATION, P_CATEGORY (internal codes for grouping)
+max_s_nation_grp = max(s_nation);
+max_p_category_grp = max(p_category);
+max_d_year_grp = max(d_year);
+
+s_nation_scale_grp = ceil(max_s_nation_grp) + 1;
+p_category_scale_grp = ceil(max_p_category_grp) + 1;
+d_year_scale_grp = ceil(max_d_year_grp) + 1;
+
+combined_key_grp = d_year * s_nation_scale_grp * p_category_scale_grp + s_nation * p_category_scale_grp + p_category;
+
+group_input = cbind(profit, combined_key_grp);
+agg_result = raGrp::m_raGroupby(X=group_input, col=2, method="nested-loop");
+
+key_grp = agg_result[, 1];
+profit_sum = rowSums(agg_result[, 2:ncol(agg_result)]);
+
+# EXTRACTING D_YEAR, S_NATION, P_CATEGORY (internal codes)
+d_year_grp = round(floor(key_grp / (s_nation_scale_grp * p_category_scale_grp)));
+s_nation_grp = round(floor((key_grp %% (s_nation_scale_grp * p_category_scale_grp)) / p_category_scale_grp));
+p_category_grp = round(key_grp %% p_category_scale_grp);
+
+# Decode specs for later
+sup_dec_spec = "{ \"recode\": [\"C1\"] }";
+part_dec_spec = "{ \"recode\": [\"C1\"] }";
+
+# Decode categories for display-code mapping (unordered)
+p_cat_dec_all = transformdecode(target=p_category_grp, spec=part_dec_spec, meta=part_cat_meta);
+
+# Build display codes to match legacy meta mapping for p_category
+p_category_disp = matrix(0, rows=nrow(p_cat_dec_all), cols=1);
+for (i in 1:nrow(p_cat_dec_all)) {
+ cat_str = as.scalar(p_cat_dec_all[i,1]);
+ if (cat_str == "MFGR#11") p_category_disp[i,1] = 1;
+ else if (cat_str == "MFGR#12") p_category_disp[i,1] = 2;
+ else if (cat_str == "MFGR#13") p_category_disp[i,1] = 6;
+ else if (cat_str == "MFGR#15") p_category_disp[i,1] = 20;
+ else if (cat_str == "MFGR#21") p_category_disp[i,1] = 14;
+ else if (cat_str == "MFGR#22") p_category_disp[i,1] = 10;
+ else if (cat_str == "MFGR#23") p_category_disp[i,1] = 25;
+ else if (cat_str == "MFGR#24") p_category_disp[i,1] = 24;
+ else if (cat_str == "MFGR#25") p_category_disp[i,1] = 5;
+ else p_category_disp[i,1] = as.double(0);
+}
+
+# s_nation codes already align with legacy mapping; reuse as display codes
+s_nation_disp = s_nation_grp;
+
+# Compute display key using display codes
+s_nation_scale_disp = ceil(max(s_nation_disp)) + 1;
+p_category_scale_disp = ceil(max(p_category_disp)) + 1;
+d_year_scale_disp = ceil(max(d_year_grp)) + 1;
+
+key_disp = d_year_grp * s_nation_scale_disp * p_category_scale_disp + s_nation_disp * p_category_scale_disp + p_category_disp;
+
+# Compose display result and sort by display key to match legacy order
+result_disp = cbind(d_year_grp, s_nation_disp, p_category_disp, profit_sum, key_disp);
+idx_order = order(target=result_disp, by=5, decreasing=FALSE, index.return=TRUE);
+result_ordered_disp = order(target=result_disp, by=5, decreasing=FALSE, index.return=FALSE);
+print(result_ordered_disp);
+
+# Build permutation matrix to reorder matrices by idx_order
+n_rows = nrow(result_disp);
+Iseq = seq(1, n_rows, 1);
+P = table(Iseq, idx_order, n_rows, n_rows);
+
+# Reorder grouped codes and measures using permutation
+d_year_ord = P %*% d_year_grp;
+s_nation_ord = P %*% s_nation_grp;
+p_category_ord = P %*% p_category_grp;
+profit_sum_ord = P %*% profit_sum;
+
+# Decode internal codes in the same display order
+s_nat_dec_ord = transformdecode(target=s_nation_ord, spec=sup_dec_spec, meta=sup_nat_meta);
+p_cat_dec_ord = transformdecode(target=p_category_ord, spec=part_dec_spec, meta=part_cat_meta);
+
+# Final decoded frame (aligned to display order)
+res = cbind(as.frame(d_year_ord), s_nat_dec_ord, p_cat_dec_ord, as.frame(profit_sum_ord));
+print(res);
+
diff --git a/scripts/staging/ssb/queries/q4_3.dml b/scripts/staging/ssb/queries/q4_3.dml
new file mode 100644
index 00000000000..554aae99e36
--- /dev/null
+++ b/scripts/staging/ssb/queries/q4_3.dml
@@ -0,0 +1,195 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+
+# DML-script implementing the ssb query Q4.3 in SystemDS.
+
+/* DML-script implementing the ssb query Q4.3 in SystemDS with on-the-fly encoding (no external meta files).
+SELECT
+ d_year,
+ s_city,
+ p_brand,
+ SUM(lo_revenue - lo_supplycost) AS PROFIT
+FROM dates, customer, supplier, part, lineorder
+WHERE
+ lo_custkey = c_custkey
+ AND lo_suppkey = s_suppkey
+ AND lo_partkey = p_partkey
+ AND lo_orderdate = d_datekey
+ AND s_nation = 'UNITED STATES'
+ AND (
+ d_year = 1997
+ OR d_year = 1998
+ )
+ AND p_category = 'MFGR#14'
+GROUP BY d_year, s_city, p_brand
+ORDER BY d_year, s_city, p_brand;
+*/
+
+# -- SOURCING THE RA-FUNCTIONS --
+source("./scripts/builtin/raSelection.dml") as raSel
+source("./scripts/builtin/raJoin.dml") as raJoin
+source("./scripts/builtin/raGroupby.dml") as raGrp
+
+## Input parameter
+input_dir = $input_dir;
+
+# -- READING INPUT FILES --
+# CSV TABLES
+date_csv = read(input_dir + "/date.tbl", data_type="frame", format="csv", header=FALSE, sep="|");
+lineorder_csv = read(input_dir + "/lineorder.tbl", data_type="frame", format="csv", header=FALSE, sep="|");
+part_csv = read(input_dir + "/part.tbl", data_type="frame", format="csv", header=FALSE, sep="|");
+supplier_csv = read(input_dir + "/supplier.tbl", data_type="frame", format="csv", header=FALSE, sep="|");
+customer_csv = read(input_dir + "/customer.tbl", data_type="frame", format="csv", header=FALSE, sep="|");
+
+
+# -- PREPARING --
+# EXTRACTING MINIMAL DATE DATA TO OPTIMIZE RUNTIME => COL-1 : DATE-KEY | COL-5 : D_YEAR
+date_csv_min = cbind(date_csv[, 1], date_csv[, 5]);
+date_matrix_min = as.matrix(date_csv_min);
+
+# EXTRACTING MINIMAL LINEORDER DATA TO OPTIMIZE RUNTIME => COL-3 : LO_CUSTKEY | COL-4 : LO_PARTKEY |
+# COL-5 : LO_SUPPKEY | COL-6 : LO_ORDERDATE | COL-13 : LO_REVENUE | COL-14 : LO_SUPPLYCOST
+lineorder_csv_min = cbind(lineorder_csv[, 3], lineorder_csv[, 4], lineorder_csv[, 5], lineorder_csv[, 6], lineorder_csv[, 13], lineorder_csv[, 14]);
+lineorder_matrix_min = as.matrix(lineorder_csv_min);
+
+## Prepare PART on-the-fly encodings (only need p_brand encoding, filter by p_category string)
+# We'll encode column 5 (p_brand) on-the-fly and later filter by category string 'MFGR#14'.
+[part_brand_enc_f, part_brand_meta] = transformencode(target=part_csv[,5], spec="{ \"ids\": false, \"recode\": [\"C1\"] }");
+
+# EXTRACTING MINIMAL CUSTOMER DATA TO OPTIMIZE RUNTIME => COL-1 : CUSTOMER-KEY
+cust_csv_min = customer_csv[, 1];
+cust_matrix_min = as.matrix(cust_csv_min);
+
+## Prepare SUPPLIER on-the-fly encodings (encode s_city, filter by s_nation string)
+[sup_city_enc_f, sup_city_meta] = transformencode(target=supplier_csv[,4], spec="{ \"ids\": false, \"recode\": [\"C1\"] }");
+
+
+## -- FILTERING THE DATA WITH RA-SELECTION FUNCTION / LOOPS --
+# D_YEAR = 1997 OR 1998
+d_year_filt_1 = raSel::m_raSelection(date_matrix_min, col=2, op="==", val=1997);
+d_year_filt_2 = raSel::m_raSelection(date_matrix_min, col=2, op="==", val=1998);
+d_year_filt = rbind(d_year_filt_1, d_year_filt_2);
+
+# Build filtered SUPPLIER table (s_nation == 'UNITED STATES'), keeping key and encoded city
+sup_filt_keys = matrix(0, rows=0, cols=1);
+sup_filt_city = matrix(0, rows=0, cols=1);
+for (i in 1:nrow(supplier_csv)) {
+ if (as.scalar(supplier_csv[i,5]) == "UNITED STATES") {
+ key_val = as.double(as.scalar(supplier_csv[i,1]));
+ city_code = as.double(as.scalar(sup_city_enc_f[i,1]));
+ sup_filt_keys = rbind(sup_filt_keys, matrix(key_val, rows=1, cols=1));
+ sup_filt_city = rbind(sup_filt_city, matrix(city_code, rows=1, cols=1));
+ }
+}
+if (nrow(sup_filt_keys) == 0) {
+ # Fallback to avoid empty join
+ sup_filt_keys = matrix(0, rows=1, cols=1);
+ sup_filt_city = matrix(0, rows=1, cols=1);
+}
+sup_filt = cbind(sup_filt_keys, sup_filt_city);
+
+# Build filtered PART table (p_category == 'MFGR#14'), keeping key and encoded brand
+part_filt_keys = matrix(0, rows=0, cols=1);
+part_filt_brand = matrix(0, rows=0, cols=1);
+for (i in 1:nrow(part_csv)) {
+ if (as.scalar(part_csv[i,4]) == "MFGR#14") {
+ key_val = as.double(as.scalar(part_csv[i,1]));
+ brand_code = as.double(as.scalar(part_brand_enc_f[i,1]));
+ part_filt_keys = rbind(part_filt_keys, matrix(key_val, rows=1, cols=1));
+ part_filt_brand = rbind(part_filt_brand, matrix(brand_code, rows=1, cols=1));
+ }
+}
+if (nrow(part_filt_keys) == 0) {
+ part_filt_keys = matrix(0, rows=1, cols=1);
+ part_filt_brand = matrix(0, rows=1, cols=1);
+}
+part_filt = cbind(part_filt_keys, part_filt_brand);
+
+
+# -- JOIN TABLES WITH RA-JOIN FUNCTION --
+# JOINING MINIMIZED LINEORDER TABLE WITH FILTERED SUPPLIER TABLE WHERE LO_SUPPKEY = S_SUPPKEY
+lo_sup = raJoin::m_raJoin(A=lineorder_matrix_min, colA=3, B=sup_filt, colB=1, method="sort-merge");
+
+# JOIN: ⨝ PART WHERE LO_PARTKEY = P_PARTKEY
+lo_sup_part = raJoin::m_raJoin(A=lo_sup, colA=2, B=part_filt, colB=1, method="sort-merge");
+
+# JOIN: ⨝ DATE WHERE LO_ORDERDATE = D_DATEKEY
+lo_sup_part_date = raJoin::m_raJoin(A=lo_sup_part, colA=4, B=d_year_filt, colB=1, method="sort-merge");
+
+# JOIN: ⨝ CUSTOMER WHERE LO_CUSTKEY = C_CUSTKEY (no filter used, but keep join for parity)
+cust_matrix_min = as.matrix(customer_csv[,1]);
+joined_matrix = raJoin::m_raJoin(A=lo_sup_part_date, colA=1, B=cust_matrix_min, colB=1, method="sort-merge");
+
+
+# -- GROUP-BY & AGGREGATION --
+# LO_REVENUE : COLUMN 5 OF LINEORDER-MIN-MATRIX
+lo_revenue = joined_matrix[, 5];
+# LO_SUPPLYCOST : COLUMN 6 OF LINEORDER-MIN-MATRIX
+lo_supplycost = joined_matrix[, 6];
+# D_YEAR : last column added in the previous join with date (2nd col of date_min)
+d_year = joined_matrix[,(ncol(lineorder_matrix_min) + ncol(sup_filt) + ncol(part_filt) + 2)];
+# S_CITY (encoded) : COLUMN 2 OF SUPPLIER-FILTERED MATRIX
+s_city = joined_matrix[,(ncol(lineorder_matrix_min) + 2)];
+# P_BRAND (encoded) : COLUMN 2 OF PART-FILTERED MATRIX
+p_brand = joined_matrix[,(ncol(lineorder_matrix_min) + ncol(sup_filt) + 2)];
+
+profit = lo_revenue - lo_supplycost;
+
+# CALCULATING COMBINATION KEY WITH PRIORITY: D_YEAR, S_CITY, P_BRAND
+max_s_city = max(s_city);
+max_p_brand = max(p_brand);
+max_d_year = max(d_year);
+
+s_city_scale_f = ceil(max_s_city) + 1;
+p_brand_scale_f = ceil(max_p_brand) + 1;
+d_year_scale_f = ceil(max_d_year) + 1;
+
+combined_key = d_year * s_city_scale_f * p_brand_scale_f + s_city * p_brand_scale_f + p_brand;
+
+group_input = cbind(profit, combined_key);
+agg_result = raGrp::m_raGroupby(X=group_input, col=2, method="nested-loop");
+
+key = agg_result[, 1];
+profit = rowSums(agg_result[, 2:ncol(agg_result)]);
+
+# EXTRACTING D_YEAR, S_CITY, P_BRAND
+d_year = round(floor(key / (s_city_scale_f * p_brand_scale_f)));
+s_city = round(floor((key %% (s_city_scale_f * p_brand_scale_f)) / p_brand_scale_f));
+p_brand = round(key %% p_brand_scale_f);
+
+result = cbind(d_year, s_city, p_brand, profit, key);
+
+# -- SORTING --
+# PRIORITY 1 D_YEAR, 2 S_CITY, 3 P_BRAND
+result_ordered = order(target=result, by=5, decreasing=FALSE, index.return=FALSE);
+print(result_ordered);
+
+# -- DECODING S_CITY & P_BRAND (using on-the-fly meta from transformencode) --
+sup_dec_spec = "{ \"recode\": [\"C1\"] }";
+part_dec_spec = "{ \"recode\": [\"C1\"] }";
+
+s_city_dec = transformdecode(target=result_ordered[, 2], spec=sup_dec_spec, meta=sup_city_meta);
+p_brand_dec = transformdecode(target=result_ordered[, 3], spec=part_dec_spec, meta=part_brand_meta);
+
+res = cbind(as.frame(result_ordered[, 1]), s_city_dec, p_brand_dec, as.frame(result_ordered[, 4]));
+
+print(res);
diff --git a/scripts/staging/ssb/shell/run_all_perf.sh b/scripts/staging/ssb/shell/run_all_perf.sh
new file mode 100644
index 00000000000..71f6810681e
--- /dev/null
+++ b/scripts/staging/ssb/shell/run_all_perf.sh
@@ -0,0 +1,1530 @@
+#!/usr/bin/env bash
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+
+# Multi-Engine SSB Performance Benchmark Runner
+# =============================================
+#
+# CORE SCRIPTS STATUS:
+# - Version: 1.0 (September 5, 2025)
+# - Status: Production-Ready with Advanced Statistical Analysis
+#
+# ENHANCED FEATURES IMPLEMENTED:
+# ✓ Multi-engine benchmarking (SystemDS, PostgreSQL, DuckDB)
+# ✓ Advanced statistical analysis (mean, stdev, p95, CV) with high-precision calculations
+# ✓ Single-pass timing optimization eliminating cache effects between measurements
+# ✓ Cross-engine core timing support (SystemDS stats, PostgreSQL EXPLAIN, DuckDB JSON profiling)
+# ✓ Adaptive terminal layout with dynamic column scaling and multi-row statistics display
+# ✓ Comprehensive metadata collection (system info, software versions, data build info)
+# ✓ Environment verification and graceful degradation for missing engines
+# ✓ Real-time progress indicators with proper terminal width handling
+# ✓ Precision timing measurements with millisecond accuracy using /usr/bin/time -p
+# ✓ Robust error handling with pre-flight validation and error propagation
+# ✓ CSV and JSON output with timestamped files and complete statistical data
+# ✓ Fastest engine detection with tie handling
+# ✓ Database connection validation and parallel execution control (disabled for fair comparison)
+# ✓ Cross-platform compatibility (macOS/Linux) with intelligent executable discovery
+# ✓ Reproducible benchmarking with configurable seeds and detailed run configuration
+#
+# RECENT IMPORTANT ADDITIONS:
+# - Accepts --input-dir=PATH and forwards it into SystemDS DML runs via
+# `-nvargs input_dir=/path/to/data`. This allows DML queries to load data from
+# custom locations without hardcoded paths.
+# - Runner performs a pre-flight input-dir existence check and exits early with
+# a clear message when the directory is missing.
+# - Test-run output is scanned for runtime SystemDS errors; when detected the
+# runner marks the query as failed and includes an `error_message` field in
+# the generated JSON results to aid debugging and CI automation.
+#
+# STATISTICAL MEASUREMENTS:
+# - Mean: Arithmetic average execution time (typical performance expectation)
+# - Standard Deviation: Population stdev measuring consistency/reliability
+# - P95 Percentile: 95th percentile for worst-case performance bounds
+# - Coefficient of Variation: Relative variability as percentage for cross-scale comparison
+# - Display Format: "1200.0 (±14.1ms/1.2%, p95:1220.0ms)" showing all key metrics
+#
+# ENGINES SUPPORTED:
+# - SystemDS: Machine learning platform with DML queries (single-threaded via XML config)
+# - PostgreSQL: Industry-standard relational database (parallel workers disabled)
+# - DuckDB: High-performance analytical database (single-threaded via PRAGMA)
+#
+# USAGE (from repo root):
+# scripts/ssb/shell/run_all_perf.sh # run full benchmark with all engines
+# scripts/ssb/shell/run_all_perf.sh --stats # enable internal engine timing statistics
+# scripts/ssb/shell/run_all_perf.sh --warmup=3 --repeats=10 # custom warmup and repetition settings
+# scripts/ssb/shell/run_all_perf.sh --layout=wide # force wide table layout
+# scripts/ssb/shell/run_all_perf.sh --seed=12345 # reproducible benchmark with specific seed
+# scripts/ssb/shell/run_all_perf.sh q1.1 q2.3 q4.1 # benchmark specific queries only
+#
+set -euo pipefail
+export LC_ALL=C
+
+REPEATS=5
+WARMUP=1
+POSTGRES_DB="ssb"
+POSTGRES_USER="$(whoami)"
+POSTGRES_HOST="localhost"
+
+export _JAVA_OPTIONS="${_JAVA_OPTIONS:-} -Xms2g -Xmx2g -XX:+UseParallelGC -XX:ParallelGCThreads=1"
+
+# Determine script directory and project root (repo root)
+if command -v realpath >/dev/null 2>&1; then
+ SCRIPT_DIR="$(dirname "$(realpath "$0")")"
+else
+ SCRIPT_DIR="$(python - <<'PY'
+import os, sys
+print(os.path.dirname(os.path.abspath(sys.argv[1])))
+PY
+"$0")"
+fi
+# Resolve repository root robustly (script may be in scripts/ssb/shell)
+if command -v git >/dev/null 2>&1 && git -C "$SCRIPT_DIR" rev-parse --show-toplevel >/dev/null 2>&1; then
+ PROJECT_ROOT="$(git -C "$SCRIPT_DIR" rev-parse --show-toplevel)"
+else
+ # Fallback: ascend until we find markers (.git or pom.xml)
+ __dir="$SCRIPT_DIR"
+ PROJECT_ROOT=""
+ while [[ "$__dir" != "/" ]]; do
+ if [[ -d "$__dir/.git" || -f "$__dir/pom.xml" ]]; then
+ PROJECT_ROOT="$__dir"; break
+ fi
+ __dir="$(dirname "$__dir")"
+ done
+ : "${PROJECT_ROOT:=$(cd "$SCRIPT_DIR/../../../" && pwd)}"
+fi
+
+# Create single-thread configuration
+CONF_DIR="$PROJECT_ROOT/conf"
+SINGLE_THREAD_CONF="$CONF_DIR/single_thread.xml"
+mkdir -p "$CONF_DIR"
+if [[ ! -f "$SINGLE_THREAD_CONF" ]]; then
+cat > "$SINGLE_THREAD_CONF" <<'XML'
+
+
+ sysds.cp.parallel.opsfalse
+
+
+ sysds.num.threads1
+
+
+XML
+fi
+SYS_EXTRA_ARGS=( "-config" "$SINGLE_THREAD_CONF" )
+
+# Query and system directories
+QUERY_DIR="$PROJECT_ROOT/scripts/ssb/queries"
+
+# Locate SystemDS binary
+SYSTEMDS_CMD="$PROJECT_ROOT/bin/systemds"
+if [[ ! -x "$SYSTEMDS_CMD" ]]; then
+ SYSTEMDS_CMD="$(command -v systemds || true)"
+fi
+if [[ -z "$SYSTEMDS_CMD" || ! -x "$SYSTEMDS_CMD" ]]; then
+ echo "SystemDS binary not found." >&2
+ exit 1
+fi
+
+# Database directories and executables
+# SQL files were moved under scripts/ssb/sql
+SQL_DIR="$PROJECT_ROOT/scripts/ssb/sql"
+
+# Try to find PostgreSQL psql executable
+PSQL_EXEC=""
+for path in "/opt/homebrew/opt/libpq/bin/psql" "/usr/local/bin/psql" "/usr/bin/psql" "$(command -v psql || true)"; do
+ if [[ -x "$path" ]]; then
+ PSQL_EXEC="$path"
+ break
+ fi
+done
+
+# Try to find DuckDB executable
+DUCKDB_EXEC=""
+for path in "/opt/homebrew/bin/duckdb" "/usr/local/bin/duckdb" "/usr/bin/duckdb" "$(command -v duckdb || true)"; do
+ if [[ -x "$path" ]]; then
+ DUCKDB_EXEC="$path"
+ break
+ fi
+done
+
+DUCKDB_DB_PATH="$SQL_DIR/ssb.duckdb"
+
+# Environment verification
+verify_environment() {
+ local ok=true
+ echo "Verifying environment..."
+
+ if [[ ! -x "$SYSTEMDS_CMD" ]]; then
+ echo "✗ SystemDS binary missing ($SYSTEMDS_CMD)" >&2
+ ok=false
+ else
+ echo "✓ SystemDS binary found: $SYSTEMDS_CMD"
+ fi
+
+ if [[ -z "$PSQL_EXEC" || ! -x "$PSQL_EXEC" ]]; then
+ echo "✗ psql not found (tried common paths)" >&2
+ echo " PostgreSQL benchmarks will be skipped" >&2
+ PSQL_EXEC=""
+ else
+ echo "✓ psql found: $PSQL_EXEC"
+ if ! "$PSQL_EXEC" -U "$POSTGRES_USER" -h "$POSTGRES_HOST" -d "$POSTGRES_DB" -c "SELECT 1" >/dev/null 2>&1; then
+ echo "✗ Could not connect to PostgreSQL database ($POSTGRES_DB)" >&2
+ echo " PostgreSQL benchmarks will be skipped" >&2
+ PSQL_EXEC=""
+ else
+ echo "✓ PostgreSQL database connection successful"
+ fi
+ fi
+
+ if [[ -z "$DUCKDB_EXEC" || ! -x "$DUCKDB_EXEC" ]]; then
+ echo "✗ DuckDB not found (tried common paths)" >&2
+ echo " DuckDB benchmarks will be skipped" >&2
+ DUCKDB_EXEC=""
+ else
+ echo "✓ DuckDB found: $DUCKDB_EXEC"
+ if [[ ! -f "$DUCKDB_DB_PATH" ]]; then
+ echo "✗ DuckDB database missing ($DUCKDB_DB_PATH)" >&2
+ echo " DuckDB benchmarks will be skipped" >&2
+ DUCKDB_EXEC=""
+ elif ! "$DUCKDB_EXEC" "$DUCKDB_DB_PATH" -c "SELECT 1" >/dev/null 2>&1; then
+ echo "✗ DuckDB database could not be opened" >&2
+ echo " DuckDB benchmarks will be skipped" >&2
+ DUCKDB_EXEC=""
+ else
+ echo "✓ DuckDB database accessible"
+ fi
+ fi
+
+ if [[ ! -x "$SYSTEMDS_CMD" ]]; then
+ echo "Error: SystemDS is required but not found" >&2
+ exit 1
+ fi
+
+ echo ""
+}
+
+# Convert seconds to milliseconds
+sec_to_ms() {
+ awk -v sec="$1" 'BEGIN{printf "%.1f", sec * 1000}'
+}
+
+# Statistical functions for multiple measurements
+calculate_statistics() {
+ local values=("$@")
+ local n=${#values[@]}
+
+ if [[ $n -eq 0 ]]; then
+ echo "0|0|0"
+ return
+ fi
+
+ if [[ $n -eq 1 ]]; then
+ # mean|stdev|p95
+ printf '%.1f|0.0|%.1f\n' "${values[0]}" "${values[0]}"
+ return
+ fi
+
+ # Compute mean and population stdev with higher precision in a single awk pass
+ local mean_stdev
+ mean_stdev=$(printf '%s\n' "${values[@]}" | awk '
+ { x[NR]=$1; s+=$1 }
+ END {
+ n=NR; if(n==0){ printf "0|0"; exit }
+ m=s/n;
+ ss=0; for(i=1;i<=n;i++){ d=x[i]-m; ss+=d*d }
+ stdev=sqrt(ss/n);
+ printf "%.6f|%.6f", m, stdev
+ }')
+
+ local mean=$(echo "$mean_stdev" | cut -d'|' -f1)
+ local stdev=$(echo "$mean_stdev" | cut -d'|' -f2)
+
+ # Calculate p95 (nearest-rank: ceil(0.95*n))
+ local sorted_values=($(printf '%s\n' "${values[@]}" | sort -n))
+ local p95_index=$(awk -v n="$n" 'BEGIN{ idx = int(0.95*n + 0.999999); if(idx<1) idx=1; if(idx>n) idx=n; print idx-1 }')
+ local p95=${sorted_values[$p95_index]}
+
+ # Format to one decimal place
+ printf '%.1f|%.1f|%.1f\n' "$mean" "$stdev" "$p95"
+}
+
+# Format statistics for display
+format_statistics() {
+ local mean="$1"
+ local stdev="$2"
+ local p95="$3"
+ local repeats="$4"
+
+ if [[ $repeats -eq 1 ]]; then
+ echo "$mean"
+ else
+ # Calculate coefficient of variation (CV) as percentage
+ local cv_percent=0
+ if [[ $(awk -v mean="$mean" 'BEGIN{print (mean > 0)}') -eq 1 ]]; then
+ cv_percent=$(awk -v stdev="$stdev" -v mean="$mean" 'BEGIN{printf "%.1f", (stdev * 100) / mean}')
+ fi
+ echo "$mean (±${stdev}ms/${cv_percent}%, p95:${p95}ms)"
+ fi
+}
+
+# Format only the stats line (without the mean), e.g., "(±10.2ms/0.6%, p95:1740.0ms)"
+format_stats_only() {
+ local mean="$1"
+ local stdev="$2"
+ local p95="$3"
+ local repeats="$4"
+
+ if [[ $repeats -eq 1 ]]; then
+ echo ""
+ return
+ fi
+ # Only for numeric means
+ if ! [[ "$mean" =~ ^[0-9]+(\.[0-9]+)?$ ]]; then
+ echo ""
+ return
+ fi
+ local cv_percent=0
+ if [[ $(awk -v mean="$mean" 'BEGIN{print (mean > 0)}') -eq 1 ]]; then
+ cv_percent=$(awk -v stdev="$stdev" -v mean="$mean" 'BEGIN{printf "%.1f", (stdev * 100) / mean}')
+ fi
+ echo "(±${stdev}ms/${cv_percent}%, p95:${p95}ms)"
+}
+
+# Format only the CV line (±stdev/CV%)
+format_cv_only() {
+ local mean="$1"; local stdev="$2"; local repeats="$3"
+ if [[ $repeats -eq 1 ]]; then echo ""; return; fi
+ if ! [[ "$mean" =~ ^[0-9]+(\.[0-9]+)?$ ]]; then echo ""; return; fi
+ local cv_percent=0
+ if [[ $(awk -v mean="$mean" 'BEGIN{print (mean > 0)}') -eq 1 ]]; then
+ cv_percent=$(awk -v stdev="$stdev" -v mean="$mean" 'BEGIN{printf "%.1f", (stdev * 100) / mean}')
+ fi
+ echo "±${stdev}ms/${cv_percent}%"
+}
+
+# Format only the p95 line
+format_p95_only() {
+ local p95="$1"; local repeats="$2"
+ if [[ $repeats -eq 1 ]]; then echo ""; return; fi
+ if ! [[ "$p95" =~ ^[0-9]+(\.[0-9]+)?$ ]]; then echo ""; return; fi
+ echo "p95:${p95}ms"
+}
+
+# Column widths for wide layout - optimized for 125-char terminals
+WIDE_COL_WIDTHS=(8 14 14 12 16 12 12 18)
+
+# Draw a grid line like +----------+----------------+...
+grid_line_wide() {
+ local parts=("+")
+ for w in "${WIDE_COL_WIDTHS[@]}"; do
+ parts+=("$(printf '%*s' "$((w+2))" '' | tr ' ' '-')+")
+ done
+ printf '%s\n' "${parts[*]}" | tr -d ' '
+}
+
+# Print a grid row with vertical separators using the wide layout widths
+grid_row_wide() {
+ local -a cells=("$@")
+ local cols=${#WIDE_COL_WIDTHS[@]}
+ while [[ ${#cells[@]} -lt $cols ]]; do
+ cells+=("")
+ done
+
+ # Build a printf format string that right-aligns numeric and statistic-like cells
+ # (numbers, lines starting with ± or p95, or containing p95/±) while leaving the
+ # first column (query) left-aligned for readability.
+ local fmt=""
+ for i in $(seq 0 $((cols-1))); do
+ local w=${WIDE_COL_WIDTHS[i]}
+ if [[ $i -eq 0 ]]; then
+ # Query name: left-align
+ fmt+="| %-${w}s"
+ else
+ local cell="${cells[i]}"
+ # Heuristic: right-align if the cell is a plain number or contains statistic markers
+ if [[ "$cell" =~ ^[[:space:]]*[0-9]+(\.[0-9]+)?[[:space:]]*$ ]] || [[ "$cell" == ±* ]] || [[ "$cell" == *'±'* ]] || [[ "$cell" == p95* ]] || [[ "$cell" == *'p95'* ]] || [[ "$cell" == \(* ]]; then
+ fmt+=" | %${w}s"
+ else
+ fmt+=" | %-${w}s"
+ fi
+ fi
+ done
+ fmt+=" |\n"
+
+ printf "$fmt" "${cells[@]}"
+}
+
+# Time a command and return real time in ms
+time_command_ms() {
+ local out
+ # Properly capture stderr from /usr/bin/time while suppressing stdout of the command
+ out=$({ /usr/bin/time -p "$@" > /dev/null; } 2>&1)
+ local real_sec=$(echo "$out" | awk '/^real /{print $2}')
+ if [[ -z "$real_sec" || ! "$real_sec" =~ ^[0-9]+(\.[0-9]+)?$ ]]; then
+ echo "(error)"
+ return 1
+ fi
+ sec_to_ms "$real_sec"
+}
+
+# Time a command, capturing stdout to a file, and return real time in ms
+time_command_ms_capture() {
+ local stdout_file="$1"; shift
+ local out
+ out=$({ /usr/bin/time -p "$@" > "$stdout_file"; } 2>&1)
+ local real_sec=$(echo "$out" | awk '/^real /{print $2}')
+ if [[ -z "$real_sec" || ! "$real_sec" =~ ^[0-9]+(\.[0-9]+)?$ ]]; then
+ echo "(error)"
+ return 1
+ fi
+ sec_to_ms "$real_sec"
+}
+
+# Run a SystemDS query and compute statistics
+run_systemds_avg() {
+ local dml="$1"
+ # Optional second parameter: path to write an error message if the test-run fails
+ local err_out_file="${2:-}"
+ local shell_times=()
+ local core_times=()
+ local core_have=false
+
+ # Change to project root directory so relative paths in DML work correctly
+ local original_dir="$(pwd)"
+ cd "$PROJECT_ROOT"
+
+ # First, test run to validate the query (avoids timing zero or errors later)
+ tmp_test=$(mktemp)
+ if $RUN_STATS; then
+ if ! "$SYSTEMDS_CMD" "$dml" -stats "${SYS_EXTRA_ARGS[@]}" "${NVARGS[@]}" > "$tmp_test" 2>&1; then
+ err_msg=$(sed -n '1,200p' "$tmp_test" | tr '\n' ' ')
+ echo "Error: SystemDS test run failed for $dml: $err_msg" >&2
+ # Write error message to provided error file for JSON capture
+ if [[ -n "$err_out_file" ]]; then printf '%s' "$err_msg" > "$err_out_file" || true; fi
+ rm -f "$tmp_test"
+ echo "(error)|0|0|(n/a)|0|0"
+ cd "$original_dir"; return
+ fi
+ err_msg=$(sed -n '/An Error Occurred :/,$ p' "$tmp_test" | sed -n '1,200p' | tr '\n' ' ')
+ if [[ -n "$err_msg" ]]; then
+ echo "Error: SystemDS reported runtime error for $dml: $err_msg" >&2
+ if [[ -n "$err_out_file" ]]; then printf '%s' "$err_msg" > "$err_out_file" || true; fi
+ rm -f "$tmp_test"
+ echo "(error)|0|0|(n/a)|0|0"
+ cd "$original_dir"; return
+ fi
+ else
+ if ! "$SYSTEMDS_CMD" "$dml" "${SYS_EXTRA_ARGS[@]}" "${NVARGS[@]}" > "$tmp_test" 2>&1; then
+ err_msg=$(sed -n '1,200p' "$tmp_test" | tr '\n' ' ')
+ echo "Error: SystemDS test run failed for $dml: $err_msg" >&2
+ if [[ -n "$err_out_file" ]]; then printf '%s' "$err_msg" > "$err_out_file" || true; fi
+ rm -f "$tmp_test"
+ echo "(error)|0|0|(n/a)|0|0"
+ cd "$original_dir"; return
+ fi
+ err_msg=$(sed -n '/An Error Occurred :/,$ p' "$tmp_test" | sed -n '1,200p' | tr '\n' ' ')
+ if [[ -n "$err_msg" ]]; then
+ echo "Error: SystemDS reported runtime error for $dml: $err_msg" >&2
+ if [[ -n "$err_out_file" ]]; then printf '%s' "$err_msg" > "$err_out_file" || true; fi
+ rm -f "$tmp_test"
+ echo "(error)|0|0|(n/a)|0|0"
+ cd "$original_dir"; return
+ fi
+ fi
+ rm -f "$tmp_test"
+
+ # Warmup runs
+ for ((w=1; w<=WARMUP; w++)); do
+ if $RUN_STATS; then
+ "$SYSTEMDS_CMD" "$dml" -stats "${SYS_EXTRA_ARGS[@]}" "${NVARGS[@]}" > /dev/null 2>&1 || true
+ else
+ "$SYSTEMDS_CMD" "$dml" "${SYS_EXTRA_ARGS[@]}" "${NVARGS[@]}" > /dev/null 2>&1 || true
+ fi
+ done
+
+ # Timed runs - collect all measurements
+ for ((i=1; i<=REPEATS; i++)); do
+ if $RUN_STATS; then
+ local shell_ms
+ local temp_file
+ temp_file=$(mktemp)
+ shell_ms=$(time_command_ms_capture "$temp_file" "$SYSTEMDS_CMD" "$dml" -stats "${SYS_EXTRA_ARGS[@]}" "${NVARGS[@]}") || {
+ rm -f "$temp_file"; cd "$original_dir"; echo "(error)|0|0|(n/a)|0|0"; return; }
+ shell_times+=("$shell_ms")
+
+ # Extract SystemDS internal timing from the same run
+ local internal_sec
+ internal_sec=$(awk '/Total execution time:/ {print $4}' "$temp_file" | tail -1 || true)
+ rm -f "$temp_file"
+ if [[ -n "$internal_sec" ]] && [[ "$internal_sec" =~ ^[0-9]+(\.[0-9]+)?$ ]]; then
+ local core_ms
+ core_ms=$(awk -v sec="$internal_sec" 'BEGIN{printf "%.1f", sec * 1000}')
+ core_times+=("$core_ms")
+ core_have=true
+ fi
+ else
+ local shell_ms
+ shell_ms=$(time_command_ms "$SYSTEMDS_CMD" "$dml" "${SYS_EXTRA_ARGS[@]}" "${NVARGS[@]}") || { cd "$original_dir"; echo "(error)|0|0|(n/a)|0|0"; return; }
+ shell_times+=("$shell_ms")
+ fi
+ done
+
+ # Return to original directory
+ cd "$original_dir"
+
+ # Calculate statistics for shell times
+ local shell_stats
+ shell_stats=$(calculate_statistics "${shell_times[@]}")
+
+ # Calculate statistics for core times if available
+ local core_stats
+ if $RUN_STATS && $core_have && [[ ${#core_times[@]} -gt 0 ]]; then
+ core_stats=$(calculate_statistics "${core_times[@]}")
+ else
+ core_stats="(n/a)|0|0"
+ fi
+
+ echo "$shell_stats|$core_stats"
+}
+
+# Run a PostgreSQL query and compute statistics
+run_psql_avg_ms() {
+ local sql_file="$1"
+
+ # Check if PostgreSQL is available
+ if [[ -z "$PSQL_EXEC" ]]; then
+ echo "(unavailable)|0|0|(n/a)|0|0"
+ return
+ fi
+
+ # Test run first
+ "$PSQL_EXEC" -U "$POSTGRES_USER" -h "$POSTGRES_HOST" -d "$POSTGRES_DB" \
+ -v ON_ERROR_STOP=1 -q \
+ -c "SET max_parallel_workers=0; SET max_parallel_maintenance_workers=0; SET max_parallel_workers_per_gather=0; SET parallel_leader_participation=off;" \
+ -f "$sql_file" >/dev/null 2>/dev/null || {
+ echo "(error)|0|0|(n/a)|0|0"
+ return
+ }
+
+ local shell_times=()
+ local core_times=()
+ local core_have=false
+
+ for ((i=1; i<=REPEATS; i++)); do
+ # Wall-clock shell time
+ local ms
+ ms=$(time_command_ms "$PSQL_EXEC" -U "$POSTGRES_USER" -h "$POSTGRES_HOST" -d "$POSTGRES_DB" \
+ -v ON_ERROR_STOP=1 -q \
+ -c "SET max_parallel_workers=0; SET max_parallel_maintenance_workers=0; SET max_parallel_workers_per_gather=0; SET parallel_leader_participation=off;" \
+ -f "$sql_file" 2>/dev/null) || {
+ echo "(error)|0|0|(n/a)|0|0"
+ return
+ }
+ shell_times+=("$ms")
+
+ # Core execution time using EXPLAIN ANALYZE (if --stats enabled)
+ if $RUN_STATS; then
+ local tmp_explain
+ tmp_explain=$(mktemp)
+
+ # Create EXPLAIN ANALYZE version of the query
+ echo "SET max_parallel_workers=0; SET max_parallel_maintenance_workers=0; SET max_parallel_workers_per_gather=0; SET parallel_leader_participation=off;" > "$tmp_explain"
+ echo "EXPLAIN (ANALYZE, BUFFERS, FORMAT TEXT)" >> "$tmp_explain"
+ cat "$sql_file" >> "$tmp_explain"
+
+ # Execute EXPLAIN ANALYZE and extract execution time
+ local explain_output core_ms
+ explain_output=$("$PSQL_EXEC" -U "$POSTGRES_USER" -h "$POSTGRES_HOST" -d "$POSTGRES_DB" \
+ -v ON_ERROR_STOP=1 -q -f "$tmp_explain" 2>/dev/null || true)
+
+ if [[ -n "$explain_output" ]]; then
+ # Extract "Execution Time: X.XXX ms" from EXPLAIN ANALYZE output
+ local exec_time_ms
+ exec_time_ms=$(echo "$explain_output" | grep -oE "Execution Time: [0-9]+\.[0-9]+" | grep -oE "[0-9]+\.[0-9]+" | head -1 || true)
+
+ if [[ -n "$exec_time_ms" ]] && [[ "$exec_time_ms" =~ ^[0-9]+(\.[0-9]+)?$ ]]; then
+ core_ms=$(awk -v ms="$exec_time_ms" 'BEGIN{printf "%.1f", ms}')
+ core_times+=("$core_ms")
+ core_have=true
+ fi
+ fi
+
+ rm -f "$tmp_explain"
+ fi
+ done
+
+ # Build outputs
+ local shell_stats core_stats
+ shell_stats=$(calculate_statistics "${shell_times[@]}")
+ if $RUN_STATS && $core_have && [[ ${#core_times[@]} -gt 0 ]]; then
+ core_stats=$(calculate_statistics "${core_times[@]}")
+ else
+ core_stats="(n/a)|0|0"
+ fi
+ echo "$shell_stats|$core_stats"
+}
+
+# Run a DuckDB query and compute statistics
+run_duckdb_avg_ms() {
+ local sql_file="$1"
+
+ # Check if DuckDB is available
+ if [[ -z "$DUCKDB_EXEC" ]]; then
+ echo "(unavailable)|0|0|(n/a)|0|0"
+ return
+ fi
+
+ # Test run with minimal setup (no profiling)
+ local tmp_test
+ tmp_test=$(mktemp)
+ printf 'PRAGMA threads=1;\n' > "$tmp_test"
+ cat "$sql_file" >> "$tmp_test"
+ "$DUCKDB_EXEC" "$DUCKDB_DB_PATH" < "$tmp_test" >/dev/null 2>&1 || {
+ rm -f "$tmp_test"
+ echo "(error)|0|0|(n/a)|0|0"
+ return
+ }
+ rm -f "$tmp_test"
+
+ local shell_times=()
+ local core_times=()
+ local core_have=false
+
+ for ((i=1; i<=REPEATS; i++)); do
+ local tmp_sql iter_json
+ tmp_sql=$(mktemp)
+ if $RUN_STATS; then
+ # Enable JSON profiling per-run and write to a temporary file
+ iter_json=$(mktemp -t duckprof.XXXXXX).json
+ cat > "$tmp_sql" < "$tmp_sql"
+ fi
+ cat "$sql_file" >> "$tmp_sql"
+
+ # Wall-clock shell time
+ local ms
+ ms=$(time_command_ms "$DUCKDB_EXEC" "$DUCKDB_DB_PATH" < "$tmp_sql") || {
+ rm -f "$tmp_sql" ${iter_json:+"$iter_json"}
+ echo "(error)|0|0|(n/a)|0|0"
+ return
+ }
+ shell_times+=("$ms")
+
+ # Parse core latency from JSON profile if available
+ if $RUN_STATS && [[ -n "${iter_json:-}" && -f "$iter_json" ]]; then
+ local core_sec
+ if command -v jq >/dev/null 2>&1; then
+ core_sec=$(jq -r '.latency // empty' "$iter_json" 2>/dev/null || true)
+ else
+ core_sec=$(grep -oE '"latency"\s*:\s*[0-9.]+' "$iter_json" 2>/dev/null | sed -E 's/.*:\s*//' | head -1 || true)
+ fi
+ if [[ -n "$core_sec" ]] && [[ "$core_sec" =~ ^[0-9]+(\.[0-9]+)?$ ]]; then
+ local core_ms
+ core_ms=$(awk -v s="$core_sec" 'BEGIN{printf "%.1f", s*1000}')
+ core_times+=("$core_ms")
+ core_have=true
+ fi
+ fi
+
+ rm -f "$tmp_sql" ${iter_json:+"$iter_json"}
+ done
+
+ # Build outputs
+ local shell_stats core_stats
+ shell_stats=$(calculate_statistics "${shell_times[@]}")
+ if $RUN_STATS && $core_have && [[ ${#core_times[@]} -gt 0 ]]; then
+ core_stats=$(calculate_statistics "${core_times[@]}")
+ else
+ core_stats="(n/a)|0|0"
+ fi
+ echo "$shell_stats|$core_stats"
+}
+
+# Help function
+show_help() {
+ cat << 'EOF'
+Multi-Engine SSB Performance Benchmark Runner v1.0
+
+USAGE (from repo root):
+ scripts/ssb/shell/run_all_perf.sh [OPTIONS] [QUERIES...]
+
+OPTIONS:
+ -stats, --stats Enable SystemDS internal statistics collection
+ -warmup=N, --warmup=N Set number of warmup runs (default: 1)
+ -repeats=N, --repeats=N Set number of timing repetitions (default: 5)
+ -seed=N, --seed=N Set random seed for reproducible results (default: auto-generated)
+ -stacked, --stacked Use stacked, multi-line layout (best for narrow terminals)
+ -layout=MODE, --layout=MODE Set layout: auto|wide|stacked (default: auto)
+ Note: --layout=stacked is equivalent to --stacked
+ --layout=wide forces wide table layout
+ -input-dir=PATH, --input-dir=PATH Specify custom data directory (default: $PROJECT_ROOT/data)
+ -output-dir=PATH, --output-dir=PATH Specify custom output directory (default: $PROJECT_ROOT/scripts/ssb/shell/ssbOutputData/PerformanceData)
+ -h, -help, --help, --h Show this help message
+ -v, -version, --version, --v Show version information
+
+QUERIES:
+ If no queries are specified, all available SSB queries (q*.dml) will be executed.
+ To run specific queries, provide their names (with or without .dml extension):
+ scripts/ssb/shell/run_all_perf.sh q1.1 q2.3 q4.1
+
+EXAMPLES (from repo root):
+ scripts/ssb/shell/run_all_perf.sh # Run full benchmark with all engines
+ scripts/ssb/shell/run_all_perf.sh --warmup=3 --repeats=10 # Custom warmup and repetition settings
+ scripts/ssb/shell/run_all_perf.sh -warmup=3 -repeats=10 # Same with single dashes
+ scripts/ssb/shell/run_all_perf.sh --stats # Enable SystemDS internal timing
+ scripts/ssb/shell/run_all_perf.sh --layout=wide # Force wide table layout
+ scripts/ssb/shell/run_all_perf.sh --stacked # Force stacked layout for narrow terminals
+ scripts/ssb/shell/run_all_perf.sh q1.1 q2.3 # Benchmark specific queries only
+ scripts/ssb/shell/run_all_perf.sh --seed=12345 # Reproducible benchmark run
+ scripts/ssb/shell/run_all_perf.sh --input-dir=/path/to/data # Custom data directory
+ scripts/ssb/shell/run_all_perf.sh -input-dir=/path/to/data # Same as above (single dash)
+ scripts/ssb/shell/run_all_perf.sh --output-dir=/tmp/results # Custom output directory
+ scripts/ssb/shell/run_all_perf.sh -output-dir=/tmp/results # Same as above (single dash)
+
+ENGINES:
+ - SystemDS: Machine learning platform with DML queries
+ - PostgreSQL: Industry-standard relational database (if available)
+ - DuckDB: High-performance analytical database (if available)
+
+OUTPUT:
+ Results are saved in CSV and JSON formats with comprehensive metadata:
+ - Performance timing statistics (mean, stdev, p95)
+ - Engine comparison and fastest detection
+ - System information and run configuration
+
+STATISTICAL OUTPUT FORMAT:
+ 1824 (±10, p95:1840)
+ │ │ └── 95th percentile (worst-case bound)
+ │ └── Standard deviation (consistency measure)
+ └── Mean execution time (typical performance)
+
+For more information, see the documentation in scripts/ssb/README.md
+EOF
+}
+
+# Parse arguments
+RUN_STATS=false
+QUERIES=()
+SEED=""
+LAYOUT="auto"
+INPUT_DIR=""
+OUTPUT_DIR=""
+
+# Support both --opt=value and --opt value forms
+EXPECT_OPT=""
+for arg in "$@"; do
+ if [[ -n "$EXPECT_OPT" ]]; then
+ case "$EXPECT_OPT" in
+ seed)
+ SEED="$arg"
+ EXPECT_OPT=""
+ continue
+ ;;
+ input-dir)
+ INPUT_DIR="$arg"
+ EXPECT_OPT=""
+ continue
+ ;;
+ output-dir)
+ OUTPUT_DIR="$arg"
+ EXPECT_OPT=""
+ continue
+ ;;
+ warmup)
+ WARMUP="$arg"
+ if ! [[ "$WARMUP" =~ ^[0-9]+$ ]] || [[ "$WARMUP" -lt 0 ]]; then
+ echo "Error: --warmup requires a non-negative integer (e.g., --warmup 2)" >&2
+ exit 1
+ fi
+ EXPECT_OPT=""
+ continue
+ ;;
+ repeats)
+ REPEATS="$arg"
+ if ! [[ "$REPEATS" =~ ^[0-9]+$ ]] || [[ "$REPEATS" -lt 1 ]]; then
+ echo "Error: --repeats requires a positive integer (e.g., --repeats 5)" >&2
+ exit 1
+ fi
+ EXPECT_OPT=""
+ continue
+ ;;
+ layout)
+ LAYOUT="$arg"
+ if [[ "$LAYOUT" != "auto" && "$LAYOUT" != "wide" && "$LAYOUT" != "stacked" ]]; then
+ echo "Error: --layout requires one of: auto, wide, stacked (e.g., --layout wide)" >&2
+ exit 1
+ fi
+ EXPECT_OPT=""
+ continue
+ ;;
+ esac
+ fi
+
+ if [[ "$arg" == "--help" || "$arg" == "-help" || "$arg" == "-h" || "$arg" == "--h" ]]; then
+ show_help
+ exit 0
+ elif [[ "$arg" == "--version" || "$arg" == "-version" || "$arg" == "-v" || "$arg" == "--v" ]]; then
+ echo "Multi-Engine SSB Performance Benchmark Runner v1.0"
+ echo "First Public Release: September 5, 2025"
+ exit 0
+ elif [[ "$arg" == "--stats" || "$arg" == "-stats" ]]; then
+ RUN_STATS=true
+ elif [[ "$arg" == --seed=* || "$arg" == -seed=* ]]; then
+ SEED="${arg#*seed=}"
+ elif [[ "$arg" == "--seed" || "$arg" == "-seed" ]]; then
+ EXPECT_OPT="seed"
+ elif [[ "$arg" == --warmup=* || "$arg" == -warmup=* ]]; then
+ WARMUP="${arg#*warmup=}"
+ if ! [[ "$WARMUP" =~ ^[0-9]+$ ]] || [[ "$WARMUP" -lt 0 ]]; then
+ echo "Error: -warmup/--warmup requires a non-negative integer (e.g., -warmup=2)" >&2
+ exit 1
+ fi
+ elif [[ "$arg" == --input-dir=* || "$arg" == -input-dir=* ]]; then
+ INPUT_DIR="${arg#*input-dir=}"
+ elif [[ "$arg" == "--input-dir" || "$arg" == "-input-dir" ]]; then
+ EXPECT_OPT="input-dir"
+ elif [[ "$arg" == --output-dir=* || "$arg" == -output-dir=* ]]; then
+ OUTPUT_DIR="${arg#*output-dir=}"
+ elif [[ "$arg" == "--output-dir" || "$arg" == "-output-dir" ]]; then
+ EXPECT_OPT="output-dir"
+ elif [[ "$arg" == "--warmup" || "$arg" == "-warmup" ]]; then
+ EXPECT_OPT="warmup"
+ elif [[ "$arg" == --repeats=* || "$arg" == -repeats=* ]]; then
+ REPEATS="${arg#*repeats=}"
+ if ! [[ "$REPEATS" =~ ^[0-9]+$ ]] || [[ "$REPEATS" -lt 1 ]]; then
+ echo "Error: -repeats/--repeats requires a positive integer (e.g., -repeats=5)" >&2
+ exit 1
+ fi
+ elif [[ "$arg" == "--repeats" || "$arg" == "-repeats" ]]; then
+ EXPECT_OPT="repeats"
+ elif [[ "$arg" == "--stacked" || "$arg" == "-stacked" ]]; then
+ LAYOUT="stacked"
+ elif [[ "$arg" == --layout=* || "$arg" == -layout=* ]]; then
+ LAYOUT="${arg#*layout=}"
+ if [[ "$LAYOUT" != "auto" && "$LAYOUT" != "wide" && "$LAYOUT" != "stacked" ]]; then
+ echo "Error: -layout/--layout requires one of: auto, wide, stacked (e.g., --layout=wide)" >&2
+ exit 1
+ fi
+ elif [[ "$arg" == "--layout" || "$arg" == "-layout" ]]; then
+ EXPECT_OPT="layout"
+ else
+ # Check if argument looks like an unrecognized option (starts with dash)
+ if [[ "$arg" == -* ]]; then
+ echo "Error: Unrecognized option '$arg'" >&2
+ echo "Use --help or -h to see available options." >&2
+ exit 1
+ else
+ # Treat as query name
+ QUERIES+=( "$(echo "$arg" | tr '.' '_')" )
+ fi
+ fi
+ done
+
+# If the last option expected a value but none was provided
+if [[ -n "$EXPECT_OPT" ]]; then
+ case "$EXPECT_OPT" in
+ seed) echo "Error: -seed/--seed requires a value (e.g., -seed=12345)" >&2 ;;
+ warmup) echo "Error: -warmup/--warmup requires a value (e.g., -warmup=2)" >&2 ;;
+ repeats) echo "Error: -repeats/--repeats requires a value (e.g., -repeats=5)" >&2 ;;
+ layout) echo "Error: -layout/--layout requires a value (e.g., -layout=wide)" >&2 ;;
+ esac
+ exit 1
+fi
+
+# Generate seed if not provided
+if [[ -z "$SEED" ]]; then
+ SEED=$((RANDOM * 32768 + RANDOM))
+fi
+if [[ ${#QUERIES[@]} -eq 0 ]]; then
+ for f in "$QUERY_DIR"/q*.dml; do
+ [[ -e "$f" ]] || continue
+ bname="$(basename "$f")"
+ QUERIES+=( "${bname%.dml}" )
+ done
+fi
+
+# Set data directory
+if [[ -z "$INPUT_DIR" ]]; then
+ INPUT_DIR="$PROJECT_ROOT/data"
+fi
+
+# Set output directory
+if [[ -z "$OUTPUT_DIR" ]]; then
+ OUTPUT_DIR="$PROJECT_ROOT/scripts/ssb/shell/ssbOutputData/PerformanceData"
+fi
+
+# Normalize paths by removing trailing slashes
+INPUT_DIR="${INPUT_DIR%/}"
+OUTPUT_DIR="${OUTPUT_DIR%/}"
+
+# Pass input directory to DML scripts via SystemDS named arguments
+NVARGS=( -nvargs "input_dir=${INPUT_DIR}" )
+
+# Validate data directory
+if [[ ! -d "$INPUT_DIR" ]]; then
+ echo "Error: Data directory '$INPUT_DIR' does not exist." >&2
+ echo "Please ensure the directory exists or specify a different path with -input-dir." >&2
+ exit 1
+fi
+
+# Ensure output directory exists
+mkdir -p "$OUTPUT_DIR"
+
+# Metadata collection functions
+collect_system_metadata() {
+ local timestamp hostname systemds_version jdk_version postgres_version duckdb_version cpu_info ram_info
+
+ # Basic system info
+ timestamp=$(date -u '+%Y-%m-%d %H:%M:%S UTC')
+ hostname=$(hostname 2>/dev/null || echo "unknown")
+
+ # SystemDS version
+ if [[ -x "$SYSTEMDS_CMD" ]]; then
+ # Try to get version from pom.xml first
+ if [[ -f "$PROJECT_ROOT/pom.xml" ]]; then
+ systemds_version=$(grep -A1 'org.apache.systemds' "$PROJECT_ROOT/pom.xml" | grep '' | sed 's/.*\(.*\)<\/version>.*/\1/' | head -1 2>/dev/null || echo "unknown")
+ else
+ systemds_version="unknown"
+ fi
+
+ # If pom.xml method failed, try alternative methods
+ if [[ "$systemds_version" == "unknown" ]]; then
+ # Try to extract from SystemDS JAR manifest
+ if [[ -f "$PROJECT_ROOT/target/systemds.jar" ]]; then
+ systemds_version=$(unzip -p "$PROJECT_ROOT/target/systemds.jar" META-INF/MANIFEST.MF 2>/dev/null | grep "Implementation-Version" | cut -d: -f2 | tr -d ' ' || echo "unknown")
+ else
+ # Try to find any SystemDS JAR and extract version
+ local jar_file=$(find "$PROJECT_ROOT" -name "systemds*.jar" | head -1 2>/dev/null)
+ if [[ -n "$jar_file" ]]; then
+ systemds_version=$(unzip -p "$jar_file" META-INF/MANIFEST.MF 2>/dev/null | grep "Implementation-Version" | cut -d: -f2 | tr -d ' ' || echo "unknown")
+ else
+ systemds_version="unknown"
+ fi
+ fi
+ fi
+ else
+ systemds_version="unknown"
+ fi
+
+ # JDK version
+ if command -v java >/dev/null 2>&1; then
+ jdk_version=$(java -version 2>&1 | grep -v "Picked up" | head -1 | sed 's/.*"\(.*\)".*/\1/' || echo "unknown")
+ else
+ jdk_version="unknown"
+ fi
+
+ # PostgreSQL version
+ if command -v psql >/dev/null 2>&1; then
+ postgres_version=$(psql --version 2>/dev/null | head -1 || echo "not available")
+ else
+ postgres_version="not available"
+ fi
+
+ # DuckDB version
+ if command -v duckdb >/dev/null 2>&1; then
+ duckdb_version=$(duckdb --version 2>/dev/null || echo "not available")
+ else
+ duckdb_version="not available"
+ fi
+
+ # System resources
+ if [[ "$(uname)" == "Darwin" ]]; then
+ # macOS
+ cpu_info=$(sysctl -n machdep.cpu.brand_string 2>/dev/null || echo "unknown")
+ ram_info=$(( $(sysctl -n hw.memsize 2>/dev/null || echo 0) / 1024 / 1024 / 1024 ))GB
+ else
+ # Linux
+ cpu_info=$(grep "model name" /proc/cpuinfo | head -1 | cut -d: -f2- | sed 's/^ *//' 2>/dev/null || echo "unknown")
+ ram_info=$(( $(grep MemTotal /proc/meminfo | awk '{print $2}' 2>/dev/null || echo 0) / 1024 / 1024 ))GB
+ fi
+
+ # Store metadata globally
+ RUN_TIMESTAMP="$timestamp"
+ RUN_HOSTNAME="$hostname"
+ RUN_SYSTEMDS_VERSION="$systemds_version"
+ RUN_JDK_VERSION="$jdk_version"
+ RUN_POSTGRES_VERSION="$postgres_version"
+ RUN_DUCKDB_VERSION="$duckdb_version"
+ RUN_CPU_INFO="$cpu_info"
+ RUN_RAM_INFO="$ram_info"
+}
+
+collect_data_metadata() {
+ # Check for SSB data directory and get basic stats
+ local ssb_data_dir="$INPUT_DIR"
+ local json_parts=()
+ local display_parts=()
+
+ if [[ -d "$ssb_data_dir" ]]; then
+ # Try to get row counts from data files (if they exist)
+ for table in customer part supplier date; do
+ local file="$ssb_data_dir/${table}.tbl"
+ if [[ -f "$file" ]]; then
+ local count=$(wc -l < "$file" 2>/dev/null | tr -d ' ' || echo "0")
+ json_parts+=(" \"$table\": \"$count\"")
+ display_parts+=("$table:$count")
+ fi
+ done
+ # Check for any lineorder*.tbl file (SSB fact table)
+ local lineorder_file=$(find "$ssb_data_dir" -name "lineorder*.tbl" -type f | head -1)
+ if [[ -n "$lineorder_file" && -f "$lineorder_file" ]]; then
+ local count=$(wc -l < "$lineorder_file" 2>/dev/null | tr -d ' ' || echo "0")
+ json_parts+=(" \"lineorder\": \"$count\"")
+ display_parts+=("lineorder:$count")
+ fi
+ fi
+
+ if [[ ${#json_parts[@]} -eq 0 ]]; then
+ RUN_DATA_INFO='"No data files found"'
+ RUN_DATA_DISPLAY="No data files found"
+ else
+ # Join array elements with commas and newlines, wrap in braces for JSON
+ local formatted_json="{\n"
+ for i in "${!json_parts[@]}"; do
+ formatted_json+="${json_parts[$i]}"
+ if [[ $i -lt $((${#json_parts[@]} - 1)) ]]; then
+ formatted_json+=",\n"
+ else
+ formatted_json+="\n"
+ fi
+ done
+ formatted_json+=" }"
+ RUN_DATA_INFO="$formatted_json"
+
+ # Join with spaces for display
+ local IFS=" "
+ RUN_DATA_DISPLAY="${display_parts[*]}"
+ fi
+}
+
+print_metadata_header() {
+ echo "=================================================================================="
+ echo " MULTI-ENGINE PERFORMANCE BENCHMARK METADATA"
+ echo "=================================================================================="
+ echo "Timestamp: $RUN_TIMESTAMP"
+ echo "Hostname: $RUN_HOSTNAME"
+ echo "Seed: $SEED"
+ echo
+ echo "Software Versions:"
+ echo " SystemDS: $RUN_SYSTEMDS_VERSION"
+ echo " JDK: $RUN_JDK_VERSION"
+ echo " PostgreSQL: $RUN_POSTGRES_VERSION"
+ echo " DuckDB: $RUN_DUCKDB_VERSION"
+ echo
+ echo "System Resources:"
+ echo " CPU: $RUN_CPU_INFO"
+ echo " RAM: $RUN_RAM_INFO"
+ echo
+ echo "Data Build Info:"
+ echo " SSB Data: $RUN_DATA_DISPLAY"
+ echo
+ echo "Run Configuration:"
+ echo " Statistics: $(if $RUN_STATS; then echo "enabled"; else echo "disabled"; fi)"
+ echo " Queries: ${#QUERIES[@]} selected"
+ echo " Warmup Runs: $WARMUP"
+ echo " Repeat Runs: $REPEATS"
+ echo "=================================================================================="
+ echo
+}
+
+# Progress indicator function
+progress_indicator() {
+ local query_name="$1"
+ local stage="$2"
+ # Use terminal width for proper clearing, fallback to 120 chars if tput fails
+ local term_width
+ term_width=$(tput cols 2>/dev/null || echo 120)
+ local spaces=$(printf "%*s" "$term_width" "")
+ echo -ne "\r$spaces\r$query_name: Running $stage..."
+}
+
+# Clear progress line function
+clear_progress() {
+ local term_width
+ term_width=$(tput cols 2>/dev/null || echo 120)
+ local spaces=$(printf "%*s" "$term_width" "")
+ echo -ne "\r$spaces\r"
+}
+
+# Main execution
+# Collect metadata
+collect_system_metadata
+collect_data_metadata
+
+# Print metadata header
+print_metadata_header
+
+verify_environment
+echo
+echo "NOTE (macOS): You cannot drop OS caches like Linux (sync; echo 3 > /proc/sys/vm/drop_caches)."
+echo "We mitigate with warm-up runs and repeated averages to ensure consistent measurements."
+echo
+echo "INTERPRETATION GUIDE:"
+echo "- SystemDS Shell (ms): Total execution time including JVM startup, I/O, and computation"
+echo "- SystemDS Core (ms): Pure computation time excluding JVM overhead (only with --stats)"
+echo "- PostgreSQL (ms): Single-threaded execution time with parallel workers disabled"
+echo "- PostgreSQL Core (ms): Query execution time from EXPLAIN ANALYZE (only with --stats)"
+echo "- DuckDB (ms): Single-threaded execution time with threads=1 pragma"
+echo "- DuckDB Core (ms): Engine-internal latency from JSON profiling (with --stats)"
+echo "- (missing): SQL file not found for this query"
+echo "- (n/a): Core timing unavailable (run with --stats flag for internal timing)"
+echo
+echo "NOTE: All engines use single-threaded execution for fair comparison."
+echo " Multiple runs with averaging provide statistical reliability."
+echo
+echo "Single-threaded execution; warm-up runs: $WARMUP, timed runs: $REPEATS"
+echo "Row 1 shows mean (ms); Row 2 shows ±stdev/CV; Row 3 shows p95 (ms)."
+echo "Core execution times available for all engines with --stats flag."
+echo
+term_width=$(tput cols 2>/dev/null || echo 120)
+if [[ "$LAYOUT" == "auto" ]]; then
+ if [[ $term_width -ge 140 ]]; then
+ LAYOUT_MODE="wide"
+ else
+ LAYOUT_MODE="stacked"
+ fi
+else
+ LAYOUT_MODE="$LAYOUT"
+fi
+
+# If the user requested wide layout but the terminal is too narrow, fall back to stacked
+if [[ "$LAYOUT_MODE" == "wide" ]]; then
+ # compute total printable width: sum(widths) + 3*cols + 1 (accounting for separators)
+ sumw=0
+ for w in "${WIDE_COL_WIDTHS[@]}"; do sumw=$((sumw + w)); done
+ cols=${#WIDE_COL_WIDTHS[@]}
+ total_width=$((sumw + 3*cols + 1))
+ if [[ $total_width -gt $term_width ]]; then
+ # Try to scale columns down proportionally to fit terminal width
+ reserved=$((3*cols + 1))
+ avail=$((term_width - reserved))
+ if [[ $avail -le 0 ]]; then
+ :
+ else
+ # Minimum sensible widths per column (keep labels readable)
+ MIN_COL_WIDTHS=(6 8 8 6 10 6 6 16)
+ # Start with proportional distribution
+ declare -a new_widths=()
+ for w in "${WIDE_COL_WIDTHS[@]}"; do
+ nw=$(( w * avail / sumw ))
+ if [[ $nw -lt 1 ]]; then nw=1; fi
+ new_widths+=("$nw")
+ done
+ # Enforce minimums
+ sum_new=0
+ for i in "${!new_widths[@]}"; do
+ if [[ ${new_widths[i]} -lt ${MIN_COL_WIDTHS[i]:-4} ]]; then
+ new_widths[i]=${MIN_COL_WIDTHS[i]:-4}
+ fi
+ sum_new=$((sum_new + new_widths[i]))
+ done
+ # If even minimums exceed available, fallback to stacked
+ if [[ $sum_new -gt $avail ]]; then
+ :
+ else
+ # Distribute remaining columns' widths left-to-right
+ rem=$((avail - sum_new))
+ i=0
+ while [[ $rem -gt 0 ]]; do
+ new_widths[i]=$((new_widths[i] + 1))
+ rem=$((rem - 1))
+ i=$(( (i + 1) % cols ))
+ done
+ # Replace WIDE_COL_WIDTHS with the scaled values for printing
+ WIDE_COL_WIDTHS=("${new_widths[@]}")
+ # Recompute total_width for logging
+ sumw=0
+ for w in "${WIDE_COL_WIDTHS[@]}"; do sumw=$((sumw + w)); done
+ total_width=$((sumw + reserved))
+ echo "Info: scaled wide layout to fit terminal ($term_width cols): table width $total_width"
+ fi
+ fi
+ fi
+fi
+
+if [[ "$LAYOUT_MODE" == "wide" ]]; then
+ grid_line_wide
+ grid_row_wide \
+ "Query" \
+ "SysDS Shell" "SysDS Core" \
+ "PostgreSQL" "PostgreSQL Core" \
+ "DuckDB" "DuckDB Core" \
+ "Fastest"
+ grid_row_wide "" "mean" "mean" "mean" "mean" "mean" "mean" ""
+ grid_row_wide "" "±/CV" "±/CV" "±/CV" "±/CV" "±/CV" "±/CV" ""
+ grid_row_wide "" "p95" "p95" "p95" "p95" "p95" "p95" ""
+ grid_line_wide
+else
+ echo "================================================================================"
+ echo "Stacked layout (use --layout=wide for table view)."
+ echo "Row 1 shows mean (ms); Row 2 shows (±stdev/CV, p95)."
+ echo "--------------------------------------------------------------------------------"
+fi
+# Prepare output file paths and write CSV header with comprehensive metadata
+# Ensure results directory exists and create timestamped filenames
+RESULT_DIR="$OUTPUT_DIR"
+mkdir -p "$RESULT_DIR"
+RESULT_BASENAME="ssb_results_$(date -u +%Y%m%dT%H%M%SZ)"
+RESULT_CSV="$RESULT_DIR/${RESULT_BASENAME}.csv"
+RESULT_JSON="$RESULT_DIR/${RESULT_BASENAME}.json"
+
+{
+ echo "# Multi-Engine Performance Benchmark Results"
+ echo "# Timestamp: $RUN_TIMESTAMP"
+ echo "# Hostname: $RUN_HOSTNAME"
+ echo "# Seed: $SEED"
+ echo "# SystemDS: $RUN_SYSTEMDS_VERSION"
+ echo "# JDK: $RUN_JDK_VERSION"
+ echo "# PostgreSQL: $RUN_POSTGRES_VERSION"
+ echo "# DuckDB: $RUN_DUCKDB_VERSION"
+ echo "# CPU: $RUN_CPU_INFO"
+ echo "# RAM: $RUN_RAM_INFO"
+ echo "# Data: $RUN_DATA_DISPLAY"
+ echo "# Warmup: $WARMUP, Repeats: $REPEATS"
+ echo "# Statistics: $(if $RUN_STATS; then echo "enabled"; else echo "disabled"; fi)"
+ echo "#"
+ echo "query,systemds_shell_display,systemds_shell_mean,systemds_shell_stdev,systemds_shell_p95,systemds_core_display,systemds_core_mean,systemds_core_stdev,systemds_core_p95,postgres_display,postgres_mean,postgres_stdev,postgres_p95,postgres_core_display,postgres_core_mean,postgres_core_stdev,postgres_core_p95,duckdb_display,duckdb_mean,duckdb_stdev,duckdb_p95,duckdb_core_display,duckdb_core_mean,duckdb_core_stdev,duckdb_core_p95,fastest"
+} > "$RESULT_CSV"
+for base in "${QUERIES[@]}"; do
+ # Show progress indicator for SystemDS
+ progress_indicator "$base" "SystemDS"
+
+ dml_path="$QUERY_DIR/${base}.dml"
+ # Parse SystemDS results: shell_mean|shell_stdev|shell_p95|core_mean|core_stdev|core_p95
+ # Capture potential SystemDS test-run error messages for JSON reporting
+ tmp_err_msg=$(mktemp)
+ systemds_result="$(run_systemds_avg "$dml_path" "$tmp_err_msg")"
+ # Read any captured error message
+ sysds_err_text="$(sed -n '1,200p' "$tmp_err_msg" 2>/dev/null | tr '\n' ' ' || true)"
+ rm -f "$tmp_err_msg"
+ IFS='|' read -r sd_shell_mean sd_shell_stdev sd_shell_p95 sd_core_mean sd_core_stdev sd_core_p95 <<< "$systemds_result"
+
+ # Format SystemDS results for display
+ if [[ "$sd_shell_mean" =~ ^[0-9]+(\.[0-9]+)?$ ]]; then
+ sd_shell_display=$(format_statistics "$sd_shell_mean" "$sd_shell_stdev" "$sd_shell_p95" "$REPEATS")
+ else
+ sd_shell_display="$sd_shell_mean"
+ sd_shell_stdev="0"
+ sd_shell_p95="0"
+ fi
+ if [[ "$sd_core_mean" == "(n/a)" ]]; then
+ sd_core_display="(n/a)"
+ else
+ sd_core_display=$(format_statistics "$sd_core_mean" "$sd_core_stdev" "$sd_core_p95" "$REPEATS")
+ fi
+
+ sql_name="${base//_/.}.sql"
+ sql_path="$SQL_DIR/$sql_name"
+ pg_display="(missing)"
+ duck_display="(missing)"
+
+ if [[ -n "$PSQL_EXEC" && -f "$sql_path" ]]; then
+ progress_indicator "$base" "PostgreSQL"
+ pg_result="$(run_psql_avg_ms "$sql_path")"
+ IFS='|' read -r pg_mean pg_stdev pg_p95 pg_core_mean pg_core_stdev pg_core_p95 <<< "$pg_result"
+ if [[ "$pg_mean" == "(unavailable)" || "$pg_mean" == "(error)" ]]; then
+ pg_display="$pg_mean"
+ pg_core_display="$pg_mean"
+ pg_stdev="0"
+ pg_p95="0"
+ pg_core_mean="(n/a)"
+ pg_core_stdev="0"
+ pg_core_p95="0"
+ else
+ pg_display=$(format_statistics "$pg_mean" "$pg_stdev" "$pg_p95" "$REPEATS")
+ if [[ "$pg_core_mean" != "(n/a)" ]]; then
+ pg_core_display=$(format_statistics "$pg_core_mean" "$pg_core_stdev" "$pg_core_p95" "$REPEATS")
+ else
+ pg_core_display="(n/a)"
+ fi
+ fi
+ elif [[ -z "$PSQL_EXEC" ]]; then
+ pg_display="(unavailable)"
+ pg_core_display="(unavailable)"
+ pg_mean="(unavailable)"
+ pg_core_mean="(unavailable)"
+ pg_stdev="0"
+ pg_p95="0"
+ pg_core_stdev="0"
+ pg_core_p95="0"
+ else
+ pg_display="(missing)"
+ pg_core_display="(missing)"
+ pg_mean="(missing)"
+ pg_core_mean="(missing)"
+ pg_stdev="0"
+ pg_p95="0"
+ pg_core_stdev="0"
+ pg_core_p95="0"
+ fi
+
+ if [[ -n "$DUCKDB_EXEC" && -f "$sql_path" ]]; then
+ progress_indicator "$base" "DuckDB"
+ duck_result="$(run_duckdb_avg_ms "$sql_path")"
+ IFS='|' read -r duck_mean duck_stdev duck_p95 duck_core_mean duck_core_stdev duck_core_p95 <<< "$duck_result"
+ if [[ "$duck_mean" == "(unavailable)" || "$duck_mean" == "(error)" ]]; then
+ duck_display="$duck_mean"
+ duck_stdev="0"
+ duck_p95="0"
+ duck_core_display="(n/a)"
+ duck_core_mean="(n/a)"
+ duck_core_stdev="0"
+ duck_core_p95="0"
+ else
+ duck_display=$(format_statistics "$duck_mean" "$duck_stdev" "$duck_p95" "$REPEATS")
+ if [[ "$duck_core_mean" == "(n/a)" ]]; then
+ duck_core_display="(n/a)"
+ else
+ duck_core_display=$(format_statistics "$duck_core_mean" "$duck_core_stdev" "$duck_core_p95" "$REPEATS")
+ fi
+ fi
+ elif [[ -z "$DUCKDB_EXEC" ]]; then
+ duck_display="(unavailable)"
+ duck_mean="(unavailable)"
+ duck_stdev="0"
+ duck_p95="0"
+ duck_core_display="(unavailable)"
+ duck_core_mean="(unavailable)"
+ duck_core_stdev="0"
+ duck_core_p95="0"
+ else
+ duck_display="(missing)"
+ duck_mean="(missing)"
+ duck_stdev="0"
+ duck_p95="0"
+ duck_core_display="(missing)"
+ duck_core_mean="(missing)"
+ duck_core_stdev="0"
+ duck_core_p95="0"
+ fi
+
+ # Determine fastest engine based on mean values
+ fastest=""
+ min_ms=999999999
+ for engine in systemds pg duck; do
+ val=""
+ eng_name=""
+ case "$engine" in
+ systemds) val="$sd_shell_mean"; eng_name="SystemDS";;
+ pg) val="$pg_mean"; eng_name="PostgreSQL";;
+ duck) val="$duck_mean"; eng_name="DuckDB";;
+ esac
+ # Check if value is a valid number (including decimal)
+ if [[ "$val" =~ ^[0-9]+(\.[0-9]+)?$ ]]; then
+ # Use awk for floating point comparison
+ if [[ $(awk -v val="$val" -v min="$min_ms" 'BEGIN{print (val < min)}') -eq 1 ]]; then
+ min_ms=$(awk -v val="$val" 'BEGIN{printf "%.1f", val}')
+ fastest="$eng_name"
+ elif [[ $(awk -v val="$val" -v min="$min_ms" 'BEGIN{print (val == min)}') -eq 1 ]] && [[ -n "$fastest" ]]; then
+ fastest="$fastest+$eng_name" # Show ties
+ fi
+ fi
+ done
+ [[ -z "$fastest" ]] && fastest="(n/a)"
+
+ # Determine SystemDS per-query status and include any error message captured
+ systemds_status="success"
+ systemds_error_message=null
+ if [[ "$sd_shell_mean" == "(error)" ]] || [[ -n "$sysds_err_text" ]]; then
+ systemds_status="error"
+ if [[ -n "$sysds_err_text" ]]; then
+ # Escape quotes for JSON embedding
+ esc=$(printf '%s' "$sysds_err_text" | sed -e 's/"/\\"/g')
+ systemds_error_message="\"$esc\""
+ else
+ systemds_error_message="\"SystemDS reported an error during test-run\""
+ fi
+ fi
+
+ # Prepare mean-only and stats-only cells
+ # Means: use numeric mean when available; otherwise use existing display label (unavailable/missing)
+ sd_shell_mean_cell=$([[ "$sd_shell_mean" =~ ^[0-9]+(\.[0-9]+)?$ ]] && echo "$sd_shell_mean" || echo "$sd_shell_display")
+ sd_core_mean_cell=$([[ "$sd_core_mean" =~ ^[0-9]+(\.[0-9]+)?$ ]] && echo "$sd_core_mean" || echo "$sd_core_display")
+ pg_mean_cell=$([[ "$pg_mean" =~ ^[0-9]+(\.[0-9]+)?$ ]] && echo "$pg_mean" || echo "$pg_display")
+ pg_core_mean_cell=$([[ "$pg_core_mean" =~ ^[0-9]+(\.[0-9]+)?$ ]] && echo "$pg_core_mean" || echo "$pg_core_display")
+ duck_mean_cell=$([[ "$duck_mean" =~ ^[0-9]+(\.[0-9]+)?$ ]] && echo "$duck_mean" || echo "$duck_display")
+ duck_core_mean_cell=$([[ "$duck_core_mean" =~ ^[0-9]+(\.[0-9]+)?$ ]] && echo "$duck_core_mean" || echo "$duck_core_display")
+
+ # Stats lines split: CV and p95
+ sd_shell_cv_cell=$(format_cv_only "$sd_shell_mean" "$sd_shell_stdev" "$REPEATS")
+ sd_core_cv_cell=$(format_cv_only "$sd_core_mean" "$sd_core_stdev" "$REPEATS")
+ pg_cv_cell=$(format_cv_only "$pg_mean" "$pg_stdev" "$REPEATS")
+ pg_core_cv_cell=$(format_cv_only "$pg_core_mean" "$pg_core_stdev" "$REPEATS")
+ duck_cv_cell=$(format_cv_only "$duck_mean" "$duck_stdev" "$REPEATS")
+ duck_core_cv_cell=$(format_cv_only "$duck_core_mean" "$duck_core_stdev" "$REPEATS")
+
+ sd_shell_p95_cell=$(format_p95_only "$sd_shell_p95" "$REPEATS")
+ sd_core_p95_cell=$(format_p95_only "$sd_core_p95" "$REPEATS")
+ pg_p95_cell=$(format_p95_only "$pg_p95" "$REPEATS")
+ pg_core_p95_cell=$(format_p95_only "$pg_core_p95" "$REPEATS")
+ duck_p95_cell=$(format_p95_only "$duck_p95" "$REPEATS")
+ duck_core_p95_cell=$(format_p95_only "$duck_core_p95" "$REPEATS")
+
+ # Clear progress line and display final results
+ clear_progress
+ if [[ "$LAYOUT_MODE" == "wide" ]]; then
+ # Three-line table style with grid separators
+ grid_row_wide \
+ "$base" \
+ "$sd_shell_mean_cell" "$sd_core_mean_cell" \
+ "$pg_mean_cell" "$pg_core_mean_cell" \
+ "$duck_mean_cell" "$duck_core_mean_cell" \
+ "$fastest"
+ grid_row_wide \
+ "" \
+ "$sd_shell_cv_cell" "$sd_core_cv_cell" \
+ "$pg_cv_cell" "$pg_core_cv_cell" \
+ "$duck_cv_cell" "$duck_core_cv_cell" \
+ ""
+ grid_row_wide \
+ "" \
+ "$sd_shell_p95_cell" "$sd_core_p95_cell" \
+ "$pg_p95_cell" "$pg_core_p95_cell" \
+ "$duck_p95_cell" "$duck_core_p95_cell" \
+ ""
+ grid_line_wide
+ else
+ # Stacked layout for narrow terminals
+ echo "Query : $base Fastest: $fastest"
+ printf ' %-20s %s\n' "SystemDS Shell:" "$sd_shell_mean_cell"
+ [[ -n "$sd_shell_cv_cell" ]] && printf ' %-20s %s\n' "" "$sd_shell_cv_cell"
+ [[ -n "$sd_shell_p95_cell" ]] && printf ' %-20s %s\n' "" "$sd_shell_p95_cell"
+ printf ' %-20s %s\n' "SystemDS Core:" "$sd_core_mean_cell"
+ [[ -n "$sd_core_cv_cell" ]] && printf ' %-20s %s\n' "" "$sd_core_cv_cell"
+ [[ -n "$sd_core_p95_cell" ]] && printf ' %-20s %s\n' "" "$sd_core_p95_cell"
+ printf ' %-20s %s\n' "PostgreSQL:" "$pg_mean_cell"
+ [[ -n "$pg_cv_cell" ]] && printf ' %-20s %s\n' "" "$pg_cv_cell"
+ [[ -n "$pg_p95_cell" ]] && printf ' %-20s %s\n' "" "$pg_p95_cell"
+ printf ' %-20s %s\n' "PostgreSQL Core:" "$pg_core_mean_cell"
+ [[ -n "$pg_core_cv_cell" ]] && printf ' %-20s %s\n' "" "$pg_core_cv_cell"
+ [[ -n "$pg_core_p95_cell" ]] && printf ' %-20s %s\n' "" "$pg_core_p95_cell"
+ printf ' %-20s %s\n' "DuckDB:" "$duck_mean_cell"
+ [[ -n "$duck_cv_cell" ]] && printf ' %-20s %s\n' "" "$duck_cv_cell"
+ [[ -n "$duck_p95_cell" ]] && printf ' %-20s %s\n' "" "$duck_p95_cell"
+ printf ' %-20s %s\n' "DuckDB Core:" "$duck_core_mean_cell"
+ [[ -n "$duck_core_cv_cell" ]] && printf ' %-20s %s\n' "" "$duck_core_cv_cell"
+ [[ -n "$duck_core_p95_cell" ]] && printf ' %-20s %s\n' "" "$duck_core_p95_cell"
+ echo "--------------------------------------------------------------------------------"
+ fi
+
+ # Write comprehensive data to CSV
+ echo "$base,\"$sd_shell_display\",$sd_shell_mean,$sd_shell_stdev,$sd_shell_p95,\"$sd_core_display\",$sd_core_mean,$sd_core_stdev,$sd_core_p95,\"$pg_display\",$pg_mean,$pg_stdev,$pg_p95,\"$pg_core_display\",$pg_core_mean,$pg_core_stdev,$pg_core_p95,\"$duck_display\",$duck_mean,$duck_stdev,$duck_p95,\"$duck_core_display\",$duck_core_mean,$duck_core_stdev,$duck_core_p95,$fastest" >> "$RESULT_CSV"
+
+ # Build JSON entry for this query
+ json_entry=$(cat < "$RESULT_JSON"
+
+echo "Results saved to $RESULT_CSV"
+echo "Results saved to $RESULT_JSON"
diff --git a/scripts/staging/ssb/shell/run_ssb.sh b/scripts/staging/ssb/shell/run_ssb.sh
new file mode 100644
index 00000000000..b7ad9c57d32
--- /dev/null
+++ b/scripts/staging/ssb/shell/run_ssb.sh
@@ -0,0 +1,876 @@
+#!/usr/bin/env bash
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+# SystemDS Star Schema Benchmark (SSB) Runner
+# ===========================================
+#
+# CORE SCRIPTS STATUS:
+# - Version: 1.0 (September 5, 2025)
+# - Status: Production-Ready with Advanced User Experience
+# - First Public Release: September 5, 2025
+#
+# FEATURES IMPLEMENTED:
+# ✓ Basic SSB query execution with SystemDS 3.4.0-SNAPSHOT
+# ✓ Single-threaded configuration for consistent benchmarking
+# ✓ Progress indicators with real-time updates
+# ✓ Comprehensive timing measurements using /usr/bin/time
+# ✓ Query result extraction (scalar and table formats)
+# ✓ Success/failure tracking with detailed reporting
+# ✓ Query summary table with execution status
+# ✓ "See below" notation with result reprinting (NEW)
+# ✓ Long table outputs displayed after summary (NEW)
+# ✓ Error handling with timeout protection
+# ✓ Cross-platform compatibility (macOS/Linux)
+#
+# RECENT IMPORTANT ADDITIONS:
+# - Accepts --input-dir=PATH and forwards it into DML runs as a SystemDS named
+# argument: -nvargs input_dir=/path/to/data (DML can use sys.vinput_dir or
+# the named argument to locate data files instead of hardcoded `data/`).
+# - Fast-fail on missing input directory: the runner verifies the provided
+# input path exists and exits with a clear error message if not.
+# - Runtime SystemDS error detection: test-run output is scanned for runtime
+# error blocks (e.g., "An Error Occurred : ..."). Queries with runtime
+# failures are reported as `status: "error"` and include `error_message`
+# in generated JSON metadata for easier debugging and CI integration.
+#
+# MAJOR FEATURES IN v1.0 (First Public Release):
+# - Complete SSB query execution with SystemDS 3.4.0-SNAPSHOT
+# - Enhanced "see below" notation with result reprinting
+# - Long table outputs displayed after summary for better UX
+# - Eliminated need to scroll back through terminal output
+# - Maintained array alignment for consistent result tracking
+# - JSON metadata contains complete query results, not "see below"
+# - Added --out-dir option for custom output directory
+# - Multi-format output: TXT, CSV, JSON for each query result
+# - Structured output directory with comprehensive run.json metadata file
+#
+# DEPENDENCIES:
+# - SystemDS binary (3.4.0-SNAPSHOT or later)
+# - Single-threaded configuration file (auto-generated)
+# - SSB query files in scripts/ssb/queries/
+# - Bash 4.0+ with timeout support
+#
+# USAGE (from repo root):
+# scripts/ssb/shell/run_ssb.sh # run all SSB queries
+# scripts/ssb/shell/run_ssb.sh q1.1 q2.3 # run specific queries
+# scripts/ssb/shell/run_ssb.sh --stats # enable internal statistics
+# scripts/ssb/shell/run_ssb.sh q3.1 --stats # run specific query with stats
+# scripts/ssb/shell/run_ssb.sh --seed=12345 # run with specific seed for reproducibility
+# scripts/ssb/shell/run_ssb.sh --out-dir=/path # specify output directory for results
+#
+set -euo pipefail
+export LC_ALL=C
+
+# Determine script directory and project root (repo root)
+if command -v realpath >/dev/null 2>&1; then
+ SCRIPT_DIR="$(dirname "$(realpath "$0")")"
+else
+ SCRIPT_DIR="$(python - <<'PY'
+import os, sys
+print(os.path.dirname(os.path.abspath(sys.argv[1])))
+PY
+"$0")"
+fi
+if command -v git >/dev/null 2>&1 && git -C "$SCRIPT_DIR" rev-parse --show-toplevel >/dev/null 2>&1; then
+ PROJECT_ROOT="$(git -C "$SCRIPT_DIR" rev-parse --show-toplevel)"
+else
+ __dir="$SCRIPT_DIR"
+ PROJECT_ROOT=""
+ while [[ "$__dir" != "/" ]]; do
+ if [[ -d "$__dir/.git" || -f "$__dir/pom.xml" ]]; then
+ PROJECT_ROOT="$__dir"; break
+ fi
+ __dir="$(dirname "$__dir")"
+ done
+ : "${PROJECT_ROOT:=$(cd "$SCRIPT_DIR/../../../" && pwd)}"
+fi
+
+# Locate SystemDS executable
+SYSTEMDS_CMD="$PROJECT_ROOT/bin/systemds"
+if [[ ! -x "$SYSTEMDS_CMD" ]]; then
+ SYSTEMDS_CMD="$(command -v systemds || true)"
+fi
+if [[ -z "$SYSTEMDS_CMD" || ! -x "$SYSTEMDS_CMD" ]]; then
+ echo "Error: could not find SystemDS executable." >&2
+ echo " Tried: $PROJECT_ROOT/bin/systemds and PATH" >&2
+ exit 1
+fi
+
+# Ensure single-threaded configuration file exists
+CONF_DIR="$PROJECT_ROOT/conf"
+SINGLE_THREAD_CONF="$CONF_DIR/single_thread.xml"
+mkdir -p "$CONF_DIR"
+if [[ ! -f "$SINGLE_THREAD_CONF" ]]; then
+cat > "$SINGLE_THREAD_CONF" <<'XML'
+
+
+ sysds.cp.parallel.opsfalse
+
+
+ sysds.num.threads1
+
+
+XML
+fi
+SYS_EXTRA_ARGS=( "-config" "$SINGLE_THREAD_CONF" )
+
+# Query directory
+QUERY_DIR="$PROJECT_ROOT/scripts/ssb/queries"
+
+# Verify query directory exists
+if [[ ! -d "$QUERY_DIR" ]]; then
+ echo "Error: Query directory not found: $QUERY_DIR" >&2
+ exit 1
+fi
+
+# Help function
+show_help() {
+ cat << 'EOF'
+SystemDS Star Schema Benchmark (SSB) Runner v1.0
+
+USAGE (from repo root):
+ scripts/ssb/shell/run_ssb.sh [OPTIONS] [QUERIES...]
+
+OPTIONS:
+ --stats, -stats Enable SystemDS internal statistics collection
+ --seed=N, -seed=N Set random seed for reproducible results (default: auto-generated)
+ --output-dir=PATH, -output-dir=PATH Specify custom output directory (default: $PROJECT_ROOT/scripts/ssb/shell/ssbOutputData/QueryData)
+ --input-dir=PATH, -input-dir=PATH Specify custom data directory (default: $PROJECT_ROOT/data)
+ --help, -help, -h, --h Show this help message
+ --version, -version, -v, --v Show version information
+
+QUERIES:
+ If no queries are specified, all available SSB queries (q*.dml) will be executed.
+ To run specific queries, provide their names (with or without .dml extension):
+ ./run_ssb.sh q1.1 q2.3 q4.1
+
+EXAMPLES (from repo root):
+ scripts/ssb/shell/run_ssb.sh # Run all SSB queries
+ scripts/ssb/shell/run_ssb.sh --stats # Run all queries with statistics
+ scripts/ssb/shell/run_ssb.sh -stats # Same as above (single dash)
+ scripts/ssb/shell/run_ssb.sh q1.1 q2.3 # Run specific queries only
+ scripts/ssb/shell/run_ssb.sh --seed=12345 --stats # Reproducible run with statistics
+ scripts/ssb/shell/run_ssb.sh -seed=12345 -stats # Same as above (single dash)
+ scripts/ssb/shell/run_ssb.sh --output-dir=/tmp/results # Custom output directory
+ scripts/ssb/shell/run_ssb.sh -output-dir=/tmp/results # Same as above (single dash)
+ scripts/ssb/shell/run_ssb.sh --input-dir=/path/to/data # Custom data directory
+ scripts/ssb/shell/run_ssb.sh -input-dir=/path/to/data # Same as above (single dash)
+
+OUTPUT:
+ Results are saved in multiple formats:
+ - TXT: Human-readable format
+ - CSV: Machine-readable data format
+ - JSON: Structured format with metadata
+ - run.json: Complete run metadata and results
+
+For more information, see the documentation in scripts/ssb/README.md
+EOF
+}
+
+# Parse arguments
+RUN_STATS=false
+QUERIES=()
+SEED=""
+OUT_DIR=""
+INPUT_DIR=""
+for arg in "$@"; do
+ if [[ "$arg" == "--help" || "$arg" == "-help" || "$arg" == "-h" || "$arg" == "--h" ]]; then
+ show_help
+ exit 0
+ elif [[ "$arg" == "--version" || "$arg" == "-version" || "$arg" == "-v" || "$arg" == "--v" ]]; then
+ echo "SystemDS Star Schema Benchmark (SSB) Runner v1.0"
+ echo "First Public Release: September 5, 2025"
+ exit 0
+ elif [[ "$arg" == "--stats" || "$arg" == "-stats" ]]; then
+ RUN_STATS=true
+ elif [[ "$arg" == --seed=* || "$arg" == -seed=* ]]; then
+ if [[ "$arg" == --seed=* ]]; then
+ SEED="${arg#--seed=}"
+ else
+ SEED="${arg#-seed=}"
+ fi
+ elif [[ "$arg" == "--seed" || "$arg" == "-seed" ]]; then
+ echo "Error: --seed/-seed requires a value (e.g., --seed=12345 or -seed=12345)" >&2
+ exit 1
+ elif [[ "$arg" == --output-dir=* || "$arg" == -output-dir=* ]]; then
+ if [[ "$arg" == --output-dir=* ]]; then
+ OUT_DIR="${arg#--output-dir=}"
+ else
+ OUT_DIR="${arg#-output-dir=}"
+ fi
+ elif [[ "$arg" == "--output-dir" || "$arg" == "-output-dir" ]]; then
+ echo "Error: --output-dir/-output-dir requires a value (e.g., --output-dir=/path/to/output or -output-dir=/path/to/output)" >&2
+ exit 1
+ elif [[ "$arg" == --input-dir=* || "$arg" == -input-dir=* ]]; then
+ if [[ "$arg" == --input-dir=* ]]; then
+ INPUT_DIR="${arg#--input-dir=}"
+ else
+ INPUT_DIR="${arg#-input-dir=}"
+ fi
+ elif [[ "$arg" == "--input-dir" || "$arg" == "-input-dir" ]]; then
+ echo "Error: --input-dir/-input-dir requires a value (e.g., --input-dir=/path/to/data or -input-dir=/path/to/data)" >&2
+ exit 1
+ else
+ # Check if argument looks like an unrecognized option (starts with dash)
+ if [[ "$arg" == -* ]]; then
+ echo "Error: Unrecognized option '$arg'" >&2
+ echo "Use --help or -h to see available options." >&2
+ exit 1
+ else
+ # Treat as query name
+ name="$(echo "$arg" | tr '.' '_')"
+ QUERIES+=( "$name.dml" )
+ fi
+ fi
+done
+
+# Set default output directory if not provided
+if [[ -z "$OUT_DIR" ]]; then
+ OUT_DIR="$PROJECT_ROOT/scripts/ssb/shell/ssbOutputData/QueryData"
+fi
+
+# Set default input data directory if not provided
+if [[ -z "$INPUT_DIR" ]]; then
+ INPUT_DIR="$PROJECT_ROOT/data"
+fi
+
+# Normalize paths by removing trailing slashes
+INPUT_DIR="${INPUT_DIR%/}"
+OUT_DIR="${OUT_DIR%/}"
+
+# Ensure output directory exists
+mkdir -p "$OUT_DIR"
+
+# Pass input directory to DML scripts via SystemDS named arguments
+NVARGS=( -nvargs "input_dir=${INPUT_DIR}" )
+
+# Validate input data directory exists
+if [[ ! -d "$INPUT_DIR" ]]; then
+ echo "Error: Input data directory '$INPUT_DIR' does not exist." >&2
+ echo "Please create the directory or specify a valid path with --input-dir=PATH" >&2
+ exit 1
+fi
+
+# Generate seed if not provided
+if [[ -z "$SEED" ]]; then
+ SEED=$((RANDOM * 32768 + RANDOM))
+fi
+
+# Discover queries if none provided
+shopt -s nullglob
+if [[ ${#QUERIES[@]} -eq 0 ]]; then
+ for f in "$QUERY_DIR"/q*.dml; do
+ if [[ -f "$f" ]]; then
+ QUERIES+=("$(basename "$f")")
+ fi
+ done
+ if [[ ${#QUERIES[@]} -eq 0 ]]; then
+ echo "Error: No query files found in $QUERY_DIR" >&2
+ exit 1
+ fi
+fi
+shopt -u nullglob
+
+# Metadata collection functions
+collect_system_metadata() {
+ local timestamp hostname systemds_version jdk_version cpu_info ram_info
+
+ # Basic system info
+ timestamp=$(date -u '+%Y-%m-%d %H:%M:%S UTC')
+ hostname=$(hostname 2>/dev/null || echo "unknown")
+
+ # SystemDS version
+ if [[ -x "$SYSTEMDS_CMD" ]]; then
+ # Try to get version from pom.xml first
+ if [[ -f "$PROJECT_ROOT/pom.xml" ]]; then
+ systemds_version=$(grep -A1 'org.apache.systemds' "$PROJECT_ROOT/pom.xml" | grep '' | sed 's/.*\(.*\)<\/version>.*/\1/' | head -1 2>/dev/null || echo "unknown")
+ else
+ systemds_version="unknown"
+ fi
+
+ # If pom.xml method failed, try alternative methods
+ if [[ "$systemds_version" == "unknown" ]]; then
+ # Try to extract from SystemDS JAR manifest
+ if [[ -f "$PROJECT_ROOT/target/systemds.jar" ]]; then
+ systemds_version=$(unzip -p "$PROJECT_ROOT/target/systemds.jar" META-INF/MANIFEST.MF 2>/dev/null | grep "Implementation-Version" | cut -d: -f2 | tr -d ' ' || echo "unknown")
+ else
+ # Try to find any SystemDS JAR and extract version
+ local jar_file=$(find "$PROJECT_ROOT" -name "systemds*.jar" | head -1 2>/dev/null)
+ if [[ -n "$jar_file" ]]; then
+ systemds_version=$(unzip -p "$jar_file" META-INF/MANIFEST.MF 2>/dev/null | grep "Implementation-Version" | cut -d: -f2 | tr -d ' ' || echo "unknown")
+ else
+ systemds_version="unknown"
+ fi
+ fi
+ fi
+ else
+ systemds_version="unknown"
+ fi
+
+ # JDK version
+ if command -v java >/dev/null 2>&1; then
+ jdk_version=$(java -version 2>&1 | head -1 | sed 's/.*"\(.*\)".*/\1/' || echo "unknown")
+ else
+ jdk_version="unknown"
+ fi
+
+ # System resources
+ if [[ "$(uname)" == "Darwin" ]]; then
+ # macOS
+ cpu_info=$(sysctl -n machdep.cpu.brand_string 2>/dev/null || echo "unknown")
+ ram_info=$(( $(sysctl -n hw.memsize 2>/dev/null || echo 0) / 1024 / 1024 / 1024 ))GB
+ else
+ # Linux
+ cpu_info=$(grep "model name" /proc/cpuinfo | head -1 | cut -d: -f2- | sed 's/^ *//' 2>/dev/null || echo "unknown")
+ ram_info=$(( $(grep MemTotal /proc/meminfo | awk '{print $2}' 2>/dev/null || echo 0) / 1024 / 1024 ))GB
+ fi
+
+ # Store metadata globally
+ RUN_TIMESTAMP="$timestamp"
+ RUN_HOSTNAME="$hostname"
+ RUN_SYSTEMDS_VERSION="$systemds_version"
+ RUN_JDK_VERSION="$jdk_version"
+ RUN_CPU_INFO="$cpu_info"
+ RUN_RAM_INFO="$ram_info"
+}
+
+collect_data_metadata() {
+ # Check for SSB data directory and get basic stats
+ local ssb_data_dir="$INPUT_DIR"
+ local json_parts=()
+ local display_parts=()
+
+ if [[ -d "$ssb_data_dir" ]]; then
+ # Try to get row counts from data files (if they exist)
+ for table in customer part supplier date; do
+ local file="$ssb_data_dir/${table}.tbl"
+ if [[ -f "$file" ]]; then
+ local count=$(wc -l < "$file" 2>/dev/null | tr -d ' ' || echo "0")
+ json_parts+=(" \"$table\": \"$count\"")
+ display_parts+=("$table:$count")
+ fi
+ done
+ # Check for any lineorder*.tbl file (SSB fact table)
+ local lineorder_file=$(find "$ssb_data_dir" -name "lineorder*.tbl" -type f | head -1)
+ if [[ -n "$lineorder_file" && -f "$lineorder_file" ]]; then
+ local count=$(wc -l < "$lineorder_file" 2>/dev/null | tr -d ' ' || echo "0")
+ json_parts+=(" \"lineorder\": \"$count\"")
+ display_parts+=("lineorder:$count")
+ fi
+ fi
+
+ if [[ ${#json_parts[@]} -eq 0 ]]; then
+ RUN_DATA_INFO='"No data files found"'
+ RUN_DATA_DISPLAY="No data files found"
+ else
+ # Join array elements with commas and newlines, wrap in braces for JSON
+ local formatted_json="{\n"
+ for i in "${!json_parts[@]}"; do
+ formatted_json+="${json_parts[$i]}"
+ if [[ $i -lt $((${#json_parts[@]} - 1)) ]]; then
+ formatted_json+=",\n"
+ else
+ formatted_json+="\n"
+ fi
+ done
+ formatted_json+=" }"
+ RUN_DATA_INFO="$formatted_json"
+
+ # Join with spaces for display
+ local IFS=" "
+ RUN_DATA_DISPLAY="${display_parts[*]}"
+ fi
+}
+
+# Output format functions
+create_output_structure() {
+ local run_id="$1"
+ local base_dir="$OUT_DIR/ssb_run_$run_id"
+
+ # Create output directory structure
+ mkdir -p "$base_dir"/{txt,csv,json}
+
+ # Set global variables for output paths
+ OUTPUT_BASE_DIR="$base_dir"
+ OUTPUT_TXT_DIR="$base_dir/txt"
+ OUTPUT_CSV_DIR="$base_dir/csv"
+ OUTPUT_JSON_DIR="$base_dir/json"
+ OUTPUT_METADATA_FILE="$base_dir/run.json"
+}
+
+save_query_result_txt() {
+ local query_name="$1"
+ local result_data="$2"
+ local output_file="$OUTPUT_TXT_DIR/${query_name}.txt"
+
+ {
+ echo "========================================="
+ echo "SSB Query: $query_name"
+ echo "========================================="
+ echo "Timestamp: $(date -u '+%Y-%m-%d %H:%M:%S UTC')"
+ echo "Seed: $SEED"
+ echo ""
+ echo "Result:"
+ echo "---------"
+ echo "$result_data"
+ echo ""
+ echo "========================================="
+ } > "$output_file"
+}
+
+save_query_result_csv() {
+ local query_name="$1"
+ local result_data="$2"
+ local output_file="$OUTPUT_CSV_DIR/${query_name}.csv"
+
+ # Check if result is a single scalar value (including negative numbers and scientific notation)
+ if [[ "$result_data" =~ ^-?[0-9]+(\.[0-9]+)?([eE][+-]?[0-9]+)?$ ]]; then
+ # Scalar result
+ {
+ echo "query,result"
+ echo "$query_name,$result_data"
+ } > "$output_file"
+ else
+ # Table result - try to convert to CSV format
+ {
+ echo "# SSB Query: $query_name"
+ echo "# Timestamp: $(date -u '+%Y-%m-%d %H:%M:%S UTC')"
+ echo "# Seed: $SEED"
+ # Convert space-separated table data to CSV
+ echo "$result_data" | sed 's/ */,/g' | sed 's/^,//g' | sed 's/,$//g'
+ } > "$output_file"
+ fi
+}
+
+save_query_result_json() {
+ local query_name="$1"
+ local result_data="$2"
+ local output_file="$OUTPUT_JSON_DIR/${query_name}.json"
+
+ # Escape quotes and special characters for JSON
+ local escaped_result=$(echo "$result_data" | sed 's/\\/\\\\/g' | sed 's/"/\\"/g' | tr '\n' ' ')
+
+ {
+ echo "{"
+ echo " \"query\": \"$query_name\","
+ echo " \"timestamp\": \"$(date -u '+%Y-%m-%d %H:%M:%S UTC')\","
+ echo " \"seed\": $SEED,"
+ echo " \"result\": \"$escaped_result\","
+ echo " \"metadata\": {"
+ echo " \"systemds_version\": \"$RUN_SYSTEMDS_VERSION\","
+ echo " \"hostname\": \"$RUN_HOSTNAME\""
+ echo " }"
+ echo "}"
+ } > "$output_file"
+}
+
+save_all_formats() {
+ local query_name="$1"
+ local result_data="$2"
+
+ save_query_result_txt "$query_name" "$result_data"
+ save_query_result_csv "$query_name" "$result_data"
+ save_query_result_json "$query_name" "$result_data"
+}
+
+# Collect metadata
+collect_system_metadata
+collect_data_metadata
+
+# Create output directory structure with timestamp-based run ID
+RUN_ID="$(date +%Y%m%d_%H%M%S)"
+create_output_structure "$RUN_ID"
+
+# Execute queries
+count=0
+failed=0
+SUCCESSFUL_QUERIES=() # Array to track successfully executed queries
+ALL_RUN_QUERIES=() # Array to track all queries that were attempted (in order)
+QUERY_STATUS=() # Array to track status: "success" or "error"
+QUERY_ERROR_MSG=() # Array to store error messages for failed queries
+QUERY_RESULTS=() # Array to track query results for display
+QUERY_FULL_RESULTS=() # Array to track complete query results for JSON
+QUERY_STATS=() # Array to track SystemDS statistics for JSON
+QUERY_TIMINGS=() # Array to track execution timing statistics
+LONG_OUTPUTS=() # Array to store long table outputs for display after summary
+
+# Progress indicator function
+progress_indicator() {
+ local query_name="$1"
+ local current="$2"
+ local total="$3"
+ echo -ne "\r[$current/$total] Running: $query_name "
+}
+
+for q in "${QUERIES[@]}"; do
+ dml="$QUERY_DIR/$q"
+ if [[ ! -f "$dml" ]]; then
+ echo "Warning: query file '$dml' not found; skipping." >&2
+ continue
+ fi
+
+ # Show progress
+ progress_indicator "$q" "$((count + failed + 1))" "${#QUERIES[@]}"
+
+ # Change to project root directory so relative paths in DML work correctly
+ cd "$PROJECT_ROOT"
+
+ # Clear progress line before showing output
+ echo -ne "\r \r"
+ echo "[$((count + failed + 1))/${#QUERIES[@]}] Running: $q"
+
+ # Record attempted query
+ ALL_RUN_QUERIES+=("$q")
+
+ if $RUN_STATS; then
+ # Capture output to extract result
+ temp_output=$(mktemp)
+ if "$SYSTEMDS_CMD" "$dml" -stats "${SYS_EXTRA_ARGS[@]}" "${NVARGS[@]}" | tee "$temp_output"; then
+ # Even when SystemDS exits 0, the DML can emit runtime errors. Detect common error markers.
+ error_msg=$(sed -n '/An Error Occurred :/,$ p' "$temp_output" | sed -n '1,200p' | tr '\n' ' ' | sed 's/^ *//;s/ *$//')
+ if [[ -n "$error_msg" ]]; then
+ echo "Error: Query $q reported runtime error" >&2
+ echo "$error_msg" >&2
+ failed=$((failed+1))
+ QUERY_STATUS+=("error")
+ QUERY_ERROR_MSG+=("$error_msg")
+ # Maintain array alignment
+ QUERY_STATS+=("")
+ QUERY_RESULTS+=("N/A")
+ QUERY_FULL_RESULTS+=("N/A")
+ LONG_OUTPUTS+=("")
+ else
+ count=$((count+1))
+ SUCCESSFUL_QUERIES+=("$q") # Track successful query
+ QUERY_STATUS+=("success")
+ # Extract result - try multiple patterns with timeouts to prevent hanging:
+ # 1. Simple scalar pattern like "REVENUE: 687752409"
+ result=$(timeout 5s grep -E "^[A-Z_]+:\s*[0-9]+" "$temp_output" | tail -1 | awk '{print $2}' 2>/dev/null || true)
+ full_result="$result" # For scalar results, display and full results are the same
+
+ # 2. If no scalar pattern, check for table output and get row count
+ if [[ -z "$result" ]]; then
+ # Look for frame info like "# FRAME: nrow = 53, ncol = 3"
+ nrows=$(timeout 5s grep "# FRAME: nrow =" "$temp_output" | awk '{print $5}' | tr -d ',' 2>/dev/null || true)
+ if [[ -n "$nrows" ]]; then
+ result="${nrows} rows (see below)"
+ # Extract and store the long output for later display (excluding statistics)
+ long_output=$(grep -v "^#" "$temp_output" | grep -v "WARNING" | grep -v "WARN" | grep -v "^$" | sed '/^SystemDS Statistics:/,$ d')
+ LONG_OUTPUTS+=("$long_output")
+ # For JSON, store the actual table data
+ full_result="$long_output"
+ else
+ # Count actual data rows (lines with numbers, excluding headers and comments) - limit to prevent hanging
+ nrows=$(timeout 5s grep -E "^[0-9]" "$temp_output" | sed '/^SystemDS Statistics:/,$ d' | head -1000 | wc -l | tr -d ' ' 2>/dev/null || echo "0")
+ if [[ "$nrows" -gt 0 ]]; then
+ result="${nrows} rows (see below)"
+ # Extract and store the long output for later display (excluding statistics)
+ long_output=$(grep -E "^[0-9]" "$temp_output" | sed '/^SystemDS Statistics:/,$ d' | head -1000)
+ LONG_OUTPUTS+=("$long_output")
+ # For JSON, store the actual table data
+ full_result="$long_output"
+ else
+ result="N/A"
+ full_result="N/A"
+ LONG_OUTPUTS+=("") # Empty placeholder to maintain array alignment
+ fi
+ fi
+ else
+ LONG_OUTPUTS+=("") # Empty placeholder for scalar results to maintain array alignment
+ fi
+ QUERY_RESULTS+=("$result") # Track query result for display
+ QUERY_FULL_RESULTS+=("$full_result") # Track complete query result for JSON
+
+ # Save result in all formats
+ query_name_clean="${q%.dml}"
+
+ # Extract and store statistics for JSON (preserving newlines)
+ stats_output=$(sed -n '/^SystemDS Statistics:/,$ p' "$temp_output")
+ QUERY_STATS+=("$stats_output") # Track statistics for JSON
+
+ save_all_formats "$query_name_clean" "$full_result"
+ fi
+ else
+ echo "Error: Query $q failed" >&2
+ failed=$((failed+1))
+ QUERY_STATUS+=("error")
+ QUERY_ERROR_MSG+=("Query execution failed (non-zero exit)")
+ # Add empty stats entry for failed queries to maintain array alignment
+ QUERY_STATS+=("")
+ fi
+ rm -f "$temp_output"
+ else
+ # Capture output to extract result
+ temp_output=$(mktemp)
+ if "$SYSTEMDS_CMD" "$dml" "${SYS_EXTRA_ARGS[@]}" "${NVARGS[@]}" | tee "$temp_output"; then
+ # Detect runtime errors in output even if command returned 0
+ error_msg=$(sed -n '/An Error Occurred :/,$ p' "$temp_output" | sed -n '1,200p' | tr '\n' ' ' | sed 's/^ *//;s/ *$//')
+ if [[ -n "$error_msg" ]]; then
+ echo "Error: Query $q reported runtime error" >&2
+ echo "$error_msg" >&2
+ failed=$((failed+1))
+ QUERY_STATUS+=("error")
+ QUERY_ERROR_MSG+=("$error_msg")
+ QUERY_STATS+=("")
+ QUERY_RESULTS+=("N/A")
+ QUERY_FULL_RESULTS+=("N/A")
+ LONG_OUTPUTS+=("")
+ else
+ count=$((count+1))
+ SUCCESSFUL_QUERIES+=("$q") # Track successful query
+ QUERY_STATUS+=("success")
+ # Extract result - try multiple patterns with timeouts to prevent hanging:
+ # 1. Simple scalar pattern like "REVENUE: 687752409"
+ result=$(timeout 5s grep -E "^[A-Z_]+:\s*[0-9]+" "$temp_output" | tail -1 | awk '{print $2}' 2>/dev/null || true)
+ full_result="$result" # For scalar results, display and full results are the same
+
+ # 2. If no scalar pattern, check for table output and get row count
+ if [[ -z "$result" ]]; then
+ # Look for frame info like "# FRAME: nrow = 53, ncol = 3"
+ nrows=$(timeout 5s grep "# FRAME: nrow =" "$temp_output" | awk '{print $5}' | tr -d ',' 2>/dev/null || true)
+ if [[ -n "$nrows" ]]; then
+ result="${nrows} rows (see below)"
+ # Extract and store the long output for later display
+ long_output=$(grep -v "^#" "$temp_output" | grep -v "WARNING" | grep -v "WARN" | grep -v "^$" | tail -n +1)
+ LONG_OUTPUTS+=("$long_output")
+ # For JSON, store the actual table data
+ full_result="$long_output"
+ else
+ # Count actual data rows (lines with numbers, excluding headers and comments) - limit to prevent hanging
+ nrows=$(timeout 5s grep -E "^[0-9]" "$temp_output" | head -1000 | wc -l | tr -d ' ' 2>/dev/null || echo "0")
+ if [[ "$nrows" -gt 0 ]]; then
+ result="${nrows} rows (see below)"
+ # Extract and store the long output for later display
+ long_output=$(grep -E "^[0-9]" "$temp_output" | head -1000)
+ LONG_OUTPUTS+=("$long_output")
+ # For JSON, store the actual table data
+ full_result="$long_output"
+ else
+ result="N/A"
+ full_result="N/A"
+ LONG_OUTPUTS+=("") # Empty placeholder to maintain array alignment
+ fi
+ fi
+ else
+ LONG_OUTPUTS+=("") # Empty placeholder for scalar results to maintain array alignment
+ fi
+ QUERY_RESULTS+=("$result") # Track query result for display
+ QUERY_FULL_RESULTS+=("$full_result") # Track complete query result for JSON
+
+ # Add empty stats entry for non-stats runs to maintain array alignment
+ QUERY_STATS+=("")
+
+ # Save result in all formats
+ query_name_clean="${q%.dml}"
+ save_all_formats "$query_name_clean" "$full_result"
+ fi
+ else
+ echo "Error: Query $q failed" >&2
+ failed=$((failed+1))
+ QUERY_STATUS+=("error")
+ QUERY_ERROR_MSG+=("Query execution failed (non-zero exit)")
+ # Add empty stats entry for failed queries to maintain array alignment
+ QUERY_STATS+=("")
+ fi
+ rm -f "$temp_output"
+ fi
+done
+
+# Summary
+echo ""
+echo "========================================="
+echo "SSB benchmark completed!"
+echo "Total queries executed: $count"
+if [[ $failed -gt 0 ]]; then
+ echo "Failed queries: $failed"
+fi
+if $RUN_STATS; then
+ echo "Statistics: enabled"
+else
+ echo "Statistics: disabled"
+fi
+
+# Display run metadata summary
+echo ""
+echo "========================================="
+echo "RUN METADATA SUMMARY"
+echo "========================================="
+echo "Timestamp: $RUN_TIMESTAMP"
+echo "Hostname: $RUN_HOSTNAME"
+echo "Seed: $SEED"
+echo ""
+echo "Software Versions:"
+echo " SystemDS: $RUN_SYSTEMDS_VERSION"
+echo " JDK: $RUN_JDK_VERSION"
+echo ""
+echo "System Resources:"
+echo " CPU: $RUN_CPU_INFO"
+echo " RAM: $RUN_RAM_INFO"
+echo ""
+echo "Data Build Info:"
+echo " SSB Data: $RUN_DATA_DISPLAY"
+echo "========================================="
+
+# Generate metadata JSON file (include all attempted queries with status and error messages)
+{
+ echo "{"
+ echo " \"benchmark_type\": \"ssb_systemds\","
+ echo " \"timestamp\": \"$RUN_TIMESTAMP\","
+ echo " \"hostname\": \"$RUN_HOSTNAME\","
+ echo " \"seed\": $SEED,"
+ echo " \"software_versions\": {"
+ echo " \"systemds\": \"$RUN_SYSTEMDS_VERSION\","
+ echo " \"jdk\": \"$RUN_JDK_VERSION\""
+ echo " },"
+ echo " \"system_resources\": {"
+ echo " \"cpu\": \"$RUN_CPU_INFO\","
+ echo " \"ram\": \"$RUN_RAM_INFO\""
+ echo " },"
+ echo -e " \"data_build_info\": $RUN_DATA_INFO,"
+ echo " \"run_configuration\": {"
+ echo " \"statistics_enabled\": $(if $RUN_STATS; then echo "true"; else echo "false"; fi),"
+ echo " \"queries_selected\": ${#QUERIES[@]},"
+ echo " \"queries_executed\": $count,"
+ echo " \"queries_failed\": $failed"
+ echo " },"
+ echo " \"results\": ["
+ for i in "${!ALL_RUN_QUERIES[@]}"; do
+ query="${ALL_RUN_QUERIES[$i]}"
+ status="${QUERY_STATUS[$i]:-error}"
+ error_msg="${QUERY_ERROR_MSG[$i]:-}"
+ # Find matching full_result and stats by searching SUCCESSFUL_QUERIES index
+ full_result=""
+ stats_result=""
+ if [[ "$status" == "success" ]]; then
+ # Find index in SUCCESSFUL_QUERIES
+ for j in "${!SUCCESSFUL_QUERIES[@]}"; do
+ if [[ "${SUCCESSFUL_QUERIES[$j]}" == "$query" ]]; then
+ full_result="${QUERY_FULL_RESULTS[$j]}"
+ stats_result="${QUERY_STATS[$j]}"
+ break
+ fi
+ done
+ fi
+ # Escape quotes and newlines for JSON
+ escaped_result=$(echo "$full_result" | sed 's/\\/\\\\/g' | sed 's/"/\\"/g' | tr '\n' ' ')
+ escaped_error=$(echo "$error_msg" | sed 's/\\/\\\\/g' | sed 's/"/\\"/g' | tr '\n' ' ')
+
+ echo " {"
+ echo " \"query\": \"${query%.dml}\","
+ echo " \"status\": \"$status\","
+ echo " \"error_message\": \"$escaped_error\","
+ echo " \"result\": \"$escaped_result\""
+ if [[ -n "$stats_result" ]]; then
+ echo " ,\"stats\": ["
+ echo "$stats_result" | sed 's/\\/\\\\/g' | sed 's/"/\\"/g' | sed 's/\t/ /g' | awk '
+ BEGIN { first = 1 }
+ {
+ if (!first) printf ",\n"
+ printf " \"%s\"", $0
+ first = 0
+ }
+ END { if (!first) printf "\n" }
+ '
+ echo " ]"
+ fi
+ if [[ $i -lt $((${#ALL_RUN_QUERIES[@]} - 1)) ]]; then
+ echo " },"
+ else
+ echo " }"
+ fi
+ done
+ echo " ]"
+ echo "}"
+} > "$OUTPUT_METADATA_FILE"
+
+echo ""
+echo "Metadata saved to $OUTPUT_METADATA_FILE"
+echo "Output directory: $OUTPUT_BASE_DIR"
+echo " - TXT files: $OUTPUT_TXT_DIR"
+echo " - CSV files: $OUTPUT_CSV_DIR"
+echo " - JSON files: $OUTPUT_JSON_DIR"
+
+# Detailed per-query summary (show status and error messages if any)
+if [[ ${#ALL_RUN_QUERIES[@]} -gt 0 ]]; then
+ echo ""
+ echo "==================================================="
+ echo "QUERIES SUMMARY"
+ echo "==================================================="
+ printf "%-4s %-15s %-30s %s\n" "No." "Query" "Result" "Status"
+ echo "---------------------------------------------------"
+ for i in "${!ALL_RUN_QUERIES[@]}"; do
+ query="${ALL_RUN_QUERIES[$i]}"
+ query_display="${query%.dml}" # Remove .dml extension for display
+ status="${QUERY_STATUS[$i]:-error}"
+ if [[ "$status" == "success" ]]; then
+ # Find index in SUCCESSFUL_QUERIES to fetch result
+ result=""
+ for j in "${!SUCCESSFUL_QUERIES[@]}"; do
+ if [[ "${SUCCESSFUL_QUERIES[$j]}" == "$query" ]]; then
+ result="${QUERY_RESULTS[$j]}"
+ break
+ fi
+ done
+ printf "%-4d %-15s %-30s %s\n" "$((i+1))" "$query_display" "$result" "✓ Success"
+ else
+ err="${QUERY_ERROR_MSG[$i]:-Unknown error}"
+ printf "%-4d %-15s %-30s %s\n" "$((i+1))" "$query_display" "N/A" "ERROR: ${err}"
+ fi
+ done
+echo "==================================================="
+fi
+
+# Display long outputs for queries that had table results
+if [[ ${#SUCCESSFUL_QUERIES[@]} -gt 0 ]]; then
+ # Check if we have any long outputs to display
+ has_long_outputs=false
+ for i in "${!LONG_OUTPUTS[@]}"; do
+ if [[ -n "${LONG_OUTPUTS[$i]}" ]]; then
+ has_long_outputs=true
+ break
+ fi
+ done
+
+ if $has_long_outputs; then
+ echo ""
+ echo "========================================="
+ echo "DETAILED QUERY RESULTS"
+ echo "========================================="
+ for i in "${!SUCCESSFUL_QUERIES[@]}"; do
+ if [[ -n "${LONG_OUTPUTS[$i]}" ]]; then
+ query="${SUCCESSFUL_QUERIES[$i]}"
+ query_display="${query%.dml}" # Remove .dml extension for display
+ echo ""
+ echo "[$((i+1))] Results for $query_display:"
+ echo "----------------------------------------"
+ echo "${LONG_OUTPUTS[$i]}"
+ echo "----------------------------------------"
+ fi
+ done
+ echo "========================================="
+ fi
+fi
+
+# Exit with appropriate code
+if [[ $failed -gt 0 ]]; then
+ exit 1
+fi
diff --git a/src/main/java/org/apache/sysds/api/DMLOptions.java b/src/main/java/org/apache/sysds/api/DMLOptions.java
index 917aecc4ab3..10c41e3d0a8 100644
--- a/src/main/java/org/apache/sysds/api/DMLOptions.java
+++ b/src/main/java/org/apache/sysds/api/DMLOptions.java
@@ -19,6 +19,10 @@
package org.apache.sysds.api;
+import java.nio.file.Files;
+import java.nio.file.InvalidPathException;
+import java.nio.file.Path;
+import java.nio.file.Paths;
import java.util.HashMap;
import java.util.Map;
@@ -66,6 +70,10 @@ public class DMLOptions {
public boolean gpu = false; // Whether to use the GPU
public boolean forceGPU = false; // Whether to ignore memory & estimates and always use the GPU
public boolean ooc = false; // Whether to use the OOC backend
+ public boolean oocLogEvents = false; // Whether to record I/O and task compute events (fine grained, may impact performance on many small tasks)
+ public String oocLogPath = "./"; // The directory where to save the recorded event logs (csv)
+ public boolean oocStats = false; // Wether to record and print coarse grained ooc statistics
+ public int oocStatsCount = 10; // Default ooc statistics count
public boolean debug = false; // to go into debug mode to be able to step through a program
public String filePath = null; // path to script
public String script = null; // the script itself
@@ -105,7 +113,11 @@ public String toString() {
", fedStats=" + fedStats +
", fedStatsCount=" + fedStatsCount +
", fedMonitoring=" + fedMonitoring +
- ", fedMonitoringAddress" + fedMonitoringAddress +
+ ", fedMonitoringAddress=" + fedMonitoringAddress +
+ ", oocStats=" + oocStats +
+ ", oocStatsCount=" + oocStatsCount +
+ ", oocLogEvents=" + oocLogEvents +
+ ", oocLogPath=" + oocLogPath +
", memStats=" + memStats +
", explainType=" + explainType +
", execMode=" + execMode +
@@ -193,7 +205,7 @@ else if (lineageType.equalsIgnoreCase("debugger"))
else if (execMode.equalsIgnoreCase("hybrid")) dmlOptions.execMode = ExecMode.HYBRID;
else if (execMode.equalsIgnoreCase("spark")) dmlOptions.execMode = ExecMode.SPARK;
else throw new org.apache.commons.cli.ParseException("Invalid argument specified for -exec option, must be one of [hadoop, singlenode, hybrid, HYBRID, spark]");
- }
+ }
if (line.hasOption("explain")) {
dmlOptions.explainType = ExplainType.RUNTIME;
String explainType = line.getOptionValue("explain");
@@ -259,6 +271,33 @@ else if (lineageType.equalsIgnoreCase("debugger"))
}
}
+ dmlOptions.oocStats = line.hasOption("oocStats");
+ if (dmlOptions.oocStats) {
+ String oocStatsCount = line.getOptionValue("oocStats");
+ if (oocStatsCount != null) {
+ try {
+ dmlOptions.oocStatsCount = Integer.parseInt(oocStatsCount);
+ } catch (NumberFormatException e) {
+ throw new org.apache.commons.cli.ParseException("Invalid argument specified for -oocStats option, must be a valid integer");
+ }
+ }
+ }
+
+ dmlOptions.oocLogEvents = line.hasOption("oocLogEvents");
+ if (dmlOptions.oocLogEvents) {
+ String eventLogPath = line.getOptionValue("oocLogEvents");
+ if (eventLogPath != null) {
+ try {
+ Path p = Paths.get(eventLogPath);
+ if (!Files.isDirectory(p))
+ throw new org.apache.commons.cli.ParseException("Invalid argument specified for -oocLogEvents option, must be valid directory");
+ } catch (InvalidPathException e) {
+ throw new org.apache.commons.cli.ParseException("Invalid argument specified for -oocLogEvents option, must be a valid path");
+ }
+ dmlOptions.oocLogPath = eventLogPath;
+ }
+ }
+
dmlOptions.memStats = line.hasOption("mem");
dmlOptions.clean = line.hasOption("clean");
@@ -387,6 +426,12 @@ private static Options createCLIOptions() {
Option fedStatsOpt = OptionBuilder.withArgName("count")
.withDescription("monitors and reports summary execution statistics of federated workers; heavy hitter is 10 unless overridden; default off")
.hasOptionalArg().create("fedStats");
+ Option oocStatsOpt = OptionBuilder
+ .withDescription("monitors and reports summary execution statistics of ooc operators and tasks; heavy hitter is 10 unless overriden; default off")
+ .hasOptionalArg().create("oocStats");
+ Option oocLogEventsOpt = OptionBuilder
+ .withDescription("records fine grained events of compute tasks, I/O, and cache; -oocLogEvents [dir='./']")
+ .hasOptionalArg().create("oocLogEvents");
Option memOpt = OptionBuilder.withDescription("monitors and reports max memory consumption in CP; default off")
.create("mem");
Option explainOpt = OptionBuilder.withArgName("level")
@@ -452,6 +497,8 @@ private static Options createCLIOptions() {
options.addOption(statsOpt);
options.addOption(ngramsOpt);
options.addOption(fedStatsOpt);
+ options.addOption(oocStatsOpt);
+ options.addOption(oocLogEventsOpt);
options.addOption(memOpt);
options.addOption(explainOpt);
options.addOption(execOpt);
diff --git a/src/main/java/org/apache/sysds/api/DMLScript.java b/src/main/java/org/apache/sysds/api/DMLScript.java
index 65805b5c2ed..81acb9deacd 100644
--- a/src/main/java/org/apache/sysds/api/DMLScript.java
+++ b/src/main/java/org/apache/sysds/api/DMLScript.java
@@ -75,6 +75,7 @@
import org.apache.sysds.runtime.lineage.LineageCacheConfig;
import org.apache.sysds.runtime.lineage.LineageCacheConfig.LineageCachePolicy;
import org.apache.sysds.runtime.lineage.LineageCacheConfig.ReuseCacheType;
+import org.apache.sysds.runtime.ooc.stats.OOCEventLog;
import org.apache.sysds.runtime.util.CommonThreadPool;
import org.apache.sysds.runtime.util.HDFSTool;
import org.apache.sysds.runtime.util.LocalFileUtils;
@@ -149,6 +150,12 @@ public class DMLScript
public static boolean SYNCHRONIZE_GPU = true;
// Set OOC backend
public static boolean USE_OOC = DMLOptions.defaultOptions.ooc;
+ // Record and print OOC statistics
+ public static boolean OOC_STATISTICS = DMLOptions.defaultOptions.oocStats;
+ public static int OOC_STATISTICS_COUNT = DMLOptions.defaultOptions.oocStatsCount;
+ // Record and save fine grained OOC event logs as csv to the specified dir
+ public static boolean OOC_LOG_EVENTS = DMLOptions.defaultOptions.oocLogEvents;
+ public static String OOC_LOG_PATH = DMLOptions.defaultOptions.oocLogPath;
// Enable eager CUDA free on rmvar
public static boolean EAGER_CUDA_FREE = false;
@@ -272,6 +279,10 @@ public static boolean executeScript( String[] args )
USE_ACCELERATOR = dmlOptions.gpu;
FORCE_ACCELERATOR = dmlOptions.forceGPU;
USE_OOC = dmlOptions.ooc;
+ OOC_STATISTICS = dmlOptions.oocStats;
+ OOC_STATISTICS_COUNT = dmlOptions.oocStatsCount;
+ OOC_LOG_EVENTS = dmlOptions.oocLogEvents;
+ OOC_LOG_PATH = dmlOptions.oocLogPath;
EXPLAIN = dmlOptions.explainType;
EXEC_MODE = dmlOptions.execMode;
LINEAGE = dmlOptions.lineage;
@@ -323,11 +334,14 @@ public static boolean executeScript( String[] args )
LineageCacheConfig.setCachePolicy(LINEAGE_POLICY);
LineageCacheConfig.setEstimator(LINEAGE_ESTIMATE);
+ if (dmlOptions.oocLogEvents)
+ OOCEventLog.setup(100000);
+
String dmlScriptStr = readDMLScript(isFile, fileOrScript);
Map argVals = dmlOptions.argVals;
DML_FILE_PATH_ANTLR_PARSER = dmlOptions.filePath;
-
+
//Step 3: invoke dml script
printInvocationInfo(fileOrScript, fnameOptConfig, argVals);
execute(dmlScriptStr, fnameOptConfig, argVals, args);
diff --git a/src/main/java/org/apache/sysds/api/PythonDMLScript.java b/src/main/java/org/apache/sysds/api/PythonDMLScript.java
index 3b1864d71dd..1a74ba0ea49 100644
--- a/src/main/java/org/apache/sysds/api/PythonDMLScript.java
+++ b/src/main/java/org/apache/sysds/api/PythonDMLScript.java
@@ -1,18 +1,18 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
+ * or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
- * http://www.apache.org/licenses/LICENSE-2.0
+ * http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
+ * KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
@@ -24,9 +24,11 @@
import org.apache.log4j.Level;
import org.apache.log4j.Logger;
import org.apache.sysds.api.jmlc.Connection;
+import org.apache.sysds.common.Types.ValueType;
-import org.apache.sysds.common.Types;
import org.apache.sysds.runtime.DMLRuntimeException;
+import org.apache.sysds.runtime.frame.data.FrameBlock;
+import org.apache.sysds.runtime.frame.data.columns.Array;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
import org.apache.sysds.runtime.util.CommonThreadPool;
import org.apache.sysds.runtime.util.UnixPipeUtils;
@@ -79,7 +81,7 @@ public static void main(String[] args) throws Exception {
* therefore use logging framework. and terminate program.
*/
LOG.info("failed startup", p4e);
- System.exit(-1);
+ exitHandler.exit(-1);
}
catch(Exception e) {
throw new DMLException("Failed startup and maintaining Python gateway", e);
@@ -116,59 +118,59 @@ public void openPipes(String path, int num) throws IOException {
}
}
- public MatrixBlock startReadingMbFromPipe(int id, int rlen, int clen, Types.ValueType type) throws IOException {
+ public MatrixBlock startReadingMbFromPipe(int id, int rlen, int clen, ValueType type) throws IOException {
long limit = (long) rlen * clen;
LOG.debug("trying to read matrix from "+id+" with "+rlen+" rows and "+clen+" columns. Total size: "+limit);
if(limit > Integer.MAX_VALUE)
throw new DMLRuntimeException("Dense NumPy array of size " + limit +
" cannot be converted to MatrixBlock");
- MatrixBlock mb = new MatrixBlock(rlen, clen, false, -1);
+ MatrixBlock mb;
if(fromPython != null){
BufferedInputStream pipe = fromPython.get(id);
double[] denseBlock = new double[(int) limit];
- UnixPipeUtils.readNumpyArrayInBatches(pipe, id, BATCH_SIZE, (int) limit, type, denseBlock, 0);
- mb.init(denseBlock, rlen, clen);
+ long nnz = UnixPipeUtils.readNumpyArrayInBatches(pipe, id, BATCH_SIZE, (int) limit, type, denseBlock, 0);
+ mb = new MatrixBlock(rlen, clen, denseBlock);
+ mb.setNonZeros(nnz);
} else {
throw new DMLRuntimeException("FIFO Pipes are not initialized.");
}
- mb.recomputeNonZeros();
- mb.examSparsity();
LOG.debug("Reading from Python finished");
+ mb.examSparsity();
return mb;
}
- public MatrixBlock startReadingMbFromPipes(int[] blockSizes, int rlen, int clen, Types.ValueType type) throws ExecutionException, InterruptedException {
+ public MatrixBlock startReadingMbFromPipes(int[] blockSizes, int rlen, int clen, ValueType type) throws ExecutionException, InterruptedException {
long limit = (long) rlen * clen;
if(limit > Integer.MAX_VALUE)
throw new DMLRuntimeException("Dense NumPy array of size " + limit +
" cannot be converted to MatrixBlock");
- MatrixBlock mb = new MatrixBlock(rlen, clen, false, -1);
+ MatrixBlock mb = new MatrixBlock(rlen, clen, false, rlen*clen);
if(fromPython != null){
ExecutorService pool = CommonThreadPool.get();
double[] denseBlock = new double[(int) limit];
int offsetOut = 0;
- List> futures = new ArrayList<>();
+ List> futures = new ArrayList<>();
for (int i = 0; i < blockSizes.length; i++) {
BufferedInputStream pipe = fromPython.get(i);
int id = i, blockSize = blockSizes[i], _offsetOut = offsetOut;
- Callable task = () -> {
- UnixPipeUtils.readNumpyArrayInBatches(pipe, id, BATCH_SIZE, blockSize, type, denseBlock, _offsetOut);
- return null;
+ Callable task = () -> {
+ return UnixPipeUtils.readNumpyArrayInBatches(pipe, id, BATCH_SIZE, blockSize, type, denseBlock, _offsetOut);
};
futures.add(pool.submit(task));
offsetOut += blockSize;
}
- // Wait for all tasks and propagate exceptions
- for (Future f : futures) {
- f.get();
+ // Wait for all tasks and propagate exceptions, sum up nonzeros
+ long nnz = 0;
+ for (Future f : futures) {
+ nnz += f.get();
}
- mb.init(denseBlock, rlen, clen);
+ mb = new MatrixBlock(rlen, clen, denseBlock);
+ mb.setNonZeros(nnz);
} else {
throw new DMLRuntimeException("FIFO Pipes are not initialized.");
}
- mb.recomputeNonZeros();
mb.examSparsity();
return mb;
}
@@ -181,7 +183,7 @@ public void startWritingMbToPipe(int id, MatrixBlock mb) throws IOException {
LOG.debug("Trying to write matrix ["+baseDir + "-"+ id+"] with "+rlen+" rows and "+clen+" columns. Total size: "+numElem*8);
BufferedOutputStream out = toPython.get(id);
- long bytes = UnixPipeUtils.writeNumpyArrayInBatches(out, id, BATCH_SIZE, numElem, Types.ValueType.FP64, mb);
+ long bytes = UnixPipeUtils.writeNumpyArrayInBatches(out, id, BATCH_SIZE, numElem, ValueType.FP64, mb);
LOG.debug("Writing of " + bytes +" Bytes to Python ["+baseDir + "-"+ id+"] finished");
} else {
@@ -189,6 +191,43 @@ public void startWritingMbToPipe(int id, MatrixBlock mb) throws IOException {
}
}
+ public void startReadingColFromPipe(int id, FrameBlock fb, int rows, int totalBytes, int col, ValueType type, boolean any) throws IOException {
+ if (fromPython == null) {
+ throw new DMLRuntimeException("FIFO Pipes are not initialized.");
+ }
+
+ BufferedInputStream pipe = fromPython.get(id);
+ LOG.debug("Start reading FrameBlock column from pipe #" + id + " with type " + type);
+
+ // Delegate to UnixPipeUtils
+ Array> arr = UnixPipeUtils.readFrameColumnFromPipe(pipe, id, rows, totalBytes, BATCH_SIZE, type);
+ // Set column into FrameBlock
+ fb.setColumn(col, arr);
+ ValueType[] schema = fb.getSchema();
+ // inplace update the schema for cases: int8 -> int32
+ schema[col] = arr.getValueType();
+
+ LOG.debug("Finished reading FrameBlock column from pipe #" + id);
+ }
+
+ public void startWritingColToPipe(int id, FrameBlock fb, int col) throws IOException {
+ if (toPython == null) {
+ throw new DMLRuntimeException("FIFO Pipes are not initialized.");
+ }
+
+ BufferedOutputStream pipe = toPython.get(id);
+ ValueType type = fb.getSchema()[col];
+ int rows = fb.getNumRows();
+ Array> array = fb.getColumn(col);
+
+ LOG.debug("Start writing FrameBlock column #" + col + " to pipe #" + id + " with type " + type + " and " + rows + " rows");
+
+ // Delegate to UnixPipeUtils
+ long bytes = UnixPipeUtils.writeFrameColumnToPipe(pipe, id, BATCH_SIZE, array, type);
+
+ LOG.debug("Finished writing FrameBlock column #" + col + " to pipe #" + id + ". Total bytes: " + bytes);
+ }
+
public void closePipes() throws IOException {
LOG.debug("Closing all pipes in Java");
for (BufferedInputStream pipe : fromPython.values())
@@ -198,6 +237,20 @@ public void closePipes() throws IOException {
LOG.debug("Closed all pipes in Java");
}
+ @FunctionalInterface
+ public interface ExitHandler {
+ void exit(int status);
+ }
+
+ private static volatile ExitHandler exitHandler = System::exit;
+
+ public static void setExitHandler(ExitHandler handler) {
+ exitHandler = handler == null ? System::exit : handler;
+ }
+
+ public static void resetExitHandler() {
+ exitHandler = System::exit;
+ }
protected static class DMLGateWayListener extends DefaultGatewayServerListener {
private static final Log LOG = LogFactory.getLog(DMLGateWayListener.class.getName());
diff --git a/src/main/java/org/apache/sysds/api/ScriptExecutorUtils.java b/src/main/java/org/apache/sysds/api/ScriptExecutorUtils.java
index fd344e44e9b..0d15218bf9a 100644
--- a/src/main/java/org/apache/sysds/api/ScriptExecutorUtils.java
+++ b/src/main/java/org/apache/sysds/api/ScriptExecutorUtils.java
@@ -36,6 +36,7 @@
import org.apache.sysds.runtime.instructions.gpu.context.GPUObject;
import org.apache.sysds.runtime.lineage.LineageEstimatorStatistics;
import org.apache.sysds.runtime.lineage.LineageGPUCacheEviction;
+import org.apache.sysds.runtime.ooc.cache.OOCCacheManager;
import org.apache.sysds.utils.Statistics;
public class ScriptExecutorUtils {
@@ -127,6 +128,9 @@ public static void executeRuntimeProgram(Program rtprog, ExecutionContext ec, DM
if (DMLScript.LINEAGE_ESTIMATE)
System.out.println(LineageEstimatorStatistics.displayLineageEstimates());
+
+ if (DMLScript.USE_OOC)
+ OOCCacheManager.reset();
}
}
diff --git a/src/main/java/org/apache/sysds/common/Builtins.java b/src/main/java/org/apache/sysds/common/Builtins.java
index 4feab311c76..dc1f23b83fc 100644
--- a/src/main/java/org/apache/sysds/common/Builtins.java
+++ b/src/main/java/org/apache/sysds/common/Builtins.java
@@ -392,6 +392,8 @@ public enum Builtins {
RMEMPTY("removeEmpty", false, true),
SCALE("scale", true, false),
SCALEAPPLY("scaleApply", true, false),
+ SCALEROBUST("scaleRobust", true, false),
+ SCALEROBUSTAPPLY("scaleRobustApply", true, false),
SCALE_MINMAX("scaleMinMax", true, false),
TIME("time", false),
TOKENIZE("tokenize", false, true),
diff --git a/src/main/java/org/apache/sysds/common/Opcodes.java b/src/main/java/org/apache/sysds/common/Opcodes.java
index 251f773a18c..1b0536416d6 100644
--- a/src/main/java/org/apache/sysds/common/Opcodes.java
+++ b/src/main/java/org/apache/sysds/common/Opcodes.java
@@ -19,7 +19,25 @@
package org.apache.sysds.common;
-import org.apache.sysds.lops.*;
+import org.apache.sysds.lops.Append;
+import org.apache.sysds.lops.DataGen;
+import org.apache.sysds.lops.LeftIndex;
+import org.apache.sysds.lops.RightIndex;
+import org.apache.sysds.lops.Compression;
+import org.apache.sysds.lops.DeCompression;
+import org.apache.sysds.lops.Local;
+import org.apache.sysds.lops.Checkpoint;
+import org.apache.sysds.lops.WeightedCrossEntropy;
+import org.apache.sysds.lops.WeightedCrossEntropyR;
+import org.apache.sysds.lops.WeightedDivMM;
+import org.apache.sysds.lops.WeightedDivMMR;
+import org.apache.sysds.lops.WeightedSigmoid;
+import org.apache.sysds.lops.WeightedSigmoidR;
+import org.apache.sysds.lops.WeightedSquaredLoss;
+import org.apache.sysds.lops.WeightedSquaredLossR;
+import org.apache.sysds.lops.WeightedUnaryMM;
+import org.apache.sysds.lops.WeightedUnaryMMR;
+
import org.apache.sysds.common.Types.OpOp1;
import org.apache.sysds.hops.FunctionOp;
diff --git a/src/main/java/org/apache/sysds/hops/NaryOp.java b/src/main/java/org/apache/sysds/hops/NaryOp.java
index 6962beadcbc..d752316a526 100644
--- a/src/main/java/org/apache/sysds/hops/NaryOp.java
+++ b/src/main/java/org/apache/sysds/hops/NaryOp.java
@@ -165,7 +165,7 @@ else if ( areDimsBelowThreshold() )
setRequiresRecompileIfNecessary();
//ensure cp exec type for single-node operations
- if ( _op == OpOpN.PRINTF || _op == OpOpN.EVAL || _op == OpOpN.LIST
+ if ( _op == OpOpN.PRINTF || _op == OpOpN.EVAL || _op == OpOpN.LIST || _op == OpOpN.EINSUM
//TODO: cbind/rbind of lists only support in CP right now
|| (_op == OpOpN.CBIND && getInput().get(0).getDataType().isList())
|| (_op == OpOpN.RBIND && getInput().get(0).getDataType().isList())
diff --git a/src/main/java/org/apache/sysds/hops/ReorgOp.java b/src/main/java/org/apache/sysds/hops/ReorgOp.java
index 5fc73e2bd3f..f43d1cc2baf 100644
--- a/src/main/java/org/apache/sysds/hops/ReorgOp.java
+++ b/src/main/java/org/apache/sysds/hops/ReorgOp.java
@@ -173,7 +173,9 @@ _op, getDataType(), getValueType(), et,
for (int i = 0; i < 2; i++)
linputs[i] = getInput().get(i).constructLops();
- Transform transform1 = new Transform(linputs, _op, getDataType(), getValueType(), et, 1);
+ Transform transform1 = new Transform(
+ linputs, _op, getDataType(), getValueType(), et,
+ OptimizerUtils.getConstrainedNumThreads(_maxNumThreads));
setOutputDimensions(transform1);
setLineNumbers(transform1);
diff --git a/src/main/java/org/apache/sysds/hops/fedplanner/FederatedPlanCostEstimator.java b/src/main/java/org/apache/sysds/hops/fedplanner/FederatedPlanCostEstimator.java
index 0c5d6c0290e..fe40c83e690 100644
--- a/src/main/java/org/apache/sysds/hops/fedplanner/FederatedPlanCostEstimator.java
+++ b/src/main/java/org/apache/sysds/hops/fedplanner/FederatedPlanCostEstimator.java
@@ -19,6 +19,12 @@
package org.apache.sysds.hops.fedplanner;
+import java.util.Iterator;
+import java.util.LinkedHashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.NoSuchElementException;
+
import org.apache.commons.lang3.tuple.Pair;
import org.apache.sysds.common.Types;
import org.apache.sysds.hops.DataOp;
@@ -28,7 +34,6 @@
import org.apache.sysds.hops.fedplanner.FederatedMemoTable.HopCommon;
import org.apache.sysds.runtime.instructions.fed.FEDInstruction.FederatedOutput;
-import java.util.*;
/**
* Cost estimator for federated execution plans.
diff --git a/src/main/java/org/apache/sysds/hops/fedplanner/FederatedPlanRewireTransTable.java b/src/main/java/org/apache/sysds/hops/fedplanner/FederatedPlanRewireTransTable.java
index f41d924999e..9a41509e81f 100644
--- a/src/main/java/org/apache/sysds/hops/fedplanner/FederatedPlanRewireTransTable.java
+++ b/src/main/java/org/apache/sysds/hops/fedplanner/FederatedPlanRewireTransTable.java
@@ -23,15 +23,29 @@
import org.apache.commons.lang3.tuple.ImmutablePair;
import org.apache.sysds.common.Types;
import org.apache.sysds.common.Types.ParamBuiltinOp;
-import org.apache.sysds.hops.*;
+import org.apache.sysds.hops.AggBinaryOp;
+import org.apache.sysds.hops.AggUnaryOp;
+import org.apache.sysds.hops.BinaryOp;
+import org.apache.sysds.hops.DataGenOp;
+import org.apache.sysds.hops.DataOp;
+import org.apache.sysds.hops.DnnOp;
+import org.apache.sysds.hops.FunctionOp;
import org.apache.sysds.hops.FunctionOp.FunctionType;
-import org.apache.sysds.parser.*;
+import org.apache.sysds.hops.Hop;
+import org.apache.sysds.hops.IndexingOp;
+import org.apache.sysds.hops.LeftIndexingOp;
+import org.apache.sysds.hops.LiteralOp;
+import org.apache.sysds.hops.NaryOp;
+import org.apache.sysds.hops.ParameterizedBuiltinOp;
+import org.apache.sysds.hops.QuaternaryOp;
+import org.apache.sysds.hops.ReorgOp;
+import org.apache.sysds.hops.TernaryOp;
+import org.apache.sysds.hops.UnaryOp;
import org.apache.sysds.hops.fedplanner.FederatedMemoTable.HopCommon;
import org.apache.sysds.hops.rewrite.HopRewriteUtils;
import org.apache.sysds.runtime.controlprogram.federated.FederatedRange;
import org.apache.sysds.runtime.controlprogram.federated.FederationUtils;
import org.apache.sysds.runtime.util.UtilFunctions;
-import java.util.*;
import org.apache.sysds.runtime.instructions.fed.InitFEDInstruction;
import org.apache.sysds.runtime.controlprogram.federated.FederatedData;
import java.net.InetAddress;
@@ -49,8 +63,26 @@
import org.apache.sysds.common.Types.DataType;
import org.apache.sysds.common.Types.ReOrgOp;
import org.apache.sysds.lops.MMTSJ.MMTSJType;
+import org.apache.sysds.parser.DMLProgram;
+import org.apache.sysds.parser.DataExpression;
+import org.apache.sysds.parser.ForStatement;
+import org.apache.sysds.parser.ForStatementBlock;
+import org.apache.sysds.parser.FunctionStatement;
+import org.apache.sysds.parser.FunctionStatementBlock;
+import org.apache.sysds.parser.IfStatement;
+import org.apache.sysds.parser.IfStatementBlock;
+import org.apache.sysds.parser.StatementBlock;
+import org.apache.sysds.parser.VariableSet;
+import org.apache.sysds.parser.WhileStatement;
+import org.apache.sysds.parser.WhileStatementBlock;
+
import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.HashSet;
+import java.util.Iterator;
import java.util.List;
+import java.util.Map;
+import java.util.Set;
public class FederatedPlanRewireTransTable {
diff --git a/src/main/java/org/apache/sysds/hops/fedplanner/FederatedPlannerFedCostBased.java b/src/main/java/org/apache/sysds/hops/fedplanner/FederatedPlannerFedCostBased.java
index 99043860bd7..9204b65c793 100644
--- a/src/main/java/org/apache/sysds/hops/fedplanner/FederatedPlannerFedCostBased.java
+++ b/src/main/java/org/apache/sysds/hops/fedplanner/FederatedPlannerFedCostBased.java
@@ -64,23 +64,23 @@ public void rewriteFunctionDynamic(FunctionStatementBlock function, LocalVariabl
private void rewriteHop(FedPlan optimalPlan, FederatedMemoTable memoTable, Set visited) {
long hopID = optimalPlan.getHopRef().getHopID();
- if (visited.contains(hopID)) {
- return;
- } else {
- visited.add(hopID);
- }
+ if (visited.contains(hopID)) {
+ return;
+ } else {
+ visited.add(hopID);
+ }
- for (Pair childFedPlanPair : optimalPlan.getChildFedPlans()) {
- FedPlan childPlan = memoTable.getFedPlanAfterPrune(childFedPlanPair);
-
- // DEBUG: Check if getFedPlanAfterPrune returns null
- if (childPlan == null) {
+ for (Pair childFedPlanPair : optimalPlan.getChildFedPlans()) {
+ FedPlan childPlan = memoTable.getFedPlanAfterPrune(childFedPlanPair);
+
+ // DEBUG: Check if getFedPlanAfterPrune returns null
+ if (childPlan == null) {
FederatedPlannerLogger.logNullChildPlanDebug(childFedPlanPair, optimalPlan, memoTable);
- continue;
- }
-
+ continue;
+ }
+
rewriteHop(childPlan, memoTable, visited);
- }
+ }
if (optimalPlan.getFedOutType() == FEDInstruction.FederatedOutput.LOUT) {
optimalPlan.setFederatedOutput(FEDInstruction.FederatedOutput.LOUT);
diff --git a/src/main/java/org/apache/sysds/hops/fedplanner/FederatedPlannerLogger.java b/src/main/java/org/apache/sysds/hops/fedplanner/FederatedPlannerLogger.java
index 35742ce1e4f..d2c8134c765 100644
--- a/src/main/java/org/apache/sysds/hops/fedplanner/FederatedPlannerLogger.java
+++ b/src/main/java/org/apache/sysds/hops/fedplanner/FederatedPlannerLogger.java
@@ -39,560 +39,560 @@
* This class integrates the functionality of the former FederatedMemoTablePrinter.
*/
public class FederatedPlannerLogger {
-
- /**
- * Logs hop information including name, hop ID, child hop IDs, privacy constraint, and ftype
- * @param hop The hop to log information for
- * @param privacyConstraintMap Map containing privacy constraints for hops
- * @param fTypeMap Map containing FType information for hops
- * @param logPrefix Prefix string to identify the log source
- */
- public static void logHopInfo(Hop hop, Map privacyConstraintMap,
- Map fTypeMap, String logPrefix) {
- StringBuilder childIds = new StringBuilder();
- if (hop.getInput() != null && !hop.getInput().isEmpty()) {
- for (int i = 0; i < hop.getInput().size(); i++) {
- if (i > 0) childIds.append(",");
- childIds.append(hop.getInput().get(i).getHopID());
- }
- } else {
- childIds.append("none");
- }
-
- Privacy privacyConstraint = privacyConstraintMap.get(hop.getHopID());
- FType ftype = fTypeMap.get(hop.getHopID());
-
- // Get hop type and opcode information
- String hopType = hop.getClass().getSimpleName();
- String opCode = hop.getOpString();
-
- System.out.println("[" + logPrefix + "] (ID:" + hop.getHopID() + " Name:" + hop.getName() +
- ") Type:" + hopType + " OpCode:" + opCode +
- " ChildIDs:(" + childIds.toString() + ") Privacy:" +
- (privacyConstraint != null ? privacyConstraint : "null") +
- " FType:" + (ftype != null ? ftype : "null"));
- }
-
- /**
- * Logs basic hop information without privacy and FType details
- * @param hop The hop to log information for
- * @param logPrefix Prefix string to identify the log source
- */
- public static void logBasicHopInfo(Hop hop, String logPrefix) {
- StringBuilder childIds = new StringBuilder();
- if (hop.getInput() != null && !hop.getInput().isEmpty()) {
- for (int i = 0; i < hop.getInput().size(); i++) {
- if (i > 0) childIds.append(",");
- childIds.append(hop.getInput().get(i).getHopID());
- }
- } else {
- childIds.append("none");
- }
-
- String hopType = hop.getClass().getSimpleName();
- String opCode = hop.getOpString();
-
- System.out.println("[" + logPrefix + "] (ID:" + hop.getHopID() + " Name:" + hop.getName() +
- ") Type:" + hopType + " OpCode:" + opCode +
- " ChildIDs:(" + childIds.toString() + ")");
- }
-
- /**
- * Logs detailed hop information with dimension and data type
- * @param hop The hop to log information for
- * @param privacyConstraintMap Map containing privacy constraints for hops
- * @param fTypeMap Map containing FType information for hops
- * @param logPrefix Prefix string to identify the log source
- */
- public static void logDetailedHopInfo(Hop hop, Map privacyConstraintMap,
- Map fTypeMap, String logPrefix) {
- StringBuilder childIds = new StringBuilder();
- if (hop.getInput() != null && !hop.getInput().isEmpty()) {
- for (int i = 0; i < hop.getInput().size(); i++) {
- if (i > 0) childIds.append(",");
- childIds.append(hop.getInput().get(i).getHopID());
- }
- } else {
- childIds.append("none");
- }
-
- Privacy privacyConstraint = privacyConstraintMap.get(hop.getHopID());
- FType ftype = fTypeMap.get(hop.getHopID());
-
- String hopType = hop.getClass().getSimpleName();
- String opCode = hop.getOpString();
- String dataType = hop.getDataType().toString();
- String dimensions = "[" + hop.getDim1() + "x" + hop.getDim2() + "]";
-
- System.out.println("[" + logPrefix + "] (ID:" + hop.getHopID() + " Name:" + hop.getName() +
- ") Type:" + hopType + " OpCode:" + opCode + " DataType:" + dataType +
- " Dims:" + dimensions + " ChildIDs:(" + childIds.toString() + ") Privacy:" +
- (privacyConstraint != null ? privacyConstraint : "null") +
- " FType:" + (ftype != null ? ftype : "null"));
- }
-
- /**
- * Logs error information for null fed plan scenarios
- * @param hopID The hop ID that caused the error
- * @param logPrefix Prefix string to identify the log source
- */
- public static void logNullFedPlanError(long hopID, String logPrefix) {
- System.err.println("[" + logPrefix + "] childFedPlan is null for hopID: " + hopID);
- }
-
- /**
- * Logs detailed error information for conflict resolution scenarios
- * @param hopID The hop ID that caused the error
- * @param fedPlan The federated plan with error details
- * @param logPrefix Prefix string to identify the log source
- */
- public static void logConflictResolutionError(long hopID, Object fedPlan, String logPrefix) {
- System.err.println("[" + logPrefix + "] confilctLOutFedPlan or confilctFOutFedPlan is null for hopID: " + hopID);
- System.err.println(" Child Hop Details:");
- if (fedPlan != null) {
- // Note: This assumes fedPlan has a getHopRef() method
- // In actual implementation, you might need to cast or handle differently
- System.err.println(" - Class: N/A");
- System.err.println(" - Name: N/A");
- System.err.println(" - OpString: N/A");
- System.err.println(" - HopID: " + hopID);
- }
- }
-
- /**
- * Logs debug information for getFederatedType function
- * @param hop The hop being analyzed
- * @param returnFType The FType that will be returned
- * @param reason The reason for the FType decision
- */
- public static void logGetFederatedTypeDebug(Hop hop, FType returnFType, String reason) {
- String hopName = hop.getName() != null ? hop.getName() : "null";
- long hopID = hop.getHopID();
- String operationType = hop.getClass().getSimpleName();
- String opCode = hop.getOpString();
-
- System.out.println("[GetFederatedType] HopName: " + hopName + " | HopID: " + hopID +
- " | OperationType: " + operationType + " | OpCode: " + opCode +
- " | ReturnFType: " + (returnFType != null ? returnFType : "null") +
- " | Reason: " + reason);
- }
-
- /**
- * Logs detailed hop error information with complete hop details
- * @param hop The hop that caused the error
- * @param logPrefix Prefix string to identify the log source
- * @param additionalMessage Additional error message
- */
- public static void logHopErrorDetails(Hop hop, String logPrefix, String additionalMessage) {
- System.err.println("[" + logPrefix + "] " + additionalMessage);
- System.err.println(" Child Hop Details:");
- System.err.println(" - Class: " + hop.getClass().getSimpleName());
- System.err.println(" - Name: " + (hop.getName() != null ? hop.getName() : "null"));
- System.err.println(" - OpString: " + hop.getOpString());
- System.err.println(" - HopID: " + hop.getHopID());
- }
-
- /**
- * Logs detailed null child plan debugging information
- * @param childFedPlanPair The child federated plan pair that is null
- * @param optimalPlan The current optimal plan (parent)
- * @param memoTable The memo table for lookups
- */
- public static void logNullChildPlanDebug(Pair childFedPlanPair,
- FedPlan optimalPlan,
- org.apache.sysds.hops.fedplanner.FederatedMemoTable memoTable) {
- FederatedOutput alternativeFedType = (childFedPlanPair.getRight() == FederatedOutput.LOUT) ?
- FederatedOutput.FOUT : FederatedOutput.LOUT;
- FedPlan alternativeChildPlan = memoTable.getFedPlanAfterPrune(childFedPlanPair.getLeft(), alternativeFedType);
-
- // Get child hop info
- Hop childHop = null;
- String childInfo = "UNKNOWN";
- if (alternativeChildPlan != null) {
- childHop = alternativeChildPlan.getHopRef();
- // Check if required fed type plan exists
- String requiredExists = memoTable.getFedPlanAfterPrune(childFedPlanPair.getLeft(), childFedPlanPair.getRight()) != null ? "O" : "X";
- // Check if alternative fed type plan exists
- String altExists = alternativeChildPlan != null ? "O" : "X";
-
- childInfo = String.format("ID:%d|Name:%s|Op:%s|RequiredFedType:%s(%s)|AltFedType:%s(%s)",
- childHop.getHopID(),
- childHop.getName() != null ? childHop.getName() : "null",
- childHop.getOpString(),
- childFedPlanPair.getRight(),
- requiredExists,
- alternativeFedType,
- altExists);
- }
-
- // Current parent hop info
- String currentParentInfo = String.format("ID:%d|Name:%s|Op:%s|FedType:%s|RequiredChild:%s",
- optimalPlan.getHopID(),
- optimalPlan.getHopRef().getName() != null ? optimalPlan.getHopRef().getName() : "null",
- optimalPlan.getHopRef().getOpString(),
- optimalPlan.getFedOutType(),
- childFedPlanPair.getRight());
-
- // Alternative parent info (if child has other parents)
- String alternativeParentInfo = "NONE";
- if (childHop != null) {
- List parents = childHop.getParent();
- for (Hop parent : parents) {
- if (parent.getHopID() != optimalPlan.getHopID()) {
- // Try to find alt parent's fed plan info
- String altParentFedType = "UNKNOWN";
- String altParentRequiredChild = "UNKNOWN";
-
- // Check both LOUT and FOUT plans for alt parent
- FedPlan altParentPlanLOUT = memoTable.getFedPlanAfterPrune(parent.getHopID(), FederatedOutput.LOUT);
- FedPlan altParentPlanFOUT = memoTable.getFedPlanAfterPrune(parent.getHopID(), FederatedOutput.FOUT);
-
- if (altParentPlanLOUT != null) {
- altParentFedType = "LOUT";
- // Find what this alt parent expects from child
- for (Pair altChildPair : altParentPlanLOUT.getChildFedPlans()) {
- if (altChildPair.getLeft() == childHop.getHopID()) {
- altParentRequiredChild = altChildPair.getRight().toString();
- break;
- }
- }
- } else if (altParentPlanFOUT != null) {
- altParentFedType = "FOUT";
- // Find what this alt parent expects from child
- for (Pair altChildPair : altParentPlanFOUT.getChildFedPlans()) {
- if (altChildPair.getLeft() == childHop.getHopID()) {
- altParentRequiredChild = altChildPair.getRight().toString();
- break;
- }
- }
- }
-
- alternativeParentInfo = String.format("ID:%d|Name:%s|Op:%s|FedType:%s|RequiredChild:%s",
- parent.getHopID(),
- parent.getName() != null ? parent.getName() : "null",
- parent.getOpString(),
- altParentFedType,
- altParentRequiredChild);
- break;
- }
- }
- }
-
- System.err.println("[DEBUG] NULL CHILD PLAN DETECTED:");
- System.err.println(" Child: " + childInfo);
- System.err.println(" Current Parent: " + currentParentInfo);
- System.err.println(" Alt Parent: " + alternativeParentInfo);
- System.err.println(" Alt Plan Exists: " + (alternativeChildPlan != null));
- }
-
- /**
- * Logs debugging information for TransRead hop rewiring process
- * @param hopName The name of the TransRead hop
- * @param hopID The ID of the TransRead hop
- * @param childHops List of child hops found during rewiring
- * @param isEmptyChildHops Whether the child hops list is empty
- * @param logPrefix Prefix string to identify the log source
- */
- public static void logTransReadRewireDebug(String hopName, long hopID, List childHops,
- boolean isEmptyChildHops, String logPrefix) {
- if (isEmptyChildHops) {
- System.err.println("[" + logPrefix + "] (hopName: " + hopName + ", hopID: " + hopID + ") child hops is empty");
- }
- }
-
- /**
- * Logs debugging information for filtered child hops during TransRead rewiring
- * @param hopName The name of the TransRead hop
- * @param hopID The ID of the TransRead hop
- * @param filteredChildHops List of filtered child hops
- * @param isEmptyFilteredChildHops Whether the filtered child hops list is empty
- * @param logPrefix Prefix string to identify the log source
- */
- public static void logFilteredChildHopsDebug(String hopName, long hopID, List filteredChildHops,
- boolean isEmptyFilteredChildHops, String logPrefix) {
- if (isEmptyFilteredChildHops) {
- System.err.println("[" + logPrefix + "] (hopName: " + hopName + ", hopID: " + hopID + ") filtered child hops is empty");
- }
- }
-
- /**
- * Logs detailed FType mismatch error information for TransRead hop
- * @param hop The TransRead hop with FType mismatch
- * @param filteredChildHops List of filtered child hops
- * @param fTypeMap Map containing FType information for hops
- * @param expectedFType The expected FType
- * @param mismatchedFType The mismatched FType
- * @param mismatchIndex The index where mismatch occurred
- */
- public static void logFTypeMismatchError(Hop hop, List filteredChildHops, Map fTypeMap,
- FType expectedFType, FType mismatchedFType, int mismatchIndex) {
- String hopName = hop.getName();
- long hopID = hop.getHopID();
-
- System.err.println("[Error] FType MISMATCH DETECTED for TransRead (hopName: " + hopName + ", hopID: " + hopID + ")");
- System.err.println("[Error] TRANSREAD HOP DETAILS - Type: " + hop.getClass().getSimpleName() +
- ", OpType: " + (hop instanceof org.apache.sysds.hops.DataOp ?
- ((org.apache.sysds.hops.DataOp)hop).getOp() : "N/A") +
- ", DataType: " + hop.getDataType() +
- ", Dims: [" + hop.getDim1() + "x" + hop.getDim2() + "]");
- System.err.println("[Error] FILTERED CHILD HOPS FTYPE ANALYSIS:");
-
- for (int j = 0; j < filteredChildHops.size(); j++) {
- Hop childHop = filteredChildHops.get(j);
- FType childFType = fTypeMap.get(childHop.getHopID());
- System.err.println("[Error] FilteredChild[" + j + "] - Name: " + childHop.getName() +
- ", ID: " + childHop.getHopID() +
- ", FType: " + childFType +
- ", Type: " + childHop.getClass().getSimpleName() +
- ", OpType: " + (childHop instanceof org.apache.sysds.hops.DataOp ?
- ((org.apache.sysds.hops.DataOp)childHop).getOp().toString() : "N/A") +
- ", Dims: [" + childHop.getDim1() + "x" + childHop.getDim2() + "]");
- }
-
- System.err.println("[Error] Expected FType: " + expectedFType +
- ", Mismatched FType: " + mismatchedFType +
- " at child index: " + mismatchIndex);
- }
-
- /**
- * Logs FType debug information for DataOp operations (FEDERATED, TRANSIENTWRITE, TRANSIENTREAD)
- * @param hop The DataOp hop being analyzed
- * @param fType The FType that was determined for this operation
- * @param opType The operation type (FEDERATED, TRANSIENTWRITE, TRANSIENTREAD)
- * @param reason The reason for the FType decision
- */
- public static void logDataOpFTypeDebug(Hop hop, FType fType, String opType, String reason) {
- String hopName = hop.getName() != null ? hop.getName() : "null";
- long hopID = hop.getHopID();
- String hopClass = hop.getClass().getSimpleName();
- String dimensions = "[" + hop.getDim1() + "x" + hop.getDim2() + "]";
-
- System.out.println("[GetFederatedType] HopName: " + hopName +
- " | HopID: " + hopID +
- " | HopClass: " + hopClass +
- " | OpType: " + opType +
- " | Dims: " + dimensions +
- " | FType: " + (fType != null ? fType : "null") +
- " | Reason: " + reason);
- }
-
- // ========== FederatedMemoTable Printing Methods ==========
-
- /**
- * Recursively prints a tree representation of the DAG starting from the given root FedPlan.
- * Includes information about hopID, fedOutType, TotalCost, SelfCost, and NetCost for each node.
- * Additionally, prints the additional total cost once at the beginning.
- *
- * @param rootFedPlan The starting point FedPlan to print
- * @param rootHopStatSet Set of root hop statistics
- * @param memoTable The memoization table containing FedPlan variants
- * @param additionalTotalCost The additional cost to be printed once
- */
- public static void printFedPlanTree(FederatedMemoTable.FedPlan rootFedPlan, Set rootHopStatSet,
- FederatedMemoTable memoTable, double additionalTotalCost) {
- System.out.println("Additional Cost: " + additionalTotalCost);
- Set visited = new HashSet<>();
- printFedPlanTreeRecursive(rootFedPlan, memoTable, visited, 0);
- for (Long hopID : rootHopStatSet) {
- FedPlan plan = memoTable.getFedPlanAfterPrune(hopID, FederatedOutput.LOUT);
- if (plan == null){
- plan = memoTable.getFedPlanAfterPrune(hopID, FederatedOutput.FOUT);
- }
- printNotReferencedFedPlanRecursive(plan, memoTable, visited, 1);
- }
- }
+ /**
+ * Logs hop information including name, hop ID, child hop IDs, privacy constraint, and ftype
+ * @param hop The hop to log information for
+ * @param privacyConstraintMap Map containing privacy constraints for hops
+ * @param fTypeMap Map containing FType information for hops
+ * @param logPrefix Prefix string to identify the log source
+ */
+ public static void logHopInfo(Hop hop, Map privacyConstraintMap,
+ Map fTypeMap, String logPrefix) {
+ StringBuilder childIds = new StringBuilder();
+ if (hop.getInput() != null && !hop.getInput().isEmpty()) {
+ for (int i = 0; i < hop.getInput().size(); i++) {
+ if (i > 0) childIds.append(",");
+ childIds.append(hop.getInput().get(i).getHopID());
+ }
+ } else {
+ childIds.append("none");
+ }
+
+ Privacy privacyConstraint = privacyConstraintMap.get(hop.getHopID());
+ FType ftype = fTypeMap.get(hop.getHopID());
+
+ // Get hop type and opcode information
+ String hopType = hop.getClass().getSimpleName();
+ String opCode = hop.getOpString();
+
+ System.out.println("[" + logPrefix + "] (ID:" + hop.getHopID() + " Name:" + hop.getName() +
+ ") Type:" + hopType + " OpCode:" + opCode +
+ " ChildIDs:(" + childIds.toString() + ") Privacy:" +
+ (privacyConstraint != null ? privacyConstraint : "null") +
+ " FType:" + (ftype != null ? ftype : "null"));
+ }
+
+ /**
+ * Logs basic hop information without privacy and FType details
+ * @param hop The hop to log information for
+ * @param logPrefix Prefix string to identify the log source
+ */
+ public static void logBasicHopInfo(Hop hop, String logPrefix) {
+ StringBuilder childIds = new StringBuilder();
+ if (hop.getInput() != null && !hop.getInput().isEmpty()) {
+ for (int i = 0; i < hop.getInput().size(); i++) {
+ if (i > 0) childIds.append(",");
+ childIds.append(hop.getInput().get(i).getHopID());
+ }
+ } else {
+ childIds.append("none");
+ }
+
+ String hopType = hop.getClass().getSimpleName();
+ String opCode = hop.getOpString();
+
+ System.out.println("[" + logPrefix + "] (ID:" + hop.getHopID() + " Name:" + hop.getName() +
+ ") Type:" + hopType + " OpCode:" + opCode +
+ " ChildIDs:(" + childIds.toString() + ")");
+ }
+
+ /**
+ * Logs detailed hop information with dimension and data type
+ * @param hop The hop to log information for
+ * @param privacyConstraintMap Map containing privacy constraints for hops
+ * @param fTypeMap Map containing FType information for hops
+ * @param logPrefix Prefix string to identify the log source
+ */
+ public static void logDetailedHopInfo(Hop hop, Map privacyConstraintMap,
+ Map fTypeMap, String logPrefix) {
+ StringBuilder childIds = new StringBuilder();
+ if (hop.getInput() != null && !hop.getInput().isEmpty()) {
+ for (int i = 0; i < hop.getInput().size(); i++) {
+ if (i > 0) childIds.append(",");
+ childIds.append(hop.getInput().get(i).getHopID());
+ }
+ } else {
+ childIds.append("none");
+ }
+
+ Privacy privacyConstraint = privacyConstraintMap.get(hop.getHopID());
+ FType ftype = fTypeMap.get(hop.getHopID());
+
+ String hopType = hop.getClass().getSimpleName();
+ String opCode = hop.getOpString();
+ String dataType = hop.getDataType().toString();
+ String dimensions = "[" + hop.getDim1() + "x" + hop.getDim2() + "]";
+
+ System.out.println("[" + logPrefix + "] (ID:" + hop.getHopID() + " Name:" + hop.getName() +
+ ") Type:" + hopType + " OpCode:" + opCode + " DataType:" + dataType +
+ " Dims:" + dimensions + " ChildIDs:(" + childIds.toString() + ") Privacy:" +
+ (privacyConstraint != null ? privacyConstraint : "null") +
+ " FType:" + (ftype != null ? ftype : "null"));
+ }
+
+ /**
+ * Logs error information for null fed plan scenarios
+ * @param hopID The hop ID that caused the error
+ * @param logPrefix Prefix string to identify the log source
+ */
+ public static void logNullFedPlanError(long hopID, String logPrefix) {
+ System.err.println("[" + logPrefix + "] childFedPlan is null for hopID: " + hopID);
+ }
+
+ /**
+ * Logs detailed error information for conflict resolution scenarios
+ * @param hopID The hop ID that caused the error
+ * @param fedPlan The federated plan with error details
+ * @param logPrefix Prefix string to identify the log source
+ */
+ public static void logConflictResolutionError(long hopID, Object fedPlan, String logPrefix) {
+ System.err.println("[" + logPrefix + "] confilctLOutFedPlan or confilctFOutFedPlan is null for hopID: " + hopID);
+ System.err.println(" Child Hop Details:");
+ if (fedPlan != null) {
+ // Note: This assumes fedPlan has a getHopRef() method
+ // In actual implementation, you might need to cast or handle differently
+ System.err.println(" - Class: N/A");
+ System.err.println(" - Name: N/A");
+ System.err.println(" - OpString: N/A");
+ System.err.println(" - HopID: " + hopID);
+ }
+ }
+
+ /**
+ * Logs debug information for getFederatedType function
+ * @param hop The hop being analyzed
+ * @param returnFType The FType that will be returned
+ * @param reason The reason for the FType decision
+ */
+ public static void logGetFederatedTypeDebug(Hop hop, FType returnFType, String reason) {
+ String hopName = hop.getName() != null ? hop.getName() : "null";
+ long hopID = hop.getHopID();
+ String operationType = hop.getClass().getSimpleName();
+ String opCode = hop.getOpString();
+
+ System.out.println("[GetFederatedType] HopName: " + hopName + " | HopID: " + hopID +
+ " | OperationType: " + operationType + " | OpCode: " + opCode +
+ " | ReturnFType: " + (returnFType != null ? returnFType : "null") +
+ " | Reason: " + reason);
+ }
+
+ /**
+ * Logs detailed hop error information with complete hop details
+ * @param hop The hop that caused the error
+ * @param logPrefix Prefix string to identify the log source
+ * @param additionalMessage Additional error message
+ */
+ public static void logHopErrorDetails(Hop hop, String logPrefix, String additionalMessage) {
+ System.err.println("[" + logPrefix + "] " + additionalMessage);
+ System.err.println(" Child Hop Details:");
+ System.err.println(" - Class: " + hop.getClass().getSimpleName());
+ System.err.println(" - Name: " + (hop.getName() != null ? hop.getName() : "null"));
+ System.err.println(" - OpString: " + hop.getOpString());
+ System.err.println(" - HopID: " + hop.getHopID());
+ }
+
+ /**
+ * Logs detailed null child plan debugging information
+ * @param childFedPlanPair The child federated plan pair that is null
+ * @param optimalPlan The current optimal plan (parent)
+ * @param memoTable The memo table for lookups
+ */
+ public static void logNullChildPlanDebug(Pair childFedPlanPair,
+ FedPlan optimalPlan,
+ org.apache.sysds.hops.fedplanner.FederatedMemoTable memoTable) {
+ FederatedOutput alternativeFedType = (childFedPlanPair.getRight() == FederatedOutput.LOUT) ?
+ FederatedOutput.FOUT : FederatedOutput.LOUT;
+ FedPlan alternativeChildPlan = memoTable.getFedPlanAfterPrune(childFedPlanPair.getLeft(), alternativeFedType);
+
+ // Get child hop info
+ Hop childHop = null;
+ String childInfo = "UNKNOWN";
+ if (alternativeChildPlan != null) {
+ childHop = alternativeChildPlan.getHopRef();
+ // Check if required fed type plan exists
+ String requiredExists = memoTable.getFedPlanAfterPrune(childFedPlanPair.getLeft(), childFedPlanPair.getRight()) != null ? "O" : "X";
+ // Check if alternative fed type plan exists
+ String altExists = alternativeChildPlan != null ? "O" : "X";
+
+ childInfo = String.format("ID:%d|Name:%s|Op:%s|RequiredFedType:%s(%s)|AltFedType:%s(%s)",
+ childHop.getHopID(),
+ childHop.getName() != null ? childHop.getName() : "null",
+ childHop.getOpString(),
+ childFedPlanPair.getRight(),
+ requiredExists,
+ alternativeFedType,
+ altExists);
+ }
+
+ // Current parent hop info
+ String currentParentInfo = String.format("ID:%d|Name:%s|Op:%s|FedType:%s|RequiredChild:%s",
+ optimalPlan.getHopID(),
+ optimalPlan.getHopRef().getName() != null ? optimalPlan.getHopRef().getName() : "null",
+ optimalPlan.getHopRef().getOpString(),
+ optimalPlan.getFedOutType(),
+ childFedPlanPair.getRight());
+
+ // Alternative parent info (if child has other parents)
+ String alternativeParentInfo = "NONE";
+ if (childHop != null) {
+ List parents = childHop.getParent();
+ for (Hop parent : parents) {
+ if (parent.getHopID() != optimalPlan.getHopID()) {
+ // Try to find alt parent's fed plan info
+ String altParentFedType = "UNKNOWN";
+ String altParentRequiredChild = "UNKNOWN";
+
+ // Check both LOUT and FOUT plans for alt parent
+ FedPlan altParentPlanLOUT = memoTable.getFedPlanAfterPrune(parent.getHopID(), FederatedOutput.LOUT);
+ FedPlan altParentPlanFOUT = memoTable.getFedPlanAfterPrune(parent.getHopID(), FederatedOutput.FOUT);
+
+ if (altParentPlanLOUT != null) {
+ altParentFedType = "LOUT";
+ // Find what this alt parent expects from child
+ for (Pair altChildPair : altParentPlanLOUT.getChildFedPlans()) {
+ if (altChildPair.getLeft() == childHop.getHopID()) {
+ altParentRequiredChild = altChildPair.getRight().toString();
+ break;
+ }
+ }
+ } else if (altParentPlanFOUT != null) {
+ altParentFedType = "FOUT";
+ // Find what this alt parent expects from child
+ for (Pair altChildPair : altParentPlanFOUT.getChildFedPlans()) {
+ if (altChildPair.getLeft() == childHop.getHopID()) {
+ altParentRequiredChild = altChildPair.getRight().toString();
+ break;
+ }
+ }
+ }
+
+ alternativeParentInfo = String.format("ID:%d|Name:%s|Op:%s|FedType:%s|RequiredChild:%s",
+ parent.getHopID(),
+ parent.getName() != null ? parent.getName() : "null",
+ parent.getOpString(),
+ altParentFedType,
+ altParentRequiredChild);
+ break;
+ }
+ }
+ }
+
+ System.err.println("[DEBUG] NULL CHILD PLAN DETECTED:");
+ System.err.println(" Child: " + childInfo);
+ System.err.println(" Current Parent: " + currentParentInfo);
+ System.err.println(" Alt Parent: " + alternativeParentInfo);
+ System.err.println(" Alt Plan Exists: " + (alternativeChildPlan != null));
+ }
+
+ /**
+ * Logs debugging information for TransRead hop rewiring process
+ * @param hopName The name of the TransRead hop
+ * @param hopID The ID of the TransRead hop
+ * @param childHops List of child hops found during rewiring
+ * @param isEmptyChildHops Whether the child hops list is empty
+ * @param logPrefix Prefix string to identify the log source
+ */
+ public static void logTransReadRewireDebug(String hopName, long hopID, List childHops,
+ boolean isEmptyChildHops, String logPrefix) {
+ if (isEmptyChildHops) {
+ System.err.println("[" + logPrefix + "] (hopName: " + hopName + ", hopID: " + hopID + ") child hops is empty");
+ }
+ }
+
+ /**
+ * Logs debugging information for filtered child hops during TransRead rewiring
+ * @param hopName The name of the TransRead hop
+ * @param hopID The ID of the TransRead hop
+ * @param filteredChildHops List of filtered child hops
+ * @param isEmptyFilteredChildHops Whether the filtered child hops list is empty
+ * @param logPrefix Prefix string to identify the log source
+ */
+ public static void logFilteredChildHopsDebug(String hopName, long hopID, List filteredChildHops,
+ boolean isEmptyFilteredChildHops, String logPrefix) {
+ if (isEmptyFilteredChildHops) {
+ System.err.println("[" + logPrefix + "] (hopName: " + hopName + ", hopID: " + hopID + ") filtered child hops is empty");
+ }
+ }
+
+ /**
+ * Logs detailed FType mismatch error information for TransRead hop
+ * @param hop The TransRead hop with FType mismatch
+ * @param filteredChildHops List of filtered child hops
+ * @param fTypeMap Map containing FType information for hops
+ * @param expectedFType The expected FType
+ * @param mismatchedFType The mismatched FType
+ * @param mismatchIndex The index where mismatch occurred
+ */
+ public static void logFTypeMismatchError(Hop hop, List filteredChildHops, Map fTypeMap,
+ FType expectedFType, FType mismatchedFType, int mismatchIndex) {
+ String hopName = hop.getName();
+ long hopID = hop.getHopID();
+
+ System.err.println("[Error] FType MISMATCH DETECTED for TransRead (hopName: " + hopName + ", hopID: " + hopID + ")");
+ System.err.println("[Error] TRANSREAD HOP DETAILS - Type: " + hop.getClass().getSimpleName() +
+ ", OpType: " + (hop instanceof org.apache.sysds.hops.DataOp ?
+ ((org.apache.sysds.hops.DataOp)hop).getOp() : "N/A") +
+ ", DataType: " + hop.getDataType() +
+ ", Dims: [" + hop.getDim1() + "x" + hop.getDim2() + "]");
+ System.err.println("[Error] FILTERED CHILD HOPS FTYPE ANALYSIS:");
+
+ for (int j = 0; j < filteredChildHops.size(); j++) {
+ Hop childHop = filteredChildHops.get(j);
+ FType childFType = fTypeMap.get(childHop.getHopID());
+ System.err.println("[Error] FilteredChild[" + j + "] - Name: " + childHop.getName() +
+ ", ID: " + childHop.getHopID() +
+ ", FType: " + childFType +
+ ", Type: " + childHop.getClass().getSimpleName() +
+ ", OpType: " + (childHop instanceof org.apache.sysds.hops.DataOp ?
+ ((org.apache.sysds.hops.DataOp)childHop).getOp().toString() : "N/A") +
+ ", Dims: [" + childHop.getDim1() + "x" + childHop.getDim2() + "]");
+ }
+
+ System.err.println("[Error] Expected FType: " + expectedFType +
+ ", Mismatched FType: " + mismatchedFType +
+ " at child index: " + mismatchIndex);
+ }
+
+ /**
+ * Logs FType debug information for DataOp operations (FEDERATED, TRANSIENTWRITE, TRANSIENTREAD)
+ * @param hop The DataOp hop being analyzed
+ * @param fType The FType that was determined for this operation
+ * @param opType The operation type (FEDERATED, TRANSIENTWRITE, TRANSIENTREAD)
+ * @param reason The reason for the FType decision
+ */
+ public static void logDataOpFTypeDebug(Hop hop, FType fType, String opType, String reason) {
+ String hopName = hop.getName() != null ? hop.getName() : "null";
+ long hopID = hop.getHopID();
+ String hopClass = hop.getClass().getSimpleName();
+ String dimensions = "[" + hop.getDim1() + "x" + hop.getDim2() + "]";
+
+ System.out.println("[GetFederatedType] HopName: " + hopName +
+ " | HopID: " + hopID +
+ " | HopClass: " + hopClass +
+ " | OpType: " + opType +
+ " | Dims: " + dimensions +
+ " | FType: " + (fType != null ? fType : "null") +
+ " | Reason: " + reason);
+ }
+
+ // ========== FederatedMemoTable Printing Methods ==========
+
+ /**
+ * Recursively prints a tree representation of the DAG starting from the given root FedPlan.
+ * Includes information about hopID, fedOutType, TotalCost, SelfCost, and NetCost for each node.
+ * Additionally, prints the additional total cost once at the beginning.
+ *
+ * @param rootFedPlan The starting point FedPlan to print
+ * @param rootHopStatSet Set of root hop statistics
+ * @param memoTable The memoization table containing FedPlan variants
+ * @param additionalTotalCost The additional cost to be printed once
+ */
+ public static void printFedPlanTree(FederatedMemoTable.FedPlan rootFedPlan, Set rootHopStatSet,
+ FederatedMemoTable memoTable, double additionalTotalCost) {
+ System.out.println("Additional Cost: " + additionalTotalCost);
+ Set visited = new HashSet<>();
+ printFedPlanTreeRecursive(rootFedPlan, memoTable, visited, 0);
- /**
- * Helper method to recursively print the FedPlan tree for not referenced plans.
- *
- * @param plan The current FedPlan to print
- * @param memoTable The memoization table containing FedPlan variants
- * @param visited Set to keep track of visited FedPlans (prevents cycles)
- * @param depth The current depth level for indentation
- */
- private static void printNotReferencedFedPlanRecursive(FederatedMemoTable.FedPlan plan, FederatedMemoTable memoTable,
- Set visited, int depth) {
- long hopID = plan.getHopRef().getHopID();
+ for (Long hopID : rootHopStatSet) {
+ FedPlan plan = memoTable.getFedPlanAfterPrune(hopID, FederatedOutput.LOUT);
+ if (plan == null){
+ plan = memoTable.getFedPlanAfterPrune(hopID, FederatedOutput.FOUT);
+ }
+ printNotReferencedFedPlanRecursive(plan, memoTable, visited, 1);
+ }
+ }
- if (visited.contains(hopID)) {
- return;
- }
+ /**
+ * Helper method to recursively print the FedPlan tree for not referenced plans.
+ *
+ * @param plan The current FedPlan to print
+ * @param memoTable The memoization table containing FedPlan variants
+ * @param visited Set to keep track of visited FedPlans (prevents cycles)
+ * @param depth The current depth level for indentation
+ */
+ private static void printNotReferencedFedPlanRecursive(FederatedMemoTable.FedPlan plan, FederatedMemoTable memoTable,
+ Set visited, int depth) {
+ long hopID = plan.getHopRef().getHopID();
- visited.add(hopID);
- printFedPlan(plan, memoTable, depth, true);
+ if (visited.contains(hopID)) {
+ return;
+ }
- // Process child nodes
- List> childFedPlanPairs = plan.getChildFedPlans();
- for (int i = 0; i < childFedPlanPairs.size(); i++) {
- Pair childFedPlanPair = childFedPlanPairs.get(i);
- FederatedMemoTable.FedPlanVariants childVariants = memoTable.getFedPlanVariants(childFedPlanPair);
- if (childVariants == null || childVariants.isEmpty())
- continue;
+ visited.add(hopID);
+ printFedPlan(plan, memoTable, depth, true);
- for (FederatedMemoTable.FedPlan childPlan : childVariants.getFedPlanVariants()) {
- printNotReferencedFedPlanRecursive(childPlan, memoTable, visited, depth + 1);
- }
- }
- }
+ // Process child nodes
+ List> childFedPlanPairs = plan.getChildFedPlans();
+ for (int i = 0; i < childFedPlanPairs.size(); i++) {
+ Pair childFedPlanPair = childFedPlanPairs.get(i);
+ FederatedMemoTable.FedPlanVariants childVariants = memoTable.getFedPlanVariants(childFedPlanPair);
+ if (childVariants == null || childVariants.isEmpty())
+ continue;
- /**
- * Helper method to recursively print the FedPlan tree.
- *
- * @param plan The current FedPlan to print
- * @param memoTable The memoization table containing FedPlan variants
- * @param visited Set to keep track of visited FedPlans (prevents cycles)
- * @param depth The current depth level for indentation
- */
- private static void printFedPlanTreeRecursive(FederatedMemoTable.FedPlan plan, FederatedMemoTable memoTable,
- Set visited, int depth) {
- long hopID = 0;
+ for (FederatedMemoTable.FedPlan childPlan : childVariants.getFedPlanVariants()) {
+ printNotReferencedFedPlanRecursive(childPlan, memoTable, visited, depth + 1);
+ }
+ }
+ }
- if (depth == 0) {
- hopID = -1;
- } else {
- hopID = plan.getHopRef().getHopID();
- }
+ /**
+ * Helper method to recursively print the FedPlan tree.
+ *
+ * @param plan The current FedPlan to print
+ * @param memoTable The memoization table containing FedPlan variants
+ * @param visited Set to keep track of visited FedPlans (prevents cycles)
+ * @param depth The current depth level for indentation
+ */
+ private static void printFedPlanTreeRecursive(FederatedMemoTable.FedPlan plan, FederatedMemoTable memoTable,
+ Set visited, int depth) {
+ long hopID = 0;
- if (visited.contains(hopID)) {
- return;
- }
+ if (depth == 0) {
+ hopID = -1;
+ } else {
+ hopID = plan.getHopRef().getHopID();
+ }
- visited.add(hopID);
- printFedPlan(plan, memoTable, depth, false);
-
- // Process child nodes
- List> childFedPlanPairs = plan.getChildFedPlans();
- for (int i = 0; i < childFedPlanPairs.size(); i++) {
- Pair childFedPlanPair = childFedPlanPairs.get(i);
- FederatedMemoTable.FedPlanVariants childVariants = memoTable.getFedPlanVariants(childFedPlanPair);
- if (childVariants == null || childVariants.isEmpty())
- continue;
+ if (visited.contains(hopID)) {
+ return;
+ }
- for (FederatedMemoTable.FedPlan childPlan : childVariants.getFedPlanVariants()) {
- printFedPlanTreeRecursive(childPlan, memoTable, visited, depth + 1);
- }
- }
- }
+ visited.add(hopID);
+ printFedPlan(plan, memoTable, depth, false);
+
+ // Process child nodes
+ List> childFedPlanPairs = plan.getChildFedPlans();
+ for (int i = 0; i < childFedPlanPairs.size(); i++) {
+ Pair childFedPlanPair = childFedPlanPairs.get(i);
+ FederatedMemoTable.FedPlanVariants childVariants = memoTable.getFedPlanVariants(childFedPlanPair);
+ if (childVariants == null || childVariants.isEmpty())
+ continue;
- /**
- * Prints detailed information about a FedPlan including costs, dimensions, and memory estimates.
- *
- * @param plan The FedPlan to print
- * @param memoTable The memoization table containing FedPlan variants
- * @param depth The current depth level for indentation
- * @param isNotReferenced Whether this plan is not referenced
- */
- private static void printFedPlan(FederatedMemoTable.FedPlan plan, FederatedMemoTable memoTable, int depth, boolean isNotReferenced) {
- StringBuilder sb = new StringBuilder();
- Hop hop = null;
+ for (FederatedMemoTable.FedPlan childPlan : childVariants.getFedPlanVariants()) {
+ printFedPlanTreeRecursive(childPlan, memoTable, visited, depth + 1);
+ }
+ }
+ }
- if (depth == 0){
- sb.append("(R) ROOT [Root]");
- } else {
- hop = plan.getHopRef();
- // Add FedPlan information
- sb.append(String.format("(%d) ", hop.getHopID()))
- .append(hop.getOpString())
- .append(" [");
+ /**
+ * Prints detailed information about a FedPlan including costs, dimensions, and memory estimates.
+ *
+ * @param plan The FedPlan to print
+ * @param memoTable The memoization table containing FedPlan variants
+ * @param depth The current depth level for indentation
+ * @param isNotReferenced Whether this plan is not referenced
+ */
+ private static void printFedPlan(FederatedMemoTable.FedPlan plan, FederatedMemoTable memoTable, int depth, boolean isNotReferenced) {
+ StringBuilder sb = new StringBuilder();
+ Hop hop = null;
- if (isNotReferenced) {
- if (depth == 1) {
- sb.append("NRef(TOP)");
- } else {
- sb.append("NRef");
- }
- } else{
- sb.append(plan.getFedOutType());
- }
- sb.append("]");
- }
+ if (depth == 0){
+ sb.append("(R) ROOT [Root]");
+ } else {
+ hop = plan.getHopRef();
+ // Add FedPlan information
+ sb.append(String.format("(%d) ", hop.getHopID()))
+ .append(hop.getOpString())
+ .append(" [");
- StringBuilder childs = new StringBuilder();
- childs.append(" (");
+ if (isNotReferenced) {
+ if (depth == 1) {
+ sb.append("NRef(TOP)");
+ } else {
+ sb.append("NRef");
+ }
+ } else{
+ sb.append(plan.getFedOutType());
+ }
+ sb.append("]");
+ }
- boolean childAdded = false;
- for (Pair childPair : plan.getChildFedPlans()){
- childs.append(childAdded?",":"");
- childs.append(childPair.getLeft());
- childAdded = true;
- }
-
- childs.append(")");
+ StringBuilder childs = new StringBuilder();
+ childs.append(" (");
- if (childAdded)
- sb.append(childs.toString());
+ boolean childAdded = false;
+ for (Pair childPair : plan.getChildFedPlans()){
+ childs.append(childAdded?",":"");
+ childs.append(childPair.getLeft());
+ childAdded = true;
+ }
+
+ childs.append(")");
- if (depth == 0){
- sb.append(String.format(" {Total: %.1f}", plan.getCumulativeCost()));
- System.out.println(sb);
- return;
- }
+ if (childAdded)
+ sb.append(childs.toString());
- sb.append(String.format(" {Total: %.1f, Self: %.1f, Net: %.1f, Weight: %.1f}",
- plan.getCumulativeCost(),
- plan.getSelfCost(),
- plan.getForwardingCost(),
- plan.getComputeWeight()));
+ if (depth == 0){
+ sb.append(String.format(" {Total: %.1f}", plan.getCumulativeCost()));
+ System.out.println(sb);
+ return;
+ }
- // Add matrix characteristics
- sb.append(" [")
- .append(hop.getDim1()).append(", ")
- .append(hop.getDim2()).append(", ")
- .append(hop.getBlocksize()).append(", ")
- .append(hop.getNnz());
+ sb.append(String.format(" {Total: %.1f, Self: %.1f, Net: %.1f, Weight: %.1f}",
+ plan.getCumulativeCost(),
+ plan.getSelfCost(),
+ plan.getForwardingCost(),
+ plan.getComputeWeight()));
- if (hop.getUpdateType().isInPlace()) {
- sb.append(", ").append(hop.getUpdateType().toString().toLowerCase());
- }
- sb.append("]");
+ // Add matrix characteristics
+ sb.append(" [")
+ .append(hop.getDim1()).append(", ")
+ .append(hop.getDim2()).append(", ")
+ .append(hop.getBlocksize()).append(", ")
+ .append(hop.getNnz());
- // Add memory estimates
- sb.append(" [")
- .append(OptimizerUtils.toMB(hop.getInputMemEstimate())).append(", ")
- .append(OptimizerUtils.toMB(hop.getIntermediateMemEstimate())).append(", ")
- .append(OptimizerUtils.toMB(hop.getOutputMemEstimate())).append(" -> ")
- .append(OptimizerUtils.toMB(hop.getMemEstimate())).append("MB]");
+ if (hop.getUpdateType().isInPlace()) {
+ sb.append(", ").append(hop.getUpdateType().toString().toLowerCase());
+ }
+ sb.append("]");
- // Add reblock and checkpoint requirements
- if (hop.requiresReblock() && hop.requiresCheckpoint()) {
- sb.append(" [rblk, chkpt]");
- } else if (hop.requiresReblock()) {
- sb.append(" [rblk]");
- } else if (hop.requiresCheckpoint()) {
- sb.append(" [chkpt]");
- }
+ // Add memory estimates
+ sb.append(" [")
+ .append(OptimizerUtils.toMB(hop.getInputMemEstimate())).append(", ")
+ .append(OptimizerUtils.toMB(hop.getIntermediateMemEstimate())).append(", ")
+ .append(OptimizerUtils.toMB(hop.getOutputMemEstimate())).append(" -> ")
+ .append(OptimizerUtils.toMB(hop.getMemEstimate())).append("MB]");
- // Add execution type
- if (hop.getExecType() != null) {
- sb.append(", ").append(hop.getExecType());
- }
-
- if (childAdded){
- sb.append(" [Edges]{");
- for (Pair childPair : plan.getChildFedPlans()){
- // Add forwarding weight for each edge
- FedPlan childPlan = memoTable.getFedPlanAfterPrune(childPair.getLeft(), childPair.getRight());
-
- if (childPlan == null) {
- sb.append(String.format("(ID:%d, NULL)", childPair.getLeft()));
- } else {
- String isForwardingCostOccured = "";
- if (childPair.getRight() == plan.getFedOutType()){
- isForwardingCostOccured = "X";
- } else {
- isForwardingCostOccured = "O";
- }
- sb.append(String.format("(ID:%d, %s, C:%.1f, F:%.1f, FW:%.1f)", childPair.getLeft(), isForwardingCostOccured,
- childPlan.getCumulativeCostPerParents(),
- plan.getChildForwardingWeight(childPlan.getLoopContext()) * childPlan.getForwardingCostPerParents(),
- plan.getChildForwardingWeight(childPlan.getLoopContext())));
- }
- sb.append(childAdded?",":"");
- }
- sb.append("}");
- }
+ // Add reblock and checkpoint requirements
+ if (hop.requiresReblock() && hop.requiresCheckpoint()) {
+ sb.append(" [rblk, chkpt]");
+ } else if (hop.requiresReblock()) {
+ sb.append(" [rblk]");
+ } else if (hop.requiresCheckpoint()) {
+ sb.append(" [chkpt]");
+ }
- System.out.println(sb);
- }
+ // Add execution type
+ if (hop.getExecType() != null) {
+ sb.append(", ").append(hop.getExecType());
+ }
+
+ if (childAdded){
+ sb.append(" [Edges]{");
+ for (Pair childPair : plan.getChildFedPlans()){
+ // Add forwarding weight for each edge
+ FedPlan childPlan = memoTable.getFedPlanAfterPrune(childPair.getLeft(), childPair.getRight());
+
+ if (childPlan == null) {
+ sb.append(String.format("(ID:%d, NULL)", childPair.getLeft()));
+ } else {
+ String isForwardingCostOccured = "";
+ if (childPair.getRight() == plan.getFedOutType()){
+ isForwardingCostOccured = "X";
+ } else {
+ isForwardingCostOccured = "O";
+ }
+ sb.append(String.format("(ID:%d, %s, C:%.1f, F:%.1f, FW:%.1f)", childPair.getLeft(), isForwardingCostOccured,
+ childPlan.getCumulativeCostPerParents(),
+ plan.getChildForwardingWeight(childPlan.getLoopContext()) * childPlan.getForwardingCostPerParents(),
+ plan.getChildForwardingWeight(childPlan.getLoopContext())));
+ }
+ sb.append(childAdded?",":"");
+ }
+ sb.append("}");
+ }
+
+ System.out.println(sb);
+ }
}
\ No newline at end of file
diff --git a/src/main/java/org/apache/sysds/hops/rewrite/ProgramRewriter.java b/src/main/java/org/apache/sysds/hops/rewrite/ProgramRewriter.java
index c2602dba510..aa82adcfdc5 100644
--- a/src/main/java/org/apache/sysds/hops/rewrite/ProgramRewriter.java
+++ b/src/main/java/org/apache/sysds/hops/rewrite/ProgramRewriter.java
@@ -69,15 +69,14 @@ public ProgramRewriter(boolean staticRewrites, boolean dynamicRewrites)
//initialize StatementBlock rewrite ruleSet (with fixed rewrite order)
_sbRuleSet = new ArrayList<>();
-
-
+
+
//STATIC REWRITES (which do not rely on size information)
if( staticRewrites )
{
//add static HOP DAG rewrite rules
_dagRuleSet.add( new RewriteRemoveReadAfterWrite() ); //dependency: before blocksize
_dagRuleSet.add( new RewriteBlockSizeAndReblock() );
- _dagRuleSet.add( new RewriteInjectOOCTee() );
if( OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION )
_dagRuleSet.add( new RewriteRemoveUnnecessaryCasts() );
if( OptimizerUtils.ALLOW_COMMON_SUBEXPRESSION_ELIMINATION )
@@ -94,6 +93,7 @@ public ProgramRewriter(boolean staticRewrites, boolean dynamicRewrites)
if( OptimizerUtils.ALLOW_QUANTIZE_COMPRESS_REWRITE )
_dagRuleSet.add( new RewriteQuantizationFusedCompression() );
+
//add statement block rewrite rules
if( OptimizerUtils.ALLOW_BRANCH_REMOVAL )
_sbRuleSet.add( new RewriteRemoveUnnecessaryBranches() ); //dependency: constant folding
@@ -152,6 +152,7 @@ public ProgramRewriter(boolean staticRewrites, boolean dynamicRewrites)
_dagRuleSet.add( new RewriteConstantFolding() ); //dependency: cse
_sbRuleSet.add( new RewriteRemoveEmptyBasicBlocks() );
_sbRuleSet.add( new RewriteRemoveEmptyForLoops() );
+ _sbRuleSet.add( new RewriteInjectOOCTee() );
}
/**
diff --git a/src/main/java/org/apache/sysds/hops/rewrite/RewriteInjectOOCTee.java b/src/main/java/org/apache/sysds/hops/rewrite/RewriteInjectOOCTee.java
index 54dffa263eb..7abfb15d1d4 100644
--- a/src/main/java/org/apache/sysds/hops/rewrite/RewriteInjectOOCTee.java
+++ b/src/main/java/org/apache/sysds/hops/rewrite/RewriteInjectOOCTee.java
@@ -25,6 +25,7 @@
import org.apache.sysds.hops.DataOp;
import org.apache.sysds.hops.Hop;
import org.apache.sysds.hops.ReorgOp;
+import org.apache.sysds.parser.StatementBlock;
import java.util.ArrayList;
import java.util.HashMap;
@@ -49,73 +50,20 @@
* 2. Apply Rewrites (Modification): Iterate over the collected candidate and put
* {@code TeeOp}, and safely rewire the graph.
*/
-public class RewriteInjectOOCTee extends HopRewriteRule {
+public class RewriteInjectOOCTee extends StatementBlockRewriteRule {
public static boolean APPLY_ONLY_XtX_PATTERN = false;
+
+ private static final Map _transientVars = new HashMap<>();
+ private static final Map> _transientHops = new HashMap<>();
+ private static final Set teeTransientVars = new HashSet<>();
private static final Set rewrittenHops = new HashSet<>();
private static final Map handledHop = new HashMap<>();
// Maintain a list of candidates to rewrite in the second pass
private final List rewriteCandidates = new ArrayList<>();
-
- /**
- * Handle a generic (last-level) hop DAG with multiple roots.
- *
- * @param roots high-level operator roots
- * @param state program rewrite status
- * @return list of high-level operators
- */
- @Override
- public ArrayList rewriteHopDAGs(ArrayList roots, ProgramRewriteStatus state) {
- if (roots == null) {
- return null;
- }
-
- // Clear candidates for this pass
- rewriteCandidates.clear();
-
- // PASS 1: Identify candidates without modifying the graph
- for (Hop root : roots) {
- root.resetVisitStatus();
- findRewriteCandidates(root);
- }
-
- // PASS 2: Apply rewrites to identified candidates
- for (Hop candidate : rewriteCandidates) {
- applyTopDownTeeRewrite(candidate);
- }
-
- return roots;
- }
-
- /**
- * Handle a predicate hop DAG with exactly one root.
- *
- * @param root high-level operator root
- * @param state program rewrite status
- * @return high-level operator
- */
- @Override
- public Hop rewriteHopDAG(Hop root, ProgramRewriteStatus state) {
- if (root == null) {
- return null;
- }
-
- // Clear candidates for this pass
- rewriteCandidates.clear();
-
- // PASS 1: Identify candidates without modifying the graph
- root.resetVisitStatus();
- findRewriteCandidates(root);
-
- // PASS 2: Apply rewrites to identified candidates
- for (Hop candidate : rewriteCandidates) {
- applyTopDownTeeRewrite(candidate);
- }
-
- return root;
- }
+ private boolean forceTee = false;
/**
* First pass: Find candidates for rewrite without modifying the graph.
@@ -137,6 +85,35 @@ private void findRewriteCandidates(Hop hop) {
findRewriteCandidates(input);
}
+ boolean isRewriteCandidate = DMLScript.USE_OOC
+ && hop.getDataType().isMatrix()
+ && !HopRewriteUtils.isData(hop, OpOpData.TEE)
+ && hop.getParent().size() > 1
+ && (!APPLY_ONLY_XtX_PATTERN || isSelfTranposePattern(hop));
+
+ if (HopRewriteUtils.isData(hop, OpOpData.TRANSIENTREAD) && hop.getDataType().isMatrix()) {
+ _transientVars.compute(hop.getName(), (key, ctr) -> {
+ int incr = (isRewriteCandidate || forceTee) ? 2 : 1;
+
+ int ret = ctr == null ? 0 : ctr;
+ ret += incr;
+
+ if (ret > 1)
+ teeTransientVars.add(hop.getName());
+
+ return ret;
+ });
+
+ _transientHops.compute(hop.getName(), (key, hops) -> {
+ if (hops == null)
+ return new ArrayList<>(List.of(hop));
+ hops.add(hop);
+ return hops;
+ });
+
+ return; // We do not tee transient reads but rather inject before TWrite or PRead as caching stream
+ }
+
// Check if this hop is a candidate for OOC Tee injection
if (DMLScript.USE_OOC
&& hop.getDataType().isMatrix()
@@ -160,11 +137,17 @@ private void applyTopDownTeeRewrite(Hop sharedInput) {
return;
}
+ int consumerCount = sharedInput.getParent().size();
+ if (LOG.isDebugEnabled()) {
+ LOG.debug("Inject tee for hop " + sharedInput.getHopID() + " ("
+ + sharedInput.getName() + "), consumers=" + consumerCount);
+ }
+
// Take a defensive copy of consumers before modifying the graph
ArrayList consumers = new ArrayList<>(sharedInput.getParent());
// Create the new TeeOp with the original hop as input
- DataOp teeOp = new DataOp("tee_out_" + sharedInput.getName(),
+ DataOp teeOp = new DataOp("tee_out_" + sharedInput.getName(),
sharedInput.getDataType(), sharedInput.getValueType(), Types.OpOpData.TEE, null,
sharedInput.getDim1(), sharedInput.getDim2(), sharedInput.getNnz(), sharedInput.getBlocksize());
HopRewriteUtils.addChildReference(teeOp, sharedInput);
@@ -177,6 +160,11 @@ private void applyTopDownTeeRewrite(Hop sharedInput) {
// Record that we've handled this hop
handledHop.put(sharedInput.getHopID(), teeOp);
rewrittenHops.add(sharedInput.getHopID());
+
+ if (LOG.isDebugEnabled()) {
+ LOG.debug("Created tee hop " + teeOp.getHopID() + " -> "
+ + teeOp.getName());
+ }
}
@SuppressWarnings("unused")
@@ -196,4 +184,108 @@ else if (HopRewriteUtils.isMatrixMultiply(parent)) {
}
return hasTransposeConsumer && hasMatrixMultiplyConsumer;
}
+
+ @Override
+ public boolean createsSplitDag() {
+ return false;
+ }
+
+ @Override
+ public List rewriteStatementBlock(StatementBlock sb, ProgramRewriteStatus state) {
+ if (!DMLScript.USE_OOC)
+ return List.of(sb);
+
+ rewriteSB(sb, state);
+
+ for (String tVar : teeTransientVars) {
+ List tHops = _transientHops.get(tVar);
+
+ if (tHops == null)
+ continue;
+
+ for (Hop affectedHops : tHops) {
+ applyTopDownTeeRewrite(affectedHops);
+ }
+
+ tHops.clear();
+ }
+
+ removeRedundantTeeChains(sb);
+
+ return List.of(sb);
+ }
+
+ @Override
+ public List rewriteStatementBlocks(List sbs, ProgramRewriteStatus state) {
+ if (!DMLScript.USE_OOC)
+ return sbs;
+
+ for (StatementBlock sb : sbs)
+ rewriteSB(sb, state);
+
+ for (String tVar : teeTransientVars) {
+ List tHops = _transientHops.get(tVar);
+
+ if (tHops == null)
+ continue;
+
+ for (Hop affectedHops : tHops) {
+ applyTopDownTeeRewrite(affectedHops);
+ }
+ }
+
+ for (StatementBlock sb : sbs)
+ removeRedundantTeeChains(sb);
+
+ return sbs;
+ }
+
+ private void rewriteSB(StatementBlock sb, ProgramRewriteStatus state) {
+ rewriteCandidates.clear();
+
+ if (sb.getHops() != null) {
+ for(Hop hop : sb.getHops()) {
+ hop.resetVisitStatus();
+ findRewriteCandidates(hop);
+ }
+ }
+
+ for (Hop candidate : rewriteCandidates) {
+ applyTopDownTeeRewrite(candidate);
+ }
+ }
+
+ private void removeRedundantTeeChains(StatementBlock sb) {
+ if (sb == null || sb.getHops() == null)
+ return;
+
+ Hop.resetVisitStatus(sb.getHops());
+ for (Hop hop : sb.getHops())
+ removeRedundantTeeChains(hop);
+ Hop.resetVisitStatus(sb.getHops());
+ }
+
+ private void removeRedundantTeeChains(Hop hop) {
+ if (hop.isVisited())
+ return;
+
+ ArrayList inputs = new ArrayList<>(hop.getInput());
+ for (Hop in : inputs)
+ removeRedundantTeeChains(in);
+
+ if (HopRewriteUtils.isData(hop, OpOpData.TEE) && hop.getInput().size() == 1) {
+ Hop teeInput = hop.getInput().get(0);
+ if (HopRewriteUtils.isData(teeInput, OpOpData.TEE)) {
+ if (LOG.isDebugEnabled()) {
+ LOG.debug("Remove redundant tee hop " + hop.getHopID()
+ + " (" + hop.getName() + ") -> " + teeInput.getHopID()
+ + " (" + teeInput.getName() + ")");
+ }
+ HopRewriteUtils.rewireAllParentChildReferences(hop, teeInput);
+ HopRewriteUtils.removeAllChildReferences(hop);
+ }
+ }
+
+ hop.setVisited();
+ }
}
diff --git a/src/main/java/org/apache/sysds/hops/rewrite/RewriteMatrixMultChainOptimization.java b/src/main/java/org/apache/sysds/hops/rewrite/RewriteMatrixMultChainOptimization.java
index fdd2f8343fe..960560c254e 100644
--- a/src/main/java/org/apache/sysds/hops/rewrite/RewriteMatrixMultChainOptimization.java
+++ b/src/main/java/org/apache/sysds/hops/rewrite/RewriteMatrixMultChainOptimization.java
@@ -210,7 +210,7 @@ protected void optimizeMMChain(Hop hop, List mmChain, List mmOperators
* Thomas H. Cormen, Charles E. Leiserson, Ronald L. Rivest, Clifford Stein
* Introduction to Algorithms, Third Edition, MIT Press, page 395.
*/
- private static int[][] mmChainDP(double[] dimArray, int size)
+ public static int[][] mmChainDP(double[] dimArray, int size)
{
double[][] dpMatrix = new double[size][size]; //min cost table
int[][] split = new int[size][size]; //min cost index table
diff --git a/src/main/java/org/apache/sysds/lops/Transform.java b/src/main/java/org/apache/sysds/lops/Transform.java
index 0d2e79f83a8..d9537dcca6c 100644
--- a/src/main/java/org/apache/sysds/lops/Transform.java
+++ b/src/main/java/org/apache/sysds/lops/Transform.java
@@ -180,7 +180,7 @@ private String getInstructions(String input1, int numInputs, String output) {
sb.append( this.prepOutputOperand(output));
if( (getExecType()==ExecType.CP || getExecType()==ExecType.FED || getExecType()==ExecType.OOC)
- && (_operation == ReOrgOp.TRANS || _operation == ReOrgOp.REV || _operation == ReOrgOp.SORT) ) {
+ && (_operation == ReOrgOp.TRANS || _operation == ReOrgOp.REV || _operation == ReOrgOp.SORT || _operation == ReOrgOp.ROLL) ) {
sb.append( OPERAND_DELIMITOR );
sb.append( _numThreads );
if ( getExecType()==ExecType.FED ) {
diff --git a/src/main/java/org/apache/sysds/parser/DMLTranslator.java b/src/main/java/org/apache/sysds/parser/DMLTranslator.java
index 092fbffe36d..c6e7188d7bc 100644
--- a/src/main/java/org/apache/sysds/parser/DMLTranslator.java
+++ b/src/main/java/org/apache/sysds/parser/DMLTranslator.java
@@ -1392,7 +1392,7 @@ public void constructHopsForWhileControlBlock(WhileStatementBlock sb) {
public void constructHopsForConditionalPredicate(StatementBlock passedSB) {
- HashMap _ids = new HashMap<>();
+ HashMap ids = new HashMap<>();
// set conditional predicate
ConditionalPredicate cp = null;
@@ -1428,7 +1428,7 @@ else if (passedSB instanceof IfStatementBlock) {
null, actualDim1, actualDim2, var.getNnz(), var.getBlocksize());
read.setParseInfo(var);
}
- _ids.put(varName, read);
+ ids.put(varName, read);
}
DataIdentifier target = new DataIdentifier(Expression.getTempName());
@@ -1439,12 +1439,12 @@ else if (passedSB instanceof IfStatementBlock) {
Expression predicate = cp.getPredicate();
if (predicate instanceof RelationalExpression) {
- predicateHops = processRelationalExpression((RelationalExpression) cp.getPredicate(), target, _ids);
+ predicateHops = processRelationalExpression((RelationalExpression) cp.getPredicate(), target, ids);
} else if (predicate instanceof BooleanExpression) {
- predicateHops = processBooleanExpression((BooleanExpression) cp.getPredicate(), target, _ids);
+ predicateHops = processBooleanExpression((BooleanExpression) cp.getPredicate(), target, ids);
} else if (predicate instanceof DataIdentifier) {
// handle data identifier predicate
- predicateHops = processExpression(cp.getPredicate(), null, _ids);
+ predicateHops = processExpression(cp.getPredicate(), null, ids);
} else if (predicate instanceof ConstIdentifier) {
// handle constant identifier
// a) translate 0 --> FALSE; translate 1 --> TRUE
@@ -1463,7 +1463,7 @@ else if (passedSB instanceof IfStatementBlock) {
throw new ParseException(predicate.printErrorLocation() + "String value '" + predicate.toString()
+ "' is not allowed for iterable predicate");
}
- predicateHops = processExpression(cp.getPredicate(), null, _ids);
+ predicateHops = processExpression(cp.getPredicate(), null, ids);
}
//create transient write to internal variable name on top of expression
@@ -1487,7 +1487,7 @@ else if (passedSB instanceof IfStatementBlock)
*/
public void constructHopsForIterablePredicate(ForStatementBlock fsb)
{
- HashMap _ids = new HashMap<>();
+ HashMap ids = new HashMap<>();
// set iterable predicate
ForStatement fs = (ForStatement) fsb.getStatement(0);
@@ -1513,13 +1513,13 @@ public void constructHopsForIterablePredicate(ForStatementBlock fsb)
null, actualDim1, actualDim2, var.getNnz(), var.getBlocksize());
read.setParseInfo(var);
}
- _ids.put(varName, read);
+ ids.put(varName, read);
}
}
//create transient write to internal variable name on top of expression
//in order to ensure proper instruction generation
- Hop predicateHops = processTempIntExpression(expr, _ids);
+ Hop predicateHops = processTempIntExpression(expr, ids);
if( predicateHops != null )
predicateHops = HopRewriteUtils.createDataOp(
ProgramBlock.PRED_VAR, predicateHops, OpOpData.TRANSIENTWRITE);
diff --git a/src/main/java/org/apache/sysds/parser/DataExpression.java b/src/main/java/org/apache/sysds/parser/DataExpression.java
index 1b9afb41b68..22dbe21c187 100644
--- a/src/main/java/org/apache/sysds/parser/DataExpression.java
+++ b/src/main/java/org/apache/sysds/parser/DataExpression.java
@@ -1019,6 +1019,18 @@ public void validateExpression(HashMap ids, HashMap args, HashMap
DMLOptions dmlOptions =DMLOptions.defaultOptions;
dmlOptions.argVals = args;
- String dmlScriptStr = readDMLScript(true, filePath);
+ String dmlScriptStr = DMLScript.readDMLScript(true, filePath);
Map argVals = dmlOptions.argVals;
ParserWrapper parser = ParserFactory.createParser();
@@ -235,7 +256,7 @@ else if (originBlock instanceof ForProgramBlock) //incl parfor
public static void setSingleNodeResourceConfigs(long nodeMemory, int nodeCores) {
DMLScript.setGlobalExecMode(Types.ExecMode.SINGLE_NODE);
// use 90% of the node's memory for the JVM heap -> rest needed for the OS
- long effectiveSingleNodeMemory = (long) (nodeMemory * JVM_MEMORY_FACTOR);
+ long effectiveSingleNodeMemory = (long) (nodeMemory * CloudUtils.JVM_MEMORY_FACTOR);
// CPU core would be shared with OS -> no further limitation
InfrastructureAnalyzer.setLocalMaxMemory(effectiveSingleNodeMemory);
InfrastructureAnalyzer.setLocalPar(nodeCores);
@@ -259,9 +280,9 @@ public static void setSparkClusterResourceConfigs(long driverMemory, int driverC
// ------------------- CP (driver) configurations -------------------
// use at most 90% of the node's memory for the JVM heap -> rest needed for the OS and resource management
// adapt the minimum based on the need for YAN RM
- long effectiveDriverMemory = calculateEffectiveDriverMemoryBudget(driverMemory, numExecutors*executorCores);
+ long effectiveDriverMemory = CloudUtils.calculateEffectiveDriverMemoryBudget(driverMemory, numExecutors*executorCores);
// require that always at least half of the memory budget is left for driver memory or 1GB
- if (effectiveDriverMemory <= GBtoBytes(1) || driverMemory > 2*effectiveDriverMemory) {
+ if (effectiveDriverMemory <= CloudUtils.GBtoBytes(1) || driverMemory > 2*effectiveDriverMemory) {
throw new IllegalArgumentException("Driver resources are not sufficient to handle the cluster");
}
// CPU core would be shared -> no further limitation
@@ -279,7 +300,7 @@ public static void setSparkClusterResourceConfigs(long driverMemory, int driverC
// ------------------ Dynamic Spark Configurations -------------------
// calculate the effective resource that would be available for the executor containers in YARN
- int[] effectiveValues = getEffectiveExecutorResources(executorMemory, executorCores, numExecutors);
+ int[] effectiveValues = CloudUtils.getEffectiveExecutorResources(executorMemory, executorCores, numExecutors);
int effectiveExecutorMemory = effectiveValues[0];
int effectiveExecutorCores = effectiveValues[1];
int effectiveNumExecutor = effectiveValues[2];
diff --git a/src/main/java/org/apache/sysds/resource/ResourceOptimizer.java b/src/main/java/org/apache/sysds/resource/ResourceOptimizer.java
index ad75de8abc1..cfa7573152b 100644
--- a/src/main/java/org/apache/sysds/resource/ResourceOptimizer.java
+++ b/src/main/java/org/apache/sysds/resource/ResourceOptimizer.java
@@ -38,7 +38,11 @@
import org.apache.commons.configuration2.io.FileHandler;
import java.io.IOException;
-import java.nio.file.*;
+import java.nio.file.FileAlreadyExistsException;
+import java.nio.file.Files;
+import java.nio.file.InvalidPathException;
+import java.nio.file.Path;
+import java.nio.file.Paths;
import java.util.HashMap;
import java.util.Map;
diff --git a/src/main/java/org/apache/sysds/resource/cost/CPCostUtils.java b/src/main/java/org/apache/sysds/resource/cost/CPCostUtils.java
index 0558e17c9cd..cd344612b51 100644
--- a/src/main/java/org/apache/sysds/resource/cost/CPCostUtils.java
+++ b/src/main/java/org/apache/sysds/resource/cost/CPCostUtils.java
@@ -22,9 +22,28 @@
import org.apache.sysds.common.Opcodes;
import org.apache.sysds.common.Types;
import org.apache.sysds.hops.OptimizerUtils;
-import org.apache.sysds.lops.*;
+import org.apache.sysds.lops.MMTSJ;
+import org.apache.sysds.lops.PickByCount;
import org.apache.sysds.runtime.DMLRuntimeException;
-import org.apache.sysds.runtime.instructions.cp.*;
+import org.apache.sysds.runtime.instructions.cp.AggregateBinaryCPInstruction;
+import org.apache.sysds.runtime.instructions.cp.BinaryCPInstruction;
+import org.apache.sysds.runtime.instructions.cp.BuiltinNaryCPInstruction;
+import org.apache.sysds.runtime.instructions.cp.CPInstruction;
+import org.apache.sysds.runtime.instructions.cp.CentralMomentCPInstruction;
+import org.apache.sysds.runtime.instructions.cp.ComputationCPInstruction;
+import org.apache.sysds.runtime.instructions.cp.CovarianceCPInstruction;
+import org.apache.sysds.runtime.instructions.cp.DataGenCPInstruction;
+import org.apache.sysds.runtime.instructions.cp.DnnCPInstruction;
+import org.apache.sysds.runtime.instructions.cp.MMTSJCPInstruction;
+import org.apache.sysds.runtime.instructions.cp.MultiReturnBuiltinCPInstruction;
+import org.apache.sysds.runtime.instructions.cp.ParameterizedBuiltinCPInstruction;
+import org.apache.sysds.runtime.instructions.cp.QuantilePickCPInstruction;
+import org.apache.sysds.runtime.instructions.cp.QuantileSortCPInstruction;
+import org.apache.sysds.runtime.instructions.cp.ReorgCPInstruction;
+import org.apache.sysds.runtime.instructions.cp.StringInitCPInstruction;
+import org.apache.sysds.runtime.instructions.cp.UaggOuterChainCPInstruction;
+import org.apache.sysds.runtime.instructions.cp.UnaryCPInstruction;
+import org.apache.sysds.runtime.instructions.cp.VariableCPInstruction;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
import org.apache.sysds.runtime.matrix.operators.CMOperator;
import org.apache.sysds.utils.stats.InfrastructureAnalyzer;
diff --git a/src/main/java/org/apache/sysds/resource/cost/CostEstimator.java b/src/main/java/org/apache/sysds/resource/cost/CostEstimator.java
index 3d4b460cb18..14c477190b3 100644
--- a/src/main/java/org/apache/sysds/resource/cost/CostEstimator.java
+++ b/src/main/java/org/apache/sysds/resource/cost/CostEstimator.java
@@ -23,28 +23,97 @@
import org.apache.sysds.common.Types.DataType;
import org.apache.sysds.common.Types.FileFormat;
import org.apache.sysds.hops.OptimizerUtils;
+import org.apache.sysds.lops.DataGen;
import org.apache.sysds.lops.LeftIndex;
import org.apache.sysds.lops.MapMult;
import org.apache.sysds.parser.DMLProgram;
import org.apache.sysds.parser.DataIdentifier;
import org.apache.sysds.resource.CloudInstance;
import org.apache.sysds.runtime.DMLRuntimeException;
-import org.apache.sysds.runtime.controlprogram.*;
+import org.apache.sysds.runtime.controlprogram.BasicProgramBlock;
+import org.apache.sysds.runtime.controlprogram.ForProgramBlock;
+import org.apache.sysds.runtime.controlprogram.FunctionProgramBlock;
+import org.apache.sysds.runtime.controlprogram.IfProgramBlock;
+import org.apache.sysds.runtime.controlprogram.Program;
+import org.apache.sysds.runtime.controlprogram.ProgramBlock;
+import org.apache.sysds.runtime.controlprogram.WhileProgramBlock;
import org.apache.sysds.runtime.controlprogram.context.SparkExecutionContext;
import org.apache.sysds.runtime.instructions.Instruction;
import org.apache.sysds.runtime.instructions.InstructionUtils;
-import org.apache.sysds.runtime.instructions.cp.*;
-import org.apache.sysds.runtime.instructions.spark.*;
+import org.apache.sysds.runtime.instructions.cp.AggregateBinaryCPInstruction;
+import org.apache.sysds.runtime.instructions.cp.AggregateUnaryCPInstruction;
+import org.apache.sysds.runtime.instructions.cp.BinaryCPInstruction;
+import org.apache.sysds.runtime.instructions.cp.BuiltinNaryCPInstruction;
+import org.apache.sysds.runtime.instructions.cp.CPInstruction;
+import org.apache.sysds.runtime.instructions.cp.CPOperand;
+import org.apache.sysds.runtime.instructions.cp.CompressionCPInstruction;
+import org.apache.sysds.runtime.instructions.cp.ComputationCPInstruction;
+import org.apache.sysds.runtime.instructions.cp.CtableCPInstruction;
+import org.apache.sysds.runtime.instructions.cp.DataGenCPInstruction;
+import org.apache.sysds.runtime.instructions.cp.DeCompressionCPInstruction;
+import org.apache.sysds.runtime.instructions.cp.FunctionCallCPInstruction;
+import org.apache.sysds.runtime.instructions.cp.IndexingCPInstruction;
+import org.apache.sysds.runtime.instructions.cp.MultiReturnBuiltinCPInstruction;
+import org.apache.sysds.runtime.instructions.cp.MultiReturnParameterizedBuiltinCPInstruction;
+import org.apache.sysds.runtime.instructions.cp.ParameterizedBuiltinCPInstruction;
+import org.apache.sysds.runtime.instructions.cp.ParamservBuiltinCPInstruction;
+import org.apache.sysds.runtime.instructions.cp.ReorgCPInstruction;
+import org.apache.sysds.runtime.instructions.cp.ScalarBuiltinNaryCPInstruction;
+import org.apache.sysds.runtime.instructions.cp.StringInitCPInstruction;
+import org.apache.sysds.runtime.instructions.cp.UnaryCPInstruction;
+import org.apache.sysds.runtime.instructions.cp.VariableCPInstruction;
+import org.apache.sysds.runtime.instructions.spark.AggregateBinarySPInstruction;
+import org.apache.sysds.runtime.instructions.spark.AggregateUnarySPInstruction;
+import org.apache.sysds.runtime.instructions.spark.AggregateUnarySketchSPInstruction;
+import org.apache.sysds.runtime.instructions.spark.AppendMSPInstruction;
+import org.apache.sysds.runtime.instructions.spark.AppendSPInstruction;
+import org.apache.sysds.runtime.instructions.spark.BinaryFrameFrameSPInstruction;
+import org.apache.sysds.runtime.instructions.spark.BinaryFrameMatrixSPInstruction;
+import org.apache.sysds.runtime.instructions.spark.BinaryMatrixBVectorSPInstruction;
+import org.apache.sysds.runtime.instructions.spark.BinaryMatrixMatrixSPInstruction;
+import org.apache.sysds.runtime.instructions.spark.BinaryMatrixScalarSPInstruction;
+import org.apache.sysds.runtime.instructions.spark.BinarySPInstruction;
+import org.apache.sysds.runtime.instructions.spark.CSVReblockSPInstruction;
+import org.apache.sysds.runtime.instructions.spark.CastSPInstruction;
+import org.apache.sysds.runtime.instructions.spark.CentralMomentSPInstruction;
+import org.apache.sysds.runtime.instructions.spark.CheckpointSPInstruction;
+import org.apache.sysds.runtime.instructions.spark.CtableSPInstruction;
+import org.apache.sysds.runtime.instructions.spark.IndexingSPInstruction;
+import org.apache.sysds.runtime.instructions.spark.LIBSVMReblockSPInstruction;
+import org.apache.sysds.runtime.instructions.spark.MapmmChainSPInstruction;
+import org.apache.sysds.runtime.instructions.spark.MapmmSPInstruction;
+import org.apache.sysds.runtime.instructions.spark.MatrixIndexingSPInstruction;
+import org.apache.sysds.runtime.instructions.spark.MatrixReshapeSPInstruction;
+import org.apache.sysds.runtime.instructions.spark.PMapmmSPInstruction;
+import org.apache.sysds.runtime.instructions.spark.ParameterizedBuiltinSPInstruction;
+import org.apache.sysds.runtime.instructions.spark.PmmSPInstruction;
+import org.apache.sysds.runtime.instructions.spark.QuantileSortSPInstruction;
+import org.apache.sysds.runtime.instructions.spark.QuaternarySPInstruction;
+import org.apache.sysds.runtime.instructions.spark.RandSPInstruction;
+import org.apache.sysds.runtime.instructions.spark.ReblockSPInstruction;
+import org.apache.sysds.runtime.instructions.spark.ReorgSPInstruction;
+import org.apache.sysds.runtime.instructions.spark.SPInstruction;
+import org.apache.sysds.runtime.instructions.spark.TernarySPInstruction;
+import org.apache.sysds.runtime.instructions.spark.Tsmm2SPInstruction;
+import org.apache.sysds.runtime.instructions.spark.TsmmSPInstruction;
+import org.apache.sysds.runtime.instructions.spark.UnaryFrameSPInstruction;
+import org.apache.sysds.runtime.instructions.spark.UnaryMatrixSPInstruction;
+import org.apache.sysds.runtime.instructions.spark.UnarySPInstruction;
+import org.apache.sysds.runtime.instructions.spark.WriteSPInstruction;
+import org.apache.sysds.runtime.instructions.spark.ZipmmSPInstruction;
import org.apache.sysds.runtime.meta.DataCharacteristics;
import org.apache.sysds.runtime.meta.MetaDataFormat;
import static org.apache.sysds.lops.Data.PREAD_PREFIX;
-import static org.apache.sysds.lops.DataGen.*;
import static org.apache.sysds.resource.cost.CPCostUtils.opcodeRequiresScan;
-import static org.apache.sysds.resource.cost.IOCostUtils.*;
import static org.apache.sysds.resource.cost.SparkCostUtils.getRandInstTime;
-import java.util.*;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.HashMap;
+import java.util.HashSet;
+import java.util.List;
+
/**
* Class for estimating the execution time of a program.
@@ -87,7 +156,7 @@ public CostEstimator(Program program, CloudInstance driverNode, CloudInstance ex
_functions = new HashSet<>();
localMemoryLimit = (long) (OptimizerUtils.getLocalMemBudget() * MEM_ALLOCATION_LIMIT_FRACTION);
freeLocalMemory = localMemoryLimit;
- driverMetrics = new IOMetrics(driverNode);
+ driverMetrics = new IOCostUtils.IOMetrics(driverNode);
if (executorNode == null) {
// estimation for single node execution -> no executor resources
executorMetrics = null;
@@ -100,7 +169,7 @@ public CostEstimator(Program program, CloudInstance driverNode, CloudInstance ex
((double) executorNode.getVCPUs() / dedicatedExecutorCores));
// adapting the rest of the metrics not needed since the OS and resource management tasks
// would not consume large portion of the memory/storage/network bandwidth in the general case
- executorMetrics = new IOMetrics(
+ executorMetrics = new IOCostUtils.IOMetrics(
effectiveExecutorFlops,
dedicatedExecutorCores,
executorNode.getMemoryBandwidth(),
@@ -645,7 +714,7 @@ public double parseSPInst(SPInstruction inst) throws CostEstimationException {
RandSPInstruction rinst = (RandSPInstruction) inst;
String opcode = rinst.getOpcode();
int randType = -1; // default for non-random object generation operations
- if (opcode.equals(RAND_OPCODE) || opcode.equals(FRAME_OPCODE)) {
+ if (opcode.equals(DataGen.RAND_OPCODE) || opcode.equals(DataGen.FRAME_OPCODE)) {
if (rinst.getMinValue() == 0d && rinst.getMaxValue() == 0d) { // empty matrix
randType = 0;
} else if (rinst.getSparsity() == 1.0 && rinst.getMinValue() == rinst.getMaxValue()) { // allocate, array fill
@@ -747,7 +816,10 @@ public double parseSPInst(SPInstruction inst) throws CostEstimationException {
} else {
throw new RuntimeException("Unsupported Unary Spark instruction of type " + inst.getClass().getName());
}
- } else if (inst instanceof BinaryFrameFrameSPInstruction || inst instanceof BinaryFrameMatrixSPInstruction || inst instanceof BinaryMatrixMatrixSPInstruction || inst instanceof BinaryMatrixScalarSPInstruction) {
+ } else if (inst instanceof BinaryFrameFrameSPInstruction
+ || inst instanceof BinaryFrameMatrixSPInstruction
+ || inst instanceof BinaryMatrixMatrixSPInstruction
+ || inst instanceof BinaryMatrixScalarSPInstruction) {
BinarySPInstruction binst = (BinarySPInstruction) inst;
VarStats input1 = getStatsWithDefaultScalar((binst).input1.getName());
VarStats input2 = getStatsWithDefaultScalar((binst).input2.getName());
@@ -942,7 +1014,7 @@ public double getTimeEstimateSparkJob(VarStats varToCollect) {
collectTime = IOCostUtils.getSparkCollectTime(varToCollect.rddStats, driverMetrics, executorMetrics);
} else {
// redirect through HDFS (writing to HDFS on executors and reading back on driver)
- varToCollect.fileInfo = new Object[] {HDFS_SOURCE_IDENTIFIER, FileFormat.BINARY};
+ varToCollect.fileInfo = new Object[] {IOCostUtils.HDFS_SOURCE_IDENTIFIER, FileFormat.BINARY};
collectTime = IOCostUtils.getHadoopWriteTime(varToCollect, executorMetrics) +
IOCostUtils.getFileSystemReadTime(varToCollect, driverMetrics);
}
@@ -985,7 +1057,7 @@ private double loadCPVarStatsAndEstimateTime(VarStats input) throws CostEstimati
// loading from a file
if (input.fileInfo == null || input.fileInfo.length != 2) {
throw new RuntimeException("Time estimation is not possible without file info.");
- } else if (isInvalidDataSource((String) input.fileInfo[0])) {
+ } else if (IOCostUtils.isInvalidDataSource((String) input.fileInfo[0])) {
throw new RuntimeException("Time estimation is not possible for data source: " + input.fileInfo[0]);
}
loadTime = IOCostUtils.getFileSystemReadTime(input, driverMetrics);
@@ -1057,7 +1129,7 @@ private double loadRDDStatsAndEstimateTime(VarStats input) {
if (input.allocatedMemory >= 0) { // generated object locally
if (inputRDD.distributedSize < freeLocalMemory && inputRDD.distributedSize < (0.1 * localMemoryLimit)) {
// in this case transfer the data object over HDF (first set the fileInfo of the input)
- input.fileInfo = new Object[] {HDFS_SOURCE_IDENTIFIER, FileFormat.BINARY};
+ input.fileInfo = new Object[] {IOCostUtils.HDFS_SOURCE_IDENTIFIER, FileFormat.BINARY};
ret = IOCostUtils.getFileSystemWriteTime(input, driverMetrics);
ret += IOCostUtils.getHadoopReadTime(input, executorMetrics);
} else {
diff --git a/src/main/java/org/apache/sysds/resource/cost/SparkCostUtils.java b/src/main/java/org/apache/sysds/resource/cost/SparkCostUtils.java
index addf37d350d..6695a0d820b 100644
--- a/src/main/java/org/apache/sysds/resource/cost/SparkCostUtils.java
+++ b/src/main/java/org/apache/sysds/resource/cost/SparkCostUtils.java
@@ -24,28 +24,59 @@
import org.apache.sysds.conf.ConfigurationManager;
import org.apache.sysds.hops.AggBinaryOp;
import org.apache.sysds.hops.OptimizerUtils;
-import org.apache.sysds.lops.*;
+import org.apache.sysds.lops.DataGen;
+import org.apache.sysds.lops.MMTSJ;
+import org.apache.sysds.resource.cost.IOCostUtils.IOMetrics;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.controlprogram.context.SparkExecutionContext;
import org.apache.sysds.runtime.instructions.InstructionUtils;
import org.apache.sysds.runtime.instructions.cp.CPInstruction.CPType;
-import org.apache.sysds.runtime.instructions.spark.*;
+import org.apache.sysds.runtime.instructions.spark.AggregateUnarySPInstruction;
+import org.apache.sysds.runtime.instructions.spark.AggregateUnarySketchSPInstruction;
+import org.apache.sysds.runtime.instructions.spark.AppendGAlignedSPInstruction;
+import org.apache.sysds.runtime.instructions.spark.AppendMSPInstruction;
+import org.apache.sysds.runtime.instructions.spark.AppendRSPInstruction;
+import org.apache.sysds.runtime.instructions.spark.AppendSPInstruction;
+import org.apache.sysds.runtime.instructions.spark.BinaryFrameFrameSPInstruction;
+import org.apache.sysds.runtime.instructions.spark.BinaryFrameMatrixSPInstruction;
+import org.apache.sysds.runtime.instructions.spark.BinaryMatrixBVectorSPInstruction;
+import org.apache.sysds.runtime.instructions.spark.BinaryMatrixMatrixSPInstruction;
+import org.apache.sysds.runtime.instructions.spark.BinaryMatrixScalarSPInstruction;
+import org.apache.sysds.runtime.instructions.spark.BinarySPInstruction;
+import org.apache.sysds.runtime.instructions.spark.CastSPInstruction;
+import org.apache.sysds.runtime.instructions.spark.CentralMomentSPInstruction;
+import org.apache.sysds.runtime.instructions.spark.CpmmSPInstruction;
+import org.apache.sysds.runtime.instructions.spark.CtableSPInstruction;
+import org.apache.sysds.runtime.instructions.spark.CumulativeAggregateSPInstruction;
+import org.apache.sysds.runtime.instructions.spark.IndexingSPInstruction;
+import org.apache.sysds.runtime.instructions.spark.MapmmChainSPInstruction;
+import org.apache.sysds.runtime.instructions.spark.MapmmSPInstruction;
+import org.apache.sysds.runtime.instructions.spark.PMapmmSPInstruction;
+import org.apache.sysds.runtime.instructions.spark.ParameterizedBuiltinSPInstruction;
+import org.apache.sysds.runtime.instructions.spark.PmmSPInstruction;
+import org.apache.sysds.runtime.instructions.spark.QuantileSortSPInstruction;
+import org.apache.sysds.runtime.instructions.spark.QuaternarySPInstruction;
+import org.apache.sysds.runtime.instructions.spark.RmmSPInstruction;
+import org.apache.sysds.runtime.instructions.spark.SPInstruction;
import org.apache.sysds.runtime.instructions.spark.SPInstruction.SPType;
+import org.apache.sysds.runtime.instructions.spark.TernarySPInstruction;
+import org.apache.sysds.runtime.instructions.spark.Tsmm2SPInstruction;
+import org.apache.sysds.runtime.instructions.spark.TsmmSPInstruction;
+import org.apache.sysds.runtime.instructions.spark.UnarySPInstruction;
+import org.apache.sysds.runtime.instructions.spark.ZipmmSPInstruction;
import org.apache.sysds.runtime.matrix.operators.CMOperator;
import org.apache.sysds.runtime.meta.MatrixCharacteristics;
-import static org.apache.sysds.lops.DataGen.*;
-import static org.apache.sysds.resource.cost.IOCostUtils.*;
public class SparkCostUtils {
public static double getReblockInstTime(String opcode, VarStats input, VarStats output, IOMetrics executorMetrics) {
// Reblock triggers a new stage
// old stage: read text file + shuffle the intermediate text rdd
- double readTime = getHadoopReadTime(input, executorMetrics);
+ double readTime = IOCostUtils.getHadoopReadTime(input, executorMetrics);
long sizeTextFile = OptimizerUtils.estimateSizeTextOutput(input.getM(), input.getN(), input.getNNZ(), (Types.FileFormat) input.fileInfo[1]);
RDDStats textRdd = new RDDStats(sizeTextFile, -1);
- double shuffleTime = getSparkShuffleTime(textRdd, executorMetrics, false);
+ double shuffleTime = IOCostUtils.getSparkShuffleTime(textRdd, executorMetrics, false);
double timeStage1 = readTime + shuffleTime;
// new stage: transform partitioned shuffled text object into partitioned binary object
long nflop = getInstNFLOP(SPType.Reblock, opcode, output);
@@ -53,19 +84,19 @@ public static double getReblockInstTime(String opcode, VarStats input, VarStats
return timeStage1 + timeStage2;
}
- public static double getRandInstTime(String opcode, int randType, VarStats output, IOMetrics executorMetrics) {
- if (opcode.equals(SAMPLE_OPCODE)) {
+ public static double getRandInstTime(String opcode, int randType, VarStats output, IOCostUtils.IOMetrics executorMetrics) {
+ if (opcode.equals(DataGen.SAMPLE_OPCODE)) {
// sample uses sortByKey() op. and it should be handled differently
- throw new RuntimeException("Spark operation Rand with opcode " + SAMPLE_OPCODE + " is not supported yet");
+ throw new RuntimeException("Spark operation Rand with opcode " + DataGen.SAMPLE_OPCODE + " is not supported yet");
}
long nflop;
- if (opcode.equals(RAND_OPCODE) || opcode.equals(FRAME_OPCODE)) {
+ if (opcode.equals(DataGen.RAND_OPCODE) || opcode.equals(DataGen.FRAME_OPCODE)) {
if (randType == 0) return 0; // empty matrix
else if (randType == 1) nflop = 8; // allocate, array fill
else if (randType == 2) nflop = 32; // full rand
else throw new RuntimeException("Unknown type of random instruction");
- } else if (opcode.equals(SEQ_OPCODE)) {
+ } else if (opcode.equals(DataGen.SEQ_OPCODE)) {
nflop = 1;
} else {
throw new DMLRuntimeException("Rand operation with opcode '" + opcode + "' is not supported by SystemDS");
@@ -93,7 +124,7 @@ public static double getAggUnaryInstTime(UnarySPInstruction inst, VarStats input
((AggregateUnarySketchSPInstruction) inst).getAggType();
double shuffleTime;
if (inst instanceof CumulativeAggregateSPInstruction) {
- shuffleTime = getSparkShuffleTime(output.rddStats, executorMetrics, true);
+ shuffleTime = IOCostUtils.getSparkShuffleTime(output.rddStats, executorMetrics, true);
output.rddStats.hashPartitioned = true;
} else {
if (aggType == AggBinaryOp.SparkAggType.SINGLE_BLOCK) {
@@ -111,9 +142,9 @@ public static double getAggUnaryInstTime(UnarySPInstruction inst, VarStats input
input.getNNZ()
);
RDDStats filteredRDD = new RDDStats(diagonalBlockSize, input.rddStats.numPartitions);
- shuffleTime = getSparkShuffleTime(filteredRDD, executorMetrics, true);
+ shuffleTime = IOCostUtils.getSparkShuffleTime(filteredRDD, executorMetrics, true);
} else {
- shuffleTime = getSparkShuffleTime(input.rddStats, executorMetrics, true);
+ shuffleTime = IOCostUtils.getSparkShuffleTime(input.rddStats, executorMetrics, true);
}
output.rddStats.hashPartitioned = true;
output.rddStats.numPartitions = input.rddStats.numPartitions;
@@ -137,17 +168,17 @@ public static double getIndexingInstTime(IndexingSPInstruction inst, VarStats in
int blockSize = ConfigurationManager.getBlocksize();
if (output.getM() <= blockSize && output.getN() <= blockSize) {
// represents single block and multi block cases
- dataTransmissionTime = getSparkCollectTime(output.rddStats, driverMetrics, executorMetrics);
+ dataTransmissionTime = IOCostUtils.getSparkCollectTime(output.rddStats, driverMetrics, executorMetrics);
output.rddStats.isCollected = true;
} else {
// represents general indexing: worst case: shuffling required
- dataTransmissionTime = getSparkShuffleTime(output.rddStats, executorMetrics, true);
+ dataTransmissionTime = IOCostUtils.getSparkShuffleTime(output.rddStats, executorMetrics, true);
}
} else if (opcode.equals(Opcodes.LEFT_INDEX.toString())) {
// model combineByKey() with shuffling the second input
- dataTransmissionTime = getSparkShuffleTime(input2.rddStats, executorMetrics, true);
+ dataTransmissionTime = IOCostUtils.getSparkShuffleTime(input2.rddStats, executorMetrics, true);
} else { // mapLeftIndex
- dataTransmissionTime = getSparkBroadcastTime(input2, driverMetrics, executorMetrics);
+ dataTransmissionTime = IOCostUtils.getSparkBroadcastTime(input2, driverMetrics, executorMetrics);
}
long nflop = getInstNFLOP(SPType.MatrixIndexing, opcode, output);
// scan only the size of the output since filter is applied first
@@ -169,34 +200,34 @@ public static double getBinaryInstTime(SPInstruction inst, VarStats input1, VarS
if (inst instanceof BinaryMatrixMatrixSPInstruction) {
if (inst instanceof BinaryMatrixBVectorSPInstruction) {
// the second matrix is always the broadcast one
- dataTransmissionTime = getSparkBroadcastTime(input2, driverMetrics, executorMetrics);
+ dataTransmissionTime = IOCostUtils.getSparkBroadcastTime(input2, driverMetrics, executorMetrics);
// flatMapToPair() or ()mapPartitionsToPair invoked -> no shuffling
output.rddStats.numPartitions = input1.rddStats.numPartitions;
output.rddStats.hashPartitioned = input1.rddStats.hashPartitioned;
} else { // regular BinaryMatrixMatrixSPInstruction
// join() input1 and input2
- dataTransmissionTime = getSparkShuffleWriteTime(input1.rddStats, executorMetrics) +
- getSparkShuffleWriteTime(input2.rddStats, executorMetrics);
+ dataTransmissionTime = IOCostUtils.getSparkShuffleWriteTime(input1.rddStats, executorMetrics) +
+ IOCostUtils.getSparkShuffleWriteTime(input2.rddStats, executorMetrics);
if (input1.rddStats.hashPartitioned) {
output.rddStats.numPartitions = input1.rddStats.numPartitions;
if (!input2.rddStats.hashPartitioned || !(input1.rddStats.numPartitions == input2.rddStats.numPartitions)) {
// shuffle needed for join() -> actual shuffle only for input2
- dataTransmissionTime += getSparkShuffleReadStaticTime(input1.rddStats, executorMetrics) +
- getSparkShuffleReadTime(input2.rddStats, executorMetrics);
+ dataTransmissionTime += IOCostUtils.getSparkShuffleReadStaticTime(input1.rddStats, executorMetrics) +
+ IOCostUtils.getSparkShuffleReadTime(input2.rddStats, executorMetrics);
} else { // no shuffle needed for join() -> only read from local disk
- dataTransmissionTime += getSparkShuffleReadStaticTime(input1.rddStats, executorMetrics) +
- getSparkShuffleReadStaticTime(input2.rddStats, executorMetrics);
+ dataTransmissionTime += IOCostUtils.getSparkShuffleReadStaticTime(input1.rddStats, executorMetrics) +
+ IOCostUtils.getSparkShuffleReadStaticTime(input2.rddStats, executorMetrics);
}
} else if (input2.rddStats.hashPartitioned) {
output.rddStats.numPartitions = input2.rddStats.numPartitions;
// input1 not hash partitioned: shuffle needed for join() -> actual shuffle only for input2
- dataTransmissionTime += getSparkShuffleReadStaticTime(input1.rddStats, executorMetrics) +
- getSparkShuffleReadTime(input2.rddStats, executorMetrics);
+ dataTransmissionTime += IOCostUtils.getSparkShuffleReadStaticTime(input1.rddStats, executorMetrics) +
+ IOCostUtils.getSparkShuffleReadTime(input2.rddStats, executorMetrics);
} else {
// repartition all data needed
output.rddStats.numPartitions = 2 * output.rddStats.numPartitions;
- dataTransmissionTime += getSparkShuffleReadTime(input1.rddStats, executorMetrics) +
- getSparkShuffleReadTime(input2.rddStats, executorMetrics);
+ dataTransmissionTime += IOCostUtils.getSparkShuffleReadTime(input1.rddStats, executorMetrics) +
+ IOCostUtils.getSparkShuffleReadTime(input2.rddStats, executorMetrics);
}
output.rddStats.hashPartitioned = true;
}
@@ -217,16 +248,16 @@ public static double getBinaryInstTime(SPInstruction inst, VarStats input1, VarS
public static double getAppendInstTime(AppendSPInstruction inst, VarStats input1, VarStats input2, VarStats output, IOMetrics driverMetrics, IOMetrics executorMetrics) {
double dataTransmissionTime;
if (inst instanceof AppendMSPInstruction) {
- dataTransmissionTime = getSparkBroadcastTime(input2, driverMetrics, executorMetrics);
+ dataTransmissionTime = IOCostUtils.getSparkBroadcastTime(input2, driverMetrics, executorMetrics);
output.rddStats.hashPartitioned = true;
} else if (inst instanceof AppendRSPInstruction) {
- dataTransmissionTime = getSparkShuffleTime(output.rddStats, executorMetrics, false);
+ dataTransmissionTime = IOCostUtils.getSparkShuffleTime(output.rddStats, executorMetrics, false);
} else if (inst instanceof AppendGAlignedSPInstruction) {
// only changing matrix indexing
dataTransmissionTime = 0;
} else { // AppendGSPInstruction
// shuffle the whole appended matrix
- dataTransmissionTime = getSparkShuffleTime(input2.rddStats, executorMetrics, true);
+ dataTransmissionTime = IOCostUtils.getSparkShuffleTime(input2.rddStats, executorMetrics, true);
output.rddStats.hashPartitioned = true;
}
// opcode not relevant for the nflop estimation of append instructions;
@@ -241,7 +272,7 @@ public static double getReorgInstTime(UnarySPInstruction inst, VarStats input, V
double dataTransmissionTime;
switch (opcode) {
case "rshape":
- dataTransmissionTime = getSparkShuffleTime(input.rddStats, executorMetrics, true);
+ dataTransmissionTime = IOCostUtils.getSparkShuffleTime(input.rddStats, executorMetrics, true);
output.rddStats.hashPartitioned = true;
break;
case "r'":
@@ -249,7 +280,7 @@ public static double getReorgInstTime(UnarySPInstruction inst, VarStats input, V
output.rddStats.hashPartitioned = input.rddStats.hashPartitioned;
break;
case "rev":
- dataTransmissionTime = getSparkShuffleTime(output.rddStats, executorMetrics, true);
+ dataTransmissionTime = IOCostUtils.getSparkShuffleTime(output.rddStats, executorMetrics, true);
output.rddStats.hashPartitioned = true;
break;
case "rdiag":
@@ -267,8 +298,8 @@ public static double getReorgInstTime(UnarySPInstruction inst, VarStats input, V
shuffleFactor = 4;// estimate cost for 2 shuffles
}
// assume case: 4 times shuffling the output
- dataTransmissionTime = getSparkShuffleWriteTime(output.rddStats, executorMetrics) +
- getSparkShuffleReadTime(output.rddStats, executorMetrics);
+ dataTransmissionTime = IOCostUtils.getSparkShuffleWriteTime(output.rddStats, executorMetrics) +
+ IOCostUtils.getSparkShuffleReadTime(output.rddStats, executorMetrics);
dataTransmissionTime *= shuffleFactor;
break;
}
@@ -285,7 +316,7 @@ public static double getTSMMInstTime(UnarySPInstruction inst, VarStats input, Va
if (inst instanceof TsmmSPInstruction) {
type = ((TsmmSPInstruction) inst).getMMTSJType();
// fold() used but result is still a whole matrix block
- dataTransmissionTime = getSparkCollectTime(output.rddStats, driverMetrics, executorMetrics);
+ dataTransmissionTime = IOCostUtils.getSparkCollectTime(output.rddStats, driverMetrics, executorMetrics);
output.rddStats.isCollected = true;
} else { // Tsmm2SPInstruction
type = ((Tsmm2SPInstruction) inst).getMMTSJType();
@@ -296,9 +327,9 @@ public static double getTSMMInstTime(UnarySPInstruction inst, VarStats input, Va
input.getN() - input.characteristics.getBlocksize();
VarStats broadcast = new VarStats("tmp1", new MatrixCharacteristics(rowsRange, colsRange));
broadcast.rddStats = new RDDStats(broadcast);
- dataTransmissionTime = getSparkCollectTime(broadcast.rddStats, driverMetrics, executorMetrics);
- dataTransmissionTime += getSparkBroadcastTime(broadcast, driverMetrics, executorMetrics);
- dataTransmissionTime += getSparkCollectTime(output.rddStats, driverMetrics, executorMetrics);
+ dataTransmissionTime = IOCostUtils.getSparkCollectTime(broadcast.rddStats, driverMetrics, executorMetrics);
+ dataTransmissionTime += IOCostUtils.getSparkBroadcastTime(broadcast, driverMetrics, executorMetrics);
+ dataTransmissionTime += IOCostUtils.getSparkCollectTime(output.rddStats, driverMetrics, executorMetrics);
}
opcode += type.isLeft() ? "_left" : "_right";
long nflop = getInstNFLOP(inst.getSPInstructionType(), opcode, output, input);
@@ -312,8 +343,8 @@ public static double getCentralMomentInstTime(CentralMomentSPInstruction inst, V
double dataTransmissionTime = 0;
if (weights != null) {
- dataTransmissionTime = getSparkShuffleWriteTime(weights.rddStats, executorMetrics) +
- getSparkShuffleReadTime(weights.rddStats, executorMetrics);
+ dataTransmissionTime = IOCostUtils.getSparkShuffleWriteTime(weights.rddStats, executorMetrics) +
+ IOCostUtils.getSparkShuffleReadTime(weights.rddStats, executorMetrics);
}
output.rddStats.isCollected = true;
@@ -327,8 +358,8 @@ public static double getCentralMomentInstTime(CentralMomentSPInstruction inst, V
public static double getCastInstTime(CastSPInstruction inst, VarStats input, VarStats output, IOMetrics executorMetrics) {
double shuffleTime = 0;
if (input.getN() > input.characteristics.getBlocksize()) {
- shuffleTime = getSparkShuffleWriteTime(input.rddStats, executorMetrics) +
- getSparkShuffleReadTime(input.rddStats, executorMetrics);
+ shuffleTime = IOCostUtils.getSparkShuffleWriteTime(input.rddStats, executorMetrics) +
+ IOCostUtils.getSparkShuffleReadTime(input.rddStats, executorMetrics);
output.rddStats.hashPartitioned = true;
}
long nflop = getInstNFLOP(inst.getSPInstructionType(), inst.getOpcode(), output, input);
@@ -341,11 +372,11 @@ public static double getQSortInstTime(QuantileSortSPInstruction inst, VarStats i
double shuffleTime = 0;
if (weights != null) {
opcode += "_wts";
- shuffleTime += getSparkShuffleWriteTime(weights.rddStats, executorMetrics) +
- getSparkShuffleReadTime(weights.rddStats, executorMetrics);
+ shuffleTime += IOCostUtils.getSparkShuffleWriteTime(weights.rddStats, executorMetrics) +
+ IOCostUtils.getSparkShuffleReadTime(weights.rddStats, executorMetrics);
}
- shuffleTime += getSparkShuffleWriteTime(output.rddStats, executorMetrics) +
- getSparkShuffleReadTime(output.rddStats, executorMetrics);
+ shuffleTime += IOCostUtils.getSparkShuffleWriteTime(output.rddStats, executorMetrics) +
+ IOCostUtils.getSparkShuffleReadTime(output.rddStats, executorMetrics);
output.rddStats.hashPartitioned = true;
long nflop = getInstNFLOP(SPType.QSort, opcode, output, input, weights);
@@ -363,12 +394,12 @@ public static double getMatMulInstTime(BinarySPInstruction inst, VarStats input1
// estimate for in1.join(in2)
long joinedSize = input1.rddStats.distributedSize + input2.rddStats.distributedSize;
RDDStats joinedRDD = new RDDStats(joinedSize, -1);
- dataTransmissionTime = getSparkShuffleTime(joinedRDD, executorMetrics, true);
+ dataTransmissionTime = IOCostUtils.getSparkShuffleTime(joinedRDD, executorMetrics, true);
if (aggType == AggBinaryOp.SparkAggType.SINGLE_BLOCK) {
- dataTransmissionTime += getSparkCollectTime(output.rddStats, driverMetrics, executorMetrics);
+ dataTransmissionTime += IOCostUtils.getSparkCollectTime(output.rddStats, driverMetrics, executorMetrics);
output.rddStats.isCollected = true;
} else {
- dataTransmissionTime += getSparkShuffleTime(output.rddStats, executorMetrics, true);
+ dataTransmissionTime += IOCostUtils.getSparkShuffleTime(output.rddStats, executorMetrics, true);
output.rddStats.hashPartitioned = true;
}
numPartitionsForMapping = joinedRDD.numPartitions;
@@ -376,33 +407,33 @@ public static double getMatMulInstTime(BinarySPInstruction inst, VarStats input1
// estimate for in1.join(in2)
long joinedSize = input1.rddStats.distributedSize + input2.rddStats.distributedSize;
RDDStats joinedRDD = new RDDStats(joinedSize, -1);
- dataTransmissionTime = getSparkShuffleTime(joinedRDD, executorMetrics, true);
+ dataTransmissionTime = IOCostUtils.getSparkShuffleTime(joinedRDD, executorMetrics, true);
// estimate for out.combineByKey() per partition
- dataTransmissionTime += getSparkShuffleTime(output.rddStats, executorMetrics, false);
+ dataTransmissionTime += IOCostUtils.getSparkShuffleTime(output.rddStats, executorMetrics, false);
output.rddStats.hashPartitioned = true;
numPartitionsForMapping = joinedRDD.numPartitions;
} else if (inst instanceof MapmmSPInstruction) {
- dataTransmissionTime = getSparkBroadcastTime(input2, driverMetrics, executorMetrics);
+ dataTransmissionTime = IOCostUtils.getSparkBroadcastTime(input2, driverMetrics, executorMetrics);
MapmmSPInstruction mapmminst = (MapmmSPInstruction) inst;
AggBinaryOp.SparkAggType aggType = mapmminst.getAggType();
if (aggType == AggBinaryOp.SparkAggType.SINGLE_BLOCK) {
- dataTransmissionTime += getSparkCollectTime(output.rddStats, driverMetrics, executorMetrics);
+ dataTransmissionTime += IOCostUtils.getSparkCollectTime(output.rddStats, driverMetrics, executorMetrics);
output.rddStats.isCollected = true;
} else {
- dataTransmissionTime += getSparkShuffleTime(output.rddStats, executorMetrics, true);
+ dataTransmissionTime += IOCostUtils.getSparkShuffleTime(output.rddStats, executorMetrics, true);
output.rddStats.hashPartitioned = true;
}
numPartitionsForMapping = input1.rddStats.numPartitions;
} else if (inst instanceof PmmSPInstruction) {
- dataTransmissionTime = getSparkBroadcastTime(input2, driverMetrics, executorMetrics);
+ dataTransmissionTime = IOCostUtils.getSparkBroadcastTime(input2, driverMetrics, executorMetrics);
output.rddStats.numPartitions = input1.rddStats.numPartitions;
- dataTransmissionTime += getSparkShuffleTime(output.rddStats, executorMetrics, true);
+ dataTransmissionTime += IOCostUtils.getSparkShuffleTime(output.rddStats, executorMetrics, true);
output.rddStats.hashPartitioned = true;
numPartitionsForMapping = input1.rddStats.numPartitions;
} else if (inst instanceof ZipmmSPInstruction) {
// assume always a shuffle without data re-distribution
- dataTransmissionTime = getSparkShuffleTime(output.rddStats, executorMetrics, false);
- dataTransmissionTime += getSparkCollectTime(output.rddStats, driverMetrics, executorMetrics);
+ dataTransmissionTime = IOCostUtils.getSparkShuffleTime(output.rddStats, executorMetrics, false);
+ dataTransmissionTime += IOCostUtils.getSparkCollectTime(output.rddStats, driverMetrics, executorMetrics);
numPartitionsForMapping = input1.rddStats.numPartitions;
output.rddStats.isCollected = true;
} else if (inst instanceof PMapmmSPInstruction) {
@@ -425,10 +456,10 @@ public static double getMatMulChainInstTime(MapmmChainSPInstruction inst, VarSta
IOMetrics driverMetrics, IOMetrics executorMetrics) {
double dataTransmissionTime = 0;
if (input3 != null) {
- dataTransmissionTime += getSparkBroadcastTime(input3, driverMetrics, executorMetrics);
+ dataTransmissionTime += IOCostUtils.getSparkBroadcastTime(input3, driverMetrics, executorMetrics);
}
- dataTransmissionTime += getSparkBroadcastTime(input2, driverMetrics, executorMetrics);
- dataTransmissionTime += getSparkCollectTime(output.rddStats, driverMetrics, executorMetrics);
+ dataTransmissionTime += IOCostUtils.getSparkBroadcastTime(input2, driverMetrics, executorMetrics);
+ dataTransmissionTime += IOCostUtils.getSparkCollectTime(output.rddStats, driverMetrics, executorMetrics);
output.rddStats.isCollected = true;
long nflop = getInstNFLOP(SPType.MAPMMCHAIN, inst.getOpcode(), output, input1, input2);
@@ -441,20 +472,20 @@ public static double getCtableInstTime(CtableSPInstruction tableInst, VarStats i
double shuffleTime;
if (opcode.equals(Opcodes.CTABLEEXPAND.toString()) || !input2.isScalar() && input3.isScalar()) { // CTABLE_EXPAND_SCALAR_WEIGHT/CTABLE_TRANSFORM_SCALAR_WEIGHT
// in1.join(in2)
- shuffleTime = getSparkShuffleTime(input2.rddStats, executorMetrics, true);
+ shuffleTime = IOCostUtils.getSparkShuffleTime(input2.rddStats, executorMetrics, true);
} else if (input2.isScalar() && input3.isScalar()) { // CTABLE_TRANSFORM_HISTOGRAM
// no joins
shuffleTime = 0;
} else if (input2.isScalar() && !input3.isScalar()) { // CTABLE_TRANSFORM_WEIGHTED_HISTOGRAM
// in1.join(in3)
- shuffleTime = getSparkShuffleTime(input3.rddStats, executorMetrics, true);
+ shuffleTime = IOCostUtils.getSparkShuffleTime(input3.rddStats, executorMetrics, true);
} else { // CTABLE_TRANSFORM
// in1.join(in2).join(in3)
- shuffleTime = getSparkShuffleTime(input2.rddStats, executorMetrics, true);
- shuffleTime += getSparkShuffleTime(input3.rddStats, executorMetrics, true);
+ shuffleTime = IOCostUtils.getSparkShuffleTime(input2.rddStats, executorMetrics, true);
+ shuffleTime += IOCostUtils.getSparkShuffleTime(input3.rddStats, executorMetrics, true);
}
// combineByKey()
- shuffleTime += getSparkShuffleTime(output.rddStats, executorMetrics, true);
+ shuffleTime += IOCostUtils.getSparkShuffleTime(output.rddStats, executorMetrics, true);
output.rddStats.hashPartitioned = true;
long nflop = getInstNFLOP(SPType.Ctable, opcode, output, input1, input2, input3);
@@ -470,16 +501,16 @@ public static double getParameterizedBuiltinInstTime(ParameterizedBuiltinSPInstr
switch (opcode) {
case "rmempty":
if (input2.rddStats == null) // broadcast
- dataTransmissionTime = getSparkBroadcastTime(input2, driverMetrics, executorMetrics);
+ dataTransmissionTime = IOCostUtils.getSparkBroadcastTime(input2, driverMetrics, executorMetrics);
else // join
- dataTransmissionTime = getSparkShuffleTime(input1.rddStats, executorMetrics, true);
- dataTransmissionTime += getSparkShuffleTime(output.rddStats, executorMetrics, true);
+ dataTransmissionTime = IOCostUtils.getSparkShuffleTime(input1.rddStats, executorMetrics, true);
+ dataTransmissionTime += IOCostUtils.getSparkShuffleTime(output.rddStats, executorMetrics, true);
break;
case "contains":
if (input2.isScalar()) {
dataTransmissionTime = 0;
} else {
- dataTransmissionTime = getSparkBroadcastTime(input2, driverMetrics, executorMetrics);
+ dataTransmissionTime = IOCostUtils.getSparkBroadcastTime(input2, driverMetrics, executorMetrics);
// ignore reduceByKey() cost
}
output.rddStats.isCollected = true;
@@ -505,32 +536,32 @@ public static double getTernaryInstTime(TernarySPInstruction tInst, VarStats inp
if (!input1.isScalar() && !input2.isScalar()) {
inputRddStats = new RDDStats[]{input1.rddStats, input2.rddStats};
// input1.join(input2)
- dataTransmissionTime += getSparkShuffleTime(input1.rddStats, executorMetrics,
+ dataTransmissionTime += IOCostUtils.getSparkShuffleTime(input1.rddStats, executorMetrics,
input1.rddStats.hashPartitioned);
- dataTransmissionTime += getSparkShuffleTime(input2.rddStats, executorMetrics,
+ dataTransmissionTime += IOCostUtils.getSparkShuffleTime(input2.rddStats, executorMetrics,
input2.rddStats.hashPartitioned);
} else if (!input1.isScalar() && !input3.isScalar()) {
inputRddStats = new RDDStats[]{input1.rddStats, input3.rddStats};
// input1.join(input3)
- dataTransmissionTime += getSparkShuffleTime(input1.rddStats, executorMetrics,
+ dataTransmissionTime += IOCostUtils.getSparkShuffleTime(input1.rddStats, executorMetrics,
input1.rddStats.hashPartitioned);
- dataTransmissionTime += getSparkShuffleTime(input3.rddStats, executorMetrics,
+ dataTransmissionTime += IOCostUtils.getSparkShuffleTime(input3.rddStats, executorMetrics,
input3.rddStats.hashPartitioned);
} else if (!input2.isScalar() || !input3.isScalar()) {
inputRddStats = new RDDStats[]{input2.rddStats, input3.rddStats};
// input2.join(input3)
- dataTransmissionTime += getSparkShuffleTime(input2.rddStats, executorMetrics,
+ dataTransmissionTime += IOCostUtils.getSparkShuffleTime(input2.rddStats, executorMetrics,
input2.rddStats.hashPartitioned);
- dataTransmissionTime += getSparkShuffleTime(input3.rddStats, executorMetrics,
+ dataTransmissionTime += IOCostUtils.getSparkShuffleTime(input3.rddStats, executorMetrics,
input3.rddStats.hashPartitioned);
} else if (!input1.isScalar() && !input2.isScalar() && !input3.isScalar()) {
inputRddStats = new RDDStats[]{input1.rddStats, input2.rddStats, input3.rddStats};
// input1.join(input2).join(input3)
- dataTransmissionTime += getSparkShuffleTime(input1.rddStats, executorMetrics,
+ dataTransmissionTime += IOCostUtils.getSparkShuffleTime(input1.rddStats, executorMetrics,
input1.rddStats.hashPartitioned);
- dataTransmissionTime += getSparkShuffleTime(input2.rddStats, executorMetrics,
+ dataTransmissionTime += IOCostUtils.getSparkShuffleTime(input2.rddStats, executorMetrics,
input2.rddStats.hashPartitioned);
- dataTransmissionTime += getSparkShuffleTime(input3.rddStats, executorMetrics,
+ dataTransmissionTime += IOCostUtils.getSparkShuffleTime(input3.rddStats, executorMetrics,
input3.rddStats.hashPartitioned);
}
@@ -547,12 +578,12 @@ public static double getQuaternaryInstTime(QuaternarySPInstruction quatInst, Var
throw new RuntimeException("Spark Quaternary reduce-operations are not supported yet");
}
double dataTransmissionTime;
- dataTransmissionTime = getSparkBroadcastTime(input2, driverMetrics, executorMetrics)
- + getSparkBroadcastTime(input3, driverMetrics, executorMetrics); // for map-side ops only
+ dataTransmissionTime = IOCostUtils.getSparkBroadcastTime(input2, driverMetrics, executorMetrics)
+ + IOCostUtils.getSparkBroadcastTime(input3, driverMetrics, executorMetrics); // for map-side ops only
if (opcode.equals("mapwsloss") || opcode.equals("mapwcemm")) {
output.rddStats.isCollected = true;
} else if (opcode.equals("mapwdivmm")) {
- dataTransmissionTime += getSparkShuffleTime(output.rddStats, executorMetrics, true);
+ dataTransmissionTime += IOCostUtils.getSparkShuffleTime(output.rddStats, executorMetrics, true);
}
long nflop = getInstNFLOP(quatInst.getSPInstructionType(), opcode, output, input1);
@@ -580,12 +611,12 @@ public static double getCPUTime(long nflop, int numPartitions, IOMetrics executo
for (RDDStats input: inputs) {
if (input == null) continue;
// compensates for spill-overs to account for non-compute bound operations
- memScanTime += getMemReadTime(input, executorMetrics);
+ memScanTime += IOCostUtils.getMemReadTime(input, executorMetrics);
}
double numWaves = Math.ceil((double) numPartitions / SparkExecutionContext.getDefaultParallelism(false));
double scaledNFLOP = (numWaves * nflop) / numPartitions;
double cpuComputationTime = scaledNFLOP / executorMetrics.cpuFLOPS;
- double memWriteTime = output != null? getMemWriteTime(output, executorMetrics) : 0;
+ double memWriteTime = output != null? IOCostUtils.getMemWriteTime(output, executorMetrics) : 0;
return Math.max(memScanTime, cpuComputationTime) + memWriteTime;
}
diff --git a/src/main/java/org/apache/sysds/resource/enumeration/EnumerationUtils.java b/src/main/java/org/apache/sysds/resource/enumeration/EnumerationUtils.java
index fa22f6c4f73..67a657c5d25 100644
--- a/src/main/java/org/apache/sysds/resource/enumeration/EnumerationUtils.java
+++ b/src/main/java/org/apache/sysds/resource/enumeration/EnumerationUtils.java
@@ -19,9 +19,12 @@
package org.apache.sysds.resource.enumeration;
-import org.apache.sysds.resource.CloudInstance;
+import java.util.Comparator;
+import java.util.HashMap;
+import java.util.LinkedList;
+import java.util.TreeMap;
-import java.util.*;
+import org.apache.sysds.resource.CloudInstance;
public class EnumerationUtils {
/**
diff --git a/src/main/java/org/apache/sysds/resource/enumeration/Enumerator.java b/src/main/java/org/apache/sysds/resource/enumeration/Enumerator.java
index b2de8410a5e..fb986c32650 100644
--- a/src/main/java/org/apache/sysds/resource/enumeration/Enumerator.java
+++ b/src/main/java/org/apache/sysds/resource/enumeration/Enumerator.java
@@ -33,8 +33,14 @@
import org.apache.sysds.resource.enumeration.EnumerationUtils.ConfigurationPoint;
import org.apache.sysds.resource.enumeration.EnumerationUtils.SolutionPoint;
-import java.util.*;
+import java.util.ArrayList;
+import java.util.EnumSet;
+import java.util.HashMap;
+import java.util.LinkedList;
+import java.util.List;
import java.util.Map.Entry;
+import java.util.Set;
+import java.util.TreeMap;
import java.util.concurrent.atomic.AtomicReference;
public abstract class Enumerator {
diff --git a/src/main/java/org/apache/sysds/resource/enumeration/GridBasedEnumerator.java b/src/main/java/org/apache/sysds/resource/enumeration/GridBasedEnumerator.java
index 571fc929b5c..5a13a4283f0 100644
--- a/src/main/java/org/apache/sysds/resource/enumeration/GridBasedEnumerator.java
+++ b/src/main/java/org/apache/sysds/resource/enumeration/GridBasedEnumerator.java
@@ -19,7 +19,7 @@
package org.apache.sysds.resource.enumeration;
-import java.util.*;
+import java.util.ArrayList;
public class GridBasedEnumerator extends Enumerator {
/** sets the step size at iterating over number of executors at enumeration */
diff --git a/src/main/java/org/apache/sysds/resource/enumeration/InterestBasedEnumerator.java b/src/main/java/org/apache/sysds/resource/enumeration/InterestBasedEnumerator.java
index 349d44312f5..aed8682dcd9 100644
--- a/src/main/java/org/apache/sysds/resource/enumeration/InterestBasedEnumerator.java
+++ b/src/main/java/org/apache/sysds/resource/enumeration/InterestBasedEnumerator.java
@@ -22,10 +22,19 @@
import org.apache.sysds.hops.Hop;
import org.apache.sysds.hops.OptimizerUtils;
import org.apache.sysds.parser.StatementBlock;
-import org.apache.sysds.runtime.controlprogram.*;
import org.apache.sysds.resource.enumeration.EnumerationUtils.InstanceSearchSpace;
+import org.apache.sysds.runtime.controlprogram.ForProgramBlock;
+import org.apache.sysds.runtime.controlprogram.FunctionProgramBlock;
+import org.apache.sysds.runtime.controlprogram.IfProgramBlock;
+import org.apache.sysds.runtime.controlprogram.Program;
+import org.apache.sysds.runtime.controlprogram.ProgramBlock;
+import org.apache.sysds.runtime.controlprogram.WhileProgramBlock;
-import java.util.*;
+import java.util.ArrayList;
+import java.util.List;
+import java.util.Map;
+import java.util.Set;
+import java.util.TreeSet;
import java.util.stream.Collectors;
import static org.apache.sysds.resource.CloudUtils.JVM_MEMORY_FACTOR;
diff --git a/src/main/java/org/apache/sysds/resource/enumeration/PruneBasedEnumerator.java b/src/main/java/org/apache/sysds/resource/enumeration/PruneBasedEnumerator.java
index 300188d5e6e..8b2ee747326 100644
--- a/src/main/java/org/apache/sysds/resource/enumeration/PruneBasedEnumerator.java
+++ b/src/main/java/org/apache/sysds/resource/enumeration/PruneBasedEnumerator.java
@@ -23,7 +23,13 @@
import org.apache.sysds.common.Opcodes;
import org.apache.sysds.resource.CloudInstance;
import org.apache.sysds.resource.ResourceCompiler;
-import org.apache.sysds.runtime.controlprogram.*;
+import org.apache.sysds.runtime.controlprogram.BasicProgramBlock;
+import org.apache.sysds.runtime.controlprogram.ForProgramBlock;
+import org.apache.sysds.runtime.controlprogram.FunctionProgramBlock;
+import org.apache.sysds.runtime.controlprogram.IfProgramBlock;
+import org.apache.sysds.runtime.controlprogram.Program;
+import org.apache.sysds.runtime.controlprogram.ProgramBlock;
+import org.apache.sysds.runtime.controlprogram.WhileProgramBlock;
import org.apache.sysds.runtime.instructions.Instruction;
import java.util.HashMap;
diff --git a/src/main/java/org/apache/sysds/runtime/compress/CompressedMatrixBlock.java b/src/main/java/org/apache/sysds/runtime/compress/CompressedMatrixBlock.java
index 48637595741..e08f731e829 100644
--- a/src/main/java/org/apache/sysds/runtime/compress/CompressedMatrixBlock.java
+++ b/src/main/java/org/apache/sysds/runtime/compress/CompressedMatrixBlock.java
@@ -75,7 +75,7 @@
import org.apache.sysds.runtime.data.SparseRow;
import org.apache.sysds.runtime.functionobjects.SwapIndex;
import org.apache.sysds.runtime.instructions.InstructionUtils;
-import org.apache.sysds.runtime.instructions.cp.CM_COV_Object;
+import org.apache.sysds.runtime.instructions.cp.CmCovObject;
import org.apache.sysds.runtime.instructions.cp.ScalarObject;
import org.apache.sysds.runtime.instructions.spark.data.IndexedMatrixValue;
import org.apache.sysds.runtime.matrix.data.CTableMap;
@@ -816,12 +816,12 @@ public MatrixBlock zeroOutOperations(MatrixValue result, IndexRange range) {
}
@Override
- public CM_COV_Object cmOperations(CMOperator op) {
+ public CmCovObject cmOperations(CMOperator op) {
return CLALibCMOps.centralMoment(this, op);
}
@Override
- public CM_COV_Object cmOperations(CMOperator op, MatrixBlock weights) {
+ public CmCovObject cmOperations(CMOperator op, MatrixBlock weights) {
printDecompressWarning("cmOperations");
MatrixBlock right = getUncompressed(weights);
if(isEmpty())
@@ -833,13 +833,13 @@ public CM_COV_Object cmOperations(CMOperator op, MatrixBlock weights) {
}
@Override
- public CM_COV_Object covOperations(COVOperator op, MatrixBlock that) {
+ public CmCovObject covOperations(COVOperator op, MatrixBlock that) {
MatrixBlock right = getUncompressed(that);
return getUncompressed("covOperations", op.getNumThreads()).covOperations(op, right);
}
@Override
- public CM_COV_Object covOperations(COVOperator op, MatrixBlock that, MatrixBlock weights) {
+ public CmCovObject covOperations(COVOperator op, MatrixBlock that, MatrixBlock weights) {
MatrixBlock right1 = getUncompressed(that);
MatrixBlock right2 = getUncompressed(weights);
return getUncompressed("covOperations", op.getNumThreads()).covOperations(op, right1, right2);
diff --git a/src/main/java/org/apache/sysds/runtime/compress/CompressionSettings.java b/src/main/java/org/apache/sysds/runtime/compress/CompressionSettings.java
index f6321bc1b6d..af944fce750 100644
--- a/src/main/java/org/apache/sysds/runtime/compress/CompressionSettings.java
+++ b/src/main/java/org/apache/sysds/runtime/compress/CompressionSettings.java
@@ -133,11 +133,14 @@ public class CompressionSettings {
public final double[] scaleFactors;
+ public final boolean preferDeltaEncoding;
+
protected CompressionSettings(double samplingRatio, double samplePower, boolean allowSharedDictionary,
String transposeInput, int seed, boolean lossy, EnumSet validCompressions,
boolean sortValuesByLength, PartitionerType columnPartitioner, int maxColGroupCoCode, double coCodePercentage,
int minimumSampleSize, int maxSampleSize, EstimationType estimationType, CostType costComputationType,
- double minimumCompressionRatio, boolean isInSparkInstruction, SORT_TYPE sdcSortType, double[] scaleFactors) {
+ double minimumCompressionRatio, boolean isInSparkInstruction, SORT_TYPE sdcSortType, double[] scaleFactors,
+ boolean preferDeltaEncoding) {
this.samplingRatio = samplingRatio;
this.samplePower = samplePower;
this.allowSharedDictionary = allowSharedDictionary;
@@ -157,6 +160,7 @@ protected CompressionSettings(double samplingRatio, double samplePower, boolean
this.isInSparkInstruction = isInSparkInstruction;
this.sdcSortType = sdcSortType;
this.scaleFactors = scaleFactors;
+ this.preferDeltaEncoding = preferDeltaEncoding;
if(!printedStatus && LOG.isDebugEnabled()) {
printedStatus = true;
diff --git a/src/main/java/org/apache/sysds/runtime/compress/CompressionSettingsBuilder.java b/src/main/java/org/apache/sysds/runtime/compress/CompressionSettingsBuilder.java
index ae6a0b2d231..02c9f97498d 100644
--- a/src/main/java/org/apache/sysds/runtime/compress/CompressionSettingsBuilder.java
+++ b/src/main/java/org/apache/sysds/runtime/compress/CompressionSettingsBuilder.java
@@ -53,6 +53,7 @@ public class CompressionSettingsBuilder {
private boolean isInSparkInstruction = false;
private SORT_TYPE sdcSortType = SORT_TYPE.MATERIALIZE;
private double[] scaleFactors = null;
+ private boolean preferDeltaEncoding = false;
public CompressionSettingsBuilder() {
@@ -101,6 +102,7 @@ public CompressionSettingsBuilder copySettings(CompressionSettings that) {
this.maxColGroupCoCode = that.maxColGroupCoCode;
this.coCodePercentage = that.coCodePercentage;
this.minimumSampleSize = that.minimumSampleSize;
+ this.preferDeltaEncoding = that.preferDeltaEncoding;
return this;
}
@@ -336,6 +338,19 @@ public CompressionSettingsBuilder setSDCSortType(SORT_TYPE sdcSortType) {
return this;
}
+ /**
+ * Set whether to prefer delta encoding during compression estimation.
+ * When enabled, the compression estimator will use delta encoding statistics
+ * instead of regular encoding statistics.
+ *
+ * @param preferDeltaEncoding Whether to prefer delta encoding
+ * @return The CompressionSettingsBuilder
+ */
+ public CompressionSettingsBuilder setPreferDeltaEncoding(boolean preferDeltaEncoding) {
+ this.preferDeltaEncoding = preferDeltaEncoding;
+ return this;
+ }
+
/**
* Create the CompressionSettings object to use in the compression.
*
@@ -345,6 +360,6 @@ public CompressionSettings create() {
return new CompressionSettings(samplingRatio, samplePower, allowSharedDictionary, transposeInput, seed, lossy,
validCompressions, sortValuesByLength, columnPartitioner, maxColGroupCoCode, coCodePercentage,
minimumSampleSize, maxSampleSize, estimationType, costType, minimumCompressionRatio, isInSparkInstruction,
- sdcSortType, scaleFactors);
+ sdcSortType, scaleFactors, preferDeltaEncoding);
}
}
diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/AColGroup.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/AColGroup.java
index ec502d6d122..003703f86a4 100644
--- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/AColGroup.java
+++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/AColGroup.java
@@ -45,7 +45,7 @@
import org.apache.sysds.runtime.data.SparseBlock;
import org.apache.sysds.runtime.data.SparseBlockMCSR;
import org.apache.sysds.runtime.functionobjects.Plus;
-import org.apache.sysds.runtime.instructions.cp.CM_COV_Object;
+import org.apache.sysds.runtime.instructions.cp.CmCovObject;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
import org.apache.sysds.runtime.matrix.operators.AggregateUnaryOperator;
import org.apache.sysds.runtime.matrix.operators.BinaryOperator;
@@ -610,7 +610,7 @@ public AColGroup addVector(double[] v) {
* @param nRows The number of rows contained in the ColumnGroup.
* @return A Central Moment object.
*/
- public abstract CM_COV_Object centralMoment(CMOperator op, int nRows);
+ public abstract CmCovObject centralMoment(CMOperator op, int nRows);
/**
* Expand the column group to multiple columns. (one hot encode the column group)
diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/AColGroupValue.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/AColGroupValue.java
index 0cde289b30f..45358c7ce46 100644
--- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/AColGroupValue.java
+++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/AColGroupValue.java
@@ -26,7 +26,7 @@
import org.apache.sysds.runtime.compress.colgroup.indexes.ColIndexFactory;
import org.apache.sysds.runtime.compress.colgroup.indexes.IColIndex;
import org.apache.sysds.runtime.functionobjects.Builtin;
-import org.apache.sysds.runtime.instructions.cp.CM_COV_Object;
+import org.apache.sysds.runtime.instructions.cp.CmCovObject;
import org.apache.sysds.runtime.matrix.operators.CMOperator;
public abstract class AColGroupValue extends ADictBasedColGroup {
@@ -189,7 +189,7 @@ public AColGroup replace(double pattern, double replace) {
}
@Override
- public CM_COV_Object centralMoment(CMOperator op, int nRows) {
+ public CmCovObject centralMoment(CMOperator op, int nRows) {
return _dict.centralMoment(op.fn, getCounts(), nRows);
}
diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupConst.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupConst.java
index 21c6a0e1d80..94137eb6381 100644
--- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupConst.java
+++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupConst.java
@@ -50,7 +50,7 @@
import org.apache.sysds.runtime.data.SparseBlock;
import org.apache.sysds.runtime.data.SparseBlockMCSR;
import org.apache.sysds.runtime.functionobjects.Builtin;
-import org.apache.sysds.runtime.instructions.cp.CM_COV_Object;
+import org.apache.sysds.runtime.instructions.cp.CmCovObject;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
import org.apache.sysds.runtime.matrix.operators.BinaryOperator;
import org.apache.sysds.runtime.matrix.operators.CMOperator;
@@ -518,8 +518,8 @@ protected double[] preAggBuiltinRows(Builtin builtin) {
}
@Override
- public CM_COV_Object centralMoment(CMOperator op, int nRows) {
- CM_COV_Object ret = new CM_COV_Object();
+ public CmCovObject centralMoment(CMOperator op, int nRows) {
+ CmCovObject ret = new CmCovObject();
op.fn.execute(ret, _dict.getValue(0), nRows);
return ret;
}
diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupDDC.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupDDC.java
index bbcefd134c1..03af6cad162 100644
--- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupDDC.java
+++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupDDC.java
@@ -33,6 +33,7 @@
import org.apache.sysds.runtime.compress.CompressedMatrixBlock;
import org.apache.sysds.runtime.compress.DMLCompressionException;
import org.apache.sysds.runtime.compress.colgroup.ColGroupUtils.P;
+import org.apache.sysds.runtime.compress.colgroup.dictionary.DeltaDictionary;
import org.apache.sysds.runtime.compress.colgroup.dictionary.Dictionary;
import org.apache.sysds.runtime.compress.colgroup.dictionary.DictionaryFactory;
import org.apache.sysds.runtime.compress.colgroup.dictionary.IDictionary;
@@ -43,6 +44,9 @@
import org.apache.sysds.runtime.compress.colgroup.indexes.RangeIndex;
import org.apache.sysds.runtime.compress.colgroup.mapping.AMapToData;
import org.apache.sysds.runtime.compress.colgroup.mapping.MapToFactory;
+import org.apache.sysds.runtime.compress.utils.ACount;
+import org.apache.sysds.runtime.compress.utils.DblArray;
+import org.apache.sysds.runtime.compress.utils.DblArrayCountHashMap;
import org.apache.sysds.runtime.compress.colgroup.offset.AOffsetIterator;
import org.apache.sysds.runtime.compress.colgroup.offset.OffsetFactory;
import org.apache.sysds.runtime.compress.colgroup.scheme.DDCScheme;
@@ -79,7 +83,7 @@ public class ColGroupDDC extends APreAgg implements IMapToDataGroup {
static final VectorSpecies SPECIES = DoubleVector.SPECIES_PREFERRED;
- private ColGroupDDC(IColIndex colIndexes, IDictionary dict, AMapToData data, int[] cachedCounts) {
+ protected ColGroupDDC(IColIndex colIndexes, IDictionary dict, AMapToData data, int[] cachedCounts) {
super(colIndexes, dict, cachedCounts);
_data = data;
@@ -1107,4 +1111,57 @@ protected boolean allowShallowIdentityRightMult() {
return true;
}
+ public AColGroup convertToDeltaDDC() {
+ int numCols = _colIndexes.size();
+ int numRows = _data.size();
+
+ DblArrayCountHashMap map = new DblArrayCountHashMap(Math.max(numRows, 64));
+ double[] rowDelta = new double[numCols];
+ double[] prevRow = new double[numCols];
+ DblArray dblArray = new DblArray(rowDelta);
+ int[] rowToDictId = new int[numRows];
+
+ double[] dictVals = _dict.getValues();
+
+ for(int i = 0; i < numRows; i++) {
+ int dictIdx = _data.getIndex(i);
+ int off = dictIdx * numCols;
+ for(int j = 0; j < numCols; j++) {
+ double val = dictVals[off + j];
+ if(i == 0) {
+ rowDelta[j] = val;
+ prevRow[j] = val;
+ } else {
+ rowDelta[j] = val - prevRow[j];
+ prevRow[j] = val;
+ }
+ }
+
+ rowToDictId[i] = map.increment(dblArray);
+ }
+
+ if(map.size() == 0)
+ return new ColGroupEmpty(_colIndexes);
+
+ ACount[] vals = map.extractValues();
+ final int nVals = vals.length;
+ final double[] dictValues = new double[nVals * numCols];
+ final int[] oldIdToNewId = new int[map.size()];
+ int idx = 0;
+ for(int i = 0; i < nVals; i++) {
+ final ACount dac = vals[i];
+ final double[] arrData = dac.key().getData();
+ System.arraycopy(arrData, 0, dictValues, idx, numCols);
+ oldIdToNewId[dac.id] = i;
+ idx += numCols;
+ }
+
+ DeltaDictionary deltaDict = new DeltaDictionary(dictValues, numCols);
+ AMapToData newData = MapToFactory.create(numRows, nVals);
+ for(int i = 0; i < numRows; i++) {
+ newData.set(i, oldIdToNewId[rowToDictId[i]]);
+ }
+ return ColGroupDeltaDDC.create(_colIndexes, deltaDict, newData, null);
+ }
+
}
diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupDDCFOR.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupDDCFOR.java
index 70191a27936..d2ee8cd6673 100644
--- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupDDCFOR.java
+++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupDDCFOR.java
@@ -48,7 +48,7 @@
import org.apache.sysds.runtime.functionobjects.Minus;
import org.apache.sysds.runtime.functionobjects.Multiply;
import org.apache.sysds.runtime.functionobjects.Plus;
-import org.apache.sysds.runtime.instructions.cp.CM_COV_Object;
+import org.apache.sysds.runtime.instructions.cp.CmCovObject;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
import org.apache.sysds.runtime.matrix.operators.BinaryOperator;
import org.apache.sysds.runtime.matrix.operators.CMOperator;
@@ -403,9 +403,9 @@ public AColGroup rexpandCols(int max, boolean ignore, boolean cast, int nRows) {
}
@Override
- public CM_COV_Object centralMoment(CMOperator op, int nRows) {
+ public CmCovObject centralMoment(CMOperator op, int nRows) {
// should be guaranteed to be one column therefore only one reference value.
- CM_COV_Object ret = _dict.centralMomentWithReference(op.fn, getCounts(), _reference[0], nRows);
+ CmCovObject ret = _dict.centralMomentWithReference(op.fn, getCounts(), _reference[0], nRows);
return ret;
}
diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupDeltaDDC.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupDeltaDDC.java
index 2666860ca68..08bdfd1e1d8 100644
--- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupDeltaDDC.java
+++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupDeltaDDC.java
@@ -19,62 +19,559 @@
package org.apache.sysds.runtime.compress.colgroup;
+import java.io.DataInput;
+import java.io.IOException;
+import java.util.Arrays;
+import java.util.Comparator;
+
+import org.apache.commons.lang3.NotImplementedException;
+import org.apache.sysds.runtime.compress.CompressedMatrixBlock;
+import org.apache.sysds.runtime.compress.DMLCompressionException;
+import org.apache.sysds.runtime.compress.colgroup.dictionary.DeltaDictionary;
+import org.apache.sysds.runtime.compress.colgroup.dictionary.Dictionary;
+import org.apache.sysds.runtime.compress.colgroup.dictionary.DictionaryFactory;
+import org.apache.sysds.runtime.compress.colgroup.dictionary.IDictionary;
+import org.apache.sysds.runtime.compress.colgroup.indexes.ColIndexFactory;
+import org.apache.sysds.runtime.compress.colgroup.indexes.IColIndex;
+import org.apache.sysds.runtime.compress.colgroup.mapping.AMapToData;
+import org.apache.sysds.runtime.compress.colgroup.mapping.MapToFactory;
+import org.apache.sysds.runtime.compress.utils.ACount;
+import org.apache.sysds.runtime.compress.utils.DblArray;
+import org.apache.sysds.runtime.compress.utils.DblArrayCountHashMap;
+import org.apache.sysds.runtime.compress.utils.DoubleCountHashMap;
+import org.apache.sysds.runtime.compress.utils.Util;
+import org.apache.sysds.runtime.data.DenseBlock;
+import org.apache.sysds.runtime.data.SparseBlock;
+import org.apache.sysds.runtime.data.SparseBlockMCSR;
+import org.apache.sysds.runtime.functionobjects.Builtin;
+import org.apache.sysds.runtime.functionobjects.Divide;
+import org.apache.sysds.runtime.functionobjects.Minus;
+import org.apache.sysds.runtime.functionobjects.Multiply;
+import org.apache.sysds.runtime.functionobjects.Plus;
+import org.apache.sysds.runtime.matrix.data.MatrixBlock;
+import org.apache.sysds.runtime.matrix.operators.ScalarOperator;
+import org.apache.sysds.runtime.matrix.operators.UnaryOperator;
+
/**
* Class to encapsulate information about a column group that is first delta encoded then encoded with dense dictionary
* encoding (DeltaDDC).
*/
-public class ColGroupDeltaDDC { // extends ColGroupDDC
-
-// private static final long serialVersionUID = -1045556313148564147L;
-
-// /** Constructor for serialization */
-// protected ColGroupDeltaDDC() {
-// }
-
-// private ColGroupDeltaDDC(int[] colIndexes, ADictionary dict, AMapToData data, int[] cachedCounts) {
-// super();
-// LOG.info("Carefully use of DeltaDDC since implementation is not finished.");
-// _colIndexes = colIndexes;
-// _dict = dict;
-// _data = data;
-// }
-
-// public static AColGroup create(int[] colIndices, ADictionary dict, AMapToData data, int[] cachedCounts) {
-// if(dict == null)
-// throw new NotImplementedException("Not implemented constant delta group");
-// else
-// return new ColGroupDeltaDDC(colIndices, dict, data, cachedCounts);
-// }
-
-// public CompressionType getCompType() {
-// return CompressionType.DeltaDDC;
-// }
-
-// @Override
-// protected void decompressToDenseBlockDenseDictionary(DenseBlock db, int rl, int ru, int offR, int offC,
-// double[] values) {
-// final int nCol = _colIndexes.length;
-// for(int i = rl, offT = rl + offR; i < ru; i++, offT++) {
-// final double[] c = db.values(offT);
-// final int off = db.pos(offT) + offC;
-// final int rowIndex = _data.getIndex(i) * nCol;
-// final int prevOff = (off == 0) ? off : off - nCol;
-// for(int j = 0; j < nCol; j++) {
-// // Here we use the values in the previous row to compute current values along with the delta
-// double newValue = c[prevOff + j] + values[rowIndex + j];
-// c[off + _colIndexes[j]] += newValue;
-// }
-// }
-// }
-
-// @Override
-// protected void decompressToSparseBlockDenseDictionary(SparseBlock ret, int rl, int ru, int offR, int offC,
-// double[] values) {
-// throw new NotImplementedException();
-// }
-
-// @Override
-// public AColGroup scalarOperation(ScalarOperator op) {
-// return new ColGroupDeltaDDC(_colIndexes, _dict.applyScalarOp(op), _data, getCachedCounts());
-// }
+public class ColGroupDeltaDDC extends ColGroupDDC {
+ private static final long serialVersionUID = -1045556313148564147L;
+
+ private ColGroupDeltaDDC(IColIndex colIndexes, IDictionary dict, AMapToData data, int[] cachedCounts) {
+ super(colIndexes, dict, data, cachedCounts);
+ if(CompressedMatrixBlock.debug) {
+ if(!(dict instanceof DeltaDictionary))
+ throw new DMLCompressionException("DeltaDDC must use DeltaDictionary");
+ }
+ }
+
+ public static AColGroup create(IColIndex colIndexes, IDictionary dict, AMapToData data, int[] cachedCounts) {
+ if(dict == null)
+ return new ColGroupEmpty(colIndexes);
+
+ if(!(dict instanceof DeltaDictionary))
+ throw new DMLCompressionException("ColGroupDeltaDDC must use DeltaDictionary");
+
+ if(data.getUnique() == 1) {
+ DeltaDictionary deltaDict = (DeltaDictionary) dict;
+ double[] values = deltaDict.getValues();
+ final int nCol = colIndexes.size();
+ boolean allZeros = true;
+ for(int i = 0; i < nCol; i++) {
+ if(!Util.eq(values[i], 0.0)) {
+ allZeros = false;
+ break;
+ }
+ }
+ if(allZeros) {
+ double[] constValues = new double[nCol];
+ System.arraycopy(values, 0, constValues, 0, nCol);
+ return ColGroupConst.create(colIndexes, Dictionary.create(constValues));
+ }
+ }
+
+ return new ColGroupDeltaDDC(colIndexes, dict, data, cachedCounts);
+ }
+
+ @Override
+ public CompressionType getCompType() {
+ return CompressionType.DeltaDDC;
+ }
+
+ @Override
+ public ColGroupType getColGroupType() {
+ return ColGroupType.DeltaDDC;
+ }
+
+ public static ColGroupDeltaDDC read(DataInput in) throws IOException {
+ IColIndex cols = ColIndexFactory.read(in);
+ IDictionary dict = DictionaryFactory.read(in);
+ AMapToData data = MapToFactory.readIn(in);
+ return new ColGroupDeltaDDC(cols, dict, data, null);
+ }
+
+ @Override
+ protected void decompressToDenseBlockDenseDictionary(DenseBlock db, int rl, int ru, int offR, int offC,
+ double[] values) {
+ final int nCol = _colIndexes.size();
+ final double[] prevRow = new double[nCol];
+
+ if(rl > 0) {
+ final int dictIdx0 = _data.getIndex(0);
+ final int rowIndex0 = dictIdx0 * nCol;
+ for(int j = 0; j < nCol; j++) {
+ prevRow[j] = values[rowIndex0 + j];
+ }
+ for(int i = 1; i < rl; i++) {
+ final int dictIdx = _data.getIndex(i);
+ final int rowIndex = dictIdx * nCol;
+ for(int j = 0; j < nCol; j++) {
+ prevRow[j] += values[rowIndex + j];
+ }
+ }
+ }
+
+ if(db.isContiguous() && nCol == db.getDim(1) && offC == 0) {
+ final int nColOut = db.getDim(1);
+ final double[] c = db.values(0);
+ for(int i = rl; i < ru; i++) {
+ final int dictIdx = _data.getIndex(i);
+ final int rowIndex = dictIdx * nCol;
+ final int rowBaseOff = (i + offR) * nColOut;
+
+ if(i == 0 && rl == 0) {
+ for(int j = 0; j < nCol; j++) {
+ final double value = values[rowIndex + j];
+ c[rowBaseOff + j] = value;
+ prevRow[j] = value;
+ }
+ }
+ else {
+ for(int j = 0; j < nCol; j++) {
+ final double delta = values[rowIndex + j];
+ final double newValue = prevRow[j] + delta;
+ c[rowBaseOff + j] = newValue;
+ prevRow[j] = newValue;
+ }
+ }
+ }
+ }
+ else {
+ for(int i = rl, offT = rl + offR; i < ru; i++, offT++) {
+ final double[] c = db.values(offT);
+ final int off = db.pos(offT) + offC;
+ final int dictIdx = _data.getIndex(i);
+ final int rowIndex = dictIdx * nCol;
+
+ if(i == 0 && rl == 0) {
+ for(int j = 0; j < nCol; j++) {
+ final double value = values[rowIndex + j];
+ final int colIdx = _colIndexes.get(j);
+ c[off + colIdx] = value;
+ prevRow[j] = value;
+ }
+ }
+ else {
+ for(int j = 0; j < nCol; j++) {
+ final double delta = values[rowIndex + j];
+ final double newValue = prevRow[j] + delta;
+ final int colIdx = _colIndexes.get(j);
+ c[off + colIdx] = newValue;
+ prevRow[j] = newValue;
+ }
+ }
+ }
+ }
+ }
+
+ @Override
+ protected void decompressToSparseBlockDenseDictionary(SparseBlock ret, int rl, int ru, int offR, int offC,
+ double[] values) {
+ final int nCol = _colIndexes.size();
+ final double[] prevRow = new double[nCol];
+
+ if(rl > 0) {
+ final int dictIdx0 = _data.getIndex(0);
+ final int rowIndex0 = dictIdx0 * nCol;
+ for(int j = 0; j < nCol; j++) {
+ prevRow[j] = values[rowIndex0 + j];
+ }
+ for(int i = 1; i < rl; i++) {
+ final int dictIdx = _data.getIndex(i);
+ final int rowIndex = dictIdx * nCol;
+ for(int j = 0; j < nCol; j++) {
+ prevRow[j] += values[rowIndex + j];
+ }
+ }
+ }
+
+ for(int i = rl, offT = rl + offR; i < ru; i++, offT++) {
+ final int dictIdx = _data.getIndex(i);
+ final int rowIndex = dictIdx * nCol;
+
+ if(i == 0 && rl == 0) {
+ for(int j = 0; j < nCol; j++) {
+ final double value = values[rowIndex + j];
+ final int colIdx = _colIndexes.get(j);
+ ret.append(offT, colIdx + offC, value);
+ prevRow[j] = value;
+ }
+ }
+ else {
+ for(int j = 0; j < nCol; j++) {
+ final double delta = values[rowIndex + j];
+ final double newValue = prevRow[j] + delta;
+ final int colIdx = _colIndexes.get(j);
+ ret.append(offT, colIdx + offC, newValue);
+ prevRow[j] = newValue;
+ }
+ }
+ }
+ }
+
+ @Override
+ protected void decompressToDenseBlockSparseDictionary(DenseBlock db, int rl, int ru, int offR, int offC,
+ SparseBlock sb) {
+ throw new NotImplementedException("Dense block decompression from sparse dictionary for DeltaDDC not yet implemented");
+ }
+
+ @Override
+ protected void decompressToSparseBlockSparseDictionary(SparseBlock ret, int rl, int ru, int offR, int offC,
+ SparseBlock sb) {
+ throw new NotImplementedException("Sparse block decompression from sparse dictionary for DeltaDDC not yet implemented");
+ }
+
+ @Override
+ protected void decompressToDenseBlockTransposedSparseDictionary(DenseBlock db, int rl, int ru, SparseBlock sb) {
+ throw new NotImplementedException("Transposed dense block decompression from sparse dictionary for DeltaDDC not yet implemented");
+ }
+
+ @Override
+ protected void decompressToDenseBlockTransposedDenseDictionary(DenseBlock db, int rl, int ru, double[] dict) {
+ throw new NotImplementedException("Transposed dense block decompression from dense dictionary for DeltaDDC not yet implemented");
+ }
+
+ @Override
+ protected void decompressToSparseBlockTransposedSparseDictionary(SparseBlockMCSR sbr, SparseBlock sb, int nColOut) {
+ throw new NotImplementedException("Transposed sparse block decompression from sparse dictionary for DeltaDDC not yet implemented");
+ }
+
+ @Override
+ protected void decompressToSparseBlockTransposedDenseDictionary(SparseBlockMCSR sbr, double[] dict, int nColOut) {
+ throw new NotImplementedException("Transposed sparse block decompression from dense dictionary for DeltaDDC not yet implemented");
+ }
+
+ @Override
+ public AColGroup scalarOperation(ScalarOperator op) {
+ if(op.fn instanceof Multiply || op.fn instanceof Divide) {
+ double[] val = _dict.getValues();
+ double[] newVal = new double[val.length];
+ for(int i = 0; i < val.length; i++)
+ newVal[i] = op.executeScalar(val[i]);
+ return create(_colIndexes, new DeltaDictionary(newVal, _colIndexes.size()), _data, getCounts());
+ }
+ else if(op.fn instanceof Plus || op.fn instanceof Minus) {
+ return scalarOperationShift(op);
+ }
+ else {
+ AColGroup ddc = convertToDDC();
+ return ddc.scalarOperation(op);
+ }
+ }
+
+ private AColGroup scalarOperationShift(ScalarOperator op) {
+ final int nCol = _colIndexes.size();
+ final int id0 = _data.getIndex(0);
+ final double[] vals = _dict.getValues();
+ final double[] tuple0 = new double[nCol];
+ for(int j = 0; j < nCol; j++)
+ tuple0[j] = vals[id0 * nCol + j];
+
+ final double[] tupleNew = new double[nCol];
+ for(int j = 0; j < nCol; j++)
+ tupleNew[j] = op.executeScalar(tuple0[j]);
+
+ int[] counts = getCounts();
+ if(counts[id0] == 1) {
+ double[] newVals = vals.clone();
+ for(int j = 0; j < nCol; j++)
+ newVals[id0 * nCol + j] = tupleNew[j];
+ return create(_colIndexes, new DeltaDictionary(newVals, nCol), _data, counts);
+ }
+ else {
+ int idNew = -1;
+ int nEntries = vals.length / nCol;
+ for(int k = 0; k < nEntries; k++) {
+ boolean match = true;
+ for(int j = 0; j < nCol; j++) {
+ if(vals[k * nCol + j] != tupleNew[j]) {
+ match = false;
+ break;
+ }
+ }
+ if(match) {
+ idNew = k;
+ break;
+ }
+ }
+
+ IDictionary newDict = _dict;
+ if(idNew == -1) {
+ double[] newVals = Arrays.copyOf(vals, vals.length + nCol);
+ System.arraycopy(tupleNew, 0, newVals, vals.length, nCol);
+ newDict = new DeltaDictionary(newVals, nCol);
+ idNew = nEntries;
+ }
+
+ AMapToData newData = MapToFactory.create(_data.size(), Math.max(_data.getUpperBoundValue(), idNew) + 1);
+ for(int i = 0; i < _data.size(); i++)
+ newData.set(i, _data.getIndex(i));
+ newData.set(0, idNew);
+
+ return create(_colIndexes, newDict, newData, null);
+ }
+ }
+
+ @Override
+ public AColGroup unaryOperation(UnaryOperator op) {
+ AColGroup ddc = convertToDDC();
+ return ddc.unaryOperation(op);
+ }
+
+ @Override
+ public void leftMultByMatrixNoPreAgg(MatrixBlock matrix, MatrixBlock result, int rl, int ru, int cl, int cu) {
+ throw new NotImplementedException("Left matrix multiplication not supported for DeltaDDC");
+ }
+
+ @Override
+ public void rightDecompressingMult(MatrixBlock right, MatrixBlock ret, int rl, int ru, int nRows, int crl, int cru) {
+ throw new NotImplementedException("Right matrix multiplication not supported for DeltaDDC");
+ }
+
+ @Override
+ public void preAggregateDense(MatrixBlock m, double[] preAgg, int rl, int ru, int cl, int cu) {
+ throw new NotImplementedException("Pre-aggregate dense not supported for DeltaDDC");
+ }
+
+ @Override
+ public void preAggregateSparse(SparseBlock sb, double[] preAgg, int rl, int ru, int cl, int cu) {
+ throw new NotImplementedException("Pre-aggregate sparse not supported for DeltaDDC");
+ }
+
+ @Override
+ public void preAggregateThatDDCStructure(ColGroupDDC that, Dictionary ret) {
+ throw new NotImplementedException("Pre-aggregate DDC structure not supported for DeltaDDC");
+ }
+
+ @Override
+ public void preAggregateThatSDCZerosStructure(ColGroupSDCZeros that, Dictionary ret) {
+ throw new NotImplementedException("Pre-aggregate SDCZeros structure not supported for DeltaDDC");
+ }
+
+ @Override
+ public void preAggregateThatSDCSingleZerosStructure(ColGroupSDCSingleZeros that, Dictionary ret) {
+ throw new NotImplementedException("Pre-aggregate SDCSingleZeros structure not supported for DeltaDDC");
+ }
+
+ @Override
+ protected void preAggregateThatRLEStructure(ColGroupRLE that, Dictionary ret) {
+ throw new NotImplementedException("Pre-aggregate RLE structure not supported for DeltaDDC");
+ }
+
+ @Override
+ protected double computeMxx(double c, Builtin builtin) {
+ throw new NotImplementedException("Compute Min/Max not supported for DeltaDDC");
+ }
+
+ @Override
+ protected void computeColMxx(double[] c, Builtin builtin) {
+ throw new NotImplementedException("Compute Column Min/Max not supported for DeltaDDC");
+ }
+
+ @Override
+ protected void computeRowMxx(double[] c, Builtin builtin, int rl, int ru, double[] preAgg) {
+ throw new NotImplementedException("Compute Row Min/Max not supported for DeltaDDC");
+ }
+
+ @Override
+ protected void computeRowSums(double[] c, int rl, int ru, double[] preAgg) {
+ throw new NotImplementedException("Compute Row Sums not supported for DeltaDDC");
+ }
+
+ @Override
+ protected void computeRowProduct(double[] c, int rl, int ru, double[] preAgg) {
+ throw new NotImplementedException("Compute Row Product not supported for DeltaDDC");
+ }
+
+ @Override
+ public boolean containsValue(double pattern) {
+ throw new NotImplementedException("Contains value not supported for DeltaDDC");
+ }
+
+ @Override
+ public AColGroup append(AColGroup g) {
+ throw new NotImplementedException("Append not supported for DeltaDDC");
+ }
+
+ @Override
+ public AColGroup appendNInternal(AColGroup[] g, int blen, int rlen) {
+ throw new NotImplementedException("AppendN not supported for DeltaDDC");
+ }
+
+ @Override
+ public long getNumberNonZeros(int nRows) {
+ long nnz = 0;
+ final int nCol = _colIndexes.size();
+ final double[] prevRow = new double[nCol];
+
+ for(int i = 0; i < nRows; i++) {
+ final int dictIdx = _data.getIndex(i);
+ final double[] vals = _dict.getValues();
+ final int rowIndex = dictIdx * nCol;
+
+ if(i == 0) {
+ for(int j = 0; j < nCol; j++) {
+ double val = vals[rowIndex + j];
+ prevRow[j] = val;
+ if(val != 0)
+ nnz++;
+ }
+ }
+ else {
+ for(int j = 0; j < nCol; j++) {
+ double val = prevRow[j] + vals[rowIndex + j];
+ prevRow[j] = val;
+ if(val != 0)
+ nnz++;
+ }
+ }
+ }
+ return nnz;
+ }
+
+ @Override
+ public AColGroup sliceRows(int rl, int ru) {
+ final int nCol = _colIndexes.size();
+ double[] firstRowValues = new double[nCol];
+ double[] dictVals = ((DeltaDictionary)_dict).getValues();
+
+ for(int i = 0; i <= rl; i++) {
+ int dictIdx = _data.getIndex(i);
+ int dictOffset = dictIdx * nCol;
+ if(i == 0) {
+ for(int j = 0; j < nCol; j++) firstRowValues[j] = dictVals[dictOffset + j];
+ } else {
+ for(int j = 0; j < nCol; j++) firstRowValues[j] += dictVals[dictOffset + j];
+ }
+ }
+
+ int nEntries = dictVals.length / nCol;
+ int newId = -1;
+ for(int k = 0; k < nEntries; k++) {
+ boolean match = true;
+ for(int j = 0; j < nCol; j++) {
+ if(dictVals[k * nCol + j] != firstRowValues[j]) {
+ match = false;
+ break;
+ }
+ }
+ if(match) {
+ newId = k;
+ break;
+ }
+ }
+
+ IDictionary newDict = _dict;
+ if(newId == -1) {
+ double[] newDictVals = Arrays.copyOf(dictVals, dictVals.length + nCol);
+ System.arraycopy(firstRowValues, 0, newDictVals, dictVals.length, nCol);
+ newDict = new DeltaDictionary(newDictVals, nCol);
+ newId = nEntries;
+ }
+
+ int numRows = ru - rl;
+ AMapToData slicedData = MapToFactory.create(numRows, Math.max(_data.getUpperBoundValue(), newId) + 1);
+ for(int i = 0; i < numRows; i++)
+ slicedData.set(i, _data.getIndex(rl + i));
+
+ slicedData.set(0, newId);
+ return ColGroupDeltaDDC.create(_colIndexes, newDict, slicedData, null);
+ }
+
+ private AColGroup convertToDDC() {
+ final int nCol = _colIndexes.size();
+ final int nRow = _data.size();
+ double[] values = new double[nRow * nCol];
+
+ double[] prevRow = new double[nCol];
+ for(int i = 0; i < nRow; i++) {
+ final int dictIdx = _data.getIndex(i);
+ final double[] dictVals = _dict.getValues();
+ final int rowIndex = dictIdx * nCol;
+
+ for(int j = 0; j < nCol; j++) {
+ if(i == 0) {
+ prevRow[j] = dictVals[rowIndex + j];
+ }
+ else {
+ prevRow[j] = prevRow[j] + dictVals[rowIndex + j];
+ }
+ values[i * nCol + j] = prevRow[j];
+ }
+ }
+
+ return compress(values, _colIndexes);
+ }
+
+ private static AColGroup compress(double[] values, IColIndex colIndexes) {
+ int nRow = values.length / colIndexes.size();
+ int nCol = colIndexes.size();
+
+ if(nCol == 1) {
+ DoubleCountHashMap map = new DoubleCountHashMap(16);
+ AMapToData mapData = MapToFactory.create(nRow, 256);
+ for(int i = 0; i < nRow; i++) {
+ int id = map.increment(values[i]);
+ if(id >= mapData.getUpperBoundValue()) {
+ mapData = mapData.resize(Math.max(mapData.getUpperBoundValue() * 2, id + 1));
+ }
+ mapData.set(i, id);
+ }
+ if(map.size() == 1)
+ return ColGroupConst.create(colIndexes, Dictionary.create(new double[] {map.getMostFrequent()}));
+
+ IDictionary dict = Dictionary.create(map.getDictionary());
+ return ColGroupDDC.create(colIndexes, dict, mapData.resize(map.size()), null);
+ }
+ else {
+ DblArrayCountHashMap map = new DblArrayCountHashMap(16);
+ AMapToData mapData = MapToFactory.create(nRow, 256);
+ DblArray dblArray = new DblArray(new double[nCol]);
+ for(int i = 0; i < nRow; i++) {
+ System.arraycopy(values, i * nCol, dblArray.getData(), 0, nCol);
+ int id = map.increment(dblArray);
+ if(id >= mapData.getUpperBoundValue()) {
+ mapData = mapData.resize(Math.max(mapData.getUpperBoundValue() * 2, id + 1));
+ }
+ mapData.set(i, id);
+ }
+ if(map.size() == 1) {
+ ACount[] counts = map.extractValues();
+ return ColGroupConst.create(colIndexes, Dictionary.create(counts[0].key().getData()));
+ }
+
+ ACount[] counts = map.extractValues();
+ Arrays.sort(counts, Comparator.comparingInt(x -> x.id));
+
+ double[] dictValues = new double[counts.length * nCol];
+ for(int i = 0; i < counts.length; i++) {
+ System.arraycopy(counts[i].key().getData(), 0, dictValues, i * nCol, nCol);
+ }
+
+ IDictionary dict = Dictionary.create(dictValues);
+ return ColGroupDDC.create(colIndexes, dict, mapData.resize(map.size()), null);
+ }
+ }
}
diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupEmpty.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupEmpty.java
index ba547a8d7aa..6d7872fce54 100644
--- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupEmpty.java
+++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupEmpty.java
@@ -49,7 +49,7 @@
import org.apache.sysds.runtime.data.SparseBlockMCSR;
import org.apache.sysds.runtime.functionobjects.Builtin;
import org.apache.sysds.runtime.functionobjects.ValueFunction;
-import org.apache.sysds.runtime.instructions.cp.CM_COV_Object;
+import org.apache.sysds.runtime.instructions.cp.CmCovObject;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
import org.apache.sysds.runtime.matrix.operators.BinaryOperator;
import org.apache.sysds.runtime.matrix.operators.CMOperator;
@@ -303,8 +303,8 @@ protected double[] preAggBuiltinRows(Builtin builtin) {
}
@Override
- public CM_COV_Object centralMoment(CMOperator op, int nRows) {
- CM_COV_Object ret = new CM_COV_Object();
+ public CmCovObject centralMoment(CMOperator op, int nRows) {
+ CmCovObject ret = new CmCovObject();
op.fn.execute(ret, 0.0, nRows);
return ret;
}
diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupFactory.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupFactory.java
index c6a098f5c32..273df9ff26f 100644
--- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupFactory.java
+++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupFactory.java
@@ -38,6 +38,7 @@
import org.apache.sysds.runtime.compress.bitmap.ABitmap;
import org.apache.sysds.runtime.compress.bitmap.BitmapEncoder;
import org.apache.sysds.runtime.compress.colgroup.AColGroup.CompressionType;
+import org.apache.sysds.runtime.compress.colgroup.dictionary.DeltaDictionary;
import org.apache.sysds.runtime.compress.colgroup.dictionary.Dictionary;
import org.apache.sysds.runtime.compress.colgroup.dictionary.DictionaryFactory;
import org.apache.sysds.runtime.compress.colgroup.dictionary.IDictionary;
@@ -288,6 +289,12 @@ else if((ct == CompressionType.SDC || ct == CompressionType.CONST) //
else if(ct == CompressionType.DDC) {
return directCompressDDC(colIndexes, cg);
}
+ else if(ct == CompressionType.DeltaDDC) {
+ return directCompressDeltaDDC(colIndexes, cg);
+ }
+ else if(ct == CompressionType.CONST && cs.preferDeltaEncoding) {
+ return directCompressDeltaDDC(colIndexes, cg);
+ }
else if(ct == CompressionType.LinearFunctional) {
if(cs.scaleFactors != null) {
throw new NotImplementedException(); // quantization-fused compression NOT allowed
@@ -684,6 +691,129 @@ private AColGroup directCompressDDCMultiCol(IColIndex colIndexes, CompressedSize
return ColGroupDDC.create(colIndexes, dict, resData, null);
}
+ private AColGroup directCompressDeltaDDC(IColIndex colIndexes, CompressedSizeInfoColGroup cg) throws Exception {
+ if(cs.transposed) {
+ throw new NotImplementedException("Delta encoding for transposed matrices not yet implemented");
+ }
+ if(cs.scaleFactors != null) {
+ throw new NotImplementedException("Delta encoding with quantization not yet implemented");
+ }
+
+ if(colIndexes.size() > 1) {
+ return directCompressDeltaDDCMultiCol(colIndexes, cg);
+ }
+ else {
+ return directCompressDeltaDDCSingleCol(colIndexes, cg);
+ }
+ }
+
+ private AColGroup directCompressDeltaDDCSingleCol(IColIndex colIndexes, CompressedSizeInfoColGroup cg) {
+ final AMapToData d = MapToFactory.create(nRow, Math.max(Math.min(cg.getNumOffs() + 1, nRow), 126));
+ final DoubleCountHashMap map = new DoubleCountHashMap(cg.getNumVals());
+
+ ReaderColumnSelection reader = ReaderColumnSelection.createDeltaReader(in, colIndexes, cs.transposed, 0, nRow);
+ DblArray cellVals = reader.nextRow();
+ int r = 0;
+ while(r < nRow && cellVals != null) {
+ final int row = reader.getCurrentRowIndex();
+ if(row == r) {
+ final double val = cellVals.getData()[0];
+ final int id = map.increment(val);
+ d.set(row, id);
+ cellVals = reader.nextRow();
+ r++;
+ }
+ else {
+ r = row;
+ }
+ }
+
+ if(map.size() == 0)
+ return new ColGroupEmpty(colIndexes);
+
+ final double[] dictValues = map.getDictionary();
+ IDictionary dict = new DeltaDictionary(dictValues, 1);
+
+ final int nUnique = map.size();
+ final AMapToData resData = d.resize(nUnique);
+ return ColGroupDeltaDDC.create(colIndexes, dict, resData, null);
+ }
+
+ private AColGroup directCompressDeltaDDCMultiCol(IColIndex colIndexes, CompressedSizeInfoColGroup cg) throws Exception {
+ final AMapToData d = MapToFactory.create(nRow, Math.max(Math.min(cg.getNumOffs() + 1, nRow), 126));
+ final int fill = d.getUpperBoundValue();
+ d.fill(fill);
+
+ final DblArrayCountHashMap map = new DblArrayCountHashMap(Math.max(cg.getNumVals(), 64));
+ boolean extra;
+ if(nRow < CompressionSettings.PAR_DDC_THRESHOLD || k < csi.getNumberColGroups() || pool == null) {
+ extra = readToMapDeltaDDC(colIndexes, map, d, 0, nRow, fill);
+ }
+ else {
+ throw new NotImplementedException("Parallel delta DDC compression not yet implemented");
+ }
+
+ if(map.size() == 0)
+ return new ColGroupEmpty(colIndexes);
+
+ final ACount[] vals = map.extractValues();
+ final int nVals = vals.length;
+ final int nTuplesOut = nVals + (extra ? 1 : 0);
+ final double[] dictValues = new double[nTuplesOut * colIndexes.size()];
+ final int[] oldIdToNewId = new int[map.size()];
+ int idx = 0;
+ for(int i = 0; i < nVals; i++) {
+ final ACount dac = vals[i];
+ final double[] arrData = dac.key().getData();
+ System.arraycopy(arrData, 0, dictValues, idx, colIndexes.size());
+ oldIdToNewId[dac.id] = i;
+ idx += colIndexes.size();
+ }
+ IDictionary dict = new DeltaDictionary(dictValues, colIndexes.size());
+
+ if(extra)
+ d.replace(fill, map.size());
+ final int nUnique = map.size() + (extra ? 1 : 0);
+ final AMapToData resData = d.resize(nUnique);
+ for(int i = 0; i < nRow; i++) {
+ final int oldId = resData.getIndex(i);
+ if(extra && oldId == map.size()) {
+ resData.set(i, nVals);
+ }
+ else if(oldId < oldIdToNewId.length) {
+ resData.set(i, oldIdToNewId[oldId]);
+ }
+ }
+ return ColGroupDeltaDDC.create(colIndexes, dict, resData, null);
+ }
+
+ private boolean readToMapDeltaDDC(IColIndex colIndexes, DblArrayCountHashMap map, AMapToData data, int rl, int ru,
+ int fill) {
+ ReaderColumnSelection reader = ReaderColumnSelection.createDeltaReader(in, colIndexes, cs.transposed, rl, ru);
+
+ DblArray cellVals = reader.nextRow();
+ boolean extra = false;
+ int r = rl;
+ while(r < ru && cellVals != null) {
+ final int row = reader.getCurrentRowIndex();
+ if(row == r) {
+ final int id = map.increment(cellVals);
+ data.set(row, id);
+ cellVals = reader.nextRow();
+ r++;
+ }
+ else {
+ r = row;
+ extra = true;
+ }
+ }
+
+ if(r < ru)
+ extra = true;
+
+ return extra;
+ }
+
private boolean readToMapDDC(IColIndex colIndexes, DblArrayCountHashMap map, AMapToData data, int rl, int ru,
int fill) {
diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupIO.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupIO.java
index 91442281317..b47100d4e64 100644
--- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupIO.java
+++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupIO.java
@@ -105,10 +105,12 @@ public static AColGroup readColGroup(DataInput in, int nRows) throws IOException
switch(ctype) {
case DDC:
return ColGroupDDC.read(in);
- case DDCFOR:
- return ColGroupDDCFOR.read(in);
- case OLE:
- return ColGroupOLE.read(in, nRows);
+ case DDCFOR:
+ return ColGroupDDCFOR.read(in);
+ case DeltaDDC:
+ return ColGroupDeltaDDC.read(in);
+ case OLE:
+ return ColGroupOLE.read(in, nRows);
case RLE:
return ColGroupRLE.read(in, nRows);
case CONST:
diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupLinearFunctional.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupLinearFunctional.java
index 45b4fbeb026..4e9fffaf718 100644
--- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupLinearFunctional.java
+++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupLinearFunctional.java
@@ -40,7 +40,7 @@
import org.apache.sysds.runtime.functionobjects.Minus;
import org.apache.sysds.runtime.functionobjects.Multiply;
import org.apache.sysds.runtime.functionobjects.Plus;
-import org.apache.sysds.runtime.instructions.cp.CM_COV_Object;
+import org.apache.sysds.runtime.instructions.cp.CmCovObject;
import org.apache.sysds.runtime.matrix.data.LibMatrixMult;
import org.apache.sysds.runtime.matrix.data.LibMatrixReorg;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
@@ -618,7 +618,7 @@ public long estimateInMemorySize() {
}
@Override
- public CM_COV_Object centralMoment(CMOperator op, int nRows) {
+ public CmCovObject centralMoment(CMOperator op, int nRows) {
throw new NotImplementedException();
}
diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupSDC.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupSDC.java
index 1270823bfdc..4340637a737 100644
--- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupSDC.java
+++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupSDC.java
@@ -50,7 +50,7 @@
import org.apache.sysds.runtime.compress.utils.Util;
import org.apache.sysds.runtime.data.SparseBlock;
import org.apache.sysds.runtime.functionobjects.Builtin;
-import org.apache.sysds.runtime.instructions.cp.CM_COV_Object;
+import org.apache.sysds.runtime.instructions.cp.CmCovObject;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
import org.apache.sysds.runtime.matrix.operators.BinaryOperator;
import org.apache.sysds.runtime.matrix.operators.CMOperator;
@@ -493,7 +493,7 @@ public AColGroup subtractDefaultTuple() {
}
@Override
- public CM_COV_Object centralMoment(CMOperator op, int nRows) {
+ public CmCovObject centralMoment(CMOperator op, int nRows) {
return _dict.centralMomentWithDefault(op.fn, getCounts(), _defaultTuple[0], nRows);
}
diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupSDCFOR.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupSDCFOR.java
index 41fb7ac5709..675c1120c38 100644
--- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupSDCFOR.java
+++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupSDCFOR.java
@@ -51,7 +51,7 @@
import org.apache.sysds.runtime.functionobjects.Minus;
import org.apache.sysds.runtime.functionobjects.Multiply;
import org.apache.sysds.runtime.functionobjects.Plus;
-import org.apache.sysds.runtime.instructions.cp.CM_COV_Object;
+import org.apache.sysds.runtime.instructions.cp.CmCovObject;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
import org.apache.sysds.runtime.matrix.operators.BinaryOperator;
import org.apache.sysds.runtime.matrix.operators.CMOperator;
@@ -431,7 +431,7 @@ public AColGroup rexpandCols(int max, boolean ignore, boolean cast, int nRows) {
}
@Override
- public CM_COV_Object centralMoment(CMOperator op, int nRows) {
+ public CmCovObject centralMoment(CMOperator op, int nRows) {
// should be guaranteed to be one column therefore only one reference value.
return _dict.centralMomentWithReference(op.fn, getCounts(), _reference[0], nRows);
}
diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupSDCSingle.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupSDCSingle.java
index fa5772c0c3e..a954f380a04 100644
--- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupSDCSingle.java
+++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupSDCSingle.java
@@ -48,7 +48,7 @@
import org.apache.sysds.runtime.compress.utils.IntArrayList;
import org.apache.sysds.runtime.compress.utils.Util;
import org.apache.sysds.runtime.functionobjects.Builtin;
-import org.apache.sysds.runtime.instructions.cp.CM_COV_Object;
+import org.apache.sysds.runtime.instructions.cp.CmCovObject;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
import org.apache.sysds.runtime.matrix.operators.BinaryOperator;
import org.apache.sysds.runtime.matrix.operators.CMOperator;
@@ -460,7 +460,7 @@ public long getNumberNonZeros(int nRows) {
}
@Override
- public CM_COV_Object centralMoment(CMOperator op, int nRows) {
+ public CmCovObject centralMoment(CMOperator op, int nRows) {
return _dict.centralMomentWithDefault(op.fn, getCounts(), _defaultTuple[0], nRows);
}
diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupUncompressed.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupUncompressed.java
index 1c3bce2e16c..8d446575975 100644
--- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupUncompressed.java
+++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupUncompressed.java
@@ -56,7 +56,7 @@
import org.apache.sysds.runtime.functionobjects.ReduceAll;
import org.apache.sysds.runtime.functionobjects.ReduceRow;
import org.apache.sysds.runtime.functionobjects.ValueFunction;
-import org.apache.sysds.runtime.instructions.cp.CM_COV_Object;
+import org.apache.sysds.runtime.instructions.cp.CmCovObject;
import org.apache.sysds.runtime.matrix.data.LibMatrixMult;
import org.apache.sysds.runtime.matrix.data.LibMatrixReorg;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
@@ -949,7 +949,7 @@ public void computeColSums(double[] c, int nRows) {
}
@Override
- public CM_COV_Object centralMoment(CMOperator op, int nRows) {
+ public CmCovObject centralMoment(CMOperator op, int nRows) {
return _data.cmOperations(op);
}
diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupUncompressedArray.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupUncompressedArray.java
index 31e29341645..08cbab30bcc 100644
--- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupUncompressedArray.java
+++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupUncompressedArray.java
@@ -28,7 +28,7 @@
import org.apache.sysds.runtime.data.SparseBlock;
import org.apache.sysds.runtime.data.SparseBlockMCSR;
import org.apache.sysds.runtime.frame.data.columns.Array;
-import org.apache.sysds.runtime.instructions.cp.CM_COV_Object;
+import org.apache.sysds.runtime.instructions.cp.CmCovObject;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
import org.apache.sysds.runtime.matrix.operators.AggregateUnaryOperator;
import org.apache.sysds.runtime.matrix.operators.BinaryOperator;
@@ -203,7 +203,7 @@ public void computeColSums(double[] c, int nRows) {
}
@Override
- public CM_COV_Object centralMoment(CMOperator op, int nRows) {
+ public CmCovObject centralMoment(CMOperator op, int nRows) {
throw new UnsupportedOperationException("Unimplemented method 'centralMoment'");
}
diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/ADictionary.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/ADictionary.java
index 31f2c9fb3c4..0b768ef4a2a 100644
--- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/ADictionary.java
+++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/ADictionary.java
@@ -26,7 +26,7 @@
import org.apache.sysds.runtime.data.SparseBlock;
import org.apache.sysds.runtime.functionobjects.Builtin;
import org.apache.sysds.runtime.functionobjects.ValueFunction;
-import org.apache.sysds.runtime.instructions.cp.CM_COV_Object;
+import org.apache.sysds.runtime.instructions.cp.CmCovObject;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
import org.apache.sysds.runtime.matrix.operators.BinaryOperator;
import org.apache.sysds.runtime.matrix.operators.ScalarOperator;
@@ -40,16 +40,16 @@ public abstract class ADictionary implements IDictionary, Serializable {
public abstract IDictionary clone();
- public final CM_COV_Object centralMoment(ValueFunction fn, int[] counts, int nRows) {
- return centralMoment(new CM_COV_Object(), fn, counts, nRows);
+ public final CmCovObject centralMoment(ValueFunction fn, int[] counts, int nRows) {
+ return centralMoment(new CmCovObject(), fn, counts, nRows);
}
- public final CM_COV_Object centralMomentWithDefault(ValueFunction fn, int[] counts, double def, int nRows) {
- return centralMomentWithDefault(new CM_COV_Object(), fn, counts, def, nRows);
+ public final CmCovObject centralMomentWithDefault(ValueFunction fn, int[] counts, double def, int nRows) {
+ return centralMomentWithDefault(new CmCovObject(), fn, counts, def, nRows);
}
- public final CM_COV_Object centralMomentWithReference(ValueFunction fn, int[] counts, double reference, int nRows) {
- return centralMomentWithReference(new CM_COV_Object(), fn, counts, reference, nRows);
+ public final CmCovObject centralMomentWithReference(ValueFunction fn, int[] counts, double reference, int nRows) {
+ return centralMomentWithReference(new CmCovObject(), fn, counts, reference, nRows);
}
@Override
@@ -144,7 +144,7 @@ public void productWithReference(double[] ret, int[] counts, double[] reference,
}
@Override
- public CM_COV_Object centralMoment(CM_COV_Object ret, ValueFunction fn, int[] counts, int nRows) {
+ public CmCovObject centralMoment(CmCovObject ret, ValueFunction fn, int[] counts, int nRows) {
return getMBDict().centralMoment(ret, fn, counts, nRows);
}
@@ -154,13 +154,13 @@ public double getSparsity() {
}
@Override
- public CM_COV_Object centralMomentWithDefault(CM_COV_Object ret, ValueFunction fn, int[] counts, double def,
+ public CmCovObject centralMomentWithDefault(CmCovObject ret, ValueFunction fn, int[] counts, double def,
int nRows) {
return getMBDict().centralMomentWithDefault(ret, fn, counts, def, nRows);
}
@Override
- public CM_COV_Object centralMomentWithReference(CM_COV_Object ret, ValueFunction fn, int[] counts, double reference,
+ public CmCovObject centralMomentWithReference(CmCovObject ret, ValueFunction fn, int[] counts, double reference,
int nRows) {
return getMBDict().centralMomentWithReference(ret, fn, counts, reference, nRows);
}
diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/DeltaDictionary.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/DeltaDictionary.java
index d67ab95f824..d667e76ed5e 100644
--- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/DeltaDictionary.java
+++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/DeltaDictionary.java
@@ -19,14 +19,13 @@
package org.apache.sysds.runtime.compress.colgroup.dictionary;
+import java.io.DataInput;
import java.io.DataOutput;
import java.io.IOException;
import org.apache.commons.lang3.NotImplementedException;
import org.apache.sysds.runtime.functionobjects.Divide;
-import org.apache.sysds.runtime.functionobjects.Minus;
import org.apache.sysds.runtime.functionobjects.Multiply;
-import org.apache.sysds.runtime.functionobjects.Plus;
import org.apache.sysds.runtime.matrix.operators.ScalarOperator;
/**
@@ -50,26 +49,22 @@ public double[] getValues(){
return _values;
}
+ @Override
+ public double getValue(int i, int col, int nCol) {
+ return _values[i * nCol + col];
+ }
+
@Override
public DeltaDictionary applyScalarOp(ScalarOperator op) {
- final double[] retV = new double[_values.length];
if(op.fn instanceof Multiply || op.fn instanceof Divide) {
+ final double[] retV = new double[_values.length];
for(int i = 0; i < _values.length; i++)
retV[i] = op.executeScalar(_values[i]);
+ return new DeltaDictionary(retV, _numCols);
}
- else if(op.fn instanceof Plus || op.fn instanceof Minus) {
- // With Plus and Minus only the first row needs to be updated when delta encoded
- for(int i = 0; i < _values.length; i++) {
- if(i < _numCols)
- retV[i] = op.executeScalar(_values[i]);
- else
- retV[i] = _values[i];
- }
+ else {
+ throw new NotImplementedException("Scalar op " + op.fn.getClass().getSimpleName() + " not supported in DeltaDictionary");
}
- else
- throw new NotImplementedException();
-
- return new DeltaDictionary(retV, _numCols);
}
@Override
@@ -79,17 +74,30 @@ public long getInMemorySize() {
@Override
public void write(DataOutput out) throws IOException {
- throw new NotImplementedException();
+ out.writeByte(DictionaryFactory.Type.DELTA_DICT.ordinal());
+ out.writeInt(_numCols);
+ out.writeInt(_values.length);
+ for(int i = 0; i < _values.length; i++)
+ out.writeDouble(_values[i]);
+ }
+
+ public static DeltaDictionary read(DataInput in) throws IOException {
+ int numCols = in.readInt();
+ int numValues = in.readInt();
+ double[] values = new double[numValues];
+ for(int i = 0; i < numValues; i++)
+ values[i] = in.readDouble();
+ return new DeltaDictionary(values, numCols);
}
@Override
public long getExactSizeOnDisk() {
- throw new NotImplementedException();
+ return 1 + 4 + 4 + 8L * _values.length;
}
@Override
public DictType getDictType() {
- throw new NotImplementedException();
+ return DictType.Delta;
}
@Override
@@ -104,12 +112,19 @@ public int getNumberOfColumns(int nrow){
@Override
public String getString(int colIndexes) {
- throw new NotImplementedException();
+ StringBuilder sb = new StringBuilder();
+ for(int i = 0; i < _values.length; i++) {
+ sb.append(_values[i]);
+ if(i != _values.length - 1) {
+ sb.append((i + 1) % colIndexes == 0 ? "\n" : ", ");
+ }
+ }
+ return sb.toString();
}
@Override
public long getNumberNonZeros(int[] counts, int nCol) {
- throw new NotImplementedException();
+ throw new NotImplementedException("Cannot calculate non-zeros from DeltaDictionary alone");
}
@Override
diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/Dictionary.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/Dictionary.java
index 939b48bf424..e94cbd7c570 100644
--- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/Dictionary.java
+++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/Dictionary.java
@@ -37,7 +37,7 @@
import org.apache.sysds.runtime.functionobjects.Multiply;
import org.apache.sysds.runtime.functionobjects.Plus;
import org.apache.sysds.runtime.functionobjects.ValueFunction;
-import org.apache.sysds.runtime.instructions.cp.CM_COV_Object;
+import org.apache.sysds.runtime.instructions.cp.CmCovObject;
import org.apache.sysds.runtime.matrix.data.LibMatrixMult;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
import org.apache.sysds.runtime.matrix.operators.BinaryOperator;
@@ -1078,7 +1078,7 @@ else if(!Double.isInfinite(ret[0]))
}
@Override
- public CM_COV_Object centralMoment(CM_COV_Object ret, ValueFunction fn, int[] counts, int nRows) {
+ public CmCovObject centralMoment(CmCovObject ret, ValueFunction fn, int[] counts, int nRows) {
// should be guaranteed to only contain one value per tuple in dictionary.
for(int i = 0; i < _values.length; i++)
fn.execute(ret, _values[i], counts[i]);
@@ -1089,7 +1089,7 @@ public CM_COV_Object centralMoment(CM_COV_Object ret, ValueFunction fn, int[] co
}
@Override
- public CM_COV_Object centralMomentWithDefault(CM_COV_Object ret, ValueFunction fn, int[] counts, double def,
+ public CmCovObject centralMomentWithDefault(CmCovObject ret, ValueFunction fn, int[] counts, double def,
int nRows) {
// should be guaranteed to only contain one value per tuple in dictionary.
for(int i = 0; i < _values.length; i++)
@@ -1101,7 +1101,7 @@ public CM_COV_Object centralMomentWithDefault(CM_COV_Object ret, ValueFunction f
}
@Override
- public CM_COV_Object centralMomentWithReference(CM_COV_Object ret, ValueFunction fn, int[] counts, double reference,
+ public CmCovObject centralMomentWithReference(CmCovObject ret, ValueFunction fn, int[] counts, double reference,
int nRows) {
// should be guaranteed to only contain one value per tuple in dictionary.
for(int i = 0; i < _values.length; i++)
diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/DictionaryFactory.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/DictionaryFactory.java
index f88ac99b87b..005d14f9ce1 100644
--- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/DictionaryFactory.java
+++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/DictionaryFactory.java
@@ -52,7 +52,7 @@ public interface DictionaryFactory {
static final Log LOG = LogFactory.getLog(DictionaryFactory.class.getName());
public enum Type {
- FP64_DICT, MATRIX_BLOCK_DICT, INT8_DICT, IDENTITY, IDENTITY_SLICE, PLACE_HOLDER
+ FP64_DICT, MATRIX_BLOCK_DICT, INT8_DICT, IDENTITY, IDENTITY_SLICE, PLACE_HOLDER, DELTA_DICT
}
public static IDictionary read(DataInput in) throws IOException {
@@ -68,6 +68,8 @@ public static IDictionary read(DataInput in) throws IOException {
return IdentityDictionary.read(in);
case IDENTITY_SLICE:
return IdentityDictionarySlice.read(in);
+ case DELTA_DICT:
+ return DeltaDictionary.read(in);
case MATRIX_BLOCK_DICT:
default:
return MatrixBlockDictionary.read(in);
diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/IDictionary.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/IDictionary.java
index dddea0eec7a..49330ba2748 100644
--- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/IDictionary.java
+++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/IDictionary.java
@@ -29,7 +29,7 @@
import org.apache.sysds.runtime.data.SparseBlock;
import org.apache.sysds.runtime.functionobjects.Builtin;
import org.apache.sysds.runtime.functionobjects.ValueFunction;
-import org.apache.sysds.runtime.instructions.cp.CM_COV_Object;
+import org.apache.sysds.runtime.instructions.cp.CmCovObject;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
import org.apache.sysds.runtime.matrix.operators.BinaryOperator;
import org.apache.sysds.runtime.matrix.operators.ScalarOperator;
@@ -700,7 +700,7 @@ public IDictionary preaggValuesFromDense(final int numVals, final IColIndex colI
* @param nRows The number of rows in total of the column group
* @return The central moment Object
*/
- public CM_COV_Object centralMoment(ValueFunction fn, int[] counts, int nRows);
+ public CmCovObject centralMoment(ValueFunction fn, int[] counts, int nRows);
/**
* Central moment function to calculate the central moment of this column group. MUST be on a single column
@@ -712,7 +712,7 @@ public IDictionary preaggValuesFromDense(final int numVals, final IColIndex colI
* @param nRows The number of rows in total of the column group
* @return The central moment Object
*/
- public CM_COV_Object centralMoment(CM_COV_Object ret, ValueFunction fn, int[] counts, int nRows);
+ public CmCovObject centralMoment(CmCovObject ret, ValueFunction fn, int[] counts, int nRows);
/**
* Central moment function to calculate the central moment of this column group with a default offset on all missing
@@ -724,7 +724,7 @@ public IDictionary preaggValuesFromDense(final int numVals, final IColIndex colI
* @param nRows The number of rows in total of the column group
* @return The central moment Object
*/
- public CM_COV_Object centralMomentWithDefault(ValueFunction fn, int[] counts, double def, int nRows);
+ public CmCovObject centralMomentWithDefault(ValueFunction fn, int[] counts, double def, int nRows);
/**
* Central moment function to calculate the central moment of this column group with a default offset on all missing
@@ -737,7 +737,7 @@ public IDictionary preaggValuesFromDense(final int numVals, final IColIndex colI
* @param nRows The number of rows in total of the column group
* @return The central moment Object
*/
- public CM_COV_Object centralMomentWithDefault(CM_COV_Object ret, ValueFunction fn, int[] counts, double def,
+ public CmCovObject centralMomentWithDefault(CmCovObject ret, ValueFunction fn, int[] counts, double def,
int nRows);
/**
@@ -750,7 +750,7 @@ public CM_COV_Object centralMomentWithDefault(CM_COV_Object ret, ValueFunction f
* @param nRows The number of rows in total of the column group
* @return The central moment Object
*/
- public CM_COV_Object centralMomentWithReference(ValueFunction fn, int[] counts, double reference, int nRows);
+ public CmCovObject centralMomentWithReference(ValueFunction fn, int[] counts, double reference, int nRows);
/**
* Central moment function to calculate the central moment of this column group with a reference offset on each
@@ -763,7 +763,7 @@ public CM_COV_Object centralMomentWithDefault(CM_COV_Object ret, ValueFunction f
* @param nRows The number of rows in total of the column group
* @return The central moment Object
*/
- public CM_COV_Object centralMomentWithReference(CM_COV_Object ret, ValueFunction fn, int[] counts, double reference,
+ public CmCovObject centralMomentWithReference(CmCovObject ret, ValueFunction fn, int[] counts, double reference,
int nRows);
/**
diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/MatrixBlockDictionary.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/MatrixBlockDictionary.java
index 24776f3adc4..71a4112f157 100644
--- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/MatrixBlockDictionary.java
+++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/MatrixBlockDictionary.java
@@ -51,7 +51,7 @@
import org.apache.sysds.runtime.functionobjects.Minus;
import org.apache.sysds.runtime.functionobjects.Plus;
import org.apache.sysds.runtime.functionobjects.ValueFunction;
-import org.apache.sysds.runtime.instructions.cp.CM_COV_Object;
+import org.apache.sysds.runtime.instructions.cp.CmCovObject;
import org.apache.sysds.runtime.matrix.data.LibMatrixAgg;
import org.apache.sysds.runtime.matrix.data.LibMatrixBincell;
import org.apache.sysds.runtime.matrix.data.LibMatrixReorg;
@@ -2425,7 +2425,7 @@ else if(!Double.isInfinite(ret[0]))
}
@Override
- public CM_COV_Object centralMoment(CM_COV_Object ret, ValueFunction fn, int[] counts, int nRows) {
+ public CmCovObject centralMoment(CmCovObject ret, ValueFunction fn, int[] counts, int nRows) {
// should be guaranteed to only contain one value per tuple in dictionary.
if(_data.isInSparseFormat())
throw new DMLCompressionException("The dictionary should not be sparse with one column");
@@ -2438,7 +2438,7 @@ public CM_COV_Object centralMoment(CM_COV_Object ret, ValueFunction fn, int[] co
}
@Override
- public CM_COV_Object centralMomentWithDefault(CM_COV_Object ret, ValueFunction fn, int[] counts, double def,
+ public CmCovObject centralMomentWithDefault(CmCovObject ret, ValueFunction fn, int[] counts, double def,
int nRows) {
// should be guaranteed to only contain one value per tuple in dictionary.
if(_data.isInSparseFormat())
@@ -2453,7 +2453,7 @@ public CM_COV_Object centralMomentWithDefault(CM_COV_Object ret, ValueFunction f
}
@Override
- public CM_COV_Object centralMomentWithReference(CM_COV_Object ret, ValueFunction fn, int[] counts, double reference,
+ public CmCovObject centralMomentWithReference(CmCovObject ret, ValueFunction fn, int[] counts, double reference,
int nRows) {
// should be guaranteed to only contain one value per tuple in dictionary.
if(_data.isInSparseFormat())
diff --git a/src/main/java/org/apache/sysds/runtime/compress/estim/AComEst.java b/src/main/java/org/apache/sysds/runtime/compress/estim/AComEst.java
index 2dce0bafe4e..ef7981e941b 100644
--- a/src/main/java/org/apache/sysds/runtime/compress/estim/AComEst.java
+++ b/src/main/java/org/apache/sysds/runtime/compress/estim/AComEst.java
@@ -197,7 +197,10 @@ public final CompressedSizeInfoColGroup combine(IColIndex combinedColumns, Compr
return null; // This combination is clearly not a good idea return null to indicate that.
else if(g1.getMap() == null || g2.getMap() == null)
// the previous information did not contain maps, therefore fall back to extract from sample
- return getColGroupInfo(combinedColumns, Math.max(g1V, g2V), (int) max);
+ if(_cs.preferDeltaEncoding)
+ return getDeltaColGroupInfo(combinedColumns, Math.max(g1V, g2V), (int) max);
+ else
+ return getColGroupInfo(combinedColumns, Math.max(g1V, g2V), (int) max);
else // Default combine the previous subject to max value calculated.
return combine(combinedColumns, g1, g2, (int) max);
}
@@ -254,8 +257,12 @@ private List CompressedSizeInfoColGroupSingleThread(
List ret = new ArrayList<>(clen);
if(!_cs.transposed && !_data.isEmpty() && _data.isInSparseFormat())
nnzCols = LibMatrixReorg.countNnzPerColumn(_data);
- for(int col = 0; col < clen; col++)
- ret.add(getColGroupInfo(new SingleIndex(col)));
+ for(int col = 0; col < clen; col++) {
+ if(_cs.preferDeltaEncoding)
+ ret.add(getDeltaColGroupInfo(new SingleIndex(col)));
+ else
+ ret.add(getColGroupInfo(new SingleIndex(col)));
+ }
return ret;
}
@@ -286,9 +293,14 @@ private List CompressedSizeInfoColGroupParallel(int
for(int col = 0; col < clen; col += blkz) {
final int start = col;
final int end = Math.min(clen, col + blkz);
+ final boolean useDelta = _cs.preferDeltaEncoding;
tasks.add(pool.submit(() -> {
- for(int c = start; c < end; c++)
- res[c] = getColGroupInfo(new SingleIndex(c));
+ for(int c = start; c < end; c++) {
+ if(useDelta)
+ res[c] = getDeltaColGroupInfo(new SingleIndex(c));
+ else
+ res[c] = getColGroupInfo(new SingleIndex(c));
+ }
return null;
}));
}
diff --git a/src/main/java/org/apache/sysds/runtime/compress/estim/CompressedSizeInfoColGroup.java b/src/main/java/org/apache/sysds/runtime/compress/estim/CompressedSizeInfoColGroup.java
index 963a044d14f..df353931c0b 100644
--- a/src/main/java/org/apache/sysds/runtime/compress/estim/CompressedSizeInfoColGroup.java
+++ b/src/main/java/org/apache/sysds/runtime/compress/estim/CompressedSizeInfoColGroup.java
@@ -255,7 +255,11 @@ private static double getCompressionSize(IColIndex cols, CompressionType ct, Est
case LinearFunctional:
return ColGroupSizes.estimateInMemorySizeLinearFunctional(numCols, contiguousColumns);
case DeltaDDC:
- throw new NotImplementedException();
+ // DeltaDDC has the same size estimation as DDC since it uses the same structure
+ // The delta encoding is just a different way of interpreting the data
+ nv = fact.numVals + (fact.numOffs < fact.numRows ? 1 : 0);
+ return ColGroupSizes.estimateInMemorySizeDDC(numCols, contiguousColumns, nv, fact.numRows,
+ fact.tupleSparsity, fact.lossy);
case DDC:
nv = fact.numVals + (fact.numOffs < fact.numRows ? 1 : 0);
return ColGroupSizes.estimateInMemorySizeDDC(numCols, contiguousColumns, nv, fact.numRows,
diff --git a/src/main/java/org/apache/sysds/runtime/compress/estim/encoding/EncodingFactory.java b/src/main/java/org/apache/sysds/runtime/compress/estim/encoding/EncodingFactory.java
index b196da658c3..efe1e14c47a 100644
--- a/src/main/java/org/apache/sysds/runtime/compress/estim/encoding/EncodingFactory.java
+++ b/src/main/java/org/apache/sysds/runtime/compress/estim/encoding/EncodingFactory.java
@@ -88,9 +88,8 @@ else if(rowCols.size() == 1) {
* @return A delta encoded encoding.
*/
public static IEncode createFromMatrixBlockDelta(MatrixBlock m, boolean transposed, IColIndex rowCols) {
- throw new NotImplementedException();
- // final int sampleSize = transposed ? m.getNumColumns() : m.getNumRows();
- // return createFromMatrixBlockDelta(m, transposed, rowCols, sampleSize);
+ final int sampleSize = transposed ? m.getNumColumns() : m.getNumRows();
+ return createFromMatrixBlockDelta(m, transposed, rowCols, sampleSize);
}
/**
@@ -107,7 +106,7 @@ public static IEncode createFromMatrixBlockDelta(MatrixBlock m, boolean transpos
*/
public static IEncode createFromMatrixBlockDelta(MatrixBlock m, boolean transposed, IColIndex rowCols,
int sampleSize) {
- throw new NotImplementedException();
+ return createWithDeltaReader(m, rowCols, transposed, sampleSize);
}
/**
@@ -691,4 +690,68 @@ private static IEncode createWithReaderSparse(MatrixBlock m, DblArrayCountHashMa
public static SparseEncoding createSparse(AMapToData map, AOffset off, int nRows) {
return new SparseEncoding(map, off, nRows);
}
+
+ private static IEncode createWithDeltaReader(MatrixBlock m, IColIndex rowCols, boolean transposed, int sampleSize) {
+ final int rl = 0;
+ final int ru = Math.min(sampleSize, transposed ? m.getNumColumns() : m.getNumRows());
+ final ReaderColumnSelection reader1 = ReaderColumnSelection.createDeltaReader(m, rowCols, transposed, rl, ru);
+ final DblArrayCountHashMap map = new DblArrayCountHashMap();
+ final IntArrayList offsets = new IntArrayList();
+ DblArray cellVals = reader1.nextRow();
+ boolean isFirstRow = true;
+
+ while(cellVals != null) {
+ map.increment(cellVals);
+ if(isFirstRow || !cellVals.isEmpty())
+ offsets.appendValue(reader1.getCurrentRowIndex());
+ isFirstRow = false;
+ cellVals = reader1.nextRow();
+ }
+
+ if(offsets.size() == 0)
+ return new EmptyEncoding();
+ else if(map.size() == 1 && offsets.size() == ru)
+ return new ConstEncoding(ru);
+
+ final ReaderColumnSelection reader2 = ReaderColumnSelection.createDeltaReader(m, rowCols, transposed, rl, ru);
+ if(offsets.size() < ru / 4)
+ return createWithDeltaReaderSparse(m, map, rowCols, offsets, ru, reader2);
+ else
+ return createWithDeltaReaderDense(m, map, rowCols, ru, offsets.size() < ru, reader2);
+ }
+
+ private static IEncode createWithDeltaReaderDense(MatrixBlock m, DblArrayCountHashMap map, IColIndex rowCols,
+ int nRows, boolean zero, ReaderColumnSelection reader2) {
+ final int unique = map.size() + (zero ? 1 : 0);
+ final AMapToData d = MapToFactory.create(nRows, unique);
+
+ DblArray cellVals;
+ if(zero)
+ while((cellVals = reader2.nextRow()) != null)
+ d.set(reader2.getCurrentRowIndex(), map.getId(cellVals) + 1);
+ else
+ while((cellVals = reader2.nextRow()) != null)
+ d.set(reader2.getCurrentRowIndex(), map.getId(cellVals));
+
+ return new DenseEncoding(d);
+ }
+
+ private static IEncode createWithDeltaReaderSparse(MatrixBlock m, DblArrayCountHashMap map, IColIndex rowCols,
+ IntArrayList offsets, int nRows, ReaderColumnSelection reader2) {
+ DblArray cellVals = reader2.nextRow();
+ final AMapToData d = MapToFactory.create(offsets.size(), map.size());
+
+ int i = 0;
+ boolean isFirstRow = true;
+ while(cellVals != null) {
+ if(isFirstRow || !cellVals.isEmpty()) {
+ d.set(i++, map.getId(cellVals));
+ }
+ isFirstRow = false;
+ cellVals = reader2.nextRow();
+ }
+
+ final AOffset o = OffsetFactory.createOffset(offsets);
+ return new SparseEncoding(d, o, nRows);
+ }
}
diff --git a/src/main/java/org/apache/sysds/runtime/compress/io/WriterCompressed.java b/src/main/java/org/apache/sysds/runtime/compress/io/WriterCompressed.java
index c4d9db367bb..1ba0ba61d10 100644
--- a/src/main/java/org/apache/sysds/runtime/compress/io/WriterCompressed.java
+++ b/src/main/java/org/apache/sysds/runtime/compress/io/WriterCompressed.java
@@ -46,7 +46,7 @@
import org.apache.sysds.runtime.compress.lib.CLALibSeparator;
import org.apache.sysds.runtime.compress.lib.CLALibSeparator.SeparatedGroups;
import org.apache.sysds.runtime.compress.lib.CLALibSlice;
-import org.apache.sysds.runtime.controlprogram.parfor.LocalTaskQueue;
+import org.apache.sysds.runtime.instructions.ooc.OOCStream;
import org.apache.sysds.runtime.instructions.spark.CompressionSPInstruction.CompressionFunction;
import org.apache.sysds.runtime.instructions.spark.data.IndexedMatrixValue;
import org.apache.sysds.runtime.instructions.spark.utils.RDDConverterUtils;
@@ -410,7 +410,7 @@ public Object call() throws Exception {
}
@Override
- public long writeMatrixFromStream(String fname, LocalTaskQueue stream, long rlen, long clen, int blen) {
+ public long writeMatrixFromStream(String fname, OOCStream stream, long rlen, long clen, int blen) {
throw new UnsupportedOperationException("Writing from an OOC stream is not supported for the HDF5 format.");
};
diff --git a/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibCMOps.java b/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibCMOps.java
index 3e77aaefb9e..644e73697a4 100644
--- a/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibCMOps.java
+++ b/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibCMOps.java
@@ -23,7 +23,7 @@
import org.apache.sysds.runtime.compress.CompressedMatrixBlock;
import org.apache.sysds.runtime.compress.colgroup.AColGroup;
-import org.apache.sysds.runtime.instructions.cp.CM_COV_Object;
+import org.apache.sysds.runtime.instructions.cp.CmCovObject;
import org.apache.sysds.runtime.matrix.data.LibMatrixAgg;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
import org.apache.sysds.runtime.matrix.operators.CMOperator;
@@ -34,7 +34,7 @@ private CLALibCMOps() {
// private constructor
}
- public static CM_COV_Object centralMoment(CompressedMatrixBlock cmb, CMOperator op) {
+ public static CmCovObject centralMoment(CompressedMatrixBlock cmb, CMOperator op) {
MatrixBlock.checkCMOperations(cmb, op);
if(cmb.isEmpty())
return LibMatrixAgg.aggregateCmCov(cmb, null, null, op.fn);
diff --git a/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibUnary.java b/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibUnary.java
index f858f15b746..cc0ff901df4 100644
--- a/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibUnary.java
+++ b/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibUnary.java
@@ -21,10 +21,15 @@
import java.util.ArrayList;
import java.util.List;
+import org.apache.commons.lang3.tuple.Pair;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.sysds.runtime.compress.CompressedMatrixBlock;
+import org.apache.sysds.runtime.compress.CompressedMatrixBlockFactory;
+import org.apache.sysds.runtime.compress.CompressionSettingsBuilder;
+import org.apache.sysds.runtime.compress.CompressionStatistics;
import org.apache.sysds.runtime.compress.colgroup.AColGroup;
+import org.apache.sysds.runtime.compress.colgroup.AColGroup.CompressionType;
import org.apache.sysds.runtime.functionobjects.Builtin;
import org.apache.sysds.runtime.functionobjects.Builtin.BuiltinCode;
import org.apache.sysds.runtime.matrix.data.LibMatrixAgg;
@@ -43,10 +48,40 @@ public static MatrixBlock unaryOperations(CompressedMatrixBlock m, UnaryOperator
final boolean overlapping = m.isOverlapping();
final int r = m.getNumRows();
final int c = m.getNumColumns();
+
// early aborts:
if(m.isEmpty())
return new MatrixBlock(r, c, 0).unaryOperations(op, result);
- else if(overlapping) {
+
+ if(Builtin.isBuiltinCode(op.fn, BuiltinCode.CUMSUM)) {
+ MatrixBlock uncompressed = m.getUncompressed("CUMSUM requires uncompression", op.getNumThreads());
+ MatrixBlock opResult = uncompressed.unaryOperations(op, null);
+
+ CompressionSettingsBuilder csb = new CompressionSettingsBuilder();
+ csb.clearValidCompression();
+ csb.setPreferDeltaEncoding(true);
+ csb.addValidCompression(CompressionType.DeltaDDC);
+ csb.addValidCompression(CompressionType.UNCOMPRESSED);
+ csb.setTransposeInput("false");
+ Pair compressedPair = CompressedMatrixBlockFactory.compress(opResult, op.getNumThreads(), csb);
+ MatrixBlock compressedResult = compressedPair.getLeft();
+
+ if(compressedResult == null) {
+ compressedResult = opResult;
+ }
+
+ CompressedMatrixBlock finalResult;
+ if(compressedResult instanceof CompressedMatrixBlock) {
+ finalResult = (CompressedMatrixBlock) compressedResult;
+ }
+ else {
+ finalResult = CompressedMatrixBlockFactory.genUncompressedCompressedMatrixBlock(compressedResult);
+ }
+
+ return finalResult;
+ }
+
+ if(overlapping) {
// when in overlapping state it is guaranteed that there is no infinites, NA, or NANs.
if(Builtin.isBuiltinCode(op.fn, BuiltinCode.ISINF, BuiltinCode.ISNA, BuiltinCode.ISNAN))
return new MatrixBlock(r, c, 0);
@@ -64,8 +99,9 @@ else if(Builtin.isBuiltinCode(op.fn, BuiltinCode.ISINF, BuiltinCode.ISNAN, Built
return new MatrixBlock(r, c, 0); // avoid unnecessary allocation
else if(LibMatrixAgg.isSupportedUnaryOperator(op)) {
String message = "Unary Op not supported: " + op.fn.getClass().getSimpleName();
- // e.g., cumsum/cumprod/cummin/cumax/cumsumprod
- return m.getUncompressed(message, op.getNumThreads()).unaryOperations(op, null);
+ MatrixBlock uncompressed = m.getUncompressed(message, op.getNumThreads());
+ MatrixBlock opResult = uncompressed.unaryOperations(op, null);
+ return opResult;
}
else {
diff --git a/src/main/java/org/apache/sysds/runtime/compress/readers/ReaderColumnSelection.java b/src/main/java/org/apache/sysds/runtime/compress/readers/ReaderColumnSelection.java
index d6ec60336f0..1734d39f4ce 100644
--- a/src/main/java/org/apache/sysds/runtime/compress/readers/ReaderColumnSelection.java
+++ b/src/main/java/org/apache/sysds/runtime/compress/readers/ReaderColumnSelection.java
@@ -193,13 +193,63 @@ else if(rawBlock.isInSparseFormat()) {
else {
return new ReaderColumnSelectionDenseSingleBlockQuantized(rawBlock, colIndices, rl, ru, scaleFactors);
}
- }
+ }
+
+ /**
+ * Create a reader of the matrix block that computes delta values (current row - previous row) on-the-fly.
+ *
+ * Note the reader reuse the return, therefore if needed for something please copy the returned rows.
+ * The first row is returned as-is (no delta computation).
+ *
+ * @param rawBlock The block to iterate though
+ * @param colIndices The column indexes to extract and insert into the double array
+ * @param transposed If the raw block should be treated as transposed
+ * @return A delta reader of the columns specified
+ */
+ public static ReaderColumnSelection createDeltaReader(MatrixBlock rawBlock, IColIndex colIndices, boolean transposed) {
+ final int rl = 0;
+ final int ru = transposed ? rawBlock.getNumColumns() : rawBlock.getNumRows();
+ return createDeltaReader(rawBlock, colIndices, transposed, rl, ru);
+ }
+
+ /**
+ * Create a reader of the matrix block that computes delta values (current row - previous row) on-the-fly.
+ *
+ * Note the reader reuse the return, therefore if needed for something please copy the returned rows.
+ * The first row is returned as-is (no delta computation).
+ *
+ * @param rawBlock The block to iterate though
+ * @param colIndices The column indexes to extract and insert into the double array
+ * @param transposed If the raw block should be treated as transposed
+ * @param rl The row to start at
+ * @param ru The row to end at (not inclusive)
+ * @return A delta reader of the columns specified
+ */
+ public static ReaderColumnSelection createDeltaReader(MatrixBlock rawBlock, IColIndex colIndices, boolean transposed,
+ int rl, int ru) {
+ checkInput(rawBlock, colIndices, rl, ru, transposed);
+ rl = rl - 1;
+ if(rawBlock.isEmpty()) {
+ LOG.warn("It is likely an error occurred when reading an empty block, but we do support it!");
+ return new ReaderColumnSelectionEmpty(rawBlock, colIndices, rl, ru, transposed);
+ }
+
+ if(transposed) {
+ throw new NotImplementedException("Delta encoding for transposed matrices not yet implemented");
+ }
+
+ if(rawBlock.isInSparseFormat())
+ return new ReaderColumnSelectionSparseDelta(rawBlock, colIndices, rl, ru);
+ else if(rawBlock.getDenseBlock().numBlocks() > 1)
+ return new ReaderColumnSelectionDenseMultiBlockDelta(rawBlock, colIndices, rl, ru);
+ return new ReaderColumnSelectionDenseSingleBlockDelta(rawBlock, colIndices, rl, ru);
+ }
private static void checkInput(final MatrixBlock rawBlock, final IColIndex colIndices, final int rl, final int ru,
final boolean transposed) {
- if(colIndices.size() <= 1)
+ if(colIndices.size() < 1)
throw new DMLCompressionException(
- "Column selection reader should not be done on single column groups: " + colIndices);
+ "Column selection reader should not be done on empty column groups: " + colIndices);
else if(rl >= ru)
throw new DMLCompressionException("Invalid inverse range for reader " + rl + " to " + ru);
diff --git a/src/main/java/org/apache/sysds/runtime/compress/readers/ReaderColumnSelectionDenseMultiBlockDelta.java b/src/main/java/org/apache/sysds/runtime/compress/readers/ReaderColumnSelectionDenseMultiBlockDelta.java
new file mode 100644
index 00000000000..f700ebd94b7
--- /dev/null
+++ b/src/main/java/org/apache/sysds/runtime/compress/readers/ReaderColumnSelectionDenseMultiBlockDelta.java
@@ -0,0 +1,63 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.runtime.compress.readers;
+
+import org.apache.sysds.runtime.compress.colgroup.indexes.IColIndex;
+import org.apache.sysds.runtime.compress.utils.DblArray;
+import org.apache.sysds.runtime.data.DenseBlock;
+import org.apache.sysds.runtime.matrix.data.MatrixBlock;
+
+public class ReaderColumnSelectionDenseMultiBlockDelta extends ReaderColumnSelection {
+ private final DenseBlock _data;
+ private final double[] _previousRow;
+ private boolean _isFirstRow;
+
+ protected ReaderColumnSelectionDenseMultiBlockDelta(MatrixBlock data, IColIndex colIndices, int rl, int ru) {
+ super(colIndices, rl, Math.min(ru, data.getNumRows()) - 1);
+ _data = data.getDenseBlock();
+ _previousRow = new double[colIndices.size()];
+ _isFirstRow = true;
+ }
+
+ protected DblArray getNextRow() {
+ _rl++;
+
+ if(_isFirstRow) {
+ for(int i = 0; i < _colIndexes.size(); i++) {
+ final double val = _data.get(_rl, _colIndexes.get(i));
+ _previousRow[i] = val;
+ reusableArr[i] = val;
+ }
+ _isFirstRow = false;
+ }
+ else {
+ for(int i = 0; i < _colIndexes.size(); i++) {
+ final double currentVal = _data.get(_rl, _colIndexes.get(i));
+ reusableArr[i] = currentVal - _previousRow[i];
+ _previousRow[i] = currentVal;
+ }
+ }
+
+ return reusableReturn;
+ }
+}
+
+
+
diff --git a/src/main/java/org/apache/sysds/runtime/compress/readers/ReaderColumnSelectionDenseSingleBlockDelta.java b/src/main/java/org/apache/sysds/runtime/compress/readers/ReaderColumnSelectionDenseSingleBlockDelta.java
new file mode 100644
index 00000000000..65f13343201
--- /dev/null
+++ b/src/main/java/org/apache/sysds/runtime/compress/readers/ReaderColumnSelectionDenseSingleBlockDelta.java
@@ -0,0 +1,65 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.runtime.compress.readers;
+
+import org.apache.sysds.runtime.compress.colgroup.indexes.IColIndex;
+import org.apache.sysds.runtime.compress.utils.DblArray;
+import org.apache.sysds.runtime.matrix.data.MatrixBlock;
+
+public class ReaderColumnSelectionDenseSingleBlockDelta extends ReaderColumnSelection {
+ private final double[] _data;
+ private final int _numCols;
+ private final double[] _previousRow;
+ private boolean _isFirstRow;
+
+ protected ReaderColumnSelectionDenseSingleBlockDelta(MatrixBlock data, IColIndex colIndices, int rl, int ru) {
+ super(colIndices, rl, Math.min(ru, data.getNumRows()) - 1);
+ _data = data.getDenseBlockValues();
+ _numCols = data.getNumColumns();
+ _previousRow = new double[colIndices.size()];
+ _isFirstRow = true;
+ }
+
+ protected DblArray getNextRow() {
+ _rl++;
+ final int indexOff = _rl * _numCols;
+
+ if(_isFirstRow) {
+ for(int i = 0; i < _colIndexes.size(); i++) {
+ final double val = _data[indexOff + _colIndexes.get(i)];
+ _previousRow[i] = val;
+ reusableArr[i] = val;
+ }
+ _isFirstRow = false;
+ }
+ else {
+ for(int i = 0; i < _colIndexes.size(); i++) {
+ final double currentVal = _data[indexOff + _colIndexes.get(i)];
+ reusableArr[i] = currentVal - _previousRow[i];
+ _previousRow[i] = currentVal;
+ }
+ }
+
+ return reusableReturn;
+ }
+}
+
+
+
diff --git a/src/main/java/org/apache/sysds/runtime/compress/readers/ReaderColumnSelectionSparseDelta.java b/src/main/java/org/apache/sysds/runtime/compress/readers/ReaderColumnSelectionSparseDelta.java
new file mode 100644
index 00000000000..8ea1fff3396
--- /dev/null
+++ b/src/main/java/org/apache/sysds/runtime/compress/readers/ReaderColumnSelectionSparseDelta.java
@@ -0,0 +1,91 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.runtime.compress.readers;
+
+import java.util.Arrays;
+
+import org.apache.sysds.runtime.compress.colgroup.indexes.IColIndex;
+import org.apache.sysds.runtime.compress.utils.DblArray;
+import org.apache.sysds.runtime.data.SparseBlock;
+import org.apache.sysds.runtime.matrix.data.MatrixBlock;
+
+public class ReaderColumnSelectionSparseDelta extends ReaderColumnSelection {
+
+ private final SparseBlock _a;
+ private final double[] _previousRow;
+ private boolean _isFirstRow;
+
+ protected ReaderColumnSelectionSparseDelta(MatrixBlock data, IColIndex colIndexes, int rl, int ru) {
+ super(colIndexes, rl, Math.min(ru, data.getNumRows()) - 1);
+ _a = data.getSparseBlock();
+ _previousRow = new double[colIndexes.size()];
+ _isFirstRow = true;
+ }
+
+ protected final DblArray getNextRow() {
+ _rl++;
+ for(int i = 0; i < _colIndexes.size(); i++)
+ reusableArr[i] = 0.0;
+
+ if(!_a.isEmpty(_rl))
+ processInRange(_rl);
+
+ if(_isFirstRow) {
+ for(int i = 0; i < _colIndexes.size(); i++)
+ _previousRow[i] = reusableArr[i];
+ _isFirstRow = false;
+ }
+ else {
+ for(int i = 0; i < _colIndexes.size(); i++) {
+ final double currentVal = reusableArr[i];
+ reusableArr[i] = currentVal - _previousRow[i];
+ _previousRow[i] = currentVal;
+ }
+ }
+
+ return reusableReturn;
+ }
+
+ final void processInRange(final int r) {
+ final int apos = _a.pos(r);
+ final int alen = _a.size(r) + apos;
+ final int[] aix = _a.indexes(r);
+ final double[] avals = _a.values(r);
+ int skip = 0;
+ int j = Arrays.binarySearch(aix, apos, alen, _colIndexes.get(0));
+ if(j < 0)
+ j = Math.abs(j + 1);
+
+ while(skip < _colIndexes.size() && j < alen) {
+ if(_colIndexes.get(skip) == aix[j]) {
+ reusableArr[skip] = avals[j];
+ skip++;
+ j++;
+ }
+ else if(_colIndexes.get(skip) > aix[j])
+ j++;
+ else
+ skip++;
+ }
+ }
+}
+
+
+
diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/caching/CacheableData.java b/src/main/java/org/apache/sysds/runtime/controlprogram/caching/CacheableData.java
index 34a8aa18631..36637ee8959 100644
--- a/src/main/java/org/apache/sysds/runtime/controlprogram/caching/CacheableData.java
+++ b/src/main/java/org/apache/sysds/runtime/controlprogram/caching/CacheableData.java
@@ -43,7 +43,6 @@
import org.apache.sysds.runtime.compress.CompressedMatrixBlock;
import org.apache.sysds.runtime.controlprogram.caching.LazyWriteBuffer.RPolicy;
import org.apache.sysds.runtime.controlprogram.federated.FederationMap;
-import org.apache.sysds.runtime.controlprogram.parfor.LocalTaskQueue;
import org.apache.sysds.runtime.controlprogram.parfor.util.IDSequence;
import org.apache.sysds.runtime.instructions.cp.Data;
import org.apache.sysds.runtime.instructions.fed.InitFEDInstruction;
@@ -471,12 +470,12 @@ public boolean hasBroadcastHandle() {
return _bcHandle != null && _bcHandle.hasBackReference();
}
- public OOCStream getStreamHandle() {
+ public synchronized OOCStream getStreamHandle() {
if( !hasStreamHandle() ) {
final SubscribableTaskQueue _mStream = new SubscribableTaskQueue<>();
- _streamHandle = _mStream;
DataCharacteristics dc = getDataCharacteristics();
MatrixBlock src = (MatrixBlock)acquireReadAndRelease();
+ _streamHandle = _mStream;
LongStream.range(0, dc.getNumBlocks())
.mapToObj(i -> UtilFunctions.createIndexedMatrixBlock(src, dc, i))
.forEach( blk -> {
@@ -489,7 +488,14 @@ public OOCStream getStreamHandle() {
_mStream.closeInput();
}
- return _streamHandle.getReadStream();
+ OOCStream stream = _streamHandle.getReadStream();
+ if (!stream.hasStreamCache())
+ _streamHandle = null; // To ensure read once
+ return stream;
+ }
+
+ public OOCStreamable getStreamable() {
+ return _streamHandle;
}
/**
@@ -499,7 +505,7 @@ public OOCStream getStreamHandle() {
* @return true if existing, false otherwise
*/
public boolean hasStreamHandle() {
- return _streamHandle != null && !_streamHandle.isProcessed();
+ return _streamHandle != null;
}
@SuppressWarnings({ "rawtypes", "unchecked" })
@@ -626,7 +632,7 @@ && getRDDHandle() == null) ) {
_requiresLocalWrite = false;
}
else if( hasStreamHandle() ) {
- _data = readBlobFromStream( getStreamHandle().toLocalTaskQueue() );
+ _data = readBlobFromStream( getStreamHandle() );
}
else if( getRDDHandle()==null || getRDDHandle().allowsShortCircuitRead() ) {
if( DMLScript.STATISTICS )
@@ -1161,7 +1167,7 @@ protected abstract T readBlobFromHDFS(String fname, long[] dims)
protected abstract T readBlobFromRDD(RDDObject rdd, MutableBoolean status)
throws IOException;
- protected abstract T readBlobFromStream(LocalTaskQueue stream)
+ protected abstract T readBlobFromStream(OOCStream stream)
throws IOException;
// Federated read
diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/caching/FrameObject.java b/src/main/java/org/apache/sysds/runtime/controlprogram/caching/FrameObject.java
index f4d20bb55a0..7151d87211c 100644
--- a/src/main/java/org/apache/sysds/runtime/controlprogram/caching/FrameObject.java
+++ b/src/main/java/org/apache/sysds/runtime/controlprogram/caching/FrameObject.java
@@ -33,8 +33,8 @@
import org.apache.sysds.runtime.controlprogram.federated.FederatedRange;
import org.apache.sysds.runtime.controlprogram.federated.FederatedResponse;
import org.apache.sysds.runtime.controlprogram.federated.FederationMap;
-import org.apache.sysds.runtime.controlprogram.parfor.LocalTaskQueue;
import org.apache.sysds.runtime.frame.data.FrameBlock;
+import org.apache.sysds.runtime.instructions.ooc.OOCStream;
import org.apache.sysds.runtime.instructions.spark.data.IndexedMatrixValue;
import org.apache.sysds.runtime.instructions.spark.data.RDDObject;
import org.apache.sysds.runtime.io.FileFormatProperties;
@@ -316,7 +316,7 @@ protected void writeBlobFromRDDtoHDFS(RDDObject rdd, String fname, String ofmt)
}
@Override
- protected FrameBlock readBlobFromStream(LocalTaskQueue stream) throws IOException {
+ protected FrameBlock readBlobFromStream(OOCStream stream) throws IOException {
// TODO Auto-generated method stub
return null;
}
diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/caching/MatrixObject.java b/src/main/java/org/apache/sysds/runtime/controlprogram/caching/MatrixObject.java
index 8191040eb18..0b1a1ee27cb 100644
--- a/src/main/java/org/apache/sysds/runtime/controlprogram/caching/MatrixObject.java
+++ b/src/main/java/org/apache/sysds/runtime/controlprogram/caching/MatrixObject.java
@@ -45,6 +45,7 @@
import org.apache.sysds.runtime.controlprogram.federated.FederationUtils;
import org.apache.sysds.runtime.controlprogram.parfor.LocalTaskQueue;
import org.apache.sysds.runtime.instructions.fed.InitFEDInstruction;
+import org.apache.sysds.runtime.instructions.ooc.OOCStream;
import org.apache.sysds.runtime.instructions.spark.data.IndexedMatrixValue;
import org.apache.sysds.runtime.instructions.spark.data.RDDObject;
import org.apache.sysds.runtime.io.FileFormatProperties;
@@ -528,17 +529,16 @@ protected MatrixBlock readBlobFromRDD(RDDObject rdd, MutableBoolean writeStatus)
@Override
- protected MatrixBlock readBlobFromStream(LocalTaskQueue stream) throws IOException {
+ protected MatrixBlock readBlobFromStream(OOCStream stream) throws IOException {
boolean dimsUnknown = getNumRows() < 0 || getNumColumns() < 0;
int nrows = (int)getNumRows();
int ncols = (int)getNumColumns();
MatrixBlock ret = dimsUnknown ? null : new MatrixBlock((int)getNumRows(), (int)getNumColumns(), false);
- // TODO if stream is CachingStream, block parts might be evicted resulting in null pointer exceptions
List blockCache = dimsUnknown ? new ArrayList<>() : null;
IndexedMatrixValue tmp = null;
try {
int blen = getBlocksize(), lnnz = 0;
- while( (tmp = stream.dequeueTask()) != LocalTaskQueue.NO_MORE_TASKS ) {
+ while( (tmp = stream.dequeue()) != LocalTaskQueue.NO_MORE_TASKS ) {
// compute row/column block offsets
final int row_offset = (int) (tmp.getIndexes().getRowIndex() - 1) * blen;
final int col_offset = (int) (tmp.getIndexes().getColumnIndex() - 1) * blen;
@@ -559,12 +559,12 @@ protected MatrixBlock readBlobFromStream(LocalTaskQueue stre
if (dimsUnknown) {
ret = new MatrixBlock(nrows, ncols, false);
- for (IndexedMatrixValue _tmp : blockCache) {
+ for (IndexedMatrixValue tmp2 : blockCache) {
// compute row/column block offsets
- final int row_offset = (int) (_tmp.getIndexes().getRowIndex() - 1) * blen;
- final int col_offset = (int) (_tmp.getIndexes().getColumnIndex() - 1) * blen;
+ final int row_offset = (int) (tmp2.getIndexes().getRowIndex() - 1) * blen;
+ final int col_offset = (int) (tmp2.getIndexes().getColumnIndex() - 1) * blen;
- ((MatrixBlock) _tmp.getValue()).putInto(ret, row_offset, col_offset, true);
+ ((MatrixBlock) tmp2.getValue()).putInto(ret, row_offset, col_offset, true);
}
}
@@ -636,7 +636,7 @@ protected long writeStreamToHDFS(String fname, String ofmt, int rep, FileFormatP
MetaDataFormat iimd = (MetaDataFormat) _metaData;
FileFormat fmt = (ofmt != null ? FileFormat.safeValueOf(ofmt) : iimd.getFileFormat());
MatrixWriter writer = MatrixWriterFactory.createMatrixWriter(fmt, rep, fprop);
- return writer.writeMatrixFromStream(fname, getStreamHandle().toLocalTaskQueue(),
+ return writer.writeMatrixFromStream(fname, getStreamHandle(),
getNumRows(), getNumColumns(), ConfigurationManager.getBlocksize());
}
diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/caching/TensorObject.java b/src/main/java/org/apache/sysds/runtime/controlprogram/caching/TensorObject.java
index d0111a34300..474db9a65fe 100644
--- a/src/main/java/org/apache/sysds/runtime/controlprogram/caching/TensorObject.java
+++ b/src/main/java/org/apache/sysds/runtime/controlprogram/caching/TensorObject.java
@@ -30,9 +30,9 @@
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.controlprogram.context.SparkExecutionContext;
import org.apache.sysds.runtime.controlprogram.federated.FederationMap;
-import org.apache.sysds.runtime.controlprogram.parfor.LocalTaskQueue;
import org.apache.sysds.runtime.data.TensorBlock;
import org.apache.sysds.runtime.data.TensorIndexes;
+import org.apache.sysds.runtime.instructions.ooc.OOCStream;
import org.apache.sysds.runtime.instructions.spark.data.IndexedMatrixValue;
import org.apache.sysds.runtime.instructions.spark.data.RDDObject;
import org.apache.sysds.runtime.io.FileFormatProperties;
@@ -210,7 +210,7 @@ protected void writeBlobFromRDDtoHDFS(RDDObject rdd, String fname, String ofmt)
@Override
- protected TensorBlock readBlobFromStream(LocalTaskQueue stream) throws IOException {
+ protected TensorBlock readBlobFromStream(OOCStream stream) throws IOException {
// TODO Auto-generated method stub
return null;
}
diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/ParamServer.java b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/ParamServer.java
index 83b612de054..333c889b7c1 100644
--- a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/ParamServer.java
+++ b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/ParamServer.java
@@ -448,10 +448,10 @@ protected void updateAndBroadcastModel(ListObject new_model, Timing tAgg, boolea
}
protected ListObject weightModels(ListObject params, int numWorkers) {
- double _averagingFactor = 1d / numWorkers;
+ double averagingFactor = 1d / numWorkers;
- if( _averagingFactor != 1) {
- double final_averagingFactor = _averagingFactor;
+ if( averagingFactor != 1) {
+ double final_averagingFactor = averagingFactor;
params.getData().parallelStream().forEach((matrix) -> {
MatrixObject matrixObject = (MatrixObject) matrix;
MatrixBlock input = matrixObject.acquireReadAndRelease().scalarOperations(
diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/parfor/LocalTaskQueue.java b/src/main/java/org/apache/sysds/runtime/controlprogram/parfor/LocalTaskQueue.java
index 783981e0f12..1849ad066b3 100644
--- a/src/main/java/org/apache/sysds/runtime/controlprogram/parfor/LocalTaskQueue.java
+++ b/src/main/java/org/apache/sysds/runtime/controlprogram/parfor/LocalTaskQueue.java
@@ -45,7 +45,7 @@ public class LocalTaskQueue
protected LinkedList _data = null;
protected boolean _closedInput = false;
- private DMLRuntimeException _failure = null;
+ protected DMLRuntimeException _failure = null;
private static final Log LOG = LogFactory.getLog(LocalTaskQueue.class.getName());
public LocalTaskQueue()
@@ -103,6 +103,10 @@ public synchronized T dequeueTask()
return t;
}
+ public synchronized boolean hasNext() {
+ return !_data.isEmpty() || _closedInput;
+ }
+
/**
* Synchronized (logical) insert of a NO_MORE_TASKS symbol at the end of the FIFO queue in order to
* mark that no more tasks will be inserted into the queue.
diff --git a/src/main/java/org/apache/sysds/runtime/einsum/EOpNode.java b/src/main/java/org/apache/sysds/runtime/einsum/EOpNode.java
new file mode 100644
index 00000000000..665e5ae5588
--- /dev/null
+++ b/src/main/java/org/apache/sysds/runtime/einsum/EOpNode.java
@@ -0,0 +1,68 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.runtime.einsum;
+
+import org.apache.commons.logging.Log;
+import org.apache.sysds.runtime.matrix.data.MatrixBlock;
+
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.List;
+
+public abstract class EOpNode {
+ public Character c1;
+ public Character c2;
+ public Integer dim1;
+ public Integer dim2;
+ public EOpNode(Character c1, Character c2, Integer dim1, Integer dim2) {
+ this.c1 = c1;
+ this.c2 = c2;
+ this.dim1 = dim1;
+ this.dim2 = dim2;
+ }
+
+ public String getOutputString() {
+ if(c1 == null) return "''";
+ if(c2 == null) return c1.toString();
+ return c1.toString() + c2.toString();
+ }
+ public abstract List getChildren();
+
+ public String[] recursivePrintString(){
+ List inpStrings = new ArrayList<>();
+ for (EOpNode node : getChildren()) {
+ inpStrings.add(node.recursivePrintString());
+ }
+ String[] inpRes = inpStrings.stream().flatMap(Arrays::stream).toArray(String[]::new);
+ String[] res = new String[1 + inpRes.length];
+
+ res[0] = this.toString();
+
+ for (int i=0; i inputs, int numOfThreads, Log LOG);
+
+ public abstract EOpNode reorderChildrenAndOptimize(EOpNode parent, Character outChar1, Character outChar2);
+}
+
diff --git a/src/main/java/org/apache/sysds/runtime/einsum/EOpNodeBinary.java b/src/main/java/org/apache/sysds/runtime/einsum/EOpNodeBinary.java
new file mode 100644
index 00000000000..071fb6706a2
--- /dev/null
+++ b/src/main/java/org/apache/sysds/runtime/einsum/EOpNodeBinary.java
@@ -0,0 +1,434 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.runtime.einsum;
+
+import org.apache.commons.lang3.tuple.Pair;
+import org.apache.commons.lang3.tuple.Triple;
+import org.apache.commons.logging.Log;
+import org.apache.sysds.runtime.functionobjects.Multiply;
+import org.apache.sysds.runtime.functionobjects.SwapIndex;
+import org.apache.sysds.runtime.instructions.cp.DoubleObject;
+import org.apache.sysds.runtime.instructions.cp.EinsumCPInstruction;
+import org.apache.sysds.runtime.instructions.cp.ScalarObject;
+import org.apache.sysds.runtime.matrix.data.LibMatrixMult;
+import org.apache.sysds.runtime.matrix.data.MatrixBlock;
+import org.apache.sysds.runtime.matrix.operators.BinaryOperator;
+import org.apache.sysds.runtime.matrix.operators.ReorgOperator;
+import org.apache.sysds.runtime.matrix.operators.SimpleOperator;
+
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.List;
+import java.util.function.Predicate;
+
+import static org.apache.sysds.runtime.instructions.cp.EinsumCPInstruction.ensureMatrixBlockColumnVector;
+import static org.apache.sysds.runtime.instructions.cp.EinsumCPInstruction.ensureMatrixBlockRowVector;
+
+public class EOpNodeBinary extends EOpNode {
+
+ public enum EBinaryOperand { // upper case: char remains, lower case: summed (reduced) dimension
+ ////// mm: //////
+ Ba_aC, // -> BC
+ aB_Ca, // -> CB
+ Ba_Ca, // -> BC
+ aB_aC, // -> BC
+
+ ////// element-wise multiplications and sums //////
+ aB_aB,// elemwise and colsum -> B
+ Ab_Ab, // elemwise and rowsum ->A
+ Ab_bA, // elemwise, either colsum or rowsum -> A
+ aB_Ba,
+ ab_ab,//M-M sum all
+ ab_ba, //M-M.T sum all
+ aB_a,// -> B
+ Ab_b, // -> A
+
+ ////// elementwise, no summations: //////
+ A_A,// v-elemwise -> A
+ AB_AB,// M-M elemwise -> AB
+ AB_BA, // M-M.T elemwise -> AB
+ AB_A, // M-v colwise -> BA!?
+ AB_B, // M-v rowwise -> AB
+
+ ////// other //////
+ a_a,// dot ->
+ A_B, // outer mult -> AB
+ A_scalar, // v-scalar
+ AB_scalar, // m-scalar
+ scalar_scalar
+ }
+ public EOpNode left;
+ public EOpNode right;
+ public EBinaryOperand operand;
+ private boolean transposeResult;
+ public EOpNodeBinary(EOpNode left, EOpNode right, EBinaryOperand operand){
+ super(null,null,null, null);
+ Character c1, c2;
+ Integer dim1, dim2;
+ switch(operand){
+ case Ba_aC -> {
+ c1=left.c1;
+ c2=right.c2;
+ dim1=left.dim1;
+ dim2=right.dim2;
+ }
+ case aB_Ca -> {
+ c1=left.c2;
+ c2=right.c1;
+ dim1=left.dim2;
+ dim2=right.dim1;
+ }
+ case Ba_Ca -> {
+ c1=left.c1;
+ c2=right.c1;
+ dim1=left.dim1;
+ dim2=right.dim1;
+ }
+ case aB_aC -> {
+ c1=left.c2;
+ c2=right.c2;
+ dim1=left.dim2;
+ dim2=right.dim2;
+ }
+ case aB_aB, aB_Ba, aB_a -> {
+ c1=left.c2;
+ c2=null;
+ dim1=left.dim2;
+ dim2=null;
+ }
+ case Ab_Ab, Ab_bA, Ab_b, A_A, A_scalar -> {
+ c1=left.c1;
+ c2=null;
+ dim1=left.dim1;
+ dim2=null;
+ }
+ case ab_ab, ab_ba, a_a, scalar_scalar -> {
+ c1=null;
+ c2=null;
+ dim1=null;
+ dim2=null;
+ }
+ case AB_AB, AB_BA, AB_A, AB_B, AB_scalar ->{
+ c1=left.c1;
+ c2=left.c2;
+ dim1=left.dim1;
+ dim2=left.dim2;
+ }
+ case A_B -> {
+ c1=left.c1;
+ c2=right.c1;
+ dim1=left.dim1;
+ dim2=right.dim1;
+ }
+ default -> throw new IllegalStateException("EOpNodeBinary Unexpected type: " + operand);
+ }
+ // super(c1, c2, dim1, dim2); // unavailable in JDK < 22
+ this.c1 = c1;
+ this.c2 = c2;
+ this.dim1 = dim1;
+ this.dim2 = dim2;
+ this.left = left;
+ this.right = right;
+ this.operand = operand;
+ }
+
+ public void setTransposeResult(boolean transposeResult){
+ this.transposeResult = transposeResult;
+ }
+
+ public static EOpNodeBinary combineMatrixMultiply(EOpNode left, EOpNode right) {
+ if (left.c2 == right.c1) { return new EOpNodeBinary(left, right, EBinaryOperand.Ba_aC); }
+ if (left.c2 == right.c2) { return new EOpNodeBinary(left, right, EBinaryOperand.Ba_Ca); }
+ if (left.c1 == right.c1) { return new EOpNodeBinary(left, right, EBinaryOperand.aB_aC); }
+ if (left.c1 == right.c2) {
+ var res = new EOpNodeBinary(left, right, EBinaryOperand.aB_Ca);
+ res.setTransposeResult(true);
+ return res;
+ }
+ throw new RuntimeException("EOpNodeBinary::combineMatrixMultiply: invalid matrix operation");
+ }
+
+ @Override
+ public List getChildren() {
+ return List.of(this.left, this.right);
+ }
+ @Override
+ public String toString() {
+ return this.getClass().getSimpleName()+" ("+ operand.toString()+") "+getOutputString();
+ }
+
+ @Override
+ public MatrixBlock computeEOpNode(List inputs, int numThreads, Log LOG) {
+ EOpNodeBinary bin = this;
+ MatrixBlock left = this.left.computeEOpNode(inputs, numThreads, LOG);
+ MatrixBlock right = this.right.computeEOpNode(inputs, numThreads, LOG);
+
+ //AggregateOperator agg = new AggregateOperator(0, Plus.getPlusFnObject());
+
+ MatrixBlock res;
+
+ switch (bin.operand){
+ case AB_AB -> {
+ res = MatrixBlock.naryOperations(new SimpleOperator(Multiply.getMultiplyFnObject()), new MatrixBlock[]{left, right},new ScalarObject[]{}, new MatrixBlock());
+ }
+ case A_A -> {
+ ensureMatrixBlockColumnVector(left);
+ ensureMatrixBlockColumnVector(right);
+ res = MatrixBlock.naryOperations(new SimpleOperator(Multiply.getMultiplyFnObject()), new MatrixBlock[]{left, right},new ScalarObject[]{}, new MatrixBlock());
+ }
+ case a_a -> {
+ ensureMatrixBlockColumnVector(left);
+ ensureMatrixBlockColumnVector(right);
+ res = new MatrixBlock(0.0);
+ res.allocateDenseBlock();
+ res.getDenseBlockValues()[0] = LibMatrixMult.dotProduct(left.getDenseBlockValues(), right.getDenseBlockValues(), 0,0 , left.getNumRows());
+ }
+ case Ab_Ab -> {
+ res = EOpNodeFuse.compute(EOpNodeFuse.EinsumRewriteType.AB_BA_B_A__A, List.of(left, right), new ArrayList<>(), new ArrayList<>(), new ArrayList<>(), new ArrayList<>(),
+ null, numThreads);
+ }
+ case aB_aB -> {
+ res = EOpNodeFuse.compute(EOpNodeFuse.EinsumRewriteType.AB_BA_A__B, List.of(left, right), new ArrayList<>(), new ArrayList<>(), new ArrayList<>(), new ArrayList<>(),
+ null, numThreads);
+ }
+ case ab_ab -> {
+ res = EOpNodeFuse.compute(EOpNodeFuse.EinsumRewriteType.AB_BA_B_A__, List.of(left, right), new ArrayList<>(), new ArrayList<>(), new ArrayList<>(), new ArrayList<>(),
+ null, numThreads);
+ }
+ case ab_ba -> {
+ res = EOpNodeFuse.compute(EOpNodeFuse.EinsumRewriteType.AB_BA_B_A__, List.of(left), List.of(right), new ArrayList<>(), new ArrayList<>(), new ArrayList<>(),
+ null, numThreads);
+ }
+ case Ab_bA -> {
+ res = EOpNodeFuse.compute(EOpNodeFuse.EinsumRewriteType.AB_BA_B_A__A, List.of(left), List.of(right), new ArrayList<>(), new ArrayList<>(), new ArrayList<>(),
+ null, numThreads);
+ }
+ case aB_Ba -> {
+ res = EOpNodeFuse.compute(EOpNodeFuse.EinsumRewriteType.AB_BA_A__B, List.of(left), List.of(right), new ArrayList<>(), new ArrayList<>(), new ArrayList<>(),
+ null, numThreads);
+ }
+ case AB_BA -> {
+ ReorgOperator transpose = new ReorgOperator(SwapIndex.getSwapIndexFnObject(), numThreads);
+ right = right.reorgOperations(transpose, new MatrixBlock(), 0, 0, 0);
+ res = MatrixBlock.naryOperations(new SimpleOperator(Multiply.getMultiplyFnObject()), new MatrixBlock[]{left, right},new ScalarObject[]{}, new MatrixBlock());
+ }
+ case Ba_aC -> {
+ res = LibMatrixMult.matrixMult(left,right, new MatrixBlock(), numThreads);
+ }
+ case aB_Ca -> {
+ res = LibMatrixMult.matrixMult(right,left, new MatrixBlock(), numThreads);
+ }
+ case Ba_Ca -> {
+ ReorgOperator transpose = new ReorgOperator(SwapIndex.getSwapIndexFnObject(), numThreads);
+ right = right.reorgOperations(transpose, new MatrixBlock(), 0, 0, 0);
+ res = LibMatrixMult.matrixMult(left,right, new MatrixBlock(), numThreads);
+ }
+ case aB_aC -> {
+ ReorgOperator transpose = new ReorgOperator(SwapIndex.getSwapIndexFnObject(), numThreads);
+ left = left.reorgOperations(transpose, new MatrixBlock(), 0, 0, 0);
+ res = LibMatrixMult.matrixMult(left, right, new MatrixBlock(), numThreads);
+ }
+ case A_scalar, AB_scalar -> {
+ res = MatrixBlock.naryOperations(new SimpleOperator(Multiply.getMultiplyFnObject()), new MatrixBlock[]{left},new ScalarObject[]{new DoubleObject(right.get(0,0))}, new MatrixBlock());
+ }
+ case AB_B -> {
+ ensureMatrixBlockRowVector(right);
+ res = left.binaryOperations(new BinaryOperator(Multiply.getMultiplyFnObject()), right);
+ }
+ case Ab_b -> {
+ res = EOpNodeFuse.compute(EOpNodeFuse.EinsumRewriteType.AB_BA_B_A__A, List.of(left), new ArrayList<>(), List.of(right), new ArrayList<>(), new ArrayList<>(),
+ null, numThreads);
+ }
+ case AB_A -> {
+ ensureMatrixBlockColumnVector(right);
+ res = left.binaryOperations(new BinaryOperator(Multiply.getMultiplyFnObject()), right);
+ }
+ case aB_a -> {
+ res = EOpNodeFuse.compute(EOpNodeFuse.EinsumRewriteType.AB_BA_A__B, List.of(left), new ArrayList<>(), new ArrayList<>(), List.of(right), new ArrayList<>(),
+ null, numThreads);
+ }
+ case A_B -> {
+ ensureMatrixBlockColumnVector(left);
+ ensureMatrixBlockRowVector(right);
+ res = left.binaryOperations(new BinaryOperator(Multiply.getMultiplyFnObject()), right);
+ }
+ case scalar_scalar -> {
+ return new MatrixBlock(left.get(0,0)*right.get(0,0));
+ }
+ default -> {
+ throw new IllegalArgumentException("Unexpected value: " + bin.operand.toString());
+ }
+
+ }
+ if(transposeResult){
+ ReorgOperator transpose = new ReorgOperator(SwapIndex.getSwapIndexFnObject(), numThreads);
+ res = res.reorgOperations(transpose, new MatrixBlock(), 0, 0, 0);
+ }
+ if(c2 == null) ensureMatrixBlockColumnVector(res);
+ return res;
+ }
+
+ @Override
+ public EOpNode reorderChildrenAndOptimize(EOpNode parent, Character outChar1, Character outChar2) {
+ if (this.operand ==EBinaryOperand.aB_aC){
+ if(this.right.c2 == outChar1) { // result is CB so Swap aB and aC
+ var tmpLeft = left; left = right; right = tmpLeft;
+ var tmpC1 = c1; c1 = c2; c2 = tmpC1;
+ var tmpDim1 = dim1; dim1 = dim2; dim2 = tmpDim1;
+ }
+ if(EinsumCPInstruction.FUSE_OUTER_MULTIPLY && left instanceof EOpNodeFuse fuse && fuse.einsumRewriteType == EOpNodeFuse.EinsumRewriteType.AB_BA_B_A__AB
+ && (!EinsumCPInstruction.FUSE_OUTER_MULTIPLY_EXCEEDS_L2_CACHE_CHECK || ((fuse.dim1 * fuse.dim2 *(fuse.ABs.size()+fuse.BAs.size())) + (right.dim1*right.dim2)) * 8 > 6 * 1024 * 1024)
+ && LibMatrixMult.isSkinnyRightHandSide(left.dim1, left.dim2, right.dim1, right.dim2, false)) {
+ fuse.AZs.add(right);
+ fuse.einsumRewriteType = EOpNodeFuse.EinsumRewriteType.AB_BA_A_AZ__BZ;
+ fuse.c1 = fuse.c2;
+ fuse.c2 = right.c2;
+ return fuse;
+ }
+
+ left = left.reorderChildrenAndOptimize(this, left.c2, left.c1); // maybe can be reordered
+ if(left.c2 == right.c1) { // check if change happened:
+ this.operand = EBinaryOperand.Ba_aC;
+ }
+ right = right.reorderChildrenAndOptimize(this, right.c1, right.c2);
+ }else if (this.operand ==EBinaryOperand.Ba_Ca){
+ if(this.right.c1 == outChar1) { // result is CB so Swap Ba and Ca
+ var tmpLeft = left; left = right; right = tmpLeft;
+ var tmpC1 = c1; c1 = c2; c2 = tmpC1;
+ var tmpDim1 = dim1; dim1 = dim2; dim2 = tmpDim1;
+ }
+
+ right = right.reorderChildrenAndOptimize(this, right.c2, right.c1); // maybe can be reordered
+ if(left.c2 == right.c1) { // check if change happened:
+ this.operand = EBinaryOperand.Ba_aC;
+ }
+ left = left.reorderChildrenAndOptimize(this, left.c1, left.c2);
+ }else {
+ left = left.reorderChildrenAndOptimize(this, left.c1, left.c2); // just recurse
+ right = right.reorderChildrenAndOptimize(this, right.c1, right.c2);
+ }
+ return this;
+ }
+
+ // used in the old approach
+ public static Triple> tryCombineAndCost(EOpNode n1 , EOpNode n2, HashMap charToSizeMap, HashMap charToOccurences, Character outChar1, Character outChar2){
+ Predicate cannotBeSummed = (c) ->
+ c == outChar1 || c == outChar2 || charToOccurences.get(c) > 2;
+
+ if(n1.c1 == null) {
+ // n2.c1 also has to be null
+ return Triple.of(1, EBinaryOperand.scalar_scalar, Pair.of(null, null));
+ }
+
+ if(n2.c1 == null) {
+ if(n1.c2 == null)
+ return Triple.of(charToSizeMap.get(n1.c1), EBinaryOperand.A_scalar, Pair.of(n1.c1, null));
+ return Triple.of(charToSizeMap.get(n1.c1)*charToSizeMap.get(n1.c2), EBinaryOperand.AB_scalar, Pair.of(n1.c1, n1.c2));
+ }
+
+ if(n1.c1 == n2.c1){
+ if(n1.c2 != null){
+ if ( n1.c2 == n2.c2){
+ if( cannotBeSummed.test(n1.c1)){
+ if(cannotBeSummed.test(n1.c2)){
+ return Triple.of(charToSizeMap.get(n1.c1)*charToSizeMap.get(n1.c2), EBinaryOperand.AB_AB, Pair.of(n1.c1, n1.c2));
+ }
+ return Triple.of(charToSizeMap.get(n1.c1)*charToSizeMap.get(n1.c2), EBinaryOperand.Ab_Ab, Pair.of(n1.c1, null));
+ }
+
+ if(cannotBeSummed.test(n1.c2)){
+ return Triple.of(charToSizeMap.get(n1.c1)*charToSizeMap.get(n1.c2), EBinaryOperand.aB_aB, Pair.of(n1.c2, null));
+ }
+
+ return Triple.of(charToSizeMap.get(n1.c1)*charToSizeMap.get(n1.c2), EBinaryOperand.ab_ab, Pair.of(null, null));
+
+ }
+
+ else if(n2.c2 == null){
+ if(cannotBeSummed.test(n1.c1)){
+ return Triple.of(charToSizeMap.get(n1.c1)*charToSizeMap.get(n1.c2)*2, EBinaryOperand.AB_A, Pair.of(n1.c1, n1.c2));
+ }
+ return Triple.of(charToSizeMap.get(n1.c1)*charToSizeMap.get(n1.c2)*2, EBinaryOperand.aB_a, Pair.of(n1.c2, null)); // in theory (null, n1.c2)
+ }
+ else if(n1.c1 ==outChar1 || n1.c1==outChar2|| charToOccurences.get(n1.c1) > 2){
+ return null;// AB,AC
+ }
+ else {
+ return Triple.of((charToSizeMap.get(n1.c1)*charToSizeMap.get(n1.c2))+(charToSizeMap.get(n1.c1)*charToSizeMap.get(n1.c2)*charToSizeMap.get(n2.c2)), EBinaryOperand.aB_aC, Pair.of(n1.c2, n2.c2)); // or n2.c2, n1.c2
+ }
+ }else{ // n1.c2 = null -> c2.c2 = null
+ if(n1.c1 ==outChar1 || n1.c1==outChar2 || charToOccurences.get(n1.c1) > 2){
+ return Triple.of(charToSizeMap.get(n1.c1), EBinaryOperand.A_A, Pair.of(n1.c1, null));
+ }
+ return Triple.of(charToSizeMap.get(n1.c1), EBinaryOperand.a_a, Pair.of(null, null));
+ }
+
+
+ }else{ // n1.c1 != n2.c1
+ if(n1.c2 == null) {
+ return Triple.of(charToSizeMap.get(n1.c1)*charToSizeMap.get(n2.c1), EBinaryOperand.A_B, Pair.of(n1.c1, n2.c1));
+ }
+ else if(n2.c2 == null) { // ab,c
+ if (n1.c2 == n2.c1) {
+ if(cannotBeSummed.test(n1.c2)){
+ return Triple.of(charToSizeMap.get(n1.c1)*charToSizeMap.get(n2.c1), EBinaryOperand.AB_B, Pair.of(n1.c1, n1.c2));
+ }
+ return Triple.of(charToSizeMap.get(n1.c1)*charToSizeMap.get(n2.c1), EBinaryOperand.Ab_b, Pair.of(n1.c1, null));
+ }
+ return null; // AB,C
+ }
+ else if (n1.c2 == n2.c1) {
+ if(n1.c1 == n2.c2){ // ab,ba
+ if(cannotBeSummed.test(n1.c1)){
+ if(cannotBeSummed.test(n1.c2)){
+ return Triple.of(charToSizeMap.get(n1.c1)*charToSizeMap.get(n1.c2), EBinaryOperand.AB_BA, Pair.of(n1.c1, n1.c2));
+ }
+ return Triple.of(charToSizeMap.get(n1.c1)*charToSizeMap.get(n1.c2), EBinaryOperand.Ab_bA, Pair.of(n1.c1, null));
+ }
+ if(cannotBeSummed.test(n1.c2)){
+ return Triple.of(charToSizeMap.get(n1.c1)*charToSizeMap.get(n1.c2), EBinaryOperand.aB_Ba, Pair.of(n1.c2, null));
+ }
+ return Triple.of(charToSizeMap.get(n1.c1)*charToSizeMap.get(n1.c2), EBinaryOperand.ab_ba, Pair.of(null, null));
+ }
+ if(cannotBeSummed.test(n1.c2)){
+ return null; // AB_B
+ }else{
+ return Triple.of(charToSizeMap.get(n1.c1)*charToSizeMap.get(n1.c2)*charToSizeMap.get(n2.c2), EBinaryOperand.Ba_aC, Pair.of(n1.c1, n2.c2));
+ }
+ }
+ if(n1.c1 == n2.c2) {
+ if(cannotBeSummed.test(n1.c1)){
+ return null; // AB_B
+ }
+ return Triple.of(charToSizeMap.get(n1.c1)*charToSizeMap.get(n1.c2)*charToSizeMap.get(n2.c1), EBinaryOperand.aB_Ca, Pair.of(n2.c1, n1.c2)); // * its just reorder of mmult
+ }
+ else if (n1.c2 == n2.c2) {
+ if(n1.c2 ==outChar1 || n1.c2==outChar2|| charToOccurences.get(n1.c2) > 2){
+ return null; // BA_CA
+ }else{
+ return Triple.of(charToSizeMap.get(n1.c1)*charToSizeMap.get(n1.c2) +(charToSizeMap.get(n1.c1)*charToSizeMap.get(n1.c2)*charToSizeMap.get(n2.c1)), EBinaryOperand.Ba_Ca, Pair.of(n1.c1, n2.c1)); // or n2.c1, n1.c1
+ }
+ }
+ else { // something like AB,CD
+ return null;
+ }
+ }
+ }
+}
diff --git a/src/main/java/org/apache/sysds/runtime/einsum/EOpNodeData.java b/src/main/java/org/apache/sysds/runtime/einsum/EOpNodeData.java
new file mode 100644
index 00000000000..4906323f8a4
--- /dev/null
+++ b/src/main/java/org/apache/sysds/runtime/einsum/EOpNodeData.java
@@ -0,0 +1,51 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.runtime.einsum;
+
+import org.apache.commons.logging.Log;
+import org.apache.sysds.runtime.matrix.data.MatrixBlock;
+
+import java.util.List;
+
+public class EOpNodeData extends EOpNode {
+ public int matrixIdx;
+ public EOpNodeData(Character c1, Character c2, Integer dim1, Integer dim2, int matrixIdx){
+ super(c1,c2,dim1,dim2);
+ this.matrixIdx = matrixIdx;
+ }
+
+ @Override
+ public List getChildren() {
+ return List.of();
+ }
+ @Override
+ public String toString() {
+ return this.getClass().getSimpleName()+" ("+matrixIdx+") "+getOutputString();
+ }
+ @Override
+ public MatrixBlock computeEOpNode(List inputs, int numOfThreads, Log LOG) {
+ return inputs.get(matrixIdx);
+ }
+
+ @Override
+ public EOpNode reorderChildrenAndOptimize(EOpNode parent, Character outChar1, Character outChar2) {
+ return this;
+ }
+}
diff --git a/src/main/java/org/apache/sysds/runtime/einsum/EOpNodeFuse.java b/src/main/java/org/apache/sysds/runtime/einsum/EOpNodeFuse.java
new file mode 100644
index 00000000000..5accf93bcca
--- /dev/null
+++ b/src/main/java/org/apache/sysds/runtime/einsum/EOpNodeFuse.java
@@ -0,0 +1,465 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.runtime.einsum;
+
+import org.apache.commons.lang3.tuple.Pair;
+import org.apache.commons.logging.Log;
+import org.apache.sysds.runtime.codegen.SpoofRowwise;
+import org.apache.sysds.runtime.functionobjects.Multiply;
+import org.apache.sysds.runtime.functionobjects.SwapIndex;
+import org.apache.sysds.runtime.instructions.cp.DoubleObject;
+import org.apache.sysds.runtime.instructions.cp.EinsumCPInstruction;
+import org.apache.sysds.runtime.instructions.cp.ScalarObject;
+import org.apache.sysds.runtime.matrix.data.LibMatrixMult;
+import org.apache.sysds.runtime.matrix.data.MatrixBlock;
+import org.apache.sysds.runtime.matrix.operators.BinaryOperator;
+import org.apache.sysds.runtime.matrix.operators.ReorgOperator;
+import org.jetbrains.annotations.NotNull;
+
+import java.util.ArrayList;
+import java.util.Comparator;
+import java.util.HashMap;
+import java.util.HashSet;
+import java.util.List;
+import java.util.Map;
+import java.util.Set;
+import java.util.function.Function;
+
+import static org.apache.sysds.runtime.instructions.cp.EinsumCPInstruction.ensureMatrixBlockColumnVector;
+import static org.apache.sysds.runtime.instructions.cp.EinsumCPInstruction.ensureMatrixBlockRowVector;
+
+public class EOpNodeFuse extends EOpNode {
+
+ private EOpNode scalar = null;
+
+ public enum EinsumRewriteType{
+ // B -> row*vec, A -> row*scalar
+ AB_BA_B_A__AB,
+ AB_BA_A__B,
+ AB_BA_B_A__A,
+ AB_BA_B_A__,
+
+ // scalar from row(AB).dot(B) multiplied by row(AZ)
+ AB_BA_B_A_AZ__Z,
+
+ // AZ: last step is outer matrix multiplication using vector Z
+ AB_BA_A_AZ__BZ, AB_BA_A_AZ__ZB,
+ }
+
+ public EinsumRewriteType einsumRewriteType;
+ public List ABs;
+ public List BAs;
+ public List Bs;
+ public List As;
+ public List AZs;
+ @Override
+ public List getChildren(){
+ List all = new ArrayList<>();
+ all.addAll(ABs);
+ all.addAll(BAs);
+ all.addAll(Bs);
+ all.addAll(As);
+ all.addAll(AZs);
+ if (scalar != null) all.add(scalar);
+ return all;
+ };
+ private EOpNodeFuse(Character c1, Character c2, Integer dim1, Integer dim2, EinsumRewriteType einsumRewriteType, List ABs, List BAs, List Bs, List As, List AZs) {
+ super(c1,c2, dim1, dim2);
+ this.einsumRewriteType = einsumRewriteType;
+ this.ABs = ABs;
+ this.BAs = BAs;
+ this.Bs = Bs;
+ this.As = As;
+ this.AZs = AZs;
+ }
+ public EOpNodeFuse(EinsumRewriteType einsumRewriteType, List ABs, List BAs, List Bs, List As, List AZs, List, List>> AXsAndXs) {
+ super(null,null,null, null);
+ switch(einsumRewriteType) {
+ case AB_BA_B_A__A->{
+ c1 = ABs.get(0).c1;
+ dim1 = ABs.get(0).dim1;
+ }case AB_BA_A__B -> {
+ c1 = ABs.get(0).c2;
+ dim1 = ABs.get(0).dim2;
+ }case AB_BA_B_A__ -> {
+ }case AB_BA_B_A__AB -> {
+ c1 = ABs.get(0).c1;
+ dim1 = ABs.get(0).dim1;
+ c2 = ABs.get(0).c2;
+ dim2 = ABs.get(0).dim2;
+ }case AB_BA_B_A_AZ__Z -> {
+ c1 = AZs.get(0).c1;
+ dim1 = AZs.get(0).dim2;
+ }case AB_BA_A_AZ__BZ ->{
+ c1 = ABs.get(0).c2;
+ dim1 = ABs.get(0).dim2;
+ c2 = AZs.get(0).c2;
+ dim2 = AZs.get(0).dim2;
+ }case AB_BA_A_AZ__ZB ->{
+ c2 = ABs.get(0).c2;
+ dim2 = ABs.get(0).dim2;
+ c1 = AZs.get(0).c2;
+ dim1 = AZs.get(0).dim2;
+ }
+ }
+ this.einsumRewriteType = einsumRewriteType;
+ this.ABs = ABs;
+ this.BAs = BAs;
+ this.Bs = Bs;
+ this.As = As;
+ this.AZs = AZs;
+ }
+
+ @Override
+ public String toString() {
+ return this.getClass().getSimpleName()+" ("+einsumRewriteType.toString()+") "+this.getOutputString();
+ }
+
+ public void addScalarAsIntermediate(EOpNode scalar) {
+ if(einsumRewriteType == EinsumRewriteType.AB_BA_B_A__A || einsumRewriteType == EinsumRewriteType.AB_BA_B_A_AZ__Z)
+ this.scalar = scalar;
+ else
+ throw new RuntimeException("EOpNodeFuse.addScalarAsIntermediate: scalar is undefined for type "+einsumRewriteType.toString());
+ }
+
+ public static List findFuseOps(List operands, Character outChar1, Character outChar2,
+ Map charToSize, Map charToOccurences, List ret)
+ {
+ List result = new ArrayList<>();
+ Set matricesChars = new HashSet<>();
+ Map> matricesCharsStartingWithChar = new HashMap<>();
+ Map> charsToMatrices = new HashMap<>();
+
+ for(EOpNode operand1 : operands) {
+ String k;
+
+ if(operand1.c2 != null) {
+ k = operand1.c1.toString() + operand1.c2;
+ matricesChars.add(k);
+ if(matricesCharsStartingWithChar.containsKey(operand1.c1)) {
+ matricesCharsStartingWithChar.get(operand1.c1).add(k);
+ }
+ else {
+ HashSet set = new HashSet<>();
+ set.add(k);
+ matricesCharsStartingWithChar.put(operand1.c1, set);
+ }
+ }
+ else {
+ k = operand1.c1.toString();
+ }
+
+ if(charsToMatrices.containsKey(k)) {
+ charsToMatrices.get(k).add(operand1);
+ }
+ else {
+ ArrayList matrices = new ArrayList<>();
+ matrices.add(operand1);
+ charsToMatrices.put(k, matrices);
+ }
+ }
+ ArrayList> matricesCharsSorted = new ArrayList<>(matricesChars.stream()
+ .map(x -> Pair.of(charsToMatrices.get(x).get(0).dim1 * charsToMatrices.get(x).get(0).dim2, x)).toList());
+ matricesCharsSorted.sort(Comparator.comparing(Pair::getLeft));
+ ArrayList AZs = new ArrayList<>();
+
+ HashSet usedMatricesChars = new HashSet<>();
+ HashSet usedOperands = new HashSet<>();
+
+ for(String ABCandidate : matricesCharsSorted.stream().map(Pair::getRight).toList()) {
+ if(usedMatricesChars.contains(ABCandidate)) continue;
+
+ char a = ABCandidate.charAt(0);
+ char b = ABCandidate.charAt(1);
+ String AB = ABCandidate;
+ String BA = "" + b + a;
+
+ int BAsCount = (charsToMatrices.containsKey(BA) ? charsToMatrices.get(BA).size() : 0);
+ int ABsCount = charsToMatrices.get(AB).size();
+
+ if(BAsCount > ABsCount + 1) {
+ BA = "" + a + b;
+ AB = "" + b + a;
+ char tmp = a;
+ a = b;
+ b = tmp;
+ int tmp2 = ABsCount;
+ ABsCount = BAsCount;
+ BAsCount = tmp2;
+ }
+ String A = "" + a;
+ String B = "" + b;
+ int AsCount = (charsToMatrices.containsKey(A) && !usedMatricesChars.contains(A) ? charsToMatrices.get(A).size() : 0);
+ int BsCount = (charsToMatrices.containsKey(B) && !usedMatricesChars.contains(B) ? charsToMatrices.get(B).size() : 0);
+
+ if(AsCount == 0 && BsCount == 0 && (ABsCount + BAsCount) < 2) { // no elementwise multiplication possible
+ continue;
+ }
+
+ int usedBsCount = BsCount + ABsCount + BAsCount;
+
+ boolean doSumA = false;
+ boolean doSumB = charToOccurences.get(b) == usedBsCount && (outChar1 == null || b != outChar1) && (outChar2 == null || b != outChar2);
+ HashSet AZCandidates = matricesCharsStartingWithChar.get(a);
+
+ String AZ = null;
+ Character z = null;
+ boolean includeAZ = AZCandidates.size() == 2;
+
+ if(includeAZ) {
+ for(var AZCandidate : AZCandidates) {
+ if(AB.equals(AZCandidate)) {continue;}
+ AZs = charsToMatrices.get(AZCandidate);
+ z = AZCandidate.charAt(1);
+ String Z = "" + z;
+ AZ = "" + a + z;
+ int AZsCount= AZs.size();
+ int ZsCount= charsToMatrices.containsKey(Z) ? charsToMatrices.get(Z).size() : 0;
+ doSumA = AZsCount + ABsCount + BAsCount + AsCount == charToOccurences.get(a) && (outChar1 == null || a != outChar1) && (outChar2 == null || a != outChar2);
+ boolean doSumZ = AZsCount + ZsCount == charToOccurences.get(z) && (outChar1 == null || z != outChar1) && (outChar2 == null || z != outChar2);
+ if(!doSumA){
+ includeAZ = false;
+ } else if(!doSumB && doSumZ){ // swap the order, to have only one fusion AB,...,AZ->Z
+ b = z;
+ z = AB.charAt(1);
+ AB = "" + a + b;
+ BA = "" + b + a;
+ A = "" + a;
+ B = "" + b;
+ AZ = "" + a + z;
+ AZs = charsToMatrices.get(AZ);
+ doSumB = true;
+ } else if(!doSumB && !doSumZ){ // outer between B and Z
+ if(!EinsumCPInstruction.FUSE_OUTER_MULTIPLY
+ || (EinsumCPInstruction.FUSE_OUTER_MULTIPLY_EXCEEDS_L2_CACHE_CHECK && ((charToSize.get(a) * charToSize.get(b) *(ABsCount + BAsCount)) + (charToSize.get(a)*charToSize.get(z)*(AZsCount))) * 8 < 6 * 1024 * 1024)
+ || !LibMatrixMult.isSkinnyRightHandSide(charToSize.get(AB.charAt(0)), charToSize.get(AB.charAt(1)), charToSize.get(AB.charAt(0)), charToSize.get(AZCandidates.iterator().next().charAt(1)),false)) {
+ includeAZ = false;
+ }
+ } else if(doSumB && doSumZ){
+ // it will be two separate templates and then mutliply a vectors
+ } else if (doSumB && !doSumZ) {
+ // ->Z template OK
+ }
+ break;
+ }
+ }
+
+ if(!includeAZ) {
+ doSumA = charToOccurences.get(a) == AsCount + ABsCount + BAsCount && (outChar1 == null || a != outChar1) && (outChar2 == null || a != outChar2);
+ }
+
+ ArrayList ABs = charsToMatrices.containsKey(AB) ? charsToMatrices.get(AB) : new ArrayList<>();
+ ArrayList BAs = charsToMatrices.containsKey(BA) ? charsToMatrices.get(BA) : new ArrayList<>();
+ ArrayList As = charsToMatrices.containsKey(A) && !usedMatricesChars.contains(A) ? charsToMatrices.get(A) : new ArrayList<>();
+ ArrayList Bs = charsToMatrices.containsKey(B) && !usedMatricesChars.contains(B) ? charsToMatrices.get(B) : new ArrayList<>();
+ Character c1 = null, c2 = null;
+ Integer dim1 = null, dim2 = null;
+ EinsumRewriteType type = null;
+
+ if(includeAZ) {
+ if(doSumB) {
+ type = EinsumRewriteType.AB_BA_B_A_AZ__Z;
+ c1 = z;
+ }
+ else if((outChar1 != null && outChar2 != null) && outChar1 == z && outChar2 == b) {
+ type = EinsumRewriteType.AB_BA_A_AZ__ZB;
+ c1 = z; c2 = b;
+ }
+ else if((outChar1 != null && outChar2 != null) && outChar1 == b && outChar2 == z) {
+ type = EinsumRewriteType.AB_BA_A_AZ__BZ;
+ c1 = b; c2 = z;
+ }
+ else {
+ type = EinsumRewriteType.AB_BA_A_AZ__ZB;
+ c1 = z; c2 = b;
+ }
+ }
+ else {
+ AZs= new ArrayList<>();
+ if(doSumA) {
+ if(doSumB) {
+ type = EinsumRewriteType.AB_BA_B_A__;
+ }
+ else {
+ type = EinsumRewriteType.AB_BA_A__B;
+ c1 = AB.charAt(1);
+ }
+ }
+ else if(doSumB) {
+ type = EinsumRewriteType.AB_BA_B_A__A;
+ c1 = AB.charAt(0);
+ }
+ else {
+ type = EinsumRewriteType.AB_BA_B_A__AB;
+ c1 = AB.charAt(0); c2 = AB.charAt(1);
+ }
+ }
+
+ if(c1 != null) {
+ charToOccurences.put(c1, charToOccurences.get(c1) + 1);
+ dim1 = charToSize.get(c1);
+ }
+ if(c2 != null) {
+ charToOccurences.put(c2, charToOccurences.get(c2) + 1);
+ dim2 = charToSize.get(c2);
+ }
+ boolean includeB = type != EinsumRewriteType.AB_BA_A__B && type != EinsumRewriteType.AB_BA_A_AZ__BZ && type != EinsumRewriteType.AB_BA_A_AZ__ZB;
+
+ usedOperands.addAll(ABs);
+ usedOperands.addAll(BAs);
+ usedOperands.addAll(As);
+ if (includeB) usedOperands.addAll(Bs);
+ if (includeAZ) usedOperands.addAll(AZs);
+
+ usedMatricesChars.add(AB);
+ usedMatricesChars.add(BA);
+ usedMatricesChars.add(A);
+ if (includeB) usedMatricesChars.add(B);
+ if (includeAZ) usedMatricesChars.add(AZ);
+
+ var e = new EOpNodeFuse(c1, c2, dim1, dim2, type, ABs, BAs, includeB ? Bs : new ArrayList<>(), As, AZs);
+
+ result.add(e);
+ }
+
+ for(EOpNode n : operands) {
+ if(!usedOperands.contains(n)){
+ ret.add(n);
+ } else {
+ charToOccurences.put(n.c1, charToOccurences.get(n.c1) - 1);
+ if(charToOccurences.get(n.c2)!= null)
+ charToOccurences.put(n.c2, charToOccurences.get(n.c2)-1);
+ }
+ }
+
+ return result;
+ }
+ @SuppressWarnings("unused")
+ public static MatrixBlock compute(EinsumRewriteType rewriteType, List ABsInput, List mbBAs, List mbBs, List mbAs, List mbAZs,
+ Double scalar, int numThreads){
+ boolean isResultAB =rewriteType == EOpNodeFuse.EinsumRewriteType.AB_BA_B_A__AB;
+ boolean isResultA = rewriteType == EOpNodeFuse.EinsumRewriteType.AB_BA_B_A__A;
+ boolean isResultB = rewriteType == EOpNodeFuse.EinsumRewriteType.AB_BA_A__B;
+ boolean isResult_ = rewriteType == EOpNodeFuse.EinsumRewriteType.AB_BA_B_A__;
+ boolean isResultZ = rewriteType == EOpNodeFuse.EinsumRewriteType.AB_BA_B_A_AZ__Z;
+ boolean isResultBZ =rewriteType == EOpNodeFuse.EinsumRewriteType.AB_BA_A_AZ__BZ;
+ boolean isResultZB =rewriteType == EOpNodeFuse.EinsumRewriteType.AB_BA_A_AZ__ZB;
+
+ List mbABs = new ArrayList<>(ABsInput);
+ int bSize = mbABs.get(0).getNumColumns();
+ int aSize = mbABs.get(0).getNumRows();
+ if (!mbBAs.isEmpty()) {
+ ReorgOperator transpose = new ReorgOperator(SwapIndex.getSwapIndexFnObject(), numThreads);
+ for(MatrixBlock mb : mbBAs) //BA->AB
+ mbABs.add(mb.reorgOperations(transpose, null, 0, 0, 0));
+ }
+
+ if(mbAs.size() > 1) mbAs = multiplyVectorsIntoOne(mbAs, aSize);
+ if(mbBs.size() > 1) mbBs = multiplyVectorsIntoOne(mbBs, bSize);
+
+ int constDim2 = -1;
+ int zSize = 0;
+ int azCount = 0;
+ switch(rewriteType){
+ case AB_BA_B_A_AZ__Z, AB_BA_A_AZ__BZ, AB_BA_A_AZ__ZB -> {
+ constDim2 = mbAZs.get(0).getNumColumns();
+ zSize = mbAZs.get(0).getNumColumns();
+ azCount = mbAZs.size();
+ }
+ }
+
+ SpoofRowwise.RowType rowType = switch(rewriteType){
+ case AB_BA_B_A__AB -> SpoofRowwise.RowType.NO_AGG;
+ case AB_BA_A__B -> SpoofRowwise.RowType.COL_AGG_T;
+ case AB_BA_B_A__A -> SpoofRowwise.RowType.ROW_AGG;
+ case AB_BA_B_A__ -> SpoofRowwise.RowType.FULL_AGG;
+ case AB_BA_B_A_AZ__Z -> SpoofRowwise.RowType.COL_AGG_CONST;
+ case AB_BA_A_AZ__BZ -> SpoofRowwise.RowType.COL_AGG_B1_T;
+ case AB_BA_A_AZ__ZB -> SpoofRowwise.RowType.COL_AGG_B1;
+ };
+ EinsumSpoofRowwise r = new EinsumSpoofRowwise(rewriteType, rowType, constDim2,
+ mbABs.size()-1, !mbBs.isEmpty() && (!isResultBZ && !isResultZB && !isResultB), mbAs.size(), azCount, zSize);
+
+ ArrayList fuseInputs = new ArrayList<>();
+ fuseInputs.addAll(mbABs);
+ if(!isResultBZ && !isResultZB && !isResultB)
+ fuseInputs.addAll(mbBs);
+ fuseInputs.addAll(mbAs);
+ if (isResultZ || isResultBZ || isResultZB)
+ fuseInputs.addAll(mbAZs);
+
+ ArrayList scalarObjects = new ArrayList<>();
+ if(scalar != null){
+ scalarObjects.add(new DoubleObject(scalar));
+ }
+ MatrixBlock out = r.execute(fuseInputs, scalarObjects, new MatrixBlock(), numThreads);
+
+ if(isResultB && !mbBs.isEmpty()){
+ LibMatrixMult.vectMultiply(mbBs.get(0).getDenseBlockValues(), out.getDenseBlockValues(), 0,0, mbABs.get(0).getNumColumns());
+ }
+ if(isResultBZ && !mbBs.isEmpty()){
+ ensureMatrixBlockColumnVector(mbBs.get(0));
+ out = out.binaryOperations(new BinaryOperator(Multiply.getMultiplyFnObject()), mbBs.get(0));
+ }
+ if(isResultZB && !mbBs.isEmpty()){
+ ensureMatrixBlockRowVector(mbBs.get(0));
+ out = out.binaryOperations(new BinaryOperator(Multiply.getMultiplyFnObject()), mbBs.get(0));
+ }
+
+ if( isResultA || isResultB || isResultZ)
+ ensureMatrixBlockColumnVector(out);
+
+ return out;
+ }
+ @Override
+ public MatrixBlock computeEOpNode(List inputs, int numThreads, Log LOG) {
+ final Function eOpNodeToMatrixBlock = n -> n.computeEOpNode(inputs, numThreads, LOG);
+ List mbABs = new ArrayList<>(ABs.stream().map(eOpNodeToMatrixBlock).toList());
+ List mbBAs = BAs.stream().map(eOpNodeToMatrixBlock).toList();
+ List mbBs = Bs.stream().map(eOpNodeToMatrixBlock).toList();
+ List mbAs = As.stream().map(eOpNodeToMatrixBlock).toList();
+ List mbAZs = AZs.stream().map(eOpNodeToMatrixBlock).toList();
+ Double scalar = this.scalar == null ? null : this.scalar.computeEOpNode(inputs, numThreads, LOG).get(0,0);
+ return EOpNodeFuse.compute(this.einsumRewriteType, mbABs, mbBAs, mbBs, mbAs, mbAZs , scalar, numThreads);
+ }
+
+ @Override
+ public EOpNode reorderChildrenAndOptimize(EOpNode parent, Character outChar1, Character outChar2) {
+ ABs.replaceAll(n -> n.reorderChildrenAndOptimize(this, n.c1, n.c2));
+ BAs.replaceAll(n -> n.reorderChildrenAndOptimize(this, n.c1, n.c2));
+ As.replaceAll(n -> n.reorderChildrenAndOptimize(this, n.c1, n.c2));
+ Bs.replaceAll(n -> n.reorderChildrenAndOptimize(this, n.c1, n.c2));
+ AZs.replaceAll(n -> n.reorderChildrenAndOptimize(this, n.c1, n.c2));
+ return this;
+ }
+
+ private static @NotNull List multiplyVectorsIntoOne(List mbs, int size) {
+ MatrixBlock mb = new MatrixBlock(mbs.get(0).getNumRows(), mbs.get(0).getNumColumns(), false);
+ mb.allocateDenseBlock();
+ for(int i = 1; i< mbs.size(); i++) { // multiply Bs
+ if(i==1)
+ LibMatrixMult.vectMultiplyWrite(mbs.get(0).getDenseBlock().values(0), mbs.get(1).getDenseBlock().values(0), mb.getDenseBlock().values(0),0,0,0, size);
+ else
+ LibMatrixMult.vectMultiply(mbs.get(i).getDenseBlock().values(0),mb.getDenseBlock().values(0),0,0, size);
+ }
+ return List.of(mb);
+ }
+}
+
diff --git a/src/main/java/org/apache/sysds/runtime/einsum/EOpNodeUnary.java b/src/main/java/org/apache/sysds/runtime/einsum/EOpNodeUnary.java
new file mode 100644
index 00000000000..918f1dd3b24
--- /dev/null
+++ b/src/main/java/org/apache/sysds/runtime/einsum/EOpNodeUnary.java
@@ -0,0 +1,103 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.runtime.einsum;
+
+import org.apache.commons.logging.Log;
+import org.apache.sysds.common.Types;
+import org.apache.sysds.runtime.functionobjects.DiagIndex;
+import org.apache.sysds.runtime.functionobjects.KahanPlus;
+import org.apache.sysds.runtime.functionobjects.Plus;
+import org.apache.sysds.runtime.functionobjects.ReduceAll;
+import org.apache.sysds.runtime.functionobjects.ReduceCol;
+import org.apache.sysds.runtime.functionobjects.ReduceDiag;
+import org.apache.sysds.runtime.functionobjects.ReduceRow;
+import org.apache.sysds.runtime.matrix.data.MatrixBlock;
+import org.apache.sysds.runtime.matrix.operators.AggregateOperator;
+import org.apache.sysds.runtime.matrix.operators.AggregateUnaryOperator;
+import org.apache.sysds.runtime.matrix.operators.ReorgOperator;
+
+import java.util.List;
+
+public class EOpNodeUnary extends EOpNode {
+ private final EUnaryOperand eUnaryOperand;
+ public EOpNode child;
+
+ public enum EUnaryOperand {
+ DIAG, TRACE, SUM, SUM_COLS, SUM_ROWS
+ }
+ public EOpNodeUnary(Character c1, Character c2, Integer dim1, Integer dim2, EOpNode child, EUnaryOperand eUnaryOperand) {
+ super(c1, c2, dim1, dim2);
+ this.child = child;
+ this.eUnaryOperand = eUnaryOperand;
+ }
+
+ @Override
+ public List getChildren() {
+ return List.of(child);
+ }
+ @Override
+ public String toString() {
+ return this.getClass().getSimpleName()+" ("+eUnaryOperand.toString()+") "+this.getOutputString();
+ }
+
+ @Override
+ public MatrixBlock computeEOpNode(List inputs, int numOfThreads, Log LOG) {
+ MatrixBlock mb = child.computeEOpNode(inputs, numOfThreads, LOG);
+ return switch(eUnaryOperand) {
+ case DIAG->{
+ ReorgOperator op = new ReorgOperator(DiagIndex.getDiagIndexFnObject());
+ yield mb.reorgOperations(op, new MatrixBlock(),0,0,0);
+ }
+ case TRACE -> {
+ AggregateOperator agg = new AggregateOperator(0, KahanPlus.getKahanPlusFnObject(), Types.CorrectionLocationType.LASTCOLUMN);
+ AggregateUnaryOperator aggun = new AggregateUnaryOperator(agg, ReduceDiag.getReduceDiagFnObject(), numOfThreads);
+ MatrixBlock res = new MatrixBlock(10, 10, false);
+ mb.aggregateUnaryOperations(aggun, res,0,null);
+ yield res;
+ }
+ case SUM->{
+ AggregateOperator agg = new AggregateOperator(0, Plus.getPlusFnObject());
+ AggregateUnaryOperator aggun = new AggregateUnaryOperator(agg, ReduceAll.getReduceAllFnObject(), numOfThreads);
+ MatrixBlock res = new MatrixBlock(1, 1, false);
+ mb.aggregateUnaryOperations(aggun, res, 0, null);
+ yield res;
+ }
+ case SUM_COLS ->{
+ AggregateOperator agg = new AggregateOperator(0, Plus.getPlusFnObject());
+ AggregateUnaryOperator aggun = new AggregateUnaryOperator(agg, ReduceRow.getReduceRowFnObject(), numOfThreads);
+ MatrixBlock res = new MatrixBlock(mb.getNumColumns(), 1, false);
+ mb.aggregateUnaryOperations(aggun, res, 0, null);
+ yield res;
+ }
+ case SUM_ROWS ->{
+ AggregateOperator agg = new AggregateOperator(0, Plus.getPlusFnObject());
+ AggregateUnaryOperator aggun = new AggregateUnaryOperator(agg, ReduceCol.getReduceColFnObject(), numOfThreads);
+ MatrixBlock res = new MatrixBlock(mb.getNumRows(), 1, false);
+ mb.aggregateUnaryOperations(aggun, res, 0, null);
+ yield res;
+ }
+ };
+ }
+
+ @Override
+ public EOpNode reorderChildrenAndOptimize(EOpNode parent, Character outChar1, Character outChar2) {
+ return this;
+ }
+}
diff --git a/src/main/java/org/apache/sysds/runtime/einsum/EinsumContext.java b/src/main/java/org/apache/sysds/runtime/einsum/EinsumContext.java
index 6da39e59873..16c67e5399b 100644
--- a/src/main/java/org/apache/sysds/runtime/einsum/EinsumContext.java
+++ b/src/main/java/org/apache/sysds/runtime/einsum/EinsumContext.java
@@ -23,155 +23,100 @@
import java.util.ArrayList;
import java.util.HashMap;
-import java.util.HashSet;
import java.util.Iterator;
-
+import java.util.List;
+import java.util.Map;
public class EinsumContext {
- public enum ContractDimensions {
- CONTRACT_LEFT,
- CONTRACT_RIGHT,
- CONTRACT_BOTH,
- }
- public Integer outRows;
- public Integer outCols;
- public Character outChar1;
- public Character outChar2;
- public HashMap charToDimensionSize;
- public String equationString;
- public boolean[] diagonalInputs;
- public HashSet summingChars;
- public HashSet contractDimsSet;
- public ContractDimensions[] contractDims;
- public ArrayList newEquationStringInputsSplit;
- public HashMap> characterAppearanceIndexes; // for each character, this tells in which inputs it appears
-
- private EinsumContext(){};
- public static EinsumContext getEinsumContext(String eqStr, ArrayList inputs){
- EinsumContext res = new EinsumContext();
-
- res.equationString = eqStr;
- res.charToDimensionSize = new HashMap();
- HashSet summingChars = new HashSet<>();
- ContractDimensions[] contractDims = new ContractDimensions[inputs.size()];
- boolean[] diagonalInputs = new boolean[inputs.size()]; // all false by default
- HashSet contractDimsSet = new HashSet<>();
- HashMap> partsCharactersToIndices = new HashMap<>();
- ArrayList newEquationStringSplit = new ArrayList<>();
-
- Iterator it = inputs.iterator();
- MatrixBlock curArr = it.next();
- int arrSizeIterator = 0;
- int arrayIterator = 0;
- int i;
- // first iteration through string: collect information on character-size and what characters are summing characters
- for (i = 0; true; i++) {
- char c = eqStr.charAt(i);
- if(c == '-'){
- i+=2;
- break;
- }
- if(c == ','){
- arrayIterator++;
- curArr = it.next();
- arrSizeIterator = 0;
- }
- else{
- if (res.charToDimensionSize.containsKey(c)) { // sanity check if dims match, this is already checked at validation
- if(arrSizeIterator == 0 && res.charToDimensionSize.get(c) != curArr.getNumRows())
- throw new RuntimeException("Einsum: character "+c+" has multiple conflicting sizes");
- else if(arrSizeIterator == 1 && res.charToDimensionSize.get(c) != curArr.getNumColumns())
- throw new RuntimeException("Einsum: character "+c+" has multiple conflicting sizes");
- summingChars.add(c);
- } else {
- if(arrSizeIterator == 0)
- res.charToDimensionSize.put(c, curArr.getNumRows());
- else if(arrSizeIterator == 1)
- res.charToDimensionSize.put(c, curArr.getNumColumns());
- }
-
- arrSizeIterator++;
- }
- }
-
- int numOfRemainingChars = eqStr.length() - i;
-
- if (numOfRemainingChars > 2)
- throw new RuntimeException("Einsum: dim > 2 not supported");
-
- arrSizeIterator = 0;
-
- Character outChar1 = numOfRemainingChars > 0 ? eqStr.charAt(i) : null;
- Character outChar2 = numOfRemainingChars > 1 ? eqStr.charAt(i+1) : null;
- res.outRows=(numOfRemainingChars > 0 ? res.charToDimensionSize.get(outChar1) : 1);
- res.outCols=(numOfRemainingChars > 1 ? res.charToDimensionSize.get(outChar2) : 1);
-
- arrayIterator=0;
- // second iteration through string: collect remaining information
- for (i = 0; true; i++) {
- char c = eqStr.charAt(i);
- if (c == '-') {
- break;
- }
- if (c == ',') {
- arrayIterator++;
- arrSizeIterator = 0;
- continue;
- }
- String s = "";
-
- if(summingChars.contains(c)) {
- s+=c;
- if(!partsCharactersToIndices.containsKey(c))
- partsCharactersToIndices.put(c, new ArrayList<>());
- partsCharactersToIndices.get(c).add(arrayIterator);
- }
- else if((outChar1 != null && c == outChar1) || (outChar2 != null && c == outChar2)) {
- s+=c;
- }
- else {
- contractDimsSet.add(c);
- contractDims[arrayIterator] = ContractDimensions.CONTRACT_LEFT;
- }
-
- if(i + 1 < eqStr.length()) { // process next character together
- char c2 = eqStr.charAt(i + 1);
- i++;
- if (c2 == '-') { newEquationStringSplit.add(s); break;}
- if (c2 == ',') { arrayIterator++; newEquationStringSplit.add(s); continue; }
-
- if (c2 == c){
- diagonalInputs[arrayIterator] = true;
- if (contractDims[arrayIterator] == ContractDimensions.CONTRACT_LEFT) contractDims[arrayIterator] = ContractDimensions.CONTRACT_BOTH;
- }
- else{
- if(summingChars.contains(c2)) {
- s+=c2;
- if(!partsCharactersToIndices.containsKey(c2))
- partsCharactersToIndices.put(c2, new ArrayList<>());
- partsCharactersToIndices.get(c2).add(arrayIterator);
- }
- else if((outChar1 != null && c2 == outChar1) || (outChar2 != null && c2 == outChar2)) {
- s+=c2;
- }
- else {
- contractDimsSet.add(c2);
- contractDims[arrayIterator] = contractDims[arrayIterator] == ContractDimensions.CONTRACT_LEFT ? ContractDimensions.CONTRACT_BOTH : ContractDimensions.CONTRACT_RIGHT;
- }
- }
- }
- newEquationStringSplit.add(s);
- arrSizeIterator++;
- }
-
- res.contractDims = contractDims;
- res.contractDimsSet = contractDimsSet;
- res.diagonalInputs = diagonalInputs;
- res.summingChars = summingChars;
- res.outChar1 = outChar1;
- res.outChar2 = outChar2;
- res.newEquationStringInputsSplit = newEquationStringSplit;
- res.characterAppearanceIndexes = partsCharactersToIndices;
- return res;
- }
+ public Integer outRows;
+ public Integer outCols;
+ public Character outChar1;
+ public Character outChar2;
+ public Map charToDimensionSize;
+ public String equationString;
+ public List newEquationStringInputsSplit;
+ public Map characterAppearanceCount;
+
+ private EinsumContext(){};
+ public static EinsumContext getEinsumContext(String eqStr, List inputs){
+ EinsumContext res = new EinsumContext();
+
+ res.equationString = eqStr;
+ Map charToDimensionSize = new HashMap<>();
+ Map characterAppearanceCount = new HashMap<>();
+ List newEquationStringSplit = new ArrayList<>();
+ Character outChar1 = null;
+ Character outChar2 = null;
+
+ Iterator it = inputs.iterator();
+ MatrixBlock curArr = it.next();
+ int i = 0;
+
+ char c = eqStr.charAt(i);
+ for(i = 0; i < eqStr.length(); i++) {
+ StringBuilder sb = new StringBuilder(2);
+ for(;i < eqStr.length(); i++){
+ c = eqStr.charAt(i);
+ if (c == ' ') continue;
+ if (c == ',' || c == '-' ) break;
+ if (!Character.isAlphabetic(c)) {
+ throw new RuntimeException("Einsum: only alphabetic characters are supported for dimensions: "+c);
+ }
+ sb.append(c);
+ if (characterAppearanceCount.containsKey(c)) characterAppearanceCount.put(c, characterAppearanceCount.get(c) + 1) ;
+ else characterAppearanceCount.put(c, 1);
+ }
+ String s = sb.toString();
+ newEquationStringSplit.add(s);
+
+ if(s.length() > 0){
+ if (charToDimensionSize.containsKey(s.charAt(0)))
+ if (charToDimensionSize.get(s.charAt(0)) != curArr.getNumRows())
+ throw new RuntimeException("Einsum: character "+c+" has multiple conflicting sizes");
+ charToDimensionSize.put(s.charAt(0), curArr.getNumRows());
+ }
+ if(s.length() > 1){
+ if (charToDimensionSize.containsKey(s.charAt(1)))
+ if (charToDimensionSize.get(s.charAt(1)) != curArr.getNumColumns())
+ throw new RuntimeException("Einsum: character "+c+" has multiple conflicting sizes");
+ charToDimensionSize.put(s.charAt(1), curArr.getNumColumns());
+ }
+ if(s.length() > 2) throw new RuntimeException("Einsum: only up-to 2D inputs strings allowed ");
+
+ if( c==','){
+ curArr = it.next();
+ }
+ else if (c=='-') break;
+
+ if (i == eqStr.length() - 1) {throw new RuntimeException("Einsum: missing '->' substring "+c);}
+ }
+
+ if (i == eqStr.length() - 1 || eqStr.charAt(i+1) != '>') throw new RuntimeException("Einsum: missing '->' substring "+c);
+ i+=2;
+
+ StringBuilder sb = new StringBuilder(2);
+
+ for(;i < eqStr.length(); i++){
+ c = eqStr.charAt(i);
+ if (c == ' ') continue;
+ if (!Character.isAlphabetic(c)) {
+ throw new RuntimeException("Einsum: only alphabetic characters are supported for dimensions: "+c);
+ }
+ sb.append(c);
+ }
+ String s = sb.toString();
+ if(s.length() > 0) outChar1 = s.charAt(0);
+ if(s.length() > 1) outChar2 = s.charAt(1);
+ if(s.length() > 2) throw new RuntimeException("Einsum: only up-to 2D output allowed ");
+
+ res.outRows=(outChar1 == null ? 1 : charToDimensionSize.get(outChar1));
+ res.outCols=(outChar2 == null ? 1 : charToDimensionSize.get(outChar2));
+
+ res.outChar1 = outChar1;
+ res.outChar2 = outChar2;
+ res.newEquationStringInputsSplit = newEquationStringSplit;
+ res.characterAppearanceCount = characterAppearanceCount;
+ res.charToDimensionSize = charToDimensionSize;
+ return res;
+ }
}
diff --git a/src/main/java/org/apache/sysds/runtime/einsum/EinsumEquationValidator.java b/src/main/java/org/apache/sysds/runtime/einsum/EinsumEquationValidator.java
index 5643159ef9a..417a1b760b2 100644
--- a/src/main/java/org/apache/sysds/runtime/einsum/EinsumEquationValidator.java
+++ b/src/main/java/org/apache/sysds/runtime/einsum/EinsumEquationValidator.java
@@ -32,113 +32,117 @@
public class EinsumEquationValidator {
- public static Triple validateEinsumEquationAndReturnDimensions(String equationString, List expressionsOrIdentifiers) throws LanguageException {
- String[] eqStringParts = equationString.split("->"); // length 2 if "...->..." , length 1 if "...->"
- boolean isResultScalar = eqStringParts.length == 1;
+ public static Triple validateEinsumEquationAndReturnDimensions(String equationString, List expressionsOrIdentifiers) throws LanguageException {
+ String[] eqStringParts = equationString.split("->"); // length 2 if "...->..." , length 1 if "...->"
+ boolean isResultScalar = eqStringParts.length == 1;
- if(expressionsOrIdentifiers == null)
- throw new RuntimeException("Einsum: called validateEinsumAndReturnDimensions with null list");
+ if(expressionsOrIdentifiers == null)
+ throw new RuntimeException("Einsum: called validateEinsumAndReturnDimensions with null list");
- HashMap charToDimensionSize = new HashMap<>();
- Iterator it = expressionsOrIdentifiers.iterator();
- HopOrIdentifier currArr = it.next();
- int arrSizeIterator = 0;
- int numberOfMatrices = 1;
- for (int i = 0; i < eqStringParts[0].length(); i++) {
- char c = equationString.charAt(i);
- if(c==' ') continue;
- if(c==','){
- if(!it.hasNext())
- throw new LanguageException("Einsum: Provided less operands than specified in equation str");
- currArr = it.next();
- arrSizeIterator = 0;
- numberOfMatrices++;
- } else{
- long thisCharDimension = getThisCharDimension(currArr, arrSizeIterator);
- if (charToDimensionSize.containsKey(c)){
- if (charToDimensionSize.get(c) != thisCharDimension)
- throw new LanguageException("Einsum: Character '" + c + "' expected to be dim " + charToDimensionSize.get(c) + ", but found " + thisCharDimension);
- }else{
- charToDimensionSize.put(c, thisCharDimension);
- }
- arrSizeIterator++;
- }
- }
- if (expressionsOrIdentifiers.size() - 1 > numberOfMatrices)
- throw new LanguageException("Einsum: Provided more operands than specified in equation str");
+ HashMap charToDimensionSize = new HashMap<>();
+ Iterator it = expressionsOrIdentifiers.iterator();
+ HopOrIdentifier currArr = it.next();
+ int arrSizeIterator = 0;
+ int numberOfMatrices = 1;
+ for (int i = 0; i < eqStringParts[0].length(); i++) {
+ char c = equationString.charAt(i);
+ if(c==' ') continue;
+ if(c==','){
+ if(!it.hasNext())
+ throw new LanguageException("Einsum: Provided less operands than specified in equation str");
+ currArr = it.next();
+ arrSizeIterator = 0;
+ numberOfMatrices++;
+ } else{
+ long thisCharDimension = getThisCharDimension(currArr, arrSizeIterator);
+ if (charToDimensionSize.containsKey(c)){
+ if (charToDimensionSize.get(c) != thisCharDimension)
+ throw new LanguageException("Einsum: Character '" + c + "' expected to be dim " + charToDimensionSize.get(c) + ", but found " + thisCharDimension);
+ }else{
+ charToDimensionSize.put(c, thisCharDimension);
+ }
+ arrSizeIterator++;
+ }
+ }
+ if (expressionsOrIdentifiers.size() - 1 > numberOfMatrices)
+ throw new LanguageException("Einsum: Provided more operands than specified in equation str");
- if (isResultScalar)
- return Triple.of(-1l,-1l, Types.DataType.SCALAR);
+ if (isResultScalar)
+ return Triple.of(-1l,-1l, Types.DataType.SCALAR);
- int numberOfOutDimensions = 0;
- Character dim1Char = null;
- long dim1 = 1;
- long dim2 = 1;
- for (int i = 0; i < eqStringParts[1].length(); i++) {
- char c = eqStringParts[1].charAt(i);
- if (c == ' ') continue;
- if (numberOfOutDimensions == 0) {
- dim1Char = c;
- dim1 = charToDimensionSize.get(c);
- } else {
- if(c==dim1Char) throw new LanguageException("Einsum: output character "+c+" provided multiple times");
- dim2 = charToDimensionSize.get(c);
- }
- numberOfOutDimensions++;
- }
- if (numberOfOutDimensions > 2) {
- throw new LanguageException("Einsum: output matrices with with no. dims > 2 not supported");
- } else {
- return Triple.of(dim1, dim2, Types.DataType.MATRIX);
- }
- }
+ int numberOfOutDimensions = 0;
+ Character dim1Char = null;
+ long dim1 = 1;
+ long dim2 = 1;
+ for (int i = 0; i < eqStringParts[1].length(); i++) {
+ char c = eqStringParts[1].charAt(i);
+ if (c == ' ') continue;
+ if (numberOfOutDimensions == 0) {
+ dim1Char = c;
+ if(!charToDimensionSize.containsKey(c))
+ throw new LanguageException("Einsum: Output dimension '"+c+"' not present in input operands");
+ dim1 = charToDimensionSize.get(c);
+ } else {
+ if(c==dim1Char) throw new LanguageException("Einsum: output character "+c+" provided multiple times");
+ if(!charToDimensionSize.containsKey(c))
+ throw new LanguageException("Einsum: Output dimension '"+c+"' not present in input operands");
+ dim2 = charToDimensionSize.get(c);
+ }
+ numberOfOutDimensions++;
+ }
+ if (numberOfOutDimensions > 2) {
+ throw new LanguageException("Einsum: output matrices with with no. dims > 2 not supported");
+ } else {
+ return Triple.of(dim1, dim2, Types.DataType.MATRIX);
+ }
+ }
- public static Types.DataType validateEinsumEquationNoDimensions(String equationString, int numberOfMatrixInputs) throws LanguageException {
- String[] eqStringParts = equationString.split("->"); // length 2 if "...->..." , length 1 if "...->"
- boolean isResultScalar = eqStringParts.length == 1;
+ public static Types.DataType validateEinsumEquationNoDimensions(String equationString, int numberOfMatrixInputs) throws LanguageException {
+ String[] eqStringParts = equationString.split("->"); // length 2 if "...->..." , length 1 if "...->"
+ boolean isResultScalar = eqStringParts.length == 1;
- int numberOfMatrices = 1;
- for (int i = 0; i < eqStringParts[0].length(); i++) {
- char c = eqStringParts[0].charAt(i);
- if(c == ' ') continue;
- if(c == ',')
- numberOfMatrices++;
- }
- if(numberOfMatrixInputs != numberOfMatrices){
- throw new LanguageException("Einsum: Invalid number of parameters, given: " + numberOfMatrixInputs + ", expected: " + numberOfMatrices);
- }
+ int numberOfMatrices = 1;
+ for (int i = 0; i < eqStringParts[0].length(); i++) {
+ char c = eqStringParts[0].charAt(i);
+ if(c == ' ') continue;
+ if(c == ',')
+ numberOfMatrices++;
+ }
+ if(numberOfMatrixInputs != numberOfMatrices){
+ throw new LanguageException("Einsum: Invalid number of parameters, given: " + numberOfMatrixInputs + ", expected: " + numberOfMatrices);
+ }
- if(isResultScalar){
- return Types.DataType.SCALAR;
- }else {
- int numberOfDimensions = 0;
- Character dim1Char = null;
- for (int i = 0; i < eqStringParts[1].length(); i++) {
- char c = eqStringParts[i].charAt(i);
- if(c == ' ') continue;
- numberOfDimensions++;
- if (numberOfDimensions == 1 && c == dim1Char)
- throw new LanguageException("Einsum: output character "+c+" provided multiple times");
- dim1Char = c;
- }
+ if(isResultScalar){
+ return Types.DataType.SCALAR;
+ }else {
+ int numberOfDimensions = 0;
+ Character dim1Char = null;
+ for (int i = 0; i < eqStringParts[1].length(); i++) {
+ char c = eqStringParts[i].charAt(i);
+ if(c == ' ') continue;
+ numberOfDimensions++;
+ if (numberOfDimensions == 1 && c == dim1Char)
+ throw new LanguageException("Einsum: output character "+c+" provided multiple times");
+ dim1Char = c;
+ }
- if (numberOfDimensions > 2) {
- throw new LanguageException("Einsum: output matrices with with no. dims > 2 not supported");
- } else {
- return Types.DataType.MATRIX;
- }
- }
- }
+ if (numberOfDimensions > 2) {
+ throw new LanguageException("Einsum: output matrices with with no. dims > 2 not supported");
+ } else {
+ return Types.DataType.MATRIX;
+ }
+ }
+ }
- private static long getThisCharDimension(HopOrIdentifier currArr, int arrSizeIterator) {
- long thisCharDimension;
- if(currArr instanceof Hop){
- thisCharDimension = arrSizeIterator == 0 ? ((Hop) currArr).getDim1() : ((Hop) currArr).getDim2();
- } else if(currArr instanceof Identifier){
- thisCharDimension = arrSizeIterator == 0 ? ((Identifier) currArr).getDim1() : ((Identifier) currArr).getDim2();
- } else {
- throw new RuntimeException("validateEinsumAndReturnDimensions called with expressions that are not Hop or Identifier: "+ currArr == null ? "null" : currArr.getClass().toString());
- }
- return thisCharDimension;
- }
+ private static long getThisCharDimension(HopOrIdentifier currArr, int arrSizeIterator) {
+ long thisCharDimension;
+ if(currArr instanceof Hop){
+ thisCharDimension = arrSizeIterator == 0 ? ((Hop) currArr).getDim1() : ((Hop) currArr).getDim2();
+ } else if(currArr instanceof Identifier){
+ thisCharDimension = arrSizeIterator == 0 ? ((Identifier) currArr).getDim1() : ((Identifier) currArr).getDim2();
+ } else {
+ throw new RuntimeException("validateEinsumAndReturnDimensions called with expressions that are not Hop or Identifier: "+ currArr == null ? "null" : currArr.getClass().toString());
+ }
+ return thisCharDimension;
+ }
}
diff --git a/src/main/java/org/apache/sysds/runtime/einsum/EinsumSpoofRowwise.java b/src/main/java/org/apache/sysds/runtime/einsum/EinsumSpoofRowwise.java
new file mode 100644
index 00000000000..40b73ea3994
--- /dev/null
+++ b/src/main/java/org/apache/sysds/runtime/einsum/EinsumSpoofRowwise.java
@@ -0,0 +1,265 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.runtime.einsum;
+
+import jdk.incubator.vector.DoubleVector;
+import jdk.incubator.vector.VectorSpecies;
+import org.apache.commons.lang3.NotImplementedException;
+import org.apache.sysds.runtime.codegen.LibSpoofPrimitives;
+import org.apache.sysds.runtime.codegen.SpoofRowwise;
+import org.apache.sysds.runtime.matrix.data.LibMatrixMult;
+
+public final class EinsumSpoofRowwise extends SpoofRowwise {
+ private static final long serialVersionUID = -5957679254041639561L;
+
+ private final int _ABCount;
+ private final boolean _Bsupplied;
+ private final int _ACount;
+ private final int _AZCount;
+ private final int _ZSize;
+ private final int _AZStartIndex;
+ private final EOpNodeFuse.EinsumRewriteType _EinsumRewriteType;
+
+ public EinsumSpoofRowwise(EOpNodeFuse.EinsumRewriteType einsumRewriteType, RowType rowType, long constDim2,
+ int abCount, boolean bSupplied, int aCount, int azCount, int zSize) {
+ super(rowType, constDim2, false, 1);
+ _ABCount = abCount;
+ _Bsupplied = bSupplied;
+ _ACount = aCount;
+ _AZStartIndex = abCount + (_Bsupplied ? 1 : 0) + aCount;
+ _AZCount = azCount;
+ _EinsumRewriteType = einsumRewriteType;
+ _ZSize = zSize;
+ }
+
+ protected void genexec(double[] a, int ai, SideInput[] b, double[] scalars, double[] c, int ci, int len, long grix,
+ int rix) {
+ switch(_EinsumRewriteType) {
+ case AB_BA_B_A__AB -> {
+ genexecAB(a, ai, b, null, c, ci, len, grix, rix);
+ if(scalars.length != 0) { LibMatrixMult.vectMultiplyWrite(scalars[0], c, c, ci, ci, len); }
+ }
+ case AB_BA_A__B -> {
+ genexecB(a, ai, b, null, c, ci, len, grix, rix);
+ }
+ case AB_BA_B_A__A -> {
+ genexecAor(a, ai, b, null, c, ci, len, grix, rix);
+ if(scalars.length != 0) { c[rix] *= scalars[0]; }
+ }
+ case AB_BA_B_A__ -> {
+ genexecAor(a, ai, b, null, c, ci, len, grix, rix);
+ if(scalars.length != 0) { c[0] *= scalars[0]; }
+ }
+ case AB_BA_B_A_AZ__Z -> {
+ double[] temp = {0};
+ genexecAor(a, ai, b, null, temp, 0, len, grix, rix);
+ if(scalars.length != 0) { temp[0] *= scalars[0]; }
+ if(_AZCount > 1) {
+ double[] temp2 = new double[_ZSize];
+ int bi = _AZStartIndex;
+ LibMatrixMult.vectMultiplyWrite(b[bi++].values(0), b[bi++].values(0), temp2, _ZSize * rix,
+ _ZSize * rix, 0, _ZSize);
+ while(bi < _AZStartIndex + _AZCount) {
+ LibMatrixMult.vectMultiplyWrite(temp2, b[bi++].values(0), temp2, 0, _ZSize * rix, 0, _ZSize);
+ }
+ LibMatrixMult.vectMultiplyAdd(temp[0], temp2, c, 0, 0, _ZSize);
+ }
+ else
+ LibMatrixMult.vectMultiplyAdd(temp[0], b[_AZStartIndex].values(rix), c, _ZSize * rix, 0, _ZSize);
+ }
+ case AB_BA_A_AZ__BZ -> {
+ double[] temp = new double[len];
+ genexecB(a, ai, b, null, temp, 0, len, grix, rix);
+ if(scalars.length != 0) {
+ LibMatrixMult.vectMultiplyWrite(scalars[0], temp, temp, 0, 0, len);
+ }
+ if(_AZCount > 1) {
+ double[] temp2 = new double[_ZSize];
+ int bi = _AZStartIndex;
+ LibMatrixMult.vectMultiplyWrite(b[bi++].values(0), b[bi++].values(0), temp2, _ZSize * rix,
+ _ZSize * rix, 0, _ZSize);
+ while(bi < _AZStartIndex + _AZCount) {
+ LibMatrixMult.vectMultiplyWrite(temp2, b[bi++].values(0), temp2, 0, _ZSize * rix, 0, _ZSize);
+ }
+ LibSpoofPrimitives.vectOuterMultAdd(temp, temp2, c, 0, 0, 0, len, _ZSize);
+ }
+ else
+ LibSpoofPrimitives.vectOuterMultAdd(temp, b[_AZStartIndex].values(rix), c, 0, _ZSize * rix, 0, len, _ZSize);
+ }
+ case AB_BA_A_AZ__ZB -> {
+ double[] temp = new double[len];
+ genexecB(a, ai, b, null, temp, 0, len, grix, rix);
+ if(scalars.length != 0) {
+ LibMatrixMult.vectMultiplyWrite(scalars[0], temp, temp, 0, 0, len);
+ }
+ if(_AZCount > 1) {
+ double[] temp2 = new double[_ZSize];
+ int bi = _AZStartIndex;
+ LibMatrixMult.vectMultiplyWrite(b[bi++].values(0), b[bi++].values(0), temp2, _ZSize * rix,
+ _ZSize * rix, 0, _ZSize);
+ while(bi < _AZStartIndex + _AZCount) {
+ LibMatrixMult.vectMultiplyWrite(temp2, b[bi++].values(0), temp2, 0, _ZSize * rix, 0, _ZSize);
+ }
+ LibSpoofPrimitives.vectOuterMultAdd(temp2, temp, c, 0, 0, 0, _ZSize, len);
+ }
+ else
+ LibSpoofPrimitives.vectOuterMultAdd(b[_AZStartIndex].values(rix), temp, c, _ZSize * rix, 0, 0, _ZSize, len);
+ }
+ default -> throw new NotImplementedException();
+ }
+ }
+
+ private void genexecAB(double[] a, int ai, SideInput[] b, double[] scalars, double[] c, int ci, int len, long grix,
+ int rix) {
+ int bi = 0;
+ double[] TMP1 = null;
+ if(_ABCount != 0) {
+ if(_ABCount == 1 & _ACount == 0 && !_Bsupplied) {
+ LibMatrixMult.vectMultiplyWrite(a, b[0].values(rix), c, ai, ai, ci, len);
+ return;
+ }
+ TMP1 = LibSpoofPrimitives.vectMultWrite(a, b[bi++].values(rix), ai, ai, len);
+ while(bi < _ABCount) {
+ if(_ACount == 0 && !_Bsupplied && bi == _ABCount - 1) {
+ LibMatrixMult.vectMultiplyWrite(TMP1, b[bi++].values(rix), c, 0, ai, ci, len);
+ }
+ else {
+ LibMatrixMult.vectMultiplyWrite(TMP1, b[bi++].values(rix), TMP1, 0, ai, 0, len);
+ }
+ }
+ }
+
+ if(_Bsupplied) {
+ if(_ACount == 1)
+ if(TMP1 == null)
+ vectMultiplyWrite(b[bi + 1].values(0)[rix], a, b[bi].values(0), c, ai, 0, ci, len);
+ else
+ vectMultiplyWrite(b[bi + 1].values(0)[rix], TMP1, b[bi].values(0), c, 0, 0, ci, len);
+ else if(TMP1 == null)
+ LibMatrixMult.vectMultiplyWrite(a, b[bi].values(0), c, ai, 0, ci, len);
+ else
+ LibMatrixMult.vectMultiplyWrite(TMP1, b[bi].values(0), c, 0, 0, ci, len);
+ }
+ else if(_ACount == 1) {
+ if(TMP1 == null)
+ LibMatrixMult.vectMultiplyWrite(b[bi].values(0)[rix], a, c, ai, ci, len);
+ else
+ LibMatrixMult.vectMultiplyWrite(b[bi].values(0)[rix], TMP1, c, 0, ci, len);
+ }
+ }
+
+ private void genexecB(double[] a, int ai, SideInput[] b, double[] scalars, double[] c, int ci, int len, long grix,
+ int rix) {
+ int bi = 0;
+ double[] TMP1 = null;
+ if(_ABCount == 1 && _ACount == 0)
+ LibMatrixMult.vectMultiplyAdd(a, b[bi++].values(rix), c, ai, ai, 0, len);
+ else if(_ABCount != 0) {
+ TMP1 = LibSpoofPrimitives.vectMultWrite(a, b[bi++].values(rix), ai, ai, len);
+ while(bi < _ABCount) {
+ if(_ACount == 0 && bi == _ABCount - 1) {
+ LibMatrixMult.vectMultiplyAdd(TMP1, b[bi++].values(rix), c, 0, ai, 0, len);
+ }
+ else {
+ LibMatrixMult.vectMultiplyWrite(TMP1, b[bi++].values(rix), TMP1, 0, ai, 0, len);
+ }
+ }
+ }
+
+ if(_ACount == 1) {
+ if(TMP1 == null)
+ LibMatrixMult.vectMultiplyAdd(b[bi].values(0)[rix], a, c, ai, 0, len);
+ else
+ LibMatrixMult.vectMultiplyAdd(b[bi].values(0)[rix], TMP1, c, 0, 0, len);
+ }
+ }
+
+ private void genexecAor(double[] a, int ai, SideInput[] b, double[] scalars, double[] c, int ci, int len, long grix, int rix) {
+ int bi = 0;
+ double[] TMP1 = null;
+ double TMP2 = 0;
+ if(_ABCount == 1 && !_Bsupplied)
+ TMP2 = LibSpoofPrimitives.dotProduct(a, b[bi++].values(rix), ai, ai, len);
+ else if(_ABCount != 0) {
+ TMP1 = LibSpoofPrimitives.vectMultWrite(a, b[bi++].values(rix), ai, ai, len);
+ while(bi < _ABCount) {
+ if(!_Bsupplied && bi == _ABCount - 1)
+ TMP2 = LibSpoofPrimitives.dotProduct(TMP1, b[bi++].values(rix), 0, ai, len);
+ else
+ LibMatrixMult.vectMultiplyWrite(TMP1, b[bi++].values(rix), TMP1, 0, ai, 0, len);
+ }
+ }
+
+ if(_Bsupplied)
+ if(_ABCount != 0) TMP2 = LibSpoofPrimitives.dotProduct(TMP1, b[bi++].values(0), 0, 0, len);
+ else TMP2 = LibSpoofPrimitives.dotProduct(a, b[bi++].values(0), ai, 0, len);
+ else if(_ABCount == 0) TMP2 = LibSpoofPrimitives.vectSum(a, ai, len);
+
+ if(_ACount == 1) TMP2 *= b[bi].values(0)[rix];
+
+ if(_EinsumRewriteType == EOpNodeFuse.EinsumRewriteType.AB_BA_B_A__A) c[ci] = TMP2;
+ else c[0] += TMP2;
+ }
+
+ protected void genexec(double[] avals, int[] aix, int ai, SideInput[] b, double[] scalars, double[] c, int ci,
+ int alen, int len, long grix, int rix) {
+ throw new RuntimeException("Sparse fused einsum not implemented");
+ }
+
+ // I am not sure if it is worth copying to LibMatrixMult so for now added it here
+ private static final VectorSpecies SPECIES = DoubleVector.SPECIES_PREFERRED;
+ private static final int vLen = SPECIES.length();
+
+ public static void vectMultiplyWrite(final double aval, double[] a, double[] b, double[] c, int ai, int bi, int ci,
+ final int len) {
+ final int bn = len % vLen;
+
+ //rest, not aligned to vLen-blocks
+ for(int j = 0; j < bn; j++, ai++, bi++, ci++)
+ c[ci] = aval * b[bi] * a[ai];
+
+ //unrolled vLen-block (for better instruction-level parallelism)
+ DoubleVector avalVec = DoubleVector.broadcast(SPECIES, aval);
+ for(int j = bn; j < len; j += vLen, ai += vLen, bi += vLen, ci += vLen) {
+ DoubleVector aVec = DoubleVector.fromArray(SPECIES, a, ai);
+ DoubleVector bVec = DoubleVector.fromArray(SPECIES, b, bi);
+ avalVec.mul(bVec).mul(aVec).intoArray(c, ci);
+ }
+ }
+
+ public static void vectMultiplyAdd(final double aval, double[] a, double[] b, double[] c, int ai, int bi, int ci,
+ final int len) {
+ final int bn = len % vLen;
+
+ //rest, not aligned to vLen-blocks
+ for(int j = 0; j < bn; j++, ai++, bi++, ci++)
+ c[ci] += aval * b[bi] * a[ai];
+
+ //unrolled vLen-block (for better instruction-level parallelism)
+ DoubleVector avalVec = DoubleVector.broadcast(SPECIES, aval);
+ for(int j = bn; j < len; j += vLen, ai += vLen, bi += vLen, ci += vLen) {
+ DoubleVector aVec = DoubleVector.fromArray(SPECIES, a, ai);
+ DoubleVector bVec = DoubleVector.fromArray(SPECIES, b, bi);
+ DoubleVector cVec = DoubleVector.fromArray(SPECIES, c, ci);
+ DoubleVector tmp = aVec.mul(bVec);
+ tmp.fma(avalVec, cVec).intoArray(c, ci);
+ }
+ }
+}
diff --git a/src/main/java/org/apache/sysds/runtime/frame/data/columns/FloatArray.java b/src/main/java/org/apache/sysds/runtime/frame/data/columns/FloatArray.java
index d0ab7a56305..f838eadc1d2 100644
--- a/src/main/java/org/apache/sysds/runtime/frame/data/columns/FloatArray.java
+++ b/src/main/java/org/apache/sysds/runtime/frame/data/columns/FloatArray.java
@@ -174,7 +174,7 @@ public void reset(int size) {
@Override
public byte[] getAsByteArray() {
- ByteBuffer floatBuffer = ByteBuffer.allocate(8 * _size);
+ ByteBuffer floatBuffer = ByteBuffer.allocate(4 * _size);
floatBuffer.order(ByteOrder.nativeOrder());
for(int i = 0; i < _size; i++)
floatBuffer.putFloat(_data[i]);
diff --git a/src/main/java/org/apache/sysds/runtime/frame/data/lib/FrameLibAppend.java b/src/main/java/org/apache/sysds/runtime/frame/data/lib/FrameLibAppend.java
index 60c61e8f4af..a0eb9965d43 100644
--- a/src/main/java/org/apache/sysds/runtime/frame/data/lib/FrameLibAppend.java
+++ b/src/main/java/org/apache/sysds/runtime/frame/data/lib/FrameLibAppend.java
@@ -65,16 +65,16 @@ else if(a.getNumColumns() == 0)
else if(b.getNumColumns() == 0)
return a;
- final ValueType[] _schema = addAll(a.getSchema(), b.getSchema());
- final ColumnMetadata[] _colmeta = addAll(a.getColumnMetadata(), b.getColumnMetadata());
- final Array>[] _coldata = addAll(a.getColumns(), b.getColumns());
- String[] _colnames = addAll(a.getColumnNames(), b.getColumnNames());
+ final ValueType[] schema = addAll(a.getSchema(), b.getSchema());
+ final ColumnMetadata[] colmeta = addAll(a.getColumnMetadata(), b.getColumnMetadata());
+ final Array>[] coldata = addAll(a.getColumns(), b.getColumns());
+ String[] colnames = addAll(a.getColumnNames(), b.getColumnNames());
// check and enforce unique columns names
- if(!Arrays.stream(_colnames).allMatch(new HashSet<>()::add))
- _colnames = null; // set to default of null to allocate on demand
+ if(!Arrays.stream(colnames).allMatch(new HashSet<>()::add))
+ colnames = null; // set to default of null to allocate on demand
- return new FrameBlock(_schema, _colnames, _colmeta, _coldata);
+ return new FrameBlock(schema, colnames, colmeta, coldata);
}
private static FrameBlock appendRbind(FrameBlock a, FrameBlock b) {
diff --git a/src/main/java/org/apache/sysds/runtime/functionobjects/CM.java b/src/main/java/org/apache/sysds/runtime/functionobjects/CM.java
index 54a2d83adf8..f3b137c088d 100644
--- a/src/main/java/org/apache/sysds/runtime/functionobjects/CM.java
+++ b/src/main/java/org/apache/sysds/runtime/functionobjects/CM.java
@@ -20,7 +20,7 @@
package org.apache.sysds.runtime.functionobjects;
import org.apache.sysds.runtime.DMLRuntimeException;
-import org.apache.sysds.runtime.instructions.cp.CM_COV_Object;
+import org.apache.sysds.runtime.instructions.cp.CmCovObject;
import org.apache.sysds.runtime.instructions.cp.Data;
import org.apache.sysds.runtime.instructions.cp.KahanObject;
import org.apache.sysds.runtime.matrix.operators.CMOperator.AggregateOperationTypes;
@@ -88,7 +88,7 @@ public AggregateOperationTypes getAggOpType() {
*/
@Override
public Data execute(Data in1, double in2) {
- CM_COV_Object cm1=(CM_COV_Object) in1;
+ CmCovObject cm1=(CmCovObject) in1;
if(cm1.isCMAllZeros()) {
cm1.w=1;
@@ -203,7 +203,7 @@ public Data execute(Data in1, double in2) {
*/
@Override
public Data execute(Data in1, double in2, double w2) {
- CM_COV_Object cm1=(CM_COV_Object) in1;
+ CmCovObject cm1=(CmCovObject) in1;
if(cm1.isCMAllZeros())
{
@@ -320,8 +320,8 @@ public Data execute(Data in1, double in2, double w2) {
@Override
public Data execute(Data in1, Data in2)
{
- CM_COV_Object cm1=(CM_COV_Object) in1;
- CM_COV_Object cm2=(CM_COV_Object) in2;
+ CmCovObject cm1=(CmCovObject) in1;
+ CmCovObject cm2=(CmCovObject) in2;
if(cm1.isCMAllZeros())
{
diff --git a/src/main/java/org/apache/sysds/runtime/functionobjects/COV.java b/src/main/java/org/apache/sysds/runtime/functionobjects/COV.java
index 89d8b235a05..836bf972ce6 100644
--- a/src/main/java/org/apache/sysds/runtime/functionobjects/COV.java
+++ b/src/main/java/org/apache/sysds/runtime/functionobjects/COV.java
@@ -19,7 +19,7 @@
package org.apache.sysds.runtime.functionobjects;
-import org.apache.sysds.runtime.instructions.cp.CM_COV_Object;
+import org.apache.sysds.runtime.instructions.cp.CmCovObject;
import org.apache.sysds.runtime.instructions.cp.Data;
import org.apache.sysds.runtime.instructions.cp.KahanObject;
@@ -62,7 +62,7 @@ private COV() {
@Override
public Data execute(Data in1, double u, double v, double w2)
{
- CM_COV_Object cov1=(CM_COV_Object) in1;
+ CmCovObject cov1=(CmCovObject) in1;
if(cov1.isCOVAllZeros())
{
cov1.w=w2;
@@ -94,7 +94,7 @@ public Data execute(Data in1, double u, double v, double w2)
@Override
public Data execute(Data in1, double u, double v)
{
- CM_COV_Object cov1=(CM_COV_Object) in1;
+ CmCovObject cov1=(CmCovObject) in1;
if(cov1.isCOVAllZeros())
{
cov1.w=1L;
@@ -118,8 +118,8 @@ public Data execute(Data in1, double u, double v)
@Override
public Data execute(Data in1, Data in2)
{
- CM_COV_Object cov1=(CM_COV_Object) in1;
- CM_COV_Object cov2=(CM_COV_Object) in2;
+ CmCovObject cov1=(CmCovObject) in1;
+ CmCovObject cov2=(CmCovObject) in2;
if(cov1.isCOVAllZeros())
{
cov1.w=cov2.w;
diff --git a/src/main/java/org/apache/sysds/runtime/instructions/OOCInstructionParser.java b/src/main/java/org/apache/sysds/runtime/instructions/OOCInstructionParser.java
index feefe5f63d6..a2e64dd0bac 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/OOCInstructionParser.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/OOCInstructionParser.java
@@ -23,6 +23,7 @@
import org.apache.commons.logging.LogFactory;
import org.apache.sysds.common.InstructionType;
import org.apache.sysds.runtime.DMLRuntimeException;
+import org.apache.sysds.runtime.instructions.ooc.AggregateTernaryOOCInstruction;
import org.apache.sysds.runtime.instructions.ooc.AggregateUnaryOOCInstruction;
import org.apache.sysds.runtime.instructions.ooc.BinaryOOCInstruction;
import org.apache.sysds.runtime.instructions.ooc.CSVReblockOOCInstruction;
@@ -33,6 +34,7 @@
import org.apache.sysds.runtime.instructions.ooc.OOCInstruction;
import org.apache.sysds.runtime.instructions.ooc.ParameterizedBuiltinOOCInstruction;
import org.apache.sysds.runtime.instructions.ooc.ReblockOOCInstruction;
+import org.apache.sysds.runtime.instructions.ooc.TernaryOOCInstruction;
import org.apache.sysds.runtime.instructions.ooc.TSMMOOCInstruction;
import org.apache.sysds.runtime.instructions.ooc.UnaryOOCInstruction;
import org.apache.sysds.runtime.instructions.ooc.MatrixVectorBinaryOOCInstruction;
@@ -64,10 +66,14 @@ public static OOCInstruction parseSingleInstruction(InstructionType ooctype, Str
return CSVReblockOOCInstruction.parseInstruction(str);
case AggregateUnary:
return AggregateUnaryOOCInstruction.parseInstruction(str);
+ case AggregateTernary:
+ return AggregateTernaryOOCInstruction.parseInstruction(str);
case Unary:
return UnaryOOCInstruction.parseInstruction(str);
case Binary:
return BinaryOOCInstruction.parseInstruction(str);
+ case Ternary:
+ return TernaryOOCInstruction.parseInstruction(str);
case AggregateBinary:
case MAPMM:
return MatrixVectorBinaryOOCInstruction.parseInstruction(str);
diff --git a/src/main/java/org/apache/sysds/runtime/instructions/cp/CPInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/cp/CPInstruction.java
index b35ca55dab6..b8d84ca3898 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/cp/CPInstruction.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/cp/CPInstruction.java
@@ -34,6 +34,7 @@
import org.apache.sysds.runtime.instructions.Instruction;
import org.apache.sysds.runtime.instructions.fed.FEDInstructionUtils;
import org.apache.sysds.runtime.matrix.operators.Operator;
+import org.apache.sysds.runtime.ooc.stats.OOCEventLog;
public abstract class CPInstruction extends Instruction {
protected static final Log LOG = LogFactory.getLog(CPInstruction.class.getName());
@@ -52,6 +53,7 @@ public enum CPType {
protected final CPType _cptype;
protected final boolean _requiresLabelUpdate;
+ private long nanoTime;
protected CPInstruction(CPType type, String opcode, String istr) {
this(type, null, opcode, istr);
@@ -88,6 +90,8 @@ public String getGraphString() {
@Override
public Instruction preprocessInstruction(ExecutionContext ec) {
+ if (DMLScript.OOC_LOG_EVENTS)
+ nanoTime = System.nanoTime();
//default preprocess behavior (e.g., debug state, lineage)
Instruction tmp = super.preprocessInstruction(ec);
@@ -118,6 +122,10 @@ public Instruction preprocessInstruction(ExecutionContext ec) {
public void postprocessInstruction(ExecutionContext ec) {
if (DMLScript.LINEAGE_DEBUGGER)
ec.maintainLineageDebuggerInfo(this);
+ if (DMLScript.OOC_LOG_EVENTS) {
+ int callerId = OOCEventLog.registerCaller(getExtendedOpcode() + "_" + hashCode());
+ OOCEventLog.onComputeEvent(callerId, nanoTime, System.nanoTime());
+ }
}
/**
diff --git a/src/main/java/org/apache/sysds/runtime/instructions/cp/CentralMomentCPInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/cp/CentralMomentCPInstruction.java
index cca466182da..18fcfe6ae95 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/cp/CentralMomentCPInstruction.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/cp/CentralMomentCPInstruction.java
@@ -90,7 +90,7 @@ public void processInstruction( ExecutionContext ec ) {
if ( cm_op.getAggOpType() == AggregateOperationTypes.INVALID )
cm_op = cm_op.setCMAggOp((int)order.getLongValue());
- CM_COV_Object cmobj = null;
+ CmCovObject cmobj = null;
if (input3 == null ) {
cmobj = matBlock.cmOperations(cm_op);
}
diff --git a/src/main/java/org/apache/sysds/runtime/instructions/cp/CM_COV_Object.java b/src/main/java/org/apache/sysds/runtime/instructions/cp/CmCovObject.java
similarity index 95%
rename from src/main/java/org/apache/sysds/runtime/instructions/cp/CM_COV_Object.java
rename to src/main/java/org/apache/sysds/runtime/instructions/cp/CmCovObject.java
index 8591c3b9ca9..09d08e8e448 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/cp/CM_COV_Object.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/cp/CmCovObject.java
@@ -27,7 +27,7 @@
import org.apache.sysds.runtime.matrix.operators.CMOperator.AggregateOperationTypes;
-public class CM_COV_Object extends Data
+public class CmCovObject extends Data
{
private static final long serialVersionUID = -5814207545197934085L;
@@ -48,7 +48,7 @@ public String toString() {
return "weight: "+w+", mean: "+mean+", m2: "+m2+", m3: "+m3+", m4: "+m4+", min: "+min+", max: "+max+", mean2: "+mean_v+", c2: "+c2;
}
- public CM_COV_Object()
+ public CmCovObject()
{
super(DataType.UNKNOWN, ValueType.UNKNOWN);
w=0;
@@ -75,7 +75,7 @@ public void reset()
max=0;
}
- public int compareTo(CM_COV_Object that)
+ public int compareTo(CmCovObject that)
{
if(w!=that.w)
return Double.compare(w, that.w);
@@ -100,10 +100,10 @@ else if(max!=that.max)
@Override
public boolean equals(Object o)
{
- if( o == null || !(o instanceof CM_COV_Object) )
+ if( o == null || !(o instanceof CmCovObject) )
return false;
- CM_COV_Object that = (CM_COV_Object)o;
+ CmCovObject that = (CmCovObject)o;
return (w==that.w && mean.equals(that.mean) && m2.equals(that.m2))
&& m3.equals(that.m3) && m4.equals(that.m4)
&& mean_v.equals(that.mean_v) && c2.equals(that.c2)
@@ -115,7 +115,7 @@ public int hashCode() {
throw new RuntimeException("hashCode() should never be called on instances of this class.");
}
- public void set(CM_COV_Object that)
+ public void set(CmCovObject that)
{
this.w=that.w;
this.mean.set(that.mean);
diff --git a/src/main/java/org/apache/sysds/runtime/instructions/cp/CovarianceCPInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/cp/CovarianceCPInstruction.java
index f0b8b597039..29b74f95741 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/cp/CovarianceCPInstruction.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/cp/CovarianceCPInstruction.java
@@ -62,7 +62,7 @@ public void processInstruction(ExecutionContext ec)
MatrixBlock matBlock2 = ec.getMatrixInput(input2.getName());
String output_name = output.getName();
COVOperator cov_op = (COVOperator)_optr;
- CM_COV_Object covobj = null;
+ CmCovObject covobj = null;
if ( input3 == null ) {
// Unweighted: cov.mvar0.mvar1.out
diff --git a/src/main/java/org/apache/sysds/runtime/instructions/cp/EinsumCPInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/cp/EinsumCPInstruction.java
index c67dd290799..2b1074c80c9 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/cp/EinsumCPInstruction.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/cp/EinsumCPInstruction.java
@@ -19,7 +19,6 @@
package org.apache.sysds.runtime.instructions.cp;
-import org.apache.commons.lang3.NotImplementedException;
import org.apache.commons.lang3.tuple.Pair;
import org.apache.commons.lang3.tuple.Triple;
import org.apache.commons.logging.Log;
@@ -29,31 +28,49 @@
import org.apache.sysds.common.Types.DataType;
import org.apache.sysds.hops.LiteralOp;
import org.apache.sysds.hops.OptimizerUtils;
-import org.apache.sysds.hops.codegen.SpoofCompiler;
import org.apache.sysds.hops.codegen.cplan.CNode;
-import org.apache.sysds.hops.codegen.cplan.CNodeBinary;
import org.apache.sysds.hops.codegen.cplan.CNodeCell;
import org.apache.sysds.hops.codegen.cplan.CNodeData;
-import org.apache.sysds.hops.codegen.cplan.CNodeRow;
-import org.apache.sysds.runtime.codegen.*;
+import org.apache.sysds.runtime.codegen.CodegenUtils;
+import org.apache.sysds.runtime.codegen.SpoofOperator;
import org.apache.sysds.runtime.compress.CompressedMatrixBlock;
import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
+import org.apache.sysds.runtime.einsum.EOpNode;
+import org.apache.sysds.runtime.einsum.EOpNodeBinary;
+import org.apache.sysds.runtime.einsum.EOpNodeBinary.EBinaryOperand;
+import org.apache.sysds.runtime.einsum.EOpNodeData;
+import org.apache.sysds.runtime.einsum.EOpNodeFuse;
+import org.apache.sysds.runtime.einsum.EOpNodeUnary;
import org.apache.sysds.runtime.einsum.EinsumContext;
-import org.apache.sysds.runtime.functionobjects.*;
-import org.apache.sysds.runtime.matrix.data.LibMatrixMult;
+import org.apache.sysds.runtime.functionobjects.SwapIndex;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
-import org.apache.sysds.runtime.matrix.operators.AggregateOperator;
-import org.apache.sysds.runtime.matrix.operators.AggregateUnaryOperator;
-import org.apache.sysds.runtime.matrix.operators.BinaryOperator;
import org.apache.sysds.runtime.matrix.operators.Operator;
import org.apache.sysds.runtime.matrix.operators.ReorgOperator;
-import org.apache.sysds.runtime.matrix.operators.SimpleOperator;
+import org.apache.sysds.utils.Explain;
-import java.util.*;
-import java.util.function.Predicate;
+
+import static org.apache.sysds.api.DMLScript.EXPLAIN;
+import static org.apache.sysds.hops.rewrite.RewriteMatrixMultChainOptimization.mmChainDP;
+
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.HashSet;
+import java.util.Iterator;
+import java.util.LinkedList;
+import java.util.List;
+import java.util.Map;
+import java.util.Set;
public class EinsumCPInstruction extends BuiltinNaryCPInstruction {
- public static boolean FORCE_CELL_TPL = false;
+ public static final boolean FORCE_CELL_TPL = false; // naive looped solution
+
+ public static final boolean FUSE_OUTER_MULTIPLY = true;
+ public static final boolean FUSE_OUTER_MULTIPLY_EXCEEDS_L2_CACHE_CHECK = true;
+
+ public static final boolean PRINT_TRACE = true;
+
protected static final Log LOG = LogFactory.getLog(EinsumCPInstruction.class.getName());
public String eqStr;
private final int _numThreads;
@@ -62,15 +79,13 @@ public class EinsumCPInstruction extends BuiltinNaryCPInstruction {
public EinsumCPInstruction(Operator op, String opcode, String istr, CPOperand out, CPOperand... inputs)
{
super(op, opcode, istr, out, inputs);
- _numThreads = OptimizerUtils.getConstrainedNumThreads(-1);
+ _numThreads = OptimizerUtils.getConstrainedNumThreads(-1)/2;
_in = inputs;
this.eqStr = inputs[0].getName();
- Logger.getLogger(EinsumCPInstruction.class).setLevel(Level.TRACE);
+ if (PRINT_TRACE)
+ Logger.getLogger(EinsumCPInstruction.class).setLevel(Level.TRACE);
}
- @SuppressWarnings("unused")
- private EinsumContext einc = null;
-
@Override
public void processInstruction(ExecutionContext ec) {
//get input matrices and scalars, incl pinning of matrices
@@ -81,119 +96,159 @@ public void processInstruction(ExecutionContext ec) {
if(mb instanceof CompressedMatrixBlock){
mb = ((CompressedMatrixBlock) mb).getUncompressed("Spoof instruction");
}
+ if(mb.getNumRows() == 1){
+ ensureMatrixBlockColumnVector(mb);
+ }
inputs.add(mb);
}
}
EinsumContext einc = EinsumContext.getEinsumContext(eqStr, inputs);
- this.einc = einc;
String resultString = einc.outChar2 != null ? String.valueOf(einc.outChar1) + einc.outChar2 : einc.outChar1 != null ? String.valueOf(einc.outChar1) : "";
- if( LOG.isDebugEnabled() ) LOG.trace("outrows:"+einc.outRows+", outcols:"+einc.outCols);
+ if( LOG.isTraceEnabled() ) LOG.trace("output: "+resultString +" "+einc.outRows+"x"+einc.outCols);
- ArrayList inputsChars = einc.newEquationStringInputsSplit;
+ List inputsChars = einc.newEquationStringInputsSplit;
if(LOG.isTraceEnabled()) LOG.trace(String.join(",",einc.newEquationStringInputsSplit));
-
- contractDimensionsAndComputeDiagonals(einc, inputs);
+ List eOpNodes = new ArrayList<>(inputsChars.size());
+ List eOpNodesScalars = new ArrayList<>(inputsChars.size()); // computed separately and not included into plan until it is already created
//make all vetors col vectors
for(int i = 0; i < inputs.size(); i++){
- if(inputs.get(i) != null && inputsChars.get(i).length() == 1) EnsureMatrixBlockColumnVector(inputs.get(i));
+ if(inputsChars.get(i).length() == 1) ensureMatrixBlockColumnVector(inputs.get(i));
}
- if(LOG.isTraceEnabled()) for(Character c : einc.characterAppearanceIndexes.keySet()){
- ArrayList a = einc.characterAppearanceIndexes.get(c);
- LOG.trace(c+" count= "+a.size());
+ addSumDimensionsDiagonalsAndScalars(einc, inputsChars, eOpNodes, eOpNodesScalars, einc.charToDimensionSize);
+
+ Map characterToOccurences = einc.characterAppearanceCount;
+
+ for (int i = 0; i < inputsChars.size(); i++) {
+ if (inputsChars.get(i) == null) continue;
+ Character c1 = inputsChars.get(i).isEmpty() ? null : inputsChars.get(i).charAt(0);
+ Character c2 = inputsChars.get(i).length() > 1 ? inputsChars.get(i).charAt(1) : null;
+ Integer dim1 = c1 == null ? null : einc.charToDimensionSize.get(c1);
+ Integer dim2 = c1 == null ? null : einc.charToDimensionSize.get(c2);
+ EOpNodeData n = new EOpNodeData(c1,c2,dim1,dim2, i);
+ eOpNodes.add(n);
}
- // compute scalar by suming-all matrices:
- Double scalar = null;
- for(int i=0;i< inputs.size(); i++){
- String s = inputsChars.get(i);
- if(s.equals("")){
- MatrixBlock mb = inputs.get(i);
- if (scalar == null) scalar = mb.get(0,0);
- else scalar*= mb.get(0,0);
- inputs.set(i,null);
- inputsChars.set(i,null);
+ List ret = new ArrayList<>();
+ addVectorMultiplies(eOpNodes, eOpNodesScalars,characterToOccurences, einc.outChar1, einc.outChar2, ret);
+ eOpNodes = ret;
+
+ List plan;
+ ArrayList remainingMatrices;
+
+ if(!FORCE_CELL_TPL) {
+ plan = generateGreedyPlan(eOpNodes, eOpNodesScalars,
+ einc.charToDimensionSize, characterToOccurences, einc.outChar1, einc.outChar2);
+ if(!eOpNodesScalars.isEmpty()){
+ EOpNode l = eOpNodesScalars.get(0);
+ for(int i = 1; i < eOpNodesScalars.size(); i++){
+ l = new EOpNodeBinary(l, eOpNodesScalars.get(i), EBinaryOperand.scalar_scalar);
+ }
+
+ if(plan.isEmpty()) plan.add(l);
+ else {
+ int minCost = Integer.MAX_VALUE;
+ EOpNode addToNode = null;
+ int minIdx = -1;
+ for(int i = 0; i < plan.size(); i++) {
+ EOpNode n = plan.get(i);
+ Pair costAndNode = addScalarToPlanFindMinCost(n, einc.charToDimensionSize);
+ if(costAndNode.getLeft() < minCost) {
+ minCost = costAndNode.getLeft();
+ addToNode = costAndNode.getRight();
+ minIdx = i;
+ }
+ }
+ plan.set(minIdx, mergeEOpNodeWithScalar(addToNode, l));
+ }
+
}
- }
- if (scalar != null) {
- inputsChars.add("");
- inputs.add(new MatrixBlock(scalar));
- }
+ if(plan.size() == 2 && plan.get(0).c2 == null && plan.get(1).c2 == null){
+ if (plan.get(0).c1 == einc.outChar1 && plan.get(1).c1 == einc.outChar2)
+ plan.set(0, new EOpNodeBinary(plan.get(0), plan.get(1), EBinaryOperand.A_B));
+ if (plan.get(0).c1 == einc.outChar2 && plan.get(1).c1 == einc.outChar1)
+ plan.set(0, new EOpNodeBinary(plan.get(1), plan.get(0), EBinaryOperand.A_B));
+ plan.remove(1);
+ }
- HashMap characterToOccurences = new HashMap<>();
- for (Character key :einc.characterAppearanceIndexes.keySet()) {
- characterToOccurences.put(key, einc.characterAppearanceIndexes.get(key).size());
- }
- for (Character key :einc.charToDimensionSize.keySet()) {
- if(!characterToOccurences.containsKey(key))
- characterToOccurences.put(key, 1);
- }
+ if(plan.size() == 1)
+ plan.set(0,plan.get(0).reorderChildrenAndOptimize(null, einc.outChar1, einc.outChar2));
- ArrayList eOpNodes = new ArrayList<>(inputsChars.size());
- for (int i = 0; i < inputsChars.size(); i++) {
- if (inputsChars.get(i) == null) continue;
- EOpNodeData n = new EOpNodeData(inputsChars.get(i).length() > 0 ? inputsChars.get(i).charAt(0) : null, inputsChars.get(i).length() > 1 ? inputsChars.get(i).charAt(1) : null, i);
- eOpNodes.add(n);
+ if (EXPLAIN != Explain.ExplainType.NONE ) {
+ System.out.println("Einsum plan:");
+ for(int i = 0; i < plan.size(); i++) {
+ System.out.println((i + 1) + ".");
+ System.out.println("- " + String.join("\n- ", plan.get(i).recursivePrintString()));
+ }
+ }
+
+ remainingMatrices = executePlan(plan, inputs);
+ }else{
+ plan = eOpNodes;
+ remainingMatrices = inputs;
}
- Pair > plan = FORCE_CELL_TPL ? null : generatePlan(0, eOpNodes, einc.charToDimensionSize, characterToOccurences, einc.outChar1, einc.outChar2);
- ArrayList resMatrices = FORCE_CELL_TPL ? null : executePlan(plan.getRight(), inputs);
-// ArrayList resMatrices = executePlan(plan.getRight(), inputs, true);
- if(!FORCE_CELL_TPL && resMatrices.size() == 1){
- EOpNode resNode = plan.getRight().get(0);
+ if(!FORCE_CELL_TPL && remainingMatrices.size() == 1){
+ EOpNode resNode = plan.get(0);
if (einc.outChar1 != null && einc.outChar2 != null){
if(resNode.c1 == einc.outChar1 && resNode.c2 == einc.outChar2){
- ec.setMatrixOutput(output.getName(), resMatrices.get(0));
+ ec.setMatrixOutput(output.getName(), remainingMatrices.get(0));
}
else if(resNode.c1 == einc.outChar2 && resNode.c2 == einc.outChar1){
+ if( LOG.isTraceEnabled()) LOG.trace("Transposing the final result");
+
ReorgOperator transpose = new ReorgOperator(SwapIndex.getSwapIndexFnObject(), _numThreads);
- MatrixBlock resM = resMatrices.get(0).reorgOperations(transpose, new MatrixBlock(),0,0,0);
+ MatrixBlock resM = remainingMatrices.get(0).reorgOperations(transpose, new MatrixBlock(),0,0,0);
ec.setMatrixOutput(output.getName(), resM);
}else{
- if(LOG.isTraceEnabled()) LOG.trace("Einsum expected: "+resultString + ", got: "+resNode.c1+resNode.c2);
- throw new RuntimeException("Einsum plan produced different result");
+ if(LOG.isTraceEnabled()) LOG.trace("Einsum error, expected: "+resultString + ", got: "+resNode.c1+resNode.c2);
+ throw new RuntimeException("Einsum plan produced different result, expected: "+resultString + ", got: "+resNode.c1+resNode.c2);
}
}else if (einc.outChar1 != null){
if(resNode.c1 == einc.outChar1 && resNode.c2 == null){
- ec.setMatrixOutput(output.getName(), resMatrices.get(0));
+ ensureMatrixBlockColumnVector(remainingMatrices.get(0));
+ ec.setMatrixOutput(output.getName(), remainingMatrices.get(0));
}else{
if(LOG.isTraceEnabled()) LOG.trace("Einsum expected: "+resultString + ", got: "+resNode.c1+resNode.c2);
throw new RuntimeException("Einsum plan produced different result");
}
}else{
if(resNode.c1 == null && resNode.c2 == null){
- ec.setScalarOutput(output.getName(), new DoubleObject(resMatrices.get(0).get(0, 0)));;
+ ec.setScalarOutput(output.getName(), new DoubleObject(remainingMatrices.get(0).get(0, 0)));;
}
}
}else{
// use cell template with loops for remaining
- ArrayList mbs = resMatrices;
- ArrayList chars = new ArrayList<>();
+ ArrayList mbs = remainingMatrices;
+ List chars = new ArrayList<>();
- for (int i = 0; i < plan.getRight().size(); i++) {
+ for (int i = 0; i < plan.size(); i++) {
String s;
- if(plan.getRight().get(i).c1 == null) s = "";
- else if(plan.getRight().get(i).c2 == null) s = plan.getRight().get(i).c1.toString();
- else s = plan.getRight().get(i).c1.toString() + plan.getRight().get(i).c2;
+ if(plan.get(i).c1 == null) s = "";
+ else if(plan.get(i).c2 == null) s = plan.get(i).c1.toString();
+ else s = plan.get(i).c1.toString() + plan.get(i).c2;
chars.add(s);
}
- ArrayList summingChars = new ArrayList<>();
- for (Character c : einc.characterAppearanceIndexes.keySet()) {
+ List