From f3f0ccc3309312e0bebcc1bdfb992efffffdaf51 Mon Sep 17 00:00:00 2001 From: "google-labs-jules[bot]" <161369871+google-labs-jules[bot]@users.noreply.github.com> Date: Tue, 3 Feb 2026 08:02:07 +0000 Subject: [PATCH 1/5] Refactor Provider interface and implement Phase 1 changes - Update Provider interface to be zone-centric (ListZones, GetZone, ListRecords, etc.) - Update Record struct with new fields (Priority, Weight, Port, Target, Metadata, Raw) - Add conformance test harness in pkg/dns/provider/conformance - Stub RESTProvider and NamecheapProvider to match new interface - Update Service layer to resolve Zone IDs and use new Provider methods - Fix tests in pkg/dns and pkg/dns/provider/rest Co-authored-by: SamyRai <919510+SamyRai@users.noreply.github.com> --- .../provider/conformance/conformance_test.go | 14 + pkg/dns/provider/conformance/mock_provider.go | 124 ++++++++ pkg/dns/provider/conformance/suite.go | 35 +++ pkg/dns/provider/namecheap/adapter.go | 94 +++++- pkg/dns/provider/provider.go | 52 +++- pkg/dns/provider/registry_test.go | 38 ++- pkg/dns/provider/rest/rest.go | 285 +++++++----------- pkg/dns/provider/rest/rest_test.go | 11 +- pkg/dns/provider/rest/zoneid_test.go | 53 ---- pkg/dns/service.go | 109 +++++-- pkg/dns/service_test.go | 40 ++- pkg/dnsrecord/record.go | 6 + 12 files changed, 586 insertions(+), 275 deletions(-) create mode 100644 pkg/dns/provider/conformance/conformance_test.go create mode 100644 pkg/dns/provider/conformance/mock_provider.go create mode 100644 pkg/dns/provider/conformance/suite.go delete mode 100644 pkg/dns/provider/rest/zoneid_test.go diff --git a/pkg/dns/provider/conformance/conformance_test.go b/pkg/dns/provider/conformance/conformance_test.go new file mode 100644 index 0000000..0a70b31 --- /dev/null +++ b/pkg/dns/provider/conformance/conformance_test.go @@ -0,0 +1,14 @@ +package conformance + +import ( + "testing" + + "zonekit/pkg/dns/provider" +) + +func TestMockProviderConformance(t *testing.T) { + p := NewMockProvider() + p.Zones["zone-1"] = provider.Zone{ID: "zone-1", Name: "example.com"} + + RunConformanceTests(t, p) +} diff --git a/pkg/dns/provider/conformance/mock_provider.go b/pkg/dns/provider/conformance/mock_provider.go new file mode 100644 index 0000000..666b894 --- /dev/null +++ b/pkg/dns/provider/conformance/mock_provider.go @@ -0,0 +1,124 @@ +package conformance + +import ( + "context" + "fmt" + + "zonekit/pkg/dns/provider" + "zonekit/pkg/dnsrecord" +) + +// MockProvider implements the Provider interface for testing +type MockProvider struct { + Zones map[string]provider.Zone + Records map[string]map[string]dnsrecord.Record // zoneID -> recordID -> Record +} + +// NewMockProvider creates a new mock provider +func NewMockProvider() *MockProvider { + return &MockProvider{ + Zones: make(map[string]provider.Zone), + Records: make(map[string]map[string]dnsrecord.Record), + } +} + +func (m *MockProvider) Name() string { + return "mock" +} + +func (m *MockProvider) ListZones(ctx context.Context) ([]provider.Zone, error) { + var zones []provider.Zone + for _, z := range m.Zones { + zones = append(zones, z) + } + return zones, nil +} + +func (m *MockProvider) GetZone(ctx context.Context, zoneID string) (provider.Zone, error) { + z, ok := m.Zones[zoneID] + if !ok { + return provider.Zone{}, fmt.Errorf("zone not found") + } + return z, nil +} + +func (m *MockProvider) ListRecords(ctx context.Context, zoneID string) ([]dnsrecord.Record, error) { + if _, ok := m.Zones[zoneID]; !ok { + return nil, fmt.Errorf("zone not found") + } + var records []dnsrecord.Record + if zoneRecords, ok := m.Records[zoneID]; ok { + for _, r := range zoneRecords { + records = append(records, r) + } + } + return records, nil +} + +func (m *MockProvider) CreateRecord(ctx context.Context, zoneID string, record dnsrecord.Record) (dnsrecord.Record, error) { + if _, ok := m.Zones[zoneID]; !ok { + return dnsrecord.Record{}, fmt.Errorf("zone not found") + } + if m.Records[zoneID] == nil { + m.Records[zoneID] = make(map[string]dnsrecord.Record) + } + if record.ID == "" { + record.ID = fmt.Sprintf("rec-%d", len(m.Records[zoneID])+1) + } + m.Records[zoneID][record.ID] = record + return record, nil +} + +func (m *MockProvider) UpdateRecord(ctx context.Context, zoneID string, recordID string, record dnsrecord.Record) (dnsrecord.Record, error) { + if _, ok := m.Zones[zoneID]; !ok { + return dnsrecord.Record{}, fmt.Errorf("zone not found") + } + if _, ok := m.Records[zoneID][recordID]; !ok { + return dnsrecord.Record{}, fmt.Errorf("record not found") + } + record.ID = recordID + m.Records[zoneID][recordID] = record + return record, nil +} + +func (m *MockProvider) DeleteRecord(ctx context.Context, zoneID string, recordID string) error { + if _, ok := m.Zones[zoneID]; !ok { + return fmt.Errorf("zone not found") + } + if _, ok := m.Records[zoneID][recordID]; !ok { + return fmt.Errorf("record not found") + } + delete(m.Records[zoneID], recordID) + return nil +} + +func (m *MockProvider) BulkReplaceRecords(ctx context.Context, zoneID string, records []dnsrecord.Record) error { + if _, ok := m.Zones[zoneID]; !ok { + return fmt.Errorf("zone not found") + } + m.Records[zoneID] = make(map[string]dnsrecord.Record) + for _, r := range records { + if r.ID == "" { + r.ID = fmt.Sprintf("rec-%d", len(m.Records[zoneID])+1) + } + m.Records[zoneID][r.ID] = r + } + return nil +} + +func (m *MockProvider) Capabilities() provider.ProviderCapabilities { + return provider.ProviderCapabilities{ + CanListZones: true, + CanGetZone: true, + CanCreateRecord: true, + CanUpdateRecord: true, + CanDeleteRecord: true, + CanBulkReplace: true, + } +} + +func (m *MockProvider) Validate() error { + return nil +} + +var _ provider.Provider = (*MockProvider)(nil) diff --git a/pkg/dns/provider/conformance/suite.go b/pkg/dns/provider/conformance/suite.go new file mode 100644 index 0000000..f818b44 --- /dev/null +++ b/pkg/dns/provider/conformance/suite.go @@ -0,0 +1,35 @@ +package conformance + +import ( + "context" + "testing" + + "zonekit/pkg/dns/provider" + + "github.com/stretchr/testify/require" +) + +// RunConformanceTests runs a set of tests to verify provider compliance +func RunConformanceTests(t *testing.T, p provider.Provider) { + ctx := context.Background() + + t.Run("Capabilities", func(t *testing.T) { + caps := p.Capabilities() + t.Logf("Provider %s capabilities: %+v", p.Name(), caps) + }) + + t.Run("ZoneOperations", func(t *testing.T) { + if !p.Capabilities().CanListZones { + t.Skip("Provider does not support listing zones") + } + + zones, err := p.ListZones(ctx) + require.NoError(t, err) + + if len(zones) > 0 && p.Capabilities().CanGetZone { + zone, err := p.GetZone(ctx, zones[0].ID) + require.NoError(t, err) + require.Equal(t, zones[0].ID, zone.ID) + } + }) +} diff --git a/pkg/dns/provider/namecheap/adapter.go b/pkg/dns/provider/namecheap/adapter.go index 46d7c99..7f94d1b 100644 --- a/pkg/dns/provider/namecheap/adapter.go +++ b/pkg/dns/provider/namecheap/adapter.go @@ -1,6 +1,7 @@ package namecheap import ( + "context" "fmt" "github.com/namecheap/go-namecheap-sdk/v2/namecheap" @@ -28,13 +29,52 @@ func (p *NamecheapProvider) Name() string { return "namecheap" } -// GetRecords retrieves all DNS records for a domain -func (p *NamecheapProvider) GetRecords(domainName string) ([]dnsrecord.Record, error) { +// ListZones retrieves all zones managed by the provider +func (p *NamecheapProvider) ListZones(ctx context.Context) ([]dnsprovider.Zone, error) { + nc := p.client.GetNamecheapClient() + res, err := nc.Domains.GetList(&namecheap.DomainsGetListArgs{}) + if err != nil { + return nil, errors.NewAPI("ListZones", "failed to list zones", err) + } + + if res == nil || res.Domains == nil { + return []dnsprovider.Zone{}, nil + } + + var zones []dnsprovider.Zone + for _, d := range *res.Domains { + zones = append(zones, dnsprovider.Zone{ + ID: *d.Name, + Name: *d.Name, + }) + } + return zones, nil +} + +// GetZone retrieves a specific zone by ID +func (p *NamecheapProvider) GetZone(ctx context.Context, zoneID string) (dnsprovider.Zone, error) { + nc := p.client.GetNamecheapClient() + res, err := nc.Domains.GetInfo(zoneID) + if err != nil { + return dnsprovider.Zone{}, errors.NewAPI("GetZone", fmt.Sprintf("failed to get zone %s", zoneID), err) + } + if res == nil || res.DomainDNSGetListResult == nil || res.DomainDNSGetListResult.DomainName == nil { + return dnsprovider.Zone{}, fmt.Errorf("zone not found") + } + + return dnsprovider.Zone{ + ID: *res.DomainDNSGetListResult.DomainName, + Name: *res.DomainDNSGetListResult.DomainName, + }, nil +} + +// ListRecords retrieves all DNS records for a zone +func (p *NamecheapProvider) ListRecords(ctx context.Context, zoneID string) ([]dnsrecord.Record, error) { nc := p.client.GetNamecheapClient() - resp, err := nc.DomainsDNS.GetHosts(domainName) + resp, err := nc.DomainsDNS.GetHosts(zoneID) if err != nil { - return nil, errors.NewAPI("GetHosts", fmt.Sprintf("failed to get DNS records for %s", domainName), err) + return nil, errors.NewAPI("GetHosts", fmt.Sprintf("failed to get DNS records for %s", zoneID), err) } // Safety check for nil response @@ -57,9 +97,38 @@ func (p *NamecheapProvider) GetRecords(domainName string) ([]dnsrecord.Record, e return records, nil } -// SetRecords sets DNS records for a domain (replaces all existing records) -func (p *NamecheapProvider) SetRecords(domainName string, records []dnsrecord.Record) error { +// CreateRecord creates a new DNS record +func (p *NamecheapProvider) CreateRecord(ctx context.Context, zoneID string, record dnsrecord.Record) (dnsrecord.Record, error) { + records, err := p.ListRecords(ctx, zoneID) + if err != nil { + return dnsrecord.Record{}, err + } + + records = append(records, record) + + if err := p.BulkReplaceRecords(ctx, zoneID, records); err != nil { + return dnsrecord.Record{}, err + } + + return record, nil +} + +// UpdateRecord updates an existing DNS record +func (p *NamecheapProvider) UpdateRecord(ctx context.Context, zoneID string, recordID string, record dnsrecord.Record) (dnsrecord.Record, error) { + // Stub: Namecheap doesn't support granular update easily without ID + return dnsrecord.Record{}, fmt.Errorf("UpdateRecord not implemented for Namecheap (requires BulkReplace)") +} + +// DeleteRecord deletes a DNS record +func (p *NamecheapProvider) DeleteRecord(ctx context.Context, zoneID string, recordID string) error { + // Stub + return fmt.Errorf("DeleteRecord not implemented for Namecheap (requires BulkReplace)") +} + +// BulkReplaceRecords sets DNS records for a domain (replaces all existing records) +func (p *NamecheapProvider) BulkReplaceRecords(ctx context.Context, zoneID string, records []dnsrecord.Record) error { nc := p.client.GetNamecheapClient() + domainName := zoneID // Convert records to Namecheap format hostRecords := make([]namecheap.DomainsDNSHostRecord, len(records)) @@ -94,7 +163,6 @@ func (p *NamecheapProvider) SetRecords(domainName string, records []dnsrecord.Re } // Set EmailType to MX if there are any MX records - // This is required by the Namecheap API when MX records are present if hasMXRecords { args.EmailType = namecheap.String("MX") } @@ -107,6 +175,18 @@ func (p *NamecheapProvider) SetRecords(domainName string, records []dnsrecord.Re return nil } +// Capabilities returns the provider's capabilities +func (p *NamecheapProvider) Capabilities() dnsprovider.ProviderCapabilities { + return dnsprovider.ProviderCapabilities{ + CanListZones: true, + CanGetZone: true, + CanCreateRecord: true, + CanUpdateRecord: false, + CanDeleteRecord: false, + CanBulkReplace: true, + } +} + // Validate checks if the provider is properly configured func (p *NamecheapProvider) Validate() error { if p.client == nil { diff --git a/pkg/dns/provider/provider.go b/pkg/dns/provider/provider.go index 3fb1964..b2970b3 100644 --- a/pkg/dns/provider/provider.go +++ b/pkg/dns/provider/provider.go @@ -1,19 +1,55 @@ package provider import ( + "context" + "zonekit/pkg/dnsrecord" ) +// Zone represents a DNS zone +type Zone struct { + ID string + Name string +} + +// ProviderCapabilities describes what a provider supports +type ProviderCapabilities struct { + CanListZones bool + CanGetZone bool + CanCreateRecord bool + CanUpdateRecord bool + CanDeleteRecord bool + CanBulkReplace bool +} + // Provider defines the interface that all DNS providers must implement type Provider interface { // Name returns the provider name (e.g., "namecheap", "cloudflare", "godaddy") Name() string - // GetRecords retrieves all DNS records for a domain - GetRecords(domainName string) ([]dnsrecord.Record, error) + // ListZones retrieves all zones managed by the provider + ListZones(ctx context.Context) ([]Zone, error) + + // GetZone retrieves a specific zone by ID + GetZone(ctx context.Context, zoneID string) (Zone, error) + + // ListRecords retrieves all DNS records for a zone + ListRecords(ctx context.Context, zoneID string) ([]dnsrecord.Record, error) + + // CreateRecord creates a new DNS record + CreateRecord(ctx context.Context, zoneID string, record dnsrecord.Record) (dnsrecord.Record, error) + + // UpdateRecord updates an existing DNS record + UpdateRecord(ctx context.Context, zoneID string, recordID string, record dnsrecord.Record) (dnsrecord.Record, error) + + // DeleteRecord deletes a DNS record + DeleteRecord(ctx context.Context, zoneID string, recordID string) error + + // BulkReplaceRecords replaces all records in a zone with the provided set + BulkReplaceRecords(ctx context.Context, zoneID string, records []dnsrecord.Record) error - // SetRecords sets DNS records for a domain (replaces all existing records) - SetRecords(domainName string, records []dnsrecord.Record) error + // Capabilities returns the provider's capabilities + Capabilities() ProviderCapabilities // Validate checks if the provider is properly configured Validate() error @@ -63,6 +99,10 @@ type FieldMappings struct { TTL string `yaml:"ttl,omitempty"` MXPref string `yaml:"mx_pref,omitempty"` // e.g., "priority" or "preference" ID string `yaml:"id,omitempty"` // provider record ID field + Priority string `yaml:"priority,omitempty"` + Weight string `yaml:"weight,omitempty"` + Port string `yaml:"port,omitempty"` + Target string `yaml:"target,omitempty"` } `yaml:"request,omitempty"` // Response mappings (provider format -> our format) @@ -73,6 +113,10 @@ type FieldMappings struct { TTL string `yaml:"ttl,omitempty"` MXPref string `yaml:"mx_pref,omitempty"` ID string `yaml:"id,omitempty"` // provider record ID field + Priority string `yaml:"priority,omitempty"` + Weight string `yaml:"weight,omitempty"` + Port string `yaml:"port,omitempty"` + Target string `yaml:"target,omitempty"` } `yaml:"response,omitempty"` // List response structure (for REST providers) diff --git a/pkg/dns/provider/registry_test.go b/pkg/dns/provider/registry_test.go index 790b8d8..d6002da 100644 --- a/pkg/dns/provider/registry_test.go +++ b/pkg/dns/provider/registry_test.go @@ -1,6 +1,7 @@ package provider import ( + "context" "testing" "github.com/stretchr/testify/suite" @@ -27,21 +28,38 @@ func (m *mockProviderForRegistry) Name() string { return m.name } -func (m *mockProviderForRegistry) GetRecords(domainName string) ([]dnsrecord.Record, error) { - if m.getRecordsError != nil { - return nil, m.getRecordsError - } - return m.records[domainName], nil +func (m *mockProviderForRegistry) ListZones(ctx context.Context) ([]Zone, error) { + return nil, nil } -func (m *mockProviderForRegistry) SetRecords(domainName string, records []dnsrecord.Record) error { - if m.setRecordsError != nil { - return m.setRecordsError - } - m.records[domainName] = records +func (m *mockProviderForRegistry) GetZone(ctx context.Context, zoneID string) (Zone, error) { + return Zone{}, nil +} + +func (m *mockProviderForRegistry) ListRecords(ctx context.Context, zoneID string) ([]dnsrecord.Record, error) { + return nil, nil +} + +func (m *mockProviderForRegistry) CreateRecord(ctx context.Context, zoneID string, record dnsrecord.Record) (dnsrecord.Record, error) { + return dnsrecord.Record{}, nil +} + +func (m *mockProviderForRegistry) UpdateRecord(ctx context.Context, zoneID string, recordID string, record dnsrecord.Record) (dnsrecord.Record, error) { + return dnsrecord.Record{}, nil +} + +func (m *mockProviderForRegistry) DeleteRecord(ctx context.Context, zoneID string, recordID string) error { return nil } +func (m *mockProviderForRegistry) BulkReplaceRecords(ctx context.Context, zoneID string, records []dnsrecord.Record) error { + return nil +} + +func (m *mockProviderForRegistry) Capabilities() ProviderCapabilities { + return ProviderCapabilities{} +} + func (m *mockProviderForRegistry) Validate() error { return m.validateError } diff --git a/pkg/dns/provider/rest/rest.go b/pkg/dns/provider/rest/rest.go index 25c6b04..7a88f0b 100644 --- a/pkg/dns/provider/rest/rest.go +++ b/pkg/dns/provider/rest/rest.go @@ -43,44 +43,66 @@ func (p *RESTProvider) Name() string { return p.name } -// GetRecords retrieves all DNS records for a domain -func (p *RESTProvider) GetRecords(domainName string) ([]dnsrecord.Record, error) { - endpoint, ok := p.endpoints["get_records"] +// ListZones retrieves all zones managed by the provider +func (p *RESTProvider) ListZones(ctx context.Context) ([]dnsprovider.Zone, error) { + // check for list_zones or zones endpoint + endpoint, ok := p.endpoints["list_zones"] if !ok { - return nil, fmt.Errorf("get_records endpoint not configured") + endpoint, ok = p.endpoints["zones"] + } + if !ok { + // If no endpoint, return empty list (stub behavior) + return []dnsprovider.Zone{}, nil } - // Replace placeholders in endpoint (e.g., {zone_id}, {domain}) - endpoint = p.replacePlaceholders(endpoint, domainName) - - // Get zone ID if required - zoneID, err := p.getZoneID(domainName) + resp, err := p.client.Get(ctx, endpoint, nil) if err != nil { - return nil, fmt.Errorf("failed to get zone ID: %w", err) + return nil, errors.NewAPI("ListZones", "failed to list zones", err) } - if zoneID != "" { - endpoint = strings.ReplaceAll(endpoint, "{zone_id}", zoneID) + + var responseData interface{} + if err := httpprovider.ParseJSONResponse(resp, &responseData); err != nil { + return nil, fmt.Errorf("failed to parse response: %w", err) } - ctx := context.Background() + // TODO: Implement zone mapping + return []dnsprovider.Zone{}, nil +} + +// GetZone retrieves a specific zone by ID +func (p *RESTProvider) GetZone(ctx context.Context, zoneID string) (dnsprovider.Zone, error) { + // Stub: return a zone with the ID and Name = ID + return dnsprovider.Zone{ID: zoneID, Name: zoneID}, nil +} + +// ListRecords retrieves all DNS records for a zone +func (p *RESTProvider) ListRecords(ctx context.Context, zoneID string) ([]dnsrecord.Record, error) { + endpoint, ok := p.endpoints["get_records"] + if !ok { + endpoint, ok = p.endpoints["list_records"] + } + if !ok { + return nil, fmt.Errorf("get_records endpoint not configured") + } + + endpoint = strings.ReplaceAll(endpoint, "{zone_id}", zoneID) + endpoint = strings.ReplaceAll(endpoint, "{domain}", zoneID) + resp, err := p.client.Get(ctx, endpoint, nil) if err != nil { - return nil, errors.NewAPI("GetRecords", fmt.Sprintf("failed to get DNS records for %s", domainName), err) + return nil, errors.NewAPI("ListRecords", fmt.Sprintf("failed to get DNS records for zone %s", zoneID), err) } - // Parse response (this will close the body) var responseData interface{} if err := httpprovider.ParseJSONResponse(resp, &responseData); err != nil { return nil, fmt.Errorf("failed to parse response: %w", err) } - // Extract records using list path recordMaps, err := mapper.ExtractRecords(responseData, p.mappings.ListPath) if err != nil { return nil, fmt.Errorf("failed to extract records: %w", err) } - // Convert to dnsrecord.Record records := make([]dnsrecord.Record, 0, len(recordMaps)) for _, recordMap := range recordMaps { record, err := mapper.FromProviderFormat(recordMap, p.mappings.Response) @@ -93,88 +115,62 @@ func (p *RESTProvider) GetRecords(domainName string) ([]dnsrecord.Record, error) return records, nil } -// SetRecords sets DNS records for a domain (replaces all existing records) -func (p *RESTProvider) SetRecords(domainName string, records []dnsrecord.Record) error { - // Most REST APIs don't support bulk replace, so we need to: - // 1. Get existing records - // 2. Delete all existing records - // 3. Create new records - - existingRecords, err := p.GetRecords(domainName) - if err != nil { - return fmt.Errorf("failed to get existing records: %w", err) +// CreateRecord creates a new DNS record +func (p *RESTProvider) CreateRecord(ctx context.Context, zoneID string, record dnsrecord.Record) (dnsrecord.Record, error) { + endpoint, ok := p.endpoints["create_record"] + if !ok { + return dnsrecord.Record{}, fmt.Errorf("create_record endpoint not configured") } - ctx := context.Background() + endpoint = strings.ReplaceAll(endpoint, "{zone_id}", zoneID) + endpoint = strings.ReplaceAll(endpoint, "{domain}", zoneID) - // Delete existing records - for _, record := range existingRecords { - if err := p.deleteRecord(ctx, domainName, record); err != nil { - // Log but continue - some records might not exist - continue - } - } + body := mapper.ToProviderFormat(record, p.mappings.Request) - // Create new records - for _, record := range records { - if err := p.createRecord(ctx, domainName, record); err != nil { - return fmt.Errorf("failed to create record: %w", err) - } + resp, err := p.client.Post(ctx, endpoint, body) + if err != nil { + return dnsrecord.Record{}, errors.NewAPI("CreateRecord", "failed to create DNS record", err) } + defer resp.Body.Close() - return nil + // TODO: Parse response to get ID + return record, nil } -// createRecord creates a single DNS record -func (p *RESTProvider) createRecord(ctx context.Context, domainName string, record dnsrecord.Record) error { - endpoint, ok := p.endpoints["create_record"] +// UpdateRecord updates an existing DNS record +func (p *RESTProvider) UpdateRecord(ctx context.Context, zoneID string, recordID string, record dnsrecord.Record) (dnsrecord.Record, error) { + endpoint, ok := p.endpoints["update_record"] if !ok { - return fmt.Errorf("create_record endpoint not configured") + return dnsrecord.Record{}, fmt.Errorf("update_record endpoint not configured") } - endpoint = p.replacePlaceholders(endpoint, domainName) - zoneID, _ := p.getZoneID(domainName) - if zoneID != "" { - endpoint = strings.ReplaceAll(endpoint, "{zone_id}", zoneID) - } + endpoint = strings.ReplaceAll(endpoint, "{zone_id}", zoneID) + endpoint = strings.ReplaceAll(endpoint, "{domain}", zoneID) + endpoint = strings.ReplaceAll(endpoint, "{record_id}", recordID) + endpoint = strings.ReplaceAll(endpoint, "{id}", recordID) - // Convert record to provider format body := mapper.ToProviderFormat(record, p.mappings.Request) - resp, err := p.client.Post(ctx, endpoint, body) + resp, err := p.client.Put(ctx, endpoint, body) if err != nil { - return errors.NewAPI("CreateRecord", "failed to create DNS record", err) + return dnsrecord.Record{}, errors.NewAPI("UpdateRecord", "failed to update DNS record", err) } defer resp.Body.Close() - return nil + return record, nil } -// deleteRecord deletes a single DNS record -func (p *RESTProvider) deleteRecord(ctx context.Context, domainName string, record dnsrecord.Record) error { +// DeleteRecord deletes a DNS record +func (p *RESTProvider) DeleteRecord(ctx context.Context, zoneID string, recordID string) error { endpoint, ok := p.endpoints["delete_record"] if !ok { - // If delete endpoint not configured, try to use record ID - // For now, skip if not configured - return nil + return fmt.Errorf("delete_record endpoint not configured") } - endpoint = p.replacePlaceholders(endpoint, domainName) - zoneID, _ := p.getZoneID(domainName) - if zoneID != "" { - endpoint = strings.ReplaceAll(endpoint, "{zone_id}", zoneID) - } - - // Replace {record_id} or {id} placeholders with the record's ID if provided - if strings.Contains(endpoint, "{record_id}") || strings.Contains(endpoint, "{id}") || strings.Contains(endpoint, "{recordId}") { - // Prefer record.ID - if record.ID == "" { - return fmt.Errorf("delete_record requires record_id - record is missing ID") - } - endpoint = strings.ReplaceAll(endpoint, "{record_id}", record.ID) - endpoint = strings.ReplaceAll(endpoint, "{id}", record.ID) - endpoint = strings.ReplaceAll(endpoint, "{recordId}", record.ID) - } + endpoint = strings.ReplaceAll(endpoint, "{zone_id}", zoneID) + endpoint = strings.ReplaceAll(endpoint, "{domain}", zoneID) + endpoint = strings.ReplaceAll(endpoint, "{record_id}", recordID) + endpoint = strings.ReplaceAll(endpoint, "{id}", recordID) resp, err := p.client.Delete(ctx, endpoint) if err != nil { @@ -185,117 +181,58 @@ func (p *RESTProvider) deleteRecord(ctx context.Context, domainName string, reco return nil } -// Validate checks if the provider is properly configured -func (p *RESTProvider) Validate() error { - if p.client == nil { - return fmt.Errorf("HTTP client is not initialized") +// BulkReplaceRecords replaces all records in a zone with the provided set +func (p *RESTProvider) BulkReplaceRecords(ctx context.Context, zoneID string, records []dnsrecord.Record) error { + // Naive implementation + existing, err := p.ListRecords(ctx, zoneID) + if err != nil { + return err } - if p.name == "" { - return fmt.Errorf("provider name is empty") + + for _, r := range existing { + if r.ID != "" { + _ = p.DeleteRecord(ctx, zoneID, r.ID) + } } - if len(p.endpoints) == 0 { - return fmt.Errorf("no endpoints configured") + + for _, r := range records { + _, err := p.CreateRecord(ctx, zoneID, r) + if err != nil { + return err + } } return nil } -// Helper methods - -func (p *RESTProvider) replacePlaceholders(endpoint, domainName string) string { - endpoint = strings.ReplaceAll(endpoint, "{domain}", domainName) - return endpoint -} - -func (p *RESTProvider) getZoneID(domainName string) (string, error) { - // 1. Check if zone_id is in settings - if zoneID, ok := p.settings["zone_id"].(string); ok && zoneID != "" { - return zoneID, nil - } - - // 2. Try configured endpoints that may list or get zones - candidates := []string{"get_zone", "get_zone_by_name", "list_zones", "zones", "search_zones"} - for _, key := range candidates { - if path, ok := p.endpoints[key]; ok && path != "" { - // Replace placeholders - endpoint := p.replacePlaceholders(path, domainName) - - ctx := context.Background() - // If endpoint does not include domain placeholder, try passing domain as query param 'name' - query := map[string]string{} - if !strings.Contains(endpoint, "{domain}") { - query["name"] = domainName - } - - resp, err := p.client.Get(ctx, endpoint, query) - if err != nil { - // Try next candidate - continue - } - - var data interface{} - if err := httpprovider.ParseJSONResponse(resp, &data); err != nil { - continue - } - - // Search for matching zone object - // Check for object with 'result' array (Cloudflare style) - if m, ok := data.(map[string]interface{}); ok { - // Search arrays at top level - for _, v := range m { - switch arr := v.(type) { - case []interface{}: - for _, item := range arr { - if id := extractIDForDomain(item, domainName); id != "" { - return id, nil - } - } - case map[string]interface{}: - if id := extractIDForDomain(arr, domainName); id != "" { - return id, nil - } - } - } - } - // As fallback, try top-level array - if arr, ok := data.([]interface{}); ok { - for _, item := range arr { - if id := extractIDForDomain(item, domainName); id != "" { - return id, nil - } - } - } - } +// Capabilities returns the provider's capabilities +func (p *RESTProvider) Capabilities() dnsprovider.ProviderCapabilities { + return dnsprovider.ProviderCapabilities{ + CanListZones: p.hasEndpoint("list_zones") || p.hasEndpoint("zones"), + CanGetZone: true, + CanCreateRecord: p.hasEndpoint("create_record"), + CanUpdateRecord: p.hasEndpoint("update_record"), + CanDeleteRecord: p.hasEndpoint("delete_record"), + CanBulkReplace: true, } +} - // 3. Not found - return "", nil +func (p *RESTProvider) hasEndpoint(name string) bool { + _, ok := p.endpoints[name] + return ok } -// extractIDForDomain tries to extract an 'id' field from an object if it matches the provided domain name -func extractIDForDomain(item interface{}, domainName string) string { - obj, ok := item.(map[string]interface{}) - if !ok { - return "" +// Validate checks if the provider is properly configured +func (p *RESTProvider) Validate() error { + if p.client == nil { + return fmt.Errorf("HTTP client is not initialized") } - - // Check common name fields - nameCandidates := []string{"name", "zone", "domain", "zone_name"} - for _, nc := range nameCandidates { - if v, ok := obj[nc]; ok { - if vs, ok := v.(string); ok && strings.EqualFold(strings.TrimSuffix(vs, "."), domainName) { - // Found matching name; extract id - for _, idc := range []string{"id", "zone_id", "dns_record_id"} { - if idv, ok := obj[idc]; ok { - return fmt.Sprintf("%v", idv) - } - } - } - } + if p.name == "" { + return fmt.Errorf("provider name is empty") } - - // Only return an ID if it was found alongside a matching name; otherwise, no match - return "" + if len(p.endpoints) == 0 { + return fmt.Errorf("no endpoints configured") + } + return nil } -// Ensure RESTProvider implements Provider interface var _ dnsprovider.Provider = (*RESTProvider)(nil) diff --git a/pkg/dns/provider/rest/rest_test.go b/pkg/dns/provider/rest/rest_test.go index aa0e26f..1e6e61f 100644 --- a/pkg/dns/provider/rest/rest_test.go +++ b/pkg/dns/provider/rest/rest_test.go @@ -8,7 +8,6 @@ import ( httpclient "zonekit/pkg/dns/provider/http" "zonekit/pkg/dns/provider/mapper" - "zonekit/pkg/dnsrecord" "github.com/stretchr/testify/require" ) @@ -28,16 +27,16 @@ func TestDeleteRecord_ByID_Success(t *testing.T) { mappings := mapper.DefaultMappings() p := NewRESTProvider("test", client, mappings, map[string]string{"delete_record": "/records/{record_id}"}, nil) - err := p.deleteRecord(context.Background(), "example.com", dnsrecord.Record{ID: "abc123"}) + err := p.DeleteRecord(context.Background(), "example.com", "abc123") require.NoError(t, err) } -func TestDeleteRecord_MissingID_Error(t *testing.T) { +func TestDeleteRecord_MissingEndpoint_Error(t *testing.T) { client := httpclient.NewClient(httpclient.ClientConfig{BaseURL: "http://example.invalid"}) mappings := mapper.DefaultMappings() - p := NewRESTProvider("test", client, mappings, map[string]string{"delete_record": "/records/{record_id}"}, nil) + p := NewRESTProvider("test", client, mappings, map[string]string{}, nil) - err := p.deleteRecord(context.Background(), "example.com", dnsrecord.Record{}) + err := p.DeleteRecord(context.Background(), "example.com", "abc123") require.Error(t, err) - require.Contains(t, err.Error(), "requires record_id") + require.Contains(t, err.Error(), "delete_record endpoint not configured") } diff --git a/pkg/dns/provider/rest/zoneid_test.go b/pkg/dns/provider/rest/zoneid_test.go deleted file mode 100644 index d764bfe..0000000 --- a/pkg/dns/provider/rest/zoneid_test.go +++ /dev/null @@ -1,53 +0,0 @@ -package rest - -import ( - "net/http" - "net/http/httptest" - "testing" - - httpclient "zonekit/pkg/dns/provider/http" - "zonekit/pkg/dns/provider/mapper" - - "github.com/stretchr/testify/require" -) - -func TestGetZoneID_FromSettings(t *testing.T) { - client := httpclient.NewClient(httpclient.ClientConfig{BaseURL: "http://example.invalid"}) - p := NewRESTProvider("test", client, mapper.DefaultMappings(), map[string]string{}, map[string]interface{}{"zone_id": "z-123"}) - - id, err := p.getZoneID("example.com") - require.NoError(t, err) - require.Equal(t, "z-123", id) -} - -func TestGetZoneID_FromListEndpoint(t *testing.T) { - // Test server returns zone list - ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.Header().Set("Content-Type", "application/json") - w.Write([]byte(`{"result":[{"id":"z-1","name":"example.com"}]}`)) - })) - defer ts.Close() - - client := httpclient.NewClient(httpclient.ClientConfig{BaseURL: ts.URL}) - p := NewRESTProvider("test", client, mapper.DefaultMappings(), map[string]string{"list_zones": "/zones"}, nil) - - id, err := p.getZoneID("example.com") - require.NoError(t, err) - require.Equal(t, "z-1", id) -} - -func TestGetZoneID_NoMatch_ReturnsEmpty(t *testing.T) { - // Test server returns unrelated zone - ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.Header().Set("Content-Type", "application/json") - w.Write([]byte(`{"result":[{"id":"z-1","name":"other.com"}]}`)) - })) - defer ts.Close() - - client := httpclient.NewClient(httpclient.ClientConfig{BaseURL: ts.URL}) - p := NewRESTProvider("test", client, mapper.DefaultMappings(), map[string]string{"list_zones": "/zones"}, nil) - - id, err := p.getZoneID("example.com") - require.NoError(t, err) - require.Equal(t, "", id) -} diff --git a/pkg/dns/service.go b/pkg/dns/service.go index 0fa9e4f..92589e3 100644 --- a/pkg/dns/service.go +++ b/pkg/dns/service.go @@ -1,6 +1,7 @@ package dns import ( + "context" "fmt" "strings" @@ -48,14 +49,53 @@ func NewServiceWithProviderName(providerName string) (*Service, error) { }, nil } +// resolveZoneID resolves a domain name to a zone ID +func (s *Service) resolveZoneID(ctx context.Context, domainName string) (string, error) { + // 1. Try GetZone assuming ID == domainName + if s.provider.Capabilities().CanGetZone { + z, err := s.provider.GetZone(ctx, domainName) + if err == nil { + return z.ID, nil + } + } + + // 2. ListZones + if s.provider.Capabilities().CanListZones { + zones, err := s.provider.ListZones(ctx) + if err != nil { + // Don't fail here, try fallback + } else { + for _, z := range zones { + // Basic matching + if strings.EqualFold(z.Name, domainName) || strings.EqualFold(z.Name, domainName+".") { + return z.ID, nil + } + } + } + } + + // Fallback: use domainName as ID + return domainName, nil +} + // GetRecords retrieves all DNS records for a domain func (s *Service) GetRecords(domainName string) ([]dnsrecord.Record, error) { - return s.provider.GetRecords(domainName) + ctx := context.Background() + zoneID, err := s.resolveZoneID(ctx, domainName) + if err != nil { + return nil, err + } + return s.provider.ListRecords(ctx, zoneID) } // SetRecords sets DNS records for a domain (replaces all existing records) func (s *Service) SetRecords(domainName string, records []dnsrecord.Record) error { - return s.provider.SetRecords(domainName, records) + ctx := context.Background() + zoneID, err := s.resolveZoneID(ctx, domainName) + if err != nil { + return err + } + return s.provider.BulkReplaceRecords(ctx, zoneID, records) } // AddRecord adds a single DNS record to a domain @@ -65,8 +105,19 @@ func (s *Service) AddRecord(domainName string, record dnsrecord.Record) error { return fmt.Errorf("invalid record: %w", err) } - // Get existing records - existingRecords, err := s.GetRecords(domainName) + ctx := context.Background() + zoneID, err := s.resolveZoneID(ctx, domainName) + if err != nil { + return err + } + + if s.provider.Capabilities().CanCreateRecord { + _, err := s.provider.CreateRecord(ctx, zoneID, record) + return err + } + + // Fallback: Get existing records + existingRecords, err := s.provider.ListRecords(ctx, zoneID) if err != nil { return fmt.Errorf("failed to get existing records: %w", err) } @@ -75,22 +126,30 @@ func (s *Service) AddRecord(domainName string, record dnsrecord.Record) error { allRecords := append(existingRecords, record) // Set all records - return s.SetRecords(domainName, allRecords) + return s.provider.BulkReplaceRecords(ctx, zoneID, allRecords) } // UpdateRecord updates a DNS record by hostname and type func (s *Service) UpdateRecord(domainName string, hostname, recordType string, newRecord dnsrecord.Record) error { - // Get existing records - existingRecords, err := s.GetRecords(domainName) + ctx := context.Background() + zoneID, err := s.resolveZoneID(ctx, domainName) + if err != nil { + return err + } + + // Find the record to get ID + existingRecords, err := s.provider.ListRecords(ctx, zoneID) if err != nil { return fmt.Errorf("failed to get existing records: %w", err) } - // Find and update the record + var recordID string + var foundIndex int found := false for i, record := range existingRecords { if record.HostName == hostname && record.RecordType == recordType { - existingRecords[i] = newRecord + recordID = record.ID + foundIndex = i found = true break } @@ -100,23 +159,36 @@ func (s *Service) UpdateRecord(domainName string, hostname, recordType string, n return errors.NewNotFound("DNS record", fmt.Sprintf("%s %s", hostname, recordType)) } - // Set all records - return s.SetRecords(domainName, existingRecords) + if s.provider.Capabilities().CanUpdateRecord && recordID != "" { + _, err := s.provider.UpdateRecord(ctx, zoneID, recordID, newRecord) + return err + } + + // Fallback to bulk replace + existingRecords[foundIndex] = newRecord + return s.provider.BulkReplaceRecords(ctx, zoneID, existingRecords) } // DeleteRecord removes a DNS record by hostname and type func (s *Service) DeleteRecord(domainName string, hostname, recordType string) error { - // Get existing records - existingRecords, err := s.GetRecords(domainName) + ctx := context.Background() + zoneID, err := s.resolveZoneID(ctx, domainName) + if err != nil { + return err + } + + existingRecords, err := s.provider.ListRecords(ctx, zoneID) if err != nil { return fmt.Errorf("failed to get existing records: %w", err) } - // Filter out the record to delete - var filteredRecords []dnsrecord.Record + var recordID string found := false + var filteredRecords []dnsrecord.Record + for _, record := range existingRecords { if record.HostName == hostname && record.RecordType == recordType { + recordID = record.ID found = true continue } @@ -127,8 +199,11 @@ func (s *Service) DeleteRecord(domainName string, hostname, recordType string) e return errors.NewNotFound("DNS record", fmt.Sprintf("%s %s", hostname, recordType)) } - // Set remaining records - return s.SetRecords(domainName, filteredRecords) + if s.provider.Capabilities().CanDeleteRecord && recordID != "" { + return s.provider.DeleteRecord(ctx, zoneID, recordID) + } + + return s.provider.BulkReplaceRecords(ctx, zoneID, filteredRecords) } // DeleteAllRecords removes all DNS records for a domain diff --git a/pkg/dns/service_test.go b/pkg/dns/service_test.go index 3e5956f..5ce225d 100644 --- a/pkg/dns/service_test.go +++ b/pkg/dns/service_test.go @@ -1,6 +1,7 @@ package dns import ( + "context" "errors" "testing" @@ -30,21 +31,52 @@ func (m *mockProvider) Name() string { return m.name } -func (m *mockProvider) GetRecords(domainName string) ([]dnsrecord.Record, error) { +func (m *mockProvider) ListZones(ctx context.Context) ([]provider.Zone, error) { + return nil, nil +} + +func (m *mockProvider) GetZone(ctx context.Context, zoneID string) (provider.Zone, error) { + return provider.Zone{}, nil +} + +func (m *mockProvider) ListRecords(ctx context.Context, zoneID string) ([]dnsrecord.Record, error) { if m.getRecordsError != nil { return nil, m.getRecordsError } - return m.records[domainName], nil + return m.records[zoneID], nil +} + +func (m *mockProvider) CreateRecord(ctx context.Context, zoneID string, record dnsrecord.Record) (dnsrecord.Record, error) { + return dnsrecord.Record{}, nil +} + +func (m *mockProvider) UpdateRecord(ctx context.Context, zoneID string, recordID string, record dnsrecord.Record) (dnsrecord.Record, error) { + return dnsrecord.Record{}, nil +} + +func (m *mockProvider) DeleteRecord(ctx context.Context, zoneID string, recordID string) error { + return nil } -func (m *mockProvider) SetRecords(domainName string, records []dnsrecord.Record) error { +func (m *mockProvider) BulkReplaceRecords(ctx context.Context, zoneID string, records []dnsrecord.Record) error { if m.setRecordsError != nil { return m.setRecordsError } - m.records[domainName] = records + m.records[zoneID] = records return nil } +func (m *mockProvider) Capabilities() provider.ProviderCapabilities { + return provider.ProviderCapabilities{ + CanListZones: false, + CanGetZone: false, + CanCreateRecord: false, + CanUpdateRecord: false, + CanDeleteRecord: false, + CanBulkReplace: true, + } +} + func (m *mockProvider) Validate() error { return m.validateError } diff --git a/pkg/dnsrecord/record.go b/pkg/dnsrecord/record.go index 46d405d..d053c32 100644 --- a/pkg/dnsrecord/record.go +++ b/pkg/dnsrecord/record.go @@ -8,6 +8,12 @@ type Record struct { Address string TTL int MXPref int + Priority int + Weight int + Port int + Target string + Metadata map[string]string + Raw interface{} } // RecordType constants From 0f91261eb7f838b45fcc90394adef0b466c05814 Mon Sep 17 00:00:00 2001 From: "google-labs-jules[bot]" <161369871+google-labs-jules[bot]@users.noreply.github.com> Date: Tue, 3 Feb 2026 09:30:43 +0000 Subject: [PATCH 2/5] Implement Zone operations for REST Providers (Phase 2) - Update generic RESTProvider to implement ListZones and GetZone - Add ZoneListPath, ZoneID, ZoneName to Mappings and FieldMappings - Update OpenAPI parser to detect list_zones/get_zone endpoints and map zone fields - Update Cloudflare OpenAPI spec to include GET /zones endpoint - Update Builder to propagate zone mappings - Add tests for new mapper, openapi, and rest functionality Co-authored-by: SamyRai <919510+SamyRai@users.noreply.github.com> --- pkg/dns/provider/builder/builder.go | 16 ++- pkg/dns/provider/cloudflare/openapi.yaml | 31 +++++ pkg/dns/provider/mapper/mapper.go | 48 +++++++ pkg/dns/provider/mapper/mapper_test.go | 31 +++++ pkg/dns/provider/openapi/openapi.go | 167 +++++++++++++++-------- pkg/dns/provider/openapi/openapi_test.go | 9 ++ pkg/dns/provider/provider.go | 5 + pkg/dns/provider/rest/rest.go | 97 ++++++++++++- pkg/dns/provider/rest/rest_test.go | 34 +++++ 9 files changed, 375 insertions(+), 63 deletions(-) diff --git a/pkg/dns/provider/builder/builder.go b/pkg/dns/provider/builder/builder.go index a1decfc..0147bd9 100644 --- a/pkg/dns/provider/builder/builder.go +++ b/pkg/dns/provider/builder/builder.go @@ -87,7 +87,21 @@ func buildMappings(configMappings *dnsprovider.FieldMappings) mapper.Mappings { } m := mapper.Mappings{ - ListPath: configMappings.ListPath, + ListPath: configMappings.ListPath, + ZoneListPath: configMappings.ZoneListPath, + ZoneID: configMappings.ZoneID, + ZoneName: configMappings.ZoneName, + } + + // Set defaults for zone mappings if empty + if m.ZoneListPath == "" { + m.ZoneListPath = "zones" + } + if m.ZoneID == "" { + m.ZoneID = "id" + } + if m.ZoneName == "" { + m.ZoneName = "name" } // Request mappings diff --git a/pkg/dns/provider/cloudflare/openapi.yaml b/pkg/dns/provider/cloudflare/openapi.yaml index fdb2423..4e38cc5 100644 --- a/pkg/dns/provider/cloudflare/openapi.yaml +++ b/pkg/dns/provider/cloudflare/openapi.yaml @@ -7,6 +7,19 @@ servers: - url: https://api.cloudflare.com/client/v4 description: Cloudflare API v4 paths: + /zones: + get: + summary: List Zones + operationId: listZones + tags: + - Zones + responses: + '200': + description: List of Zones + content: + application/json: + schema: + $ref: '#/components/schemas/ZoneResponse' /zones/{zone_id}/dns_records: get: summary: List DNS records @@ -223,6 +236,24 @@ components: properties: result: $ref: '#/components/schemas/DNSRecord' + Zone: + type: object + properties: + id: + type: string + description: Zone identifier + name: + type: string + description: Zone name + status: + type: string + ZoneResponse: + type: object + properties: + result: + type: array + items: + $ref: '#/components/schemas/Zone' security: - ApiKeyAuth: [] - EmailAuth: [] diff --git a/pkg/dns/provider/mapper/mapper.go b/pkg/dns/provider/mapper/mapper.go index 639463e..5b6b455 100644 --- a/pkg/dns/provider/mapper/mapper.go +++ b/pkg/dns/provider/mapper/mapper.go @@ -5,6 +5,7 @@ import ( "reflect" "strings" + "zonekit/pkg/dns/provider" "zonekit/pkg/dnsrecord" ) @@ -13,6 +14,11 @@ type Mappings struct { Request FieldMapping Response FieldMapping ListPath string // JSON path to records array (e.g., "result" or "data.records") + + // Zone Mappings + ZoneListPath string // JSON path to zones array + ZoneID string // Field name for Zone ID + ZoneName string // Field name for Zone Name } // FieldMapping defines how to map fields @@ -45,6 +51,9 @@ func DefaultMappings() Mappings { ID: "", }, ListPath: "records", + ZoneListPath: "zones", + ZoneID: "id", + ZoneName: "name", } } @@ -128,6 +137,45 @@ func FromProviderFormat(data map[string]interface{}, mapping FieldMapping) (dnsr // ExtractRecords extracts records from a JSON response using the list path func ExtractRecords(data interface{}, listPath string) ([]map[string]interface{}, error) { + return extractList(data, listPath) +} + +// ExtractZones extracts zones from a JSON response using the list path +func ExtractZones(data interface{}, listPath string) ([]map[string]interface{}, error) { + return extractList(data, listPath) +} + +// FromProviderZoneFormat converts provider's zone format to provider.Zone +func FromProviderZoneFormat(data map[string]interface{}, mappings Mappings) (provider.Zone, error) { + zone := provider.Zone{} + + getString := func(key string) string { + if val, ok := data[key]; ok { + if str, ok := val.(string); ok { + return str + } + return fmt.Sprintf("%v", val) + } + return "" + } + + if mappings.ZoneID != "" { + zone.ID = getString(mappings.ZoneID) + } + if mappings.ZoneName != "" { + zone.Name = getString(mappings.ZoneName) + } + + if zone.ID == "" { + return zone, fmt.Errorf("zone ID not found in response (mapped to %s)", mappings.ZoneID) + } + + return zone, nil +} + + +// extractList is a helper to extract a list of maps from a JSON response +func extractList(data interface{}, listPath string) ([]map[string]interface{}, error) { if listPath == "" { // Default: assume data is an array if arr, ok := data.([]interface{}); ok { diff --git a/pkg/dns/provider/mapper/mapper_test.go b/pkg/dns/provider/mapper/mapper_test.go index 97c5550..5c120f1 100644 --- a/pkg/dns/provider/mapper/mapper_test.go +++ b/pkg/dns/provider/mapper/mapper_test.go @@ -48,3 +48,34 @@ func TestToProviderFormat_IncludesID(t *testing.T) { require.Equal(t, "abc123", m["id"]) require.Equal(t, "www", m["hostname"]) } + +func TestExtractZones(t *testing.T) { + data := map[string]interface{}{ + "result": []interface{}{ + map[string]interface{}{"id": "1", "name": "example.com"}, + map[string]interface{}{"id": "2", "name": "example.org"}, + }, + } + + zones, err := ExtractZones(data, "result") + require.NoError(t, err) + require.Len(t, zones, 2) + require.Equal(t, "1", zones[0]["id"]) +} + +func TestFromProviderZoneFormat(t *testing.T) { + data := map[string]interface{}{ + "id": "z1", + "name": "example.com", + } + + mapping := Mappings{ + ZoneID: "id", + ZoneName: "name", + } + + zone, err := FromProviderZoneFormat(data, mapping) + require.NoError(t, err) + require.Equal(t, "z1", zone.ID) + require.Equal(t, "example.com", zone.Name) +} diff --git a/pkg/dns/provider/openapi/openapi.go b/pkg/dns/provider/openapi/openapi.go index 3d9b7da..9900fe8 100644 --- a/pkg/dns/provider/openapi/openapi.go +++ b/pkg/dns/provider/openapi/openapi.go @@ -164,6 +164,12 @@ func (s *Spec) mapOperationToEndpoint(method, operationID, path string) string { if strings.Contains(path, "record") || strings.Contains(path, "dns") { return "get_records" } + if strings.Contains(path, "zones") || strings.Contains(path, "domains") { + if strings.Contains(operationID, "list") { + return "list_zones" + } + return "get_zone" + } } if strings.Contains(operationID, "create") || strings.Contains(operationID, "add") { if strings.Contains(path, "record") || strings.Contains(path, "dns") { @@ -187,6 +193,12 @@ func (s *Spec) mapOperationToEndpoint(method, operationID, path string) string { if strings.Contains(path, "record") || strings.Contains(path, "dns") { return "get_records" } + if (strings.Contains(path, "zones") || strings.Contains(path, "domains")) && !strings.Contains(path, "{") { + return "list_zones" + } + if strings.Contains(path, "zones") || strings.Contains(path, "domains") { + return "get_zone" + } case "post": if strings.Contains(path, "record") || strings.Contains(path, "dns") { return "create_record" @@ -272,78 +284,119 @@ func (s *Spec) extractMappings() *dnsprovider.FieldMappings { for schemaName, schema := range s.Components.Schemas { origName := schemaName schemaName = strings.ToLower(schemaName) - if !strings.Contains(schemaName, "record") && !strings.Contains(schemaName, "dns") { - continue + + // Check for Record schema + if strings.Contains(schemaName, "record") && !strings.Contains(schemaName, "zone") { + s.mapRecordSchema(schema, mappings, origName) } - schemaMap, ok := schema.(map[string]interface{}) - if !ok { - continue + // Check for Zone schema + if strings.Contains(schemaName, "zone") || strings.Contains(schemaName, "domain") { + if !strings.Contains(schemaName, "record") { + s.mapZoneSchema(schema, mappings, origName) + } } + } - properties, ok := schemaMap["properties"].(map[string]interface{}) - if !ok { - continue + return mappings +} + +func (s *Spec) mapRecordSchema(schema interface{}, mappings *dnsprovider.FieldMappings, origName string) { + schemaMap, ok := schema.(map[string]interface{}) + if !ok { + return + } + + properties, ok := schemaMap["properties"].(map[string]interface{}) + if !ok { + return + } + + // Map common DNS record fields + for propName := range properties { + propLower := strings.ToLower(propName) + switch propLower { + case "name", "hostname", "host": + mappings.Request.HostName = propName + mappings.Response.HostName = propName + case "type", "recordtype", "record_type": + mappings.Request.RecordType = propName + mappings.Response.RecordType = propName + case "content", "data", "value", "address": + mappings.Request.Address = propName + mappings.Response.Address = propName + case "ttl": + mappings.Request.TTL = propName + mappings.Response.TTL = propName + case "priority", "preference", "mxpref", "mx_pref": + mappings.Request.MXPref = propName + mappings.Response.MXPref = propName + case "id", "recordid", "record_id", "_id": + mappings.Request.ID = propName + mappings.Response.ID = propName } + } - // Map common DNS record fields - for propName := range properties { - propLower := strings.ToLower(propName) - switch propLower { - case "name", "hostname", "host": - mappings.Request.HostName = propName - mappings.Response.HostName = propName - case "type", "recordtype", "record_type": - mappings.Request.RecordType = propName - mappings.Response.RecordType = propName - case "content", "data", "value", "address": - mappings.Request.Address = propName - mappings.Response.Address = propName - case "ttl": - mappings.Request.TTL = propName - mappings.Response.TTL = propName - case "priority", "preference", "mxpref", "mx_pref": - mappings.Request.MXPref = propName - mappings.Response.MXPref = propName - case "id", "recordid", "record_id", "_id": - mappings.Request.ID = propName - mappings.Response.ID = propName - } + // Try to find list path by inspecting other schemas + if mappings.ListPath == "" { + mappings.ListPath = s.findListPath(origName) + } +} + +func (s *Spec) mapZoneSchema(schema interface{}, mappings *dnsprovider.FieldMappings, origName string) { + schemaMap, ok := schema.(map[string]interface{}) + if !ok { + return + } + + properties, ok := schemaMap["properties"].(map[string]interface{}) + if !ok { + return + } + + // Map common Zone fields + for propName := range properties { + propLower := strings.ToLower(propName) + switch propLower { + case "name", "domain", "zonename", "zone_name": + mappings.ZoneName = propName + case "id", "zoneid", "zone_id", "_id": + mappings.ZoneID = propName } + } - // Try to find list path by inspecting other schemas for arrays of this schema - for _, otherSchema := range s.Components.Schemas { - otherSchemaMap, ok := otherSchema.(map[string]interface{}) - if !ok { - continue - } + // Try to find list path by inspecting other schemas + if mappings.ZoneListPath == "" { + mappings.ZoneListPath = s.findListPath(origName) + } +} - // Look for properties that are arrays with items referencing this schema - if props, ok := otherSchemaMap["properties"].(map[string]interface{}); ok { - for propName, prop := range props { - propMap, ok := prop.(map[string]interface{}) - if !ok { - continue - } +func (s *Spec) findListPath(targetSchemaName string) string { + for _, otherSchema := range s.Components.Schemas { + otherSchemaMap, ok := otherSchema.(map[string]interface{}) + if !ok { + continue + } - if propMap["type"] == "array" { - if items, ok := propMap["items"].(map[string]interface{}); ok { - if ref, ok := items["$ref"].(string); ok { - refLower := strings.ToLower(ref) - if strings.Contains(refLower, strings.ToLower(origName)) || strings.HasSuffix(refLower, "/"+strings.ToLower(origName)) { - mappings.ListPath = propName - break - } + if props, ok := otherSchemaMap["properties"].(map[string]interface{}); ok { + for propName, prop := range props { + propMap, ok := prop.(map[string]interface{}) + if !ok { + continue + } + + if propMap["type"] == "array" { + if items, ok := propMap["items"].(map[string]interface{}); ok { + if ref, ok := items["$ref"].(string); ok { + refLower := strings.ToLower(ref) + if strings.Contains(refLower, strings.ToLower(targetSchemaName)) || strings.HasSuffix(refLower, "/"+targetSchemaName) { + return propName } } } } } - if mappings.ListPath != "" { - break - } } } - - return mappings + return "" } diff --git a/pkg/dns/provider/openapi/openapi_test.go b/pkg/dns/provider/openapi/openapi_test.go index d9b9104..3fc2b11 100644 --- a/pkg/dns/provider/openapi/openapi_test.go +++ b/pkg/dns/provider/openapi/openapi_test.go @@ -23,10 +23,19 @@ func TestCloudflareSpec_ToProviderConfig(t *testing.T) { require.Contains(t, cfg.API.Endpoints, "delete_record") require.Equal(t, "/zones/{zone_id}/dns_records/{dns_record_id}", cfg.API.Endpoints["delete_record"]) + // Zone endpoints + require.Contains(t, cfg.API.Endpoints, "list_zones") + require.Equal(t, "/zones", cfg.API.Endpoints["list_zones"]) + // Mappings require.NotNil(t, cfg.Mappings) require.Equal(t, "id", cfg.Mappings.Response.ID) // List path should be detected require.Equal(t, "result", cfg.Mappings.ListPath) + + // Zone mappings + require.Equal(t, "id", cfg.Mappings.ZoneID) + require.Equal(t, "name", cfg.Mappings.ZoneName) + require.Equal(t, "result", cfg.Mappings.ZoneListPath) } diff --git a/pkg/dns/provider/provider.go b/pkg/dns/provider/provider.go index b2970b3..649f247 100644 --- a/pkg/dns/provider/provider.go +++ b/pkg/dns/provider/provider.go @@ -121,4 +121,9 @@ type FieldMappings struct { // List response structure (for REST providers) ListPath string `yaml:"list_path,omitempty"` // JSON path to records array, e.g., "data.records" + + // Zone Mappings + ZoneListPath string `yaml:"zone_list_path,omitempty"` + ZoneID string `yaml:"zone_id,omitempty"` + ZoneName string `yaml:"zone_name,omitempty"` } diff --git a/pkg/dns/provider/rest/rest.go b/pkg/dns/provider/rest/rest.go index 7a88f0b..a4e2085 100644 --- a/pkg/dns/provider/rest/rest.go +++ b/pkg/dns/provider/rest/rest.go @@ -45,7 +45,18 @@ func (p *RESTProvider) Name() string { // ListZones retrieves all zones managed by the provider func (p *RESTProvider) ListZones(ctx context.Context) ([]dnsprovider.Zone, error) { - // check for list_zones or zones endpoint + // 1. Check if a static zone_id is configured in settings + if zoneID, ok := p.settings["zone_id"].(string); ok && zoneID != "" { + zoneName := zoneID // Default name to ID if not provided + if name, ok := p.settings["zone_name"].(string); ok && name != "" { + zoneName = name + } + return []dnsprovider.Zone{ + {ID: zoneID, Name: zoneName}, + }, nil + } + + // 2. Use list_zones endpoint if available endpoint, ok := p.endpoints["list_zones"] if !ok { endpoint, ok = p.endpoints["zones"] @@ -65,13 +76,84 @@ func (p *RESTProvider) ListZones(ctx context.Context) ([]dnsprovider.Zone, error return nil, fmt.Errorf("failed to parse response: %w", err) } - // TODO: Implement zone mapping - return []dnsprovider.Zone{}, nil + zoneMaps, err := mapper.ExtractZones(responseData, p.mappings.ZoneListPath) + if err != nil { + return nil, fmt.Errorf("failed to extract zones: %w", err) + } + + zones := make([]dnsprovider.Zone, 0, len(zoneMaps)) + for _, zoneMap := range zoneMaps { + zone, err := mapper.FromProviderZoneFormat(zoneMap, p.mappings) + if err != nil { + // Log error but continue? Or fail? For now, fail. + return nil, fmt.Errorf("failed to convert zone: %w", err) + } + zones = append(zones, zone) + } + + return zones, nil } // GetZone retrieves a specific zone by ID func (p *RESTProvider) GetZone(ctx context.Context, zoneID string) (dnsprovider.Zone, error) { - // Stub: return a zone with the ID and Name = ID + // 1. Check if a static zone_id matches + if configuredID, ok := p.settings["zone_id"].(string); ok && configuredID == zoneID { + zoneName := zoneID + if name, ok := p.settings["zone_name"].(string); ok && name != "" { + zoneName = name + } + return dnsprovider.Zone{ID: zoneID, Name: zoneName}, nil + } + + // 2. Use get_zone endpoint if available + endpoint, ok := p.endpoints["get_zone"] + if ok { + endpoint = strings.ReplaceAll(endpoint, "{zone_id}", zoneID) + endpoint = strings.ReplaceAll(endpoint, "{id}", zoneID) + + resp, err := p.client.Get(ctx, endpoint, nil) + if err != nil { + return dnsprovider.Zone{}, errors.NewAPI("GetZone", fmt.Sprintf("failed to get zone %s", zoneID), err) + } + + var responseData interface{} + if err := httpprovider.ParseJSONResponse(resp, &responseData); err != nil { + return dnsprovider.Zone{}, fmt.Errorf("failed to parse response: %w", err) + } + + // Assuming response is the zone object directly or wrapped in result + // We can try to reuse ExtractZones if it handles single object or wrap it? + // Or assume single object mapping. + // Usually REST APIs return { "result": { ... } } or just { ... } + + // If responseData has the wrapper path (e.g. "result"), extract it + dataToMap := responseData + if p.mappings.ZoneListPath != "" && p.mappings.ZoneListPath != "zones" { // Heuristic: list path might also be response wrapper + // Simple check: is it a map with that key? + if m, ok := responseData.(map[string]interface{}); ok { + if val, ok := m[p.mappings.ZoneListPath]; ok { + // Check if val is map (single object) or array (list) + // For GetZone, expect map + if vMap, ok := val.(map[string]interface{}); ok { + dataToMap = vMap + } + } + } + } else if m, ok := responseData.(map[string]interface{}); ok { + // standard "result" wrapper check + if res, ok := m["result"]; ok { + dataToMap = res + } + } + + if m, ok := dataToMap.(map[string]interface{}); ok { + return mapper.FromProviderZoneFormat(m, p.mappings) + } + return dnsprovider.Zone{}, fmt.Errorf("unexpected response format for GetZone") + } + + // Fallback/Stub: return a zone with the ID and Name = ID + // This maintains compatibility for simple providers where ID=Name return dnsprovider.Zone{ID: zoneID, Name: zoneID}, nil } @@ -207,7 +289,7 @@ func (p *RESTProvider) BulkReplaceRecords(ctx context.Context, zoneID string, re // Capabilities returns the provider's capabilities func (p *RESTProvider) Capabilities() dnsprovider.ProviderCapabilities { return dnsprovider.ProviderCapabilities{ - CanListZones: p.hasEndpoint("list_zones") || p.hasEndpoint("zones"), + CanListZones: p.hasEndpoint("list_zones") || p.hasEndpoint("zones") || p.hasSetting("zone_id"), CanGetZone: true, CanCreateRecord: p.hasEndpoint("create_record"), CanUpdateRecord: p.hasEndpoint("update_record"), @@ -221,6 +303,11 @@ func (p *RESTProvider) hasEndpoint(name string) bool { return ok } +func (p *RESTProvider) hasSetting(name string) bool { + _, ok := p.settings[name] + return ok +} + // Validate checks if the provider is properly configured func (p *RESTProvider) Validate() error { if p.client == nil { diff --git a/pkg/dns/provider/rest/rest_test.go b/pkg/dns/provider/rest/rest_test.go index 1e6e61f..663d032 100644 --- a/pkg/dns/provider/rest/rest_test.go +++ b/pkg/dns/provider/rest/rest_test.go @@ -40,3 +40,37 @@ func TestDeleteRecord_MissingEndpoint_Error(t *testing.T) { require.Error(t, err) require.Contains(t, err.Error(), "delete_record endpoint not configured") } + +func TestListZones_StaticConfig(t *testing.T) { + client := httpclient.NewClient(httpclient.ClientConfig{BaseURL: "http://example.invalid"}) + mappings := mapper.DefaultMappings() + settings := map[string]interface{}{"zone_id": "static-zone"} + p := NewRESTProvider("test", client, mappings, map[string]string{}, settings) + + zones, err := p.ListZones(context.Background()) + require.NoError(t, err) + require.Len(t, zones, 1) + require.Equal(t, "static-zone", zones[0].ID) +} + +func TestListZones_Endpoint(t *testing.T) { + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path == "/zones" { + w.Header().Set("Content-Type", "application/json") + w.Write([]byte(`{"zones": [{"id": "z1", "name": "example.com"}]}`)) + return + } + w.WriteHeader(http.StatusNotFound) + })) + defer ts.Close() + + client := httpclient.NewClient(httpclient.ClientConfig{BaseURL: ts.URL}) + mappings := mapper.DefaultMappings() + p := NewRESTProvider("test", client, mappings, map[string]string{"list_zones": "/zones"}, nil) + + zones, err := p.ListZones(context.Background()) + require.NoError(t, err) + require.Len(t, zones, 1) + require.Equal(t, "z1", zones[0].ID) + require.Equal(t, "example.com", zones[0].Name) +} From d5285f5774940e4361eb3d0c0eea120266d1daca Mon Sep 17 00:00:00 2001 From: "google-labs-jules[bot]" <161369871+google-labs-jules[bot]@users.noreply.github.com> Date: Tue, 3 Feb 2026 09:51:49 +0000 Subject: [PATCH 3/5] Harden codebase and enforce best practices for Provider Contract Refactor - Fix RESTProvider.BulkReplaceRecords to handle delete errors and check capabilities - Update RESTProvider to parse response body in CreateRecord and UpdateRecord to capture server-assigned fields (ID) - Fix potential panic in mapper.navigatePath by adding validation checks - Implement dynamic capabilities in RESTProvider - Update tests to verify new functionality and robustness Co-authored-by: SamyRai <919510+SamyRai@users.noreply.github.com> --- pkg/dns/provider/builder/builder.go | 1 + pkg/dns/provider/mapper/mapper.go | 73 ++++++++++++++++++------ pkg/dns/provider/openapi/openapi.go | 34 ++++++++++- pkg/dns/provider/openapi/openapi_test.go | 3 + pkg/dns/provider/provider.go | 3 + pkg/dns/provider/rest/rest.go | 59 +++++++++++++++---- pkg/dns/provider/rest/rest_test.go | 34 +++++++++++ 7 files changed, 180 insertions(+), 27 deletions(-) diff --git a/pkg/dns/provider/builder/builder.go b/pkg/dns/provider/builder/builder.go index 0147bd9..2d5eab0 100644 --- a/pkg/dns/provider/builder/builder.go +++ b/pkg/dns/provider/builder/builder.go @@ -88,6 +88,7 @@ func buildMappings(configMappings *dnsprovider.FieldMappings) mapper.Mappings { m := mapper.Mappings{ ListPath: configMappings.ListPath, + ResponsePath: configMappings.ResponsePath, ZoneListPath: configMappings.ZoneListPath, ZoneID: configMappings.ZoneID, ZoneName: configMappings.ZoneName, diff --git a/pkg/dns/provider/mapper/mapper.go b/pkg/dns/provider/mapper/mapper.go index 5b6b455..e00824e 100644 --- a/pkg/dns/provider/mapper/mapper.go +++ b/pkg/dns/provider/mapper/mapper.go @@ -14,6 +14,7 @@ type Mappings struct { Request FieldMapping Response FieldMapping ListPath string // JSON path to records array (e.g., "result" or "data.records") + ResponsePath string // JSON path to single record response (e.g., "result" or "domain_record") // Zone Mappings ZoneListPath string // JSON path to zones array @@ -50,7 +51,8 @@ func DefaultMappings() Mappings { MXPref: "mx_pref", ID: "", }, - ListPath: "records", + ListPath: "records", + ResponsePath: "", // Default: root object ZoneListPath: "zones", ZoneID: "id", ZoneName: "name", @@ -140,6 +142,11 @@ func ExtractRecords(data interface{}, listPath string) ([]map[string]interface{} return extractList(data, listPath) } +// ExtractRecord extracts a single record from a JSON response using the response path +func ExtractRecord(data interface{}, responsePath string) (map[string]interface{}, error) { + return extractObject(data, responsePath) +} + // ExtractZones extracts zones from a JSON response using the list path func ExtractZones(data interface{}, listPath string) ([]map[string]interface{}, error) { return extractList(data, listPath) @@ -184,8 +191,43 @@ func extractList(data interface{}, listPath string) ([]map[string]interface{}, e return nil, fmt.Errorf("no list path specified and data is not an array") } - // Navigate through the path (e.g., "result" or "data.records") - parts := strings.Split(listPath, ".") + val, err := navigatePath(data, listPath) + if err != nil { + return nil, err + } + + if arr, ok := val.([]interface{}); ok { + return convertArrayToMaps(arr) + } + + return nil, fmt.Errorf("path '%s' does not point to an array (got %T)", listPath, val) +} + +// extractObject is a helper to extract a single map from a JSON response +func extractObject(data interface{}, path string) (map[string]interface{}, error) { + if path == "" { + // Default: assume data is the object + if m, ok := data.(map[string]interface{}); ok { + return m, nil + } + return nil, fmt.Errorf("no path specified and data is not a map") + } + + val, err := navigatePath(data, path) + if err != nil { + return nil, err + } + + if m, ok := val.(map[string]interface{}); ok { + return m, nil + } + + return nil, fmt.Errorf("path '%s' does not point to a map (got %T)", path, val) +} + +// navigatePath traverses the JSON structure +func navigatePath(data interface{}, path string) (interface{}, error) { + parts := strings.Split(path, ".") current := reflect.ValueOf(data) for _, part := range parts { @@ -193,36 +235,35 @@ func extractList(data interface{}, listPath string) ([]map[string]interface{}, e current = current.Elem() } + // Check if valid after dereferencing interface or nil pointer + if !current.IsValid() { + return nil, fmt.Errorf("path '%s' contains nil value", path) + } + switch current.Kind() { case reflect.Map: key := reflect.ValueOf(part) current = current.MapIndex(key) if !current.IsValid() { - return nil, fmt.Errorf("path '%s' not found in response", listPath) + return nil, fmt.Errorf("path '%s' not found in response", path) } case reflect.Slice, reflect.Array: - // If we hit an array/slice, we're done navigating - break + return nil, fmt.Errorf("cannot navigate through array at '%s'", part) default: - return nil, fmt.Errorf("invalid path '%s': cannot navigate through %v", listPath, current.Kind()) + return nil, fmt.Errorf("invalid path '%s': cannot navigate through %v", path, current.Kind()) } } - // Convert to array of maps if current.Kind() == reflect.Interface { current = current.Elem() } - if current.Kind() != reflect.Slice && current.Kind() != reflect.Array { - return nil, fmt.Errorf("path '%s' does not point to an array", listPath) - } - - arr := make([]interface{}, current.Len()) - for i := 0; i < current.Len(); i++ { - arr[i] = current.Index(i).Interface() + // Check validity again + if !current.IsValid() { + return nil, nil // Valid nil result } - return convertArrayToMaps(arr) + return current.Interface(), nil } // convertArrayToMaps converts an array of interfaces to array of maps diff --git a/pkg/dns/provider/openapi/openapi.go b/pkg/dns/provider/openapi/openapi.go index 9900fe8..210ec33 100644 --- a/pkg/dns/provider/openapi/openapi.go +++ b/pkg/dns/provider/openapi/openapi.go @@ -337,10 +337,15 @@ func (s *Spec) mapRecordSchema(schema interface{}, mappings *dnsprovider.FieldMa } } - // Try to find list path by inspecting other schemas + // Try to find list path if mappings.ListPath == "" { mappings.ListPath = s.findListPath(origName) } + + // Try to find response path + if mappings.ResponsePath == "" { + mappings.ResponsePath = s.findResponsePath(origName) + } } func (s *Spec) mapZoneSchema(schema interface{}, mappings *dnsprovider.FieldMappings, origName string) { @@ -400,3 +405,30 @@ func (s *Spec) findListPath(targetSchemaName string) string { } return "" } + +func (s *Spec) findResponsePath(targetSchemaName string) string { + for _, otherSchema := range s.Components.Schemas { + otherSchemaMap, ok := otherSchema.(map[string]interface{}) + if !ok { + continue + } + + if props, ok := otherSchemaMap["properties"].(map[string]interface{}); ok { + for propName, prop := range props { + propMap, ok := prop.(map[string]interface{}) + if !ok { + continue + } + + // Look for property that is the target object directly (not array) + if ref, ok := propMap["$ref"].(string); ok { + refLower := strings.ToLower(ref) + if strings.Contains(refLower, strings.ToLower(targetSchemaName)) || strings.HasSuffix(refLower, "/"+targetSchemaName) { + return propName + } + } + } + } + } + return "" +} diff --git a/pkg/dns/provider/openapi/openapi_test.go b/pkg/dns/provider/openapi/openapi_test.go index 3fc2b11..8c61dce 100644 --- a/pkg/dns/provider/openapi/openapi_test.go +++ b/pkg/dns/provider/openapi/openapi_test.go @@ -34,6 +34,9 @@ func TestCloudflareSpec_ToProviderConfig(t *testing.T) { // List path should be detected require.Equal(t, "result", cfg.Mappings.ListPath) + // Response path should be detected + require.Equal(t, "result", cfg.Mappings.ResponsePath) + // Zone mappings require.Equal(t, "id", cfg.Mappings.ZoneID) require.Equal(t, "name", cfg.Mappings.ZoneName) diff --git a/pkg/dns/provider/provider.go b/pkg/dns/provider/provider.go index 649f247..6afb0ad 100644 --- a/pkg/dns/provider/provider.go +++ b/pkg/dns/provider/provider.go @@ -122,6 +122,9 @@ type FieldMappings struct { // List response structure (for REST providers) ListPath string `yaml:"list_path,omitempty"` // JSON path to records array, e.g., "data.records" + // Single record response structure (for REST providers) + ResponsePath string `yaml:"response_path,omitempty"` // JSON path to record object, e.g., "result" or "domain_record" + // Zone Mappings ZoneListPath string `yaml:"zone_list_path,omitempty"` ZoneID string `yaml:"zone_id,omitempty"` diff --git a/pkg/dns/provider/rest/rest.go b/pkg/dns/provider/rest/rest.go index a4e2085..aac4fa2 100644 --- a/pkg/dns/provider/rest/rest.go +++ b/pkg/dns/provider/rest/rest.go @@ -213,10 +213,24 @@ func (p *RESTProvider) CreateRecord(ctx context.Context, zoneID string, record d if err != nil { return dnsrecord.Record{}, errors.NewAPI("CreateRecord", "failed to create DNS record", err) } - defer resp.Body.Close() - // TODO: Parse response to get ID - return record, nil + // Parse response to get ID and other server-assigned fields + var responseData interface{} + if err := httpprovider.ParseJSONResponse(resp, &responseData); err != nil { + return dnsrecord.Record{}, fmt.Errorf("failed to parse response: %w", err) + } + + recordMap, err := mapper.ExtractRecord(responseData, p.mappings.ResponsePath) + if err != nil { + return dnsrecord.Record{}, fmt.Errorf("failed to extract record from response: %w", err) + } + + createdRecord, err := mapper.FromProviderFormat(recordMap, p.mappings.Response) + if err != nil { + return dnsrecord.Record{}, fmt.Errorf("failed to convert record from response: %w", err) + } + + return createdRecord, nil } // UpdateRecord updates an existing DNS record @@ -237,9 +251,24 @@ func (p *RESTProvider) UpdateRecord(ctx context.Context, zoneID string, recordID if err != nil { return dnsrecord.Record{}, errors.NewAPI("UpdateRecord", "failed to update DNS record", err) } - defer resp.Body.Close() - return record, nil + // Parse response to get updated fields + var responseData interface{} + if err := httpprovider.ParseJSONResponse(resp, &responseData); err != nil { + return dnsrecord.Record{}, fmt.Errorf("failed to parse response: %w", err) + } + + recordMap, err := mapper.ExtractRecord(responseData, p.mappings.ResponsePath) + if err != nil { + return dnsrecord.Record{}, fmt.Errorf("failed to extract record from response: %w", err) + } + + updatedRecord, err := mapper.FromProviderFormat(recordMap, p.mappings.Response) + if err != nil { + return dnsrecord.Record{}, fmt.Errorf("failed to convert record from response: %w", err) + } + + return updatedRecord, nil } // DeleteRecord deletes a DNS record @@ -266,6 +295,10 @@ func (p *RESTProvider) DeleteRecord(ctx context.Context, zoneID string, recordID // BulkReplaceRecords replaces all records in a zone with the provided set func (p *RESTProvider) BulkReplaceRecords(ctx context.Context, zoneID string, records []dnsrecord.Record) error { // Naive implementation + if !p.Capabilities().CanDeleteRecord || !p.Capabilities().CanCreateRecord { + return fmt.Errorf("provider does not support bulk replace (missing delete_record or create_record capability)") + } + existing, err := p.ListRecords(ctx, zoneID) if err != nil { return err @@ -273,14 +306,16 @@ func (p *RESTProvider) BulkReplaceRecords(ctx context.Context, zoneID string, re for _, r := range existing { if r.ID != "" { - _ = p.DeleteRecord(ctx, zoneID, r.ID) + if err := p.DeleteRecord(ctx, zoneID, r.ID); err != nil { + return fmt.Errorf("failed to delete record %s during bulk replace: %w", r.ID, err) + } } } for _, r := range records { _, err := p.CreateRecord(ctx, zoneID, r) if err != nil { - return err + return fmt.Errorf("failed to create record %s during bulk replace: %w", r.HostName, err) } } return nil @@ -288,13 +323,17 @@ func (p *RESTProvider) BulkReplaceRecords(ctx context.Context, zoneID string, re // Capabilities returns the provider's capabilities func (p *RESTProvider) Capabilities() dnsprovider.ProviderCapabilities { + canDelete := p.hasEndpoint("delete_record") + canCreate := p.hasEndpoint("create_record") + return dnsprovider.ProviderCapabilities{ CanListZones: p.hasEndpoint("list_zones") || p.hasEndpoint("zones") || p.hasSetting("zone_id"), CanGetZone: true, - CanCreateRecord: p.hasEndpoint("create_record"), + CanCreateRecord: canCreate, CanUpdateRecord: p.hasEndpoint("update_record"), - CanDeleteRecord: p.hasEndpoint("delete_record"), - CanBulkReplace: true, + CanDeleteRecord: canDelete, + // Naive bulk replace requires delete and create + CanBulkReplace: canDelete && canCreate, } } diff --git a/pkg/dns/provider/rest/rest_test.go b/pkg/dns/provider/rest/rest_test.go index 663d032..f99fe4f 100644 --- a/pkg/dns/provider/rest/rest_test.go +++ b/pkg/dns/provider/rest/rest_test.go @@ -8,6 +8,7 @@ import ( httpclient "zonekit/pkg/dns/provider/http" "zonekit/pkg/dns/provider/mapper" + "zonekit/pkg/dnsrecord" "github.com/stretchr/testify/require" ) @@ -74,3 +75,36 @@ func TestListZones_Endpoint(t *testing.T) { require.Equal(t, "z1", zones[0].ID) require.Equal(t, "example.com", zones[0].Name) } + +func TestCreateRecord_ReturnsID(t *testing.T) { + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method == http.MethodPost && r.URL.Path == "/zones/z1/records" { + w.Header().Set("Content-Type", "application/json") + w.Write([]byte(`{"result": {"id": "new-rec-123", "name": "www", "type": "A", "content": "1.2.3.4"}}`)) + return + } + w.WriteHeader(http.StatusNotFound) + })) + defer ts.Close() + + client := httpclient.NewClient(httpclient.ClientConfig{BaseURL: ts.URL}) + mappings := mapper.DefaultMappings() + mappings.ResponsePath = "result" // Mock response structure + mappings.Response.ID = "id" + mappings.Response.HostName = "name" + mappings.Response.RecordType = "type" + mappings.Response.Address = "content" + + p := NewRESTProvider("test", client, mappings, map[string]string{"create_record": "/zones/{zone_id}/records"}, nil) + + rec := dnsrecord.Record{ + HostName: "www", + RecordType: "A", + Address: "1.2.3.4", + } + + created, err := p.CreateRecord(context.Background(), "z1", rec) + require.NoError(t, err) + require.Equal(t, "new-rec-123", created.ID) + require.Equal(t, "www", created.HostName) +} From fc6c2e090aafa81d0904634f97ec79cb88c703ed Mon Sep 17 00:00:00 2001 From: "google-labs-jules[bot]" <161369871+google-labs-jules[bot]@users.noreply.github.com> Date: Tue, 3 Feb 2026 10:03:15 +0000 Subject: [PATCH 4/5] Create TODO.md with gap analysis and prioritized tasks - Add TODO.md file identifying gaps in safety, testing, and code quality - Highlight unsafe non-atomic bulk update implementation in Service/RESTProvider - Prioritize conformance testing and context propagation - Document need for granular operations in Service.BulkUpdate Co-authored-by: SamyRai <919510+SamyRai@users.noreply.github.com> --- TODO.md | 53 +++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 53 insertions(+) create mode 100644 TODO.md diff --git a/TODO.md b/TODO.md new file mode 100644 index 0000000..4e8d1bc --- /dev/null +++ b/TODO.md @@ -0,0 +1,53 @@ +# ZoneKit TODO List + +This list identifies gaps, technical debt, and areas for improvement in the codebase, prioritized by safety and production readiness. + +## ๐Ÿšจ High Priority (Safety & Correctness) + +- [ ] **Refactor `Service.BulkUpdate` Strategy** + - **Issue**: Currently, `Service.BulkUpdate` builds a new record list and calls `SetRecords`. For providers like `RESTProvider` (which implements `SetRecords` via "delete all then create all"), this is **non-atomic and unsafe**. If creation fails, data is lost. + - **Task**: Update `Service.BulkUpdate` to: + 1. Check `Provider.Capabilities()`. + 2. If the provider supports **Atomic Bulk Replace**, use `SetRecords`. + 3. Otherwise, orchestrate the update using granular `CreateRecord`, `UpdateRecord`, and `DeleteRecord` calls to minimize risk. + +- [ ] **Enhance `ProviderCapabilities`** + - **Issue**: `CanBulkReplace` is ambiguous. It doesn't distinguish between a safe, atomic API call and a dangerous client-side loop. + - **Task**: Add `IsBulkReplaceAtomic bool` to `ProviderCapabilities`. + +- [ ] **Fix `context.Context` Propagation** + - **Issue**: `Service` methods (e.g., `GetRecords`) create a new `context.Background()` instead of accepting a context from the caller. This prevents cancellation and timeout propagation from the CLI or API layer. + - **Task**: Update all `Service` methods to accept `ctx context.Context` as the first argument. + +- [ ] **Harden `RESTProvider` Error Handling** + - **Issue**: While `BulkReplaceRecords` now checks errors, it's still a "stop the world" failure. + - **Task**: Implement rollback attempts or "continue on error" policies where appropriate (configurable). + +## ๐Ÿงช Medium Priority (Testing & QA) + +- [ ] **Expand Conformance Test Suite** + - **Issue**: `pkg/dns/provider/conformance` only tests `ListZones` and `GetZone`. + - **Task**: Add tests for: + - `CreateRecord`: Verify record is created and ID is returned. + - `UpdateRecord`: Verify record is updated. + - `DeleteRecord`: Verify record is gone. + - `ListRecords`: Verify filtering and pagination (if applicable). + - `BulkReplaceRecords`: Verify state transitions. + +- [ ] **Add Integration Tests** + - **Issue**: Tests primarily rely on mocks. + - **Task**: Add integration tests that spin up a local HTTP server (mocking Cloudflare/DigitalOcean APIs) to verify the full `RESTProvider` -> `Mapper` -> `HTTP` stack. + +## ๐Ÿงน Low Priority (Cleanup & Features) + +- [ ] **Implement `BatchUpdate` Interface** + - **Issue**: Some providers support batch operations (e.g., "create 10 records") which is more efficient than 10 separate calls but less drastic than "replace zone". + - **Task**: Add `BatchUpdate(ctx, operations)` to `Provider` interface. + +- [ ] **Structured Logging** + - **Issue**: Logging is likely ad-hoc (using `fmt` or basic `log`). + - **Task**: Integrate a structured logger (like `log/slog`) to provide consistent, machine-readable logs for debugging production issues. + +- [ ] **Configuration Validation** + - **Issue**: Configuration loading could be stricter. + - **Task**: Use a validation library to ensure all required fields (auth, endpoints) are present and well-formed at startup. From 8cdfbdb54a5308b614fe8eccd445045d910d111e Mon Sep 17 00:00:00 2001 From: "google-labs-jules[bot]" <161369871+google-labs-jules[bot]@users.noreply.github.com> Date: Tue, 3 Feb 2026 10:18:55 +0000 Subject: [PATCH 5/5] Add long-term development roadmap and vision - Create ROADMAP.md defining strategic goals for 3 months, 6 months, and 1-5 years - Outline key themes: Safety & Stability, Automation, Enterprise Scale, and AI/SaaS evolution - Prioritize technical debt elimination (atomic updates, context propagation) in the short term - Plan for advanced features like Zone Sync, GitOps, and structured logging Co-authored-by: SamyRai <919510+SamyRai@users.noreply.github.com> --- ROADMAP.md | 81 ++++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 81 insertions(+) create mode 100644 ROADMAP.md diff --git a/ROADMAP.md b/ROADMAP.md new file mode 100644 index 0000000..5a2597b --- /dev/null +++ b/ROADMAP.md @@ -0,0 +1,81 @@ +# ZoneKit Long-Term Roadmap + +This document outlines the strategic vision and development roadmap for ZoneKit over the next 5 years. + +## ๐ŸŸข 0-3 Months: Foundation & Stability (v0.x - v1.0) + +**Theme: "Safe, Reliable, and Correct"** + +The immediate focus is on eliminating technical debt, ensuring data safety, and finalizing the core provider contract to reach a stable v1.0 release. + +* **Safety & Correctness (Priority #1)** + * [ ] **Atomic Operations**: Eliminate non-atomic bulk updates. Implement intelligent diffing in `Service.BulkUpdate` to minimize API calls and prevent data loss. + * [ ] **Context Propagation**: Ensure `context.Context` is threaded through every layer of the application for proper timeout and cancellation handling. + * [ ] **Validation**: Implement strict schema validation for all provider configurations and DNS records. + +* **Provider Ecosystem** + * [ ] **Conformance Suite**: Expand the conformance test harness to cover 100% of the `Provider` interface (CRUD, Edge Cases). + * [ ] **Core Providers**: Fully support Cloudflare, Namecheap, AWS Route53, and Google Cloud DNS with production-grade reliability. + +* **Developer Experience** + * [ ] **Structured Logging**: Replace ad-hoc logging with `log/slog` for structured, machine-readable output. + * [ ] **Error Handling**: Standardize error types across all providers (e.g., `ErrRecordNotFound`, `ErrAuthenticationFailed`). + +--- + +## ๐ŸŸก 3-6 Months: Advanced Features & Ecosystem (v1.x) + +**Theme: "Power User & Automation"** + +Once the core is stable, we shift focus to enabling complex workflows, automation, and broader integrations. + +* **Advanced DNS Management** + * [ ] **Zone Sync**: One-way synchronization between providers (e.g., "Primary: Cloudflare" -> "Backup: Route53"). + * [ ] **Dry Run**: Reliable "what-if" analysis for all operations, showing exactly what records will be created, updated, or deleted. + * [ ] **Record Templates**: Support for templated zones (e.g., "Standard Mail Setup", "Web Server Basic") for rapid provisioning. + +* **Infrastructure as Code (IaC)** + * [ ] **Terraform Provider**: Release an official Terraform provider wrapping ZoneKit logic. + * [ ] **GitOps Integration**: Native support for managing DNS configuration via Git repositories (YAML/JSON definitions). + +* **Observability** + * [ ] **Metrics**: Expose Prometheus metrics for API calls, latencies, and error rates. + * [ ] **Audit Logs**: Comprehensive audit logging for all changes made via the tool. + +--- + +## ๐Ÿ”ต 6-12 Months: Enterprise & Scale (v2.x) + +**Theme: "Enterprise Ready"** + +Focus on multi-tenancy, team management, and handling massive scale. + +* **Enterprise Security** + * [ ] **SSO/OIDC**: Support for retrieving provider credentials via enterprise identity providers. + * [ ] **RBAC**: Granular permissions for API keys (e.g., "Read Only", "Zone Specific Write"). + * [ ] **Vault Integration**: Native integration with HashiCorp Vault for secret management. + +* **Performance** + * [ ] **Parallel Execution**: Concurrent processing of multi-zone operations for high-performance updates. + * [ ] **Caching**: Intelligent caching layer to reduce API costs and improve latency for read operations. + +--- + +## ๐ŸŸฃ 1-5 Years: Platform & Intelligence (v3.x+) + +**Theme: "The DNS Platform"** + +Long-term evolution from a CLI tool to a comprehensive DNS management platform. + +* **SaaS Evolution** + * [ ] **ZoneKit Cloud**: A managed SaaS offering providing a web UI and unified API over all your DNS providers. + * [ ] **Global API**: A single, normalized API endpoint that routes to any underlying provider (Cloudflare, AWS, etc.). + +* **Intelligent Automation (AI)** + * [ ] **AI-Driven Optimization**: Automatic suggestions for DNS misconfigurations (e.g., missing SPF/DMARC, dangling CNAMEs). + * [ ] **Smart Routing**: Dynamic updates to DNS records based on real-time latency or uptime monitoring of endpoints. + * [ ] **Anomaly Detection**: Alerts for unusual DNS record changes or query patterns. + +* **Global Ecosystem** + * [ ] **Marketplace**: A community marketplace for custom provider plugins and automation scripts. + * [ ] **Standardization**: Work towards establishing the "ZoneKit Schema" as an industry-standard format for vendor-agnostic DNS definition.