diff --git a/.gitignore b/.gitignore index d86e4ff..5bceaac 100644 --- a/.gitignore +++ b/.gitignore @@ -169,3 +169,16 @@ cython_debug/ # PyPI configuration file .pypirc + +# ignore particular files +src/nn/permutation_pred.py +src/ordered_permutation_model_100_earlystop_butno.pth +src/ordered_permutation_model_weird_100_epoch.pth +src/ordered_permutation_model_5_qubit.pth +src/ordered_permutation_model3-4-4.pth +src/ordered_permutation_model3-4.pth +src/ordered_permutation_model3-1.pth +src/ordered_permutation_model2-3.pth +src/ordered_permutation_model2-2.pth +src/ordered_permutation_model2-1.pth +src/ordered_permutation_model1.pth \ No newline at end of file diff --git a/.vscode/settings.json b/.vscode/settings.json new file mode 100644 index 0000000..61c714d --- /dev/null +++ b/.vscode/settings.json @@ -0,0 +1,5 @@ +{ + "python.testing.unittestArgs": ["-v", "-s", "./src", "-p", "*test.py"], + "python.testing.pytestEnabled": false, + "python.testing.unittestEnabled": true +} diff --git a/Technical report.pdf b/Technical report.pdf new file mode 100644 index 0000000..5903eb5 Binary files /dev/null and b/Technical report.pdf differ diff --git a/experiment_log.txt b/experiment_log.txt new file mode 100644 index 0000000..fff8206 --- /dev/null +++ b/experiment_log.txt @@ -0,0 +1,522 @@ +Model: BestQubitModel( + (fc1): Linear(in_features=32, out_features=128, bias=True) + (hidden_layers_list): ModuleList( + (0-2): 3 x Linear(in_features=128, out_features=128, bias=True) + ) + (fc_out): Linear(in_features=128, out_features=16, bias=True) + (dropout): Dropout(p=0.3, inplace=False) +) +Number of hidden layers: 4 +Optimizer: Adam ( +Parameter Group 0 + amsgrad: False + betas: (0.9, 0.999) + capturable: False + differentiable: False + eps: 1e-08 + foreach: None + fused: None + lr: 0.0001 + maximize: False + weight_decay: 0 +) +Number of epochs: 4709 +Patience: 1000 +Best training loss: 0.0870 +Best validation loss: 0.2626 + +================================================================================ + +Model: BestQubitModel( + (fc1): Linear(in_features=32, out_features=128, bias=True) + (hidden_layers_list): ModuleList( + (0-2): 3 x Linear(in_features=128, out_features=128, bias=True) + ) + (fc_out): Linear(in_features=128, out_features=16, bias=True) + (dropout): Dropout(p=0.3, inplace=False) +) +Number of hidden layers: 4 +Optimizer: Adam ( +Parameter Group 0 + amsgrad: False + betas: (0.9, 0.999) + capturable: False + differentiable: False + eps: 1e-08 + foreach: None + fused: None + lr: 0.0001 + maximize: False + weight_decay: 0 +) +Number of epochs: 3968 +Patience: 1000 +Best training loss: 0.1179 +Best validation loss: 0.2706 + +================================================================================ + +Model: BestQubitModel( + (fc1): Linear(in_features=32, out_features=128, bias=True) + (hidden_layers_list): ModuleList( + (0-2): 3 x Linear(in_features=128, out_features=128, bias=True) + ) + (fc_out): Linear(in_features=128, out_features=16, bias=True) + (dropout): Dropout(p=0.3, inplace=False) +) +Number of hidden layers: 4 +Optimizer: Adam ( +Parameter Group 0 + amsgrad: False + betas: (0.9, 0.999) + capturable: False + differentiable: False + eps: 1e-08 + foreach: None + fused: None + lr: 0.0001 + maximize: False + weight_decay: 0.0001 +) +Number of epochs: 4204 +Patience: 1000 +Best training loss: 0.0967 +Best validation loss: 0.2644 + +================================================================================ + +Model: BestQubitModel( + (fc1): Linear(in_features=32, out_features=128, bias=True) + (hidden_layers_list): ModuleList( + (0-2): 3 x Linear(in_features=128, out_features=128, bias=True) + ) + (fc_out): Linear(in_features=128, out_features=16, bias=True) + (dropout): Dropout(p=0.3, inplace=False) +) +Number of hidden layers: 4 +Optimizer: Adam ( +Parameter Group 0 + amsgrad: False + betas: (0.9, 0.999) + capturable: False + differentiable: False + eps: 1e-08 + foreach: None + fused: None + lr: 0.0001 + maximize: False + weight_decay: 0.0001 +) +Number of epochs: 3671 +Patience: 1000 +Best training loss: 0.1028 +Best validation loss: 0.2579 + +================================================================================ + +Model: BestQubitModel( + (fc1): Linear(in_features=32, out_features=128, bias=True) + (hidden_layers_list): ModuleList( + (0-2): 3 x Linear(in_features=128, out_features=128, bias=True) + ) + (fc_out): Linear(in_features=128, out_features=16, bias=True) + (dropout): Dropout(p=0.3, inplace=False) +) +Number of hidden layers: 4 +Optimizer: Adam ( +Parameter Group 0 + amsgrad: False + betas: (0.9, 0.999) + capturable: False + differentiable: False + eps: 1e-08 + foreach: None + fused: None + lr: 0.0001 + maximize: False + weight_decay: 0.0001 +) +Number of epochs: 4903 +Patience: 1000 +Best training loss: 0.1027 +Best validation loss: 0.2825 + +================================================================================ + +Model: BestQubitModel( + (fc1): Linear(in_features=32, out_features=128, bias=True) + (hidden_layers_list): ModuleList( + (0-2): 3 x Linear(in_features=128, out_features=128, bias=True) + ) + (fc_out): Linear(in_features=128, out_features=16, bias=True) + (dropout): Dropout(p=0.3, inplace=False) +) +Number of hidden layers: 4 +Optimizer: AdamW ( +Parameter Group 0 + amsgrad: False + betas: (0.9, 0.999) + capturable: False + differentiable: False + eps: 1e-08 + foreach: None + fused: None + lr: 0.0001 + maximize: False + weight_decay: 0.0001 +) +Number of epochs: 2022 +Patience: 1000 +Best training loss: 0.1695 +Best validation loss: 0.3072 + +================================================================================ + +Model: BestQubitModel( + (fc1): Linear(in_features=32, out_features=128, bias=True) + (hidden_layers_list): ModuleList( + (0-2): 3 x Linear(in_features=128, out_features=128, bias=True) + ) + (fc_out): Linear(in_features=128, out_features=16, bias=True) + (dropout): Dropout(p=0.3, inplace=False) +) +Number of hidden layers: 4 +Optimizer: AdamW ( +Parameter Group 0 + amsgrad: False + betas: (0.9, 0.999) + capturable: False + differentiable: False + eps: 1e-08 + foreach: None + fused: None + lr: 0.0001 + maximize: False + weight_decay: 0.0001 +) +Number of epochs: 3256 +Patience: 1000 +Best training loss: 0.1342 +Best validation loss: 0.2927 + +================================================================================ + +Model: BestQubitModel( + (fc1): Linear(in_features=32, out_features=128, bias=True) + (hidden_layers_list): ModuleList( + (0-2): 3 x Linear(in_features=128, out_features=128, bias=True) + ) + (fc_out): Linear(in_features=128, out_features=16, bias=True) + (dropout): Dropout(p=0.3, inplace=False) +) +Number of hidden layers: 4 +Optimizer: AdamW ( +Parameter Group 0 + amsgrad: False + betas: (0.9, 0.999) + capturable: False + differentiable: False + eps: 1e-08 + foreach: None + fused: None + lr: 0.0001 + maximize: False + weight_decay: 0.0001 +) +Number of epochs: 4499 +Patience: 1000 +Best training loss: 0.0847 +Best validation loss: 0.2784 + +================================================================================ + +Model: BestQubitModel( + (fc1): Linear(in_features=32, out_features=128, bias=True) + (hidden_layers_list): ModuleList( + (0-2): 3 x Linear(in_features=128, out_features=128, bias=True) + ) + (fc_out): Linear(in_features=128, out_features=16, bias=True) + (dropout): Dropout(p=0.3, inplace=False) +) +Number of hidden layers: 4 +Optimizer: AdamW ( +Parameter Group 0 + amsgrad: False + betas: (0.9, 0.999) + capturable: False + differentiable: False + eps: 1e-08 + foreach: None + fused: None + lr: 0.0001 + maximize: False + weight_decay: 0.0001 +) +Number of epochs: 4616 +Patience: 1000 +Best training loss: 0.0907 +Best validation loss: 0.2940 + +================================================================================ + +Model: BestQubitModel( + (fc1): Linear(in_features=32, out_features=128, bias=True) + (hidden_layers_list): ModuleList( + (0-2): 3 x Linear(in_features=128, out_features=128, bias=True) + ) + (fc_out): Linear(in_features=128, out_features=16, bias=True) + (dropout): Dropout(p=0.3, inplace=False) +) +Number of hidden layers: 4 +Optimizer: AdamW ( +Parameter Group 0 + amsgrad: False + betas: (0.9, 0.999) + capturable: False + differentiable: False + eps: 1e-08 + foreach: None + fused: None + lr: 0.0001 + maximize: False + weight_decay: 0.0001 +) +Number of epochs: 4802 +Patience: 1000 +Best training loss: 0.1008 +Best validation loss: 0.2530 + +================================================================================ + +Model: BestQubitModel( + (fc1): Linear(in_features=32, out_features=128, bias=True) + (hidden_layers_list): ModuleList( + (0-2): 3 x Linear(in_features=128, out_features=128, bias=True) + ) + (fc_out): Linear(in_features=128, out_features=16, bias=True) + (dropout): Dropout(p=0.3, inplace=False) +) +Number of hidden layers: 4 +Optimizer: AdamW ( +Parameter Group 0 + amsgrad: False + betas: (0.9, 0.999) + capturable: False + differentiable: False + eps: 1e-08 + foreach: None + fused: None + lr: 0.0001 + maximize: False + weight_decay: 0.0001 +) +Number of epochs: 6218 +Patience: 1000 +Best training loss: 0.1074 +Best validation loss: 0.2965 + +================================================================================ + +Model: BestQubitModel( + (fc1): Linear(in_features=32, out_features=128, bias=True) + (hidden_layers_list): ModuleList( + (0-2): 3 x Linear(in_features=128, out_features=128, bias=True) + ) + (fc_out): Linear(in_features=128, out_features=16, bias=True) + (dropout): Dropout(p=0.3, inplace=False) +) +Number of hidden layers: 4 +Optimizer: AdamW ( +Parameter Group 0 + amsgrad: False + betas: (0.9, 0.999) + capturable: False + differentiable: False + eps: 1e-08 + foreach: None + fused: None + lr: 0.0001 + maximize: False + weight_decay: 0.0001 +) +Number of epochs: 2157 +Patience: 1000 +Best training loss: 0.8267 +Best validation loss: 0.7643 + +================================================================================ + +Model: BestQubitModel( + (fc1): Linear(in_features=32, out_features=128, bias=True) + (hidden_layers_list): ModuleList( + (0-2): 3 x Linear(in_features=128, out_features=128, bias=True) + ) + (fc_out): Linear(in_features=128, out_features=16, bias=True) + (dropout): Dropout(p=0.3, inplace=False) +) +Number of hidden layers: 4 +Optimizer: AdamW ( +Parameter Group 0 + amsgrad: False + betas: (0.9, 0.999) + capturable: False + differentiable: False + eps: 1e-08 + foreach: None + fused: None + lr: 0.0001 + maximize: False + weight_decay: 0.0001 +) +Number of epochs: 2174 +Patience: 1000 +Best training loss: 0.8120 +Best validation loss: 0.7643 + +================================================================================ + +Model: BestQubitModel( + (fc1): Linear(in_features=32, out_features=128, bias=True) + (hidden_layers_list): ModuleList( + (0-2): 3 x Linear(in_features=128, out_features=128, bias=True) + ) + (fc_out): Linear(in_features=128, out_features=16, bias=True) + (dropout): Dropout(p=0.3, inplace=False) +) +Number of hidden layers: 4 +Optimizer: AdamW ( +Parameter Group 0 + amsgrad: False + betas: (0.9, 0.999) + capturable: False + differentiable: False + eps: 1e-08 + foreach: None + fused: None + lr: 0.0001 + maximize: False + weight_decay: 0.0001 +) +Number of epochs: 1321 +Patience: 1000 +Best training loss: 0.6876 +Best validation loss: 0.8739 + +================================================================================ + +Model: BestQubitModel( + (fc1): Linear(in_features=32, out_features=128, bias=True) + (hidden_layers_list): ModuleList( + (0-8): 9 x Linear(in_features=128, out_features=128, bias=True) + ) + (fc_out): Linear(in_features=128, out_features=16, bias=True) + (dropout): Dropout(p=0.5, inplace=False) +) +Number of hidden layers: 10 +Optimizer: AdamW ( +Parameter Group 0 + amsgrad: False + betas: (0.9, 0.999) + capturable: False + differentiable: False + eps: 1e-08 + foreach: None + fused: None + lr: 0.001 + maximize: False + weight_decay: 0.0001 +) +Number of epochs: 1327 +Patience: 1000 +Best training loss: 0.7197 +Best validation loss: 0.7976 + +================================================================================ + +Model: BestQubitModel( + (fc1): Linear(in_features=32, out_features=128, bias=True) + (hidden_layers_list): ModuleList( + (0-8): 9 x Linear(in_features=128, out_features=128, bias=True) + ) + (fc_out): Linear(in_features=128, out_features=16, bias=True) + (dropout): Dropout(p=0.5, inplace=False) +) +Number of hidden layers: 10 +Optimizer: AdamW ( +Parameter Group 0 + amsgrad: False + betas: (0.9, 0.999) + capturable: False + differentiable: False + eps: 1e-08 + foreach: None + fused: None + lr: 0.001 + maximize: False + weight_decay: 0.0001 +) +Number of epochs: 1521 +Patience: 1000 +Best training loss: 0.4948 +Best validation loss: 0.5179 + +================================================================================ + +Model: BestQubitModel( + (fc1): Linear(in_features=32, out_features=64, bias=True) + (hidden_layers_list): ModuleList( + (0): Linear(in_features=64, out_features=64, bias=True) + ) + (fc_out): Linear(in_features=64, out_features=16, bias=True) + (dropout): Dropout(p=0.5, inplace=False) +) +Number of hidden layers: 2 +Optimizer: Adam ( +Parameter Group 0 + amsgrad: False + betas: (0.9, 0.999) + capturable: False + differentiable: False + eps: 1e-08 + foreach: None + fused: None + lr: 0.0001 + maximize: False + weight_decay: 0.0001 +) +Number of epochs: 5323 +Patience: 1000 +Best training loss: 0.2356 +Best validation loss: 0.3459 + +================================================================================ + +Model: BestQubitModel( + (fc1): Linear(in_features=32, out_features=128, bias=True) + (hidden_layers_list): ModuleList( + (0-1): 2 x Linear(in_features=128, out_features=128, bias=True) + ) + (fc_out): Linear(in_features=128, out_features=16, bias=True) + (dropout): Dropout(p=0.5, inplace=False) +) +Number of hidden layers: 3 +Optimizer: Adam ( +Parameter Group 0 + amsgrad: False + betas: (0.9, 0.999) + capturable: False + differentiable: False + eps: 1e-08 + foreach: None + fused: None + lr: 0.0001 + maximize: False + weight_decay: 0.0001 +) +Number of epochs: 5157 +Patience: 1000 +Best training loss: 0.2217 +Best validation loss: 0.3809 + +================================================================================ + diff --git a/models/basemodel_curriculum_nrgates_21.pt b/models/basemodel_curriculum_nrgates_21.pt new file mode 100644 index 0000000..a292d37 Binary files /dev/null and b/models/basemodel_curriculum_nrgates_21.pt differ diff --git a/models/best_model_plot.png b/models/best_model_plot.png new file mode 100644 index 0000000..afa1567 Binary files /dev/null and b/models/best_model_plot.png differ diff --git a/models/finetuned_model_up_to_nr_gates_10.pt b/models/finetuned_model_up_to_nr_gates_10.pt new file mode 100644 index 0000000..e9d8d44 Binary files /dev/null and b/models/finetuned_model_up_to_nr_gates_10.pt differ diff --git a/models/ordered_permutation_model.pth b/models/ordered_permutation_model.pth new file mode 100644 index 0000000..cf90c65 Binary files /dev/null and b/models/ordered_permutation_model.pth differ diff --git a/src/nn/README.md b/src/nn/README.md new file mode 100644 index 0000000..8fe3559 --- /dev/null +++ b/src/nn/README.md @@ -0,0 +1,19 @@ +# Permutation-based Approach + +1. The code includes several encoders, including Transformer-based encoder, which makes training rather time-consuming and does not improve enough. + + **The current configuration is highly recommended: CNN-based encoder + Transformer-based decoder.** + +2. The class `TableauPermutationDataset` is flexible. Please refer to for its example usage. + `training_data_perm.pkl` contains 3200 4-qubit data points, while `training_data_perm_4_qubit.pkl` contains 9600 4-qubit data points. + **By experiments, it is sufficient to train with 3200 data points with 50 epochs, which takes around 20 minutes or even fewer** (See `src/loss_history3-4.png`). If training with 9600 data points, it would be better to increase to 100 epochs or 200 epochs (See `src/loss_history3-4.png`). +3. `permutation_math.py` contains mixed code for **Model 2** and **Model 3**. + There are 2 example usage of **Model 3** at the end of `permutation_math.py`. (See ). I usually uncomment the example usage code and simply run `permutation_math.py` for **fast debugging**. + These 2 example usage illustrate using **Model 3** without and with supervised-learning-based fine-tuning. **Without SL fine-tuning is recommended** as fine-tuning takes time and basically does not improve. + There is a function of RL fine-tuning but commented out, because it is not ready to use and sortof put aside. + Remember to comment out the example usage code in `permutation_math.py` when running `nn_eval_main.py`, otherwise an error will be raised. +4. If deciding to use SL fine-tuning, please remember to use different datasets for pre-training and fine-tuning respectively (See https://github.com/traffictse/2025DataScienceProject/blob/94e9b64160145f40ae8ba605f8951cbce3266f38/src/nn_eval_main.py#L169-L177) and **call the right model** in `nn_eval_main.py` (See ). +5. The loss history will by default be **plotted over epochs not over batches**, saved as `.png` and `.pkl`, and the model will be saved as `.pth`. +6. `dummy_perm_data_gen.py` is to generate training data in the desired format for this approach. All needed to do is to specify the qubit number here (See ) and run `dummy_perm_data_gen.py`. The training data will be saved like `training_data_perm_4_qubit.pkl`. It would roughly take 30-40 hours to generate 3200 5-qubit data points, which I terminated in the haflway. +7. It is recommended to test initial performance over just 50-100 valuations. Just for time saving. +8. Good luck! diff --git a/src/nn/dummy_perm_data_gen.py b/src/nn/dummy_perm_data_gen.py new file mode 100644 index 0000000..eaa1d09 --- /dev/null +++ b/src/nn/dummy_perm_data_gen.py @@ -0,0 +1,58 @@ +import pickle +from src.utils import tableau_from_circuit, random_hscx_circuit +from pauliopt.clifford.tableau import CliffordTableau +from src.nn.brute_force_data import get_best_cnots +from pauliopt.topologies import Topology +import numpy as np +import warnings +from tqdm import tqdm # Import tqdm for progress bar + +# Suppress all overflow warnings globally +np.seterr(over="ignore") + +# Suppress FutureWarning +warnings.simplefilter(action="ignore", category=FutureWarning) + + +def generate_data(n_qubits=4, nr_gates=1000, batch_size=32, num_epochs=100): + # Configuration + total_data_points = batch_size * num_epochs # 3200 + filename = f"training_data_perm_{n_qubits}_qubit.pkl" + + # Generate and save data + data = [] + + # Create a progress bar + with tqdm(total=total_data_points, desc="Generating data", unit="circuits") as pbar: + for i in range(total_data_points): + # Generate circuit with minimum gates + circuit = random_hscx_circuit(nr_qubits=n_qubits, nr_gates=nr_gates) + tableau = tableau_from_circuit(CliffordTableau(n_qubits), circuit) + best_perms = get_best_cnots(tableau, Topology.complete(n_qubits)) + + # Store as (input_tableau, target_perms) for various tableau-to-graph implementations + data.append((tableau, best_perms)) + + # Update progress bar + pbar.update(1) + + # Still print epoch completion + if (i + 1) % batch_size == 0: + epoch = (i + 1) // batch_size + pbar.set_postfix({"Epochs": f"{epoch}/{num_epochs}"}) + + # Save to file + print(f"Saving {len(data)} data points to {filename}...") + with open(filename, "wb") as f: + pickle.dump(data, f) + + print(f"Generated {len(data)} data points saved to {filename}") + + +def main(): + generate_data(n_qubits=5, nr_gates=1000, batch_size=32, num_epochs=100) + return 0 + + +if __name__ == "__main__": + main() diff --git a/src/nn/nn_train_main.py b/src/nn/nn_train_main.py new file mode 100644 index 0000000..18ba02b --- /dev/null +++ b/src/nn/nn_train_main.py @@ -0,0 +1,204 @@ +import torch +import torch.nn as nn +import torch.optim as optim +import torch.nn.functional as F # Import torch.nn.functional as F +from torch.utils.data import DataLoader, TensorDataset +from src.nn.best_qubit_model import BestQubitModel + +def load_data(train_path, val_path): + """ + Loads the training and validation data. + + Args: + train_path (str): Path to the training data file. + val_path (str): Path to the validation data file. + + Returns: + (Tensor, Tensor, Tensor, Tensor): X_train, y_train, X_val, y_val + """ + train_data = torch.load(train_path) + val_data = torch.load(val_path) + X_train, y_train = train_data + X_val, y_val = val_data + return X_train, y_train, X_val, y_val + +def create_dataloaders(X_train, y_train, batch_size=32): + """ + Creates the training DataLoader. + + Args: + X_train (Tensor): Training inputs + y_train (Tensor): Training targets + batch_size (int): Batch size for DataLoader + + Returns: + DataLoader: A DataLoader for training data + """ + train_dataset = TensorDataset(X_train, y_train) + train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True) + return train_loader + +def save_loss_plot(train_losses, val_losses, filename="training_loss.png"): + """ + Saves the training and validation loss plot as a PNG file. + + Args: + train_losses (list of float): List containing the training loss value per epoch. + val_losses (list of float): List containing the validation loss value per epoch. + filename (str): Filename for the saved plot (default: training_loss.png). + """ + import matplotlib.pyplot as plt + plt.figure(figsize=(10, 6)) + plt.plot(train_losses, label="Training Loss") + plt.plot(val_losses, label="Validation Loss") + plt.xlabel("Epoch") + plt.ylabel("Loss") + plt.title("Training and Validation Loss Progression") + plt.legend() + plt.grid(True) + plt.savefig(filename) + plt.close() + +def log_experiment_details(filename, model, optimizer, best_train_loss, best_val_loss, n_epochs, patience): + """ + Logs the experiment details to a text file. + + Args: + filename (str): Path to the log file. + model (nn.Module): The model being trained. + optimizer (torch.optim.Optimizer): The optimizer used for training. + best_train_loss (float): The best training loss achieved. + best_val_loss (float): The best validation loss achieved. + n_epochs (int): Number of epochs the model was trained for. + patience (int): Patience for early stopping. + """ + with open(filename, 'a') as f: + f.write(f"Model: {model}\n") + f.write(f"Number of hidden layers: {model.hidden_layers}\n") + f.write(f"Optimizer: {optimizer}\n") + f.write(f"Number of epochs: {n_epochs}\n") + f.write(f"Patience: {patience}\n") + f.write(f"Best training loss: {best_train_loss:.4f}\n") + f.write(f"Best validation loss: {best_val_loss:.4f}\n") + f.write("\n" + "="*80 + "\n\n") + +def custom_loss(output, target): + mse_loss = nn.MSELoss()(output, target) + penalty = torch.sum(F.relu(-output)) # Penalize negative values + return mse_loss + penalty + + +def train_model(model, train_loader, criterion, optimizer, X_train, y_train, X_val, y_val, n_epochs=30000, verbose=True, patience=1000, log_file="experiment_log.txt"): + """ + Main training loop for the model. + + Args: + model (nn.Module): Neural network model. + train_loader (DataLoader): DataLoader for training data. + criterion (nn.Module): Loss function. + optimizer (torch.optim.Optimizer): Optimizer for training. + X_train (Tensor): Training inputs for occasional sample prediction. + y_train (Tensor): Training targets for occasional sample comparison. + X_val (Tensor): Validation inputs. + y_val (Tensor): Validation targets. + n_epochs (int): Number of epochs to train. + verbose (bool): If True, prints updates to terminal. + log_file (str): Path to the log file. + + Returns: + None. + """ + train_losses = [] + val_losses = [] + best_train_loss = float('inf') + best_val_loss = float('inf') + epochs_no_improve = 0 # Counter for early stopping + + for epoch in range(n_epochs): + model.train() + total_loss = 0 + + for batch_X, batch_y in train_loader: + optimizer.zero_grad() + outputs = model(batch_X) + loss = criterion(outputs, batch_y) + loss.backward() + optimizer.step() + total_loss += loss.item() + + # Compute average training loss for this epoch + avg_loss = total_loss / len(train_loader) + train_losses.append(avg_loss) + + # Evaluate on validation set + model.eval() + with torch.no_grad(): + val_outputs = model(X_val) + val_loss = criterion(val_outputs, y_val).item() + val_losses.append(val_loss) + + if verbose and epoch % 10 == 0: + current_lr = optimizer.param_groups[0]['lr'] + print(f'Epoch {epoch}, Training Loss: {avg_loss:.4f}, ' + f'Validation Loss: {val_loss:.4f}, LR: {current_lr:.6f}') + with torch.no_grad(): + for i in range(2): # Print predictions for the first i examples + test_input = X_train[i:i+1] + pred = model(test_input) + print(f"Example {i+1} - Predicted values:") + print(pred[0, 0]) + print(f"Example {i+1} - Actual values:") + print(y_train[i, 0]) + + # Save only the best model so far and check early stopping + if val_loss < best_val_loss: + best_val_loss = val_loss + best_train_loss = avg_loss + torch.save(model.state_dict(), "best_qubit_model_weights.pt") + epochs_no_improve = 0 + + else: + epochs_no_improve += 1 + + if epochs_no_improve >= patience: + print(f"Early stopping triggered after {epoch} epochs.") + break + + save_loss_plot(train_losses, val_losses) + model.load_state_dict(torch.load("best_qubit_model_weights.pt")) + + # Log experiment details when a new best validation loss is achieved + log_experiment_details(log_file, model, optimizer, best_train_loss, best_val_loss, epoch, patience) + +def main(): + # File paths + train_path = 'train_data_True_from_project_description.pt' + val_path = 'val_data_True_from_project_description.pt' + + # Load data + X_train, y_train, X_val, y_val = load_data(train_path, val_path) + + # Create model, criterion, optimizer + model = BestQubitModel(n_size=4, hidden_layers=3, hidden_size=128, dropout_rate=0.5) + criterion = custom_loss # Use the custom loss function + optimizer = optim.Adam(model.parameters(), lr=0.0001, weight_decay=1e-4) + + # Create data loader + train_loader = create_dataloaders(X_train, y_train, batch_size=32) + + # Train model with validation + train_model( + model=model, + train_loader=train_loader, + criterion=criterion, + optimizer=optimizer, + X_train=X_train, + y_train=y_train, + X_val=X_val, + y_val=y_val, + n_epochs=20000, + verbose=True + ) + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/src/nn/permutation_clean.py b/src/nn/permutation_clean.py new file mode 100644 index 0000000..8837cc6 --- /dev/null +++ b/src/nn/permutation_clean.py @@ -0,0 +1,1671 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.utils.data import Dataset +from src.utils import tableau_from_circuit, random_hscx_circuit +from pauliopt.clifford.tableau import CliffordTableau +from pauliopt.topologies import Topology +from pauliopt.clifford.tableau_synthesis import synthesize_tableau_perm_row_col +from pauliopt.circuits import Circuit +from scipy.optimize import linear_sum_assignment +import numpy as np +import warnings +import pickle +import tempfile +import os +import matplotlib.pyplot as plt + + +# Suppress all overflow warnings globally +np.seterr(over="ignore") + +# Suppress FutureWarning +warnings.simplefilter(action="ignore", category=FutureWarning) + + +# This current code file is slower running on GPU, as it needs to move data between CPU and GPU frequently. +def get_default_device(): + """Return the best available device (CUDA → MPS → CPU)""" + + if torch.cuda.is_available(): + return torch.device("cuda") + elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available(): + return torch.device("mps") + else: + return torch.device("cpu") + + +# # Try a Block-Aware Residual Encoder (Expected to be better for Tableau Structure, but empirically not) +# class ResidualBlockAwareEncoder(nn.Module): +# def __init__(self, n_qubits): +# super().__init__() +# self.n_qubits = n_qubits + +# # Extract features from X/Z blocks separately +# self.x_block_encoder = nn.Sequential( +# nn.Conv2d(2, 32, 3, padding=1), +# nn.GELU(), +# nn.Conv2d(32, 48, 3, padding=1), +# ) + +# self.z_block_encoder = nn.Sequential( +# nn.Conv2d(2, 32, 3, padding=1), +# nn.GELU(), +# nn.Conv2d(32, 48, 3, padding=1), +# ) + +# # Combine X/Z information +# self.combiner = nn.Conv2d(96, 64, 1) + +# # Final processing +# self.final = nn.Sequential( +# nn.GELU(), +# nn.Conv2d(64, 64, 3, padding=1), +# nn.GELU(), +# nn.AdaptiveAvgPool2d((4, 4)), +# ) + +# def forward(self, x): +# # Split X and Z blocks +# n_qubits = self.n_qubits +# x_part = x[:, :, :n_qubits, :] +# z_part = x[:, :, n_qubits:, :] + +# # Process blocks separately +# x_features = self.x_block_encoder(x_part) +# z_features = self.z_block_encoder(z_part) + +# # Combine features +# combined = torch.cat([x_features, z_features], dim=1) +# combined = self.combiner(combined) + +# # Final processing +# output = self.final(combined) +# return output.flatten(start_dim=1) + + +# # Try a tableau-aware encoder, also empirically not as good as expected +# class TableauStructureEncoder(nn.Module): +# def __init__(self, n_qubits): +# super().__init__() +# self.n_qubits = n_qubits + +# # Single pathway with structural awareness +# self.encoder = nn.Sequential( +# # Initial feature extraction +# nn.Conv2d(2, 32, 3, padding=1), +# nn.GELU(), +# nn.BatchNorm2d(32), +# # Capture larger patterns +# nn.Conv2d(32, 64, 5, padding=2), +# nn.GELU(), +# nn.BatchNorm2d(64), +# # Global context +# nn.Conv2d(64, 128, 3, padding=1), +# nn.GELU(), +# nn.BatchNorm2d(128), +# ) + +# # Position-aware readout +# self.pool = nn.AdaptiveAvgPool2d((4, 4)) + +# # Final projection +# self.project = nn.Linear(128 * 4 * 4, 256) + +# def forward(self, x): +# # Process the whole tableau together +# features = self.encoder(x) +# pooled = self.pool(features) +# return self.project(pooled.flatten(1)) + + +# # Empirically more time-consuming and not as good as expected +# class TransformerTableauEncoder(nn.Module): +# def __init__(self, n_qubits, dim=256, num_layers=4, num_heads=8): +# super().__init__() +# self.n_qubits = n_qubits + +# # Embedding for tableau entries +# self.embedding = nn.Linear(2, dim) # 2 channels to dimension + +# # 2D positional encoding +# self.row_pos = nn.Parameter(torch.randn(2 * n_qubits, dim // 2)) +# self.col_pos = nn.Parameter(torch.randn(2 * n_qubits, dim // 2)) + +# # Transformer encoder +# encoder_layer = nn.TransformerEncoderLayer( +# d_model=dim, +# nhead=num_heads, +# dim_feedforward=dim * 4, +# activation=F.gelu, +# batch_first=True, +# ) +# self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers) + +# # Output projection +# self.output_proj = nn.Linear(dim * (2 * n_qubits) ** 2, dim) + +# def forward(self, x): +# B = x.size(0) +# n = 2 * self.n_qubits + +# # Reshape input to [batch, n*n, 2] +# x = x.permute(0, 2, 3, 1).reshape(B, n * n, 2) + +# # Embed features +# x = self.embedding(x) # [batch, n*n, dim] + +# # Add 2D positional encoding +# pos_indices = torch.arange(n, device=x.device) +# row_idx = pos_indices.repeat_interleave(n).view(n, n) +# col_idx = pos_indices.repeat(n, 1) + +# row_emb = self.row_pos[row_idx.flatten()] # [n*n, dim//2] +# col_emb = self.col_pos[col_idx.flatten()] # [n*n, dim//2] +# pos_emb = torch.cat([row_emb, col_emb], dim=-1) # [n*n, dim] + +# # Add positional embeddings +# x = x + pos_emb.unsqueeze(0) # [batch, n*n, dim] + +# # Pass through transformer +# x = self.transformer(x) # [batch, n*n, dim] + +# # Global pooling with attention +# x = x.flatten(1) # [batch, n*n*dim] +# x = self.output_proj(x) # [batch, dim] + +# return x + + +class SelfAttentionBlock(nn.Module): + """Self-attention block for feature extraction.""" + + def __init__(self, channels): + super().__init__() + self.query = nn.Conv2d(channels, channels // 8, 1) + self.key = nn.Conv2d(channels, channels // 8, 1) + self.value = nn.Conv2d(channels, channels, 1) + self.gamma = nn.Parameter(torch.zeros(1)) + + def forward(self, x): + batch, c, h, w = x.size() + query = self.query(x).view(batch, -1, h * w).permute(0, 2, 1) + key = self.key(x).view(batch, -1, h * w) + energy = torch.bmm(query, key) + attention = F.softmax(energy, dim=-1) + value = self.value(x).view(batch, -1, h * w) + out = torch.bmm(value, attention.permute(0, 2, 1)) + out = out.view(batch, c, h, w) + return x + self.gamma * out + + +# Try a CX-aware decoder +class PermutationCXAwareDecoder(nn.Module): + """ + Permutation CX-aware decoder for sequence generation. + This decoder uses a transformer architecture with an additional CX-aware attention mechanism. + """ + + def __init__(self, dim, n_qubits, num_layers=4, nhead=4): + super().__init__() + # Base transformer decoder + self.decoder = nn.TransformerDecoder( + nn.TransformerDecoderLayer( + d_model=dim, nhead=nhead, dim_feedforward=4 * dim, activation=F.gelu + ), + num_layers=num_layers, + ) + + # CX-aware attention mechanism + self.cx_attention = nn.Linear(dim, n_qubits**2) + + # Final CX integration layer + self.cx_integration = nn.Sequential( + nn.Linear(dim + n_qubits**2, dim), nn.GELU(), nn.LayerNorm(dim) + ) + + def forward(self, tgt, memory): + # Regular transformer decoding + output = self.decoder(tgt, memory) + + # Generate CX-aware attention weights + cx_weights = torch.sigmoid(self.cx_attention(output)) + + # Reshape for visualization and further processing + batch_size = output.size(1) + seq_len = output.size(0) + cx_weights = cx_weights.view(seq_len, batch_size, -1) + + # Mix output with CX awareness + enhanced_output = self.cx_integration(torch.cat([output, cx_weights], dim=-1)) + + return enhanced_output, cx_weights + + +# Residual block for CX predictions +class ResidualCXBlock(nn.Module): + """Residual block for CX predictions.""" + + def __init__(self, dim): + super().__init__() + self.net = nn.Sequential(nn.Linear(dim, dim), nn.GELU(), nn.LayerNorm(dim)) + + def forward(self, x): + return x + self.net(x) # Residual connection + + +class OrderedPermutationTransformer(nn.Module): + """The model for ordered permutation prediction. Experiments with different architectures were in the comments.""" + + def __init__(self, n_qubits, dim=256, num_layers=4, num_heads=8, dropout=0.1): + super().__init__() + self.n_qubits = n_qubits + self.dim = dim + + # # Simple CNN encoder, already worked very well + # self.tableau_encoder = nn.Sequential( + # nn.Conv2d(2, 32, 3, padding=1), + # nn.GELU(), + # nn.BatchNorm2d(32), # Added normalization + # nn.Conv2d(32, 64, 3, padding=1), + # nn.GELU(), + # nn.BatchNorm2d(64), # Added normalization + # nn.Conv2d(64, 64, 3, padding=1), # Added third layer + # nn.GELU(), + # nn.AdaptiveAvgPool2d((4, 4)), + # ) + + # Enhanced CNN encoder with self-attention + self.tableau_encoder = nn.Sequential( + nn.Conv2d(2, 32, 3, padding=1), + nn.GELU(), + nn.BatchNorm2d(32), + nn.Dropout2d(dropout / 2), + SelfAttentionBlock(32), # Add self-attention between conv layers + nn.Conv2d(32, 64, 3, padding=1), + nn.GELU(), + nn.BatchNorm2d(64), + nn.Dropout2d(dropout / 2), + SelfAttentionBlock(64), # Add self-attention between conv layers + nn.Conv2d(64, 64, 3, padding=1), + nn.GELU(), + nn.AdaptiveAvgPool2d((4, 4)), + ) + + # Linear layer to match the enhanced CNN encoder output to dim + self.linear = nn.Linear(64 * 4 * 4, dim) + + # # Simple sequence Decoder, already worked very well + # self.decoder = nn.TransformerDecoder( + # nn.TransformerDecoderLayer( + # d_model=dim, nhead=4, dim_feedforward=4 * dim, activation=F.gelu + # ), + # num_layers=num_layers, + # ) + + # Enhanced sequence Decoder with CX-awareness, empirically not as good as expected + self.decoder = PermutationCXAwareDecoder( + dim=dim, n_qubits=n_qubits, num_layers=num_layers, nhead=num_heads + ) + + # Position-Aware Prediction Heads + self.step_heads = nn.ModuleList( + [ + nn.Sequential( + nn.Linear(dim, 2 * n_qubits**2), + nn.GELU(), + nn.Linear(2 * n_qubits**2, n_qubits**2), + ) + for _ in range(10) # Max sequence length + ] + ) + + # Adaptive Sequence Length Prediction + self.stop_head = nn.Linear(dim, 1) + + # # Simple CX head, already worked very well + # self.cx_head = nn.Sequential( + # nn.Linear(dim, dim // 2), + # nn.GELU(), + # nn.Dropout(0.1), + # nn.Linear(dim // 2, 1), + # nn.ReLU(), # Ensure non-negative predictions + # ) + + # # Try sophisticated CX head, directly connected to permutation prediction + # self.cx_head = nn.Sequential( + # nn.Linear(dim + n_qubits**2, 256), # Add permutation matrix features + # nn.GELU(), + # nn.Dropout(0.1), + # nn.LayerNorm(256), + # nn.Linear(256, 128), + # nn.GELU(), + # nn.Linear(128, 1), + # nn.Softplus(), # Ensure non-negative predictions + # ) + + # # Try more sophisticated CX head to accept the additional cx_aware features + # self.cx_head = nn.Sequential( + # nn.Linear( + # dim + n_qubits**2 + n_qubits**2, 256 + # ), # Added cx_awareness features + # nn.GELU(), + # nn.Dropout(0.1), + # nn.LayerNorm(256), + # nn.Linear(256, 128), + # nn.GELU(), + # nn.Linear(128, 1), + # nn.Softplus(), # Ensure non-negative predictions + # ) + + # Further enhanced more sophisticated CX-head with residual connections + self.cx_head = nn.Sequential( + nn.Linear(dim + n_qubits**2 + n_qubits**2, 384), # Wider network + nn.GELU(), + nn.Dropout(0.15), # Slightly higher dropout + nn.LayerNorm(384), + ResidualCXBlock(384), # Add residual connections + nn.Linear(384, 128), + nn.GELU(), + nn.Linear(128, 64), + nn.GELU(), + nn.Linear(64, 1), + nn.ReLU(), # Use ReLU instead of Softplus for sharper predictions + ) + + # Enhanced initialization, empirically no improvement + for p in self.parameters(): + if p.dim() > 1: + nn.init.kaiming_normal_(p, mode="fan_out", nonlinearity="relu") + + def forward(self, tableau): + """ + Forward pass through the model. + Args: + tableau: Input tableau tensor of shape [batch_size, 2, n_qubits, n_qubits] + Returns: + predictions: Predicted permutations of shape [batch_size, seq_len, n_qubits, n_qubits] + stop_logits: Logits for stop signal of shape [batch_size, seq_len] + cx_predictions: Predicted CX values of shape [batch_size, seq_len] + """ + # Encode Tableau + B = tableau.size(0) + x = self.tableau_encoder(tableau) + x = x.view(B, -1) + x = self.linear(x) # Ensure the dimension matches + + # Generate Sequence + memory = x.unsqueeze(0) # [1, batch_size, dim] + output = torch.zeros(B, self.dim, device=tableau.device) # [batch_size, dim] + + predictions = [] + stop_logits = [] + cx_predicitons = [] # For CX-prediction + + for t in range(10): # Max steps. When set to n_qubits, stop head is not needed + # Try CX-aware decoder + output_enhanced, cx_awareness = self.decoder( + tgt=output.unsqueeze(0), memory=memory + ) + output = output_enhanced.squeeze(0) + + # Step-Specific Prediction + pred = self.step_heads[t](output) # [batch_size, n_qubits*n_qubits] + + # Reshape for predictions + predictions.append(pred.view(B, self.n_qubits, self.n_qubits)) + + # Stop Prediction + stop_logits.append(self.stop_head(output)) + + # For CX-prediction + perm_features = pred.view(B, self.n_qubits, self.n_qubits).flatten(1) + + # Try mixing with cx_awareness information from decoder + cx_aware_features = torch.cat( + [output, perm_features, cx_awareness.squeeze(0)], dim=1 + ) + + # Updated CX head to use awareness features + cx_pred = self.cx_head(cx_aware_features) + cx_predicitons.append(cx_pred) + + # Stack along sequence dimension -> [batch_size, seq_len, n_qubits, n_qubits] + predictions = torch.stack(predictions, dim=1) + + # Stack stop logits -> [batch_size, seq_len] + stop_logits = torch.cat(stop_logits, dim=1) + + # Stack cx predictions -> [batch_size, seq_len] + cx_predicitons = torch.cat(cx_predicitons, dim=1) + + return predictions, stop_logits, cx_predicitons + + +class OrderedPermutationLoss(nn.Module): + """Loss function for ordered permutation prediction.""" + + def __init__(self, alpha=0.5, beta=0.1, gamma=2.0): + super().__init__() + self.alpha = alpha # Stop signal loss weight + self.beta = beta # Length regularization weight + self.gamma = gamma # CX prediction loss weight, gamma=8.0 sometimes result in stuck at loss 270.0 + + def forward( + self, preds, stop_logits, targets, masks, cx_preds=None, cx_targets=None + ): + """ + preds: [bs, seq_len, n, n] + stop_logits: [bs, seq_len] - This has shape [32, 10] + targets: [bs, seq_len, n, n] + masks: [bs, seq_len] - But this might have shape [32, 1] + cx_preds: [bs, seq_len] + cx_targets: [bs, seq_len] + """ + # 1. Computing permutation loss for each valid target and use the minimum + perm_losses = [] + for i in range(len(preds)): + losses = [] + for valid_seq in targets[i]: # valid_seq: [seq_len, n, n] + # Pad valid_seq if needed + if valid_seq.size(0) < preds[i].size( + 0 + ): # Compare sequence lengths (dim 0) + n = valid_seq.size(-1) # Get n_qubits dimension size + padding = torch.zeros( + preds[i].size(0) - valid_seq.size(0), + n, + n, + device=valid_seq.device, + ) + valid_seq_padded = torch.cat( + [valid_seq, padding], dim=0 + ) # Along seq dim + else: + valid_seq_padded = valid_seq[ + : preds[i].size(0) + ] # Truncate if too long + + # Ensure device matching + valid_seq_padded = valid_seq_padded.to(preds[i].device) + + loss = ( + F.mse_loss(preds[i], valid_seq_padded, reduction="none").mean( + dim=(-1, -2) + ) + * masks[i] + ) + loss = loss.sum() / masks[i].sum().clamp(min=1.0) + losses.append(loss) + perm_losses.append(torch.stack(losses).min()) + perm_loss = torch.stack(perm_losses).mean() + + # 2. Stop signal loss with properly sized tensors + # Create stop_labels with the right size [bs, seq_len] + stop_labels = torch.zeros_like(stop_logits) # Match size exactly + + # Fill first positions with 1s (assuming first step should always stop) + if masks.size(1) > 0: + stop_labels[:, 0] = 1.0 + + # If masks has more than 1 column, shift it + if masks.size(1) > 1: + stop_labels[:, 1:] = masks[:, :-1] # Shift right + + stop_loss = F.binary_cross_entropy_with_logits( + stop_logits, stop_labels, reduction="mean" + ) + + # 3. Length regularization + pred_lengths = torch.sigmoid(stop_logits).sum(dim=1) # [bs] + true_lengths = masks.sum(dim=1).clamp(min=1.0) # [bs], avoid zeros + length_loss = F.l1_loss(pred_lengths, true_lengths) + + # 4. CX prediction loss (if cx_targets is provided) + cx_loss = 0 + if cx_preds is not None and cx_targets is not None: + # Make sure cx_targets is a tensor + if not isinstance(cx_targets, torch.Tensor): + cx_targets = torch.tensor(cx_targets, device=cx_preds.device).float() + + # Ensure shapes match + if cx_targets.dim() == 1: + cx_targets = cx_targets.unsqueeze(1).expand(-1, cx_preds.size(1)) + + # Use MSE loss for regression (no sigmoid needed with ReLU output), grows quadratically + cx_loss = F.mse_loss(cx_preds, cx_targets, reduction="mean") + + return ( + perm_loss + + self.alpha * stop_loss + + self.beta * length_loss + + self.gamma * cx_loss + ) + + +def train(model, dataloader, epochs, device): + """ + Train the model with the given dataloader and number of epochs. + Args: + model: The model to train. + dataloader: The DataLoader for the training data. + epochs: Number of epochs to train. + device: Device to use for training (CPU or GPU). + """ + # Initialization + model = model.to(device) + optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4, weight_decay=1e-5) + scheduler = torch.optim.lr_scheduler.OneCycleLR( + optimizer, max_lr=5e-4, total_steps=epochs * len(dataloader), pct_start=0.3 + ) + + # Loss function + criterion = OrderedPermutationLoss(gamma=5.0) # Adjust gamma as needed + + # Gradient accumulation (for larger batches) + accum_steps = 4 + + loss_history = [] # To track loss per epoch + + for epoch in range(epochs): + model.train() + total_loss = 0 + batches = 0 + + optimizer.zero_grad() + + for batch_idx, (tableaus, raw_targets) in enumerate(dataloader): + # 1. Prepare batch ------------------------------------------------- + # Convert raw targets to padded tensor + targets, masks = pad_targets_all_seq( + raw_targets, device + ) # returning all sequences + + # Move data to device + tableaus = tableaus.to(device) + + # Extract cx targets from `raw_targets` + cx_targets = [item[1] for item in raw_targets] + + # 2. Forward pass ------------------------------------------------- + raw_preds, stop_logits, cx_preds = model( + tableaus + ) # [batch_size, seq_len, n, n] + + # 3. Apply Sinkhorn normalization to batch-first format + sinkhorn_preds = [] + for i in range(raw_preds.size(0)): # For each batch item + seq_preds = [] + for j in range(raw_preds.size(1)): # For each sequence step + logits = raw_preds[i, j] + # Apply Sinkhorn + for _ in range(20): # n_iters + logits = logits - torch.logsumexp(logits, dim=-1, keepdim=True) + logits = logits - torch.logsumexp(logits, dim=-2, keepdim=True) + seq_preds.append(torch.exp(logits / 0.1)) + sinkhorn_preds.append(torch.stack(seq_preds)) + preds = torch.stack(sinkhorn_preds) # [batch_size, seq_len, n, n] + + # 4. Loss calculation ---------------------------------------------- + # Convert cx_targets from list to properly shaped tensor + cx_targets_tensor = torch.tensor(cx_targets, device=device).float() + + # Reshape to match cx_preds dimensions [batch_size, seq_len] + cx_targets_tensor = cx_targets_tensor.unsqueeze(1).expand( + -1, cx_preds.size(1) + ) + + loss = criterion( + preds, # [batch_size, seq_len, n, n] + stop_logits, # [batch_size, seq_len] + targets, # [batch_size, seq_len, n, n] + masks, # [batch_size, seq_len] + cx_preds=cx_preds, # [batch_size, seq_len] + cx_targets=cx_targets_tensor, # [batch_size, seq_len] + ) + + # 5. Backpropagation ---------------------------------------------- + loss.backward() + + # 6. Gradient accumulation ----------------------------------------- + if (batch_idx + 1) % accum_steps == 0: + torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) + optimizer.step() + optimizer.zero_grad() + scheduler.step() # For OneCycleLR + + # 7. Logging ------------------------------------------------------ + total_loss += loss.item() + batches += 1 + + if batch_idx % 10 == 0: + print(f"Epoch {epoch+1} | Batch {batch_idx} | Loss: {loss.item():.4f}") + + # 8. Epoch summary ----------------------------------------------------- + avg_loss = total_loss / batches + loss_history.append(avg_loss) + print(f"Epoch {epoch+1} completed | Average Loss: {avg_loss:.4f}") + + # Plotting loss history + plt.figure(figsize=(10, 6)) + plt.plot(range(1, len(loss_history) + 1), loss_history, marker="o") + plt.title("Training Loss by Epoch") + plt.xlabel("Epoch") + plt.ylabel("Loss") + plt.grid(True) + plt.savefig("loss_history.png") + plt.close() + # Save the model + torch.save(model.state_dict(), "ordered_permutation_model.pth") + print("Model saved as ordered_permutation_model.pth") + # Save loss history + with open("loss_history.pkl", "wb") as f: + pickle.dump(loss_history, f) + print("Loss history saved as loss_history.pkl") + + return model # Return the trained model + + +def pad_targets(raw_targets, device): + """ + Handle the correct nested structure of permutation sequences. + Args: + raw_targets: list of tuples (tableau, perm_sequences) + device: device to move tensors to (CPU or GPU) + Returns: + padded: [batch_size, max_len, n_qubits, n_qubits] tensor + mask: [batch_size, max_len] tensor (mask for the longest sequence in the batch) + """ + # Extract first sequence from each batch item + first_sequences = [] + + for item in raw_targets: + perm_sequences = item[0] # This is the all_perm_sequences list + # Take the first permutation sequence + first_seq = perm_sequences[0] + first_sequences.append(first_seq) + + # Determine max length and dimensions + max_len = max(len(seq) for seq in first_sequences) + n_qubits = first_sequences[0][0].shape[0] + + # Ensure max_len is at least 1 + max_len = max(max_len, 1) + + # Pad sequences to max_len + padded = torch.stack( + [ + torch.cat( + [ + ( + torch.stack(seq) + if len(seq) > 0 + else torch.eye(n_qubits).unsqueeze(0) + ), + torch.zeros(max_len - min(len(seq), max_len), n_qubits, n_qubits), + ] + ) + for seq in first_sequences + ] + ) + + # Create masks with the same length + mask = torch.stack( + [ + torch.cat( + [ + torch.ones(min(len(seq), max_len)), + torch.zeros(max_len - min(len(seq), max_len)), + ] + ) + for seq in first_sequences + ] + ) + + return padded.to(device), mask.to(device) + + +# Try returning all permutation sequences +def pad_targets_all_seq(raw_targets, device): + """ + Handle the correct nested structure of permutation sequences. + Args: + raw_targets: list of tuples (tableau, perm_sequences) + device: device to move tensors to (CPU or GPU) + Returns: + all_padded_sequences: list of [num_valid_seqs, max_len, n, n] tensors (per batch item) + masks: [batch_size, max_len] tensor (mask for the longest sequence in the batch) + """ + all_sequences = [] + max_len = 0 + n_qubits = None + + # Gather all valid sequences and find max length + for item in raw_targets: + perm_sequences = item[0] # list of valid sequences + all_sequences.append(perm_sequences) + for seq in perm_sequences: + if n_qubits is None and len(seq) > 0: + n_qubits = seq[0].shape[0] + max_len = max(max_len, len(seq)) + + # Pad all sequences to max_len + all_padded_sequences = [] + for perm_sequences in all_sequences: + padded_seqs = [] + for seq in perm_sequences: + seq_len = len(seq) + if seq_len < max_len: + pad = [torch.eye(n_qubits, device=device)] * (max_len - seq_len) + padded_seq = torch.stack(seq + pad) + else: + padded_seq = torch.stack(seq[:max_len]) + padded_seqs.append(padded_seq) + all_padded_sequences.append( + torch.stack(padded_seqs) + ) # [num_valid_seqs, max_len, n, n] + + # Create masks for the longest sequence in the batch (used for all) + mask = torch.zeros(len(all_sequences), max_len, device=device) + for i, perm_sequences in enumerate(all_sequences): + # Use the length of the longest sequence for this sample + seq_len = max(len(seq) for seq in perm_sequences) + mask[i, :seq_len] = 1 + + return all_padded_sequences, mask + + +class TableauPermutationDataset(Dataset): + """ + Custom dataset for loading tableau permutation data. + Usage: + # For testing with a small random subset (320 samples) + dataset = TableauPermutationDataset( + data_file="training_data_perm.pkl", + max_samples=320, + shuffle=True # Random subset + ) + # For using the first 320 samples (deterministic) + dataset = TableauPermutationDataset( + data_file="training_data_perm.pkl", + max_samples=320, + shuffle=False # First N samples + ) + # For final training with all data + dataset = TableauPermutationDataset( + data_file="training_data_perm.pkl" # No max_samples means use all data + ) + """ + + def __init__(self, data_file, max_samples=None, shuffle=True): + print(f"Loading data from {data_file}...") + with open(data_file, "rb") as f: + all_data = pickle.load(f) + + # Detect n_qubits from the first tableau in the dataset + self.n_qubits = all_data[0][0].n_qubits + print(f"Auto-detected {self.n_qubits} qubits from data") + + # Option to load only a subset + if max_samples is not None and max_samples < len(all_data): + if shuffle: + # Randomly select subset + indices = torch.randperm(len(all_data))[:max_samples].tolist() + self.data = [all_data[i] for i in indices] + else: + # Take first max_samples + self.data = all_data[:max_samples] + print(f"Using subset of {max_samples} examples from {len(all_data)} total") + else: + self.data = all_data + print(f"Loaded all {len(self.data)} training examples") + + def __len__(self): + return len(self.data) + + def tableau_to_tensor(self, tableau): + """ + Convert a Clifford tableau to a tensor representation. + Args: + tableau: CliffordTableau object. + Returns: + Tensor representation of the tableau. + """ + # Get number of qubits + n_qubits = tableau.n_qubits + + # Create tensor with 2 channels - first for tableau, second for signs + combined = torch.zeros(2, 2 * n_qubits, 2 * n_qubits) + + # Get tableau data and signs + tableau_data = tableau.tableau + signs_data = tableau.signs + + # Convert numpy arrays to torch tensors if needed + if isinstance(tableau_data, np.ndarray): + tableau_data = torch.from_numpy(tableau_data).float() + if isinstance(signs_data, np.ndarray): + signs_data = torch.from_numpy(signs_data).float() + + # Fill first channel with tableau data + combined[0, :, :] = tableau_data + + # Add signs to second channel (broadcast along columns) + combined[1, :, 0] = signs_data + + return combined + + def __getitem__(self, idx): + """ + Get item from dataset. + Args: + idx: Index of the item. + Returns: + tableau: Tensor representation of the tableau. + best_perms: List of best permutation sequences and their CX counts. + """ + tableau, best_perms = self.data[idx] + n_qubits = tableau.n_qubits + + # Get CX count from first permutation (all should be same) + cx_count = best_perms[0][1] if best_perms else float("inf") + + # Create a list to hold all permutation sequences + all_perm_sequences = [] + + # Process each permutation sequence in best_perms + for sequence in best_perms: + # Each `sequence` is a tuple [a list of permutation tuples, cx_count] + + # For each sequence, create a list of matrices + perm_matrices = [] + current_perm = torch.eye(n_qubits) # Initialize with identity + + for step in sequence: + # Each step is a tuple (i,j) representing a swap + if isinstance(step, tuple) and len(step) == 2: + i, j = step + + # Convert to integers + if isinstance(i, (list, tuple)): + i = i[0] if i else 0 + if isinstance(j, (list, tuple)): + j = j[0] if i else 0 + + i = int(i) if hasattr(i, "__int__") else 0 + j = int(j) if hasattr(j, "__int__") else 0 + + # Create permutation matrix for this step + step_matrix = torch.eye(n_qubits) + + # Safe row swapping + temp = step_matrix[i].clone() + step_matrix[i] = step_matrix[j] + step_matrix[j] = temp + + # Compose with previous permutations + current_perm = step_matrix @ current_perm + perm_matrices.append(current_perm.clone()) + + # If this sequence produced valid matrices, add it + if perm_matrices: + all_perm_sequences.append(perm_matrices) + + # If not create any valid sequences, add a default one + if not all_perm_sequences: + all_perm_sequences.append([torch.eye(n_qubits)]) + + return self.tableau_to_tensor(tableau), (all_perm_sequences, cx_count) + + +def custom_collate_fn(batch): + """ + Custom collate function to handle variable-length permutation lists. + Args: + batch: List of tuples (tableau, best_perms). + Returns: + tableaus: Stacked tensor of tableaus. + targets: List of best permutation sequences and their CX counts. + """ + # Extract tableaus and targets + tableaus = [item[0] for item in batch] + targets = [item[1] for item in batch] + + # Stack tableaus (they should all have the same shape) + tableaus = torch.stack(tableaus) + + # Don't try to stack targets, just return them as a list + return tableaus, targets + + +def pretrain_a_model_from_file(data_file, max_samples, epochs, device): + """ + A wrapper function to pretrain a model from a file with the given parameters. + Args: + data_file: Path to the data file. + max_samples: Maximum number of samples to use. + epochs: Number of epochs to train. + device: Device to use for training (CPU or GPU). + Returns: + model: The trained model. + """ + dataset = TableauPermutationDataset(data_file, max_samples=max_samples) + model = OrderedPermutationTransformer( + n_qubits=dataset.n_qubits, dim=256, num_layers=12 + ) # Try wider model 6 -> 12 + + # Use standard PyTorch DataLoader with custom collate function + dataloader = torch.utils.data.DataLoader( + dataset, + batch_size=32, + shuffle=True, + collate_fn=custom_collate_fn, + ) + + train(model, dataloader, epochs, device) # 100 epochs too much, plateaus before 50; + return model + + +def convert_tensor_to_tableau(tableau_tensor): + """ + Convert tensor representation back to CliffordTableau object. + Args: + tableau_tensor: Tensor representation of the tableau. + Returns: + tableau: CliffordTableau object. + """ + # Extract tableau data from tensor + n_qubits = tableau_tensor.shape[-1] // 2 + + # Convert to int8 for proper bitwise operations (important fix!) + tableau_data = tableau_tensor[0, :, :].cpu().numpy().astype(np.int8) + signs_data = tableau_tensor[1, :, 0].cpu().numpy().astype(np.int8) + + # Create a new tableau + tableau = CliffordTableau(n_qubits) + tableau.tableau = tableau_data + tableau.signs = signs_data + + return tableau + + +# Ready to use, but no big improvement as expected! So, did not used mostly. +def supervised_cx_fine_tune(model, dataset, epochs, device): + """ + Fine-tune with supervised learning focusing on CX reduction. + Args: + model: The model to fine-tune. + dataset: The dataset for fine-tuning. + epochs: Number of epochs to fine-tune. + device: Device to use for training (CPU or GPU). + """ + model.train() + optimizer = torch.optim.Adam(model.parameters(), lr=1e-5) + criterion = OrderedPermutationLoss() + + for epoch in range(epochs): + total_improvement = 0 + dataloader = torch.utils.data.DataLoader( + dataset, batch_size=8, shuffle=True, collate_fn=custom_collate_fn + ) + + for batch_idx, (tableau_tensors, _) in enumerate(dataloader): + # For each tableau in batch + for tableau_tensor in tableau_tensors: + optimizer.zero_grad() + clifford_tableau = convert_tensor_to_tableau(tableau_tensor) + tableau_tensor = tableau_tensor.unsqueeze(0).to(device) + + # Get baseline CX count + topology = Topology.complete(clifford_tableau.n_qubits) + baseline_circuit = synthesize_tableau_perm_row_col( + clifford_tableau, topology + ) + baseline_cx = collect_circuit_data(baseline_circuit)["cx"] + + # Generate permutation candidates and find best one + with torch.no_grad(): + raw_preds, _, cx_preds = model(tableau_tensor) + target_perm = None + best_cx = baseline_cx + + # Try using CX predictions to guide step selection + cx_weights = torch.exp( + -cx_preds[0] * 2 + ) # Higher weight for lower CX predictions + step_probs = F.softmax(cx_weights, dim=0) + step_indices = torch.multinomial( + step_probs, min(8, raw_preds.shape[1]), replacement=False + ) + + # Try steps probabilistically selected based on predicted CX efficiency + for step_idx in step_indices: + # # Try different permutations + # for step_idx in range(min(3, raw_preds.shape[1])): + logits = raw_preds[0, step_idx] + for _ in range(20): + logits = logits - torch.logsumexp( + logits, dim=-1, keepdim=True + ) + logits = logits - torch.logsumexp( + logits, dim=-2, keepdim=True + ) + + perm_matrix = torch.exp(logits / 0.1).cpu().numpy() + row_ind, col_ind = linear_sum_assignment(-perm_matrix) + perm = [(int(i), int(j)) for i, j in zip(row_ind, col_ind)] + pred_iter = iter(perm) + + def pred_callback(G, remaining, remaining_rows, choice_fn=min): + try: + row, col = next(pred_iter) + # Convert to int if they're tensors + if isinstance(row, torch.Tensor): + row = row.item() + if isinstance(col, torch.Tensor): + col = col.item() + return int(row), int(col) # Ensure integers + # except StopIteration: + # row = choice_fn(remaining_rows) + # return row, row + + # Try smarter fallback, empirically no big improvement as expected + except StopIteration: + # Use graph analysis for better pivot selection + row = choice_fn(remaining_rows) + + if G[row]: + # Choose column with highest impact + impact_scores = {} + for col in G[row]: + # Count affected rows + impact = sum( + 1 for r in remaining_rows if col in G[r] + ) + impact_scores[col] = impact + + col = max( + impact_scores.items(), key=lambda x: x[1] + )[0] + else: + col = row + + return int(row), int(col) + + circuit = synthesize_tableau_perm_row_col( + clifford_tableau, + topology, + pick_pivot_callback=pred_callback, + ) + cx_count = collect_circuit_data(circuit)["cx"] + if cx_count < best_cx: + best_cx = cx_count + target_perm = perm + + # Try keeping consistent with OrderedPermutationLoss + # If we found a better permutation, train toward it + if target_perm is not None and best_cx < baseline_cx: + # Create target matrix + target = torch.zeros_like(raw_preds[0, 0]) + for i, j in target_perm: + target[i, j] = 1.0 + + # Prepare targets and masks in the format expected by OrderedPermutationLoss + target_tensor = target.unsqueeze(0).unsqueeze(0) # [1, 1, n, n] + mask_tensor = torch.ones(1, 1, device=device) # [1, 1] + + # Train model to predict this permutation + new_preds, stop_logits, new_cx_preds = model(tableau_tensor) + + # Apply Sinkhorn normalization (same as in train function) + logits = new_preds[0, 0] + for _ in range(20): + logits = logits - torch.logsumexp(logits, dim=-1, keepdim=True) + logits = logits - torch.logsumexp(logits, dim=-2, keepdim=True) + pred_tensor = ( + torch.exp(logits / 0.1).unsqueeze(0).unsqueeze(0) + ) # [1, 1, n, n] + + # Try adding strong explicit CX training (add this block) + cx_target = torch.tensor([best_cx], device=device).float() + cx_weight = 5.0 # Stronger weight for explicit CX training + cx_loss = cx_weight * F.mse_loss( + new_cx_preds[0, 0].unsqueeze(0), cx_target + ) + cx_loss.backward( + retain_graph=True + ) # First backward pass for CX only + + # Use OrderedPermutationLoss for consistent training objectives + cx_target_tensor = torch.tensor([[best_cx]], device=device).float() + + loss = criterion( + pred_tensor, # [1, 1, n, n] + stop_logits[:, 0:1], # [1, 1] + target_tensor, # [1, 1, n, n] + mask_tensor, # [1, 1] + cx_preds=new_cx_preds[:, 0:1], # [1, 1] + cx_targets=cx_target_tensor, # [1, 1] + ) + + loss.backward() + optimizer.step() + + total_improvement += baseline_cx - best_cx + + if batch_idx % 10 == 0: + print( + f"Epoch {epoch+1} | Batch {batch_idx} | Improvement: {total_improvement/(batch_idx+1):.2f}" + ) + + print( + f"Epoch {epoch+1} | Avg CX Improvement: {total_improvement/len(dataloader):.2f}" + ) + + return model + + +def tableau_to_tensor(tableau): + """ + Convert a Clifford tableau to a tensor representation. + Args: + tableau: CliffordTableau object. + Returns: + Tensor representation of the tableau. + """ + # Get number of qubits + n_qubits = tableau.n_qubits + + # Create tensor with 2 channels - first for tableau, second for signs + combined = torch.zeros(2, 2 * n_qubits, 2 * n_qubits) + + # Get tableau data and signs + tableau_data = tableau.tableau + signs_data = tableau.signs + + # Convert numpy arrays to torch tensors if needed + if isinstance(tableau_data, np.ndarray): + tableau_data = torch.from_numpy(tableau_data).float() + if isinstance(signs_data, np.ndarray): + signs_data = torch.from_numpy(signs_data).float() + + # Fill first channel with tableau data + combined[0, :, :] = tableau_data + + # Add signs to second channel (broadcast along columns) + combined[1, :, 0] = signs_data + + return combined + + +def collect_circuit_data(circuit: Circuit) -> dict: + """ + Collects data from a circuit object. + Args: + circuit: A Circuit object. + Returns: + A dictionary with the number of qubits, H gates, S gates, CX gates, and circuit depth. + """ + circuit.final_permutation = None + ops = circuit.to_qiskit().count_ops() + return { + "num_qubits": circuit.n_qubits, + "h": ops.get("h", 0), + "s": ops.get("s", 0), + "cx": ops.get("cx", 0), + "depth": circuit.to_qiskit().depth(), + } + + +def compute_weighted_score(pred_metrics): + """ + Computes a weighted score based on the number of CX gates and circuit depth. + The weights can be adjusted based on the importance of each metric. + Args: + pred_metrics: A dictionary containing the predicted metrics. + Returns: + A weighted score based on the CX count and circuit depth. + """ + # Define weights for each metric + cx_weight = 10.0 # Try adjusting from 10 to 50; 50 not better + depth_weight = 1.0 # 1.0 seems better than 0.1 + # Compute weighted score, only keep cx seems worse + score = cx_weight * pred_metrics["cx"] + depth_weight * pred_metrics["depth"] + return score + + +# Try to adapt weights based on how close we are to optimum, but not much better +def compute_adaptive_score(pred_metrics, baseline_metrics): + """ + Adapt weights based on how close we are to optimum. + Args: + pred_metrics: A dictionary containing the predicted metrics. + baseline_metrics: A dictionary containing the baseline metrics. + Returns: + A weighted score based on the CX count and circuit depth. + """ + cx_ratio = pred_metrics["cx"] / baseline_metrics["cx"] + depth_ratio = pred_metrics["depth"] / baseline_metrics["depth"] + + # If CX is much worse than depth, focus more on CX + if cx_ratio > depth_ratio * 1.2: + cx_weight = 20.0 + depth_weight = 0.5 + else: + cx_weight = 10.0 + depth_weight = 1.0 + + return cx_weight * pred_metrics["cx"] + depth_weight * pred_metrics["depth"] + + +def gumbel_sinkhorn(logits, temp=0.1, n_samples=5, n_iters=20): + """ + Generate multiple permutation samples using Gumbel-Sinkhorn normalization. + Args: + logits: Tensor of logits to sample from. + temp: Temperature for Gumbel noise. + n_samples: Number of samples to generate. + Returns: + samples: List of sampled permutation matrices. + """ + samples = [] + for _ in range(n_samples): + # Generate fresh Gumbel noise for each sample + gumbels = -torch.log(-torch.log(torch.rand_like(logits) + 1e-10) + 1e-10) + noisy_logits = (logits + gumbels) / temp + + # Apply Sinkhorn normalization + s = noisy_logits.clone() # Important: clone to avoid in-place modification + for _ in range(n_iters): + s = s - torch.logsumexp(s, dim=-1, keepdim=True) + s = s - torch.logsumexp(s, dim=-2, keepdim=True) + # samples.append(torch.exp(s)) + # Try adding final softmax instead of exp() for better numerical stability + samples.append(F.softmax(s, dim=-1)) + + # Return all samples (not average - keep the diversity) + return samples + + +def predict_permutation_gumbel(model, clifford_tableau, device): + """Returns the best permutation after evaluating multiple candidates. + Args: + model: The trained model. + clifford_tableau: The input tableau to predict the permutation for. + device: Device to use for prediction (CPU or GPU). + Returns: + best_perm: The best permutation found. + """ + device = torch.device(device) + model = model.to(device) + + # Get tableau tensor + tableau_tensor = tableau_to_tensor(clifford_tableau) + tableau_tensor = tableau_tensor.unsqueeze(0).to(device) + + # Store best permutation and its score + best_perm = None + best_score = float("inf") + topology = Topology.complete(clifford_tableau.n_qubits) + + with torch.no_grad(): + # Get predictions for all steps + raw_preds, _, cx_preds = model(tableau_tensor) + + # # Sort steps by cx count. Deterministic order + # step_indices = torch.argsort(cx_preds[0], descending=True).tolist() + + # Try a probabilistic approach. Weight step selection by predicted CX reduction potential + # Higher weight for lower CX predictions + cx_weights = torch.exp(-cx_preds[0] * 2) + step_probs = F.softmax(cx_weights, dim=0) + step_indices = torch.multinomial( + step_probs, min(8, raw_preds.shape[1]), replacement=False + ) + + for step_idx in step_indices: # Integrate cx_preds + # Apply Sinkhorn normalization + logits = raw_preds[0, step_idx] + + # Generate multiple permutation samples with Gumbel-Sinkhorn + perm_samples = gumbel_sinkhorn(logits, temp=0.1, n_samples=10) + for sample_matrix in perm_samples: + perm_matrix = sample_matrix.cpu().numpy() + + # Extract permutation using Hungarian algorithm + row_ind, col_ind = linear_sum_assignment(-perm_matrix) + perm = [(int(i), int(j)) for i, j in zip(row_ind, col_ind)] + + pred_iter = iter(perm) + + def pred_callback(G, remaining, remaining_rows, choice_fn=min): + try: + row, col = next(pred_iter) + # Convert to int if they're tensors + if isinstance(row, torch.Tensor): + row = row.item() + if isinstance(col, torch.Tensor): + col = col.item() + return int(row), int(col) # Ensure integers + # except StopIteration: + # row = choice_fn(remaining_rows) + # return row, row + + # Try smarter fallback, empirically no big improvement as expected + except StopIteration: + # Use graph analysis for better pivot selection + row = choice_fn(remaining_rows) + + if G[row]: + # Choose column with highest impact + impact_scores = {} + for col in G[row]: + # Count affected rows + impact = sum(1 for r in remaining_rows if col in G[r]) + impact_scores[col] = impact + + col = max(impact_scores.items(), key=lambda x: x[1])[0] + else: + col = row + + return int(row), int(col) + + pred_circuit = synthesize_tableau_perm_row_col( + clifford_tableau, topology, pick_pivot_callback=pred_callback + ) + pred_metrics = collect_circuit_data(pred_circuit) + + # Synthesize circuit and count gates + # score = pred_metrics["cx"] + pred_metrics["depth"] # Works fine already, weights = [1,1] + score = compute_weighted_score(pred_metrics) + # score = compute_adaptive_score(pred_metrics, collect_circuit_data(circuit)) + + # Keep track of best permutation + if score < best_score: + best_score = score + best_perm = perm + + return [best_perm] # Keep list format for compatibility + + +def entropy_guided_search(model, clifford_tableau, device, n_samples=5): + """Focus search on areas where model is uncertain to find better permutations. + Args: + model: The trained model. + clifford_tableau: The input tableau to predict the permutation for. + device: Device to use for prediction (CPU or GPU). + n_samples: Number of samples to generate for each step. + Returns: + best_perm: The best permutation found. + """ + device = torch.device(device) + model = model.to(device) + + tableau_tensor = tableau_to_tensor(clifford_tableau) + tableau_tensor = tableau_tensor.unsqueeze(0).to(device) + topology = Topology.complete(clifford_tableau.n_qubits) + + with torch.no_grad(): + # Forward pass to get raw predictions + raw_preds, _, _ = model(tableau_tensor) + + # Calculate entropy for each position (high entropy = uncertainty) + entropies = [] + for step in range(raw_preds.shape[1]): + logits = raw_preds[0, step] + + # Apply Sinkhorn normalization + for _ in range(20): + logits = logits - torch.logsumexp(logits, dim=-1, keepdim=True) + logits = logits - torch.logsumexp(logits, dim=-2, keepdim=True) + + # Convert to probabilities + probs = torch.exp(logits) + + # Calculate entropy: -sum(p * log(p)) + entropy = -torch.sum(probs * torch.log(probs + 1e-10)) + entropies.append((step, entropy.item())) + + # Sort steps by uncertainty (highest entropy first) + uncertain_steps = sorted(entropies, key=lambda x: -x[1]) + + # Store best permutation and score + best_perm = None + best_score = float("inf") + + # Explore uncertain steps more thoroughly + for step_idx, entropy in uncertain_steps[:3]: # Focus on top-3 most uncertain + # Use adaptive temperature range - more temps for higher entropy + temps = [0.01, 0.03, 0.05, 0.1, 0.2, 0.5] + if entropy > 3.0: # Very uncertain areas + temps.extend([0.7, 1.0]) # Add higher temps for exploration + + # Sample more permutations for high-entropy areas + samples_for_step = n_samples + int(3 * entropy) + + for temp in temps: + # Generate multiple samples with Gumbel-Sinkhorn + perm_samples = gumbel_sinkhorn( + raw_preds[0, step_idx], temp=temp, n_samples=samples_for_step + ) + + for sample_matrix in perm_samples: + perm_matrix = sample_matrix.cpu().numpy() + row_ind, col_ind = linear_sum_assignment(-perm_matrix) + perm = [(int(i), int(j)) for i, j in zip(row_ind, col_ind)] + + # Evaluate permutation + pred_iter = iter(perm) + + def pred_callback(G, remaining, remaining_rows, choice_fn=min): + try: + row, col = next(pred_iter) + # Convert to int if they're tensors + if isinstance(row, torch.Tensor): + row = row.item() + if isinstance(col, torch.Tensor): + col = col.item() + return int(row), int(col) # Ensure integers + # except StopIteration: + # row = choice_fn(remaining_rows) + # return row, row + + # Try smarter fallback, empirically no big improvement as expected + except StopIteration: + # Use graph analysis for better pivot selection + row = choice_fn(remaining_rows) + + if G[row]: + # Choose column with highest impact + impact_scores = {} + for col in G[row]: + # Count affected rows + impact = sum( + 1 for r in remaining_rows if col in G[r] + ) + impact_scores[col] = impact + + col = max(impact_scores.items(), key=lambda x: x[1])[0] + else: + col = row + + return int(row), int(col) + + circuit = synthesize_tableau_perm_row_col( + clifford_tableau, topology, pick_pivot_callback=pred_callback + ) + metrics = collect_circuit_data(circuit) + score = compute_weighted_score(metrics) + + # Update best if improved + if score < best_score: + best_score = score + best_perm = perm + + # In case nothing improved, fall back to basic prediction + if best_perm is None: + # Use basic Gumbel-Sinkhorn with default temp + perm_samples = gumbel_sinkhorn(raw_preds[0, 0], temp=0.1, n_samples=1) + perm_matrix = perm_samples[0].cpu().numpy() + row_ind, col_ind = linear_sum_assignment(-perm_matrix) + best_perm = [(int(i), int(j)) for i, j in zip(row_ind, col_ind)] + + return [best_perm] # Keep list format for compatibility + + +def ensemble_predict_permutation(model, tableau, device): + """Run multiple permutation prediction strategies and select best result""" + candidates = [] + + # Run all prediction strategies + gumbel_perms = predict_permutation_gumbel(model, tableau, device) + entropy_perms = entropy_guided_search(model, tableau, device, n_samples=8) + + # Add more aggressive temperature sampling + tableau_tensor = tableau_to_tensor(tableau).unsqueeze(0).to(device) + with torch.no_grad(): + raw_preds, _, cx_preds = model(tableau_tensor) + # Try extreme temperatures for more diversity + for temp in [0.005, 1.0]: # Very low and very high temps + perm_samples = gumbel_sinkhorn(raw_preds[0, 0], temp=temp, n_samples=5) + for sample in perm_samples: + perm_matrix = sample.cpu().numpy() + row_ind, col_ind = linear_sum_assignment(-perm_matrix) + candidates.append([(int(i), int(j)) for i, j in zip(row_ind, col_ind)]) + + # Add the candidates from standard methods + candidates.extend(gumbel_perms) + candidates.extend(entropy_perms) + + # Evaluate all candidates + topology = Topology.complete(tableau.n_qubits) + best_perm = None + best_score = float("inf") + + for perm in candidates: + # Evaluate permutation + pred_iter = iter(perm) + + def pred_callback(G, remaining, remaining_rows, choice_fn=min): + try: + row, col = next(pred_iter) + # Convert to int if they're tensors + if isinstance(row, torch.Tensor): + row = row.item() + if isinstance(col, torch.Tensor): + col = col.item() + return int(row), int(col) # Ensure integers + # except StopIteration: + # row = choice_fn(remaining_rows) + # return row, row + + # Try smarter fallback, empirically no big improvement as expected + except StopIteration: + # Use graph analysis for better pivot selection + row = choice_fn(remaining_rows) + + if G[row]: + # Choose column with highest impact + impact_scores = {} + for col in G[row]: + # Count affected rows + impact = sum(1 for r in remaining_rows if col in G[r]) + impact_scores[col] = impact + + col = max(impact_scores.items(), key=lambda x: x[1])[0] + else: + col = row + + return int(row), int(col) + + circuit = synthesize_tableau_perm_row_col( + tableau, topology, pick_pivot_callback=pred_callback + ) + metrics = collect_circuit_data(circuit) + score = compute_weighted_score(metrics) + + if score < best_score: + best_score = score + best_perm = perm + + return [best_perm] + + +def curriculum_train(data_file, max_samples, epochs_per_stage, device): + """Curriculum training with staged data""" + # Load and preprocess your data + # `all_data` should be a list of (tableau, best_perms, cx_count) + all_data = TableauPermutationDataset(data_file=data_file, max_samples=max_samples) + model = OrderedPermutationTransformer( + n_qubits=all_data.n_qubits, dim=256, num_layers=6 + ) + all_data = all_data.data + # all_data: list of (tableau, best_perms, cx_count) + all_data = sorted(all_data, key=lambda x: x[1]) # x[2] = cx_count + + n = len(all_data) + stages = [ + all_data[: n // 4], + all_data[n // 4 : n // 2], + all_data[n // 2 : 3 * n // 4], + all_data[3 * n // 4 :], + ] + + for stage_idx, stage_data in enumerate(stages): + print( + f"Training on curriculum stage {stage_idx+1} with {len(stage_data)} samples" + ) + # Write stage_data to a temporary pickle file + with tempfile.NamedTemporaryFile(delete=False, suffix=".pkl") as tmp: + pickle.dump(stage_data, tmp) + tmp_filename = tmp.name + + dataset = TableauPermutationDataset(tmp_filename) + dataloader = torch.utils.data.DataLoader( + dataset, batch_size=32, shuffle=True, collate_fn=custom_collate_fn + ) + model = train(model, dataloader, epochs_per_stage, device) + + # Clean up the temporary file + os.remove(tmp_filename) + return model + + +def main(): + """ + Main function to run the example usage of the model. + """ + # Example usage without fine-tuning + device = "cpu" + n_qubit = 4 + model = pretrain_a_model_from_file("training_data_perm.pkl", 320, 5, device) + # # Try curriculum training + # model = curriculum_train( + # data_file="training_data_perm.pkl", max_samples=320, epochs_per_stage=5 + # ) + circuit = random_hscx_circuit(nr_qubits=n_qubit, nr_gates=1000) + tableau = tableau_from_circuit(CliffordTableau(n_qubit), circuit) + # permutations = predict_permutation_gumbel(model, tableau, device) # Alternative for prediction + permutations = entropy_guided_search( + model, tableau, device + ) # Alternative for prediction + # permutations = ensemble_predict_permutation(model, tableau, device) # Ensemble method for prediction + print(permutations) + + # # Example usage with SL fine-tuning + # model = pretrain_a_model_from_file("training_data_perm.pkl", 320, 5, device) + # # Try SL fine-tuning + # sl_dataset = TableauPermutationDataset( + # data_file="training_data_perm_4_qubit.pkl", max_samples=320 + # ) + # sl_model = supervised_cx_fine_tune(model, sl_dataset, 5, device) + # circuit = random_hscx_circuit(nr_qubits=4, nr_gates=1000) + # tableau = tableau_from_circuit(CliffordTableau(4), circuit) + # permutations = predict_permutation_gumbel(sl_model, tableau) + # print(permutations) + + return 0 + + +if __name__ == "__main__": + main() diff --git a/src/nn_eval_main.py b/src/nn_eval_main.py index 05ba411..08cba2a 100644 --- a/src/nn_eval_main.py +++ b/src/nn_eval_main.py @@ -1,22 +1,49 @@ -"""Code to possibly evaluate the NN training approach. Currently, this only compares our and the CNN compilation. """ +"""Code to possibly evaluate the NN training approach. Currently, this only compares our and the CNN compilation.""" + import warnings from typing import List - +import networkx as nx import numpy as np import pandas as pd +import torch +import matplotlib.pyplot as plt +import seaborn as sns + from pauliopt.circuits import Circuit from pauliopt.clifford.tableau import CliffordTableau from pauliopt.clifford.tableau_synthesis import synthesize_tableau_perm_row_col from pauliopt.topologies import Topology +from src.rl.env import CliffordTableauEnv +from src.rl.agent import DQNAgent from src.nn.brute_force_data import get_best_cnots from src.utils import random_hscx_circuit, tableau_from_circuit +from src.nn.permutation_clean import ( + pretrain_a_model_from_file, + OrderedPermutationTransformer, + TableauPermutationDataset, + supervised_cx_fine_tune, + predict_permutation_gumbel, + entropy_guided_search, + ensemble_predict_permutation, + curriculum_train, +) + # Suppress all overflow warnings globally -np.seterr(over='ignore') +np.seterr(over="ignore") # Suppress FutureWarning -warnings.simplefilter(action='ignore', category=FutureWarning) +warnings.simplefilter(action="ignore", category=FutureWarning) + +model_path = "models/finetuned_model_up_to_nr_gates_10.pt" +checkpoint = torch.load(model_path, map_location=torch.device("cpu")) +CONFIG = checkpoint["config"] +n_qubits = 4 +agent = DQNAgent(n_qubits=n_qubits, config=CONFIG) +agent.model.load_state_dict(checkpoint["model_state_dict"]) +agent.model.eval() +agent.epsilon = 0.0 def collect_circuit_data(circuit: Circuit) -> dict: @@ -27,10 +54,11 @@ def collect_circuit_data(circuit: Circuit) -> dict: "h": ops.get("h", 0), "s": ops.get("s", 0), "cx": ops.get("cx", 0), - "depth": circuit.to_qiskit().depth() + "depth": circuit.to_qiskit().depth(), } +# "number of repetitions", "repetition index"??? def our_compilation(circuit: Circuit, topology: Topology, n_rep: int): """ Compilation from previous paper as a baseline. @@ -44,7 +72,11 @@ def our_compilation(circuit: Circuit, topology: Topology, n_rep: int): clifford_tableau = tableau_from_circuit(clifford_tableau, circuit) circ_out = synthesize_tableau_perm_row_col(clifford_tableau, topology) - return {"n_rep": n_rep} | collect_circuit_data(circ_out) | {"method": "normal_heuristic"} + return ( + {"n_rep": n_rep} + | collect_circuit_data(circ_out) + | {"method": "normal_heuristic"} + ) def random_compilation(circuit: Circuit, topology: Topology, n_rep: int): @@ -58,14 +90,20 @@ def random_compilation(circuit: Circuit, topology: Topology, n_rep: int): clifford_tableau = CliffordTableau(circuit.n_qubits) clifford_tableau = tableau_from_circuit(clifford_tableau, circuit) - def pick_pivot_callback(G, remaining: "CliffordTableau", remaining_rows: List[int], choice_fn=min): + def pick_pivot_callback( + G, remaining: "CliffordTableau", remaining_rows: List[int], choice_fn=min + ): row = np.random.choice(remaining_rows) col = row return row, col - circ_out = synthesize_tableau_perm_row_col(clifford_tableau, topology, pick_pivot_callback=pick_pivot_callback) + circ_out = synthesize_tableau_perm_row_col( + clifford_tableau, topology, pick_pivot_callback=pick_pivot_callback + ) return {"n_rep": n_rep} | collect_circuit_data(circ_out) | {"method": "random"} + +# Bruteforce compilation, of course it is the optimal... def optimal_compilation(circuit: Circuit, topology: Topology, n_rep: int): """ Brute force compilation of the circuit (may be slow for >=4 qubits!) @@ -77,17 +115,134 @@ def optimal_compilation(circuit: Circuit, topology: Topology, n_rep: int): clifford_tableau = CliffordTableau(circuit.n_qubits) clifford_tableau = tableau_from_circuit(clifford_tableau, circuit) - best_permutation, score = get_best_cnots(clifford_tableau.inverse().inverse(), topology)[0] + best_permutation, score = get_best_cnots( + clifford_tableau.inverse().inverse(), topology + )[0] + + # print_perm = get_best_cnots(clifford_tableau.inverse().inverse(), topology) + # print(f"best_permutation: {print_perm}") + best_permutation = iter(best_permutation) - def pick_pivot_callback(G, remaining: "CliffordTableau", remaining_rows: List[int], choice_fn=min): + def pick_pivot_callback( + G, remaining: "CliffordTableau", remaining_rows: List[int], choice_fn=min + ): row, col = next(best_permutation) return row, col - circ_out = synthesize_tableau_perm_row_col(clifford_tableau, topology, pick_pivot_callback=pick_pivot_callback) + circ_out = synthesize_tableau_perm_row_col( + clifford_tableau, topology, pick_pivot_callback=pick_pivot_callback + ) return {"n_rep": n_rep} | collect_circuit_data(circ_out) | {"method": "optimum"} +def dummy_perm_compilation( + circuit: Circuit, + topology: Topology, + n_rep: int, + model: OrderedPermutationTransformer, + device, +): + """ + Compilation using the neural network approach. + """ + clifford_tableau = CliffordTableau(circuit.n_qubits) + clifford_tableau = tableau_from_circuit(clifford_tableau, circuit) + # best_permutation = predict_permutation_gumbel(model, clifford_tableau, device) + best_permutation = entropy_guided_search(model, clifford_tableau, device) + # best_permutation = ensemble_predict_permutation(model, clifford_tableau, device) + best_permutation = iter(best_permutation[0]) + + def pick_pivot_callback( + G, remaining: "CliffordTableau", remaining_rows: List[int], choice_fn=min + ): + row, col = next(best_permutation) + return row, col + + circ_out = synthesize_tableau_perm_row_col( + clifford_tableau, topology, pick_pivot_callback=pick_pivot_callback + ) + return {"n_rep": n_rep} | collect_circuit_data(circ_out) | {"method": "dummy-perm"} + +def rl_compilation(circuit: Circuit, topology: Topology, n_rep: int): + """Use the RL agent to compile the circuit.""" + tableau = tableau_from_circuit(CliffordTableau(circuit.n_qubits), circuit) + env = CliffordTableauEnv( + n_qubits=circuit.n_qubits, + nr_gates=0, + topology=topology, + cx_penalty=0.0, + h_penalty=0.0, + s_penalty=0.0, + final_reward=0.0 + ) + env.clifford_tableau_to_reduce = tableau.inverse() + env.final_circuit = Circuit(circuit.n_qubits) + env.final_cx = None + env.allowed_rows = list(range(circuit.n_qubits)) + env.allowed_cols = list(range(circuit.n_qubits)) + env.qubits_reduced = 0 + env.graph = env.topology.to_nx + env.adjacency_matrix = nx.adjacency_matrix(env.graph).toarray() + + def pick_pivot(G, remaining, rows, choice_fn=min): + obs = env._get_obs() + row, col = agent.act(obs, env.allowed_rows, env.allowed_cols, explore=False) + env.allowed_rows.remove(row) + env.allowed_cols.remove(col) + env.graph.remove_node(col) + return row, col + + circ_out = synthesize_tableau_perm_row_col(tableau, topology, pick_pivot_callback=pick_pivot) + return {"n_rep": n_rep, "method": "rl_model", **collect_circuit_data(circ_out)} + +def visualize_optimality_gaps(df): + """ + Plot the gap between each method and the optimum solution. + Args: + df (pd.DataFrame): DataFrame containing the results of the experiments. + """ + plt.figure(figsize=(14, 8)) + sns.set_style("whitegrid") + + # Extract the methods we care about + methods = ["normal_heuristic", "dummy-perm", "combined_min", "optimum","rl_model"] + colors = ["#1f77b4", "#ff7f0e", "#2ca02c", "#d62728", "#9467bd"] + labels = ["Standard Heuristic", "Neural Network", "Combined", "Optimum","RL Model"] + + # Plot main trends + for method, color, label in zip(methods, colors, labels): + method_df = df[df["method"] == method].sort_values("n_rep") + rolling_avg = method_df["cx"].rolling(window=50, min_periods=1).mean() + + # Plot rolling average + plt.plot(method_df["n_rep"], rolling_avg, color=color, linewidth=3, label=label) + + # Add title and labels + plt.title("CX Count Comparison with Optimum", fontsize=16) + plt.xlabel("Circuit Evaluation Index", fontsize=14) + plt.ylabel("Number of CX Gates", fontsize=14) + plt.legend(fontsize=12) + + # Add average gap statistics in text box + gaps_text = "Average Gap to Optimum:\n" + opt_mean = df[df["method"] == "optimum"]["cx"].mean() + + for method, label in zip(methods, labels): + if method != "optimum": + method_mean = df[df["method"] == method]["cx"].mean() + gap = method_mean - opt_mean + gap_percent = (gap / opt_mean) * 100 + gaps_text += f"{label}: +{gap:.2f} gates (+{gap_percent:.1f}%)\n" + + plt.figtext( + 0.02, 0.02, gaps_text, fontsize=12, bbox=dict(facecolor="white", alpha=0.9) + ) + + plt.tight_layout() + plt.savefig("optimality_gap_comparison.png", dpi=300) + plt.show() + def main(n_qubits: int = 4, nr_gates: int = 1000): """ @@ -97,26 +252,124 @@ def main(n_qubits: int = 4, nr_gates: int = 1000): :return: """ - df = pd.DataFrame(columns=["n_rep", "num_qubits", "method", "h", "s", "cx", "depth"]) + device = "cpu" + cnt_eval = 10 + + # # If want to train a new model, uncomment the following line. + # model = pretrain_a_model_from_file("nn/training_data_perm.pkl", None, 50, device) + + # # If want to use curriculum training, uncomment the following line. + # model = curriculum_train( + # "nn/training_data_perm.pkl", max_samples=None, epochs_per_stage=25 + # ) + + # # If want to use supervised learning fine-tuning, uncomment the following lines. + # sl_dataset = TableauPermutationDataset( + # "nn/training_data_perm_4_qubit.pkl", max_samples=320 + # ) + # sl_model = supervised_cx_fine_tune(model, sl_dataset, epochs=50, device="cpu") + + # If want to use pre-trained model, uncomment the following lines. + # The 4-qubit model `ordered_permutation_model.pth` is ready to use. + checkpoint_perm = torch.load("models/ordered_permutation_model.pth", map_location=device) + #print(type(checkpoint_perm)) + #if isinstance(checkpoint_perm, dict): + # print(checkpoint_perm.keys()) + model = OrderedPermutationTransformer(n_qubits=n_qubits, dim=256, num_layers=12) + model.load_state_dict(checkpoint_perm) + model.eval() + + df = pd.DataFrame( + columns=["n_rep", "num_qubits", "method", "h", "s", "cx", "depth"] + ) topo = Topology.complete(n_qubits) - for i in range(20): - print(i) + confusion_matrix = pd.DataFrame() + + if nr_gates > 20: + print("Warning: nr_gates > 20, RL agent only trained up to 20 gates complexity.") + + for i in range(cnt_eval): + print(f"Iteration {i}") circuit = random_hscx_circuit(nr_qubits=n_qubits, nr_gates=nr_gates) + method_scores = {} + for method_fn in [ + our_compilation, + random_compilation, + # nn_compilation, + optimal_compilation, + rl_compilation, + dummy_perm_compilation + ]: + if method_fn == dummy_perm_compilation: + row = method_fn(circuit.copy(),topo,i,model,device) + else: + row = method_fn(circuit.copy(), topo, i) + df = pd.concat([df,pd.DataFrame([row])], ignore_index=True) + method_scores[row["method"]] = row["cx"] + print(f"{row['method']}: {row['cx']}", end=" | ") + print("\n") + + r1_score = method_scores["rl_model"] + optimum_score = method_scores["optimum"] + + if optimum_score not in confusion_matrix.index or r1_score not in confusion_matrix.columns: + confusion_matrix.loc[optimum_score, r1_score] = 0 + + confusion_matrix.loc[optimum_score, r1_score] += 1 + + # df_dictionary = pd.DataFrame( + # [dummy_perm_compilation(circuit.copy(), topo, i, sl_model)] + # ) # Replace the above line with this line if using SL fine-tuning - df_dictionary = pd.DataFrame([our_compilation(circuit.copy(), topo, i)]) - df = pd.concat([df, df_dictionary], ignore_index=True) - print("Min", df_dictionary["cx"]) - df_dictionary = pd.DataFrame([optimal_compilation(circuit.copy(), topo, i)]) - df = pd.concat([df, df_dictionary], ignore_index=True) - print("OPTIMUM", df_dictionary["cx"]) - df_dictionary = pd.DataFrame([random_compilation(circuit.copy(), topo, i)]) - df = pd.concat([df, df_dictionary], ignore_index=True) - print("Random", df_dictionary["cx"]) + # Convert the cx column to a numerical type + df["cx"] = pd.to_numeric(df["cx"]) + + # Create a combined method from existing results + combined_results = [] + for rep in df["n_rep"].unique(): + # Get results for this circuit + circuit_df = df[df["n_rep"] == rep] + # Get rows for both methods + heuristic_row = circuit_df[circuit_df["method"] == "normal_heuristic"].iloc[0] + dummy_row = circuit_df[circuit_df["method"] == "dummy-perm"].iloc[0] + # Choose the better one + if heuristic_row["cx"] <= dummy_row["cx"]: + best_row = heuristic_row.copy() + else: + best_row = dummy_row.copy() + # Update the method name + best_row["method"] = "combined_min" + # Add to results + combined_results.append(best_row) + # Add combined results to the DataFrame + combined_df = pd.DataFrame(combined_results) + df = pd.concat([df, combined_df], ignore_index=True) df.to_csv("test_clifford_synthesis.csv", index=False) + # Question: what should be the comparision metric? Mean, median, std, mse, etc.? + print("\nMean scores by method") print(df.groupby("method").mean()) + visualize_optimality_gaps(df) # Plot the results + + # Is the difference just luck? + from scipy.stats import ttest_ind + + nn_cx_values = df[df["method"] == "nn"]["cx"] + random_cx_values = df[df["method"] == "random"]["cx"] + t_stat, p_value = ttest_ind(nn_cx_values, random_cx_values) + + print(f"T-test results: t-statistic = {t_stat}, p-value = {p_value}") + if p_value < 0.05: + print( + "The difference in cx values between nn and random is statistically significant (p < 0.05)." + ) + else: + print( + "The difference in cx values between nn and random is not statistically significant (p >= 0.05)." + ) + if __name__ == "__main__": main() diff --git a/src/nn_generate_data_main.py b/src/nn_generate_data_main.py index 2888670..b5811c0 100644 --- a/src/nn_generate_data_main.py +++ b/src/nn_generate_data_main.py @@ -4,30 +4,57 @@ from src.nn.preprocess_data import PreprocessingType -def generate_data(n_qubits, n_gates, labels_as_described:bool, preprocessing_type:PreprocessingType): - X_train, Y_train = generate_dataset_ct(1, n_qubits, n_gates,labels_as_described=labels_as_described, preprocessing_type=preprocessing_type) - X_val, Y_val = generate_dataset_ct(1, n_qubits, n_gates,labels_as_described=labels_as_described, preprocessing_type=preprocessing_type) - - torch.save((X_train, Y_train), f'train_data_{labels_as_described}_{preprocessing_type.value}.pt') - torch.save((X_val, Y_val), f'val_data_{labels_as_described}_{preprocessing_type.value}.pt') +def generate_data( + n_qubits, n_gates, labels_as_described: bool, preprocessing_type: PreprocessingType +): + X_train, Y_train = generate_dataset_ct( + 1, + n_qubits, + n_gates, + labels_as_described=labels_as_described, + preprocessing_type=preprocessing_type, + ) + X_val, Y_val = generate_dataset_ct( + 1, + n_qubits, + n_gates, + labels_as_described=labels_as_described, + preprocessing_type=preprocessing_type, + ) + + torch.save( + (X_train, Y_train), + f"train_data_{labels_as_described}_{preprocessing_type.value}.pt", + ) + torch.save( + (X_val, Y_val), f"val_data_{labels_as_described}_{preprocessing_type.value}.pt" + ) print("Successfully generated example data!") -#def main(): +# def main(): # X_train, Y_train = torch.load('train_data.pt') # X_val, Y_val = torch.load('val_data.pt') # print(X_train.shape, Y_train.shape) # print(X_val.shape, Y_val.shape) -if __name__ == '__main__': +if __name__ == "__main__": labels_as_described = False preprocessing_type = PreprocessingType.ORIGINAL n_qubits = [4] - n_gates_per_circuit = [1000] + n_gates_per_circuit = [10] # The code is written so you can make multiple datasets at once # so you can define multiple n_qubits, and multiple n_gates # However, I believe they are all gathered in the same Tensor stack so they need to be the same number. # Feel free to fix this as need arises. - generate_data(n_qubits, n_gates_per_circuit, labels_as_described, preprocessing_type) - #main() + generate_data( + n_qubits, n_gates_per_circuit, labels_as_described, preprocessing_type + ) + + labels_as_described = True + preprocessing_type = PreprocessingType.FROM_PROJECT_DESCRIPTION + generate_data( + n_qubits, n_gates_per_circuit, labels_as_described, preprocessing_type + ) + # main() diff --git a/src/rl/agent.py b/src/rl/agent.py index 45adb87..4433e9d 100644 --- a/src/rl/agent.py +++ b/src/rl/agent.py @@ -1,87 +1,225 @@ +from typing import Tuple import random from collections import deque -from typing import Tuple - import numpy as np import torch import torch.nn as nn +import torch.nn.functional as F import torch.optim as optim - from src.rl.env import Array3D - -def _build_model(): - return nn.Sequential( - nn.Conv2d(3, 512, kernel_size=3), - nn.ReLU(), - nn.ConvTranspose2d(512, 1, kernel_size=3), - ) - +def init_weights(m): + if isinstance(m, (nn.Linear, nn.Conv2d)): + nn.init.xavier_uniform_(m.weight) + if m.bias is not None: + nn.init.zeros_(m.bias) + +class ResidualBlock(nn.Module): + def __init__(self, channels): + super().__init__() + self.conv1 = nn.Conv2d(channels, channels, kernel_size=3, padding=1) + self.relu = nn.ReLU() + self.conv2 = nn.Conv2d(channels, channels, kernel_size=3, padding=1) + + def forward(self, x): + identity = x + out = self.relu(self.conv1(x)) + out = self.conv2(out) + return self.relu(out + identity) + +def _build_model(n_qubits=4): + class CNNFactorizedQNet(nn.Module): + def __init__(self, n_qubits): + super().__init__() + self.n_qubits = n_qubits + self.backbone = nn.Sequential( + nn.Conv2d(8, 128, kernel_size=3, padding=1), + nn.ReLU(), + ResidualBlock(128), + ResidualBlock(128), + ResidualBlock(128), + nn.AdaptiveAvgPool2d((n_qubits, n_qubits)), + nn.Flatten(), + nn.LayerNorm(128 * n_qubits * n_qubits), + nn.Linear(128 * n_qubits * n_qubits, 256), + nn.ReLU(), + nn.Dropout(p=0.2), + ) + self.control_head = nn.Linear(256, n_qubits) + self.target_head = nn.Linear(256, n_qubits) + self.apply(init_weights) + + def forward(self, x): + features = self.backbone(x) + control_q = self.control_head(features) + target_q = self.target_head(features) + q_matrix = torch.einsum("bi,bj->bij", control_q, target_q) + return q_matrix + + return CNNFactorizedQNet(n_qubits) class DQNAgent: - def __init__(self, n_qubits: int, gamma=0.99, epsilon=1.0, epsilon_min=0.01, epsilon_decay=0.997) -> None: - self.model = _build_model() + def __init__(self, n_qubits, config): + self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + self.model = _build_model(n_qubits).to(self.device) + self.target_model = _build_model(n_qubits).to(self.device) + self.target_model.load_state_dict(self.model.state_dict()) + self.target_model.eval() + self.memory = deque(maxlen=20000) - self.gamma = gamma - self.epsilon = epsilon - self.epsilon_min = epsilon_min - self.epsilon_decay = epsilon_decay - self.learning_rate = 1e-4 + self.archive = deque(maxlen=100) # Top-performing episodes + self.episode_buffer = [] + + self.gamma = config["gamma"] + self.epsilon = config.get("epsilon_start", 1.0) + self.epsilon_min = config["epsilon_min"] + self.epsilon_decay = config["epsilon_decay"] + self.learning_rate = config["learning_rate"] + self.gradient_clip_norm = config["gradient_clip_norm"] + self.reward_clip = config.get("reward_clip", 20.0) + self.final_reward = config["final_reward"] self.n_qubits = n_qubits + self.k_explore = config.get("top_k", 5) + self.optimizer = optim.Adam(self.model.parameters(), lr=self.learning_rate) + self.losses = [] + + def update_target_network(self): + self.target_model.load_state_dict(self.model.state_dict()) + + def remember(self, state, action, reward, next_state, done): + self.episode_buffer.append((state, action, reward, next_state, done)) + + def remember_episode(self, done: bool): + if done: + # Always save the current episode transitions to the full memory + self.memory.extend(self.episode_buffer) + + try: + # Try to get the final CX value from the last next_state + cx_val = self.episode_buffer[-1][3][0][-1, 0, 0] # (batch, channel, row, col) + + # Check if agent knows the true optimal CX for the environment + if hasattr(self, 'true_optimal_cx') and self.true_optimal_cx is not None: + # Supervised fine-tuning: archive only if CX is close to optimal + if cx_val <= 1 * self.true_optimal_cx: + self.archive.append(list(self.episode_buffer)) # Save a copy + else: + # Normal training (no true optimal known): archive best-performing episodes + if len(self.archive) < self.archive.maxlen or cx_val < max( + x[-1][3][0][-1, 0, 0] for x in self.archive + ): + self.archive.append(list(self.episode_buffer)) # Save a copy + except Exception: + # If anything fails (e.g., empty buffer), just skip archiving safely + pass + + # Clear current episode buffer for next episode + self.episode_buffer.clear() + + def act(self, state: Array3D, allowed_rows: list, allowed_cols: list, explore: bool = True) -> Tuple[int, int]: + if explore and np.random.rand() <= self.epsilon: + return random.choice(allowed_rows), random.choice(allowed_cols) + + with torch.no_grad(): + input_tensor = torch.from_numpy(state).float().unsqueeze(0).to(self.device) + q_values = self.model(input_tensor)[0].cpu() + + mask = torch.full((self.n_qubits, self.n_qubits), float('-inf')) + for r in allowed_rows: + for c in allowed_cols: + mask[r, c] = q_values[r, c] + + if explore: + flat_values = mask.flatten() + topk = min(self.k_explore, len(allowed_rows) * len(allowed_cols)) + topk_indices = torch.topk(flat_values, topk).indices + topk_values = flat_values[topk_indices] + probs = torch.softmax(topk_values, dim=0).numpy() + chosen_idx = np.random.choice(topk_indices.numpy(), p=probs) + else: + chosen_idx = torch.argmax(mask).item() + + row_idx, col_idx = divmod(chosen_idx, self.n_qubits) + return row_idx, col_idx def replay(self, batch_size): - minibatch = random.sample(self.memory, batch_size) - all_current_q_values = [] - all_target_q_values = [] - for state, action, reward, next_state, done in minibatch: - target = reward - if not done: - with torch.no_grad(): - target = target + self.gamma * torch.max(self.model(torch.from_numpy(next_state[0]).float())) - - target_q_values = torch.zeros(size=(1, self.n_qubits, self.n_qubits)) - target_q_values[0][action] = target - - current_q_values = torch.from_numpy(state[0]).float() - - all_target_q_values.append(target_q_values) - all_current_q_values.append(current_q_values) - target_q_values = torch.stack([item for item in all_target_q_values]) - current_q_values = self.model(torch.stack([item for item in all_current_q_values])) - self.optimizer_step(current_q_values, target_q_values) - - if self.epsilon > self.epsilon_min: - self.epsilon *= self.epsilon_decay - - def remember(self, state: Tuple[Array3D, list, list], - action: Tuple[int, int], - reward: float, - next_state: Tuple[Array3D, list, list], done: bool): - - self.memory.append((state, action, reward, next_state, done)) - - def act(self, state: Array3D, allowed_rows: list, allowed_cols: list) -> Tuple[int, int]: - if np.random.rand() <= self.epsilon: - row = random.choice(allowed_rows) - col = random.choice(allowed_cols) - return row, col - q_values = self.model(torch.from_numpy(state).float()).cpu().detach()[0] - q_values = q_values[allowed_rows][:, allowed_cols] - row_idx, col_idx = divmod(torch.argmax(q_values).item(), q_values.size(1)) - - selected_row = allowed_rows[row_idx] - selected_col = allowed_cols[col_idx] - - return selected_row, selected_col - - def optimizer_step(self, state_action_values, expected_state_action_values): - criterion = nn.HuberLoss() - loss = criterion(expected_state_action_values, state_action_values) - # Optimize the model - self.optimizer.zero_grad() - loss.backward() - # In-place gradient clipping + if len(self.memory) < batch_size: + return + + """Base model replay function.""" + """ + primary_batch = random.sample(self.memory, int(batch_size * 0.8)) + archive_batch = [] + if self.archive: + for ep in random.sample(list(self.archive), min(len(self.archive), batch_size - len(primary_batch))): + archive_batch.append(random.choice(ep)) + batch = primary_batch + archive_batch + """ + """Sharp reward replay function for finetuning.""" + archive_ratio = 0.1 + archive_batch_size = int(batch_size * archive_ratio) + memory_batch_size = batch_size - archive_batch_size + + memory_batch = random.sample(self.memory, min(len(self.memory), memory_batch_size)) + archive_batch = [] + + if self.archive: + archive_batch = [ + random.choice(ep) for ep in random.sample(list(self.archive), min(len(self.archive), archive_batch_size)) + ] + batch = memory_batch + archive_batch + + states, actions, targets, rewards = [], [], [], [] + + for s, a, r, s_next, done in batch: + s_tensor = torch.from_numpy(s[0]).float().to(self.device) + s_next_tensor = torch.from_numpy(s_next[0]).float().unsqueeze(0).to(self.device) + rewards.append(r) + + with torch.no_grad(): + q_next = self.target_model(s_next_tensor) + best_action = torch.argmax(q_next.view(-1)).item() + max_q = q_next[0].view(-1)[best_action] + + y = r if done else r + self.gamma * max_q.item() + y = np.clip(y, -self.reward_clip, self.reward_clip) + + states.append(s_tensor) + actions.append(a) + targets.append(y) + + states = torch.stack(states) + actions = torch.tensor(actions).long().to(self.device) + targets = torch.tensor(targets).float().to(self.device) + rewards_batch = torch.tensor(rewards).float().to(self.device) + + q_pred = self.model(states) + q_vals = q_pred[torch.arange(len(states)), actions[:, 0], actions[:, 1]] + + # loss = F.mse_loss(q_vals, targets) # Loss function for base model training with lower reward scheme + + # Loss function for base model training with sharp rewards + losses = F.smooth_l1_loss(q_vals, targets, reduction='none') + + # whether you want to apply reward scaling + USE_REWARD_SCALING = True + + if USE_REWARD_SCALING: + rewards_batch = torch.tensor(rewards).float().to(self.device) + final_reward = self.final_reward + alpha = 0.5 + + reward_factor = (1 - alpha * (rewards_batch / final_reward)) + scaled_losses = losses * reward_factor + final_loss = scaled_losses.mean() + else: + final_loss = losses.mean() + self.optimizer.zero_grad() + final_loss.backward() + torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.gradient_clip_norm) self.optimizer.step() - torch.nn.utils.clip_grad_value_(self.model.parameters(), 100) + + self.losses.append(final_loss.item()) + self.epsilon = max(self.epsilon_min, self.epsilon * self.epsilon_decay) \ No newline at end of file diff --git a/src/rl/env.py b/src/rl/env.py index 318ec84..5179238 100644 --- a/src/rl/env.py +++ b/src/rl/env.py @@ -1,5 +1,4 @@ -from typing import Tuple, Optional, Union, List, Any, Dict - +from typing import Tuple, Optional, Union, List import gym import networkx as nx import numpy as np @@ -10,123 +9,145 @@ from pauliopt.gates import CX, H, S from pauliopt.topologies import Topology from pauliopt.utils import is_cutting - from src.utils import random_hscx_circuit, tableau_from_circuit +from src.nn.brute_force_data import get_best_cnots -Array3D = np.array - +Array3D = np.ndarray class CliffordTableauEnv(gym.Env[Tuple[int, int], np.ndarray]): - def __init__(self, n_qubits: int, nr_gates: int = 1000, topology: Topology = None): - """ - Defines a RL environment, that describes the synthesis of a clifford tableau. - - :param n_qubits: Nr of qubits of the tableau - :param nr_gates: Nr of gates of the tableau - :param topology: Which topology to synth the tableau with - """ - super(CliffordTableauEnv, self).__init__() + def __init__(self, n_qubits: int, nr_gates: int = 1000, topology: Topology = None, + cx_penalty: float = -0.5, h_penalty: float = -0.1, s_penalty: float = -0.1, + final_reward: float = 30.0, final_exp_decay: float = 0.3): + super().__init__() self.n_qubits = n_qubits self.nr_gates = nr_gates - self.clifford_tableau_to_reduce = None - self.final_circuit = None - self.qubits_reduced = 0 - if topology is None: - self.topology = Topology.complete(self.n_qubits) - else: - self.topology = topology - self.graph = self.topology.to_nx - self.adjacency_matrix = nx.adjacency_matrix(self.graph).toarray() - self.allowed_rows = list(range(self.n_qubits)) - self.allowed_cols = list(range(self.n_qubits)) + self.cx_penalty = cx_penalty + self.h_penalty = h_penalty + self.s_penalty = s_penalty + self.final_reward = final_reward * (1 + ((self.nr_gates - 5) // 5)) + self.final_exp_decay = final_exp_decay + self.topology = topology or Topology.complete(n_qubits) def reset(self, **kwargs): - """ - - :param kwargs: - :return: - """ - circuit = random_hscx_circuit(nr_qubits=self.n_qubits, nr_gates=self.nr_gates) - clifford_tableau = CliffordTableau(self.n_qubits) - clifford_tableau = tableau_from_circuit(clifford_tableau, circuit) - self.clifford_tableau_to_reduce = clifford_tableau.inverse() - self.final_circuit = Circuit(self.n_qubits) self.graph = self.topology.to_nx + self.adjacency_matrix = nx.adjacency_matrix(self.graph).toarray() self.allowed_rows = list(range(self.n_qubits)) self.allowed_cols = list(range(self.n_qubits)) - self.adjacency_matrix = nx.adjacency_matrix(self.graph).toarray() self.qubits_reduced = 0 + self.final_circuit = Circuit(self.n_qubits) + self.final_cx = None - return self._get_obs(), self.allowed_rows, self.allowed_cols - - def get_current_stats(self) -> float: - return self.final_circuit.to_qiskit().count_ops().get("cx", 0) + circuit = random_hscx_circuit(nr_qubits=self.n_qubits, nr_gates=self.nr_gates) + tableau = CliffordTableau(self.n_qubits) + self.initial_tableau = tableau_from_circuit(tableau, circuit) + self.clifford_tableau_to_reduce = self.initial_tableau.inverse() - def step(self, action: Tuple[int, int]) -> Tuple[Tuple[Array3D, list, list], float, bool, Dict[Any, Any]]: - current_circuit = Circuit(self.n_qubits) + self.true_optimal_cx = get_best_cnots(self.clifford_tableau_to_reduce.inverse(), self.topology)[0][1] + + return self._get_obs(), self.allowed_rows.copy(), self.allowed_cols.copy() - def apply(gate_name: str, gate_data: tuple) -> None: - if gate_name == "CNOT": - self.clifford_tableau_to_reduce.append_cnot(gate_data[0], gate_data[1]) - self.final_circuit.add_gate(CX(gate_data[0], gate_data[1])) - current_circuit.add_gate(CX(gate_data[0], gate_data[1])) - elif gate_name == "H": - self.clifford_tableau_to_reduce.append_h(gate_data[0]) - self.final_circuit.add_gate(H(gate_data[0])) - current_circuit.add_gate(H(gate_data[0])) - elif gate_name == "S": - self.clifford_tableau_to_reduce.append_s(gate_data[0]) - self.final_circuit.add_gate(S(gate_data[0])) - current_circuit.add_gate(S(gate_data[0])) - else: - raise Exception("Unknown Gate") + def get_current_stats(self) -> int: + if self.final_cx is not None: + return self.final_cx + return self.final_circuit.to_qiskit().count_ops().get("cx", 0) + + def _compute_supervised_reward(self, cx_found: int, cx_optimal: int) -> float: + if cx_found <= cx_optimal: + return self.final_reward + elif cx_found <= cx_optimal + 1: + return self.final_reward * 0.5 + elif cx_found <= cx_optimal + 2: + return self.final_reward * 0.25 + else: + overshoot = cx_found - cx_optimal + penalty = self.final_reward * np.exp(-0.2 * overshoot) + penalty = penalty - self.final_reward + return max(penalty, self.cx_penalty) + def step(self, action: Tuple[int, int]): pivot_row, pivot_col = action assert not is_cutting(pivot_col, self.graph) - self.allowed_rows.remove(pivot_row) - self.allowed_cols.remove(pivot_col) + reward = 0.0 + current_circuit = Circuit(self.n_qubits) + + def apply(gate_name: str, args: tuple): + gate_cls = {"CNOT": CX, "H": H, "S": S}[gate_name] + getattr(self.clifford_tableau_to_reduce, f"append_{gate_name.lower()}")(*args) + self.final_circuit.add_gate(gate_cls(*args)) + current_circuit.add_gate(gate_cls(*args)) + + if pivot_row in self.allowed_rows: + self.allowed_rows.remove(pivot_row) + if pivot_col in self.allowed_cols: + self.allowed_cols.remove(pivot_col) self.qubits_reduced += 1 + steiner_reduce_column(pivot_col, pivot_row, self.graph, self.clifford_tableau_to_reduce, apply) self.graph.remove_node(pivot_col) - done = False - if self.qubits_reduced >= self.n_qubits: - final_permutation = np.argmax(self.clifford_tableau_to_reduce.x_matrix, axis=1) - signs_copy_z = self.clifford_tableau_to_reduce.signs[ - self.clifford_tableau_to_reduce.n_qubits: 2 * self.clifford_tableau_to_reduce.n_qubits].copy() - for col in range(self.clifford_tableau_to_reduce.n_qubits): - if signs_copy_z[col] != 0: - apply("H", (final_permutation[col],)) - apply("S", (final_permutation[col],)) - apply("S", (final_permutation[col],)) - apply("H", (final_permutation[col],)) + # Reward based on gate cost + ops = current_circuit.to_qiskit().count_ops() + reward += self.cx_penalty * ops.get("cx", 0) + reward += self.h_penalty * ops.get("h", 0) + reward += self.s_penalty * ops.get("s", 0) + + done = self.qubits_reduced >= self.n_qubits + + """Basic reward structure (commented out): + if done: + self.final_cx = self.final_circuit.to_qiskit().count_ops().get("cx", 0) + bonus = self.final_reward * np.exp(-self.final_exp_decay * self.final_cx) + curriculum_level = max((self.nr_gates - 5) // 10, 0) + curriculum_bonus = curriculum_level * 6.0 + reward += bonus + curriculum_bonus + """ + + """ + Secondary reward structure (commented out): + if done: + self.final_cx = self.final_circuit.to_qiskit().count_ops().get("cx", 0) - for col in range(self.clifford_tableau_to_reduce.n_qubits): - if self.clifford_tableau_to_reduce.signs[col] != 0: - apply("S", (final_permutation[col],)) - apply("S", (final_permutation[col],)) + # Smoother final reward (larger values still get rewarded, just less) + decay = 0.25 # You can tune this down to 0.2 or up to 0.4 + smooth_bonus = self.final_reward / (1 + decay * self.final_cx) - done = True + # Curriculum bonus — give a small fixed increase every +2 gates + curriculum_level = max((self.nr_gates - 5) // 2, 0) + curriculum_bonus = curriculum_level * 7.5 # You can try 5.0 → 7.5 → 10.0 + """ + """Supervised finetuning reward structure with true optimal circuit (commented out):""" + if done: + self.final_cx = self.final_circuit.to_qiskit().count_ops().get("cx", 0) + reward = self._compute_supervised_reward(self.final_cx, self.true_optimal_cx) - reward = np.exp(-current_circuit.to_qiskit().count_ops().get("cx", 0)) - return (self._get_obs(), self.allowed_rows, self.allowed_cols), reward, done, {} - - def render(self) -> Optional[Union[RenderFrame, List[RenderFrame]]]: - print(self.clifford_tableau_to_reduce) - return None + return (self._get_obs(), self.allowed_rows.copy(), self.allowed_cols.copy()), reward, done, {} def _get_obs(self) -> Array3D: - disallowed_rows = list(set(list(range(self.n_qubits))) - set(self.allowed_rows)) - disallowed_columns = list(set(list(range(self.n_qubits))) - set(self.allowed_cols)) - bitmap = np.ones((self.n_qubits, self.n_qubits)) - bitmap[disallowed_rows, :] = 0.0 - bitmap[:, disallowed_columns] = 0.0 + n = self.n_qubits + + bitmap = np.ones((n, n), dtype=np.float32) + for i in range(n): + if i not in self.allowed_rows: + bitmap[i, :] = 0.0 + if i not in self.allowed_cols: + bitmap[:, i] = 0.0 + + x_block = self.clifford_tableau_to_reduce.x_matrix[:n, :n] + z_block = self.clifford_tableau_to_reduce.z_matrix[:n, :n] + adj_matrix = self.adjacency_matrix.astype(np.float32) + signs = self.clifford_tableau_to_reduce.signs[:n].astype(np.float32) + sign_map = np.tile(signs[:, np.newaxis], (1, n)) + + row_coords = np.tile(np.linspace(0, 1, n).reshape(n, 1), (1, n)).astype(np.float32) + col_coords = np.tile(np.linspace(0, 1, n).reshape(1, n), (n, 1)).astype(np.float32) + cx_channel = np.full((n, n), np.log1p(self.get_current_stats()), dtype=np.float32) + return np.stack([ - self.clifford_tableau_to_reduce.x_matrix, - self.clifford_tableau_to_reduce.z_matrix, - bitmap + x_block, z_block, bitmap, adj_matrix, + sign_map, row_coords, col_coords, cx_channel ], axis=0) - def _was_selected_previously(self, pivot_row: int, pivot_column: int) -> bool: - return pivot_row not in self.allowed_rows or pivot_column not in self.allowed_cols + def render(self) -> Optional[Union[RenderFrame, List[RenderFrame]]]: + print(self.clifford_tableau_to_reduce) + return None diff --git a/src/rl/finetune.py b/src/rl/finetune.py new file mode 100644 index 0000000..a8f1c5e --- /dev/null +++ b/src/rl/finetune.py @@ -0,0 +1,153 @@ +import torch +import matplotlib.pyplot as plt +from tqdm import trange +import numpy as np +from src.rl.agent import DQNAgent +from src.rl.env import CliffordTableauEnv + +# Finetuning-specific config +CONFIG = { + "learning_rate": 1e-5, + "batch_size": 32, + "epsilon_decay": 0.9999, + "epsilon_min": 0.001, + "gamma": 0.995, + "gradient_clip_norm": 10.0, + "step_penalty": -0.3, + "final_reward_max": 100.0, + "target_update_interval": 3, + "replay_start_size": 128, + "replay_every_n_steps": 1, + "aux_loss_weight": 0.4 +} + +class FinetunePlotter: + def __init__(self, total_episodes): + self.episodes = [] + self.cx = [] + self.opt_cx = [] + self.aux = [] + self.rewards = [] + self.losses = [] + + self.window = 50 + plt.ion() + self.fig, self.axs = plt.subplots(4, 1, figsize=(10, 10)) + self.fig.tight_layout(pad=2.0) + + self.lines = [] + for ax in self.axs: + line, = ax.plot([], []) + self.lines.append(line) + + self.axs[0].set_title("Final CX vs Optimal CX") + self.axs[0].legend(["Final CX", "Optimal CX"]) + self.axs[1].set_title("Total Reward") + self.axs[2].set_title("Auxiliary Label") + self.axs[3].set_title("Loss") + + self.ma_lines = [] + for ax in self.axs: + ma_line, = ax.plot([], [], linestyle='--', color='gray') + self.ma_lines.append(ma_line) + + def update(self, episode, cx, opt_cx, aux, reward, loss): + self.episodes.append(episode) + self.cx.append(cx) + self.opt_cx.append(opt_cx) + self.aux.append(aux) + self.rewards.append(reward) + self.losses.append(loss) + + self.lines[0].set_data(self.episodes, self.cx) + self.axs[0].plot(self.episodes, self.opt_cx, label="Optimal CX", color='orange') + + self.lines[1].set_data(self.episodes, self.rewards) + self.lines[2].set_data(self.episodes, self.aux) + self.lines[3].set_data(self.episodes, self.losses) + + # Moving averages + ma = lambda arr: np.convolve(arr, np.ones(self.window)/self.window, mode='valid') + if len(self.episodes) >= self.window: + ma_x = self.episodes[self.window-1:] + self.ma_lines[0].set_data(ma_x, ma(self.cx)) + self.ma_lines[1].set_data(ma_x, ma(self.rewards)) + self.ma_lines[2].set_data(ma_x, ma(self.aux)) + self.ma_lines[3].set_data(ma_x, ma(self.losses)) + + for ax in self.axs: + ax.relim() + ax.autoscale_view() + + plt.draw() + plt.pause(0.01) + + +def finetune(): + n_qubits = 4 + env = CliffordTableauEnv( + n_qubits=n_qubits, + nr_gates=100, + step_penalty=CONFIG["step_penalty"], + final_reward_max=CONFIG["final_reward_max"], + use_true_cx=True + ) + + agent = DQNAgent(n_qubits=n_qubits, config=CONFIG) + checkpoint = torch.load("models/best_model.pt") + agent.model.load_state_dict(checkpoint["model_state_dict"]) + agent.target_model.load_state_dict(checkpoint["target_model_state_dict"]) + agent.optimizer.load_state_dict(checkpoint["optimizer_state_dict"]) + agent.epsilon = checkpoint["epsilon"] + agent.losses = checkpoint["losses"] + + print("Starting finetuning from baseline model...") + plotter = FinetunePlotter(total_episodes=1000) + + progress = trange(5000, desc="Finetuning", dynamic_ncols=True) + for episode in progress: + state = env.reset() + opt_cx = env.get_true_optimal_cx() + done = False + total_reward = 0.0 + + while not done: + action = agent.act(*state) + next_state, reward, done, _ = env.step(action) + final_cx = env.final_cx if done else None + aux_label = env.get_auxiliary_label() if done else 0.0 + agent.remember(state, action, reward, next_state, done, final_cx) + state = next_state + total_reward += reward + + if len(agent.memory) > CONFIG["replay_start_size"]: + agent.replay(CONFIG["batch_size"]) + + if episode % CONFIG["target_update_interval"] == 0: + agent.update_target_network() + + avg_loss = agent.losses[-1] if agent.losses else 0.0 + progress.write( + f"[Finetune] Episode {episode}: Final CX = {final_cx}, True Opt = {opt_cx}, " + f"Aux(true-opt) = {aux_label:.3f}, Reward = {total_reward:.2f}, Loss = {avg_loss:.4f}" + ) + + plotter.update(episode, cx=final_cx, opt_cx=opt_cx, aux=aux_label, reward=total_reward, loss=avg_loss) + + torch.save({ + "model_state_dict": agent.model.state_dict(), + "target_model_state_dict": agent.target_model.state_dict(), + "optimizer_state_dict": agent.optimizer.state_dict(), + "epsilon": agent.epsilon, + "losses": agent.losses, + "config": CONFIG + }, "models/finetuned_model.pt") + + print("Finetuning complete. Saved to models/finetuned_model.pt") + plt.ioff() + plt.savefig("models/finetuning_progress.png") + plt.show() + + +if __name__ == "__main__": + finetune() \ No newline at end of file diff --git a/src/rl_main.py b/src/rl_main.py index 191f2fe..e64ffd5 100644 --- a/src/rl_main.py +++ b/src/rl_main.py @@ -1,48 +1,200 @@ +import os import matplotlib.pyplot as plt +import numpy as np +import torch +from tqdm import trange -from src.nn.brute_force_data import get_best_cnots from src.rl.agent import DQNAgent from src.rl.env import CliffordTableauEnv +CONFIG = { + "learning_rate": 0.00001, + "batch_size": 64, + "epsilon_start": 0, + "epsilon_min": 0.1, + "epsilon_decay": 0.99995, + "gamma": 0.95, + "gradient_clip_norm": 1.0, -def main(n_qubits=4, n_gates=100, n_episodes=1000, batch_size=2000): - """Describes a DQN Algorithm to try to learn clifford tableau heuristics.L""" - env = CliffordTableauEnv(n_qubits, n_gates) - agent = DQNAgent(n_qubits) + # Reward structure + "final_reward": 50.0, + "reward_clip": 50.0, + "cx_penalty": -5.0, + "h_penalty": 0.0, + "s_penalty": 0.0, - scores_episode = [] + # Target network & replay + "target_update_interval": 50, + "replay_every_n_steps": 4, + + # Logging & training + "n_episodes": 200000, + "moving_avg_window": 50, + + # Loss weight placeholders (unused currently but kept if needed later) + "aux_loss_weight": 0.2, + "cx_loss_weight": 3.0, + + # Exploration control + "top_k": 1, + + # Curriculum learning + "use_curriculum": True, + "curriculum_start_gates": 10, + "curriculum_step": 5, + "curriculum_max_gates": 51000 +} + +def save_checkpoint(agent, episode, best_cx, path): + torch.save({ + "model_state_dict": agent.model.state_dict(), + "target_model_state_dict": agent.target_model.state_dict(), + "optimizer_state_dict": agent.optimizer.state_dict(), + "epsilon": agent.epsilon, + "losses": agent.losses, + "episode": episode, + "config": CONFIG, + "best_cx": best_cx + }, path) + plt.savefig("models/best_model_plot.png") + with open("models/best_model_log.txt", "w") as f: + f.write(f"Best model at episode {episode} with avg CX count {best_cx:.4f}\n") + +def load_checkpoint(agent, path): + checkpoint = torch.load(path) + agent.model.load_state_dict(checkpoint["model_state_dict"]) + agent.target_model.load_state_dict(checkpoint["target_model_state_dict"]) + agent.optimizer.load_state_dict(checkpoint["optimizer_state_dict"]) + agent.epsilon = CONFIG["epsilon_start"] + agent.losses = checkpoint["losses"] + return checkpoint["episode"], checkpoint.get("best_cx", float("inf")) + +def main(): + resume_training = True + checkpoint_path = "models/checkpoint_ep_ft_5000.pt" + start_episode = 0 + + n_qubits = 4 + current_gates = CONFIG["curriculum_start_gates"] + + env = CliffordTableauEnv( + n_qubits=n_qubits, + nr_gates=current_gates, + cx_penalty=CONFIG["cx_penalty"], + h_penalty=CONFIG["h_penalty"], + s_penalty=CONFIG["s_penalty"], + final_reward=CONFIG["final_reward"] + ) + agent = DQNAgent(n_qubits=n_qubits, config=CONFIG) + + if resume_training and os.path.exists(checkpoint_path): + start_episode, best_cx = load_checkpoint(agent, checkpoint_path) + print(f"Resuming training from episode {start_episode} with best CX {best_cx:.4f}") + else: + best_cx = float("inf") + + scores_episode, rewards_episode = [], [] + moving_avg_scores, moving_avg_rewards = [], [] + + best_cx = float("inf") + previous_gates = current_gates + curriculum_episode_threshold = 10000 + next_curriculum_update = curriculum_episode_threshold + + plt.ion() + fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(10, 8)) + ax1.set_title("#CX over Episodes") + ax1.set_xlabel("Episodes") + ax1.set_ylabel("#CX") + line_cx, = ax1.plot([], [], label="#CX") + ma_line_cx, = ax1.plot([], [], label="Moving Avg CX", color="orange") + ax1.legend() + + ax2.set_title("Reward per Episode") + ax2.set_xlabel("Episodes") + ax2.set_ylabel("Reward") + line_r, = ax2.plot([], [], label="Reward") + ma_line_r, = ax2.plot([], [], label="Moving Avg Reward", color="green") + ax2.legend() + + def update_plot(): + x = range(len(scores_episode)) + line_cx.set_xdata(x) + line_cx.set_ydata(scores_episode) + ma_line_cx.set_xdata(x) + ma_line_cx.set_ydata(moving_avg_scores) + line_r.set_xdata(x) + line_r.set_ydata(rewards_episode) + ma_line_r.set_xdata(x) + ma_line_r.set_ydata(moving_avg_rewards) + ax1.relim(); ax1.autoscale_view() + ax2.relim(); ax2.autoscale_view() + plt.draw(); plt.pause(0.01) + + progress = trange(CONFIG["n_episodes"], desc="Training", dynamic_ncols=True) + for episode in progress: + if CONFIG["use_curriculum"] and episode >= next_curriculum_update: + new_gates = min(current_gates + CONFIG["curriculum_step"], CONFIG["curriculum_max_gates"]) + if new_gates != current_gates: + current_gates = new_gates + agent.epsilon = CONFIG["epsilon_start"] * 0.5 + curriculum_episode_threshold = int(curriculum_episode_threshold * 1) + next_curriculum_update = episode + curriculum_episode_threshold + + env.nr_gates = current_gates - env.reset() - cnots, score = get_best_cnots(env.clifford_tableau_to_reduce)[0] - for episode in range(n_episodes): - print(f"Episode: {episode}") state = env.reset() done = False - while not done: - action = agent.act(*state) + total_reward = 0 + step_count = 0 + while not done: + action = agent.act(*state, explore=True) next_state, reward, done, _ = env.step(action) - if done: - break - agent.remember(state, action, reward, next_state, done) state = next_state - if len(agent.memory) > batch_size: - agent.replay(batch_size) + total_reward += reward + step_count += 1 + + if len(agent.memory) >= CONFIG["batch_size"] and step_count % CONFIG["replay_every_n_steps"] == 0: + agent.replay(CONFIG["batch_size"]) + + agent.remember_episode(done) + + cx_count = env.get_current_stats() + scores_episode.append(cx_count) + rewards_episode.append(total_reward) + + ma_cx = np.mean(scores_episode[-CONFIG["moving_avg_window"]:]) + ma_reward = np.mean(rewards_episode[-CONFIG["moving_avg_window"]:]) + moving_avg_scores.append(ma_cx) + moving_avg_rewards.append(ma_reward) + + if episode % 5 == 0: + avg_loss = np.nanmean(agent.losses[-10:]) if agent.losses else float("nan") + progress.write( + f"Ep {episode} | Reward={ma_reward:.2f}, " + f"Eps={agent.epsilon:.3f}, " + f"Loss={avg_loss:.4f}, " + f"CX_MA={ma_cx:.2f}, CX={cx_count}, Gates={env.nr_gates}" + ) + + if episode % CONFIG["target_update_interval"] == 0: + agent.update_target_network() - if done: - scores_episode.append(env.get_current_stats()) - else: - scores_episode.append(-1) + if episode % 5 == 0: + update_plot() - plt.plot(scores_episode, label="#CX over epochs") - plt.axhline(y=score, color='red', linestyle='--', label="Best possible #CX") + if ma_cx < best_cx and episode > 50: + best_cx = ma_cx + save_checkpoint(agent, episode, best_cx, "models/best_model.pt") - plt.xlabel("Epochs") - plt.ylabel("#CX") - plt.legend() - plt.savefig("./dqn_agent.png") + if episode % 1000 == 0: + save_checkpoint(agent, episode, ma_cx, f"models/checkpoint_ep_ft_{episode}.pt") + plt.savefig("dqn_agent_trends.png") + plt.ioff(); plt.show() + print(f"Training finished. Best CX achieved: {best_cx:.4f}") -if __name__ == '__main__': - main() +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/src/test_clifford_synthesis.csv b/src/test_clifford_synthesis.csv new file mode 100644 index 0000000..1ce5cf9 --- /dev/null +++ b/src/test_clifford_synthesis.csv @@ -0,0 +1,51 @@ +n_rep,num_qubits,method,h,s,cx,depth +0,4,normal_heuristic,11,12,5,15 +0,4,optimum,13,17,5,19 +0,4,random,16,22,10,25 +0,4,dummy-perm,12,20,6,19 +1,4,normal_heuristic,11,20,10,21 +1,4,optimum,9,16,6,17 +1,4,random,13,19,12,24 +1,4,dummy-perm,11,20,8,16 +2,4,normal_heuristic,15,19,7,16 +2,4,optimum,15,21,6,18 +2,4,random,13,17,7,16 +2,4,dummy-perm,15,19,7,16 +3,4,normal_heuristic,10,16,6,15 +3,4,optimum,11,19,4,15 +3,4,random,10,16,11,18 +3,4,dummy-perm,12,18,6,17 +4,4,normal_heuristic,14,15,10,23 +4,4,optimum,9,20,6,20 +4,4,random,13,15,11,22 +4,4,dummy-perm,9,13,8,18 +5,4,normal_heuristic,9,13,6,15 +5,4,optimum,6,14,6,16 +5,4,random,12,22,10,25 +5,4,dummy-perm,10,14,9,20 +6,4,normal_heuristic,12,17,10,23 +6,4,optimum,14,27,7,24 +6,4,random,12,13,8,20 +6,4,dummy-perm,8,12,8,16 +7,4,normal_heuristic,13,16,10,23 +7,4,optimum,11,18,7,20 +7,4,random,11,19,10,24 +7,4,dummy-perm,7,14,7,19 +8,4,normal_heuristic,12,20,8,21 +8,4,optimum,9,13,6,17 +8,4,random,9,10,8,17 +8,4,dummy-perm,9,10,7,15 +9,4,normal_heuristic,11,16,9,16 +9,4,optimum,11,16,5,14 +9,4,random,12,23,10,24 +9,4,dummy-perm,9,19,8,15 +0,4,combined_min,11,12,5,15 +1,4,combined_min,11,20,8,16 +2,4,combined_min,15,19,7,16 +3,4,combined_min,10,16,6,15 +4,4,combined_min,9,13,8,18 +5,4,combined_min,9,13,6,15 +6,4,combined_min,8,12,8,16 +7,4,combined_min,7,14,7,19 +8,4,combined_min,9,10,7,15 +9,4,combined_min,9,19,8,15 diff --git a/src/utils.py b/src/utils.py index eab50f1..c6b66fe 100644 --- a/src/utils.py +++ b/src/utils.py @@ -22,6 +22,17 @@ def random_clifford_circuit(nr_gates=20, nr_qubits=4, gate_choice=None) -> Circu :param nr_gates: :param nr_qubits: :param gate_choice: Subset of ["CX", "H", "S", "V", "CY", "CZ", "Sdg", "Vdg", "X", "Y", "Z"] + CX: Controlled-NOT (CNOT) gate + H: Hadamard gate + S: Phase gate + V: V gate (also known as the √X gate, square root of X gate) + CY: Controlled-Y gate + CZ: Controlled-Z gate + Sdg: S-dagger gate (inverse of the S gate) + Vdg: V-dagger gate (inverse of the V gate) + X: Pauli-X gate (also known as the NOT gate) + Y: Pauli-Y gate + Z: Pauli-Z gate :return: """ qc = Circuit(nr_qubits) diff --git a/test_clifford_synthesis.csv b/test_clifford_synthesis.csv new file mode 100644 index 0000000..5428458 --- /dev/null +++ b/test_clifford_synthesis.csv @@ -0,0 +1,61 @@ +n_rep,num_qubits,method,h,s,cx,depth +0,4,normal_heuristic,15,18,10,24 +0,4,random,14,17,7,16 +0,4,optimum,12,18,5,17 +0,4,rl_model,11,19,7,18 +0,4,dummy-perm,9,14,6,17 +1,4,normal_heuristic,13,17,6,20 +1,4,random,13,18,8,18 +1,4,optimum,12,14,5,20 +1,4,rl_model,11,18,11,26 +1,4,dummy-perm,10,15,8,16 +2,4,normal_heuristic,17,16,7,18 +2,4,random,17,16,9,23 +2,4,optimum,13,20,7,20 +2,4,rl_model,17,16,11,24 +2,4,dummy-perm,13,15,8,16 +3,4,normal_heuristic,8,12,6,13 +3,4,random,16,14,11,22 +3,4,optimum,10,12,6,14 +3,4,rl_model,14,24,12,26 +3,4,dummy-perm,10,12,6,14 +4,4,normal_heuristic,10,14,7,17 +4,4,random,9,12,9,21 +4,4,optimum,5,9,6,10 +4,4,rl_model,15,17,11,26 +4,4,dummy-perm,11,15,8,16 +5,4,normal_heuristic,10,17,10,17 +5,4,random,12,15,10,23 +5,4,optimum,10,17,6,23 +5,4,rl_model,14,15,9,19 +5,4,dummy-perm,14,21,8,24 +6,4,normal_heuristic,18,21,9,25 +6,4,random,14,12,12,19 +6,4,optimum,8,15,5,21 +6,4,rl_model,10,14,11,24 +6,4,dummy-perm,6,15,6,15 +7,4,normal_heuristic,12,18,7,18 +7,4,random,9,14,11,23 +7,4,optimum,10,13,6,17 +7,4,rl_model,14,19,13,29 +7,4,dummy-perm,10,11,8,18 +8,4,normal_heuristic,12,15,9,22 +8,4,random,4,10,12,20 +8,4,optimum,10,16,5,15 +8,4,rl_model,6,11,11,23 +8,4,dummy-perm,4,8,8,16 +9,4,normal_heuristic,16,19,10,25 +9,4,random,16,17,12,24 +9,4,optimum,14,19,7,16 +9,4,rl_model,9,17,11,20 +9,4,dummy-perm,17,22,9,21 +0,4,combined_min,9,14,6,17 +1,4,combined_min,13,17,6,20 +2,4,combined_min,17,16,7,18 +3,4,combined_min,8,12,6,13 +4,4,combined_min,10,14,7,17 +5,4,combined_min,14,21,8,24 +6,4,combined_min,6,15,6,15 +7,4,combined_min,12,18,7,18 +8,4,combined_min,4,8,8,16 +9,4,combined_min,17,22,9,21 diff --git a/train_data_False_original.pt b/train_data_False_original.pt deleted file mode 100644 index 3dff485..0000000 Binary files a/train_data_False_original.pt and /dev/null differ diff --git a/train_data_True_from_project_description.pt b/train_data_True_from_project_description.pt deleted file mode 100644 index bd5f7ca..0000000 Binary files a/train_data_True_from_project_description.pt and /dev/null differ diff --git a/val_data_False_original.pt b/val_data_False_original.pt deleted file mode 100644 index c46bb5d..0000000 Binary files a/val_data_False_original.pt and /dev/null differ diff --git a/val_data_True_from_project_description.pt b/val_data_True_from_project_description.pt deleted file mode 100644 index 05dc639..0000000 Binary files a/val_data_True_from_project_description.pt and /dev/null differ