-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathgenerate_content.py
More file actions
99 lines (86 loc) · 4.19 KB
/
generate_content.py
File metadata and controls
99 lines (86 loc) · 4.19 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
import sys
import json
import time
import asyncio
import zhipuai
import os
import streamlit as st
# Load API Key from streamlit secrets
zhipuai.api_key = st.secrets["ZHIPUAI_API_KEY"]
async def async_generate(topic, module):
try:
print(f"[DEBUG] 开始生成模块内容 - 主题: {topic}, 模块: {module}", flush=True)
print(f"[DEBUG] 发送API请求...", flush=True)
response = zhipuai.model_api.invoke(
model="chatglm_turbo",
prompt=[{
"role": "user",
"content": f"""你是一个帮助解释学习主题模块的助手。请详细解释"{topic}"这个主题中的"{module}"模块,如果需要的话可以包含示例。请用markdown格式回复。"""
}],
temperature=0.9,
top_p=0.7
)
print(f"[DEBUG] 收到API响应: {response.get('code')}", flush=True)
if response.get("code") == 200:
content = response["data"]["choices"][0]["content"]
print(f"[DEBUG] 内容长度: {len(content)} 字符", flush=True)
print(f"* '{module}' 内容生成完成", flush=True)
return content
else:
error_msg = response.get('msg', '未知错误')
print(f"! 生成'{module}'内容时发生错误: {error_msg}", flush=True)
return f"生成{module}的内容时发生错误: {error_msg}"
except Exception as e:
print(f"! 生成'{module}'内容时发生异常: {str(e)}", flush=True)
return f"生成{module}的内容时发生异常: {str(e)}"
async def generate_concurrently(topic, sections):
print(f"[DEBUG] 接收到的sections: {sections}", flush=True)
print(f"开始为 '{topic}' 生成 {len(sections)} 个模块的内容...", flush=True)
tasks = []
for module in sections:
# 确保module是一个干净的字符串
module = module.strip()
print(f"[DEBUG] 处理模块: '{module}'", flush=True)
if module: # 只处理非空模块
tasks.append(async_generate(topic, module))
else:
print(f"[DEBUG] 跳过空模块", flush=True)
print(f"[DEBUG] 创建的任务数量: {len(tasks)}", flush=True)
results = await asyncio.gather(*tasks)
print(f"[DEBUG] 获得结果数量: {len(results)}", flush=True)
return results
def save_to_json(topic, sections, contents):
print(f"[DEBUG] 准备保存内容 - sections数量: {len(sections)}, contents数量: {len(contents)}", flush=True)
# 清理sections,确保它们是干净的字符串
cleaned_sections = [section.strip() for section in sections if section.strip()]
print(f"[DEBUG] 清理后的sections: {cleaned_sections}", flush=True)
content_data = {}
for section, content in zip(cleaned_sections, contents):
print(f"[DEBUG] 保存模块 '{section}' 的内容", flush=True)
content_data[section] = content
print(f"[DEBUG] 最终数据包含 {len(content_data)} 个模块", flush=True)
print(f"保存内容到 {topic}_content.json", flush=True)
try:
with open(f"{topic}_content.json", "w", encoding='utf-8') as f:
json.dump(content_data, f, indent=4, ensure_ascii=False)
print("* 保存完成", flush=True)
except Exception as e:
print(f"! 保存文件时发生错误: {str(e)}", flush=True)
def async_main(topic, sections):
print(f"[DEBUG] 主程序开始 - 主题: {topic}", flush=True)
print(f"[DEBUG] 收到的命令行参数: {sections}", flush=True)
s = time.perf_counter()
contents = asyncio.run(generate_concurrently(topic, sections))
elapsed = time.perf_counter() - s
print(f"总用时: {elapsed:.2f} 秒", flush=True)
save_to_json(topic, sections, contents)
if __name__ == "__main__":
if sys.stdout.encoding != 'utf-8':
import codecs
sys.stdout = codecs.getwriter('utf-8')(sys.stdout.buffer, 'strict')
print(f"[DEBUG] 程序启动 - 参数数量: {len(sys.argv)}", flush=True)
print(f"[DEBUG] 完整参数列表: {sys.argv}", flush=True)
# 确保命令行参数是UTF-8编码
user_topic = sys.argv[1]
sections = [arg.strip() for arg in sys.argv[2:] if arg.strip()]
async_main(user_topic, sections)