Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
104 changes: 68 additions & 36 deletions cli/azd/extensions/azure.ai.finetune/internal/cmd/operations.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,10 @@ import (
"github.com/fatih/color"
"github.com/spf13/cobra"

FTYaml "azure.ai.finetune/internal/fine_tuning_yaml"
"azure.ai.finetune/internal/services"
JobWrapper "azure.ai.finetune/internal/tools"
Utils "azure.ai.finetune/internal/utils"
"azure.ai.finetune/pkg/models"
)

func newOperationCommand() *cobra.Command {
Expand Down Expand Up @@ -65,15 +66,18 @@ func formatFineTunedModel(model string) string {

func newOperationSubmitCommand() *cobra.Command {
var filename string
var model string
var trainingFile string
var validationFile string
var suffix string
var seed int64
cmd := &cobra.Command{
Use: "submit",
Short: "Submit fine tuning job",
Short: "submit fine tuning job",
RunE: func(cmd *cobra.Command, args []string) error {
ctx := azdext.WithAccessToken(cmd.Context())

// Validate filename is provided
if filename == "" {
return fmt.Errorf("config file is required, use -f or --file flag")
if filename == "" && (model == "" || trainingFile == "") {
return fmt.Errorf("either config file or model and training-file parameters are required")
}

azdClient, err := azdext.NewAzdClient()
Expand All @@ -82,60 +86,88 @@ func newOperationSubmitCommand() *cobra.Command {
}
defer azdClient.Close()

// Parse and validate the YAML configuration file
color.Green("Parsing configuration file...")
config, err := FTYaml.ParseFineTuningConfig(filename)
if err != nil {
return err
// Show spinner while creating job
spinner := ux.NewSpinner(&ux.SpinnerOptions{
Text: "creating fine-tuning job...",
})
if err := spinner.Start(ctx); err != nil {
fmt.Printf("failed to start spinner: %v\n", err)
}

// Upload training file
// Parse and validate the YAML configuration file if provided
var config *models.CreateFineTuningRequest
if filename != "" {
color.Green("\nparsing configuration file...")
config, err = Utils.ParseCreateFineTuningRequestConfig(filename)
if err != nil {
_ = spinner.Stop(ctx)
fmt.Println()
return err
}
} else {
config = &models.CreateFineTuningRequest{}
}

trainingFileID, err := JobWrapper.UploadFileIfLocal(ctx, azdClient, config.TrainingFile)
if err != nil {
return fmt.Errorf("failed to upload training file: %w", err)
// Override config values with command-line parameters if provided
if model != "" {
config.BaseModel = model
}
if trainingFile != "" {

// Upload validation file if provided
var validationFileID string
if config.ValidationFile != "" {
validationFileID, err = JobWrapper.UploadFileIfLocal(ctx, azdClient, config.ValidationFile)
if err != nil {
return fmt.Errorf("failed to upload validation file: %w", err)
}
config.TrainingFile = trainingFile
}
if validationFile != "" {
config.ValidationFile = &validationFile
}
if suffix != "" {
config.Suffix = &suffix
}
if seed != 0 {
config.Seed = &seed
}

// Create fine-tuning job
// Convert YAML configuration to OpenAI job parameters
jobParams, err := ConvertYAMLToJobParams(config, trainingFileID, validationFileID)
fineTuneSvc, err := services.NewFineTuningService(ctx, azdClient, nil)
if err != nil {
return fmt.Errorf("failed to convert configuration to job parameters: %w", err)
_ = spinner.Stop(ctx)
fmt.Println()
return err
}

// Submit the fine-tuning job using CreateJob from JobWrapper
job, err := JobWrapper.CreateJob(ctx, azdClient, jobParams)
job, err := fineTuneSvc.CreateFineTuningJob(ctx, config)
_ = spinner.Stop(ctx)
fmt.Println()

if err != nil {
return err
}

// Print success message
fmt.Println(strings.Repeat("=", 120))
color.Green("\nSuccessfully submitted fine-tuning Job!\n")
fmt.Printf("Job ID: %s\n", job.Id)
fmt.Printf("Model: %s\n", job.Model)
fmt.Println("\n", strings.Repeat("=", 60))
color.Green("\nsuccessfully submitted fine-tuning Job!\n")
fmt.Printf("Job ID: %s\n", job.ID)
fmt.Printf("Model: %s\n", job.BaseModel)
fmt.Printf("Status: %s\n", job.Status)
fmt.Printf("Created: %s\n", job.CreatedAt)
if job.FineTunedModel != "" {
fmt.Printf("Fine-tuned: %s\n", job.FineTunedModel)
}
fmt.Println(strings.Repeat("=", 120))

fmt.Println(strings.Repeat("=", 60))
return nil
},
}

cmd.Flags().StringVarP(&filename, "file", "f", "", "Path to the config file")

cmd.Flags().StringVarP(&filename, "file", "f", "", "Path to the config file.")
cmd.Flags().StringVarP(&model, "model", "m", "", "Base model to fine-tune. Overrides config file. Required if --file is not provided")
cmd.Flags().StringVarP(&trainingFile, "training-file", "t", "", "Training file ID or local path. Use 'local:' prefix for local paths. Required if --file is not provided")
cmd.Flags().StringVarP(&validationFile, "validation-file", "v", "", "Validation file ID or local path. Use 'local:' prefix for local paths.")
cmd.Flags().StringVarP(&suffix, "suffix", "s", "", "An optional string of up to 64 characters that will be added to your fine-tuned model name. Overrides config file.")
cmd.Flags().Int64VarP(&seed, "seed", "r", 0, "Random seed for reproducibility of the job. If a seed is not specified, one will be generated for you. Overrides config file.")

//Either config file should be provided or at least `model` & `training-file` parameters
cmd.MarkFlagFilename("file", "yaml", "yml")
cmd.MarkFlagsOneRequired("file", "model")
cmd.MarkFlagsRequiredTogether("model", "training-file")
return cmd
}

Expand Down Expand Up @@ -289,7 +321,7 @@ func newOperationListCommand() *cobra.Command {
jobs, err := fineTuneSvc.ListFineTuningJobs(ctx, limit, after)
_ = spinner.Stop(ctx)
if err != nil {
fmt.Println()
fmt.Println()
return err
}

Expand Down
Loading