diff --git a/Ollama.md b/Ollama.md new file mode 100644 index 00000000..846bcbde --- /dev/null +++ b/Ollama.md @@ -0,0 +1,25 @@ +# Using Ollama for Lyric Generation + +As an alternative to running LLM models with GGUF format directly with `llama-cpp` +by using the Python wrapper `llama-cpp-python`, `llama-cpp` hosted +by [Ollama](https://github.com/ollama) can be used for lyric generation. + +The advantages are +- usage of Ollama LLM models +- usage of several GPUs out-of-the-box + +## Prerequisites + +An Ollama instance up and running with the intended Ollama LLM model. + +Installing the Python wrapper for the Ollama REST API: +``` +pip install ollama +``` + +## Usage + +For using Ollama, add the following command line arguments: +- `--ollama` +- `--model_path `, e.g. `--model_path gemma3:12b-it-q4_K_M` + \ No newline at end of file diff --git a/radio_gradio.py b/radio_gradio.py index 0634d2ef..24d3e798 100644 --- a/radio_gradio.py +++ b/radio_gradio.py @@ -238,17 +238,19 @@ def generate_identity(cls, genre: str, theme: str): return cls(name, slogan) class AIRadioStation: - def __init__(self, ace_step_pipeline: ACEStepPipeline, model_path: str = "gemma-3-12b-it-abliterated.q4_k_m"): + def __init__(self, ace_step_pipeline: ACEStepPipeline, model_path: str = "gemma-3-12b-it-abliterated.q4_k_m", ollama: bool = False): """ Initialize the AI Radio Station with continuous generation. Args: ace_step_pipeline: Initialized ACEStepPipeline for music generation model_path: Path to LLM model for lyric generation + ollama: Whether to use Ollama for lyric generation (default: False) """ self._pipeline = ace_step_pipeline # Store the original pipeline reference self.random_mode = False self.llm_model_path = model_path + self.ollama = ollama self.llm = None self._first_play = True self.pipeline_args = { @@ -286,26 +288,30 @@ def load_llm(self): self.unload_llm() gc.collect() if self.llm is None: - print("Loading LLM model...") - try: - from llama_cpp import Llama - self.llm = Llama( - model_path=self.llm_model_path, - n_ctx=2048, - n_threads=4, - n_gpu_layers=-1, - seed = -1 # random seed for random lyrics + if self.ollama: + self.llm = True + else: + print("Loading LLM model...") + try: + from llama_cpp import Llama + self.llm = Llama( + model_path=self.llm_model_path, + n_ctx=2048, + n_threads=4, + n_gpu_layers=-1, + seed = -1 # random seed for random lyrics - ) - except ImportError: - print("Warning: llama-cpp-python not installed, using simple lyric generation") - self.llm = None + ) + except ImportError: + print("Warning: llama-cpp-python not installed, using simple lyric generation") + self.llm = None def unload_llm(self): """Unload the LLM model from memory""" if self.llm is not None: - print("Unloading LLM model...") - del self.llm + if not self.ollama: + print("Unloading LLM model...") + del self.llm self.llm = None if torch.cuda.is_available(): torch.cuda.empty_cache() @@ -742,18 +748,31 @@ def generate_lyrics_and_prompt(self, genre: str, theme: str, language: str = "En if self.llm: # Check if load was successful print(f"Using LLM for lyric generation (attempt {retry_count + 1}/{max_retries + 1})...") - output = self.llm( - prompt, - max_tokens=700, - temperature=0.7, - top_p=0.9, - repeat_penalty=1.1, - stop=["[End]", "\n\n\n"], - echo=False, - seed=-1 - ) - - lyrics = output["choices"][0]["text"].strip() + if self.ollama: + import ollama + output = ollama.chat( + model=self.llm_model_path, + messages=[ + { + 'role': 'user', + 'content': prompt, + }, + ] + ) + lyrics = output.message.content.strip() + else: + output = self.llm( + prompt, + max_tokens=700, + temperature=0.7, + top_p=0.9, + repeat_penalty=1.1, + stop=["[End]", "\n\n\n"], + echo=False, + seed=-1 + ) + lyrics = output["choices"][0]["text"].strip() + print(f"Generated lyrics:\n{lyrics}") # Validate lyrics quality @@ -1239,7 +1258,9 @@ def update_theme_suggestions(genre): ) as demo: gr.Markdown("# 🎵 AI Radio Station") gr.Markdown("Continuous AI-powered music generation using ACE") - + if radio.ollama: + gr.Markdown("#### Ollama Model for Lyric Generation: `" + radio.llm_model_path + "`") + # Add a timer component for automatic updates timer = gr.Timer(0.5, active=True) @@ -1262,12 +1283,19 @@ def update_theme_suggestions(genre): buffer_size = gr.Slider(1, 10, value=1, step=1, label="Buffer Size (songs)") random_mode = gr.Checkbox(label="Continuous Random Mode (after the first song)", value=True) random_languages = gr.Checkbox(label="Randomize Languages (after the first song)", value=False) - model_path_input = gr.File( - label="GGUF Model File", - file_types=[".gguf"], - value="gemma-3-12b-it-abliterated.q4_k_m.gguf" - ) - + if radio.ollama: + model_path_input = gr.File( + label="GGUF Model File (Ollama model used instead)", + file_types=[".gguf"], + value="gemma-3-12b-it-abliterated.q4_k_m.gguf" + ) + else: + model_path_input = gr.File( + label="GGUF Model File", + file_types=[".gguf"], + value="gemma-3-12b-it-abliterated.q4_k_m.gguf" + ) + with gr.Tab("Advanced Settings"): language_input = gr.Dropdown( choices=list(SUPPORTED_LANGUAGES.keys()), @@ -1413,6 +1441,8 @@ def main(): help="Use bfloat16 precision") parser.add_argument("--torch_compile", default=False, help="Enable torch compilation for faster inference") + parser.add_argument("--ollama", default=False, + help="Enable Ollama for lyric generation") args = parser.parse_args() os.environ["CUDA_VISIBLE_DEVICES"] = str(args.device_id) @@ -1430,7 +1460,8 @@ def main(): print("Initializing AI Radio Station...") radio = AIRadioStation( ace_step_pipeline=pipeline, - model_path=args.model_path + model_path=args.model_path, + ollama=args.ollama ) # Create and launch interface