diff --git a/cmd/base64.go b/cmd/base64.go index a15b0c6..5452763 100644 --- a/cmd/base64.go +++ b/cmd/base64.go @@ -33,7 +33,6 @@ var base64Cmd = &cobra.Command{ if err != nil { log.Fatal(err) } - base64String, err := base64.Process(inputData, formatType) if err != nil { log.Fatal(err) @@ -49,7 +48,6 @@ var base64Cmd = &cobra.Command{ // init - cobra init function func init() { rootCmd.AddCommand(base64Cmd) - requiredFlags := map[string]bool{ "in": true, } @@ -58,4 +56,5 @@ func init() { base64Cmd.PersistentFlags().String(base64.FormatFlagName, base64.TextFormat, base64.FormatFlagDescription) base64Cmd.PersistentFlags().String(base64.OutputFlagName, "", base64.OutputFlagDescription) common.SetCustomHelpTemplate(base64Cmd, requiredFlags) + common.SetCustomErrorTemplate(base64Cmd) } diff --git a/cmd/base64Tgz.go b/cmd/base64Tgz.go index ef8293c..d529efd 100644 --- a/cmd/base64Tgz.go +++ b/cmd/base64Tgz.go @@ -59,4 +59,5 @@ func init() { base64TgzCmd.PersistentFlags().String(base64Tgz.CertFlagName, "", base64Tgz.CertPathDescription) base64TgzCmd.PersistentFlags().String(base64Tgz.OutputFlagName, "", base64Tgz.OutputPathDescription) common.SetCustomHelpTemplate(base64TgzCmd, requiredFlags) + common.SetCustomErrorTemplate(base64TgzCmd) } diff --git a/cmd/decryptAttestation.go b/cmd/decryptAttestation.go index 901e957..2b70503 100644 --- a/cmd/decryptAttestation.go +++ b/cmd/decryptAttestation.go @@ -54,8 +54,9 @@ func init() { "in": true, "priv": true, } - decryptAttestationCmd.PersistentFlags().String(decryptAttestation.InputFlagName, decryptAttestation.DecryptAttestFileInDefaultPath, decryptAttestation.DecryptAttestFileInDescription) + decryptAttestationCmd.PersistentFlags().String(decryptAttestation.InputFlagName, "", decryptAttestation.DecryptAttestFileInDescription) decryptAttestationCmd.PersistentFlags().String(decryptAttestation.PrivateKeyFlagName, "", decryptAttestation.PrivateKeyFlagDescription) decryptAttestationCmd.PersistentFlags().String(decryptAttestation.OutputFlagName, "", decryptAttestation.DecryptAttestFlagDescription) common.SetCustomHelpTemplate(decryptAttestationCmd, requiredFlags) + common.SetCustomErrorTemplate(decryptAttestationCmd) } diff --git a/cmd/downloadCertificates.go b/cmd/downloadCertificates.go index e0f4d36..4716e29 100644 --- a/cmd/downloadCertificates.go +++ b/cmd/downloadCertificates.go @@ -31,7 +31,7 @@ var ( Short: downloadCertificate.ParameterShortDescription, Long: downloadCertificate.ParameterLongDescription, Run: func(cmd *cobra.Command, args []string) { - formatType, certificatePath, err := downloadCertificate.ValidateInput(cmd) + formatType, certificatePath, err := downloadCertificate.ValidateInput(cmd, versions) if err != nil { log.Fatal(err) } @@ -62,4 +62,5 @@ func init() { downloadCertificatesCmd.PersistentFlags().String(downloadCertificate.FormatFlag, downloadCertificate.JsonFormat, downloadCertificate.DataFormatFlag) downloadCertificatesCmd.PersistentFlags().String(downloadCertificate.OutputFlagName, "", downloadCertificate.OutputPathDescription) common.SetCustomHelpTemplate(downloadCertificatesCmd, requiredFlags) + common.SetCustomErrorTemplate(downloadCertificatesCmd) } diff --git a/cmd/encrypt.go b/cmd/encrypt.go index cf6b0c4..ef3bce7 100644 --- a/cmd/encrypt.go +++ b/cmd/encrypt.go @@ -78,4 +78,5 @@ func init() { encryptCmd.PersistentFlags().String(encrypt.CsrFlag, "", encrypt.CsrFlagDescription) encryptCmd.PersistentFlags().Int(encrypt.ExpiryDaysFlag, 0, encrypt.ExpiryDaysFlagDescription) common.SetCustomHelpTemplate(encryptCmd, requiredFlags) + common.SetCustomErrorTemplate(encryptCmd) } diff --git a/cmd/encryptString.go b/cmd/encryptString.go index eb8d063..3ebe9a8 100644 --- a/cmd/encryptString.go +++ b/cmd/encryptString.go @@ -59,4 +59,5 @@ func init() { encryptStringCmd.PersistentFlags().String(encryptString.CertFlagName, "", encryptString.CertFlagDescription) encryptStringCmd.PersistentFlags().String(encryptString.OutputFlagName, "", encryptString.OutputFlagDescription) common.SetCustomHelpTemplate(encryptStringCmd, requiredFlags) + common.SetCustomErrorTemplate(encryptStringCmd) } diff --git a/cmd/getCertificate.go b/cmd/getCertificate.go index 26500f5..b6499d0 100644 --- a/cmd/getCertificate.go +++ b/cmd/getCertificate.go @@ -58,4 +58,5 @@ func init() { getCertificateCmd.PersistentFlags().String(getCertificate.VersionFlagName, "", getCertificate.VersionFlagDescription) getCertificateCmd.PersistentFlags().String(getCertificate.OutputFlagName, "", getCertificate.FileOutFlagDescription) common.SetCustomHelpTemplate(getCertificateCmd, requiredFlags) + common.SetCustomErrorTemplate(getCertificateCmd) } diff --git a/cmd/image.go b/cmd/image.go index 917d44a..ed1674a 100644 --- a/cmd/image.go +++ b/cmd/image.go @@ -69,4 +69,5 @@ func init() { imageCmd.PersistentFlags().String(image.FormatFlag, image.JsonFormat, image.DataFormatFlagDescription) imageCmd.PersistentFlags().String(image.OutputFlagName, "", image.OutputFlagDescription) common.SetCustomHelpTemplate(imageCmd, requiredFlags) + common.SetCustomErrorTemplate(imageCmd) } diff --git a/cmd/validateContract.go b/cmd/validateContract.go index 3c3170d..789096d 100644 --- a/cmd/validateContract.go +++ b/cmd/validateContract.go @@ -63,4 +63,5 @@ func init() { validateContractCmd.PersistentFlags().String(validateContract.InputFlagName, "", validateContract.InputFlagDescription) validateContractCmd.PersistentFlags().String(validateContract.OsVersionFlagName, "", validateContract.OsVersionFlagDescription) common.SetCustomHelpTemplate(validateContractCmd, requiredFlags) + common.SetCustomErrorTemplate(validateContractCmd) } diff --git a/cmd/validateEncryptionCertificate.go b/cmd/validateEncryptionCertificate.go index 3407f47..fc7a7e2 100644 --- a/cmd/validateEncryptionCertificate.go +++ b/cmd/validateEncryptionCertificate.go @@ -57,4 +57,5 @@ func init() { } validateEncryptionCertificateCmd.PersistentFlags().String(validateEncryptionCertificate.InputFlagName, "", validateEncryptionCertificate.CertVersionFlagDescription) common.SetCustomHelpTemplate(validateEncryptionCertificateCmd, requiredFlags) + common.SetCustomErrorTemplate(validateEncryptionCertificateCmd) } diff --git a/cmd/validateNetwork.go b/cmd/validateNetwork.go index 60fc93c..01e452d 100644 --- a/cmd/validateNetwork.go +++ b/cmd/validateNetwork.go @@ -62,4 +62,5 @@ func init() { } validateNetworkConfigCmd.PersistentFlags().String(validateNetwork.InputFlagName, "", validateNetwork.InputFlagDescription) common.SetCustomHelpTemplate(validateNetworkConfigCmd, requiredFlags) + common.SetCustomErrorTemplate(validateNetworkConfigCmd) } diff --git a/common/common.go b/common/common.go index 17666cc..bdb677c 100644 --- a/common/common.go +++ b/common/common.go @@ -196,3 +196,23 @@ func SetCustomHelpTemplate(cmd *cobra.Command, requiredFlags map[string]bool) { printFlags(false) }) } + +// SetCustomErrorTemplate function customizes Cobra's default flag error handling to +// print a spaced error message followed by command usage. +func SetCustomErrorTemplate(cmd *cobra.Command) { + cmd.SetFlagErrorFunc(func(cmd *cobra.Command, err error) error { + SetMandatoryFlagError(cmd, err) + return nil + }) + + cmd.SilenceUsage = true +} + +// SetMandatoryFlagError function print error if required flag is missing +func SetMandatoryFlagError(cmd *cobra.Command, err error) { + out := cmd.ErrOrStderr() + fmt.Fprintln(out, err) + fmt.Fprintln(out) + cmd.Usage() + os.Exit(1) +} diff --git a/go.sum b/go.sum index 35f0a3c..cabc783 100644 --- a/go.sum +++ b/go.sum @@ -3,8 +3,6 @@ github.com/Masterminds/semver/v3 v3.4.0/go.mod h1:4V+yj/TJE1HU9XfppCwVMZq3I84lpr github.com/cpuguy83/go-md2man/v2 v2.0.6/go.mod h1:oOW0eioCTA6cOiMLiUPZOpcVxMig6NIQQ7OS05n1F4g= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= -github.com/ibm-hyper-protect/contract-go/v2 v2.5.0 h1:4A/sm/dHSNUci37mImOCMWRWAEIoiuu/PUslYAr6MNQ= -github.com/ibm-hyper-protect/contract-go/v2 v2.5.0/go.mod h1:i0Bb/Ko6N2RlCQiCsOqQw6bElpfLipYjXDYi3HZGejs= github.com/ibm-hyper-protect/contract-go/v2 v2.5.1 h1:rClqsX+fYcko0ZxsTMM723o4cx7v+UwMqlbIGvGmXwo= github.com/ibm-hyper-protect/contract-go/v2 v2.5.1/go.mod h1:i0Bb/Ko6N2RlCQiCsOqQw6bElpfLipYjXDYi3HZGejs= github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8= diff --git a/lib/base64/base64.go b/lib/base64/base64.go index 3bca056..2058ddf 100644 --- a/lib/base64/base64.go +++ b/lib/base64/base64.go @@ -54,6 +54,11 @@ func ValidateInput(cmd *cobra.Command) (string, string, string, error) { return "", "", "", err } + if inputData == "" { + err := fmt.Errorf("Error: required flag '--in' is missing") + common.SetMandatoryFlagError(cmd, err) + } + formatType, err := cmd.Flags().GetString(FormatFlagName) if err != nil { return "", "", "", err @@ -72,10 +77,6 @@ func Process(inputData, formatType string) (string, error) { var base64String string var err error - if inputData == "" { - return "", fmt.Errorf(inputMissingMessageBase64) - } - if formatType == TextFormat { base64String, _, _, err = contract.HpcrText(inputData) if err != nil { diff --git a/lib/base64Tgz/base64Tgz.go b/lib/base64Tgz/base64Tgz.go index 4f32c06..c1965fb 100644 --- a/lib/base64Tgz/base64Tgz.go +++ b/lib/base64Tgz/base64Tgz.go @@ -53,6 +53,11 @@ func ValidateInput(cmd *cobra.Command) (string, string, string, string, string, return "", "", "", "", "", err } + if inputData == "" { + err := fmt.Errorf("Error: required flag '--in' is missing") + common.SetMandatoryFlagError(cmd, err) + } + outputFormat, err := cmd.Flags().GetString(OutputFormatFlag) if err != nil { return "", "", "", "", "", err diff --git a/lib/decryptAttestation/decryptAttestation.go b/lib/decryptAttestation/decryptAttestation.go index 3068bc9..ee9157c 100644 --- a/lib/decryptAttestation/decryptAttestation.go +++ b/lib/decryptAttestation/decryptAttestation.go @@ -18,6 +18,7 @@ package decryptAttestation import ( "fmt" "log" + "strings" "github.com/ibm-hyper-protect/contract-cli/common" "github.com/ibm-hyper-protect/contract-go/v2/attestation" @@ -32,7 +33,6 @@ const ( Attestation records are typically found at /var/hyperprotect/se-checksums.txt.enc and contain cryptographic hashes for verifying workload integrity.` - DecryptAttestFileInDefaultPath = "build/se-checksums.txt.enc" DecryptAttestFileInDescription = "Path to encrypted attestation file (se-checksums.txt.enc)" DecryptAttestFlagDescription = "Path to save decrypted attestation records" successMessageDecryptAttestation = "Successfully decrypted attestation records" @@ -59,13 +59,40 @@ func ValidateInput(cmd *cobra.Command) (string, string, string, error) { return "", "", "", err } + requiredFlags := map[string]string{ + "--in": encAttestPath, + "--priv": privateKeyPath, + } + + var missing []string + for flag, val := range requiredFlags { + if val == "" { + missing = append(missing, flag) + } + } + + if len(missing) > 0 { + if len(missing) == 1 { + err := fmt.Errorf("Error: required flag %s is missing.", + strings.Join(missing, ", ")) + common.SetMandatoryFlagError(cmd, err) + } else { + err := fmt.Errorf("Error: required flag %s are missing.", + strings.Join(missing, ", ")) + common.SetMandatoryFlagError(cmd, err) + } + } + return encAttestPath, privateKeyPath, decryptedAttestPath, nil } // DecryptAttestationRecords - function to decrypt attestation records func DecryptAttestationRecords(encryptedAttestationRecordsPath, privateKeyPath string) (string, error) { - if !common.CheckFileFolderExists(encryptedAttestationRecordsPath) || !common.CheckFileFolderExists(privateKeyPath) { - log.Fatal("The path to encrypted attestation records file or private key doesn't exists") + if !common.CheckFileFolderExists(encryptedAttestationRecordsPath) { + log.Fatal("The path to encrypted attestation records file doesn't exists") + } + if !common.CheckFileFolderExists(privateKeyPath) { + log.Fatal("The path to private key doesn't exists") } encryptedChecksum, err := common.ReadDataFromFile(encryptedAttestationRecordsPath) diff --git a/lib/downloadCertificate/downloadCertificate.go b/lib/downloadCertificate/downloadCertificate.go index 21ed99b..7dcd021 100644 --- a/lib/downloadCertificate/downloadCertificate.go +++ b/lib/downloadCertificate/downloadCertificate.go @@ -40,7 +40,7 @@ for contract encryption and workload deployment.` ) // ValidateInput - function to validate download-certificate inputs -func ValidateInput(cmd *cobra.Command) (string, string, error) { +func ValidateInput(cmd *cobra.Command, versions []string) (string, string, error) { formatType, err := cmd.Flags().GetString(FormatFlag) if err != nil { return "", "", err @@ -50,6 +50,10 @@ func ValidateInput(cmd *cobra.Command) (string, string, error) { if err != nil { return "", "", err } + if len(versions) == 0 { + err := fmt.Errorf("Error: required flag '--version' is missing") + common.SetMandatoryFlagError(cmd, err) + } return formatType, certificatePath, nil } diff --git a/lib/encrypt/encrypt.go b/lib/encrypt/encrypt.go index bc58033..7ef6055 100644 --- a/lib/encrypt/encrypt.go +++ b/lib/encrypt/encrypt.go @@ -63,6 +63,11 @@ func ValidateInput(cmd *cobra.Command) (string, string, string, string, string, return "", "", "", "", "", err } + if inputData == "" { + err := fmt.Errorf("Error: required flag '--in' is missing") + common.SetMandatoryFlagError(cmd, err) + } + osVersion, err := cmd.Flags().GetString(OsVersionFlagName) if err != nil { return "", "", "", "", "", err diff --git a/lib/encryptString/encryptString.go b/lib/encryptString/encryptString.go index 89c244f..8d5fdce 100644 --- a/lib/encryptString/encryptString.go +++ b/lib/encryptString/encryptString.go @@ -51,6 +51,10 @@ func ValidateInput(cmd *cobra.Command) (string, string, string, string, string, if err != nil { return "", "", "", "", "", err } + if inputData == "" { + err := fmt.Errorf("Error: required flag '--in' is missing") + common.SetMandatoryFlagError(cmd, err) + } inputFormat, err := cmd.Flags().GetString(FormatFlag) if err != nil { diff --git a/lib/getCertificate/getCertificate.go b/lib/getCertificate/getCertificate.go index 9c6da2c..7794062 100644 --- a/lib/getCertificate/getCertificate.go +++ b/lib/getCertificate/getCertificate.go @@ -17,6 +17,7 @@ package getCertificate import ( "fmt" + "strings" "github.com/ibm-hyper-protect/contract-cli/common" "github.com/ibm-hyper-protect/contract-go/v2/certificate" @@ -55,6 +56,30 @@ func ValidateInput(cmd *cobra.Command) (string, string, string, error) { return "", "", "", err } + requiredFlags := map[string]string{ + "--in": encryptionCertsPath, + "--version": version, + } + + var missing []string + for flag, val := range requiredFlags { + if val == "" { + missing = append(missing, flag) + } + } + + if len(missing) > 0 { + if len(missing) == 1 { + err := fmt.Errorf("Error: required flag %s is missing.", + strings.Join(missing, ", ")) + common.SetMandatoryFlagError(cmd, err) + } else { + err := fmt.Errorf("Error: required flag %s are missing.", + strings.Join(missing, ", ")) + common.SetMandatoryFlagError(cmd, err) + } + } + return encryptionCertsPath, version, encryptionCertificatePath, nil } diff --git a/lib/image/image.go b/lib/image/image.go index 9f16e6d..6a5949b 100644 --- a/lib/image/image.go +++ b/lib/image/image.go @@ -60,6 +60,11 @@ func ValidateInput(cmd *cobra.Command) (string, string, string, string, error) { if err != nil { return "", "", "", "", err } + if imageListJsonPath == "" { + err := fmt.Errorf("Error: required flag '--in' is missing") + common.SetMandatoryFlagError(cmd, err) + } + versionName, err := cmd.Flags().GetString(VersionFlagName) if err != nil { return "", "", "", "", err diff --git a/lib/validateContract/validateContract.go b/lib/validateContract/validateContract.go index c827c57..69eeab3 100644 --- a/lib/validateContract/validateContract.go +++ b/lib/validateContract/validateContract.go @@ -16,6 +16,9 @@ package validateContract import ( + "fmt" + + "github.com/ibm-hyper-protect/contract-cli/common" "github.com/spf13/cobra" ) @@ -38,6 +41,10 @@ func ValidateInput(cmd *cobra.Command) (string, string, error) { if err != nil { return "", "", err } + if contract == "" { + err := fmt.Errorf("Error: required flag '--in' is missing") + common.SetMandatoryFlagError(cmd, err) + } version, err := cmd.Flags().GetString(OsVersionFlagName) if err != nil { diff --git a/lib/validateEncryptionCertificate/validateEncryptionCertificate.go b/lib/validateEncryptionCertificate/validateEncryptionCertificate.go index 1c53411..9dc4b10 100644 --- a/lib/validateEncryptionCertificate/validateEncryptionCertificate.go +++ b/lib/validateEncryptionCertificate/validateEncryptionCertificate.go @@ -40,6 +40,10 @@ func ValidateInput(cmd *cobra.Command) (string, error) { if err != nil { return "", err } + if encryptionCertsPath == "" { + err := fmt.Errorf("Error: required flag '--in' is missing") + common.SetMandatoryFlagError(cmd, err) + } return encryptionCertsPath, nil } diff --git a/lib/validateNetwork/validateNetwork.go b/lib/validateNetwork/validateNetwork.go index a9f3f03..6765567 100644 --- a/lib/validateNetwork/validateNetwork.go +++ b/lib/validateNetwork/validateNetwork.go @@ -16,6 +16,9 @@ package validateNetwork import ( + "fmt" + + "github.com/ibm-hyper-protect/contract-cli/common" "github.com/spf13/cobra" ) @@ -36,6 +39,10 @@ func ValidateInput(cmd *cobra.Command) (string, error) { if err != nil { return "", err } + if networkConfig == "" { + err := fmt.Errorf("Error: required flag '--in' is missing") + common.SetMandatoryFlagError(cmd, err) + } return networkConfig, nil }