Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions depth/infer_indoor_mps.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
export mps=0

PYTHONPATH="$(dirname $0)/..":"$(dirname $0)/../stable-diffusion":$PYTHONPATH \
python infer_mps.py \
--video_path /%%%%.mp4 \
--max_depth 10.0 \
--min_depth 1e-3 \
--ckpt_dir /%%%%/EcoDepth/depth/checkpoints/nyu.ckpt \
163 changes: 163 additions & 0 deletions depth/infer_mps.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,163 @@

import cv2
import numpy as np
import torch
from models.model import EcoDepth
from configs.infer_options import InferOptions
from utils import colorize_depth
import math

def predict(orig_img, model, device):
# requires a numpy image of shape (h,w,3) with pixel values 0~255, the model and device
# returns a numpy image representing the depth map with shape h, w
# resize to a given shape
orig_img = cv2.cvtColor(orig_img, cv2.COLOR_RGB2BGR).astype(np.float32)

orig_img = orig_img/255.0
orig_h,orig_w,_ = orig_img.shape
max_area = 1000*720
area = orig_h*orig_w
ratio = math.sqrt(area/max_area)

new_h = int(orig_h/ratio)
new_w = int(orig_w/ratio)
new_img = cv2.resize(orig_img, (new_w, new_h))

# add padding to ensure img dimensions are multiples of 64
add_h = 64-new_h%64
add_w = 64-new_w%64

final_h = new_h+add_h
final_w = new_w+add_w

final_img = np.zeros((final_h, final_w, 3))
final_img[:new_h, :new_w, :] = new_img

# convert to pytorch tensor, reshape and send to device
final_img = torch.from_numpy(final_img)
final_img = final_img.permute(2,0,1)
final_img = final_img.unsqueeze(0)
final_img = final_img.to(torch.float32).to(device)

# flip images
final_img_flipped = torch.flip(final_img, [3])
final_img_concat = torch.cat([final_img, final_img_flipped])

# change datatype from torch.float64 to torch.float32
final_img_concat = final_img_concat.to(torch.float32)

# send depth to model
with torch.no_grad():
final_depth_concat = model(final_img_concat)['pred_d']

final_depth = final_depth_concat[0]
final_depth_flipped = final_depth_concat[1]

# take an average of the two predicted images
final_depth = (final_depth+torch.flip(final_depth_flipped, [2]))/2

# squeeze out extra batch and channel dimensions
final_depth = final_depth.squeeze()

# undo padding
final_depth = final_depth[:new_h, :new_w]

final_depth = final_depth.detach().cpu().numpy()

# resize to original shape
final_depth = cv2.resize(final_depth, (orig_w, orig_h))

return final_depth

def visualize(img, depth):
# requires a numpy array of shape (h,w,3) with pixel values 0~255 representing the RGB image
# requires a numpy array of shape (h,w) representing the predicted depth
# returns a side-by-side visualization of the image and depth map

# obtain depth map using colorize_depth
# take log of depth to put greater focus on nearer objects

# remove the top portion and a little bottom part to get a better visualization

img = img[60:-20, :-20]

depth_map = colorize_depth(np.log(depth))

depth_map = depth_map[60:-20, :-20]

# reverse the colour channel to get a better visual effect
depth_map = depth_map[:, :, ::-1]

# stack the img and depth horizontally with the img coming first
viz = np.hstack((img, depth_map))

return viz.astype(np.uint8)


def main():
# set inference arguments and load model
opt = InferOptions()
args = opt.initialize().parse_args()
device = torch.device('mps') if torch.backends.mps.is_available() else torch.device('cpu')
model_weight = torch.load(args.ckpt_dir, map_location=device)['model']
model_weight = {k: v.to(torch.float32) for k, v in model_weight.items()}
model = EcoDepth(args=args)
model.load_state_dict(model_weight)
model = model.float()
model.to(device)
model.eval()

# model is ready for inference
if args.img_path is not None:
print("Converting {} to a depth map".format(args.img_path))
img_name = args.img_path[:-4]
ext = args.img_path[-4:]
# allow support for png or jpg images only
assert ext == '.png' or ext == '.jpg'
depth_name = img_name+'_depth'
depth_path = depth_name+'.png'
# read img
img = cv2.imread(args.img_path)

# get depth
depth = predict(img, model, device)

# get visualization
viz = visualize(img, depth)

# write visualization to file
cv2.imwrite(depth_path, viz)

if args.video_path is not None:
print("Converting {} to a depth video".format(args.video_path))
video_name = args.video_path[:-4]
ext = args.video_path[-4:]
# allow support for mp4 videos only
assert ext == '.mp4'
depth_name = video_name+'_depth'
depth_path = depth_name+'.avi'
# read img

vidcap = cv2.VideoCapture(args.video_path)
# read a frame from the video
success, img = vidcap.read()
h, w, _ = img.shape
frame_rate = 30.0
video = cv2.VideoWriter(depth_path, cv2.VideoWriter_fourcc(*"MJPG"), frame_rate, (2*w-40, h-80))

while(success):
# get depth
depth = predict(img, model, device)

# get visualization
viz = visualize(img, depth)
# write visualization to file
video.write(viz)

# read a frame from the video
success, img = vidcap.read()

video.release()

if __name__ == '__main__':
main()
125 changes: 125 additions & 0 deletions env_mps.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
name: ecodepth_mps
channels:
- defaults
dependencies:
- bzip2=1.0.8
- ca-certificates=2023.01.10
- libffi
- libuuid
- ncurses
- openssl
- pip=23.0.1
- python=3.11.3
- readline=8.2
- setuptools=66.0.0
- sqlite=3.41.2
- tk=8.6.12
- wheel=0.38.4
- xz=5.4.2
- zlib=1.2.13
- pip:
- addict==2.4.0
- aiohttp==3.8.4
- aiosignal==1.3.1
- albumentations==1.3.0
- antlr4-python3-runtime==4.9.3
- appdirs==1.4.4
- async-timeout==4.0.2
- attrs==23.1.0
- beautifulsoup4==4.12.2
- certifi==2023.5.7
- charset-normalizer==3.1.0
- click==8.1.3
- cmake==3.26.3
- colorama==0.4.6
- contourpy==1.0.7
- cycler==0.11.0
- docker-pycreds==0.4.0
- einops==0.6.1
- filelock==3.12.0
- fonttools==4.39.4
- frozenlist==1.3.3
- fsspec==2023.5.0
- ftfy==6.1.1
- gdown==4.7.1
- gitdb==4.0.10
- gitpython==3.1.31
- h5py==3.8.0
- huggingface-hub==0.14.1
- idna==3.4
- imageio==2.29.0
- jinja2==3.1.2
- joblib==1.2.0
- kiwisolver==1.4.4
- kornia==0.6.12
- lazy-loader==0.2
- lightning-utilities==0.8.0
- lit==16.0.5
- markdown==3.4.3
- markdown-it-py==2.2.0
- markupsafe==2.1.2
- matplotlib==3.7.1
- mdurl==0.1.2
- mmcv-full==1.7.1
- model-index==0.1.11
- mpmath==1.3.0
- multidict==6.0.4
- networkx==3.1
- numpy==1.24.3
- omegaconf==2.3.0
- opencv-python==4.7.0.72
- opencv-python-headless==4.7.0.72
- openmim==0.3.7
- ordered-set==4.1.0
- packaging==23.1
- pandas==2.0.1
- pathtools==0.1.2
- pillow==9.5.0
- protobuf==3.20.3
- psutil==5.9.5
- pygments==2.15.1
- pyparsing==3.0.9
- pysocks==1.7.1
- python-dateutil==2.8.2
- pytorch-lightning==2.0.2
- pytz==2023.3
- pywavelets==1.4.1
- pyyaml==6.0
- qudida==0.0.4
- regex==2023.5.5
- requests==2.31.0
- rich==13.3.5
- safetensors==0.3.1
- scikit-image==0.20.0
- scikit-learn==1.2.2
- scipy==1.10.1
- sentencepiece==0.1.99
- sentry-sdk==1.24.0
- setproctitle==1.3.2
- six==1.16.0
- smmap==5.0.0
- soupsieve==2.4.1
- sympy==1.12
- tabulate==0.9.0
- taming-transformers==0.0.1
- tensorboardx==2.6
- threadpoolctl==3.1.0
- tifffile==2023.4.12
- timm==0.9.2
- tokenizers==0.13.3
- tomli==2.0.1
- torch==2.0.1
- torchmetrics==0.11.4
- torchvision==0.15.2
- tqdm==4.65.0
- transformers==4.29.2
- typing-extensions==4.6.1
- tzdata==2023.3
- urllib3==1.26.16
- wandb==0.15.11
- wcwidth==0.2.6
- yapf==0.33.0
- yarl==1.9.2
- clip
- taming-transformers-rom1504