-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathdeployment.py
More file actions
107 lines (82 loc) · 3.19 KB
/
deployment.py
File metadata and controls
107 lines (82 loc) · 3.19 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
import streamlit as st
from PIL import Image
import torch
import os
import argparse
import yaml
import utils as util
from train import inference
import tempfile
from huggingface_hub import hf_hub_download
MODEL_PATH = hf_hub_download(
repo_id="ntmy777/dehazedet",
filename="best.pt"
)
# Load arguments and parameters
def load_args_params():
parser = argparse.ArgumentParser()
parser.add_argument('--input-size', default=640, type=int)
parser.add_argument('--batch-size', default=16, type=int)
parser.add_argument('--local_rank', default=0, type=int)
parser.add_argument('--epochs', default=100, type=int)
parser.add_argument('--model_path', default=MODEL_PATH, type=str)
parser.add_argument('--train', action='store_true')
parser.add_argument('--test', action='store_true')
parser.add_argument('--inference', action='store_true')
parser.add_argument('--inference_test', action='store_true')
parser.add_argument('--data', default='rtts', type=str)
parser.add_argument('--detection_weight', default=0.1, type=int)
parser.add_argument('--dehazing_weight', default=0.9, type=int)
parser.add_argument('--resume', action='store_true')
args = parser.parse_args([])
args.local_rank = int(os.getenv('LOCAL_RANK', 0))
args.world_size = int(os.getenv('WORLD_SIZE', 1))
if args.world_size > 1:
torch.cuda.set_device(device=args.local_rank)
torch.distributed.init_process_group(backend='nccl', init_method='env://')
if args.local_rank == 0:
if not os.path.exists('weights'):
os.makedirs('weights')
util.setup_seed()
util.setup_multi_processes()
with open(os.path.join('args.yaml'), errors='ignore') as f:
params = yaml.safe_load(f)
return args, params
# Streamlit UI
st.set_page_config(page_title="DehazeDET Inference", layout="centered")
st.title("DehazeDet - Object Detection with Image Restoration")
st.markdown(
"""
Select an input image to perform joint dehazing and object detection.
You may upload your own image or use the provided sample.
"""
)
st.divider()
uploaded_file = st.file_uploader("Upload an image...", type=["jpg", "jpeg", "png"])
use_sample = st.button("Use Provided Sample Image")
image = None
image_path = None
if uploaded_file is not None:
image = Image.open(uploaded_file).convert("RGB")
# Save the uploaded PIL image to a temporary file
with tempfile.NamedTemporaryFile(suffix=".jpg", delete=False) as tmp_file:
image.save(tmp_file.name)
image_path = tmp_file.name
elif use_sample:
image_path = "./sample-input.jpg"
image = Image.open(image_path).convert("RGB")
if image is not None:
st.subheader("Input Image")
st.image(image, use_container_width=True)
args, params = load_args_params()
st.divider()
with st.spinner("Running inference..."):
result_img = inference(
args.model_path,
image_path,
args,
params,
device="cpu"
)
st.subheader("Inference Result")
st.image(result_img, use_container_width=True)