-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathsplit_bands.py
More file actions
66 lines (51 loc) · 1.97 KB
/
split_bands.py
File metadata and controls
66 lines (51 loc) · 1.97 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
"""Split multiband raster into individual raster bands."""
import argparse
import os
import logging
import multiprocessing
from ecoshard import geoprocessing
from ecoshard import taskgraph
from osgeo import gdal
gdal.SetCacheMax(2**30)
logging.basicConfig(
level=logging.DEBUG,
format=(
'%(asctime)s (%(relativeCreated)d) %(levelname)s %(name)s'
' [%(funcName)s:%(lineno)d] %(message)s'))
logging.getLogger('taskgraph').setLevel(logging.WARN)
LOGGER = logging.getLogger(__name__)
def main():
"""Entry point."""
parser = argparse.ArgumentParser(
description='Split multiband raster into individual raster')
parser.add_argument('base_path', type=str, help='path to multiband raster')
parser.add_argument(
'--offset_count', type=int,
help='number to add to target raster band id suffix')
args = parser.parse_args()
raster_info = geoprocessing.get_raster_info(args.base_path)
n_bands = raster_info['n_bands']
target_path_list = [
f'%s{args.offset_count+band_index}%s' % os.path.splitext(
os.path.basename(args.base_path))
for band_index in range(n_bands)]
if any([os.path.exists(path) for path in target_path_list]):
raise ValueError(
f"expected paths arlready exist, don't want to overwrite: "
f"{target_path_list}")
task_graph = taskgraph.TaskGraph(
'.', min(multiprocessing.cpu_count(), n_bands))
task_graph.join()
for band_index, target_path in enumerate(target_path_list):
task_graph.add_task(
func=geoprocessing.raster_calculator,
args=[
[(args.base_path, band_index+1)], passthrough_op, target_path,
raster_info['datatype'], raster_info['nodata'][band_index]],
target_path_list=[target_path],
task_name=f'extract band {band_index}')
task_graph.close()
task_graph.join()
def passthrough_op(x): return x
if __name__ == '__main__':
main()