Skip to content

juliobellano/CNN

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

5 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

📊 CNN for GuessTheCorrelation.com

This repository contains a PyTorch implementation of a Convolutional Neural Network (CNN) designed to predict the correlation coefficient from scatter plot images, based on the gameplay of guessthecorrelation.com. The model takes an image of a scatter plot and outputs a continuous value representing the Pearson correlation coefficient ($r$), which ranges from $-1$ to $1$.


🚀 Model Performance Summary

The model demonstrates excellent performance in predicting the correlation coefficient from unseen test data:

Metric Value Interpretation
Test Loss (MSE) $0.0020$ The average squared difference between the actual and predicted correlation is very small.
Correlation ($r$) between Actual & Predicted $0.9976$ The model's predictions have a near-perfect linear relationship with the actual correlation values.

Prediction vs. Actual Scatter Plot

The scatter plot below visually confirms the strong performance. The predicted values closely follow the red dashed line, which represents perfect prediction (where predicted equals actual).

!


🛠️ Implementation Details

1. Data and Preprocessing

  • Dataset: The model is trained on a custom dataset of scatter plot images and their corresponding correlation coefficients, sourced from the game data (responses.csv).
  • Data Split: The dataset is split into Training (80%), Validation (10%), and Test (10%) sets for robust evaluation.
    • Training Data: 120,000 samples
    • Validation Data: 15,000 samples
    • Test Data: 15,000 samples
  • Custom Dataset Class (ImageDataset): Handles image loading and data pairing.
    • Images are loaded from the input/images directory.
    • Images are converted to grayscale (convert('1')) as scatter plots are black and white, reducing input channels to 1.
  • Transforms:
    • transforms.ToTensor(): Converts the PIL image to a PyTorch tensor.
    • transforms.Normalize((0.5,), (0.5,)): Normalizes the single-channel image data.

2. CNN Architecture (ConvNet)

The network is a standard CNN designed for image processing and tailored for this regression task.

[Image of a general Convolutional Neural Network architecture]

  • Input Channel: $1$ (Grayscale image).
  • Layers:
    1. Convolutional Layer 1 (conv1): $1$ input channel, $8$ output channels, $3\times3$ kernel.
    2. Max Pooling (pool): $2\times2$ kernel.
    3. 2D Dropout (dropout2d): $p=0.2$ for regularization.
    4. Convolutional Layer 2 (conv2): $8$ input channels, $16$ output channels, $3\times3$ kernel.
    5. Max Pooling (pool): $2\times2$ kernel.
    6. Flatten: Prepares the feature maps for the fully connected layers ($16 \times 36 \times 36$ size).
    7. Fully Connected Layer 1 (fc1): Maps $16 \times 36 \times 36$ features to $256$ outputs.
    8. Dropout (dropout): $p=0.4$ for regularization.
    9. Fully Connected Layer 2 (fc2): Maps $256$ inputs to $1$ output (the predicted correlation).
    10. Activation: $\text{Tanh}$ function is applied to the final output, which conveniently scales the prediction to the required range of $[-1, 1]$ (the range of the correlation coefficient).

3. Training Details

  • Device: Configured to use MPS (Metal Performance Shaders) for Apple M1/M2 chips, falling back to CPU if unavailable.
  • Hyperparameters:
    • Epochs (n_epochs): $3$
    • Batch Size (batchsize): $64$
    • Learning Rate (learning_rate): $0.001$
  • Loss Function (criterion): Mean Squared Error (nn.MSELoss) is used for this regression problem.
  • Optimizer: Adam (torch.optim.Adam) is used for optimization.
  • Validation: The model is evaluated on the validation set every $200$ steps during training to monitor performance (both validation loss and actual vs. predicted correlation).

About

Convolutional Neural Network, Vision Transformers, Pytorch

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors