-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathinference.py
More file actions
164 lines (129 loc) · 5.2 KB
/
inference.py
File metadata and controls
164 lines (129 loc) · 5.2 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
"""
Inference script for AMF / NMF classification on stacked histopathology crop‑outs.
* Model : EfficientNet‑V2‑M
* Weights: /opt/ml/model/a_tarball_subdirectory/efficientnetv2_m_fold3_best.pth, please download it from https://midog2025.deepmicroscopy.org/files/efficientnetv2_m_fold3_best.pth and place it in the model/ folder
* Input : /input/images/stacked-histopathology-roi-cropouts/*.tif(f) (multi‑page TIFF generated by Pillow)
* Output : /output/multiple-mitotic-figure-classification.json
Each TIFF page is scored independently. The JSON contains one object per
page with the probability of atypical mitosis, regardless of the winning
class.
"""
from __future__ import annotations
from pathlib import Path
import json
from glob import glob
import numpy as np
from PIL import Image, ImageSequence
import torch
import torch.nn as nn
from torchvision import transforms
import os
from typing import Any, Dict, List
import logging as log
import timm
# Paths & constants
INPUT_PATH = Path("/input")
OUTPUT_PATH = Path("/output")
MODEL_WEIGHTS = Path(
"/opt/app/model"
)
_env_model = os.getenv("MODEL_PATH")
if _env_model and Path(_env_model).is_file():
MODEL_PATH = Path(_env_model)
else:
# pick first .pth file in model dir
candidates = sorted(MODEL_WEIGHTS.glob("*.pth"))
if not candidates:
raise FileNotFoundError(
f"No .pth checkpoint found in {MODEL_WEIGHTS}. "
f"You can also set MODEL_PATH env var.")
MODEL_PATH = candidates[0]
if _env_model:
log.warning("MODEL_PATH '%s' not found. Using '%s' instead.", _env_model, MODEL_PATH)
else:
log.info("Using checkpoint %s", MODEL_PATH)
CLASS_NAMES = {0: "atypical", 1: "normal"}
THRESHOLD = 0.5 # probability threshold for classifying *normal*
# Model definition
try:
import timm # Ensure timm is present
except ImportError: # pragma: no cover
raise ImportError("timm is required inside the algorithm image: pip install timm")
MODEL_NAME = "efficientnetv2_m"
class BinaryEfficientNetV2M(nn.Module):
"""Architecture identical to the training script (classifier replaced in‑place)."""
def __init__(self, pretrained: bool = False) -> None:
super().__init__()
# Same call signature as during training (num_classes defaults to 1000)
self.net = timm.create_model(MODEL_NAME, pretrained=pretrained)
# Replace the classifier with a single logit layer
n_feats = self.net.classifier.in_features
self.net.classifier = nn.Linear(n_feats, 1)
def forward(self, x):
return self.net(x)
# Transforms (must match validation)
val_transform = transforms.Compose(
[
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
]
)
# Helper functions
def get_interface_key() -> tuple[str, ...]:
with open(INPUT_PATH / "inputs.json", "r") as f:
inputs = json.load(f)
return tuple(sorted(sv["interface"]["slug"] for sv in inputs))
def load_image_stack(dir_path: Path) -> List[np.ndarray]:
tiffs = glob(str(dir_path / "*.tif")) + glob(str(dir_path / "*.tiff"))
if not tiffs:
raise FileNotFoundError(f"No TIFF files found in {dir_path}")
with Image.open(tiffs[0]) as tif:
return [np.array(p.convert("RGB")) for p in ImageSequence.Iterator(tif)]
def write_json(path: Path, obj):
path.parent.mkdir(parents=True, exist_ok=True)
with open(path, "w") as f:
json.dump(obj, f, indent=4)
def _show_torch_cuda_info():
print("=+=" * 10)
avail = torch.cuda.is_available()
print("Torch CUDA available:", avail)
if avail:
print(" num devices:", torch.cuda.device_count())
cur = torch.cuda.current_device()
print(" current device:", cur)
print(" properties :", torch.cuda.get_device_properties(cur))
print("=+=" * 10)
# Inference handler
def interf0_handler() -> int:
# Load TIFF stack
slices = load_image_stack(INPUT_PATH / "images/stacked-histopathology-roi-cropouts")
# Prepare model
_show_torch_cuda_info()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = BinaryEfficientNetV2M(pretrained=False).to(device)
model.load_state_dict(torch.load(MODEL_PATH, map_location=device), strict=True)
model.eval()
# Run inference
preds = []
with torch.no_grad():
for sl in slices:
img = Image.fromarray(sl)
x = val_transform(img).unsqueeze(0).to(device)
logit = model(x)
p_normal = torch.sigmoid(logit).item() # P(normal)
p_atyp = 1.0 - p_normal # P(atypical)
cls_idx = 1 if p_normal >= THRESHOLD else 0
preds.append({"class": CLASS_NAMES[cls_idx], "confidence": round(p_atyp, 4)})
# Persist
write_json(OUTPUT_PATH / "multiple-mitotic-figure-classification.json", preds)
return 0
# Entrypoint
def run() -> int:
handlers = {("stacked-histopathology-roi-cropouts",): interf0_handler}
key = get_interface_key()
if key not in handlers:
raise KeyError(f"Unsupported interface combination: {key}")
return handlers[key]()
if __name__ == "__main__":
raise SystemExit(run())