diff --git a/cli/azd/extensions/azure.ai.finetune/internal/cmd/operations.go b/cli/azd/extensions/azure.ai.finetune/internal/cmd/operations.go index bfb574eb7db..e90329f35a8 100644 --- a/cli/azd/extensions/azure.ai.finetune/internal/cmd/operations.go +++ b/cli/azd/extensions/azure.ai.finetune/internal/cmd/operations.go @@ -13,7 +13,6 @@ import ( "github.com/azure/azure-dev/cli/azd/pkg/azdext" "github.com/azure/azure-dev/cli/azd/pkg/ux" - FTYaml "azure.ai.finetune/internal/fine_tuning_yaml" "azure.ai.finetune/internal/services" JobWrapper "azure.ai.finetune/internal/tools" "azure.ai.finetune/internal/utils" @@ -68,15 +67,18 @@ func formatFineTunedModel(model string) string { func newOperationSubmitCommand() *cobra.Command { var filename string + var model string + var trainingFile string + var validationFile string + var suffix string + var seed int64 cmd := &cobra.Command{ Use: "submit", - Short: "Submit fine tuning job", + Short: "submit fine tuning job", RunE: func(cmd *cobra.Command, args []string) error { ctx := azdext.WithAccessToken(cmd.Context()) - - // Validate filename is provided - if filename == "" { - return fmt.Errorf("config file is required, use -f or --file flag") + if filename == "" && (model == "" || trainingFile == "") { + return fmt.Errorf("either config file or model and training-file parameters are required") } azdClient, err := azdext.NewAzdClient() @@ -85,60 +87,88 @@ func newOperationSubmitCommand() *cobra.Command { } defer azdClient.Close() - // Parse and validate the YAML configuration file - color.Green("Parsing configuration file...") - config, err := FTYaml.ParseFineTuningConfig(filename) - if err != nil { - return err + // Show spinner while creating job + spinner := ux.NewSpinner(&ux.SpinnerOptions{ + Text: "creating fine-tuning job...", + }) + if err := spinner.Start(ctx); err != nil { + fmt.Printf("failed to start spinner: %v\n", err) } - // Upload training file + // Parse and validate the YAML configuration file if provided + var config *models.CreateFineTuningRequest + if filename != "" { + color.Green("\nparsing configuration file...") + config, err = utils.ParseCreateFineTuningRequestConfig(filename) + if err != nil { + _ = spinner.Stop(ctx) + fmt.Println() + return err + } + } else { + config = &models.CreateFineTuningRequest{} + } - trainingFileID, err := JobWrapper.UploadFileIfLocal(ctx, azdClient, config.TrainingFile) - if err != nil { - return fmt.Errorf("failed to upload training file: %w", err) + // Override config values with command-line parameters if provided + if model != "" { + config.BaseModel = model } + if trainingFile != "" { - // Upload validation file if provided - var validationFileID string - if config.ValidationFile != "" { - validationFileID, err = JobWrapper.UploadFileIfLocal(ctx, azdClient, config.ValidationFile) - if err != nil { - return fmt.Errorf("failed to upload validation file: %w", err) - } + config.TrainingFile = trainingFile + } + if validationFile != "" { + config.ValidationFile = &validationFile + } + if suffix != "" { + config.Suffix = &suffix + } + if seed != 0 { + config.Seed = &seed } - // Create fine-tuning job - // Convert YAML configuration to OpenAI job parameters - jobParams, err := ConvertYAMLToJobParams(config, trainingFileID, validationFileID) + fineTuneSvc, err := services.NewFineTuningService(ctx, azdClient, nil) if err != nil { - return fmt.Errorf("failed to convert configuration to job parameters: %w", err) + _ = spinner.Stop(ctx) + fmt.Println() + return err } // Submit the fine-tuning job using CreateJob from JobWrapper - job, err := JobWrapper.CreateJob(ctx, azdClient, jobParams) + job, err := fineTuneSvc.CreateFineTuningJob(ctx, config) + _ = spinner.Stop(ctx) + fmt.Println() + if err != nil { return err } // Print success message - fmt.Println(strings.Repeat("=", 120)) - color.Green("\nSuccessfully submitted fine-tuning Job!\n") - fmt.Printf("Job ID: %s\n", job.Id) - fmt.Printf("Model: %s\n", job.Model) + fmt.Println("\n", strings.Repeat("=", 60)) + color.Green("\nsuccessfully submitted fine-tuning Job!\n") + fmt.Printf("Job ID: %s\n", job.ID) + fmt.Printf("Model: %s\n", job.BaseModel) fmt.Printf("Status: %s\n", job.Status) fmt.Printf("Created: %s\n", job.CreatedAt) if job.FineTunedModel != "" { fmt.Printf("Fine-tuned: %s\n", job.FineTunedModel) } - fmt.Println(strings.Repeat("=", 120)) - + fmt.Println(strings.Repeat("=", 60)) return nil }, } - cmd.Flags().StringVarP(&filename, "file", "f", "", "Path to the config file") - + cmd.Flags().StringVarP(&filename, "file", "f", "", "Path to the config file.") + cmd.Flags().StringVarP(&model, "model", "m", "", "Base model to fine-tune. Overrides config file. Required if --file is not provided") + cmd.Flags().StringVarP(&trainingFile, "training-file", "t", "", "Training file ID or local path. Use 'local:' prefix for local paths. Required if --file is not provided") + cmd.Flags().StringVarP(&validationFile, "validation-file", "v", "", "Validation file ID or local path. Use 'local:' prefix for local paths.") + cmd.Flags().StringVarP(&suffix, "suffix", "s", "", "An optional string of up to 64 characters that will be added to your fine-tuned model name. Overrides config file.") + cmd.Flags().Int64VarP(&seed, "seed", "r", 0, "Random seed for reproducibility of the job. If a seed is not specified, one will be generated for you. Overrides config file.") + + //Either config file should be provided or at least `model` & `training-file` parameters + cmd.MarkFlagFilename("file", "yaml", "yml") + cmd.MarkFlagsOneRequired("file", "model") + cmd.MarkFlagsRequiredTogether("model", "training-file") return cmd } diff --git a/cli/azd/extensions/azure.ai.finetune/internal/providers/openai/conversions.go b/cli/azd/extensions/azure.ai.finetune/internal/providers/openai/conversions.go index 914524f0069..f27372483f1 100644 --- a/cli/azd/extensions/azure.ai.finetune/internal/providers/openai/conversions.go +++ b/cli/azd/extensions/azure.ai.finetune/internal/providers/openai/conversions.go @@ -4,6 +4,9 @@ package openai import ( + "encoding/json" + "strings" + "github.com/openai/openai-go/v3" "github.com/openai/openai-go/v3/packages/pagination" @@ -121,3 +124,280 @@ func convertOpenAIJobCheckpointsToModel(checkpointsPage *pagination.CursorPage[o HasMore: checkpointsPage.HasMore, } } + +// Converts the internal create finetuning request model to OpenAI job parameters +func convertInternalJobParamToOpenAiJobParams(config *models.CreateFineTuningRequest) (*openai.FineTuningJobNewParams, error) { + jobParams := openai.FineTuningJobNewParams{ + Model: openai.FineTuningJobNewParamsModel(config.BaseModel), + TrainingFile: config.TrainingFile, + } + + if config.ValidationFile != nil && *config.ValidationFile != "" { + jobParams.ValidationFile = openai.String(*config.ValidationFile) + } + + // Set optional fields + if config.Suffix != nil && *config.Suffix != "" { + jobParams.Suffix = openai.String(*config.Suffix) + } + + if config.Seed != nil { + jobParams.Seed = openai.Int(*config.Seed) + } + + // Set metadata if provided + if len(config.Metadata) > 0 { + jobParams.Metadata = make(map[string]string) + for k, v := range config.Metadata { + jobParams.Metadata[k] = v + } + } + + // Set hyperparameters if provided + if config.Method.Type == "supervised" && config.Method.Supervised != nil { + hp := config.Method.Supervised.Hyperparameters + supervisedMethod := openai.SupervisedMethodParam{ + Hyperparameters: openai.SupervisedHyperparameters{}, + } + + if hp.BatchSize != nil { + if batchSize := convertHyperparameterToInt(hp.BatchSize); batchSize != nil { + supervisedMethod.Hyperparameters.BatchSize = openai.SupervisedHyperparametersBatchSizeUnion{ + OfInt: openai.Int(*batchSize), + } + } + } + + if hp.LearningRateMultiplier != nil { + if lr := convertHyperparameterToFloat(hp.LearningRateMultiplier); lr != nil { + supervisedMethod.Hyperparameters.LearningRateMultiplier = openai.SupervisedHyperparametersLearningRateMultiplierUnion{ + OfFloat: openai.Float(*lr), + } + } + } + + if hp.Epochs != nil { + if epochs := convertHyperparameterToInt(hp.Epochs); epochs != nil { + supervisedMethod.Hyperparameters.NEpochs = openai.SupervisedHyperparametersNEpochsUnion{ + OfInt: openai.Int(*epochs), + } + } + } + + jobParams.Method = openai.FineTuningJobNewParamsMethod{ + Type: "supervised", + Supervised: supervisedMethod, + } + + } else if config.Method.Type == "dpo" && config.Method.DPO != nil { + hp := config.Method.DPO.Hyperparameters + dpoMethod := openai.DpoMethodParam{ + Hyperparameters: openai.DpoHyperparameters{}, + } + + if hp.BatchSize != nil { + if batchSize := convertHyperparameterToInt(hp.BatchSize); batchSize != nil { + dpoMethod.Hyperparameters.BatchSize = openai.DpoHyperparametersBatchSizeUnion{ + OfInt: openai.Int(*batchSize), + } + } + } + + if hp.LearningRateMultiplier != nil { + if lr := convertHyperparameterToFloat(hp.LearningRateMultiplier); lr != nil { + dpoMethod.Hyperparameters.LearningRateMultiplier = openai.DpoHyperparametersLearningRateMultiplierUnion{ + OfFloat: openai.Float(*lr), + } + } + } + + if hp.Epochs != nil { + if epochs := convertHyperparameterToInt(hp.Epochs); epochs != nil { + dpoMethod.Hyperparameters.NEpochs = openai.DpoHyperparametersNEpochsUnion{ + OfInt: openai.Int(*epochs), + } + } + } + + if hp.Beta != nil { + if beta := convertHyperparameterToFloat(hp.Beta); beta != nil { + dpoMethod.Hyperparameters.Beta = openai.DpoHyperparametersBetaUnion{ + OfFloat: openai.Float(*beta), + } + } + } + + jobParams.Method = openai.FineTuningJobNewParamsMethod{ + Type: "dpo", + Dpo: dpoMethod, + } + + } else if config.Method.Type == "reinforcement" && config.Method.Reinforcement != nil { + hp := config.Method.Reinforcement.Hyperparameters + reinforcementMethod := openai.ReinforcementMethodParam{ + Hyperparameters: openai.ReinforcementHyperparameters{}, + } + + if hp.BatchSize != nil { + if batchSize := convertHyperparameterToInt(hp.BatchSize); batchSize != nil { + reinforcementMethod.Hyperparameters.BatchSize = openai.ReinforcementHyperparametersBatchSizeUnion{ + OfInt: openai.Int(*batchSize), + } + } + } + + if hp.LearningRateMultiplier != nil { + if lr := convertHyperparameterToFloat(hp.LearningRateMultiplier); lr != nil { + reinforcementMethod.Hyperparameters.LearningRateMultiplier = openai.ReinforcementHyperparametersLearningRateMultiplierUnion{ + OfFloat: openai.Float(*lr), + } + } + } + + if hp.Epochs != nil { + if epochs := convertHyperparameterToInt(hp.Epochs); epochs != nil { + reinforcementMethod.Hyperparameters.NEpochs = openai.ReinforcementHyperparametersNEpochsUnion{ + OfInt: openai.Int(*epochs), + } + } + } + + if hp.ComputeMultiplier != nil { + if compute := convertHyperparameterToFloat(hp.ComputeMultiplier); compute != nil { + reinforcementMethod.Hyperparameters.ComputeMultiplier = openai.ReinforcementHyperparametersComputeMultiplierUnion{ + OfFloat: openai.Float(*compute), + } + } + } + + if hp.EvalInterval != nil { + if evalSteps := convertHyperparameterToInt(hp.EvalInterval); evalSteps != nil { + reinforcementMethod.Hyperparameters.EvalInterval = openai.ReinforcementHyperparametersEvalIntervalUnion{ + OfInt: openai.Int(*evalSteps), + } + } + } + + if hp.EvalSamples != nil { + if evalSamples := convertHyperparameterToInt(hp.EvalSamples); evalSamples != nil { + reinforcementMethod.Hyperparameters.EvalSamples = openai.ReinforcementHyperparametersEvalSamplesUnion{ + OfInt: openai.Int(*evalSamples), + } + } + } + + if hp.ReasoningEffort != "" { + reinforcementMethod.Hyperparameters.ReasoningEffort = getReasoningEffortValue(hp.ReasoningEffort) + } + + grader := config.Method.Reinforcement.Grader + if grader != nil { + // Convert grader to JSON and unmarshal to ReinforcementMethodGraderUnionParam + graderJSON, err := json.Marshal(grader) + if err != nil { + return nil, err + } + + var graderUnion openai.ReinforcementMethodGraderUnionParam + err = json.Unmarshal(graderJSON, &graderUnion) + if err != nil { + return nil, err + } + reinforcementMethod.Grader = graderUnion + } + + jobParams.Method = openai.FineTuningJobNewParamsMethod{ + Type: "reinforcement", + Reinforcement: reinforcementMethod, + } + } + + // Set integrations if provided + if len(config.Integrations) > 0 { + var integrations []openai.FineTuningJobNewParamsIntegration + + for _, integration := range config.Integrations { + if integration.Type == "" || integration.Type == "wandb" { + + wandbConfigJSON, err := json.Marshal(integration.Config) + if err != nil { + return nil, err + } + + var wandbConfig openai.FineTuningJobNewParamsIntegrationWandb + err = json.Unmarshal(wandbConfigJSON, &wandbConfig) + if err != nil { + return nil, err + } + integrations = append(integrations, openai.FineTuningJobNewParamsIntegration{ + Type: "wandb", + Wandb: wandbConfig, + }) + } + } + + if len(integrations) > 0 { + jobParams.Integrations = integrations + } + } + + return &jobParams, nil +} + +// convertHyperparameterToInt converts interface{} hyperparameter to *int64 +func convertHyperparameterToInt(value interface{}) *int64 { + if value == nil { + return nil + } + switch v := value.(type) { + case int: + val := int64(v) + return &val + case int64: + return &v + case float64: + val := int64(v) + return &val + case string: + // "auto" string handled separately + return nil + default: + return nil + } +} + +// convertHyperparameterToFloat converts interface{} hyperparameter to *float64 +func convertHyperparameterToFloat(value interface{}) *float64 { + if value == nil { + return nil + } + switch v := value.(type) { + case int: + val := float64(v) + return &val + case int64: + val := float64(v) + return &val + case float64: + return &v + case string: + // "auto" string handled separately + return nil + default: + return nil + } +} + +func getReasoningEffortValue(effort string) openai.ReinforcementHyperparametersReasoningEffort { + + switch strings.ToLower(effort) { + case "low": + return openai.ReinforcementHyperparametersReasoningEffortLow + case "medium": + return openai.ReinforcementHyperparametersReasoningEffortMedium + case "high": + return openai.ReinforcementHyperparametersReasoningEffortHigh + default: + return openai.ReinforcementHyperparametersReasoningEffortDefault + } +} diff --git a/cli/azd/extensions/azure.ai.finetune/internal/providers/openai/provider.go b/cli/azd/extensions/azure.ai.finetune/internal/providers/openai/provider.go index f0e99e0a19a..5ed9c1404a9 100644 --- a/cli/azd/extensions/azure.ai.finetune/internal/providers/openai/provider.go +++ b/cli/azd/extensions/azure.ai.finetune/internal/providers/openai/provider.go @@ -5,10 +5,14 @@ package openai import ( "context" - - "github.com/openai/openai-go/v3" + "fmt" + "os" + "time" "azure.ai.finetune/pkg/models" + "github.com/azure/azure-dev/cli/azd/pkg/ux" + "github.com/fatih/color" + "github.com/openai/openai-go/v3" ) // OpenAIProvider implements the provider interface for OpenAI APIs @@ -25,11 +29,18 @@ func NewOpenAIProvider(client *openai.Client) *OpenAIProvider { // CreateFineTuningJob creates a new fine-tuning job via OpenAI API func (p *OpenAIProvider) CreateFineTuningJob(ctx context.Context, req *models.CreateFineTuningRequest) (*models.FineTuningJob, error) { - // TODO: Implement - // 1. Convert domain model to OpenAI SDK format - // 2. Call OpenAI SDK CreateFineTuningJob - // 3. Convert OpenAI response to domain model - return nil, nil + + params, err := convertInternalJobParamToOpenAiJobParams(req) + if err != nil { + return nil, fmt.Errorf("failed to convert internal model to openai: %w", err) + } + + job, err := p.client.FineTuning.Jobs.New(ctx, *params) + if err != nil { + return nil, fmt.Errorf("failed to create fine-tuning job: %w", err) + } + + return convertOpenAIJobToModel(*job), nil } // GetFineTuningStatus retrieves the status of a fine-tuning job @@ -121,8 +132,60 @@ func (p *OpenAIProvider) CancelJob(ctx context.Context, jobID string) (*models.F // UploadFile uploads a file for fine-tuning func (p *OpenAIProvider) UploadFile(ctx context.Context, filePath string) (string, error) { - // TODO: Implement - return "", nil + if filePath == "" { + return "", fmt.Errorf("file path cannot be empty") + } + + // Show spinner while creating job + spinner := ux.NewSpinner(&ux.SpinnerOptions{ + Text: "uploading the file for fine-tuning", + }) + if err := spinner.Start(ctx); err != nil { + fmt.Printf("failed to start spinner: %v\n", err) + } + + file, err := os.Open(filePath) + if err != nil { + _ = spinner.Stop(ctx) + return "", fmt.Errorf("\nfailed to open file %s: %w", filePath, err) + } + defer file.Close() + + uploadedFile, err := p.client.Files.New(ctx, openai.FileNewParams{ + File: file, + Purpose: openai.FilePurposeFineTune, + }) + + if err != nil { + _ = spinner.Stop(ctx) + return "", fmt.Errorf("\nfailed to upload file: %w", err) + } + + if uploadedFile == nil || uploadedFile.ID == "" { + _ = spinner.Stop(ctx) + return "", fmt.Errorf("\nuploaded file is empty") + } + + // Poll for file processing status + for { + f, err := p.client.Files.Get(ctx, uploadedFile.ID) + if err != nil { + _ = spinner.Stop(ctx) + return "", fmt.Errorf("\nfailed to check file status: %w", err) + } + if f.Status == openai.FileObjectStatusProcessed { + _ = spinner.Stop(ctx) + break + } + if f.Status == openai.FileObjectStatusError { + _ = spinner.Stop(ctx) + return "", fmt.Errorf("\nfile processing failed with status: %s", f.Status) + } + color.Yellow(".") + time.Sleep(2 * time.Second) + } + + return uploadedFile.ID, nil } // GetUploadedFile retrieves information about an uploaded file diff --git a/cli/azd/extensions/azure.ai.finetune/internal/services/finetune_service.go b/cli/azd/extensions/azure.ai.finetune/internal/services/finetune_service.go index 054583eee5c..b257e9f9ad8 100644 --- a/cli/azd/extensions/azure.ai.finetune/internal/services/finetune_service.go +++ b/cli/azd/extensions/azure.ai.finetune/internal/services/finetune_service.go @@ -6,13 +6,15 @@ package services import ( "context" "fmt" - - "github.com/azure/azure-dev/cli/azd/pkg/azdext" + "os" "azure.ai.finetune/internal/providers" "azure.ai.finetune/internal/providers/factory" "azure.ai.finetune/internal/utils" + Utils "azure.ai.finetune/internal/utils" "azure.ai.finetune/pkg/models" + "github.com/azure/azure-dev/cli/azd/pkg/azdext" + "github.com/fatih/color" ) // Ensure fineTuningServiceImpl implements FineTuningService interface @@ -41,13 +43,62 @@ func NewFineTuningService(ctx context.Context, azdClient *azdext.AzdClient, stat // CreateFineTuningJob creates a new fine-tuning job with business validation func (s *fineTuningServiceImpl) CreateFineTuningJob(ctx context.Context, req *models.CreateFineTuningRequest) (*models.FineTuningJob, error) { - // TODO: Implement - // 1. Validate request (model exists, data size valid, etc.) - // 2. Call provider.CreateFineTuningJob() - // 3. Transform any errors to standardized ErrorDetail - // 4. Persist job to state store - // 5. Return job - return nil, nil + // Validate request + if req == nil { + return nil, fmt.Errorf("request cannot be nil") + } + if req.BaseModel == "" { + return nil, fmt.Errorf("base model is required") + } + if req.TrainingFile == "" { + return nil, fmt.Errorf("training file is required") + } + + if Utils.IsLocalFilePath(req.TrainingFile) { + color.Green("\nuploading training file...") + + trainingDataID, err := s.UploadFile(ctx, Utils.GetLocalFilePath(req.TrainingFile)) + if err != nil { + return nil, fmt.Errorf("failed to upload training file: %w", err) + } + req.TrainingFile = trainingDataID + } else { + color.Yellow("\nProvided training file is non-local, skipping upload...") + } + + // Upload validation file if provided + if req.ValidationFile != nil && *req.ValidationFile != "" { + if Utils.IsLocalFilePath(*req.ValidationFile) { + color.Green("\nuploading validation file...") + validationDataID, err := s.UploadFile(ctx, Utils.GetLocalFilePath(*req.ValidationFile)) + if err != nil { + return nil, fmt.Errorf("failed to upload validation file: %w", err) + } + req.ValidationFile = &validationDataID + } else { + color.Yellow("\nProvided validation file is non-local, skipping upload...") + } + } + + // Call provider with retry logic + var job *models.FineTuningJob + err := utils.RetryOperation(ctx, utils.DefaultRetryConfig(), func() error { + var err error + job, err = s.provider.CreateFineTuningJob(ctx, req) + return err + }) + if err != nil { + return nil, fmt.Errorf("failed to create fine-tuning job: %w", err) + } + + // Persist job to state store if available + if s.stateStore != nil { + if err := s.stateStore.SaveJob(ctx, job); err != nil { + return nil, fmt.Errorf("failed to persist job: %w", err) + } + } + + return job, nil } // GetFineTuningStatus retrieves the current status of a job @@ -146,16 +197,39 @@ func (s *fineTuningServiceImpl) CancelJob(ctx context.Context, jobID string) (*m return nil, nil } -// UploadTrainingFile uploads and validates a training file -func (s *fineTuningServiceImpl) UploadTrainingFile(ctx context.Context, filePath string) (string, error) { - // TODO: Implement - return "", nil +// UploadFile uploads and validates a file +func (s *fineTuningServiceImpl) UploadFile(ctx context.Context, filePath string) (string, error) { + if filePath == "" { + return "", fmt.Errorf("file path cannot be empty") + } + uploadedFileId, err := s._uploadFile(ctx, filePath) + if err != nil || uploadedFileId == "" { + return "", fmt.Errorf("failed to upload file: %w", err) + } + return uploadedFileId, nil } -// UploadValidationFile uploads and validates a validation file -func (s *fineTuningServiceImpl) UploadValidationFile(ctx context.Context, filePath string) (string, error) { - // TODO: Implement - return "", nil +func (s *fineTuningServiceImpl) _uploadFile(ctx context.Context, filePath string) (string, error) { + // validate file existence + fileInfo, err := os.Stat(filePath) + if err != nil { + if os.IsNotExist(err) { + return "", fmt.Errorf("file does not exist: %s", filePath) + } + return "", fmt.Errorf("failed to stat file %s: %w", filePath, err) + } + if fileInfo.IsDir() { + return "", fmt.Errorf("path is a directory, not a file: %s", filePath) + } + + // upload file with retry + uploadedFileId := "" + err = utils.RetryOperation(ctx, utils.DefaultRetryConfig(), func() error { + var err error + uploadedFileId, err = s.provider.UploadFile(ctx, filePath) + return err + }) + return uploadedFileId, err } // PollJobUntilCompletion polls a job until it completes or fails diff --git a/cli/azd/extensions/azure.ai.finetune/internal/services/interface.go b/cli/azd/extensions/azure.ai.finetune/internal/services/interface.go index c4d20d13c9d..4bceba7daa2 100644 --- a/cli/azd/extensions/azure.ai.finetune/internal/services/interface.go +++ b/cli/azd/extensions/azure.ai.finetune/internal/services/interface.go @@ -38,11 +38,8 @@ type FineTuningService interface { // CancelJob cancels a job with proper state validation CancelJob(ctx context.Context, jobID string) (*models.FineTuningJob, error) - // UploadTrainingFile uploads and validates a training file - UploadTrainingFile(ctx context.Context, filePath string) (string, error) - - // UploadValidationFile uploads and validates a validation file - UploadValidationFile(ctx context.Context, filePath string) (string, error) + // UploadFile uploads and validates a file + UploadFile(ctx context.Context, filePath string) (string, error) // PollJobUntilCompletion polls a job until it completes or fails PollJobUntilCompletion(ctx context.Context, jobID string, intervalSeconds int) (*models.FineTuningJob, error) diff --git a/cli/azd/extensions/azure.ai.finetune/internal/utils/common.go b/cli/azd/extensions/azure.ai.finetune/internal/utils/common.go new file mode 100644 index 00000000000..491e04a4a22 --- /dev/null +++ b/cli/azd/extensions/azure.ai.finetune/internal/utils/common.go @@ -0,0 +1,18 @@ +package utils + +func IsLocalFilePath(fileID string) bool { + if fileID == "" { + return false + } + if len(fileID) > 6 && fileID[:6] == "local:" { + return true + } + return false +} + +func GetLocalFilePath(fileID string) string { + if IsLocalFilePath(fileID) { + return fileID[6:] + } + return fileID +} diff --git a/cli/azd/extensions/azure.ai.finetune/internal/utils/parser.go b/cli/azd/extensions/azure.ai.finetune/internal/utils/parser.go new file mode 100644 index 00000000000..8c487a1b7a2 --- /dev/null +++ b/cli/azd/extensions/azure.ai.finetune/internal/utils/parser.go @@ -0,0 +1,30 @@ +package utils + +import ( + "fmt" + "os" + + "azure.ai.finetune/pkg/models" + "github.com/braydonk/yaml" +) + +func ParseCreateFineTuningRequestConfig(filePath string) (*models.CreateFineTuningRequest, error) { + // Read the YAML file + yamlFile, err := os.ReadFile(filePath) + if err != nil { + return nil, fmt.Errorf("failed to read config file %s: %w", filePath, err) + } + + // Parse YAML into config struct + var config models.CreateFineTuningRequest + if err := yaml.Unmarshal(yamlFile, &config); err != nil { + return nil, fmt.Errorf("failed to parse YAML config: %w", err) + } + + // Validate the configuration + if err := config.Validate(); err != nil { + return nil, fmt.Errorf("invalid configuration: %w", err) + } + + return &config, nil +} diff --git a/cli/azd/extensions/azure.ai.finetune/pkg/models/finetune.go b/cli/azd/extensions/azure.ai.finetune/pkg/models/finetune.go index 9a943580c52..35f511667fd 100644 --- a/cli/azd/extensions/azure.ai.finetune/pkg/models/finetune.go +++ b/cli/azd/extensions/azure.ai.finetune/pkg/models/finetune.go @@ -3,7 +3,10 @@ package models -import "time" +import ( + "fmt" + "time" +) // JobStatus represents the status of a fine-tuning job type JobStatus string @@ -18,6 +21,15 @@ const ( StatusPaused JobStatus = "paused" ) +// Represents the type of method used for fine-tuning +type MethodType string + +const ( + Supervised MethodType = "supervised" + DPO MethodType = "dpo" + Reinforcement MethodType = "reinforcement" +) + // FineTuningJob represents a vendor-agnostic fine-tuning job type FineTuningJob struct { // Core identification @@ -42,14 +54,6 @@ type FineTuningJob struct { ErrorDetails *ErrorDetail } -// CreateFineTuningRequest represents a request to create a fine-tuning job -type CreateFineTuningRequest struct { - BaseModel string - TrainingDataID string - ValidationDataID string - Hyperparameters *Hyperparameters -} - // Hyperparameters represents fine-tuning hyperparameters type Hyperparameters struct { BatchSize int64 @@ -115,3 +119,193 @@ type CheckpointMetrics struct { FullValidLoss float64 FullValidMeanTokenAccuracy float64 } + +// CreateFineTuningRequest represents a request to create a fine-tuning job +type CreateFineTuningRequest struct { + // Required: The name of the model to fine-tune + BaseModel string `yaml:"model"` + + // Required: Path to training file + // Format: "file-id" or "local:/path/to/file.jsonl" + TrainingFile string `yaml:"training_file"` + + // Optional: Path to validation file + ValidationFile *string `yaml:"validation_file,omitempty"` + + // Optional: Suffix for the fine-tuned model name (up to 64 characters) + // Example: "custom-model-name" produces "ft:gpt-4o-mini:openai:custom-model-name:7p4lURel" + Suffix *string `yaml:"suffix,omitempty"` + + // Optional: Random seed for reproducibility + Seed *int64 `yaml:"seed,omitempty"` + + // Optional: Custom metadata for the fine-tuning job + // Max 16 key-value pairs, keys max 64 chars, values max 512 chars + Metadata map[string]string `yaml:"metadata,omitempty"` + + // Optional: Fine-tuning method configuration (supervised, dpo, or reinforcement) + Method MethodConfig `yaml:"method,omitempty"` + + // Optional: Integrations to enable (e.g., wandb for Weights & Biases) + Integrations []Integration `yaml:"integrations,omitempty"` + + // Optional: Additional request body fields not covered by standard config + ExtraBody map[string]interface{} `yaml:"extra_body,omitempty"` +} + +// MethodConfig represents fine-tuning method configuration +type MethodConfig struct { + // Type of fine-tuning method: "supervised", "dpo", or "reinforcement" + Type string `yaml:"type"` + + // Supervised fine-tuning configuration + Supervised *SupervisedConfig `yaml:"supervised,omitempty"` + + // Direct Preference Optimization (DPO) configuration + DPO *DPOConfig `yaml:"dpo,omitempty"` + + // Reinforcement learning fine-tuning configuration + Reinforcement *ReinforcementConfig `yaml:"reinforcement,omitempty"` +} + +// SupervisedConfig represents supervised fine-tuning method configuration +// Suitable for standard supervised learning tasks +type SupervisedConfig struct { + Hyperparameters HyperparametersConfig `yaml:"hyperparameters,omitempty"` +} + +// DPOConfig represents Direct Preference Optimization (DPO) configuration +// DPO is used for preference-based fine-tuning +type DPOConfig struct { + Hyperparameters HyperparametersConfig `yaml:"hyperparameters,omitempty"` +} + +// ReinforcementConfig represents reinforcement learning fine-tuning configuration +// Suitable for reasoning models that benefit from reinforcement learning +type ReinforcementConfig struct { + // Grader configuration for reinforcement learning (evaluates model outputs) + Grader map[string]interface{} `yaml:"grader,omitempty"` + + // Hyperparameters specific to reinforcement learning + Hyperparameters HyperparametersConfig `yaml:"hyperparameters,omitempty"` +} + +// HyperparametersConfig represents hyperparameter configuration +// Values can be integers, floats, or "auto" for automatic configuration +type HyperparametersConfig struct { + // Number of training epochs + // Can be: integer (1-10), "auto" + Epochs interface{} `yaml:"epochs,omitempty"` + + // Batch size for training + // Can be: integer (1, 8, 16, 32, 64, 128), "auto" + BatchSize interface{} `yaml:"batch_size,omitempty"` + + // Learning rate multiplier + // Can be: float (0.1-2.0), "auto" + LearningRateMultiplier interface{} `yaml:"learning_rate_multiplier,omitempty"` + + // Weight for prompt loss in supervised learning (0.0-1.0) + PromptLossWeight *float64 `yaml:"prompt_loss_weight,omitempty"` + + // Beta parameter for DPO (temperature-like parameter) + // Can be: float, "auto" + Beta interface{} `yaml:"beta,omitempty"` + + // Compute multiplier for reinforcement learning + // Multiplier on amount of compute used for exploring search space during training + // Can be: float, "auto" + ComputeMultiplier interface{} `yaml:"compute_multiplier,omitempty"` + + // Reasoning effort level for reinforcement learning with reasoning models + // Options: "low", "medium", "high" + ReasoningEffort string `yaml:"reasoning_effort,omitempty"` + + // Evaluation interval for reinforcement learning + // Number of training steps between evaluation runs + // Can be: integer, "auto" + EvalInterval interface{} `yaml:"eval_interval,omitempty"` + + // Evaluation samples for reinforcement learning + // Number of evaluation samples to generate per training step + // Can be: integer, "auto" + EvalSamples interface{} `yaml:"eval_samples,omitempty"` +} + +// Integration represents integration configuration (e.g., Weights & Biases) +type Integration struct { + // Type of integration: "wandb" (Weights & Biases), etc. + Type string `yaml:"type"` + + // Integration-specific configuration (API keys, project names, etc.) + Config map[string]interface{} `yaml:"config,omitempty"` +} + +// Validate checks if the configuration is valid +func (c CreateFineTuningRequest) Validate() error { + // Validate required fields + if c.BaseModel == "" { + return fmt.Errorf("model is required") + } + + if c.TrainingFile == "" { + return fmt.Errorf("training_file is required") + } + + // Validate method if provided + if c.Method.Type != "" { + if c.Method.Type != string(Supervised) && c.Method.Type != string(DPO) && c.Method.Type != string(Reinforcement) { + return fmt.Errorf("invalid method type: %s (must be 'supervised', 'dpo', or 'reinforcement')", c.Method.Type) + } + + // Validate method-specific configuration + switch c.Method.Type { + case string(Supervised): + if c.Method.Supervised == nil { + return fmt.Errorf("supervised method requires 'supervised' configuration block") + } + case string(DPO): + if c.Method.DPO == nil { + return fmt.Errorf("dpo method requires 'dpo' configuration block") + } + case string(Reinforcement): + if c.Method.Reinforcement == nil { + return fmt.Errorf("reinforcement method requires 'reinforcement' configuration block") + } + } + } + + // Validate integrations if provided + if len(c.Integrations) > 0 { + for _, integration := range c.Integrations { + if integration.Type == "" { + return fmt.Errorf("integration type is required if integrations are specified") + } + if integration.Config == nil { + return fmt.Errorf("integration of type '%s' requires 'config' block", integration.Type) + } + } + } + + // Validate suffix length if provided + if c.Suffix != nil && len(*c.Suffix) > 64 { + return fmt.Errorf("suffix exceeds maximum length of 64 characters: %d", len(*c.Suffix)) + } + + // Validate metadata constraints + if c.Metadata != nil { + if len(c.Metadata) > 16 { + return fmt.Errorf("metadata exceeds maximum of 16 key-value pairs: %d", len(c.Metadata)) + } + for k, v := range c.Metadata { + if len(k) > 64 { + return fmt.Errorf("metadata key exceeds maximum length of 64 characters: %s", k) + } + if len(v) > 512 { + return fmt.Errorf("metadata value exceeds maximum length of 512 characters for key: %s", k) + } + } + } + + return nil +}