forked from Soul-AILab/SoulX-Podcast
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathrun_api.py
More file actions
126 lines (114 loc) · 3.59 KB
/
run_api.py
File metadata and controls
126 lines (114 loc) · 3.59 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
"""
Quick start script for SoulX-Podcast API
使用示例:
python run_api.py
python run_api.py --port 8080
python run_api.py --model pretrained_models/SoulX-Podcast-1.7B-dialect
"""
import os
import sys
import argparse
import signal
import time
def main():
parser = argparse.ArgumentParser(description="启动SoulX-Podcast API服务")
parser.add_argument(
"--model",
type=str,
default="pretrained_models/SoulX-Podcast-1.7B",
help="模型路径(默认: pretrained_models/SoulX-Podcast-1.7B)"
)
parser.add_argument(
"--port",
type=int,
default=8000,
help="API端口(默认: 8000)"
)
parser.add_argument(
"--host",
type=str,
default="0.0.0.0",
help="API主机地址(默认: 0.0.0.0)"
)
parser.add_argument(
"--engine",
type=str,
choices=["hf", "vllm"],
default="hf",
help="LLM引擎(默认: hf)"
)
parser.add_argument(
"--fp16-flow",
action="store_true",
help="使用FP16精度的Flow模型(更快但略降质量)"
)
parser.add_argument(
"--max-tasks",
type=int,
default=2,
help="最大并发任务数(默认: 2)"
)
parser.add_argument(
"--reload",
action="store_true",
help="启用热重载(开发模式)"
)
args = parser.parse_args()
# 设置环境变量
os.environ["MODEL_PATH"] = args.model
os.environ["API_HOST"] = args.host
os.environ["API_PORT"] = str(args.port)
os.environ["LLM_ENGINE"] = args.engine
os.environ["FP16_FLOW"] = "true" if args.fp16_flow else "false"
os.environ["MAX_CONCURRENT_TASKS"] = str(args.max_tasks)
os.environ["API_RELOAD"] = "true" if args.reload else "false"
# 检查模型路径
if not os.path.exists(args.model):
print(f"错误: 模型路径不存在: {args.model}")
print("\n请先下载模型:")
print(f"huggingface-cli download --resume-download Soul-AILab/SoulX-Podcast-1.7B --local-dir {args.model}")
sys.exit(1)
# 打印启动信息
print("=" * 60)
print("SoulX-Podcast API 服务启动中...")
print("=" * 60)
print(f"模型路径: {args.model}")
print(f"服务地址: http://{args.host}:{args.port}")
print(f"API文档: http://localhost:{args.port}/docs")
print(f"LLM引擎: {args.engine}")
print(f"FP16 Flow: {'是' if args.fp16_flow else '否'}")
print(f"最大并发: {args.max_tasks}")
print("=" * 60)
print("\n正在加载模型,请稍候...\n")
print("提示: 按 Ctrl+C 可以停止服务(如果响应慢,连按两次强制退出)\n")
# 设置信号处理器,支持快速退出
shutdown_count = 0
def signal_handler(signum, frame):
nonlocal shutdown_count
shutdown_count += 1
if shutdown_count == 1:
print("\n\n正在优雅关闭服务... (再按一次 Ctrl+C 强制退出)")
else:
print("\n\n强制退出!")
# 清理GPU内存
try:
import torch
if torch.cuda.is_available():
torch.cuda.empty_cache()
except:
pass
os._exit(0)
signal.signal(signal.SIGINT, signal_handler)
if hasattr(signal, 'SIGTERM'):
signal.signal(signal.SIGTERM, signal_handler)
# 启动API
import uvicorn
uvicorn.run(
"api.main:app",
host=args.host,
port=args.port,
reload=args.reload,
log_level="info"
)
if __name__ == "__main__":
main()