diff --git a/depth/infer_indoor_mps.sh b/depth/infer_indoor_mps.sh new file mode 100644 index 0000000..1cf8de1 --- /dev/null +++ b/depth/infer_indoor_mps.sh @@ -0,0 +1,8 @@ +export mps=0 + +PYTHONPATH="$(dirname $0)/..":"$(dirname $0)/../stable-diffusion":$PYTHONPATH \ +python infer_mps.py \ + --video_path /%%%%.mp4 \ + --max_depth 10.0 \ + --min_depth 1e-3 \ + --ckpt_dir /%%%%/EcoDepth/depth/checkpoints/nyu.ckpt \ \ No newline at end of file diff --git a/depth/infer_mps.py b/depth/infer_mps.py new file mode 100644 index 0000000..7f55bca --- /dev/null +++ b/depth/infer_mps.py @@ -0,0 +1,163 @@ + +import cv2 +import numpy as np +import torch +from models.model import EcoDepth +from configs.infer_options import InferOptions +from utils import colorize_depth +import math + +def predict(orig_img, model, device): + # requires a numpy image of shape (h,w,3) with pixel values 0~255, the model and device + # returns a numpy image representing the depth map with shape h, w + # resize to a given shape + orig_img = cv2.cvtColor(orig_img, cv2.COLOR_RGB2BGR).astype(np.float32) + + orig_img = orig_img/255.0 + orig_h,orig_w,_ = orig_img.shape + max_area = 1000*720 + area = orig_h*orig_w + ratio = math.sqrt(area/max_area) + + new_h = int(orig_h/ratio) + new_w = int(orig_w/ratio) + new_img = cv2.resize(orig_img, (new_w, new_h)) + + # add padding to ensure img dimensions are multiples of 64 + add_h = 64-new_h%64 + add_w = 64-new_w%64 + + final_h = new_h+add_h + final_w = new_w+add_w + + final_img = np.zeros((final_h, final_w, 3)) + final_img[:new_h, :new_w, :] = new_img + + # convert to pytorch tensor, reshape and send to device + final_img = torch.from_numpy(final_img) + final_img = final_img.permute(2,0,1) + final_img = final_img.unsqueeze(0) + final_img = final_img.to(torch.float32).to(device) + + # flip images + final_img_flipped = torch.flip(final_img, [3]) + final_img_concat = torch.cat([final_img, final_img_flipped]) + + # change datatype from torch.float64 to torch.float32 + final_img_concat = final_img_concat.to(torch.float32) + + # send depth to model + with torch.no_grad(): + final_depth_concat = model(final_img_concat)['pred_d'] + + final_depth = final_depth_concat[0] + final_depth_flipped = final_depth_concat[1] + + # take an average of the two predicted images + final_depth = (final_depth+torch.flip(final_depth_flipped, [2]))/2 + + # squeeze out extra batch and channel dimensions + final_depth = final_depth.squeeze() + + # undo padding + final_depth = final_depth[:new_h, :new_w] + + final_depth = final_depth.detach().cpu().numpy() + + # resize to original shape + final_depth = cv2.resize(final_depth, (orig_w, orig_h)) + + return final_depth + +def visualize(img, depth): + # requires a numpy array of shape (h,w,3) with pixel values 0~255 representing the RGB image + # requires a numpy array of shape (h,w) representing the predicted depth + # returns a side-by-side visualization of the image and depth map + + # obtain depth map using colorize_depth + # take log of depth to put greater focus on nearer objects + + # remove the top portion and a little bottom part to get a better visualization + + img = img[60:-20, :-20] + + depth_map = colorize_depth(np.log(depth)) + + depth_map = depth_map[60:-20, :-20] + + # reverse the colour channel to get a better visual effect + depth_map = depth_map[:, :, ::-1] + + # stack the img and depth horizontally with the img coming first + viz = np.hstack((img, depth_map)) + + return viz.astype(np.uint8) + + +def main(): + # set inference arguments and load model + opt = InferOptions() + args = opt.initialize().parse_args() + device = torch.device('mps') if torch.backends.mps.is_available() else torch.device('cpu') + model_weight = torch.load(args.ckpt_dir, map_location=device)['model'] + model_weight = {k: v.to(torch.float32) for k, v in model_weight.items()} + model = EcoDepth(args=args) + model.load_state_dict(model_weight) + model = model.float() + model.to(device) + model.eval() + + # model is ready for inference + if args.img_path is not None: + print("Converting {} to a depth map".format(args.img_path)) + img_name = args.img_path[:-4] + ext = args.img_path[-4:] + # allow support for png or jpg images only + assert ext == '.png' or ext == '.jpg' + depth_name = img_name+'_depth' + depth_path = depth_name+'.png' + # read img + img = cv2.imread(args.img_path) + + # get depth + depth = predict(img, model, device) + + # get visualization + viz = visualize(img, depth) + + # write visualization to file + cv2.imwrite(depth_path, viz) + + if args.video_path is not None: + print("Converting {} to a depth video".format(args.video_path)) + video_name = args.video_path[:-4] + ext = args.video_path[-4:] + # allow support for mp4 videos only + assert ext == '.mp4' + depth_name = video_name+'_depth' + depth_path = depth_name+'.avi' + # read img + + vidcap = cv2.VideoCapture(args.video_path) + # read a frame from the video + success, img = vidcap.read() + h, w, _ = img.shape + frame_rate = 30.0 + video = cv2.VideoWriter(depth_path, cv2.VideoWriter_fourcc(*"MJPG"), frame_rate, (2*w-40, h-80)) + + while(success): + # get depth + depth = predict(img, model, device) + + # get visualization + viz = visualize(img, depth) + # write visualization to file + video.write(viz) + + # read a frame from the video + success, img = vidcap.read() + + video.release() + +if __name__ == '__main__': + main() \ No newline at end of file diff --git a/env_mps.yml b/env_mps.yml new file mode 100644 index 0000000..d604347 --- /dev/null +++ b/env_mps.yml @@ -0,0 +1,125 @@ +name: ecodepth_mps +channels: + - defaults +dependencies: + - bzip2=1.0.8 + - ca-certificates=2023.01.10 + - libffi + - libuuid + - ncurses + - openssl + - pip=23.0.1 + - python=3.11.3 + - readline=8.2 + - setuptools=66.0.0 + - sqlite=3.41.2 + - tk=8.6.12 + - wheel=0.38.4 + - xz=5.4.2 + - zlib=1.2.13 + - pip: + - addict==2.4.0 + - aiohttp==3.8.4 + - aiosignal==1.3.1 + - albumentations==1.3.0 + - antlr4-python3-runtime==4.9.3 + - appdirs==1.4.4 + - async-timeout==4.0.2 + - attrs==23.1.0 + - beautifulsoup4==4.12.2 + - certifi==2023.5.7 + - charset-normalizer==3.1.0 + - click==8.1.3 + - cmake==3.26.3 + - colorama==0.4.6 + - contourpy==1.0.7 + - cycler==0.11.0 + - docker-pycreds==0.4.0 + - einops==0.6.1 + - filelock==3.12.0 + - fonttools==4.39.4 + - frozenlist==1.3.3 + - fsspec==2023.5.0 + - ftfy==6.1.1 + - gdown==4.7.1 + - gitdb==4.0.10 + - gitpython==3.1.31 + - h5py==3.8.0 + - huggingface-hub==0.14.1 + - idna==3.4 + - imageio==2.29.0 + - jinja2==3.1.2 + - joblib==1.2.0 + - kiwisolver==1.4.4 + - kornia==0.6.12 + - lazy-loader==0.2 + - lightning-utilities==0.8.0 + - lit==16.0.5 + - markdown==3.4.3 + - markdown-it-py==2.2.0 + - markupsafe==2.1.2 + - matplotlib==3.7.1 + - mdurl==0.1.2 + - mmcv-full==1.7.1 + - model-index==0.1.11 + - mpmath==1.3.0 + - multidict==6.0.4 + - networkx==3.1 + - numpy==1.24.3 + - omegaconf==2.3.0 + - opencv-python==4.7.0.72 + - opencv-python-headless==4.7.0.72 + - openmim==0.3.7 + - ordered-set==4.1.0 + - packaging==23.1 + - pandas==2.0.1 + - pathtools==0.1.2 + - pillow==9.5.0 + - protobuf==3.20.3 + - psutil==5.9.5 + - pygments==2.15.1 + - pyparsing==3.0.9 + - pysocks==1.7.1 + - python-dateutil==2.8.2 + - pytorch-lightning==2.0.2 + - pytz==2023.3 + - pywavelets==1.4.1 + - pyyaml==6.0 + - qudida==0.0.4 + - regex==2023.5.5 + - requests==2.31.0 + - rich==13.3.5 + - safetensors==0.3.1 + - scikit-image==0.20.0 + - scikit-learn==1.2.2 + - scipy==1.10.1 + - sentencepiece==0.1.99 + - sentry-sdk==1.24.0 + - setproctitle==1.3.2 + - six==1.16.0 + - smmap==5.0.0 + - soupsieve==2.4.1 + - sympy==1.12 + - tabulate==0.9.0 + - taming-transformers==0.0.1 + - tensorboardx==2.6 + - threadpoolctl==3.1.0 + - tifffile==2023.4.12 + - timm==0.9.2 + - tokenizers==0.13.3 + - tomli==2.0.1 + - torch==2.0.1 + - torchmetrics==0.11.4 + - torchvision==0.15.2 + - tqdm==4.65.0 + - transformers==4.29.2 + - typing-extensions==4.6.1 + - tzdata==2023.3 + - urllib3==1.26.16 + - wandb==0.15.11 + - wcwidth==0.2.6 + - yapf==0.33.0 + - yarl==1.9.2 + - clip + - taming-transformers-rom1504 +