diff --git a/CATNet_dataset_train/stage2/infer.py b/CATNet_dataset_train/stage2/infer.py new file mode 100644 index 0000000..52cbd49 --- /dev/null +++ b/CATNet_dataset_train/stage2/infer.py @@ -0,0 +1,108 @@ +import cv2 +import numpy as np +import os +import torchvision +from torchvision.datasets import ImageFolder +import argparse +import torch +import torch.nn as nn +from model import MyModel +from PIL import Image + +class MPC(nn.Module): + def __init__(self, weight_path, device): + super(MPC, self).__init__() + self.cur_net = MyModel().to(device) + self.load(self.cur_net, weight_path) + + def process(self, Ii): + with torch.no_grad(): + Fo = self.cur_net(Ii) + return torch.sigmoid(Fo) + + def load(self, model, path): + weights = torch.load(path) + model_state_dict = model.state_dict() + + loaded_layers = [] + missing_layers = [] + mismatched_shapes = [] + + # 遍历加载的权重字典 + for name, param in weights.items(): + if name in model_state_dict: + if param.shape == model_state_dict[name].shape: + model_state_dict[name].copy_(param) # 更新模型的权重 + loaded_layers.append(name) + else: + mismatched_shapes.append(name) + else: + # 如果模型中没有该层,记录缺失的层 + missing_layers.append(name) + + # 打印加载成功的层 + if loaded_layers: + print(f"Successfully loaded the following layers: {', '.join(loaded_layers)}") + + # 打印形状不匹配的层 + if mismatched_shapes: + print(f"The following layers have mismatched shapes: {', '.join(mismatched_shapes)}") + + # 打印缺失的层 + if missing_layers: + print(f"The following layers are missing in the model: {', '.join(missing_layers)}") + + # 如果都加载成功,打印成功信息 + if not mismatched_shapes and not missing_layers: + print("All layers have been successfully loaded!") + +def convert(x): + x = x * 255. + return x.permute(0, 2, 3, 1).cpu().detach().numpy() + +def thresholding(x, thres=0.5): + x[x <= int(thres * 255)] = 0 + x[x > int(thres * 255)] = 255 + return x + +def args_parser(): + parser = argparse.ArgumentParser(description="MPC Model Inference") + parser.add_argument('--weight-path', type=str, default="./weights/MPC_CATNet_stage2_weights.pth", help='Path to the model weights') + parser.add_argument('--device', type=str, default='cuda:0', help='Device to run the model on (e.g., cuda:0 or cpu)') + parser.add_argument('--infer-folder', type=str, default='./data', help='Folder containing images for inference') + parser.add_argument('--output-folder', type=str, default='./output/', help='Folder to save inference results') + return parser.parse_args() + +def infer(args): + + device = torch.device(args.device) + mpc_model = MPC(weight_path=args.weight_path, device=device) + + image_transform = torchvision.transforms.Compose([ + torchvision.transforms.ToTensor(), + # torchvision.transforms.Resize((512, 512)), + ]) + + with torch.no_grad(): + for img in os.listdir(args.infer_folder): + img_path = os.path.join(args.infer_folder, img) + print(f"Processing image: {img_path}") + if not os.path.isfile(img_path): + continue + images = Image.open(img_path).convert('RGB') + images = image_transform(images).unsqueeze(0).to(device) + images = images.to(device) + Mo = mpc_model.process(images) + Mo = convert(Mo) + print(f"Processed image shape: {Mo.shape}") + output_path = os.path.join(args.output_folder, img) + if not os.path.exists(args.output_folder): + os.makedirs(args.output_folder) + cv2.imwrite(output_path, thresholding(Mo[0])) + +if __name__ == '__main__': + args = args_parser() + os.makedirs(args.output_folder, exist_ok=True) + infer(args) +Doc print("Inference completed and results saved to:", args.output_folder) + \ No newline at end of file diff --git a/README.md b/README.md index a72adac..15c946b 100644 --- a/README.md +++ b/README.md @@ -20,6 +20,20 @@ An official implementation code for paper "[Exploring Multi-View Pixel Contrast - torchvision 0.8.2+cu110 - python 3.8 +### Docker + +Pull the PyTorch CUDA image +```bash +docker pull pytorch/pytorch:2.0.0-cuda11.7-cudnn8-devel +``` + +Run the container interactively, install dependencies, and mount your project directory: +```bash +docker run --gpus all --rm -v ~/projects/:/workspace -it pytorch/pytorch:2.0.1-cuda11.7-cudnn8-devel bash +apt update && apt install libgl1-mesa-glx libglib2.0-0 -y +pip install albumentations==1.3.1 fvcore==0.1.5.post20221221 numpy==1.23.0 opencv-python==4.8.1.78 opencv-python-headless==4.9.0.80 einops timm +``` + ## Usage Generate the file list: @@ -41,7 +55,14 @@ For example to test: download [MPC_CATNet_stage2_weights.pth](https://www.123684 cd CATNet_dataset_train/stage2 python test.py ``` -If you want to test MPC of trained with CASIAv2 dataset, please download the weight file from [MPC_CASIAv2_stage2_weights.pth](https://www.123684.com/s/2pf9-ylCHv) +If you want to test MPC of trained with CASIAv2 dataset, please download the weight file from [MPC_CASIAv2_stage2_weights.pth](https://www.123684.com/s/2pf9-ylCHv) or [MPC_CASIAv2_stage2_weights.pth](https://drive.google.com/file/d/1vXDvrTVizINKcR_MgdA_uuTWcB3I4HdI/view?usp=sharing). + +To run inference on a folder of images, use: +```bash +cd CATNet_dataset_train/stage2 +python infer.py +``` + ## Citation If you use this code for your research, please cite our paper ```