diff --git a/x509/ekcert.go b/x509/ekcert.go new file mode 100644 index 0000000..e811207 --- /dev/null +++ b/x509/ekcert.go @@ -0,0 +1,314 @@ +package x509ext + +import ( + "bytes" + "crypto/x509" + "crypto/x509/pkix" + "encoding/asn1" + "errors" + "fmt" + "math/big" + + "github.com/google/go-attestation/oid" +) + +var ( + // The DER encoding of an empty SEQUENCE is 0x30 0x00. + emptyASN1Subject = []byte{0x30, 0} + oidAuthorityInfoAccess = []int{1, 3, 6, 1, 5, 5, 7, 1, 1} + oidBasicConstraints = []int{2, 5, 29, 19} + oidSubjectKeyIdentifier = []int{2, 5, 29, 14} + oidKeyUsage = []int{2, 5, 29, 15} + oidCRLDistributionPoints = []int{2, 5, 29, 31} + oidAuthorityKeyID = []int{2, 5, 29, 35} + oidExtendedKeyUsage = []int{2, 5, 29, 37} + mustHaveExtensions = []asn1.ObjectIdentifier{ + oid.SubjectAltName, + oidBasicConstraints, + oidKeyUsage, + oidAuthorityKeyID, + } + oidToExtNameMap = map[string]string{ + (asn1.ObjectIdentifier)(oid.SubjectAltName).String(): "SubjectAltName", + (asn1.ObjectIdentifier)(oidBasicConstraints).String(): "BasicConstraints", + (asn1.ObjectIdentifier)(oidKeyUsage).String(): "Key Usage", + (asn1.ObjectIdentifier)(oidAuthorityKeyID).String(): "Authority Key Identifier", + } +) + +type attribute struct { + Type asn1.ObjectIdentifier + Values []asn1.RawValue `asn1:"set"` +} + +// TpmSpecification represents the TPM specification of an EK certificate. +type TpmSpecification struct { + Family string + Level int + Revision int +} + +// EKCertificate extends x509.certificate with helper methods for working with +// TCG EK Certificates. +type EKCertificate struct { + *x509.Certificate + TpmManufacturer, TpmModel, TpmVersion string + TpmSpecification TpmSpecification + // If the certificate contains a tcg-kp-EKCertificate (2.23.133.8.1) in the + // Extended Key Usage, this will be true. + HasEkcExtendedKeyUsage bool +} + +// ParseEKCertificate parses a single certificate from the given ASN.1 DER data. +func ParseEKCertificate(der []byte) (*EKCertificate, error) { + cert, err := x509.ParseCertificate(der) + if err != nil { + return nil, err + } + return ToEKCertificate(cert) +} + +// ToEKCertificate converts a x509 certificate to an EKCertificate. It also +// validates the EK cert according to Section 3.2 of +// https://trustedcomputinggroup.org/wp-content/uploads/TCG-EK-Credential-Profile-for-TPM-Family-2.0-Level-0-Version-2.6_pub.pdf +// +// This function checks for the presence and criticality of required extensions +// and performs basic structural validation, but it does NOT validate the +// semantic values of every extension. Some additional validation may be necessary depending on the use case. +// +// Extensions handled by this function will be removed from `cert.UnhandledCriticalExtensions` in place if present. +// +// TODO: Handle TPM Security Assertions (Section 3.1.1 from the EK Credential profile spec) +func ToEKCertificate(cert *x509.Certificate) (*EKCertificate, error) { + // Some older EK certificates have RSA-OAEP public keys, which are not + // parsed by crypto/x509, resulting in PublicKey being nil. + if cert.PublicKey == nil { + return nil, errors.New("publicKey is nil") + } + + var spec TpmSpecification + var tpmManufacturer, tpmModel, tpmVersion string + var hasEKCExtendedKeyUsage bool + extPresent := make(map[string]bool) + + // Version must be 3. + if cert.Version != 3 { + return nil, fmt.Errorf("invalid version of EK certificate: %d", cert.Version) + } + + // SerialNumber must be a positive integer and not nil. + if err := validateSerialNumber(cert.SerialNumber); err != nil { + return nil, err + } + + // Issuer must be present. + if bytes.Equal(cert.RawIssuer, emptyASN1Subject) { + return nil, errors.New("issuer is empty") + } + + isSubjectEmpty := bytes.Equal(cert.RawSubject, emptyASN1Subject) + + // Basic Constraints must be valid and the certificate must not be a CA. + if !cert.BasicConstraintsValid || cert.IsCA { + return nil, errors.New("BasicConstraints are not valid or it is a CA certificate") + } + + for _, ext := range cert.Extensions { + switch { + case ext.Id.Equal(oid.SubjectAltName): + if isSubjectEmpty { + if !ext.Critical { + return nil, errors.New("SubjectAltName extension must be critical when Subject is not present") + } + } + san, err := ParseSubjectAltName(ext) + if err != nil { + return nil, err + } + if len(san.DirectoryNames) != 1 { + return nil, errors.New("only a single DirectoryName is supported") + } + tpmManufacturer, tpmModel, tpmVersion, err = parseName(san.DirectoryNames[0]) + if err != nil { + return nil, err + } + case ext.Id.Equal(oid.SubjectDirectoryAttributes): + subjectDirectoryAttributes, err := parseSubjectDirectoryAttributes(ext) + if err != nil { + return nil, err + } + if spec, err = parseTPMSpecification(subjectDirectoryAttributes); err != nil { + return nil, err + } + case ext.Id.Equal(oidBasicConstraints): + if !ext.Critical { + return nil, errors.New("extension \"Basic Constraints\" is not critical, supposed to be critical") + } + case ext.Id.Equal(oidKeyUsage): + if !ext.Critical { + return nil, errors.New("extension \"Key Usage\" is not critical, supposed to be critical") + } + case ext.Id.Equal(oidAuthorityKeyID): + if ext.Critical { + return nil, errors.New("extension \"Authority Key Identifier\" is critical, supposed to be non-critical") + } + case ext.Id.Equal(oid.CertificatePolicies): + if len(cert.PolicyIdentifiers) == 0 { + return nil, errors.New("extension \"Certificate Policies\" should contain at least 1 policy identifier if the extension is present") + } + case ext.Id.Equal(oidAuthorityInfoAccess): + if ext.Critical { + return nil, errors.New("extension \"Authority Info Access\" is critical, supposed to be non-critical") + } + case ext.Id.Equal(oidCRLDistributionPoints): + if ext.Critical { + return nil, errors.New("extension \"CRL Distribution Points\" is critical, supposed to be non-critical") + } + case ext.Id.Equal(oidExtendedKeyUsage): + if ext.Critical { + return nil, errors.New("extension \"Extended Key Usage\" is critical, supposed to be non-critical") + } + case ext.Id.Equal(oidSubjectKeyIdentifier): + if ext.Critical { + return nil, errors.New("extension \"Subject Key Identifier\" is critical, supposed to be non-critical") + } + } + + extPresent[ext.Id.String()] = true + } + + // Check that all must-have extensions are present. + for _, extOID := range mustHaveExtensions { + if !extPresent[extOID.String()] { + return nil, fmt.Errorf("extension %v is missing", oidToExtNameMap[extOID.String()]) + } + } + + // Authority Key ID must be present and non-empty. + if len(cert.AuthorityKeyId) == 0 { + return nil, errors.New("missing Authority Key ID") + } + + // KeyUsage must be set and correctly set for the public key type. + if err := validateKeyUsage(cert.PublicKeyAlgorithm, cert.KeyUsage); err != nil { + return nil, err + } + + // Iterate through unknown/custom ExtKeyUsage OIDs + for _, eku := range cert.UnknownExtKeyUsage { + if eku.Equal(oid.EKCertificate) { + hasEKCExtendedKeyUsage = true + } + } + + // Iterate through the unhandled critical extensions to remove the handled extensions from the list. + for i, ext := range cert.UnhandledCriticalExtensions { + if ext.Equal(oid.SubjectAltName) { + length := len(cert.UnhandledCriticalExtensions) + // Remove the extension from the list of unhandled critical extensions. + cert.UnhandledCriticalExtensions[i] = cert.UnhandledCriticalExtensions[length-1] + cert.UnhandledCriticalExtensions = cert.UnhandledCriticalExtensions[:length-1] + break + } + } + + return &EKCertificate{ + Certificate: cert, + TpmManufacturer: tpmManufacturer, + TpmModel: tpmModel, + TpmVersion: tpmVersion, + TpmSpecification: spec, + HasEkcExtendedKeyUsage: hasEKCExtendedKeyUsage, + }, nil +} + +func validateSerialNumber(serialNumber *big.Int) error { + if serialNumber == nil { + return errors.New("SerialNumber is nil, expected a positive integer") + } + if serialNumber.Cmp(big.NewInt(0)) <= 0 { + return errors.New("SerialNumber is not a positive integer") + } + return nil +} + +func parseSubjectDirectoryAttributes(ext pkix.Extension) ([]attribute, error) { + var attrs []attribute + rest, err := asn1.Unmarshal(ext.Value, &attrs) + if err != nil { + return nil, err + } + if len(rest) != 0 { + return nil, errors.New("trailing data after X.509 extension") + } + return attrs, nil +} + +func parseTPMSpecification(subjectDirectoryAttributes []attribute) (TpmSpecification, error) { + for _, attr := range subjectDirectoryAttributes { + if attr.Type.Equal(oid.TPMSpecification) { + if len(attr.Values) != 1 { + return TpmSpecification{}, errors.New("expected SET size of 1") + } + value := attr.Values[0] + var spec TpmSpecification + rest, err := asn1.Unmarshal(value.FullBytes, &spec) + if err != nil { + return TpmSpecification{}, err + } + if len(rest) != 0 { + return TpmSpecification{}, errors.New("trailing data after TPMSpecification") + } + return spec, nil + } + } + return TpmSpecification{}, errors.New("TPMSpecification not present") +} + +func parseName(name pkix.Name) (string, string, string, error) { + var tpmManufacturer, tpmModel, tpmVersion string + for _, attr := range name.Names { + if attr.Type.Equal(oid.TPMManufacturer) { + tpmManufacturer = fmt.Sprintf("%v", attr.Value) + continue + } + if attr.Type.Equal(oid.TPMModel) { + tpmModel = fmt.Sprintf("%v", attr.Value) + continue + } + if attr.Type.Equal(oid.TPMVersion) { + tpmVersion = fmt.Sprintf("%v", attr.Value) + continue + } + return "", "", "", fmt.Errorf("unknown attribute type: %v", attr.Type) + } + if tpmManufacturer == "" { + return "", "", "", fmt.Errorf("TPM Manufacturer not present") + } + if tpmModel == "" { + return "", "", "", fmt.Errorf("TPM Model not present") + } + if tpmVersion == "" { + return "", "", "", fmt.Errorf("TPM Version not present") + } + return tpmManufacturer, tpmModel, tpmVersion, nil +} + +func validateKeyUsage(certType x509.PublicKeyAlgorithm, keyUsage x509.KeyUsage) error { + if keyUsage == 0 { + return fmt.Errorf("KeyUsage field is not set") + } + switch certType { + case x509.RSA: + if keyUsage&x509.KeyUsageKeyEncipherment == 0 { + return fmt.Errorf("KeyUsageKeyEncipherment is not set for RSA public key type") + } + case x509.ECDSA: + if keyUsage&x509.KeyUsageKeyAgreement == 0 { + return fmt.Errorf("KeyUsageKeyAgreement is not set for ECDSA public key type") + } + default: + return fmt.Errorf("unsupported public key type: %v", certType) + } + return nil +} diff --git a/x509/ekcert_test.go b/x509/ekcert_test.go new file mode 100644 index 0000000..72e5bd0 --- /dev/null +++ b/x509/ekcert_test.go @@ -0,0 +1,446 @@ +package x509ext + +import ( + "crypto" + "crypto/rand" + "crypto/rsa" + "crypto/x509" + "crypto/x509/pkix" + "encoding/asn1" + "errors" + "math/big" + "testing" + "time" + + "github.com/google/go-attestation/oid" + "github.com/google/go-cmp/cmp" + "github.com/google/go-cmp/cmp/cmpopts" +) + +var ( + SAN = &SubjectAltName{ + DirectoryNames: []pkix.Name{ + { + ExtraNames: []pkix.AttributeTypeAndValue{ + {Type: oid.TPMManufacturer, Value: "id:12345"}, + {Type: oid.TPMModel, Value: "vTPM"}, + {Type: oid.TPMVersion, Value: "id:246810"}, + }, + }, + }, + } + tpmSpec = TpmSpecification{ + Family: "2.0", + Level: 1, + Revision: 48, + } + + keySize = 2048 +) + +func generateRSAPrivateKey(t *testing.T) crypto.Signer { + privateKey, err := rsa.GenerateKey(rand.Reader, keySize) + if err != nil { + t.Fatalf("failed to generate RSA private key: %v", err) + } + + return privateKey +} + +// selfSignedCert creates a self-signed certificate. +func selfSignedCert(t *testing.T, signer crypto.Signer) *x509.Certificate { + template := &x509.Certificate{ + SerialNumber: big.NewInt(1), + IsCA: true, + NotBefore: time.Date(2024, 01, 01, 0, 0, 0, 0, time.UTC), + NotAfter: time.Date(2034, 01, 01, 0, 0, 0, 0, time.UTC), + Subject: pkix.Name{CommonName: "Test Signer"}, + ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth, x509.ExtKeyUsageServerAuth}, + KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageCertSign, + BasicConstraintsValid: true, + } + + certBytes, err := x509.CreateCertificate(nil, template, template, signer.Public(), signer) + if err != nil { + t.Fatalf("x509.CreateCertificate() failed: %v", err) + } + cert, err := x509.ParseCertificate(certBytes) + if err != nil { + t.Fatalf("x509.ParseCertificate() failed: %v", err) + } + return cert +} + +// setupTestCert creates a new x509.Certificate for each test. +func setupTestCert(t *testing.T, modifyCertTemplate func(*testing.T, *x509.Certificate)) *x509.Certificate { + t.Helper() + privateKey := generateRSAPrivateKey(t) + parentCert := selfSignedCert(t, privateKey) + template := &x509.Certificate{ + SerialNumber: big.NewInt(1234), + Version: 3, + Issuer: pkix.Name{ + CommonName: "Test CA", + }, + Subject: pkix.Name{ + CommonName: "test-common-name", + }, + KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageKeyAgreement, + BasicConstraintsValid: true, + IsCA: false, + ExtraExtensions: []pkix.Extension{ + createSubjectAltNameExtension(t, SAN, false), + createSubjectDirectoryAttributesExtension(t, tpmSpec, false), + }, + AuthorityKeyId: []byte("test-authority-key-id"), + UnknownExtKeyUsage: []asn1.ObjectIdentifier{oid.EKCertificate}, + } + + if modifyCertTemplate != nil { + modifyCertTemplate(t, template) + } + certBytes, err := x509.CreateCertificate(rand.Reader, template, parentCert, generateRSAPrivateKey(t).Public(), privateKey) + if err != nil { + t.Fatalf("Failed to create certificate: %v", err) + } + + cert, err := x509.ParseCertificate(certBytes) + if err != nil { + t.Fatalf("Failed to parse created certificate: %v", err) + } + + return cert +} + +// createSubjectAltNameExtension creates a new pkix.Extension for the SubjectAltName extension. +func createSubjectAltNameExtension(t *testing.T, san *SubjectAltName, critical bool) pkix.Extension { + t.Helper() + marshalledSAN, err := MarshalSubjectAltName(san, critical) + if err != nil { + t.Fatalf("Failed to marshal SubjectAltName: %v", err) + } + return pkix.Extension{ + Id: oid.SubjectAltName, + Critical: critical, + Value: marshalledSAN.Value, + } +} + +// createSubjectDirectoryAttributesExtension creates a new pkix.Extension for the SubjectDirectoryAttributes extension. +func createSubjectDirectoryAttributesExtension(t *testing.T, tpmSpec TpmSpecification, critical bool) pkix.Extension { + t.Helper() + tpmSpecBytes, err := asn1.Marshal(tpmSpec) + if err != nil { + t.Fatalf("Failed to marshal TpmSpecification: %v", err) + } + subjectDirectoryAttributes := []attribute{ + { + Type: oid.TPMSpecification, + Values: []asn1.RawValue{ + {FullBytes: tpmSpecBytes}, + }, + }, + } + subjectDirectoryAttributesBytes, err := asn1.Marshal(subjectDirectoryAttributes) + if err != nil { + t.Fatalf("Failed to marshal subjectDirectoryAttributes: %v", err) + } + return pkix.Extension{ + Id: oid.SubjectDirectoryAttributes, + Critical: critical, + Value: subjectDirectoryAttributesBytes, + } +} + +// modifyExtension ensures that an extension with the given OID is present with the specified criticality. If not, it is added. +func modifyExtension(t *testing.T, cert *x509.Certificate, oid asn1.ObjectIdentifier, value []byte, critical bool) { + t.Helper() + for i, ext := range cert.Extensions { + if ext.Id.Equal(oid) { + cert.Extensions[i].Critical = critical + cert.Extensions[i].Value = value + return + } + } + + // Extension not found, so add it. + ext := pkix.Extension{ + Id: oid, + Critical: critical, + Value: value, + } + cert.Extensions = append(cert.Extensions, ext) +} + +func TestToEKCertificate_Success(t *testing.T) { + tests := []struct { + name string + modifyCertTemplate func(*testing.T, *x509.Certificate) + wantHasEkcExtendedKeyUsage bool + wantUnhandledCriticalExtensionsOIDs []asn1.ObjectIdentifier + }{ + { + name: "HasEkcExtendedKeyUsage is true", + modifyCertTemplate: func(t *testing.T, cert *x509.Certificate) { + cert.UnknownExtKeyUsage = []asn1.ObjectIdentifier{oid.EKCertificate} + }, + wantHasEkcExtendedKeyUsage: true, + }, + { + name: "HasEkcExtendedKeyUsage is false", + modifyCertTemplate: func(t *testing.T, cert *x509.Certificate) { + cert.UnknownExtKeyUsage = []asn1.ObjectIdentifier{} + }, + wantHasEkcExtendedKeyUsage: false, + }, + { + name: "Unhandled critical extension", + modifyCertTemplate: func(t *testing.T, cert *x509.Certificate) { + // Use empty subject, make the SubjectAltName extension critical, and + // add an unhandled critical extension. + cert.Subject = pkix.Name{} + for i, ext := range cert.ExtraExtensions { + if ext.Id.Equal(oid.SubjectAltName) { + cert.ExtraExtensions[i].Critical = true + } + } + cert.ExtraExtensions = append(cert.ExtraExtensions, pkix.Extension{ + Id: asn1.ObjectIdentifier{1, 2, 3, 4, 5}, + Critical: true, + Value: []byte("test"), + }) + }, + wantHasEkcExtendedKeyUsage: true, + wantUnhandledCriticalExtensionsOIDs: []asn1.ObjectIdentifier{{1, 2, 3, 4, 5}}, + }, + } + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + cert := setupTestCert(t, test.modifyCertTemplate) + ekCert, err := ToEKCertificate(cert) + if err != nil { + t.Fatalf("ToEKCertificate() returned an unexpected error: %v", err) + } + + if ekCert.Certificate != cert { + t.Errorf("ToEKCertificate() Certificate = %v, want: %v", ekCert.Certificate, cert) + } + + if ekCert.TpmManufacturer != SAN.DirectoryNames[0].ExtraNames[0].Value { + t.Errorf("ToEKCertificate() TPMManufacturer = %v, want: %v", ekCert.TpmManufacturer, SAN.DirectoryNames[0].ExtraNames[0].Value) + } + + if ekCert.TpmModel != SAN.DirectoryNames[0].ExtraNames[1].Value { + t.Errorf("ToEKCertificate() TPMModel = %v, want: %v", ekCert.TpmModel, SAN.DirectoryNames[0].ExtraNames[1].Value) + } + + if ekCert.TpmVersion != SAN.DirectoryNames[0].ExtraNames[2].Value { + t.Errorf("ToEKCertificate() TPMVersion = %v, want: %v", ekCert.TpmVersion, SAN.DirectoryNames[0].ExtraNames[2].Value) + } + + if ekCert.TpmSpecification != (tpmSpec) { + t.Errorf("ToEKCertificate() TpmSpecification is %v, want %v", ekCert.TpmSpecification, tpmSpec) + } + + if test.wantHasEkcExtendedKeyUsage != ekCert.HasEkcExtendedKeyUsage { + t.Errorf("ToEKCertificate() HasEkcExtendedKeyUsage = %v, want: %v", ekCert.HasEkcExtendedKeyUsage, test.wantHasEkcExtendedKeyUsage) + } + + sortOIDs := cmpopts.SortSlices(func(a, b asn1.ObjectIdentifier) bool { return a.String() < b.String() }) + if diff := cmp.Diff(test.wantUnhandledCriticalExtensionsOIDs, ekCert.UnhandledCriticalExtensions, cmpopts.EquateEmpty(), sortOIDs); diff != "" { + t.Errorf("ToEKCertificate() UnhandledCriticalExtensions differs (-want +got):\n%s", diff) + } + }) + } +} + +func TestToEKCertificate_Failures(t *testing.T) { + tests := []struct { + name string + modifyCert func(*testing.T, *x509.Certificate) + wantErr error + }{ + { + name: "Version is 2", + modifyCert: func(t *testing.T, cert *x509.Certificate) { + cert.Version = 2 + }, + wantErr: errors.New("invalid version of EK certificate: 2"), + }, + { + name: "SerialNumber is zero", + modifyCert: func(t *testing.T, cert *x509.Certificate) { + cert.SerialNumber = big.NewInt(0) + }, + wantErr: errors.New("SerialNumber is not a positive integer"), + }, + { + name: "SerialNumber negative", + modifyCert: func(t *testing.T, cert *x509.Certificate) { + cert.SerialNumber = big.NewInt(-1) + }, + wantErr: errors.New("SerialNumber is not a positive integer"), + }, + { + name: "SerialNumber nil", + modifyCert: func(t *testing.T, cert *x509.Certificate) { + cert.SerialNumber = nil + }, + wantErr: errors.New("SerialNumber is nil, expected a positive integer"), + }, + { + name: "Issuer is empty", + modifyCert: func(t *testing.T, cert *x509.Certificate) { + cert.RawIssuer = []byte{0x30, 0x00} + }, + wantErr: errors.New("issuer is empty"), + }, + { + name: "BasicConstraints not set", + modifyCert: func(t *testing.T, cert *x509.Certificate) { + cert.BasicConstraintsValid = false + cert.IsCA = false + }, + wantErr: errors.New("BasicConstraints are not valid or it is a CA certificate"), + }, + { + name: "BasicConstraints is set but cert is a CA", + modifyCert: func(t *testing.T, cert *x509.Certificate) { + cert.BasicConstraintsValid = true + cert.IsCA = true + }, + wantErr: errors.New("BasicConstraints are not valid or it is a CA certificate"), + }, + { + name: "BasicConstraints is not critical", + modifyCert: func(t *testing.T, cert *x509.Certificate) { + for i, ext := range cert.Extensions { + if ext.Id.Equal(oidBasicConstraints) { + cert.Extensions[i].Critical = false + } + } + }, + wantErr: errors.New("extension \"Basic Constraints\" is not critical, supposed to be critical"), + }, + { + name: "SAN is not critical when Subject is empty", + modifyCert: func(t *testing.T, cert *x509.Certificate) { + cert.RawSubject = []byte{0x30, 0x00} + for i, ext := range cert.Extensions { + if ext.Id.Equal(oid.SubjectAltName) { + cert.Extensions[i].Critical = false + } + } + }, + wantErr: errors.New("SubjectAltName extension must be critical when Subject is not present"), + }, + { + name: "MUST extension is missing", + modifyCert: func(t *testing.T, cert *x509.Certificate) { + cert.Extensions = []pkix.Extension{} + }, + wantErr: errors.New("extension SubjectAltName is missing"), + }, + { + name: "AuthorityKeyId empty", + modifyCert: func(t *testing.T, cert *x509.Certificate) { + cert.AuthorityKeyId = []byte{} + }, + wantErr: errors.New("missing Authority Key ID"), + }, + { + name: "Authority Key Identifier is critical", + modifyCert: func(t *testing.T, cert *x509.Certificate) { + modifyExtension(t, cert, oidAuthorityKeyID, []byte{0x30, 0x00}, true) + }, + wantErr: errors.New("extension \"Authority Key Identifier\" is critical, supposed to be non-critical"), + }, + { + name: "KeyUsage not set", + modifyCert: func(t *testing.T, cert *x509.Certificate) { + cert.PublicKeyAlgorithm = x509.RSA + cert.KeyUsage = 0 + }, + wantErr: errors.New("KeyUsage field is not set"), + }, + { + name: "KeyUsageKeyEncipherment not set for RSA", + modifyCert: func(t *testing.T, cert *x509.Certificate) { + cert.PublicKeyAlgorithm = x509.RSA + cert.KeyUsage = x509.KeyUsageKeyAgreement + }, + wantErr: errors.New("KeyUsageKeyEncipherment is not set for RSA public key type"), + }, + { + name: "KeyUsageKeyAgreement not set for ECDSA", + modifyCert: func(t *testing.T, cert *x509.Certificate) { + cert.PublicKeyAlgorithm = x509.ECDSA + cert.KeyUsage = x509.KeyUsageKeyEncipherment + }, + wantErr: errors.New("KeyUsageKeyAgreement is not set for ECDSA public key type"), + }, + { + name: "Key Usage is not critical", + modifyCert: func(t *testing.T, cert *x509.Certificate) { + for i, ext := range cert.Extensions { + if ext.Id.Equal(oidKeyUsage) { + cert.Extensions[i].Critical = false + } + } + }, + wantErr: errors.New("extension \"Key Usage\" is not critical, supposed to be critical"), + }, + { + name: "CertificatePolicies present but empty PolicyIdentifiers", + modifyCert: func(t *testing.T, cert *x509.Certificate) { + modifyExtension(t, cert, oid.CertificatePolicies, []byte{0x30, 0x00}, true) + cert.PolicyIdentifiers = []asn1.ObjectIdentifier{} + }, + wantErr: errors.New("extension \"Certificate Policies\" should contain at least 1 policy identifier if the extension is present"), + }, + { + name: "AuthorityInfoAccess is critical", + modifyCert: func(t *testing.T, cert *x509.Certificate) { + modifyExtension(t, cert, oidAuthorityInfoAccess, []byte{0x30, 0x00}, true) + }, + wantErr: errors.New("extension \"Authority Info Access\" is critical, supposed to be non-critical"), + }, + { + name: "CRLDistributionPoints is critical", + modifyCert: func(t *testing.T, cert *x509.Certificate) { + modifyExtension(t, cert, oidCRLDistributionPoints, []byte{0x30, 0x00}, true) + }, + wantErr: errors.New("extension \"CRL Distribution Points\" is critical, supposed to be non-critical"), + }, + { + name: "ExtendedKeyUsage is critical", + modifyCert: func(t *testing.T, cert *x509.Certificate) { + modifyExtension(t, cert, oidExtendedKeyUsage, []byte{0x30, 0x00}, true) + }, + wantErr: errors.New("extension \"Extended Key Usage\" is critical, supposed to be non-critical"), + }, + { + name: "SubjectKeyIdentifier is critical", + modifyCert: func(t *testing.T, cert *x509.Certificate) { + modifyExtension(t, cert, oidSubjectKeyIdentifier, []byte{0x30, 0x00}, true) + }, + wantErr: errors.New("extension \"Subject Key Identifier\" is critical, supposed to be non-critical"), + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + cert := setupTestCert(t, nil) + test.modifyCert(t, cert) + _, err := ToEKCertificate(cert) + + if err == nil { + t.Fatalf("ToEKCertificate() succeeded unexpectedly, want error: %v", test.wantErr) + } + + if err.Error() != test.wantErr.Error() { + t.Fatalf("ToEKCertificate() error = %v, want error: %v", err, test.wantErr) + } + }) + } +}