diff --git a/scripts/replay_implementation.py b/scripts/replay_implementation.py index c549156..8206a2b 100644 --- a/scripts/replay_implementation.py +++ b/scripts/replay_implementation.py @@ -40,16 +40,28 @@ from omni.ext.mobility_gen.build import load_scenario +def str2bool(v): + """Convert string to boolean for argparse.""" + if isinstance(v, bool): + return v + if v.lower() in ('yes', 'true', 't', 'y', '1'): + return True + elif v.lower() in ('no', 'false', 'f', 'n', '0'): + return False + else: + raise argparse.ArgumentTypeError('Boolean value expected.') + + if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--input_path", type=str) parser.add_argument("--output_path", type=str) - parser.add_argument("--rgb_enabled", type=bool, default=True) - parser.add_argument("--segmentation_enabled", type=bool, default=True) - parser.add_argument("--depth_enabled", type=bool, default=True) - parser.add_argument("--instance_id_segmentation_enabled", type=bool, default=True) - parser.add_argument("--normals_enabled", type=bool, default=False) + parser.add_argument("--rgb_enabled", type=str2bool, default=True) + parser.add_argument("--segmentation_enabled", type=str2bool, default=True) + parser.add_argument("--depth_enabled", type=str2bool, default=True) + parser.add_argument("--instance_id_segmentation_enabled", type=str2bool, default=True) + parser.add_argument("--normals_enabled", type=str2bool, default=False) parser.add_argument("--render_rt_subframes", type=int, default=1) parser.add_argument("--render_interval", type=int, default=1) @@ -96,6 +108,9 @@ print(f"\tOutput path: {args.output_path}") print(f"\tRgb enabled: {args.rgb_enabled}") print(f"\tSegmentation enabled: {args.segmentation_enabled}") + print(f"\tDepth enabled: {args.depth_enabled}") + print(f"\tInstance id segmentation enabled: {args.instance_id_segmentation_enabled}") + print(f"\tNormals enabled: {args.normals_enabled}") print(f"\tRendering RT subframes: {args.render_rt_subframes}") print(f"\tRender interval: {args.render_interval}")