From a5b576a3d97edb130b9ba496de449c49b9e0fec1 Mon Sep 17 00:00:00 2001 From: Kent Keirsey Date: Fri, 15 Aug 2025 09:42:29 -0400 Subject: [PATCH 1/3] flux UI fixes --- .../ui/config_groups/flux_lora_config_group.py | 18 ++++++++++++++++-- 1 file changed, 16 insertions(+), 2 deletions(-) diff --git a/src/invoke_training/ui/config_groups/flux_lora_config_group.py b/src/invoke_training/ui/config_groups/flux_lora_config_group.py index 69b6c93f..dacad3ff 100644 --- a/src/invoke_training/ui/config_groups/flux_lora_config_group.py +++ b/src/invoke_training/ui/config_groups/flux_lora_config_group.py @@ -9,6 +9,7 @@ ) from invoke_training.ui.config_groups.optimizer_config_group import OptimizerConfigGroup from invoke_training.ui.config_groups.ui_config_element import UIConfigElement +from invoke_training.ui.utils.prompts import convert_pos_neg_prompts_to_ui_prompts, convert_ui_prompts_to_pos_neg_prompts from invoke_training.ui.utils.utils import get_typing_literal_options @@ -64,6 +65,7 @@ def __init__(self): with gr.Tab("Core"): with gr.Row(): self.train_transformer = gr.Checkbox(label="Train Transformer", interactive=True) + self.train_text_encoder = gr.Checkbox(label="Train Text Encoder", interactive=True) with gr.Row(): self.transformer_learning_rate = gr.Number( label="Transformer Learning Rate", @@ -71,6 +73,12 @@ def __init__(self): "learning rate.", interactive=True, ) + self.text_encoder_learning_rate = gr.Number( + label="Text Encoder Learning Rate", + info="The text encoder learning rate. If None, then it is inherited from the base optimizer " + "learning rate.", + interactive=True, + ) with gr.Row(): self.gradient_accumulation_steps = gr.Number( label="Gradient Accumulation Steps", @@ -187,7 +195,9 @@ def get_ui_output_components(self) -> list[gr.components.Component]: components = [ self.model, self.train_transformer, + self.train_text_encoder, self.transformer_learning_rate, + self.text_encoder_learning_rate, self.gradient_accumulation_steps, self.gradient_checkpointing, self.lr_scheduler, @@ -228,7 +238,9 @@ def update_ui_components_with_config_data( update_dict = { self.model: config.model, self.train_transformer: config.train_transformer, + self.train_text_encoder: config.train_text_encoder, self.transformer_learning_rate: config.transformer_learning_rate, + self.text_encoder_learning_rate: config.text_encoder_learning_rate, self.gradient_accumulation_steps: config.gradient_accumulation_steps, self.gradient_checkpointing: config.gradient_checkpointing, self.lr_scheduler: config.lr_scheduler, @@ -247,7 +259,9 @@ def update_ui_components_with_config_data( self.guidance_scale: config.guidance_scale, self.use_masks: config.use_masks, self.prediction_type: config.prediction_type, - self.validation_prompts: config.validation_prompts, + self.validation_prompts: convert_pos_neg_prompts_to_ui_prompts( + config.validation_prompts, None + ), self.num_validation_images_per_prompt: config.num_validation_images_per_prompt, self.max_checkpoints: config.max_checkpoints, } @@ -357,7 +371,7 @@ def safe_pop(component, default=None): # Handle validation prompts try: validation_prompts_text = safe_pop(self.validation_prompts, "") - positive_prompts = validation_prompts_text + positive_prompts, _ = convert_ui_prompts_to_pos_neg_prompts(validation_prompts_text) new_config.validation_prompts = positive_prompts except Exception as e: print(f"Error processing validation prompts: {e}") From 9c40523dd10a25f6e1d2ac383dbe00da4e7ebc84 Mon Sep 17 00:00:00 2001 From: Kent Keirsey Date: Sat, 16 Aug 2025 22:02:12 -0400 Subject: [PATCH 2/3] Add Stop Training Button --- .gitignore | 1 + .../ui/gradio_blocks/pipeline_tab.py | 54 +++++++++++++++++-- src/invoke_training/ui/pages/training_page.py | 48 ++++++++++++++++- 3 files changed, 99 insertions(+), 4 deletions(-) diff --git a/.gitignore b/.gitignore index 76b01973..252e16c9 100644 --- a/.gitignore +++ b/.gitignore @@ -170,3 +170,4 @@ cython_debug/ # option (not recommended) you can uncomment the following to ignore the entire idea folder. #.idea/ .aider* +datasets/ \ No newline at end of file diff --git a/src/invoke_training/ui/gradio_blocks/pipeline_tab.py b/src/invoke_training/ui/gradio_blocks/pipeline_tab.py index 3c2f8611..b5e3e83a 100644 --- a/src/invoke_training/ui/gradio_blocks/pipeline_tab.py +++ b/src/invoke_training/ui/gradio_blocks/pipeline_tab.py @@ -16,6 +16,8 @@ def __init__( pipeline_config_cls: typing.Type[PipelineConfig], config_group_cls: typing.Type[UIConfigElement], run_training_cb: typing.Callable[[PipelineConfig], None], + stop_training_cb: typing.Callable[[], str], + is_training_active_cb: typing.Callable[[], bool], app: gr.Blocks, ): """A tab for a single training pipeline type. @@ -27,6 +29,8 @@ def __init__( self._default_config_file_path = default_config_file_path self._pipeline_config_cls = pipeline_config_cls self._run_training_cb = run_training_cb + self._stop_training_cb = stop_training_cb + self._is_training_active_cb = is_training_active_cb # self._default_config is the config that was last loaded from the reference config file. self._default_config = None @@ -52,7 +56,18 @@ def __init__( **Warning: Click 'Generate Config' to capture all of the latest changes before starting training.** """ ) - run_training_button = gr.Button(value="Start Training") + + with gr.Row(): + run_training_button = gr.Button(value="Start Training", variant="primary") + stop_training_button = gr.Button(value="Stop Training", variant="stop", interactive=False) + + with gr.Row(): + training_status = gr.Textbox( + label="Training Status", + value="No training active", + interactive=False, + lines=2 + ) gr.Markdown( """# Visualize Results @@ -80,7 +95,16 @@ def __init__( outputs=self.pipeline_config_group.get_ui_output_components() + [self._config_yaml], ) - run_training_button.click(self.on_run_training_button_click, inputs=[], outputs=[]) + run_training_button.click( + self.on_run_training_button_click, + inputs=[], + outputs=[stop_training_button, training_status] + ) + stop_training_button.click( + self.on_stop_training_button_click, + inputs=[], + outputs=[run_training_button, stop_training_button, training_status] + ) # On app load, reset the configs based on the default reference config file. # We'll wrap this in a try-except block to handle any errors during loading @@ -99,6 +123,22 @@ def safe_load_config(file_path): outputs=self.pipeline_config_group.get_ui_output_components() + [self._config_yaml], ) + # Add status update function for manual refresh + def update_training_status(): + is_active = self._is_training_active_cb() + if is_active: + return gr.Button(interactive=False), gr.Button(interactive=True), "Training in progress..." + else: + return gr.Button(interactive=True), gr.Button(interactive=False), "No training active" + + # Add a refresh button for status updates + refresh_status_button = gr.Button(value="🔄 Refresh Status", size="sm") + refresh_status_button.click( + update_training_status, + inputs=[], + outputs=[run_training_button, stop_training_button, training_status] + ) + def on_reset_config_button_click(self, file_path: str): try: print(f"Resetting UI configs for {self._name} to {file_path}.") @@ -160,4 +200,12 @@ def on_generate_config_button_click(self, data: dict): return {self._config_yaml: f"Error generating config: {e}"} def on_run_training_button_click(self): - self._run_training_cb(self._current_config) + result = self._run_training_cb(self._current_config) + if result and "already in progress" in result: + return gr.Button(interactive=True), result + else: + return gr.Button(interactive=False), result or f"Training started for {self._name}..." + + def on_stop_training_button_click(self): + result = self._stop_training_cb() + return gr.Button(interactive=True), gr.Button(interactive=False), result diff --git a/src/invoke_training/ui/pages/training_page.py b/src/invoke_training/ui/pages/training_page.py index f397ffc8..184eca00 100644 --- a/src/invoke_training/ui/pages/training_page.py +++ b/src/invoke_training/ui/pages/training_page.py @@ -88,6 +88,8 @@ def __init__(self): pipeline_config_cls=SdLoraConfig, config_group_cls=SdLoraConfigGroup, run_training_cb=self._run_training, + stop_training_cb=self._stop_training, + is_training_active_cb=self._is_training_active, app=app, ) with gr.Tab(label="SDXL LoRA"): @@ -97,6 +99,8 @@ def __init__(self): pipeline_config_cls=SdxlLoraConfig, config_group_cls=SdxlLoraConfigGroup, run_training_cb=self._run_training, + stop_training_cb=self._stop_training, + is_training_active_cb=self._is_training_active, app=app, ) with gr.Tab(label="SD Textual Inversion"): @@ -106,6 +110,8 @@ def __init__(self): pipeline_config_cls=SdTextualInversionConfig, config_group_cls=SdTextualInversionConfigGroup, run_training_cb=self._run_training, + stop_training_cb=self._stop_training, + is_training_active_cb=self._is_training_active, app=app, ) with gr.Tab(label="SDXL Textual Inversion"): @@ -115,6 +121,8 @@ def __init__(self): pipeline_config_cls=SdxlTextualInversionConfig, config_group_cls=SdxlTextualInversionConfigGroup, run_training_cb=self._run_training, + stop_training_cb=self._stop_training, + is_training_active_cb=self._is_training_active, app=app, ) with gr.Tab(label="SDXL LoRA and Textual Inversion"): @@ -124,6 +132,8 @@ def __init__(self): pipeline_config_cls=SdxlLoraAndTextualInversionConfig, config_group_cls=SdxlLoraAndTextualInversionConfigGroup, run_training_cb=self._run_training, + stop_training_cb=self._stop_training, + is_training_active_cb=self._is_training_active, app=app, ) with gr.Tab(label="SDXL Finetune"): @@ -133,6 +143,8 @@ def __init__(self): pipeline_config_cls=SdxlFinetuneConfig, config_group_cls=SdxlFinetuneConfigGroup, run_training_cb=self._run_training, + stop_training_cb=self._stop_training, + is_training_active_cb=self._is_training_active, app=app, ) with gr.Tab(label="Flux LoRA"): @@ -144,6 +156,8 @@ def __init__(self): pipeline_config_cls=FluxLoraConfig, config_group_cls=FluxLoraConfigGroup, run_training_cb=self._run_training, + stop_training_cb=self._stop_training, + is_training_active_cb=self._is_training_active, app=app, ) @@ -160,7 +174,7 @@ def _run_training(self, config: PipelineConfig): "Tried to start a new training process, but another training process is already running. " "Terminate the existing process first." ) - return + return "Training already in progress. Stop the current training first." else: self._training_process = None @@ -175,3 +189,35 @@ def _run_training(self, config: PipelineConfig): self._training_process = subprocess.Popen(["invoke-train", "-c", str(config_path)]) print(f"Started {config.type} training.") + return f"Started {config.type} training successfully." + + def _stop_training(self): + """Gracefully terminate the active training process.""" + if self._training_process is not None and self._training_process.poll() is None: + print("Terminating training process...") + try: + # Send SIGTERM for graceful termination + self._training_process.terminate() + + # Wait for the process to terminate gracefully (up to 10 seconds) + try: + self._training_process.wait(timeout=10) + print("Training process terminated gracefully.") + except subprocess.TimeoutExpired: + # If graceful termination fails, force kill + print("Graceful termination failed, forcing kill...") + self._training_process.kill() + self._training_process.wait() + print("Training process force killed.") + + self._training_process = None + return "Training stopped successfully." + except Exception as e: + print(f"Error terminating training process: {e}") + return f"Error stopping training: {e}" + else: + return "No active training process to stop." + + def _is_training_active(self): + """Check if there is an active training process.""" + return self._training_process is not None and self._training_process.poll() is None From 08f29ca65c0c5e89ed5f9bd046a2f836880eac63 Mon Sep 17 00:00:00 2001 From: Kent Keirsey Date: Mon, 18 Aug 2025 11:46:34 -0400 Subject: [PATCH 3/3] fix alpha scaling --- src/invoke_training/pipelines/flux/lora/train.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/src/invoke_training/pipelines/flux/lora/train.py b/src/invoke_training/pipelines/flux/lora/train.py index 1ae48b71..0379b38a 100644 --- a/src/invoke_training/pipelines/flux/lora/train.py +++ b/src/invoke_training/pipelines/flux/lora/train.py @@ -424,8 +424,7 @@ def inject_lora_layers(model, lora_config: peft.LoraConfig, lr: float | None = N if config.train_transformer: transformer_lora_config = peft.LoraConfig( r=config.lora_rank_dim, - # TODO(ryand): Diffusers uses lora_alpha=config.lora_rank_dim. Is that preferred? - lora_alpha=1.0, + lora_alpha=config.lora_rank_dim, target_modules=config.flux_lora_target_modules, ) transformer = inject_lora_layers(transformer, transformer_lora_config, lr=config.transformer_learning_rate) @@ -433,7 +432,7 @@ def inject_lora_layers(model, lora_config: peft.LoraConfig, lr: float | None = N if config.train_text_encoder: text_encoder_lora_config = peft.LoraConfig( r=config.lora_rank_dim, - lora_alpha=1.0, + lora_alpha=config.lora_rank_dim, # init_lora_weights="gaussian", target_modules=config.text_encoder_lora_target_modules, )