diff --git a/eagle/application/webui.py b/eagle/application/webui.py index a1b96dff..01ac7791 100644 --- a/eagle/application/webui.py +++ b/eagle/application/webui.py @@ -1,5 +1,4 @@ import os -os.environ["CUDA_VISIBLE_DEVICES"] = "2,3" import time import gradio as gr @@ -269,8 +268,18 @@ def clear(history,session_state): default=512, help="The maximum number of new generated tokens.", ) +parser.add_argument( + "--cuda-visible-devices", + type=str, + default="2,3", + help="Comma-separated list of GPU IDs to use (e.g., '0,1' or '2,3'). Default: '2,3'. Set to empty string to use all available GPUs.", +) args = parser.parse_args() +# Set CUDA visible devices if specified (empty string means use all GPUs) +if args.cuda_visible_devices: + os.environ["CUDA_VISIBLE_DEVICES"] = args.cuda_visible_devices + model = EaModel.from_pretrained( base_model_path=args.base_model_path, ea_model_path=args.ea_model_path,