diff --git a/index.html b/index.html
index 4121ace..c2cf5f9 100644
--- a/index.html
+++ b/index.html
@@ -18,9 +18,9 @@
Interactive Convolution Visualizer
@@ -43,6 +43,13 @@
Interactive Convolution Visualizer
+
+
+
+
diff --git a/script.js b/script.js
index 26cc027..e3bbe8d 100644
--- a/script.js
+++ b/script.js
@@ -10,6 +10,39 @@ document.addEventListener('DOMContentLoaded', function() {
let currentStep = 0;
let animationSpeed = 500; // milliseconds
+ // MNIST 28x28 grayscale image data representing the digit "5"
+ // Values are actual pixel intensities (0-255) creating a recognizable digit pattern
+ const mnistImageData = [
+ [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
+ [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
+ [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
+ [0, 0, 0, 0, 30, 60, 120, 180, 220, 255, 255, 255, 255, 255, 255, 255, 255, 220, 180, 120, 60, 30, 0, 0, 0, 0, 0, 0],
+ [0, 0, 0, 0, 50, 120, 200, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 200, 120, 50, 0, 0, 0, 0, 0, 0],
+ [0, 0, 0, 0, 70, 150, 220, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 220, 150, 70, 0, 0, 0, 0, 0, 0],
+ [0, 0, 0, 0, 80, 160, 230, 255, 255, 255, 200, 120, 80, 40, 20, 10, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
+ [0, 0, 0, 0, 90, 170, 240, 255, 255, 200, 100, 50, 20, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
+ [0, 0, 0, 0, 100, 180, 250, 255, 180, 80, 30, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
+ [0, 0, 0, 0, 110, 190, 255, 200, 80, 20, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
+ [0, 0, 0, 0, 120, 200, 255, 150, 40, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
+ [0, 0, 0, 0, 130, 210, 255, 200, 140, 100, 80, 60, 40, 20, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
+ [0, 0, 0, 0, 140, 220, 255, 255, 255, 255, 255, 255, 255, 240, 180, 120, 60, 30, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
+ [0, 0, 0, 0, 150, 230, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 240, 180, 120, 60, 30, 0, 0, 0, 0, 0, 0, 0],
+ [0, 0, 0, 0, 160, 240, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 240, 180, 120, 60, 30, 0, 0, 0, 0, 0],
+ [0, 0, 0, 0, 40, 80, 120, 160, 200, 240, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 240, 180, 120, 60, 30, 0, 0, 0],
+ [0, 0, 0, 0, 0, 0, 0, 0, 20, 60, 120, 180, 220, 255, 255, 255, 255, 255, 255, 255, 255, 255, 240, 180, 120, 60, 30, 0],
+ [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 20, 60, 120, 180, 220, 255, 255, 255, 255, 255, 255, 255, 240, 180, 120, 60, 30],
+ [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 20, 60, 120, 180, 220, 255, 255, 255, 255, 255, 255, 240, 180, 120, 60],
+ [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 20, 60, 120, 180, 220, 255, 255, 255, 255, 255, 240, 180, 120],
+ [0, 0, 0, 0, 30, 60, 120, 180, 220, 255, 200, 140, 80, 40, 20, 40, 80, 140, 200, 255, 255, 255, 255, 255, 240, 180, 120, 60],
+ [0, 0, 0, 0, 50, 120, 200, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 240, 180, 120, 60, 30],
+ [0, 0, 0, 0, 70, 150, 220, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 240, 180, 120, 60, 30, 0],
+ [0, 0, 0, 0, 80, 160, 230, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 240, 180, 120, 60, 30, 0, 0, 0, 0],
+ [0, 0, 0, 0, 30, 60, 120, 180, 220, 255, 255, 255, 255, 255, 255, 240, 180, 120, 60, 30, 0, 0, 0, 0, 0, 0, 0, 0],
+ [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
+ [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
+ [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
+ ];
+
// Get DOM elements
const inputHeightEl = document.getElementById('input-height');
const inputWidthEl = document.getElementById('input-width');
@@ -18,6 +51,7 @@ document.addEventListener('DOMContentLoaded', function() {
const paddingEl = document.getElementById('padding');
const dilationEl = document.getElementById('dilation');
const strideEl = document.getElementById('stride');
+ const inputTypeEl = document.getElementById('input-type');
const updateBtn = document.getElementById('update-btn');
const animateBtn = document.getElementById('animate-btn');
const stopBtn = document.getElementById('stop-btn');
@@ -180,10 +214,26 @@ document.addEventListener('DOMContentLoaded', function() {
}
function generateMatrices(inputHeight, inputWidth, kernelHeight, kernelWidth, padding, outputHeight, outputWidth) {
- // Generate input matrix with random values (0-9)
- inputMatrix = Array(inputHeight).fill().map(() =>
- Array(inputWidth).fill().map(() => Math.floor(Math.random() * 10))
- );
+ const inputType = inputTypeEl.value;
+
+ if (inputType === 'image') {
+ // Generate input matrix from MNIST image data (no normalization)
+ inputMatrix = Array(inputHeight).fill().map((_, i) =>
+ Array(inputWidth).fill().map((_, j) => {
+ // Map to image data coordinates
+ const imgRow = Math.floor((i / inputHeight) * mnistImageData.length);
+ const imgCol = Math.floor((j / inputWidth) * mnistImageData[0].length);
+ const pixelValue = mnistImageData[imgRow][imgCol];
+ // Return actual pixel values (0-255) without normalization
+ return pixelValue;
+ })
+ );
+ } else {
+ // Generate input matrix with random values (0-9)
+ inputMatrix = Array(inputHeight).fill().map(() =>
+ Array(inputWidth).fill().map(() => Math.floor(Math.random() * 10))
+ );
+ }
// Generate weight matrix with random values (-2 to 2)
weightMatrix = Array(kernelHeight).fill().map(() =>
@@ -233,9 +283,17 @@ document.addEventListener('DOMContentLoaded', function() {
function renderMatrix(matrix, container, type, padding = 0, inputHeight = 0, inputWidth = 0) {
container.innerHTML = '';
+ const inputType = inputTypeEl.value;
+
+ // Determine cell size based on matrix size and type
+ let cellSize = 40;
+ if (type === 'input' && inputType === 'image' && inputHeight >= 20) {
+ cellSize = Math.max(15, Math.min(25, 600 / Math.max(inputHeight, inputWidth)));
+ }
+
// Set grid dimensions
- container.style.gridTemplateRows = `repeat(${matrix.length}, 40px)`;
- container.style.gridTemplateColumns = `repeat(${matrix[0].length}, 40px)`;
+ container.style.gridTemplateRows = `repeat(${matrix.length}, ${cellSize}px)`;
+ container.style.gridTemplateColumns = `repeat(${matrix[0].length}, ${cellSize}px)`;
for (let i = 0; i < matrix.length; i++) {
for (let j = 0; j < matrix[i].length; j++) {
@@ -243,14 +301,79 @@ document.addEventListener('DOMContentLoaded', function() {
cell.className = 'cell';
cell.id = `${type}-${i}-${j}`;
+ // Apply dynamic sizing for image cells
+ if (type === 'input' && inputType === 'image' && inputHeight >= 20) {
+ cell.style.width = `${cellSize}px`;
+ cell.style.height = `${cellSize}px`;
+ cell.style.fontSize = `${Math.max(8, cellSize * 0.4)}px`;
+ }
+
// Check if this is a padding cell for the input matrix
if (type === 'input' && (i < padding || i >= inputHeight + padding || j < padding || j >= inputWidth + padding)) {
cell.classList.add('padding-cell');
} else {
cell.classList.add(`${type}-cell`);
+
+ // Add image background for input cells when using image input
+ if (type === 'input' && inputType === 'image' && !(i < padding || i >= inputHeight + padding || j < padding || j >= inputWidth + padding)) {
+ cell.classList.add('image-cell');
+
+ // Calculate the corresponding image pixel
+ const actualRow = i - padding;
+ const actualCol = j - padding;
+ const imgRow = Math.floor((actualRow / inputHeight) * mnistImageData.length);
+ const imgCol = Math.floor((actualCol / inputWidth) * mnistImageData[0].length);
+ const pixelValue = mnistImageData[imgRow][imgCol];
+
+ // Set background color based on pixel value (grayscale)
+ cell.style.backgroundColor = `rgb(${pixelValue}, ${pixelValue}, ${pixelValue})`;
+ }
+
+ // Add image-based background for output cells when using image input
+ if (type === 'output' && inputType === 'image') {
+ cell.classList.add('image-cell');
+
+ // Map output value to a color intensity
+ // Find min/max values in the output matrix for normalization
+ const flatOutput = outputMatrix.flat();
+ const minOutput = Math.min(...flatOutput);
+ const maxOutput = Math.max(...flatOutput);
+ const outputValue = matrix[i][j];
+
+ // Normalize output value to 0-255 range for visualization
+ let normalizedValue;
+ if (maxOutput !== minOutput) {
+ normalizedValue = Math.round(((outputValue - minOutput) / (maxOutput - minOutput)) * 255);
+ } else {
+ normalizedValue = 128; // Middle gray if all values are the same
+ }
+
+ // Apply a blue-to-red colormap for better visualization of positive/negative values
+ let red, green, blue;
+ if (outputValue < 0) {
+ // Negative values: blue tones
+ const intensity = Math.abs(outputValue - minOutput) / Math.abs(minOutput) * 255;
+ red = Math.max(0, 255 - intensity);
+ green = Math.max(0, 255 - intensity);
+ blue = 255;
+ } else {
+ // Positive values: red tones
+ const intensity = (outputValue / maxOutput) * 255;
+ red = 255;
+ green = Math.max(0, 255 - intensity);
+ blue = Math.max(0, 255 - intensity);
+ }
+
+ // Set background color with transparency to show the pattern
+ cell.style.backgroundColor = `rgba(${red}, ${green}, ${blue}, 0.6)`;
+ }
}
- cell.textContent = matrix[i][j];
+ // Use a span for the text content to ensure it's above the background
+ const textSpan = document.createElement('span');
+ textSpan.textContent = matrix[i][j];
+ cell.appendChild(textSpan);
+
container.appendChild(cell);
}
}
diff --git a/styles.css b/styles.css
index 7866001..cf070b7 100644
--- a/styles.css
+++ b/styles.css
@@ -71,6 +71,14 @@ input[type="number"] {
text-align: center;
}
+select {
+ padding: 8px;
+ border: 1px solid #ddd;
+ border-radius: 4px;
+ background-color: white;
+ min-width: 120px;
+}
+
button {
background-color: #4285f4;
color: white;
@@ -139,6 +147,36 @@ button:hover {
background-color: var(--input-color);
}
+.input-cell.image-cell {
+ background-size: cover;
+ background-position: center;
+ background-repeat: no-repeat;
+ color: white;
+ text-shadow: 1px 1px 3px rgba(0,0,0,1);
+ font-weight: 900;
+ position: relative;
+ font-size: 10px;
+}
+
+.input-cell.image-cell::after {
+ content: '';
+ position: absolute;
+ top: 0;
+ left: 0;
+ right: 0;
+ bottom: 0;
+ background-color: rgba(0, 0, 0, 0.1);
+ z-index: 1;
+ pointer-events: none;
+}
+
+.input-cell.image-cell span {
+ position: relative;
+ z-index: 2;
+ font-size: 10px;
+ font-weight: 900;
+}
+
.weight-cell {
background-color: var(--weight-color);
}
@@ -147,6 +185,19 @@ button:hover {
background-color: var(--output-color);
}
+.output-cell.image-cell {
+ color: white;
+ text-shadow: 1px 1px 3px rgba(0,0,0,1);
+ font-weight: 900;
+ position: relative;
+}
+
+.output-cell.image-cell span {
+ position: relative;
+ z-index: 2;
+ font-weight: 900;
+}
+
.padding-cell {
background-color: var(--padding-color);
color: #999;