diff --git a/internal/api/api.go b/internal/api/api.go index f97adbc..71f55cb 100644 --- a/internal/api/api.go +++ b/internal/api/api.go @@ -8,14 +8,13 @@ import ( ) const ( - ApiIPsPath = "/api/v1/ips" + ApiIPsPath = "/api/v1/ips" + ApiAllIPsPath = "/api/v1/allips" ) // For now we don't really have any API, just parsing JSON response with []string data in it. - -// GetIPs fetches the IP addresses from the PD Assistant instances. -func GetIPs(conf cfg.AppConfig, pdaAddress string) ([]string, error) { - fullAddress := pdaAddress + ApiIPsPath +func getIPs(conf cfg.AppConfig, pdaAddress, path string) ([]string, error) { + fullAddress := pdaAddress + path resp, err := utils.MakeHTTPRequest(fullAddress, "", "", "", conf.PDAssistantTLSInsecure, conf.HTTPRequestTimeout, conf.BearerToken) // Check if the request was successful if err != nil { @@ -36,3 +35,13 @@ func GetIPs(conf cfg.AppConfig, pdaAddress string) ([]string, error) { return ips, nil } + +// GetIPs fetches local IP addresses from the PD Assistant instances. +func GetLocalIPs(conf cfg.AppConfig, pdaAddress string) ([]string, error) { + return getIPs(conf, pdaAddress, ApiIPsPath) +} + +// GetIAllPs fetches all IP addresses from the PD Assistant instances. +func GetAllIPs(conf cfg.AppConfig, pdaAddress string) ([]string, error) { + return getIPs(conf, pdaAddress, ApiAllIPsPath) +} diff --git a/internal/cfg/cfg.go b/internal/cfg/cfg.go index edd52e4..dffb91e 100644 --- a/internal/cfg/cfg.go +++ b/internal/cfg/cfg.go @@ -50,6 +50,7 @@ type AppConfig struct { PDAssistantScheme string PDAssistantPort string PDAssistantTLSInsecure bool + PDAssistantConsensus bool // HTTPRequestTimeout is the timeout for HTTP requests in seconds. HTTPRequestTimeout int diff --git a/internal/cfg/cfg_test.go b/internal/cfg/cfg_test.go index e53808d..afc2c18 100644 --- a/internal/cfg/cfg_test.go +++ b/internal/cfg/cfg_test.go @@ -25,6 +25,9 @@ func TestLoadCertificateYaml(t *testing.T) { if len(cert.Spec.DNSNames) != 4 { t.Errorf("expected 4 DNS names, got %d", len(cert.Spec.DNSNames)) } + if len(cert.Spec.IPAddresses) != 5 { + t.Errorf("expected 5 IP addresses, got %d", len(cert.Spec.IPAddresses)) + } if cert.Spec.SecretName != "example-certificate-secret" { t.Errorf("expected secret name 'example-certificate-secret', got '%s'", cert.Spec.SecretName) } diff --git a/internal/k8s/k8s.go b/internal/k8s/k8s.go index 80fbf5e..0de401b 100644 --- a/internal/k8s/k8s.go +++ b/internal/k8s/k8s.go @@ -106,25 +106,29 @@ func (c *Client) GetCiliumNodes() ([]string, error) { } // UpdateCertificate updates the certificate in Kubernetes with the provided IP addresses. -func (c *Client) UpdateCertificate(conf cfg.AppConfig, IPs []string) error { +func (c *Client) UpdateCertificate(conf cfg.AppConfig, inIPs []string) error { client, err := cmclient.NewForConfig(c.Config) if err != nil { return err } - // Override IP addresses from the configuration - conf.Certificate.Spec.IPAddresses = IPs - conf.Certificate.SetAnnotations(injectAnnotations(conf.Certificate)) + // Add the IPs to the certificate loaded from the configuration + IPs := append(conf.Certificate.Spec.IPAddresses, inIPs...) + // Check if the certificate already exists certificate, err := client.CertmanagerV1().Certificates(conf.Certificate.Namespace).Get(context.TODO(), conf.Certificate.Name, metav1.GetOptions{}) if err != nil { if errors.IsNotFound(err) { - glog.Infof("Certificate %s/%s not found, creating a new one", conf.Certificate.Namespace, conf.Certificate.Name) - _, err = client.CertmanagerV1().Certificates(conf.Certificate.Namespace).Create(context.TODO(), &conf.Certificate, metav1.CreateOptions{}) + // Override IP addresses from the configuration + newCert := conf.Certificate.DeepCopy() + newCert.Spec.IPAddresses = IPs + newCert.SetAnnotations(injectAnnotations(conf.Certificate)) + glog.Infof("Certificate %s/%s not found, creating a new one", newCert.Namespace, newCert.Name) + _, err = client.CertmanagerV1().Certificates(newCert.Namespace).Create(context.TODO(), newCert, metav1.CreateOptions{}) if err != nil { - return fmt.Errorf("failed to create certificate %s/%s: %s", conf.Certificate.Namespace, conf.Certificate.Name, err.Error()) + return fmt.Errorf("failed to create certificate %s/%s: %s", newCert.Namespace, newCert.Name, err.Error()) } - glog.Infof("Certificate %s/%s created successfully", conf.Certificate.Namespace, conf.Certificate.Name) + glog.Infof("Certificate %s/%s created successfully", newCert.Namespace, newCert.Name) return nil } return fmt.Errorf("failed to get certificate %s/%s: %s", certificate.Namespace, conf.Certificate.Name, err.Error()) diff --git a/internal/metrics/metrics.go b/internal/metrics/metrics.go index 05e6233..9498704 100644 --- a/internal/metrics/metrics.go +++ b/internal/metrics/metrics.go @@ -18,6 +18,7 @@ type AppMetrics struct { // Counters CertUpdateErrors *prometheus.CounterVec PDAssistantFetchErrors *prometheus.CounterVec + ConsensusErrors *prometheus.CounterVec } func InitMetrics(version string) AppMetrics { @@ -68,11 +69,21 @@ func InitMetrics(version string) AppMetrics { Name: "fetch_errors_total", Help: "Total number of errors fetching data from PD Assistants", }, - []string{"pd_assistant"}, + []string{"pd_assistant", "type"}, + ) + + am.ConsensusErrors = promauto.With(am.Registry).NewCounterVec( + prometheus.CounterOpts{ + Namespace: "pd_assistant", + Name: "consensus_errors_total", + Help: "Total number of errors in consensus check", + }, + []string{}, ) am.Config.WithLabelValues(version).Set(1) am.CertUpdateErrors.WithLabelValues().Add(0) + am.ConsensusErrors.WithLabelValues().Add(0) am.Registry.MustRegister() return am diff --git a/internal/server/server.go b/internal/server/server.go index 21d9e6b..fe13930 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -14,6 +14,7 @@ import ( "github.com/impossiblecloud/pd-cert-assistant/internal/k8s" "github.com/impossiblecloud/pd-cert-assistant/internal/metrics" "github.com/impossiblecloud/pd-cert-assistant/internal/tidb" + "github.com/impossiblecloud/pd-cert-assistant/internal/utils" "github.com/prometheus/client_golang/prometheus" "github.com/prometheus/client_golang/prometheus/promhttp" ) @@ -80,26 +81,59 @@ func authHandler(endpoint http.HandlerFunc, cfg cfg.AppConfig) http.HandlerFunc func (s *State) getAllIPAddresses(conf cfg.AppConfig, pdaAddresses []string) ([]string, error) { allIPAddresses := []string{} + // Iterate over pd-assistant addresses and fetch their local IPs for _, pdaAddress := range pdaAddresses { - glog.V(6).Infof("Fetching IPs from pd-assistant: %s", pdaAddress) - ips, err := api.GetIPs(conf, pdaAddress) + glog.V(6).Infof("Fetching local IPs from pd-assistant: %s", pdaAddress) + ips, err := api.GetLocalIPs(conf, pdaAddress) if err != nil { - s.Metrics.PDAssistantFetchErrors.WithLabelValues(pdaAddress).Inc() + s.Metrics.PDAssistantFetchErrors.WithLabelValues(pdaAddress, "local").Inc() return nil, fmt.Errorf("failed to fetch IPs from pd-assistant %s: %v", pdaAddress, err) } if len(ips) == 0 { - s.Metrics.PDAssistantFetchErrors.WithLabelValues(pdaAddress).Inc() + s.Metrics.PDAssistantFetchErrors.WithLabelValues(pdaAddress, "local").Inc() return nil, fmt.Errorf("no IPs found in pd-assistant %s", pdaAddress) } // Update the state with the fetched IPs allIPAddresses = append(allIPAddresses, ips...) - glog.V(6).Infof("Fetched IPs from pd-assistant %s: %+v", pdaAddress, ips) + glog.V(6).Infof("Fetched local IPs from pd-assistant %s: %+v", pdaAddress, ips) } return allIPAddresses, nil } +func (s *State) allIPsConsesusCheck(conf cfg.AppConfig, pdaAddresses []string) (bool, error) { + var sampleIPs []string + + // Iterate over pd-assistant addresses, fetch all IPs they've found and compare them between each other + for id, pdaAddress := range pdaAddresses { + glog.V(6).Infof("Fetching all IPs from pd-assistant: %s", pdaAddress) + ips, err := api.GetAllIPs(conf, pdaAddress) + if err != nil { + s.Metrics.PDAssistantFetchErrors.WithLabelValues(pdaAddress, "all").Inc() + return false, fmt.Errorf("failed to fetch all IPs from pd-assistant %s: %v", pdaAddress, err) + + } + if len(ips) == 0 { + s.Metrics.PDAssistantFetchErrors.WithLabelValues(pdaAddress, "all").Inc() + return false, fmt.Errorf("no all IPs found in pd-assistant %s", pdaAddress) + } + if id == 0 { + // Snapshot sample IPs which will be used for comparison + sampleIPs = ips + } else { + // Compare the fetched IPs with the sample IPs + if !utils.IPListsEqual(sampleIPs, ips) { + glog.V(0).Infof("All IPs consensus error: all IPs are not equal between pd-assistants: %s and %s", pdaAddresses[0], pdaAddress) + glog.V(8).Infof("Sample all IPs from %s: %+v", pdaAddresses[0], sampleIPs) + glog.V(8).Infof("Fetched all IPs from %s: %+v", pdaAddress, ips) + return false, nil + } + } + } + return true, nil +} + // IPWatchLoop continuously fetches CiliumNode IPs and updates the state func (s *State) IPWatchLoop(conf cfg.AppConfig, kc k8s.Client) { for { @@ -152,9 +186,23 @@ func (s *State) FetchIPsAndUpdateCertLoop(conf cfg.AppConfig, kc k8s.Client) { // Atomic update of AllIPAddresses in the state, only if all IPs are fetched successfully s.AllIPAddresses = allIPAddresses s.Metrics.AllIPs.WithLabelValues().Set(float64(len(allIPAddresses))) - glog.V(4).Infof("All IPs fetched from pd-assistants: %+v", s.AllIPAddresses) + glog.V(6).Infof("All IPs fetched from pd-assistants: %+v", s.AllIPAddresses) glog.V(6).Info("Checking for certificate updates") + // Check IP address consensus + if conf.PDAssistantConsensus { + if consensus, err := s.allIPsConsesusCheck(conf, pdaAddresses); err != nil { + s.Metrics.ConsensusErrors.WithLabelValues().Inc() + glog.Errorf("Failed to check IP address consensus: %v", err) + continue + } else if !consensus { + s.Metrics.ConsensusErrors.WithLabelValues().Inc() + glog.Errorf("IP address consensus check failed, skipping certificate update") + continue + } + glog.V(6).Info("IP address consensus check passed") + } + // Update the certificate with the new IPs if needed err = kc.UpdateCertificate(conf, allIPAddresses) if err != nil { @@ -164,11 +212,11 @@ func (s *State) FetchIPsAndUpdateCertLoop(conf cfg.AppConfig, kc k8s.Client) { } } -// GetIPs handler with bearer token authentication +// GetIPs returns local IP addresses in JSON format func (s *State) GetIPs(w http.ResponseWriter, r *http.Request) { glog.V(10).Infof("Got HTTP request for %s", api.ApiIPsPath) - // Marshal the IP addresses to JSON + // Marshal local IP addresses to JSON jsonResponse, err := json.Marshal(s.IPAddresses) if err != nil { glog.Errorf("Failed to marshal IP addresses: %v", err) @@ -183,6 +231,25 @@ func (s *State) GetIPs(w http.ResponseWriter, r *http.Request) { w.Write(jsonResponse) } +// GetAllIPs returns all IP addresses in JSON format +func (s *State) GetAllIPs(w http.ResponseWriter, r *http.Request) { + glog.V(10).Infof("Got HTTP request for %s", api.ApiIPsPath) + + // Marshal all IP addresses to JSON + jsonResponse, err := json.Marshal(s.AllIPAddresses) + if err != nil { + glog.Errorf("Failed to marshal all IP addresses: %v", err) + w.WriteHeader(http.StatusInternalServerError) + fmt.Fprintf(w, `{"error": "Failed to encode all IP addresses"}`) + return + } + + // Respond with the IP addresses + w.WriteHeader(http.StatusOK) + w.Header().Set("Content-Type", "application/json") + w.Write(jsonResponse) +} + // Main web server func (s *State) RunMainWebServer(config cfg.AppConfig, listen string) { // Setup http router @@ -192,6 +259,7 @@ func (s *State) RunMainWebServer(config cfg.AppConfig, listen string) { router.HandleFunc("/health", healthHandler).Methods("GET") router.HandleFunc("/metrics", s.handleMetrics(config)).Methods("GET") router.HandleFunc(api.ApiIPsPath, authHandler(s.GetIPs, config)).Methods("GET") + router.HandleFunc(api.ApiAllIPsPath, authHandler(s.GetAllIPs, config)).Methods("GET") router.HandleFunc("/", rootHandler).Methods("GET") // Run main http router diff --git a/main.go b/main.go index d769271..3c23f33 100644 --- a/main.go +++ b/main.go @@ -37,16 +37,17 @@ func main() { flag.BoolVar(&showVersion, "version", false, "Show version and exit") // Kubernetes parameters flag.StringVar(&kubeconfig, "kubeconfig", "", "Path to the kubeconfig file (optional)") - flag.IntVar(&config.KubernetesPollInterval, "k8s-poll-interval", 180, "Interval for polling Kubernetes in seconds") + flag.IntVar(&config.KubernetesPollInterval, "k8s-poll-interval", 60, "Interval for polling Kubernetes in seconds") // PD assistant parameters - flag.IntVar(&config.PDAssistantPollInterval, "pd-assistant-poll-interval", 300, "Interval for polling all pd-assistants in seconds") + flag.IntVar(&config.PDAssistantPollInterval, "pd-assistant-poll-interval", 120, "Interval for polling all pd-assistants in seconds") flag.StringVar(&config.PDAssistantHostPrefix, "pd-assistant-host-prefix", "pd-assistant", "Host prefix for PD Assistant instances") flag.StringVar(&config.PDAssistantScheme, "pd-assistant-scheme", "https", "Scheme for PD Assistant instances (http or https)") flag.StringVar(&config.PDAssistantPort, "pd-assistant-port", "443", "Port for PD Assistant instances") flag.BoolVar(&config.PDAssistantTLSInsecure, "pd-assistant-tls-insecure", false, "Skip TLS verification for PD Assistant instances (not recommended)") flag.StringVar(&pdAssistantURLs, "pd-assistant-urls", "", "List of PD Assistant URLs (comma-separated). Overrides --pd-assistant-host-prefix and ignores --pd-address auto-discovery if provided") + flag.BoolVar(&config.PDAssistantConsensus, "pd-assistant-consensus", false, "Require consensus from all PD Assistant instances before updating the certificate") // Certificate parameters - flag.IntVar(&config.CertUpdateInterval, "cert-update-interval", 300, "Interval for updating PD certificate in seconds") + flag.IntVar(&config.CertUpdateInterval, "cert-update-interval", 180, "Interval for updating PD certificate in seconds") flag.StringVar(&certFilePath, "certificate-file", "/app/conf/", "Path to a Certificate YAML file to be used as a template") // PD discovery parameters flag.StringVar(&config.PDDiscoveryConfig.URL, "pd-discovery-url", "", "PD Discovery service URL") @@ -86,6 +87,11 @@ func main() { glog.V(4).Infof("PD Discovery URL: %s", config.PDDiscoveryConfig.URL) } glog.V(4).Infof("Loaded certificate YAML file %q: name=%s, namespace=%s", certFilePath, config.Certificate.Name, config.Certificate.Namespace) + if config.PDAssistantConsensus { + glog.V(4).Infof("PD Assistant consensus check is enabled") + } else { + glog.V(4).Infof("PD Assistant consensus check is disabled") + } // Let's rock and roll! // Watch CliliumNode IPs and update the state