-
Notifications
You must be signed in to change notification settings - Fork 2
Expand file tree
/
Copy pathcompose.py
More file actions
95 lines (79 loc) · 3.49 KB
/
compose.py
File metadata and controls
95 lines (79 loc) · 3.49 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
import asyncio
import os
import aiohttp
import aiofiles
BACKEND_V1_API_URL = "https://public-api.beatoven.ai/api/v1"
BACKEND_API_HEADER_KEY = os.getenv("BEATOVEN_API_KEY", "")
async def compose_track(request_data):
async with aiohttp.ClientSession() as session:
try:
async with session.post(
f"{BACKEND_V1_API_URL}/tracks/compose",
json=request_data,
headers={"Authorization": f"Bearer {BACKEND_API_HEADER_KEY}"},
) as response:
if response.status == 200:
data = await response.json()
return data
except aiohttp.ClientConnectionError:
raise Exception({"error": "Could not connect to beatoven.ai"})
except Exception as e:
raise Exception({"error": "Failed to make a request to beatoven.ai"}) from e
finally:
if not data.get("task_id"):
raise Exception(data)
async def get_track_status(task_id):
async with aiohttp.ClientSession() as session:
try:
async with session.get(
f"{BACKEND_V1_API_URL}/tasks/{task_id}",
headers={"Authorization": f"Bearer {BACKEND_API_HEADER_KEY}"},
) as response:
if response.status == 200:
data = await response.json()
return data
else:
raise Exception({"error": "Composition failed"})
except aiohttp.ClientConnectionError as e:
raise Exception({"error": "Could not connect"}) from e
except Exception as e:
raise Exception({"error": "Failed to make a request"}) from e
async def handle_track_file(track_path: str, track_url: str):
async with aiohttp.ClientSession() as session:
try:
async with session.get(track_url) as response:
if response.status == 200:
async with aiofiles.open(track_path, "wb+") as f:
await f.write(await response.read())
return {}
except aiohttp.ClientConnectionError as e:
raise Exception({"error": "Could not download file"}) from e
except Exception as e:
raise Exception(
{"error": "Failed to make a request to get track file"}
) from e
async def watch_task_status(task_id, interval=10):
while True:
track_status = await get_track_status(task_id)
if "error" in track_status:
raise Exception(track_status)
print(f"Task status: {track_status}")
if track_status.get("status") == "composing":
await asyncio.sleep(interval)
elif track_status.get("status") == "failed":
raise Exception({"error": "task failed"})
else:
return track_status
async def create_and_compose(text_prompt: str):
track_meta = {"prompt": {"text": text_prompt}, "format": "wav"}
compose_res = await compose_track(track_meta)
task_id = compose_res["task_id"]
print(f"Started composition task with ID: {task_id}")
generation_meta = await watch_task_status(task_id)
print(generation_meta)
track_url = generation_meta["meta"]["track_url"]
print("Downloading track file")
await handle_track_file(os.path.join(os.getcwd(), "composed_track.mp3"), track_url)
print("Composed! you can find your track as `composed_track.mp3`")
if __name__ == "__main__":
asyncio.run(create_and_compose("30 seconds soft piano melody"))