Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -170,3 +170,4 @@ cython_debug/
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
#.idea/
.aider*
datasets/
5 changes: 2 additions & 3 deletions src/invoke_training/pipelines/flux/lora/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -424,16 +424,15 @@ def inject_lora_layers(model, lora_config: peft.LoraConfig, lr: float | None = N
if config.train_transformer:
transformer_lora_config = peft.LoraConfig(
r=config.lora_rank_dim,
# TODO(ryand): Diffusers uses lora_alpha=config.lora_rank_dim. Is that preferred?
lora_alpha=1.0,
lora_alpha=config.lora_rank_dim,
target_modules=config.flux_lora_target_modules,
)
transformer = inject_lora_layers(transformer, transformer_lora_config, lr=config.transformer_learning_rate)

if config.train_text_encoder:
text_encoder_lora_config = peft.LoraConfig(
r=config.lora_rank_dim,
lora_alpha=1.0,
lora_alpha=config.lora_rank_dim,
# init_lora_weights="gaussian",
target_modules=config.text_encoder_lora_target_modules,
)
Expand Down
18 changes: 16 additions & 2 deletions src/invoke_training/ui/config_groups/flux_lora_config_group.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
import typing

import gradio as gr

from invoke_training.pipelines.flux.lora.config import FluxLoraConfig
from invoke_training.ui.config_groups.base_pipeline_config_group import BasePipelineConfigGroup
from invoke_training.ui.config_groups.image_caption_sd_data_loader_config_group import (
ImageCaptionSDDataLoaderConfigGroup,
)
from invoke_training.ui.config_groups.optimizer_config_group import OptimizerConfigGroup
from invoke_training.ui.config_groups.ui_config_element import UIConfigElement
from invoke_training.ui.utils.prompts import convert_pos_neg_prompts_to_ui_prompts, convert_ui_prompts_to_pos_neg_prompts

Check failure on line 12 in src/invoke_training/ui/config_groups/flux_lora_config_group.py

View workflow job for this annotation

GitHub Actions / build (3.12)

Ruff (E501)

src/invoke_training/ui/config_groups/flux_lora_config_group.py:12:121: E501 Line too long (121 > 120)
from invoke_training.ui.utils.utils import get_typing_literal_options

Check failure on line 13 in src/invoke_training/ui/config_groups/flux_lora_config_group.py

View workflow job for this annotation

GitHub Actions / build (3.12)

Ruff (I001)

src/invoke_training/ui/config_groups/flux_lora_config_group.py:1:1: I001 Import block is un-sorted or un-formatted


Expand Down Expand Up @@ -64,13 +65,20 @@
with gr.Tab("Core"):
with gr.Row():
self.train_transformer = gr.Checkbox(label="Train Transformer", interactive=True)
self.train_text_encoder = gr.Checkbox(label="Train Text Encoder", interactive=True)
with gr.Row():
self.transformer_learning_rate = gr.Number(
label="Transformer Learning Rate",
info="The transformer learning rate. If None, then it is inherited from the base optimizer "
"learning rate.",
interactive=True,
)
self.text_encoder_learning_rate = gr.Number(
label="Text Encoder Learning Rate",
info="The text encoder learning rate. If None, then it is inherited from the base optimizer "
"learning rate.",
interactive=True,
)
with gr.Row():
self.gradient_accumulation_steps = gr.Number(
label="Gradient Accumulation Steps",
Expand Down Expand Up @@ -187,7 +195,9 @@
components = [
self.model,
self.train_transformer,
self.train_text_encoder,
self.transformer_learning_rate,
self.text_encoder_learning_rate,
self.gradient_accumulation_steps,
self.gradient_checkpointing,
self.lr_scheduler,
Expand Down Expand Up @@ -228,7 +238,9 @@
update_dict = {
self.model: config.model,
self.train_transformer: config.train_transformer,
self.train_text_encoder: config.train_text_encoder,
self.transformer_learning_rate: config.transformer_learning_rate,
self.text_encoder_learning_rate: config.text_encoder_learning_rate,
self.gradient_accumulation_steps: config.gradient_accumulation_steps,
self.gradient_checkpointing: config.gradient_checkpointing,
self.lr_scheduler: config.lr_scheduler,
Expand All @@ -247,7 +259,9 @@
self.guidance_scale: config.guidance_scale,
self.use_masks: config.use_masks,
self.prediction_type: config.prediction_type,
self.validation_prompts: config.validation_prompts,
self.validation_prompts: convert_pos_neg_prompts_to_ui_prompts(
config.validation_prompts, None
),
self.num_validation_images_per_prompt: config.num_validation_images_per_prompt,
self.max_checkpoints: config.max_checkpoints,
}
Expand Down Expand Up @@ -357,7 +371,7 @@
# Handle validation prompts
try:
validation_prompts_text = safe_pop(self.validation_prompts, "")
positive_prompts = validation_prompts_text
positive_prompts, _ = convert_ui_prompts_to_pos_neg_prompts(validation_prompts_text)
new_config.validation_prompts = positive_prompts
except Exception as e:
print(f"Error processing validation prompts: {e}")
Expand Down
54 changes: 51 additions & 3 deletions src/invoke_training/ui/gradio_blocks/pipeline_tab.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
pipeline_config_cls: typing.Type[PipelineConfig],
config_group_cls: typing.Type[UIConfigElement],
run_training_cb: typing.Callable[[PipelineConfig], None],
stop_training_cb: typing.Callable[[], str],
is_training_active_cb: typing.Callable[[], bool],
app: gr.Blocks,
):
"""A tab for a single training pipeline type.
Expand All @@ -27,6 +29,8 @@
self._default_config_file_path = default_config_file_path
self._pipeline_config_cls = pipeline_config_cls
self._run_training_cb = run_training_cb
self._stop_training_cb = stop_training_cb
self._is_training_active_cb = is_training_active_cb

# self._default_config is the config that was last loaded from the reference config file.
self._default_config = None
Expand All @@ -52,7 +56,18 @@
**Warning: Click 'Generate Config' to capture all of the latest changes before starting training.**
"""
)
run_training_button = gr.Button(value="Start Training")

Check failure on line 59 in src/invoke_training/ui/gradio_blocks/pipeline_tab.py

View workflow job for this annotation

GitHub Actions / build (3.12)

Ruff (W293)

src/invoke_training/ui/gradio_blocks/pipeline_tab.py:59:1: W293 Blank line contains whitespace
with gr.Row():
run_training_button = gr.Button(value="Start Training", variant="primary")
stop_training_button = gr.Button(value="Stop Training", variant="stop", interactive=False)

Check failure on line 63 in src/invoke_training/ui/gradio_blocks/pipeline_tab.py

View workflow job for this annotation

GitHub Actions / build (3.12)

Ruff (W293)

src/invoke_training/ui/gradio_blocks/pipeline_tab.py:63:1: W293 Blank line contains whitespace
with gr.Row():
training_status = gr.Textbox(
label="Training Status",

Check failure on line 66 in src/invoke_training/ui/gradio_blocks/pipeline_tab.py

View workflow job for this annotation

GitHub Actions / build (3.12)

Ruff (W291)

src/invoke_training/ui/gradio_blocks/pipeline_tab.py:66:41: W291 Trailing whitespace
value="No training active",

Check failure on line 67 in src/invoke_training/ui/gradio_blocks/pipeline_tab.py

View workflow job for this annotation

GitHub Actions / build (3.12)

Ruff (W291)

src/invoke_training/ui/gradio_blocks/pipeline_tab.py:67:44: W291 Trailing whitespace
interactive=False,
lines=2
)

gr.Markdown(
"""# Visualize Results
Expand Down Expand Up @@ -80,7 +95,16 @@
outputs=self.pipeline_config_group.get_ui_output_components() + [self._config_yaml],
)

run_training_button.click(self.on_run_training_button_click, inputs=[], outputs=[])
run_training_button.click(
self.on_run_training_button_click,

Check failure on line 99 in src/invoke_training/ui/gradio_blocks/pipeline_tab.py

View workflow job for this annotation

GitHub Actions / build (3.12)

Ruff (W291)

src/invoke_training/ui/gradio_blocks/pipeline_tab.py:99:47: W291 Trailing whitespace
inputs=[],

Check failure on line 100 in src/invoke_training/ui/gradio_blocks/pipeline_tab.py

View workflow job for this annotation

GitHub Actions / build (3.12)

Ruff (W291)

src/invoke_training/ui/gradio_blocks/pipeline_tab.py:100:23: W291 Trailing whitespace
outputs=[stop_training_button, training_status]
)
stop_training_button.click(
self.on_stop_training_button_click,
inputs=[],
outputs=[run_training_button, stop_training_button, training_status]
)

# On app load, reset the configs based on the default reference config file.
# We'll wrap this in a try-except block to handle any errors during loading
Expand All @@ -99,6 +123,22 @@
outputs=self.pipeline_config_group.get_ui_output_components() + [self._config_yaml],
)

# Add status update function for manual refresh
def update_training_status():
is_active = self._is_training_active_cb()
if is_active:
return gr.Button(interactive=False), gr.Button(interactive=True), "Training in progress..."
else:
return gr.Button(interactive=True), gr.Button(interactive=False), "No training active"

# Add a refresh button for status updates
refresh_status_button = gr.Button(value="🔄 Refresh Status", size="sm")
refresh_status_button.click(
update_training_status,
inputs=[],
outputs=[run_training_button, stop_training_button, training_status]
)

def on_reset_config_button_click(self, file_path: str):
try:
print(f"Resetting UI configs for {self._name} to {file_path}.")
Expand Down Expand Up @@ -160,4 +200,12 @@
return {self._config_yaml: f"Error generating config: {e}"}

def on_run_training_button_click(self):
self._run_training_cb(self._current_config)
result = self._run_training_cb(self._current_config)
if result and "already in progress" in result:
return gr.Button(interactive=True), result
else:
return gr.Button(interactive=False), result or f"Training started for {self._name}..."

def on_stop_training_button_click(self):
result = self._stop_training_cb()
return gr.Button(interactive=True), gr.Button(interactive=False), result
48 changes: 47 additions & 1 deletion src/invoke_training/ui/pages/training_page.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,8 @@
pipeline_config_cls=SdLoraConfig,
config_group_cls=SdLoraConfigGroup,
run_training_cb=self._run_training,
stop_training_cb=self._stop_training,
is_training_active_cb=self._is_training_active,
app=app,
)
with gr.Tab(label="SDXL LoRA"):
Expand All @@ -97,6 +99,8 @@
pipeline_config_cls=SdxlLoraConfig,
config_group_cls=SdxlLoraConfigGroup,
run_training_cb=self._run_training,
stop_training_cb=self._stop_training,
is_training_active_cb=self._is_training_active,
app=app,
)
with gr.Tab(label="SD Textual Inversion"):
Expand All @@ -106,6 +110,8 @@
pipeline_config_cls=SdTextualInversionConfig,
config_group_cls=SdTextualInversionConfigGroup,
run_training_cb=self._run_training,
stop_training_cb=self._stop_training,
is_training_active_cb=self._is_training_active,
app=app,
)
with gr.Tab(label="SDXL Textual Inversion"):
Expand All @@ -115,6 +121,8 @@
pipeline_config_cls=SdxlTextualInversionConfig,
config_group_cls=SdxlTextualInversionConfigGroup,
run_training_cb=self._run_training,
stop_training_cb=self._stop_training,
is_training_active_cb=self._is_training_active,
app=app,
)
with gr.Tab(label="SDXL LoRA and Textual Inversion"):
Expand All @@ -124,6 +132,8 @@
pipeline_config_cls=SdxlLoraAndTextualInversionConfig,
config_group_cls=SdxlLoraAndTextualInversionConfigGroup,
run_training_cb=self._run_training,
stop_training_cb=self._stop_training,
is_training_active_cb=self._is_training_active,
app=app,
)
with gr.Tab(label="SDXL Finetune"):
Expand All @@ -133,6 +143,8 @@
pipeline_config_cls=SdxlFinetuneConfig,
config_group_cls=SdxlFinetuneConfigGroup,
run_training_cb=self._run_training,
stop_training_cb=self._stop_training,
is_training_active_cb=self._is_training_active,
app=app,
)
with gr.Tab(label="Flux LoRA"):
Expand All @@ -144,6 +156,8 @@
pipeline_config_cls=FluxLoraConfig,
config_group_cls=FluxLoraConfigGroup,
run_training_cb=self._run_training,
stop_training_cb=self._stop_training,
is_training_active_cb=self._is_training_active,
app=app,
)

Expand All @@ -160,7 +174,7 @@
"Tried to start a new training process, but another training process is already running. "
"Terminate the existing process first."
)
return
return "Training already in progress. Stop the current training first."
else:
self._training_process = None

Expand All @@ -175,3 +189,35 @@
self._training_process = subprocess.Popen(["invoke-train", "-c", str(config_path)])

print(f"Started {config.type} training.")
return f"Started {config.type} training successfully."

def _stop_training(self):
"""Gracefully terminate the active training process."""
if self._training_process is not None and self._training_process.poll() is None:
print("Terminating training process...")
try:
# Send SIGTERM for graceful termination
self._training_process.terminate()

Check failure on line 201 in src/invoke_training/ui/pages/training_page.py

View workflow job for this annotation

GitHub Actions / build (3.12)

Ruff (W293)

src/invoke_training/ui/pages/training_page.py:201:1: W293 Blank line contains whitespace
# Wait for the process to terminate gracefully (up to 10 seconds)
try:
self._training_process.wait(timeout=10)
print("Training process terminated gracefully.")
except subprocess.TimeoutExpired:
# If graceful termination fails, force kill
print("Graceful termination failed, forcing kill...")
self._training_process.kill()
self._training_process.wait()
print("Training process force killed.")

Check failure on line 212 in src/invoke_training/ui/pages/training_page.py

View workflow job for this annotation

GitHub Actions / build (3.12)

Ruff (W293)

src/invoke_training/ui/pages/training_page.py:212:1: W293 Blank line contains whitespace
self._training_process = None
return "Training stopped successfully."
except Exception as e:
print(f"Error terminating training process: {e}")
return f"Error stopping training: {e}"
else:
return "No active training process to stop."

def _is_training_active(self):
"""Check if there is an active training process."""
return self._training_process is not None and self._training_process.poll() is None
Loading