diff --git a/agent/api/agent_handler.go b/agent/api/agent_handler.go index 1044ca7b..6c0d9ea0 100644 --- a/agent/api/agent_handler.go +++ b/agent/api/agent_handler.go @@ -35,18 +35,18 @@ var ( agentService = agentservice.AgentService{} ) -// @Summary join the specified agent -// @Description join the specified agent -// @Tags agent -// @Accept application/json -// @Produce application/json -// @Param X-OCS-Header header string true "Authorization" -// @Param body body param.JoinApiParam true "agent info with zone name" -// @Success 200 object http.OcsAgentResponse{data=task.DagDetailDTO} -// @Failure 400 object http.OcsAgentResponse -// @Failure 500 object http.OcsAgentResponse -// @Router /api/v1/agent [post] -// @Router /api/v1/agent/join [post] +// @Summary join the specified agent +// @Description join the specified agent +// @Tags agent +// @Accept application/json +// @Produce application/json +// @Param X-OCS-Header header string true "Authorization" +// @Param body body param.JoinApiParam true "agent info with zone name" +// @Success 200 object http.OcsAgentResponse{data=task.DagDetailDTO} +// @Failure 400 object http.OcsAgentResponse +// @Failure 500 object http.OcsAgentResponse +// @Router /api/v1/agent [post] +// @Router /api/v1/agent/join [post] func agentJoinHandler(c *gin.Context) { var param param.JoinApiParam if err := c.BindJSON(¶m); err != nil { @@ -54,7 +54,7 @@ func agentJoinHandler(c *gin.Context) { return } if !meta.OCS_AGENT.IsSingleAgent() { - common.SendResponse(c, nil, errors.Occurf(errors.ErrBadRequest, "%s:%d is not single agent", meta.OCS_AGENT.GetIp(), meta.OCS_AGENT.GetPort())) + common.SendResponse(c, nil, errors.Occurf(errors.ErrBadRequest, "%s is not single agent", meta.OCS_AGENT.String())) return } @@ -86,7 +86,13 @@ func agentJoinHandler(c *gin.Context) { param.AgentInfo.String(), agentStatus.OBVersion, meta.OCS_AGENT.String(), obVersion)) return } - dag, err = agent.CreateJoinMasterDag(param.AgentInfo, param.ZoneName) + // send token to master early. + if err = agent.SendTokenToMaster(param.AgentInfo, param.MasterPassword); err != nil { + common.SendResponse(c, nil, errors.Occur(errors.ErrTaskCreateFailed, err)) + return + } + + dag, err = agent.CreateJoinMasterDag(param.AgentInfo, param.ZoneName, param.MasterPassword) } if err != nil { @@ -96,18 +102,18 @@ func agentJoinHandler(c *gin.Context) { common.SendResponse(c, task.NewDagDetailDTO(dag), nil) } -// @Summary remove the specified agent -// @Description remove the specified agent -// @Tags agent -// @Accept application/json -// @Produce application/json -// @Param X-OCS-Header header string true "Authorization" -// @Param body body meta.AgentInfo true "agent info" -// @Success 200 object http.OcsAgentResponse{data=task.DagDetailDTO} -// @Failure 400 object http.OcsAgentResponse -// @Failure 500 object http.OcsAgentResponse -// @Router /api/v1/agent [delete] -// @Router /api/v1/agent/remove [post] +// @Summary remove the specified agent +// @Description remove the specified agent +// @Tags agent +// @Accept application/json +// @Produce application/json +// @Param X-OCS-Header header string true "Authorization" +// @Param body body meta.AgentInfo true "agent info" +// @Success 200 object http.OcsAgentResponse{data=task.DagDetailDTO} +// @Failure 400 object http.OcsAgentResponse +// @Failure 500 object http.OcsAgentResponse +// @Router /api/v1/agent [delete] +// @Router /api/v1/agent/remove [post] func agentRemoveHandler(c *gin.Context) { var param meta.AgentInfo if err := c.BindJSON(¶m); err != nil { @@ -178,3 +184,13 @@ func agentRemoveHandler(c *gin.Context) { } common.SendResponse(c, task.NewDagDetailDTO(dag), nil) } + +func agentSetPasswordHandler(c *gin.Context) { + var param param.SetAgentPasswordParam + if err := c.BindJSON(¶m); err != nil { + common.SendResponse(c, nil, err) + return + } + + common.SendResponse(c, nil, agentService.SetAgentPassword(param.Password)) +} diff --git a/agent/api/agent_route.go b/agent/api/agent_route.go index aed220a6..ef47bb0d 100644 --- a/agent/api/agent_route.go +++ b/agent/api/agent_route.go @@ -52,9 +52,14 @@ func InitOcsAgentRoutes(s *http2.State, r *gin.Engine, isLocalRoute bool) { constant.URI_API_V1+constant.URI_TENANT_GROUP+constant.URI_PATH_PARAM_NAME+constant.URI_BACKUP+constant.URI_CONFIG, constant.URI_API_V1+constant.URI_TENANT_GROUP+constant.URI_RESTORE, constant.URI_API_V1+constant.URI_RESTORE+constant.URI_WINDOWS, + constant.URI_API_V1+constant.URI_INIT, constant.URI_TASK_RPC_PREFIX+constant.URI_SUB_TASK, constant.URI_API_V1+constant.URI_TENANT_GROUP, constant.URI_TENANT_API_PREFIX+constant.URI_PATH_PARAM_NAME+constant.URI_ROOTPASSWORD, + constant.URI_TENANT_API_PREFIX+constant.URI_PATH_PARAM_NAME+constant.URI_USER, + constant.URI_TENANT_API_PREFIX+constant.URI_PATH_PARAM_NAME, + constant.URI_OBPROXY_API_PREFIX, + constant.URI_TENANT_API_PREFIX+constant.URI_PATH_PARAM_NAME+constant.URI_VARIABLES, ), common.SetContentType, ) @@ -107,6 +112,7 @@ func InitOcsAgentRoutes(s *http2.State, r *gin.Engine, isLocalRoute bool) { InitTenantRoutes(v1, isLocalRoute) InitBackupRoutes(v1, isLocalRoute) InitRestoreRoutes(v1, isLocalRoute) + InitObproxyRoutes(v1, isLocalRoute) // ob routes ob.POST(constant.URI_INIT, obInitHandler) @@ -126,6 +132,7 @@ func InitOcsAgentRoutes(s *http2.State, r *gin.Engine, isLocalRoute bool) { agent.POST(constant.URI_REMOVE, agentRemoveHandler) agent.POST(constant.URI_UPGRADE, agentUpgradeHandler) agent.POST(constant.URI_UPGRADE+constant.URI_CHECK, agentUpgradeCheckHandler) + agent.POST(constant.URI_PASSWORD, agentSetPasswordHandler) // agents routes agents.GET(constant.URI_STATUS, GetAllAgentStatus(s)) diff --git a/agent/api/common/forward_handler.go b/agent/api/common/forward_handler.go index 73ce6970..68aab998 100644 --- a/agent/api/common/forward_handler.go +++ b/agent/api/common/forward_handler.go @@ -44,7 +44,7 @@ func autoForward(c *gin.Context) { agentService := agentservice.AgentService{} master := agentService.GetMasterAgentInfo() if master == nil { - SendResponse(c, nil, errors.Occur(errors.ErrUnauthorized)) + SendResponse(c, nil, errors.Occur(errors.ErrUnauthorized, "master agent not found")) return } @@ -62,13 +62,13 @@ func autoForward(c *gin.Context) { headerByte, exist := c.Get(constant.OCS_HEADER) if headerByte == nil || !exist { - SendResponse(c, nil, errors.Occur(errors.ErrUnauthorized)) + SendResponse(c, nil, errors.Occur(errors.ErrUnauthorized, "header not found")) return } header, ok := headerByte.(secure.HttpHeader) if !ok { - SendResponse(c, nil, errors.Occur(errors.ErrUnauthorized)) + SendResponse(c, nil, errors.Occur(errors.ErrUnauthorized, "header type error")) return } @@ -104,7 +104,7 @@ func sendRequsetForForward(c *gin.Context, ctx context.Context, agentInfo meta.A } request.SetBody(body) - uri := fmt.Sprintf("%s://%s:%d%s", global.Protocol, agentInfo.GetIp(), agentInfo.GetPort(), c.Request.URL) + uri := fmt.Sprintf("%s://%s%s", global.Protocol, agentInfo.String(), c.Request.URL) response, err := request.Execute(c.Request.Method, uri) if err != nil { log.WithError(err).Errorf("API response failed : [%v %v, client=%v, agent=%v]", c.Request.Method, c.Request.URL, c.ClientIP(), agentInfo.String()) diff --git a/agent/api/common/middleware.go b/agent/api/common/middleware.go index 38c9210f..4b60d1f8 100644 --- a/agent/api/common/middleware.go +++ b/agent/api/common/middleware.go @@ -31,6 +31,7 @@ import ( log "github.com/sirupsen/logrus" "github.com/oceanbase/obshell/agent/constant" + "github.com/oceanbase/obshell/agent/engine/task" "github.com/oceanbase/obshell/agent/errors" ocshttp "github.com/oceanbase/obshell/agent/lib/http" "github.com/oceanbase/obshell/agent/lib/path" @@ -307,20 +308,30 @@ func BodyDecrypt(skipRoutes ...string) func(*gin.Context) { } var err error - encryptedHeader := c.Request.Header.Get(constant.OCS_HEADER) - if encryptedHeader == "" { + if c.Request.Header.Get(constant.OCS_HEADER) == "" && c.Request.Header.Get(constant.OCS_AGENT_HEADER) == "" { c.Next() return } var header secure.HttpHeader - header, err = secure.DecryptHeader(c.Request.Header.Get(constant.OCS_HEADER)) - if err != nil { - log.WithContext(NewContextWithTraceId(c)).Errorf("header decrypt failed, err: %v", err) - c.Abort() - SendResponse(c, nil, errors.Occur(errors.ErrUnauthorized)) - return + if c.Request.Header.Get(constant.OCS_AGENT_HEADER) != "" { + header, err = secure.DecryptHeader(c.Request.Header.Get(constant.OCS_AGENT_HEADER)) + if err != nil { + log.WithContext(NewContextWithTraceId(c)).Errorf("header decrypt failed, err: %v", err) + c.Abort() + SendResponse(c, nil, errors.Occurf(errors.ErrUnauthorized, "header decrypt failed")) + return + } + c.Set(constant.OCS_AGENT_HEADER, header) + } else { + header, err = secure.DecryptHeader(c.Request.Header.Get(constant.OCS_HEADER)) + if err != nil { + log.WithContext(NewContextWithTraceId(c)).Errorf("header decrypt failed, err: %v", err) + c.Abort() + SendResponse(c, nil, errors.Occurf(errors.ErrUnauthorized, "header decrypt failed")) + return + } + c.Set(constant.OCS_HEADER, header) } - c.Set(constant.OCS_HEADER, header) for _, route := range secure.GetSkipBodyEncryptRoutes() { if route == c.Request.RequestURI { @@ -334,7 +345,7 @@ func BodyDecrypt(skipRoutes ...string) func(*gin.Context) { if err != nil { log.WithContext(NewContextWithTraceId(c)).Errorf("read body failed, err: %v", err) c.Abort() - SendResponse(c, nil, errors.Occur(errors.ErrUnauthorized)) + SendResponse(c, nil, errors.Occurf(errors.ErrUnauthorized, "read body failed")) return } if len(encryptedBody) == 0 { @@ -345,7 +356,7 @@ func BodyDecrypt(skipRoutes ...string) func(*gin.Context) { if err != nil { log.WithContext(NewContextWithTraceId(c)).Errorf("body decrypt failed, err: %v", err) c.Abort() - SendResponse(c, nil, errors.Occur(errors.ErrUnauthorized)) + SendResponse(c, nil, errors.Occurf(errors.ErrUnauthorized, "body decrypt failed")) return } c.Request.Body = io.NopCloser(bytes.NewBuffer(body)) @@ -355,91 +366,210 @@ func BodyDecrypt(skipRoutes ...string) func(*gin.Context) { } } -func Verify(skipRoutes ...string) func(*gin.Context) { - return func(c *gin.Context) { - log.WithContext(NewContextWithTraceId(c)).Infof("verfiy request: %s", c.Request.RequestURI) - for _, route := range skipRoutes { - if route == c.Request.RequestURI { - c.Next() - return - } - } - var header secure.HttpHeader - headerByte, exist := c.Get(constant.OCS_HEADER) - - if headerByte == nil || !exist { - log.WithContext(NewContextWithTraceId(c)).Error("header not found") - c.Abort() - SendResponse(c, nil, errors.Occur(errors.ErrUnauthorized)) - return +func VerifyObRouters(c *gin.Context, curTs int64, header *secure.HttpHeader, passwordType secure.VerifyType) { + pass := false + var err error + switch meta.OCS_AGENT.GetIdentity() { + case meta.SINGLE: + if err = secure.VerifyToken(header.Token); err == nil { + pass = true + break } - pass := false - - header, ok := headerByte.(secure.HttpHeader) - if !ok { - log.WithContext(NewContextWithTraceId(c)).Error("header type error") - c.Abort() - SendResponse(c, nil, errors.Occur(errors.ErrUnauthorized)) - return + if meta.AGENT_PWD.Inited() { + if passwordType == secure.AGENT_PASSWORD { + if err = secure.VerifyAuth(header.Auth, header.Ts, curTs, secure.AGENT_PASSWORD); err == nil { + pass = true + } else { + decryptAgentPassword, err1 := secure.Decrypt(header.Auth) + if err1 == nil && secure.VerifyAuth(decryptAgentPassword, header.Ts, curTs, secure.AGENT_PASSWORD) == nil { + pass = true + } + } + } else { + err = errors.New("agent password has been set, use agent password to verify") + } + } else { + pass = true } - - switch meta.OCS_AGENT.GetIdentity() { - case meta.SINGLE: + case meta.FOLLOWER: + // Follower verify token only. + if err = secure.VerifyToken(header.Token); err == nil { pass = true - case meta.FOLLOWER: - // Follower verify token only. - if secure.VerifyToken(header.Token) == nil { + } else { + if IsApiRoute(c) && header.ForwardType != secure.ManualForward { + // If the request is api and is not manual forwarded, auto forward it. + autoForward(c) + c.Abort() + return + } + } + case meta.MASTER: + if header.ForwardType == secure.ManualForward { + // When a request is manually forwarded, it must have a valid follower token. + if err = secure.VerifyTokenByAgentInfo(header.Token, header.ForwardAgent); err == nil { pass = true - } else { - if IsApiRoute(c) && header.ForwardType != secure.ManualForward { - // If the request is api and is not manual forwarded, auto forward it. - autoForward(c) - c.Abort() - return - } } - case meta.MASTER: - if header.ForwardType == secure.ManualForward { - // When a request is manually forwarded, it must have a valid follower token. - if err := secure.VerifyTokenByAgentInfo(header.Token, header.ForwardAgent); err == nil { + break + } else if header.ForwardType == secure.AutoForward { + // If the request is auto-forwarded, set IsAutoForwardedFlag to true for parse password. + c.Set(IsAutoForwardedFlag, true) + c.Set(FollowerAgentOfForward, header.ForwardAgent) + } + fallthrough + default: + if passwordType == secure.OCEANBASE_PASSWORD { + if !meta.OCEANBASE_PASSWORD_INITIALIZED && meta.AGENT_PWD.Inited() { + err = errors.New("oceanbase password is not initialized, use agent password to verify") + } else { + if err = secure.VerifyAuth(header.Auth, header.Ts, curTs, secure.OCEANBASE_PASSWORD); err == nil { pass = true } - break - } else if header.ForwardType == secure.AutoForward { - // If the request is auto-forwarded, set IsAutoForwardedFlag to true for parse password. - c.Set(IsAutoForwardedFlag, true) - c.Set(FollowerAgentOfForward, header.ForwardAgent) } - fallthrough - default: - curTs := time.Now().Unix() - if r, ok := c.Get(constant.REQUEST_RECEIVED_TIME); ok { - if receivedTs, ok := r.(int64); ok { - curTs = receivedTs - } + } else { + if meta.OCEANBASE_PASSWORD_INITIALIZED && !meta.AGENT_PWD.Inited() { + err = errors.New("agent password is not initialized, use oceanbase password to verify") + } else if !meta.AGENT_PWD.Inited() { + pass = true + } else if err = secure.VerifyAuth(header.Auth, header.Ts, curTs, secure.AGENT_PASSWORD); err == nil { + pass = true } - if err := secure.VerifyAuth(header.Auth, header.Ts, curTs); err != nil { - log.WithContext(NewContextWithTraceId(c)).Error(err.Error()) - } else { + } + } + if !pass { + log.WithContext(NewContextWithTraceId(c)).Errorf("Security verification failed: %s", err.Error()) + c.Abort() + SendResponse(c, nil, errors.Occurf(errors.ErrUnauthorized, err.Error())) + return + } +} + +func VerifyForSetAgentPassword(c *gin.Context, curTs int64, header *secure.HttpHeader, passwordType secure.VerifyType) { + pass := false + var err error + if meta.AGENT_PWD.Inited() { + if passwordType == secure.AGENT_PASSWORD { + if err = secure.VerifyAuth(header.Auth, header.Ts, curTs, secure.AGENT_PASSWORD); err == nil { + pass = true + } + } else { + err = errors.New("agent password has been set, use agent password to verify") + } + } else if meta.OCS_AGENT.IsClusterAgent() { + if passwordType == secure.OCEANBASE_PASSWORD { + if err = secure.VerifyAuth(header.Auth, header.Ts, curTs, secure.OCEANBASE_PASSWORD); err == nil { pass = true } + } else { + err = errors.New("oceanbase password has been set, use oceanbase password to verify") + } + } else if meta.OCS_AGENT.IsSingleAgent() { + pass = true + } + if !pass { + log.WithContext(NewContextWithTraceId(c)).Errorf("Security verification failed: %s", err.Error()) + c.Abort() + SendResponse(c, nil, errors.Occurf(errors.ErrUnauthorized, err.Error())) + return + } +} + +func VerifyAgentRoutes(c *gin.Context, curTs int64, header *secure.HttpHeader, passwordType secure.VerifyType) { + pass := false + var err error + if passwordType != secure.AGENT_PASSWORD { + err = errors.New("Please use agent password to verify") + } else { + if meta.AGENT_PWD.Inited() { + if err = secure.VerifyAuth(header.Auth, header.Ts, curTs, secure.AGENT_PASSWORD); err == nil { + pass = true + } + } else { + err = errors.New("agent password is not initialized") + } + } + if !pass { + log.WithContext(NewContextWithTraceId(c)).Errorf("Security verification failed: %s", err.Error()) + c.Abort() + SendResponse(c, nil, errors.Occurf(errors.ErrUnauthorized, err.Error())) + return + } +} + +func VerifyTaskRoutes(c *gin.Context, curTs int64, header *secure.HttpHeader, passwordType secure.VerifyType) { + id := c.Param("id") + if id == "" { + VerifyObRouters(c, curTs, header, passwordType) + return + } + if task.IsObproxyTask(id) { + VerifyAgentRoutes(c, curTs, header, passwordType) + return + } else { + VerifyObRouters(c, curTs, header, passwordType) + return + } +} + +func Verify(routeType ...secure.RouteType) func(*gin.Context) { + return func(c *gin.Context) { + log.WithContext(NewContextWithTraceId(c)).Infof("verfiy request: %s", c.Request.RequestURI) + var header secure.HttpHeader + obHeaderByte, _ := c.Get(constant.OCS_HEADER) + agentHeaderByte, _ := c.Get(constant.OCS_AGENT_HEADER) + var headerByte any + var passwordType secure.VerifyType + if agentHeaderByte != nil { + passwordType = secure.AGENT_PASSWORD + headerByte = agentHeaderByte + } else if obHeaderByte != nil { + passwordType = secure.OCEANBASE_PASSWORD + headerByte = obHeaderByte + } else { + log.WithContext(NewContextWithTraceId(c)).Error("header not found") + c.Abort() + SendResponse(c, nil, errors.Occur(errors.ErrUnauthorized, "header not found")) + return } - if !pass { - log.WithContext(NewContextWithTraceId(c)).Error("Security verification failed") + if passwordType != secure.AGENT_PASSWORD && len(routeType) != 0 && routeType[0] == secure.ROUTE_OBPROXY { + log.WithContext(NewContextWithTraceId(c)).Error("agent header not found") c.Abort() - SendResponse(c, nil, errors.Occur(errors.ErrUnauthorized)) + SendResponse(c, nil, errors.Occur(errors.ErrUnauthorized, "aegnt header not found")) + return + } + + header, ok := headerByte.(secure.HttpHeader) + if !ok { + log.WithContext(NewContextWithTraceId(c)).Error("header type error") + c.Abort() + SendResponse(c, nil, errors.Occur(errors.ErrUnauthorized, "header type error")) return } // Verify the URI in the header matches the URI of the request. if header.Uri != c.Request.RequestURI { log.WithContext(NewContextWithTraceId(c)).Errorf("verify uri failed, uri: %s, expect: %s", header.Uri, c.Request.RequestURI) - authErr := errors.Occur(errors.ErrUnauthorized) + authErr := errors.Occurf(errors.ErrUnauthorized, "uri mismatch") c.Abort() SendResponse(c, nil, authErr) return } + curTs := time.Now().Unix() + if r, ok := c.Get(constant.REQUEST_RECEIVED_TIME); ok { + if receivedTs, ok := r.(int64); ok { + curTs = receivedTs + } + } + + if c.Request.RequestURI == constant.URI_AGENT_API_PREFIX+constant.URI_PASSWORD { + VerifyForSetAgentPassword(c, curTs, &header, passwordType) + } else if len(routeType) != 0 && routeType[0] == secure.ROUTE_OBPROXY { + VerifyAgentRoutes(c, curTs, &header, passwordType) + } else if len(routeType) != 0 && routeType[0] == secure.ROUTE_TASK { + VerifyTaskRoutes(c, curTs, &header, passwordType) + } else { + VerifyObRouters(c, curTs, &header, passwordType) + } // Verification succeeded, continue to the next middleware. c.Next() } diff --git a/agent/api/info_handler.go b/agent/api/info_handler.go index d7989005..a041eaaa 100644 --- a/agent/api/info_handler.go +++ b/agent/api/info_handler.go @@ -63,7 +63,7 @@ func TimeHandler(c *gin.Context) { func InfoHandler(s *http.State) gin.HandlerFunc { return func(c *gin.Context) { obVersion, err := binary.GetMyOBVersion() - agentStatus := meta.NewAgentStatus(meta.OCS_AGENT, global.Pid, s.GetState(), global.StartAt, global.HomePath, obVersion) + agentStatus := meta.NewAgentStatus(meta.OCS_AGENT, global.Pid, s.GetState(), global.StartAt, global.HomePath, obVersion, meta.AGENT_PWD.Inited()) common.SendResponse(c, agentStatus, err) } } diff --git a/agent/api/obcluster_handler.go b/agent/api/obcluster_handler.go index e0576134..daa81f14 100644 --- a/agent/api/obcluster_handler.go +++ b/agent/api/obcluster_handler.go @@ -48,20 +48,20 @@ func parseRootPwd(pwd string, isForward bool) (string, error) { return pwd, nil } -// @ID obclusterConfig -// @Summary put ob cluster configs -// @Description put ob cluster configs -// @Tags ob -// @Accept application/json -// @Produce application/json -// @Param X-OCS-Header header string true "Authorization" -// @Param body body param.ObClusterConfigParams true "obcluster configs" -// @Success 200 object http.OcsAgentResponse{data=task.DagDetailDTO} -// @Failure 401 object http.OcsAgentResponse -// @Failure 400 object http.OcsAgentResponse -// @Failure 500 object http.OcsAgentResponse -// @Router /api/v1/obcluster/config [put] -// @Router /api/v1/obcluster/config [post] +// @ID obclusterConfig +// @Summary put ob cluster configs +// @Description put ob cluster configs +// @Tags ob +// @Accept application/json +// @Produce application/json +// @Param X-OCS-Header header string true "Authorization" +// @Param body body param.ObClusterConfigParams true "obcluster configs" +// @Success 200 object http.OcsAgentResponse{data=task.DagDetailDTO} +// @Failure 401 object http.OcsAgentResponse +// @Failure 400 object http.OcsAgentResponse +// @Failure 500 object http.OcsAgentResponse +// @Router /api/v1/obcluster/config [put] +// @Router /api/v1/obcluster/config [post] func obclusterConfigHandler(deleteAll bool) func(c *gin.Context) { return func(c *gin.Context) { var params param.ObClusterConfigParams @@ -120,20 +120,20 @@ func obclusterConfigHandler(deleteAll bool) func(c *gin.Context) { } } -// @ID obServerConfig -// @Summary put observer configs -// @Description put observer configs -// @Tags ob -// @Accept application/json -// @Produce application/json -// @Param X-OCS-Header header string true "Authorization" -// @Param body body param.ObServerConfigParams true "ob server configs" -// @Success 200 object http.OcsAgentResponse{data=task.DagDetailDTO} -// @Failure 400 object http.OcsAgentResponse -// @Failure 401 object http.OcsAgentResponse -// @Failure 500 object http.OcsAgentResponse -// @Router /api/v1/observer/config [put] -// @Router /api/v1/observer/config [post] +// @ID obServerConfig +// @Summary put observer configs +// @Description put observer configs +// @Tags ob +// @Accept application/json +// @Produce application/json +// @Param X-OCS-Header header string true "Authorization" +// @Param body body param.ObServerConfigParams true "ob server configs" +// @Success 200 object http.OcsAgentResponse{data=task.DagDetailDTO} +// @Failure 400 object http.OcsAgentResponse +// @Failure 401 object http.OcsAgentResponse +// @Failure 500 object http.OcsAgentResponse +// @Router /api/v1/observer/config [put] +// @Router /api/v1/observer/config [post] func obServerConfigHandler(deleteAll bool) func(c *gin.Context) { return func(c *gin.Context) { var params param.ObServerConfigParams @@ -178,18 +178,18 @@ func obServerConfigHandler(deleteAll bool) func(c *gin.Context) { } } -// @ID obInit -// @Summary init ob cluster -// @Description init ob cluster -// @Tags ob -// @Accept application/json -// @Produce application/json -// @Param X-OCS-Header header string true "Authorization" -// @Success 200 object http.OcsAgentResponse{data=task.DagDetailDTO} -// @Failure 400 object http.OcsAgentResponse -// @Failure 401 object http.OcsAgentResponse -// @Failure 500 object http.OcsAgentResponse -// @Router /api/v1/ob/init [post] +// @ID obInit +// @Summary init ob cluster +// @Description init ob cluster +// @Tags ob +// @Accept application/json +// @Produce application/json +// @Param X-OCS-Header header string true "Authorization" +// @Success 200 object http.OcsAgentResponse{data=task.DagDetailDTO} +// @Failure 400 object http.OcsAgentResponse +// @Failure 401 object http.OcsAgentResponse +// @Failure 500 object http.OcsAgentResponse +// @Router /api/v1/ob/init [post] func obInitHandler(c *gin.Context) { var param param.ObInitParam if err := c.BindJSON(¶m); err != nil { @@ -214,19 +214,19 @@ func obInitHandler(c *gin.Context) { } } -// @ID obStop -// @Summary stop observers -// @Description stop observers or the whole cluster, use param to specify -// @Tags ob -// @Accept application/json -// @Produce application/json -// @Param X-OCS-Header header string true "Authorization" -// @Param body body param.ObStopParam true "use 'Scope' to specify the servers/zones/cluster, use 'Force'(optional) to specify whether alter system, use 'ForcePassDag'(optional) to force pass the prev stop dag if need" -// @Success 200 object http.OcsAgentResponse{data=task.DagDetailDTO} -// @Failure 400 object http.OcsAgentResponse -// @Failure 401 object http.OcsAgentResponse -// @Failure 500 object http.OcsAgentResponse -// @Router /api/v1/ob/stop [post] +// @ID obStop +// @Summary stop observers +// @Description stop observers or the whole cluster, use param to specify +// @Tags ob +// @Accept application/json +// @Produce application/json +// @Param X-OCS-Header header string true "Authorization" +// @Param body body param.ObStopParam true "use 'Scope' to specify the servers/zones/cluster, use 'Force'(optional) to specify whether alter system, use 'ForcePassDag'(optional) to force pass the prev stop dag if need" +// @Success 200 object http.OcsAgentResponse{data=task.DagDetailDTO} +// @Failure 400 object http.OcsAgentResponse +// @Failure 401 object http.OcsAgentResponse +// @Failure 500 object http.OcsAgentResponse +// @Router /api/v1/ob/stop [post] func obStopHandler(c *gin.Context) { var param param.ObStopParam if err := c.BindJSON(¶m); err != nil { @@ -247,19 +247,19 @@ func obStopHandler(c *gin.Context) { } } -// @ID obStart -// @Summary start observers -// @Description start observers or the whole cluster, use param to specify -// @Tags ob -// @Accept application/json -// @Produce application/json -// @Param X-OCS-Header header string true "Authorization" -// @Param body body param.StartObParam true "use 'Scope' to specify the servers/zones/cluster, use 'ForcePassDag'(optional) to force pass the prev start dag if need" -// @Success 200 object http.OcsAgentResponse{data=task.DagDetailDTO} -// @Failure 400 object http.OcsAgentResponse -// @Failure 401 object http.OcsAgentResponse -// @Failure 500 object http.OcsAgentResponse -// @Router /api/v1/ob/start [post] +// @ID obStart +// @Summary start observers +// @Description start observers or the whole cluster, use param to specify +// @Tags ob +// @Accept application/json +// @Produce application/json +// @Param X-OCS-Header header string true "Authorization" +// @Param body body param.StartObParam true "use 'Scope' to specify the servers/zones/cluster, use 'ForcePassDag'(optional) to force pass the prev start dag if need" +// @Success 200 object http.OcsAgentResponse{data=task.DagDetailDTO} +// @Failure 400 object http.OcsAgentResponse +// @Failure 401 object http.OcsAgentResponse +// @Failure 500 object http.OcsAgentResponse +// @Router /api/v1/ob/start [post] func obStartHandler(c *gin.Context) { var param param.StartObParam if err := c.BindJSON(¶m); err != nil { @@ -298,18 +298,18 @@ func obStartHandler(c *gin.Context) { } } -// @ID ScaleOut -// @Summary cluster scale-out -// @Description cluster scale-out -// @Tags ob -// @Accept application/json -// @Param X-OCS-Header header string true "Authorization" -// @Param body body param.ClusterScaleOutParam true "scale-out param" -// @Produce application/json -// @Success 200 object http.OcsAgentResponse{data=task.DagDetailDTO} -// @Failure 401 object http.OcsAgentResponse -// @Failure 500 object http.OcsAgentResponse -// @Router /api/v1/ob/scale_out [get] +// @ID ScaleOut +// @Summary cluster scale-out +// @Description cluster scale-out +// @Tags ob +// @Accept application/json +// @Param X-OCS-Header header string true "Authorization" +// @Param body body param.ClusterScaleOutParam true "scale-out param" +// @Produce application/json +// @Success 200 object http.OcsAgentResponse{data=task.DagDetailDTO} +// @Failure 401 object http.OcsAgentResponse +// @Failure 500 object http.OcsAgentResponse +// @Router /api/v1/ob/scale_out [get] func obClusterScaleOutHandler(c *gin.Context) { var param param.ClusterScaleOutParam if err := c.BindJSON(¶m); err != nil { @@ -320,19 +320,19 @@ func obClusterScaleOutHandler(c *gin.Context) { common.SendResponse(c, data, err) } -// @Summary cluster scale-in -// @Description cluster scale-in -// @Tags ob -// @Accept application/json -// @Param X-OCS-Header header string true "Authorization" -// @Param body body param.ClusterScaleInParam true "scale-in param" -// @Produce application/json -// @Success 200 object http.OcsAgentResponse{data=task.DagDetailDTO} -// @Success 204 object http.OcsAgentResponse -// @Failure 401 object http.OcsAgentResponse -// @Failure 500 object http.OcsAgentResponse -// @Router /api/v1/ob/scale_in [post] -// @Router /api/v1/observer [delete] +// @Summary cluster scale-in +// @Description cluster scale-in +// @Tags ob +// @Accept application/json +// @Param X-OCS-Header header string true "Authorization" +// @Param body body param.ClusterScaleInParam true "scale-in param" +// @Produce application/json +// @Success 200 object http.OcsAgentResponse{data=task.DagDetailDTO} +// @Success 204 object http.OcsAgentResponse +// @Failure 401 object http.OcsAgentResponse +// @Failure 500 object http.OcsAgentResponse +// @Router /api/v1/ob/scale_in [post] +// @Router /api/v1/observer [delete] func obClusterScaleInHandler(c *gin.Context) { var param param.ClusterScaleInParam if err := c.BindJSON(¶m); err != nil { @@ -359,15 +359,15 @@ func obClusterScaleInHandler(c *gin.Context) { } } -// @ID GetObInfo -// @Summary get ob and agent info -// @Description get ob and agent info -// @Tags ob -// @Accept application/json -// @Produce application/json -// @Success 200 object http.OcsAgentResponse{data=param.ObInfoResp} -// @Failure 500 object http.OcsAgentResponse -// @Router /api/v1/ob/info [get] +// @ID GetObInfo +// @Summary get ob and agent info +// @Description get ob and agent info +// @Tags ob +// @Accept application/json +// @Produce application/json +// @Success 200 object http.OcsAgentResponse{data=param.ObInfoResp} +// @Failure 500 object http.OcsAgentResponse +// @Router /api/v1/ob/info [get] func obInfoHandler(c *gin.Context) { if meta.OCS_AGENT.IsFollowerAgent() { master := agentService.GetMasterAgentInfo() @@ -392,19 +392,19 @@ func isEmergencyMode(c *gin.Context, scope *param.Scope) (res bool, agentErr *er return false, nil } -// @ID agentUpgradeCheck -// @Summary check agent upgrade -// @Description check agent upgrade -// @Tags upgrade -// @Accept application/json -// @Produce application/json -// @Param X-OCS-Header header string true "Authorization" -// @Param body body param.UpgradeCheckParam true "agent upgrade check params" -// @Success 200 object http.OcsAgentResponse{data=task.DagDetailDTO} -// @Failure 400 object http.OcsAgentResponse -// @Failure 401 object http.OcsAgentResponse -// @Failure 500 object http.OcsAgentResponse -// @Router /api/v1/agent/upgrade/check [post] +// @ID agentUpgradeCheck +// @Summary check agent upgrade +// @Description check agent upgrade +// @Tags upgrade +// @Accept application/json +// @Produce application/json +// @Param X-OCS-Header header string true "Authorization" +// @Param body body param.UpgradeCheckParam true "agent upgrade check params" +// @Success 200 object http.OcsAgentResponse{data=task.DagDetailDTO} +// @Failure 400 object http.OcsAgentResponse +// @Failure 401 object http.OcsAgentResponse +// @Failure 500 object http.OcsAgentResponse +// @Router /api/v1/agent/upgrade/check [post] func agentUpgradeCheckHandler(c *gin.Context) { var param param.UpgradeCheckParam if err := c.BindJSON(¶m); err != nil { @@ -415,19 +415,19 @@ func agentUpgradeCheckHandler(c *gin.Context) { common.SendResponse(c, task, err) } -// @ID obUpgradeCheck -// @Summary check ob upgrade -// @Description check ob upgrade -// @Tags upgrade -// @Accept application/json -// @Produce application/json -// @Param X-OCS-Header header string true "Authorization" -// @Param body body param.UpgradeCheckParam true "ob upgrade check params" -// @Success 200 object http.OcsAgentResponse -// @Failure 400 object http.OcsAgentResponse -// @Failure 401 object http.OcsAgentResponse -// @Failure 500 object http.OcsAgentResponse -// @Router /api/v1/ob/upgrade/check [post] +// @ID obUpgradeCheck +// @Summary check ob upgrade +// @Description check ob upgrade +// @Tags upgrade +// @Accept application/json +// @Produce application/json +// @Param X-OCS-Header header string true "Authorization" +// @Param body body param.UpgradeCheckParam true "ob upgrade check params" +// @Success 200 object http.OcsAgentResponse +// @Failure 400 object http.OcsAgentResponse +// @Failure 401 object http.OcsAgentResponse +// @Failure 500 object http.OcsAgentResponse +// @Router /api/v1/ob/upgrade/check [post] func obUpgradeCheckHandler(c *gin.Context) { var param param.UpgradeCheckParam if err := c.BindJSON(¶m); err != nil { @@ -438,18 +438,18 @@ func obUpgradeCheckHandler(c *gin.Context) { common.SendResponse(c, task, err) } -// @ID UpgradePkgUpload -// @Summary upload upgrade package -// @Description upload upgrade package -// @Tags upgrade -// @Accept multipart/form-data -// @Produce application/json -// @Param X-OCS-Header header string true "Authorization" -// @Param file formData file true "ob upgrade package" -// @Success 200 object http.OcsAgentResponse{data=oceanbase.UpgradePkgInfo} -// @Failure 401 object http.OcsAgentResponse -// @Failure 500 object http.OcsAgentResponse -// @Router /api/v1/upgrade/package [post] +// @ID UpgradePkgUpload +// @Summary upload upgrade package +// @Description upload upgrade package +// @Tags upgrade +// @Accept multipart/form-data +// @Produce application/json +// @Param X-OCS-Header header string true "Authorization" +// @Param file formData file true "ob upgrade package" +// @Success 200 object http.OcsAgentResponse{data=oceanbase.UpgradePkgInfo} +// @Failure 401 object http.OcsAgentResponse +// @Failure 500 object http.OcsAgentResponse +// @Router /api/v1/upgrade/package [post] func pkgUploadHandler(c *gin.Context) { if !meta.OCS_AGENT.IsClusterAgent() { common.SendResponse(c, nil, errors.Occur(errors.ErrObclusterNotFound, "Unable to proceed with package upload. Please ensure the 'init' command is executed before attempting to upload.")) @@ -465,53 +465,53 @@ func pkgUploadHandler(c *gin.Context) { common.SendResponse(c, &data, agentErr) } -// @ID ParamsBackup -// @Summary backup params -// @Description backup params -// @Tags upgrade -// @Accept application/json -// @Produce application/json -// @Param X-OCS-Header header string true "Authorization" -// @Success 200 object http.OcsAgentResponse{data=[]oceanbase.ObParameters} -// @Failure 401 object http.OcsAgentResponse -// @Failure 500 object http.OcsAgentResponse -// @Router /api/v1/upgrade/params/backup [post] +// @ID ParamsBackup +// @Summary backup params +// @Description backup params +// @Tags upgrade +// @Accept application/json +// @Produce application/json +// @Param X-OCS-Header header string true "Authorization" +// @Success 200 object http.OcsAgentResponse{data=[]oceanbase.ObParameters} +// @Failure 401 object http.OcsAgentResponse +// @Failure 500 object http.OcsAgentResponse +// @Router /api/v1/upgrade/params/backup [post] func paramsBackupHandler(c *gin.Context) { data, err := ob.ParamsBackup() common.SendResponse(c, data, err) } -// @ID ParamsRestore -// @Summary restore params -// @Description restore params -// @Tags upgrade -// @Accept application/json -// @Produce application/json -// @Param X-OCS-Header header string true "Authorization" -// @Param body body param.RestoreParams true "restore params" -// @Success 200 object http.OcsAgentResponse -// @Failure 401 object http.OcsAgentResponse -// @Failure 500 object http.OcsAgentResponse -// @Router /api/v1/upgrade/params/restore [post] +// @ID ParamsRestore +// @Summary restore params +// @Description restore params +// @Tags upgrade +// @Accept application/json +// @Produce application/json +// @Param X-OCS-Header header string true "Authorization" +// @Param body body param.RestoreParams true "restore params" +// @Success 200 object http.OcsAgentResponse +// @Failure 401 object http.OcsAgentResponse +// @Failure 500 object http.OcsAgentResponse +// @Router /api/v1/upgrade/params/restore [post] func paramsRestoreHandler(c *gin.Context) { var param param.RestoreParams err := ob.ParamsRestore(param) common.SendResponse(c, nil, err) } -// @ID agentUpgrade -// @Summary upgrade agent -// @Description upgrade agent -// @Tags upgrade -// @Accept application/json -// @Produce application/json -// @Param X-OCS-Header header string true "Authorization" -// @Param body body param.UpgradeCheckParam true "agent upgrade check params" -// @Success 200 object http.OcsAgentResponse{data=task.DagDetailDTO} -// @Failure 400 object http.OcsAgentResponse -// @Failure 401 object http.OcsAgentResponse -// @Failure 500 object http.OcsAgentResponse -// @Router /api/v1/agent/upgrade [post] +// @ID agentUpgrade +// @Summary upgrade agent +// @Description upgrade agent +// @Tags upgrade +// @Accept application/json +// @Produce application/json +// @Param X-OCS-Header header string true "Authorization" +// @Param body body param.UpgradeCheckParam true "agent upgrade check params" +// @Success 200 object http.OcsAgentResponse{data=task.DagDetailDTO} +// @Failure 400 object http.OcsAgentResponse +// @Failure 401 object http.OcsAgentResponse +// @Failure 500 object http.OcsAgentResponse +// @Router /api/v1/agent/upgrade [post] func agentUpgradeHandler(c *gin.Context) { var param param.UpgradeCheckParam if err := c.BindJSON(¶m); err != nil { @@ -522,19 +522,19 @@ func agentUpgradeHandler(c *gin.Context) { common.SendResponse(c, dag, err) } -// @ID obUpgrade -// @Summary upgrade ob -// @Description upgrade ob -// @Tags upgrade -// @Accept application/json -// @Produce application/json -// @Param X-OCS-Header header string true "Authorization" -// @Param body body param.ObUpgradeParam true "ob upgrade params" -// @Success 200 object http.OcsAgentResponse{data=task.DagDetailDTO} -// @Failure 400 object http.OcsAgentResponse -// @Failure 401 object http.OcsAgentResponse -// @Failure 500 object http.OcsAgentResponse -// @Router /api/v1/ob/upgrade [post] +// @ID obUpgrade +// @Summary upgrade ob +// @Description upgrade ob +// @Tags upgrade +// @Accept application/json +// @Produce application/json +// @Param X-OCS-Header header string true "Authorization" +// @Param body body param.ObUpgradeParam true "ob upgrade params" +// @Success 200 object http.OcsAgentResponse{data=task.DagDetailDTO} +// @Failure 400 object http.OcsAgentResponse +// @Failure 401 object http.OcsAgentResponse +// @Failure 500 object http.OcsAgentResponse +// @Router /api/v1/ob/upgrade [post] func obUpgradeHandler(c *gin.Context) { var param param.ObUpgradeParam if err := c.BindJSON(¶m); err != nil { diff --git a/agent/api/obproxy_route.go b/agent/api/obproxy_route.go new file mode 100644 index 00000000..702ca827 --- /dev/null +++ b/agent/api/obproxy_route.go @@ -0,0 +1,162 @@ +/* + * Copyright (c) 2024 OceanBase. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package api + +import ( + "github.com/gin-gonic/gin" + + "github.com/oceanbase/obshell/agent/api/common" + "github.com/oceanbase/obshell/agent/constant" + "github.com/oceanbase/obshell/agent/errors" + "github.com/oceanbase/obshell/agent/executor/obproxy" + "github.com/oceanbase/obshell/agent/secure" + "github.com/oceanbase/obshell/param" +) + +func InitObproxyRoutes(r *gin.RouterGroup, isLocalRoute bool) { + obproxy := r.Group(constant.URI_OBPROXY_GROUP) + if !isLocalRoute { + obproxy.Use(common.Verify(secure.ROUTE_OBPROXY)) + } + + // obproxy routes + obproxy.POST("", obproxyAddHandler) + obproxy.DELETE("", obproxyDeleteHandler) + obproxy.POST(constant.URI_START, obproxyStartHandler) + obproxy.POST(constant.URI_STOP, obproxyStopHandler) + obproxy.POST("package", obproxyPkgUploadHandler) + obproxy.POST(constant.URI_UPGRADE, obproxyUpgradeHandler) +} + +// @ID obproxyAdd +// @Summary Add obproxy +// @Tags Obproxy +// @Accept application/json +// @Produce application/json +// @Param X-OCS-Agent-Header header string true "Authorization" +// @Param body body param.AddObproxyParam true "Add obproxy" +// @Success 200 object http.OcsAgentResponse{data=task.DagDetailDTO} +// @Failure 400 object http.OcsAgentResponse +// @Failure 401 object http.OcsAgentResponse +// @Failure 500 object http.OcsAgentResponse +// @Router /api/v1/obproxy [post] +func obproxyAddHandler(c *gin.Context) { + var param param.AddObproxyParam + if err := c.BindJSON(¶m); err != nil { + common.SendResponse(c, nil, err) + return + } + dag, err := obproxy.AddObproxy(param) + common.SendResponse(c, dag, err) +} + +// @ID obproxyStop +// @Summary Stop obproxy +// @Tags Obproxy +// @Accept application/json +// @Produce application/json +// @Param X-OCS-Agent-Header header string true "Authorization" +// @Success 200 object http.OcsAgentResponse{data=task.DagDetailDTO} +// @Failure 400 object http.OcsAgentResponse +// @Failure 401 object http.OcsAgentResponse +// @Failure 500 object http.OcsAgentResponse +// @Router /api/v1/obproxy/stop [post] +func obproxyStopHandler(c *gin.Context) { + dag, err := obproxy.StopObproxy() + common.SendResponse(c, dag, err) +} + +// @ID obproxyStart +// @Summary Start obproxy +// @Tags Obproxy +// @Accept application/json +// @Produce application/json +// @Param X-OCS-Agent-Header header string true "Authorization" +// @Success 200 object http.OcsAgentResponse{data=task.DagDetailDTO} +// @Failure 400 object http.OcsAgentResponse +// @Failure 401 object http.OcsAgentResponse +// @Failure 500 object http.OcsAgentResponse +// @Router /api/v1/obproxy/start [post] +func obproxyStartHandler(c *gin.Context) { + dag, err := obproxy.StartObproxy() + common.SendResponse(c, dag, err) +} + +// @ID obproxyDelete +// @Summary Delete obproxy +// @Tags Obproxy +// @Accept application/json +// @Produce application/json +// @Param X-OCS-Agent-Header header string true "Authorization" +// @Success 204 object http.OcsAgentResponse +// @Failure 400 object http.OcsAgentResponse +// @Failure 401 object http.OcsAgentResponse +// @Failure 500 object http.OcsAgentResponse +// @Router /api/v1/obproxy [delete] +func obproxyDeleteHandler(c *gin.Context) { + dag, err := obproxy.DeleteObproxy() + if dag == nil && err == nil { + common.SendNoContentResponse(c, nil) + } + common.SendResponse(c, dag, err) +} + +// @ID obproxyUpgrade +// @Summary Upgrade obproxy +// @Tags Obproxy +// @Accept application/json +// @Produce application/json +// @Param X-OCS-Agent-Header header string true "Authorization" +// @Param body body param.UpgradeObproxyParam true "Upgrade obproxy" +// @Success 200 object http.OcsAgentResponse{data=task.DagDetailDTO} +// @Failure 400 object http.OcsAgentResponse +// @Failure 401 object http.OcsAgentResponse +// @Failure 500 object http.OcsAgentResponse +// @Router /api/v1/obproxy/upgrade [post] +func obproxyUpgradeHandler(c *gin.Context) { + var param param.UpgradeObproxyParam + if err := c.BindJSON(¶m); err != nil { + common.SendResponse(c, nil, err) + return + } + + dag, err := obproxy.UpgradeObproxy(param) + common.SendResponse(c, dag, err) +} + +// @ID obproxyPkgUpload +// @Summary Upload obproxy package +// @Tags Obproxy +// @Accept multipart/form-data +// @Produce application/json +// @Param X-OCS-Agent-Header header string true "Authorization" +// @Param file formData file true "Obproxy package" +// @Success 200 object http.OcsAgentResponse{data=sqlite.UpgradePkgInfo} +// @Failure 400 object http.OcsAgentResponse +// @Failure 401 object http.OcsAgentResponse +// @Failure 500 object http.OcsAgentResponse +// @Router /api/v1/obproxy/package [post] +func obproxyPkgUploadHandler(c *gin.Context) { + file, _, err := c.Request.FormFile("file") + if err != nil { + common.SendResponse(c, nil, errors.Occur(errors.ErrKnown, "get file failed.", err)) + return + } + defer file.Close() + data, agentErr := obproxy.UpgradePkgUpload(file) + common.SendResponse(c, &data, agentErr) +} diff --git a/agent/api/task_handler.go b/agent/api/task_handler.go index 472a7563..6570ff10 100644 --- a/agent/api/task_handler.go +++ b/agent/api/task_handler.go @@ -22,16 +22,18 @@ import ( "github.com/oceanbase/obshell/agent/api/common" "github.com/oceanbase/obshell/agent/constant" "github.com/oceanbase/obshell/agent/executor/task" + "github.com/oceanbase/obshell/agent/secure" ) func InitTaskRoutes(r *gin.RouterGroup, isLocalRoute bool) { group := r.Group(constant.URI_TASK_GROUP) if !isLocalRoute { - group.Use(common.Verify()) + group.Use(common.Verify(secure.ROUTE_TASK)) } group.GET(constant.URI_SUB_TASK+"/:id", task.GetSubTaskDetail) group.GET(constant.URI_NODE+"/:id", task.GetNodeDetail) group.GET(constant.URI_DAG+"/:id", task.GetDagDetail) + group.POST(constant.URI_DAG+"/:id", task.DagHandler) group.GET(constant.URI_DAG+constant.URI_MAINTAIN+constant.URI_OB_GROUP, task.GetObLastMaintenanceDag) group.GET(constant.URI_DAG+constant.URI_MAINTAIN+constant.URI_AGENT_GROUP, task.GetAgentLastMaintenanceDag) group.GET(constant.URI_DAG+constant.URI_MAINTAIN+constant.URI_AGENTS_GROUP, task.GetAllAgentLastMaintenanceDag) @@ -39,5 +41,4 @@ func InitTaskRoutes(r *gin.RouterGroup, isLocalRoute bool) { group.GET(constant.URI_DAG+constant.URI_OB_GROUP+constant.URI_UNFINISH, task.GetClusterUnfinishDags) group.GET(constant.URI_DAG+constant.URI_AGENT_GROUP+constant.URI_UNFINISH, task.GetAgentUnfinishDags) - group.POST(constant.URI_DAG+"/:id", task.DagHandler) } diff --git a/agent/api/tenant_handler.go b/agent/api/tenant_handler.go index 02352c39..479354db 100644 --- a/agent/api/tenant_handler.go +++ b/agent/api/tenant_handler.go @@ -53,6 +53,8 @@ func InitTenantRoutes(v1 *gin.RouterGroup, isLocalRoute bool) { tenant.GET(constant.URI_PATH_PARAM_NAME+constant.URI_VARIABLE+constant.URI_PATH_PARAM_VAR, getTenantVariable) tenant.GET(constant.URI_PATH_PARAM_NAME+constant.URI_PARAMETERS, getTenantParameters) tenant.GET(constant.URI_PATH_PARAM_NAME+constant.URI_VARIABLES, getTenantVariables) + tenant.POST(constant.URI_PATH_PARAM_NAME+constant.URI_USER, createUserHandler) + tenant.DELETE(constant.URI_PATH_PARAM_NAME+constant.URI_USER+constant.URI_PATH_PARAM_USER, dropUserHandler) tenants.GET(constant.URI_OVERVIEW, getTenantOverView) } @@ -435,7 +437,7 @@ func tenantSetVariableHandler(c *gin.Context) { common.SendResponse(c, nil, err) return } - common.SendResponse(c, nil, tenant.SetTenantVariables(name, param.Variables)) + common.SendResponse(c, nil, tenant.SetTenantVariables(c, name, param)) } // @ID getTenantInfo @@ -587,3 +589,66 @@ func getTenantOverView(c *gin.Context) { tenants, err := tenant.GetTenantsOverView() common.SendResponse(c, tenants, err) } + +// @ID createUser +// @Summary create user +// @Description create user +// @Tags tenant +// @Accept application/json +// @Produce application/json +// @Param X-OCS-Header header string true "Authorization" +// @Param body body param.CreateUserParam true "create user params" +// @Success 200 object http.OcsAgentResponse +// @Failure 400 object http.OcsAgentResponse +// @Failure 401 object http.OcsAgentResponse +// @Failure 500 object http.OcsAgentResponse +// @Router /api/v1/tenant/{name}/user [post] +func createUserHandler(c *gin.Context) { + name, err := tenantCheckWithName(c) + if err != nil { + common.SendResponse(c, nil, err) + return + } + var param param.CreateUserParam + if err := c.BindJSON(¶m); err != nil { + common.SendResponse(c, nil, err) + return + } + common.SendResponse(c, nil, tenant.CreateUser(name, param)) +} + +// @ID dropUser +// @Summary drop user +// @Description drop user +// @Tags tenant +// @Accept application/json +// @Produce application/json +// @Param X-OCS-Header header string true "Authorization" +// @Param name path string true "tenant name" +// @Param user path string true "user name" +// @Param body body param.DropUserParam true "drop user params" +// @Success 200 object http.OcsAgentResponse +// @Failure 400 object http.OcsAgentResponse +// @Failure 401 object http.OcsAgentResponse +// @Failure 500 object http.OcsAgentResponse +// @Router /api/v1/tenant/{name}/user/{user} [delete] +func dropUserHandler(c *gin.Context) { + name, err := tenantCheckWithName(c) + if err != nil { + common.SendResponse(c, nil, err) + return + } + userName := c.Param(constant.URI_PARAM_USER) + if userName == "" { + common.SendResponse(c, nil, errors.Occur(errors.ErrIllegalArgument, "User name is empty.")) + return + } + + var param param.DropUserParam + if err := c.BindJSON(¶m); err != nil { + common.SendResponse(c, nil, err) + return + } + + common.SendResponse(c, nil, tenant.DropUser(name, userName, param.RootPassword)) +} diff --git a/agent/api/zone_handler.go b/agent/api/zone_handler.go index 7852884a..932e2967 100644 --- a/agent/api/zone_handler.go +++ b/agent/api/zone_handler.go @@ -25,20 +25,20 @@ import ( "github.com/oceanbase/obshell/agent/meta" ) -// @ID DeleteZone +// @ID DeleteZone // -// @Summary delete zone -// @Description delete zone -// @Tags ob -// @Accept application/json -// @Produce application/json -// @Param X-OCS-Header header string true "Authorization" -// @Param zoneName path string true "zone name" -// @Success 200 object http.OcsAgentResponse{data=task.DagDetailDTO} -// @Success 204 object http.OcsAgentResponse -// @Failure 401 object http.OcsAgentResponse -// @Failure 500 object http.OcsAgentResponse -// @Router /api/v1/zone/{zoneName} [delete] +// @Summary delete zone +// @Description delete zone +// @Tags ob +// @Accept application/json +// @Produce application/json +// @Param X-OCS-Header header string true "Authorization" +// @Param zoneName path string true "zone name" +// @Success 200 object http.OcsAgentResponse{data=task.DagDetailDTO} +// @Success 204 object http.OcsAgentResponse +// @Failure 401 object http.OcsAgentResponse +// @Failure 500 object http.OcsAgentResponse +// @Router /api/v1/zone/{zoneName} [delete] func zoneDeleteHandler(c *gin.Context) { zoneName := c.Param(constant.URI_PARAM_NAME) if zoneName == "" { diff --git a/agent/assets/i18n/error/en.json b/agent/assets/i18n/error/en.json index 1cbbb92d..698b8128 100644 --- a/agent/assets/i18n/error/en.json +++ b/agent/assets/i18n/error/en.json @@ -10,5 +10,5 @@ "err.obcluster.not.found": "There is no obcluster now. %v", "err.user.permission.denied": "Permission denied: %v", - "err.unauthorized": "Verification failed" + "err.unauthorized": "Verification failed: %v" } diff --git a/agent/cmd/admin/start.go b/agent/cmd/admin/start.go index c93f7146..28740a0d 100644 --- a/agent/cmd/admin/start.go +++ b/agent/cmd/admin/start.go @@ -130,8 +130,12 @@ func isDaemonRunning() (pid int32, res bool) { if err != nil { return 0, false } - if _, err = proc.NewProcess(pid); err != nil { + if pidInfo, err := proc.NewProcess(pid); err != nil { return pid, false + } else { + if name, err := pidInfo.Name(); err == nil && name != constant.PROC_OBSHELL { + return pid, false + } } return pid, true } diff --git a/agent/cmd/daemon/start_daemon.go b/agent/cmd/daemon/start_daemon.go index 11c1a047..7250b2a2 100644 --- a/agent/cmd/daemon/start_daemon.go +++ b/agent/cmd/daemon/start_daemon.go @@ -20,7 +20,6 @@ import ( "fmt" "net" "os" - "syscall" "time" log "github.com/sirupsen/logrus" @@ -123,22 +122,7 @@ func (d *Daemon) startSocket(socketListener *net.UnixListener) { func (d *Daemon) writePid() (err error) { pid := os.Getpid() log.Info("obshell daemon pid is ", pid) - return writePid(path.DaemonPidPath(), pid) -} - -// writePid writes the pid to the specified path atomically. -// If the file already exists, an error is returned. -func writePid(path string, pid int) (err error) { - f, err := os.OpenFile(path, os.O_RDWR|os.O_CREATE|os.O_EXCL|os.O_SYNC|syscall.O_CLOEXEC, 0644) - if err != nil { - return err - } - defer f.Close() - _, err = fmt.Fprint(f, pid) - if err != nil { - return err - } - return nil + return process.WritePid(path.DaemonPidPath(), pid) } func (d *Daemon) isForUpgrade() bool { diff --git a/agent/cmd/daemon/start_process.go b/agent/cmd/daemon/start_process.go index d302f0f4..08d4c373 100644 --- a/agent/cmd/daemon/start_process.go +++ b/agent/cmd/daemon/start_process.go @@ -120,7 +120,7 @@ func (s *Server) startServerProc() (err error) { } func (s *Server) writePid() error { - return writePid(path.ObshellPidPath(), s.GetPid()) + return process.WritePid(path.ObshellPidPath(), s.GetPid()) } func (s *Server) handleProcExited(procState process.ProcState, count *int) (err error) { diff --git a/agent/cmd/server/init.go b/agent/cmd/server/init.go index d3aa3e5b..245fa950 100644 --- a/agent/cmd/server/init.go +++ b/agent/cmd/server/init.go @@ -18,7 +18,7 @@ package server import ( "fmt" - "os" + "path/filepath" "syscall" "time" @@ -31,6 +31,7 @@ import ( "github.com/oceanbase/obshell/agent/errors" "github.com/oceanbase/obshell/agent/executor/agent" "github.com/oceanbase/obshell/agent/executor/ob" + "github.com/oceanbase/obshell/agent/executor/obproxy" "github.com/oceanbase/obshell/agent/executor/pool" "github.com/oceanbase/obshell/agent/executor/recyclebin" "github.com/oceanbase/obshell/agent/executor/script" @@ -98,14 +99,12 @@ func (a *Agent) initSqlite() (err error) { // initServerForUpgrade will only start the unix socket service When upgrading. func (a *Agent) initServerForUpgrade() error { log.Info("init local server [upgrade mode]") - serverConfig := config.ServerConfig{ - Ip: "0.0.0.0", - Port: meta.OCS_AGENT.GetPort(), - Address: fmt.Sprintf("0.0.0.0:%d", meta.OCS_AGENT.GetPort()), - RunDir: path.RunDir(), - UpgradeMode: true, + serverConfig, err := config.NewServerConfig(meta.OCS_AGENT.GetIp(), meta.OCS_AGENT.GetPort(), path.RunDir(), true) + if err != nil { + return err } - a.server = web.NewServerOnlyLocal(config.DebugMode, serverConfig) + + a.server = web.NewServerOnlyLocal(config.DebugMode, *serverConfig) socketListener, err := a.server.NewUnixListener() if err != nil { return err @@ -227,14 +226,9 @@ func (a *Agent) checkAgentInfo() { // initServer will only initialize the Server and will not start the service. func (a *Agent) initServer() { log.Info("init server") - serverConfig := config.ServerConfig{ - Ip: "0.0.0.0", - Port: meta.OCS_AGENT.GetPort(), - Address: fmt.Sprintf("0.0.0.0:%d", meta.OCS_AGENT.GetPort()), - RunDir: path.RunDir(), - } + serverConfig, _ := config.NewServerConfig(meta.OCS_AGENT.GetIp(), meta.OCS_AGENT.GetPort(), path.RunDir(), false) log.Infof("server config is %v", serverConfig) - a.server = web.NewServer(config.DebugMode, serverConfig) + a.server = web.NewServer(config.DebugMode, *serverConfig) a.startChan = make(chan bool, 1) } @@ -253,6 +247,7 @@ func (a *Agent) initTask() { recyclebin.RegisterRecyclebinTask() task.RegisterTaskType(script.ImportScriptForTenantTask{}) pool.RegisterPoolTask() + obproxy.RegisterTaskType() } // Check if the ob config file exists. @@ -267,12 +262,16 @@ func (a *Agent) isUpgradeMode() bool { if a.OldServerPid != 0 { // If the old agent is running in the same directory as the new agent, // it is considered an upgrade. - cwdDir, err := os.Readlink(fmt.Sprintf("/proc/%d/cwd", a.OldServerPid)) + cwdDir, err := filepath.EvalSymlinks(fmt.Sprintf("/proc/%d/cwd", a.OldServerPid)) + if err != nil { + return false + } + curDir, err := filepath.EvalSymlinks(global.HomePath) if err != nil { return false } log.Infof("the cwd of %d is %s", a.OldServerPid, cwdDir) - if global.HomePath == cwdDir { + if curDir == cwdDir { log.Info("The obshell is in upgrade mode.") a.upgradeMode = true // Unset root password env to avoid cover sqlite when upgrade (agent restart) diff --git a/agent/cmd/server/run.go b/agent/cmd/server/run.go index dc1da2c6..5315158d 100644 --- a/agent/cmd/server/run.go +++ b/agent/cmd/server/run.go @@ -79,17 +79,23 @@ func (a *Agent) restoreSecure() (err error) { log.WithError(err).Error("reinit secure failed") return err } - } else { - log.Info("restore secure info successed, check password in sqlite") - err = secure.LoadPassword(a.GetRootPassword()) - if err != nil { - log.WithError(err).Info("check password in sqlite failed") - if !meta.OCS_AGENT.IsClusterAgent() { - process.ExitWithFailure(constant.EXIT_CODE_ERROR_NOT_CLUSTER_AGENT, "check password in sqlite failed: not cluster agent") - } - } else { - log.Info("check password in sqlite successed") + } + + log.Info("restore secure info successed, check password of root@sys in sqlite") + err = secure.LoadOceanbasePassword(a.GetRootPassword()) + if err != nil { + log.WithError(err).Info("check password of root@sys in sqlite failed") + if !meta.OCS_AGENT.IsClusterAgent() { + process.ExitWithFailure(constant.EXIT_CODE_ERROR_NOT_CLUSTER_AGENT, "check password of root@sys in sqlite failed: not cluster agent") } + } else { + log.Info("check password of root@sys in sqlite successed") + } + + log.Info("check agent password from sqlite") + err = secure.LoadAgentPassword() + if err != nil { + log.WithError(err).Error("check agent password from sqlite failed") } return nil } diff --git a/agent/config/oceanbase.go b/agent/config/oceanbase.go index c92ca16d..4e25cce6 100644 --- a/agent/config/oceanbase.go +++ b/agent/config/oceanbase.go @@ -21,12 +21,13 @@ import ( "strings" "github.com/oceanbase/obshell/agent/constant" + "github.com/oceanbase/obshell/agent/meta" ) func NewObDataSourceConfig() *ObDataSourceConfig { return &ObDataSourceConfig{ username: constant.DB_USERNAME, - ip: constant.LOCAL_IP, + ip: meta.OCS_AGENT.GetLocalIp(), dBName: constant.DB_OCS, charset: constant.DB_DEFAULT_CHARSET, parseTime: true, @@ -37,6 +38,19 @@ func NewObDataSourceConfig() *ObDataSourceConfig { } } +func NewObproxyDataSourceConfig() *ObDataSourceConfig { + return &ObDataSourceConfig{ + username: constant.DB_PROXYSYS_USERNAME, + ip: constant.LOCAL_IP, + charset: constant.DB_DEFAULT_CHARSET, + parseTime: true, + location: constant.DB_DEFAULT_LOCATION, + maxIdleConns: constant.DB_DEFAULT_MAX_IDLE_CONNS, + maxOpenConns: constant.DB_DEFAULT_MAX_OPEN_CONNS, + connMaxLifetime: constant.DB_DEFAULT_CONN_MAX_LIFETIME, + } +} + type ObDataSourceConfig struct { // dsn config username string @@ -182,7 +196,7 @@ func (config *ObDataSourceConfig) GetSkipPwdCheck() bool { } func (config *ObDataSourceConfig) GetDSN() string { - dsn := fmt.Sprintf("%s:%s@tcp(%s:%d)/", config.username, config.password, config.ip, config.port) + dsn := fmt.Sprintf("%s:%s@tcp(%s)/", config.username, config.password, meta.NewAgentInfo(config.ip, config.port).String()) if config.dBName != "" { dsn += config.dBName } diff --git a/agent/config/server.go b/agent/config/server.go index fc18033a..f2a5aea7 100644 --- a/agent/config/server.go +++ b/agent/config/server.go @@ -16,6 +16,12 @@ package config +import ( + "errors" + "fmt" + "net" +) + type AgentMode = string const ( @@ -30,3 +36,29 @@ type ServerConfig struct { RunDir string UpgradeMode bool } + +func NewServerConfig(ip string, port int, runDir string, UpgradeMode bool) (*ServerConfig, error) { + address, err := generateAddress(ip, port) + if err != nil { + return nil, err + } + + return &ServerConfig{ + Ip: ip, + Port: port, + Address: address, + RunDir: runDir, + UpgradeMode: UpgradeMode, + }, nil +} + +func generateAddress(ip string, port int) (string, error) { + ipParsed := net.ParseIP(ip) + if ipParsed == nil { + return "", errors.New("invalid ip") + } + if ipParsed.To4() != nil { + return fmt.Sprint("0.0.0.0:", port), nil + } + return fmt.Sprint("[::]:", port), nil +} diff --git a/agent/constant/agent.go b/agent/constant/agent.go index d578b6ab..f8bbe204 100644 --- a/agent/constant/agent.go +++ b/agent/constant/agent.go @@ -78,6 +78,8 @@ const ( PROC_OBSHELL_CLIENT = "client" PROC_OBSERVER = "observer" + + PROC_OBPROXY = "obproxy" ) // upload pkg names @@ -85,6 +87,7 @@ const ( PKG_OBSHELL = "obshell" PKG_OCEANBASE_CE = "oceanbase-ce" PKG_OCEANBASE_CE_LIBS = "oceanbase-ce-libs" + PKG_OBPROXY_CE = "obproxy-ce" ) var SUPPORT_PKG_NAMES = []string{ diff --git a/agent/constant/gorm.go b/agent/constant/gorm.go index 9ef3507a..747c7fc6 100644 --- a/agent/constant/gorm.go +++ b/agent/constant/gorm.go @@ -18,9 +18,11 @@ package constant const ( // source config default value - DB_USERNAME = "root" - LOCAL_IP = "127.0.0.1" - DB_DEFAILT_TIMEOUT = 10 + DB_USERNAME = "root" + DB_PROXYSYS_USERNAME = "root@proxysys" + LOCAL_IP = "127.0.0.1" + LOCAL_IP_V6 = "::1" + DB_DEFAULT_CHARSET = "utf8mb4" DB_DEFAULT_LOCATION = "Local" diff --git a/agent/constant/obproxy.go b/agent/constant/obproxy.go new file mode 100644 index 00000000..1bdb3741 --- /dev/null +++ b/agent/constant/obproxy.go @@ -0,0 +1,58 @@ +/* + * Copyright (c) 2024 OceanBase. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package constant + +const ( + OBPROXY_INFO_SQL_PORT = "sql_port" + OBPROXY_INFO_OBPROXY_SYS_PASSWORD = "obproxy_sys_password" + OBPROXY_INFO_HOME_PATH = "home_path" + OBPROXY_INFO_PROXYRO_PASSWORD = "proxyro_password" + OBPROXY_INFO_VERSION = "version" + + OBPROXY_CONFIG_PROMETHUES_LISTEN_PORT = "prometheus_listen_port" + OBPROXY_CONFIG_RS_LIST = "rootservice_list" + OBPROXY_CONFIG_CONFIG_SERVER_URL = "obproxy_config_server_url" + OBPROXY_CONFIG_LISTEN_PORT = "listen_port" + OBPROXY_CONFIG_CLUSTER_NAME = "cluster_name" + OBPROXY_CONFIG_RPC_LISTEN_PORT = "rpc_listen_port" + OBPROXY_CONFIG_OBPROXY_SYS_PASSWORD = "obproxy_sys_password" + OBPROXY_CONFIG_ROOT_SERVICE_CLUSTER_NAME = "rootservice_cluster_name" + OBPROXY_CONFIG_PROXYRO_PASSWORD = "observer_sys_password" + OBPROXY_CONFIG_PROXY_LOCAL_CMD = "proxy_local_cmd" + OBPROXY_CONFIG_HOT_UPGRADE_ROLLBACK_TIMEOUT = "hot_upgrade_rollback_timeout" + OBPROXY_CONFIG_HOT_UPGRADE_EXIT_TIMEOUT = "hot_upgrade_exit_timeout" + + OBPROXY_MIN_VERSION_SUPPORT = "4.0.0" + + OBPROXY_INFO_STATUS = "status" + + OBPROXY_DIR_ETC = "etc" + OBPROXY_DIR_BIN = "bin" + OBPROXY_DIR_LIB = "lib" + OBPROXY_DIR_LOG = "log" + OBPROXY_DIR_RUN = "run" + BIN_OBPROXY = "obproxy" + BIN_OBPROXYD = "obproxyd" + + RESTART_FOR_PROXY_LOCAL_CMD = "2" + + OBPROXY_DEFAULT_SQL_PORT = 2883 + OBPROXY_DEFAULT_EXPORTER_PORT = 2884 + OBPROXY_DEFAULT_RPC_PORT = 2885 + + DEFAULT_HOT_RESTART_TIME_OUT = 1800 // 30 minutes +) diff --git a/agent/constant/oceanbase.go b/agent/constant/oceanbase.go index 3368d057..e09c2fe7 100644 --- a/agent/constant/oceanbase.go +++ b/agent/constant/oceanbase.go @@ -28,18 +28,23 @@ const ( DB_OCEANBASE = "oceanbase" DB_OCS = "ocs" + DEFAULT_HOST = "%" + + SYS_USER_PROXYRO = "proxyro" + CONFIG_RPC_PORT = "rpc_port" CONFIG_MYSQL_PORT = "mysql_port" DEFAULT_MYSQL_PORT = 2881 DEFAULT_RPC_PORT = 2882 - CONFIG_HOME_PATH = "homePath" - CONFIG_ROOT_PWD = "rootPwd" - CONFIG_DATA_DIR = "data_dir" - CONFIG_REDO_DIR = "redo_dir" - CONFIG_CLOG_DIR = "clog_dir" - CONFIG_SLOG_DIR = "slog_dir" + CONFIG_HOME_PATH = "homePath" + CONFIG_ROOT_PWD = "rootPwd" + CONFIG_AGENT_PASSWORD = "agentRootPwd" + CONFIG_DATA_DIR = "data_dir" + CONFIG_REDO_DIR = "redo_dir" + CONFIG_CLOG_DIR = "clog_dir" + CONFIG_SLOG_DIR = "slog_dir" CONFIG_LOCAL_IP = "local_ip" CONFIG_DEV_NAME = "devname" diff --git a/agent/constant/secure.go b/agent/constant/secure.go index 596622c1..a376170b 100644 --- a/agent/constant/secure.go +++ b/agent/constant/secure.go @@ -20,6 +20,7 @@ import "time" const ( OCS_HEADER = "X-OCS-Header" + OCS_AGENT_HEADER = "X-OCS-Agent-Header" REQUEST_RECEIVED_TIME = "request_received_time" RESPONSE_PWD_KEY = "password" AGENT_PRIVATE_KEY = "private_key" diff --git a/agent/constant/task.go b/agent/constant/task.go index 00378a9e..98c1c5f0 100644 --- a/agent/constant/task.go +++ b/agent/constant/task.go @@ -25,10 +25,12 @@ const ( ) const ( - CLUSTER_TASK_ID_PREFIX = '1' - LOCAL_TASK_IPV4_ID_PREFIX = '2' - LOCAL_TASK_IPV6_ID_PREFIX = '3' - ENGINE_WAIT_TIME = 30 * time.Second + CLUSTER_TASK_ID_PREFIX = '1' + LOCAL_TASK_IPV4_ID_PREFIX = '2' + LOCAL_TASK_IPV6_ID_PREFIX = '3' + OBPROXY_TASK_IPV4_ID_PREFIX = '4' + OBPROXY_TASK_IPV6_ID_PREFIX = '5' + ENGINE_WAIT_TIME = 30 * time.Second SYNC_INTERVAL = 1 * time.Second SYNC_TASK_BUFFER_SIZE = 10000 diff --git a/agent/constant/uri.go b/agent/constant/uri.go index 9e81baea..43480349 100644 --- a/agent/constant/uri.go +++ b/agent/constant/uri.go @@ -34,6 +34,7 @@ const ( URI_POOL_GROUP = "/resource-pool" URI_POOLS_GROUP = "/resource-pools" URI_RECYCLEBIN_GROUP = "/recyclebin" + URI_OBPROXY_GROUP = "/obproxy" URI_INFO = "/info" URI_TIME = "/time" @@ -41,8 +42,10 @@ const ( URI_STATUS = "/status" URI_SECRET = "secret" - URI_JOIN = "/join" - URI_REMOVE = "/remove" + URI_JOIN = "/join" + URI_REMOVE = "/remove" + URI_PASSWORD = "/password" + URI_TOKEN = "/token" URI_SYNC_BIN = "/sync-bin" @@ -90,6 +93,7 @@ const ( URI_PARAMETER = "/parameter" URI_OVERVIEW = "/overview" URI_TENANT = "/tenant" + URI_USER = "/user" URI_PARAM_NAME = "name" URI_PATH_PARAM_NAME = "/:" + URI_PARAM_NAME @@ -97,6 +101,8 @@ const ( URI_PATH_PARAM_VAR = "/:" + URI_PARAM_VAR URI_PARAM_PARA = "parameter" URI_PATH_PARAM_PARA = "/:" + URI_PARAM_PARA + URI_PARAM_USER = "user" + URI_PATH_PARAM_USER = "/:" + URI_PARAM_USER // Used for backup URI_ARCHIVE = "/log" @@ -112,6 +118,7 @@ const ( URI_OBSERVER_API_PREFIX = URI_API_V1 + URI_OBSERVER_GROUP URI_ZONE_API_PREFIX = URI_API_V1 + URI_ZONE_GROUP URI_TENANT_API_PREFIX = URI_API_V1 + URI_TENANT_GROUP + URI_OBPROXY_API_PREFIX = URI_API_V1 + URI_OBPROXY_GROUP URI_TASK_RPC_PREFIX = URI_RPC_V1 + URI_TASK_GROUP URI_AGENT_RPC_PREFIX = URI_RPC_V1 + URI_AGENT_GROUP diff --git a/agent/engine/coordinator/coordinator.go b/agent/engine/coordinator/coordinator.go index ff3be726..19a21d18 100644 --- a/agent/engine/coordinator/coordinator.go +++ b/agent/engine/coordinator/coordinator.go @@ -225,7 +225,7 @@ func (c *Coordinator) buildMaintainerByPolling() error { } if err := c.getMaintainerbyRpc(&agent); err != nil { - log.WithError(err).Warnf("get maintainer from '%s:%d' failed", agent.GetIp(), agent.GetPort()) + log.WithError(err).Warnf("get maintainer from '%s' failed", agent.String()) continue } return nil @@ -235,7 +235,7 @@ func (c *Coordinator) buildMaintainerByPolling() error { func (c *Coordinator) getMaintainerbyRpc(agentInfo meta.AgentInfoInterface) error { now := time.Now() - log.Infof("try get maintainer rpc request from '%s:%d' to '%s:%d' ", meta.OCS_AGENT.GetIp(), meta.OCS_AGENT.GetPort(), agentInfo.GetIp(), agentInfo.GetPort()) + log.Infof("try get maintainer rpc request from '%s' to '%s' ", meta.OCS_AGENT.String(), agentInfo.String()) maintainer := Maintainer{} if err := secure.SendGetRequest(agentInfo, constant.URI_RPC_V1+constant.URI_MAINTAINER, nil, &maintainer); err != nil { return err diff --git a/agent/engine/executor/executor.go b/agent/engine/executor/executor.go index d1d56511..a4d077e5 100644 --- a/agent/engine/executor/executor.go +++ b/agent/engine/executor/executor.go @@ -233,7 +233,7 @@ func sendUpdateTaskRpc(remoteTaskId int64, task task.ExecutableTask) error { if coordinator.OCS_COORDINATOR.IsFaulty() { return errors.New("faulty does not have maintainer") } - log.Infof("send update task rpc to %s:%d, remote task id %d", coordinator.OCS_COORDINATOR.Maintainer.GetIp(), coordinator.OCS_COORDINATOR.Maintainer.GetPort(), remoteTaskId) + log.Infof("send update task rpc to %s, remote task id %d", coordinator.OCS_COORDINATOR.Maintainer.String(), remoteTaskId) remoteTask := createRemoteTask(remoteTaskId, task) maintainerAgent := coordinator.OCS_COORDINATOR.Maintainer return secure.SendPatchRequest(maintainerAgent, constant.URI_TASK_RPC_PREFIX+constant.URI_SUB_TASK, remoteTask, nil) diff --git a/agent/engine/executor/task_log_sync.go b/agent/engine/executor/task_log_sync.go index d78a0355..f1e21a79 100644 --- a/agent/engine/executor/task_log_sync.go +++ b/agent/engine/executor/task_log_sync.go @@ -104,6 +104,6 @@ func (synchronizer *taskLogSynchronizer) syncTaskLog(taskLog *sqlite.SubTaskLog) func postTaskLogToRemote(taskLog task.TaskExecuteLogDTO) error { maintainerAgent := coordinator.OCS_COORDINATOR.Maintainer - log.Infof("send task log to %s:%d", maintainerAgent.GetIp(), maintainerAgent.GetPort()) + log.Infof("send task log to %s", maintainerAgent.String()) return secure.SendPostRequest(maintainerAgent, constant.URI_TASK_RPC_PREFIX+constant.URI_LOG, taskLog, nil) } diff --git a/agent/engine/scheduler/task_handler.go b/agent/engine/scheduler/task_handler.go index e88ca706..d1975589 100644 --- a/agent/engine/scheduler/task_handler.go +++ b/agent/engine/scheduler/task_handler.go @@ -238,7 +238,7 @@ func (s *Scheduler) updateExecuterAgent(node *task.Node, subTask task.Executable ctx := node.GetContext() agents := ctx.GetParam(task.EXECUTE_AGENTS) if agents == nil { // Not specified execute agent. - log.withScheduler(s).Infof("subtask %d update executer agent %s:%d to %s:%d\n", subTask.GetID(), subTask.GetExecuteAgent().Ip, subTask.GetExecuteAgent().Port, meta.OCS_AGENT.GetIp(), meta.OCS_AGENT.GetPort()) + log.withScheduler(s).Infof("subtask %d update executer agent %s to %s\n", subTask.GetID(), subTask.GetExecuteAgent().String(), meta.OCS_AGENT.String()) subTask.SetExecuteAgent(*meta.NewAgentInfoByInterface(meta.OCS_AGENT)) } } @@ -278,14 +278,14 @@ func (s *Scheduler) runSubTask(subTask *task.RemoteTask) error { } } else { if err := s.sendRunSubTaskRpc(subTask); err != nil { - return errors.Wrapf(err, "send run sub task rpc to %s:%d error", agentInfo.Ip, agentInfo.Port) + return errors.Wrapf(err, "send run sub task rpc to %s error", agentInfo.String()) } } return nil } func (s *Scheduler) sendRunSubTaskRpc(subTask *task.RemoteTask) error { - log.withScheduler(s).Infof("send run sub task %d to %s:%d", subTask.TaskID, subTask.ExecuterAgent.Ip, subTask.ExecuterAgent.Port) + log.withScheduler(s).Infof("send run sub task %d to %s", subTask.TaskID, subTask.ExecuterAgent.String()) return secure.SendPostRequest(&subTask.ExecuterAgent, constant.URI_TASK_RPC_PREFIX+constant.URI_SUB_TASK, subTask, nil) } diff --git a/agent/engine/task/context.go b/agent/engine/task/context.go index bddcbcc6..bfecef20 100644 --- a/agent/engine/task/context.go +++ b/agent/engine/task/context.go @@ -58,7 +58,7 @@ func (ctx *TaskContext) GetData(key string) interface{} { } func (ctx *TaskContext) GetAgentData(agent meta.AgentInfoInterface, key string) interface{} { - return ctx.GetAgentDataByAgentKey(fmt.Sprintf("%s:%d", agent.GetIp(), agent.GetPort()), key) + return ctx.GetAgentDataByAgentKey(agent.String(), key) } func (ctx *TaskContext) GetAgentDataByAgentKey(agentKey string, key string) interface{} { diff --git a/agent/engine/task/dag.go b/agent/engine/task/dag.go index 8e4bd5ab..f4a0d448 100644 --- a/agent/engine/task/dag.go +++ b/agent/engine/task/dag.go @@ -20,6 +20,20 @@ import ( "time" ) +type DagType uint8 + +const ( + DAG_OB DagType = iota + DAG_OBPROXY +) + +var ( + DAG_TYPE_MAP = map[DagType]string{ + DAG_OB: "ob", + DAG_OBPROXY: "obproxy", + } +) + type Dag struct { dagType string stage int diff --git a/agent/engine/task/maintenance_type.go b/agent/engine/task/maintenance_type.go index 27ac2b3b..49c33eec 100644 --- a/agent/engine/task/maintenance_type.go +++ b/agent/engine/task/maintenance_type.go @@ -44,6 +44,7 @@ const ( NOT_UNDER_MAINTENANCE GLOBAL_MAINTENANCE TENANT_MAINTENANCE + OBPROXY_MAINTENACE ) func UnMaintenance() Maintainer { @@ -65,6 +66,12 @@ func TenantMaintenance(tenantName string) Maintainer { } } +func ObproxyMaintenance() Maintainer { + return &maintenance{ + maintenanceType: OBPROXY_MAINTENACE, + } +} + func NewMaintenance(maintenanceType int, maintenanceKey string) Maintainer { return &maintenance{ maintenanceType: maintenanceType, diff --git a/agent/engine/task/node.go b/agent/engine/task/node.go index 3e00b648..072dcb11 100644 --- a/agent/engine/task/node.go +++ b/agent/engine/task/node.go @@ -33,6 +33,7 @@ type Node struct { nodeType string upStream *Node downStream *Node + dagId int TaskInfo ctx *TaskContext } @@ -45,6 +46,10 @@ func (node *Node) GetNodeType() string { return node.nodeType } +func (node *Node) GetDagId() int { + return node.dagId +} + func (node *Node) GetSubTasks() []ExecutableTask { return node.subtasks } @@ -166,12 +171,13 @@ func NewNodeWithContext(task ExecutableTask, paralle bool, ctx *TaskContext) *No return node } -func NewNodeWithId(id int64, name string, nodeType string, state int, operator int, structName string, ctx *TaskContext, isLocalTask bool, startTime time.Time, endTime time.Time) *Node { +func NewNodeWithId(id int64, name string, dagId int, nodeType string, state int, operator int, structName string, ctx *TaskContext, isLocalTask bool, startTime time.Time, endTime time.Time) *Node { node := &Node{ taskType: TASK_TYPE[structName], subtasks: make([]ExecutableTask, 0), nodeType: nodeType, ctx: ctx, + dagId: dagId, TaskInfo: TaskInfo{ id: id, name: name, diff --git a/agent/engine/task/task.go b/agent/engine/task/task.go index 2fbe6774..1a7d3a27 100644 --- a/agent/engine/task/task.go +++ b/agent/engine/task/task.go @@ -564,7 +564,7 @@ func CreateSubTaskInstance( }, executeTimes: executeTimes, executerAgent: executerAgent, - localAgentKey: fmt.Sprintf("%s:%d", executerAgent.GetIp(), executerAgent.GetPort()), + localAgentKey: executerAgent.String(), } taskInstance := reflect.New(TASK_TYPE[taskType]).Elem() diff --git a/agent/engine/task/task_dto.go b/agent/engine/task/task_dto.go index 90960bcc..e7decf53 100644 --- a/agent/engine/task/task_dto.go +++ b/agent/engine/task/task_dto.go @@ -204,28 +204,28 @@ func NewTaskStatusDTO(task *TaskInfo) *TaskStatusDTO { func NewDagDetailDTO(dag *Dag) *DagDetailDTO { return &DagDetailDTO{ - GenericDTO: newGenericDTO(dag), + GenericDTO: newGenericDTO(dag, dag.GetDagType()), DagDetail: NewDagDetail(dag), } } -func NewNodeDetailDTO(node *Node) *NodeDetailDTO { +func NewNodeDetailDTO(node *Node, dagType string) *NodeDetailDTO { return &NodeDetailDTO{ - GenericDTO: newGenericDTO(node), + GenericDTO: newGenericDTO(node, dagType), NodeDetail: NewNodeDetail(node), } } -func NewTaskDetailDTO(task ExecutableTask) *TaskDetailDTO { +func NewTaskDetailDTO(task ExecutableTask, dagType string) *TaskDetailDTO { return &TaskDetailDTO{ - GenericDTO: newGenericDTO(task), + GenericDTO: newGenericDTO(task, dagType), TaskDetail: NewTaskDetail(task), } } -func newGenericDTO(instance TaskInfoInterface) *GenericDTO { +func newGenericDTO(instance TaskInfoInterface, dagType string) *GenericDTO { return &GenericDTO{ - GenericID: ConvertToGenericID(instance), + GenericID: ConvertToGenericID(instance, dagType), } } @@ -268,23 +268,37 @@ func NewTaskDetail(task ExecutableTask) *TaskDetail { } // ConvertToGenericID will convert task instance id to generic dto id. -func ConvertToGenericID(instance TaskInfoInterface) string { +func ConvertToGenericID(instance TaskInfoInterface, dagType string) string { if instance.IsLocalTask() { - return ConvertLocalIDToGenericID(instance.GetID()) + return ConvertLocalIDToGenericID(instance.GetID(), dagType) } return fmt.Sprintf("1%d", instance.GetID()) } -func ConvertIDToGenericID(dagID int64, isLocal bool) string { +func ConvertIDToGenericID(dagID int64, isLocal bool, dagType string) string { if isLocal { - return ConvertLocalIDToGenericID(dagID) + return ConvertLocalIDToGenericID(dagID, dagType) } else { return fmt.Sprintf("1%d", dagID) } } +func ConvertObproxyIDToGenericID(id int64) string { + ipParsed := net.ParseIP(meta.OCS_AGENT.GetIp()) + if ipParsed.To4() != nil { + bigInt := new(big.Int).SetBytes(ipParsed.To4()) + return fmt.Sprintf("4%010d%05d%d", bigInt, meta.OCS_AGENT.GetPort(), id) + } else { + bigInt := new(big.Int).SetBytes(ipParsed.To16()) + return fmt.Sprintf("5%039d%05d%d", bigInt, meta.OCS_AGENT.GetPort(), id) + } +} + // ConvertLocalIDToGenericID will convert id of local task to generic id. -func ConvertLocalIDToGenericID(id int64) string { +func ConvertLocalIDToGenericID(id int64, dagType string) string { + if DAG_TYPE_MAP[DAG_OBPROXY] == dagType { + return ConvertObproxyIDToGenericID(id) + } ipParsed := net.ParseIP(meta.OCS_AGENT.GetIp()) if ipParsed.To4() != nil { bigInt := new(big.Int).SetBytes(ipParsed.To4()) @@ -295,11 +309,17 @@ func ConvertLocalIDToGenericID(id int64) string { } } +func IsObproxyTask(genericID string) bool { + return genericID[0] == constant.OBPROXY_TASK_IPV4_ID_PREFIX || genericID[0] == constant.OBPROXY_TASK_IPV6_ID_PREFIX +} + // ConvertGenericID will onvert dto id to instance id. func ConvertGenericID(genericID string) (id int64, agent meta.AgentInfoInterface, err error) { if genericID[0] == constant.CLUSTER_TASK_ID_PREFIX && len(genericID) <= 1 || - genericID[0] == constant.LOCAL_TASK_IPV4_ID_PREFIX && len(genericID) <= 16 || - genericID[0] == constant.LOCAL_TASK_IPV6_ID_PREFIX && len(genericID) <= 45 { + (genericID[0] == constant.LOCAL_TASK_IPV4_ID_PREFIX || + genericID[0] == constant.OBPROXY_TASK_IPV4_ID_PREFIX) && len(genericID) <= 16 || + (genericID[0] == constant.LOCAL_TASK_IPV6_ID_PREFIX || + genericID[0] == constant.OBPROXY_TASK_IPV6_ID_PREFIX) && len(genericID) <= 45 { err = fmt.Errorf("invalid id: %s", genericID) return } @@ -309,10 +329,10 @@ func ConvertGenericID(genericID string) (id int64, agent meta.AgentInfoInterface switch genericID[0] { case constant.CLUSTER_TASK_ID_PREFIX: idIdx = 1 - case constant.LOCAL_TASK_IPV4_ID_PREFIX: + case constant.LOCAL_TASK_IPV4_ID_PREFIX, constant.OBPROXY_TASK_IPV4_ID_PREFIX: // Ipv4 address. ipIdx, portIdx, idIdx = 11, 16, 16 - case constant.LOCAL_TASK_IPV6_ID_PREFIX: + case constant.LOCAL_TASK_IPV6_ID_PREFIX, constant.OBPROXY_TASK_IPV6_ID_PREFIX: // Ipv6 address. ipIdx, portIdx, idIdx, isV6 = 40, 45, 45, true default: diff --git a/agent/engine/task/template.go b/agent/engine/task/template.go index cf60be10..f0018c23 100644 --- a/agent/engine/task/template.go +++ b/agent/engine/task/template.go @@ -21,6 +21,7 @@ type Template struct { nodes []*Node Name string maintenance Maintainer + Type string } func (template *Template) AddNode(node *Node) { @@ -89,3 +90,8 @@ func (builder *TemplateBuilder) SetMaintenance(maintenanceType Maintainer) *Temp builder.Template.maintenance = maintenanceType return builder } + +func (builder *TemplateBuilder) SetType(dagType DagType) *TemplateBuilder { + builder.Template.Type = DAG_TYPE_MAP[dagType] + return builder +} diff --git a/agent/errors/type.go b/agent/errors/type.go index 52e802c4..daba3ddd 100644 --- a/agent/errors/type.go +++ b/agent/errors/type.go @@ -27,3 +27,7 @@ func IsTaskNotFoundErr(err error) bool { func IsUnkonwnTimeZoneErr(err error) bool { return strings.Contains(err.Error(), "Unknown or incorrect time zone") } + +func IsRecordNotFoundErr(err error) bool { + return strings.Contains(err.Error(), "record not found") +} diff --git a/agent/executor/agent/enter.go b/agent/executor/agent/enter.go index 66a1e1d9..f8009da0 100644 --- a/agent/executor/agent/enter.go +++ b/agent/executor/agent/enter.go @@ -33,6 +33,7 @@ var ( const ( // task param PARAM_MASTER_AGENT = "masterAgent" + PARAM_MASTER_AGENT_PASSWORD = "masterAgentPassword" PARAM_ZONE = "zone" PARAM_AGENT = "agent" PARAM_TAKE_OVER_MASTER_AGENT = "takeOverMasterAgent" diff --git a/agent/executor/agent/follower_remove.go b/agent/executor/agent/follower_remove.go index 89a21339..afa3852f 100644 --- a/agent/executor/agent/follower_remove.go +++ b/agent/executor/agent/follower_remove.go @@ -77,7 +77,7 @@ func CreateRemoveFollowerAgentDag(agent meta.AgentInfo, fromAPI bool) (*task.Dag // Follower agent send rpc to master agent to remove itself or master agent receive api to remove follower agent. // Then, master agent create a task to remove follower agent. // Master will clear observer and zone config if there is no other follower agent in the zone. - name := fmt.Sprintf("Remove follower agent %s:%d", agent.Ip, agent.Port) + name := fmt.Sprintf("Remove follower agent %s", agent.String()) builder := task.NewTemplateBuilder(name) if fromAPI { builder.AddNode(newAgentRemoveFollowerRPCNode([]meta.AgentInfo{agent})) @@ -118,13 +118,13 @@ func (t *RemoveFollowerAgentTask) Execute() (err error) { return errors.Wrap(err, "get param failed") } - t.ExecuteLogf("finding agent %s:%d info", agent.Ip, agent.Port) + t.ExecuteLogf("finding agent %s info", agent.String()) agentInstance, err := GetFollowerAgent(&agent) if err != nil { return errors.Wrap(err, "get follower agent failed") } if agentInstance == nil { - t.ExecuteLogf("agent %s:%d is not exists", agent.Ip, agent.Port) + t.ExecuteLogf("agent %s is not exists", agent.String()) return nil } @@ -145,11 +145,11 @@ func (t *RemoveFollowerAgentTask) Execute() (err error) { } } - t.ExecuteLogf("deleting agent %s:%d", agent.Ip, agent.Port) + t.ExecuteLogf("deleting agent %s", agent.String()) if err = agentService.DeleteAgent(&agent); err != nil { return errors.Wrap(err, "delete agent failed") } - t.ExecuteLogf("remove follower agent %s:%d success", agent.Ip, agent.Port) + t.ExecuteLogf("remove follower agent %s success", agent.String()) return nil } @@ -185,7 +185,7 @@ func GetFollowerAgent(agent meta.AgentInfoInterface) (agentInstance *meta.AgentI if err != nil { err = errors.Wrap(err, "get agent instance failed") } else if agentInstance != nil && !agentInstance.IsFollowerAgent() { - err = errors.Errorf("agent %s:%d is not follower", agent.GetIp(), agent.GetPort()) + err = errors.Errorf("agent %s is not follower", agent.String()) } return } diff --git a/agent/executor/agent/join_follower.go b/agent/executor/agent/join_follower.go index 345cff6c..d5742787 100644 --- a/agent/executor/agent/join_follower.go +++ b/agent/executor/agent/join_follower.go @@ -23,6 +23,7 @@ import ( "github.com/oceanbase/obshell/agent/engine/task" "github.com/oceanbase/obshell/agent/errors" "github.com/oceanbase/obshell/agent/global" + "github.com/oceanbase/obshell/agent/lib/http" "github.com/oceanbase/obshell/agent/meta" "github.com/oceanbase/obshell/agent/secure" "github.com/oceanbase/obshell/param" @@ -30,13 +31,29 @@ import ( type AgentJoinMasterTask struct { task.Task + masterPassword string } type AgentBeFollowerTask struct { task.Task } -func CreateJoinMasterDag(masterAgent meta.AgentInfo, zone string) (*task.Dag, error) { +func SendTokenToMaster(agentInfo meta.AgentInfo, masterPassword string) error { + token, err := secure.NewToken(&agentInfo) + if err != nil { + return errors.Wrap(err, "get token failed") + } + param := param.AddTokenParam{ + AgentInfo: *meta.NewAgentInfoByInterface(meta.OCS_AGENT), + Token: token, + } + if err := secure.SendRequestWithPassword(&agentInfo, constant.URI_AGENT_RPC_PREFIX+constant.URI_TOKEN, http.POST, masterPassword, param, nil); err != nil { + return errors.Wrap(err, "send post request failed") + } + return nil +} + +func CreateJoinMasterDag(masterAgent meta.AgentInfo, zone string, masterPassword string) (*task.Dag, error) { // Agent receive api to join master, then create a task to be follower. builder := task.NewTemplateBuilder(DAG_JOIN_TO_MASTER) @@ -53,8 +70,16 @@ func CreateJoinMasterDag(masterAgent meta.AgentInfo, zone string) (*task.Dag, er builder.AddTask(beFollowerAgent, false) builder.SetMaintenance(task.GlobalMaintenance()) + + // Encrypt master agent password. + agentPassword, err := secure.Encrypt(masterPassword) + if err != nil { + return nil, errors.Wrap(err, "encrypt master password failed") + } template := builder.Build() - ctx := task.NewTaskContext().SetParam(PARAM_ZONE, zone).SetParam(PARAM_MASTER_AGENT, masterAgent) + ctx := task.NewTaskContext().SetParam(PARAM_ZONE, zone).SetParam(PARAM_MASTER_AGENT, masterAgent). + SetParam(PARAM_MASTER_AGENT_PASSWORD, agentPassword) + return localTaskService.CreateDagInstanceByTemplate(template, ctx) } @@ -64,6 +89,14 @@ func (t *AgentJoinMasterTask) Execute() error { if err := taskCtx.GetParamWithValue(PARAM_MASTER_AGENT, &masterAgent); err != nil { return errors.Wrapf(err, "Get Param %s failed", PARAM_MASTER_AGENT) } + if err := taskCtx.GetParamWithValue(PARAM_MASTER_AGENT_PASSWORD, &t.masterPassword); err != nil { + return errors.Wrapf(err, "Get Param %s failed", PARAM_MASTER_AGENT_PASSWORD) + } + // Decrypt master agent password. + masterPassword, err := secure.Decrypt(t.masterPassword) + if err != nil { + return errors.Wrap(err, "decrypt master password failed") + } zone, ok := t.GetContext().GetParam(PARAM_ZONE).(string) if !ok { return errors.New("zone is not set") @@ -89,8 +122,9 @@ func (t *AgentJoinMasterTask) Execute() error { Token: token, } t.ExecuteLog("send join rpc to master") + var masterAgentInstance meta.AgentInstance - if err := secure.SendPostRequest(&masterAgent, constant.URI_AGENT_RPC_PREFIX, param, &masterAgentInstance); err != nil { + if err := secure.SendRequestWithPassword(&masterAgent, constant.URI_AGENT_RPC_PREFIX, http.POST, masterPassword, param, &masterAgentInstance); err != nil { return errors.Wrap(err, "send post request failed") } t.ExecuteLog(fmt.Sprintf("join to master success, master agent info: %v", masterAgentInstance)) @@ -126,7 +160,7 @@ func (t *AgentBeFollowerTask) Execute() error { func AddFollowerAgent(param param.JoinMasterParam) *errors.OcsAgentError { targetToken, err := secure.Crypter.Decrypt(param.Token) if err != nil { - return errors.Occurf(errors.ErrKnown, "decrypt token of '%s:%d' failed: %v", param.JoinApiParam.AgentInfo.GetIp(), param.JoinApiParam.AgentInfo.GetPort(), err) + return errors.Occurf(errors.ErrKnown, "decrypt token of '%s' failed: %v", param.JoinApiParam.AgentInfo.String(), err) } agentInstance := meta.NewAgentInstanceByAgentInfo(¶m.JoinApiParam.AgentInfo, param.JoinApiParam.ZoneName, meta.FOLLOWER, param.Version) @@ -136,19 +170,14 @@ func AddFollowerAgent(param param.JoinMasterParam) *errors.OcsAgentError { return nil } -func UpdateFollowerAgent(agentInstance meta.Agent, param param.JoinMasterParam) *errors.OcsAgentError { - // Agent already exists. - if agentInstance.GetIdentity() != meta.FOLLOWER || agentInstance.GetVersion() != param.Version || agentInstance.GetZone() != param.JoinApiParam.ZoneName { - return errors.Occur(errors.ErrBadRequest, "agent already exists") - } - +func AddSingleToken(param param.AddTokenParam) *errors.OcsAgentError { targetToken, err := secure.Crypter.Decrypt(param.Token) if err != nil { - return errors.Occurf(errors.ErrKnown, "decrypt token of '%s:%d' failed: %v", param.JoinApiParam.AgentInfo.GetIp(), param.JoinApiParam.AgentInfo.GetPort(), err) + return errors.Occurf(errors.ErrKnown, "decrypt token of '%s' failed: %v", param.AgentInfo.String(), err) } - if err = agentService.UpdateAgent(agentInstance, param.HomePath, param.Os, param.Architecture, param.PublicKey, targetToken); err != nil { - return errors.Occurf(errors.ErrKnown, "update agent failed: %v", err) + if err = agentService.AddSingleToken(¶m.AgentInfo, targetToken); err != nil { + return errors.Occurf(errors.ErrKnown, "insert token failed: %v", err) } return nil } diff --git a/agent/executor/ob/bootstrap.go b/agent/executor/ob/bootstrap.go index ae0fcb43..e0fed06c 100644 --- a/agent/executor/ob/bootstrap.go +++ b/agent/executor/ob/bootstrap.go @@ -90,7 +90,8 @@ func (t *ClusterBoostrapTask) generateBootstrapCmd() (string, error) { if !ok { return "", fmt.Errorf("zone %s has no rs", zone) } - list = append(list, fmt.Sprintf("ZONE '%s' SERVER '%s:%d'", zone, observerInfo.Ip, observerInfo.Port)) + agent := meta.NewAgentInfo(observerInfo.Ip, observerInfo.Port) + list = append(list, fmt.Sprintf("ZONE '%s' SERVER '%s'", zone, agent.String())) } bootstrapCmd = bootstrapCmd + strings.Join(list, ", ") return bootstrapCmd, nil @@ -113,7 +114,8 @@ func (t *ClusterBoostrapTask) execBootstrap(cmd string) error { func (t *ClusterBoostrapTask) addServers() error { for zone, serverList := range t.unRS { for _, server := range serverList { - sql := fmt.Sprintf("ALTER SYSTEM ADD SERVER '%s:%d' ZONE '%s'", server.Ip, server.Port, zone) + agent := meta.NewAgentInfo(server.Ip, server.Port) + sql := fmt.Sprintf("ALTER SYSTEM ADD SERVER '%s' ZONE '%s'", agent.String(), zone) t.ExecuteInfoLogf("add server: %s", sql) if err := obclusterService.ExecuteSqlWithoutIdentityCheck(sql); err != nil { return err diff --git a/agent/executor/ob/config.go b/agent/executor/ob/config.go index 2c9b53d8..4cf123e9 100644 --- a/agent/executor/ob/config.go +++ b/agent/executor/ob/config.go @@ -19,7 +19,6 @@ package ob import ( "fmt" "path/filepath" - "regexp" "strconv" "strings" @@ -135,28 +134,15 @@ func (t *UpdateOBServerConfigTask) updateZoneConfig() error { func (t *UpdateOBServerConfigTask) updateServerConfig() error { agents := make([]meta.AgentInfoInterface, 0) for _, server := range t.config.Scope.Target { - agent := ConvertAgentInfo(server) - if agent == nil { - return fmt.Errorf("invalid server: %s", server) + agent, err := meta.ConvertAddressToAgentInfo(server) + if err != nil { + return err } agents = append(agents, agent) } return observerService.UpdateServerConfig(t.config.ObServerConfig, agents, t.deleteAll) } -func ConvertAgentInfo(str string) meta.AgentInfoInterface { - re := regexp.MustCompile(`(\[?[^\[\]:]+\]?):(\d+)`) - matches := re.FindAllStringSubmatch(str, -1) - if len(matches) != 1 || len(matches[0]) != 3 { - return nil - } - port, err := strconv.Atoi(matches[0][2]) - if err != nil { - return nil - } - return meta.NewAgentInfo(matches[0][1], port) -} - func newUpdateOBClusterConfigTask() *UpdateOBClusterConfigTask { newTask := &UpdateOBClusterConfigTask{ Task: *task.NewSubTask(TASK_NAME_UPDATE_CONFIG), @@ -269,6 +255,7 @@ func (t *IntegrateObConfigTask) getAgents() error { agent := *meta.NewAgentInfo(agentDO.Ip, agentDO.Port) t.agents[agent] = agentDO } + t.ExecuteInfoLogf("agents: %v", t.agents) return nil } @@ -362,10 +349,31 @@ func (t *IntegrateObConfigTask) hasRSList() error { rsServers := strings.Split(val.Value, ";") t.ExecuteLogf("check rsList config: %s", rsServers) for _, rsServer := range rsServers { - info := strings.Split(rsServer, ":") + if rsServer == "" { + return fmt.Errorf("invalid rsList config: %s", val.Value) + } + + var info []string + if rsServer[0] == '[' { + t.ExecuteLogf("rsServer: %s is IPv6 address", rsServer) + // It means the rsServer is IPv6 address. + info = strings.Split(rsServer, "]") + if len(info) != 2 { + return fmt.Errorf("invalid rsList config: %s", rsServer) + } + + ip := info[0][1:len(info[0])] + info = strings.Split(info[1], ":") + info[0] = ip + } else { + t.ExecuteLogf("rsServer: %s is IPv4 address", rsServer) + // It means the rsServer is IPv4 address. + info = strings.Split(rsServer, ":") + } if len(info) != 3 { return fmt.Errorf("invalid rsList config: %s", rsServer) } + rpcPort, err := strconv.Atoi(info[1]) if err != nil { return fmt.Errorf("invalid rsList config: %s", rsServer) @@ -419,7 +427,7 @@ func (t *IntegrateObConfigTask) setUNRootServer() error { } func (t *IntegrateObConfigTask) getObserverConfig(agent meta.AgentInfo) (map[string]string, error) { - agentStr := fmt.Sprintf("%s:%d", agent.Ip, agent.Port) + agentStr := agent.String() t.ExecuteLogf("Integrating %s agent config", agentStr) t.ExecuteLogf("get %s agent config", agentStr) diff --git a/agent/executor/ob/convert.go b/agent/executor/ob/convert.go index a36fd7c1..ab458443 100644 --- a/agent/executor/ob/convert.go +++ b/agent/executor/ob/convert.go @@ -92,7 +92,7 @@ func (t *ConvertFollowerToClusterAgentTask) Execute() (err error) { } } for _, agent := range agents { - t.ExecuteLogf("convert agent %s:%d to cluster agent", agent.GetIp(), agent.GetPort()) + t.ExecuteLogf("convert agent %s to cluster agent", agent.String()) if err := agentService.ConvertToClusterAgent(&agent); err != nil { return err } diff --git a/agent/executor/ob/create_user.go b/agent/executor/ob/create_user.go new file mode 100644 index 00000000..7ee5a769 --- /dev/null +++ b/agent/executor/ob/create_user.go @@ -0,0 +1,63 @@ +/* + * Copyright (c) 2024 OceanBase. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package ob + +import ( + "github.com/oceanbase/obshell/agent/engine/task" + "github.com/oceanbase/obshell/agent/secure" +) + +// CreateDefaultUserTask creates default user, currently limited to 'proxyro'. +type CreateDefaultUserTask struct { + task.Task + encryptProxyroPassword string +} + +func newCreateDefaultUserNode(proxyroPassword string) (*task.Node, error) { + subtask := newCreateDefaultUserTask() + encryptedProxyroPassword, err := secure.Encrypt(proxyroPassword) + if err != nil { + return nil, err + } + ctx := task.NewTaskContext().SetParam(PARAM_PROXYRO_PASSWORD, encryptedProxyroPassword) + return task.NewNodeWithContext(subtask, false, ctx), nil +} + +func newCreateDefaultUserTask() *CreateDefaultUserTask { + newTask := &CreateDefaultUserTask{ + Task: *task.NewSubTask(TASK_NAME_CREATE_USER), + } + newTask.SetCanRetry().SetCanContinue().SetCanCancel() + return newTask +} + +func (t *CreateDefaultUserTask) Execute() error { + if err := t.GetContext().GetParamWithValue(PARAM_PROXYRO_PASSWORD, &t.encryptProxyroPassword); err != nil { + return err + } + + // decrypt password + proxyroPassword, err := secure.Decrypt(t.encryptProxyroPassword) + if err != nil { + return err + } + + if err := obclusterService.CreateProxyroUser(proxyroPassword); err != nil { + return err + } + return nil +} diff --git a/agent/executor/ob/enter.go b/agent/executor/ob/enter.go index a7b74f33..50808afc 100644 --- a/agent/executor/ob/enter.go +++ b/agent/executor/ob/enter.go @@ -62,6 +62,7 @@ const ( PARAM_EXPECT_MAIN_NEXT_STAGE = "expectedMainNextStage" PARAM_URI = "uri" PARAM_HEALTH_CHECK = "healthCheck" + PARAM_PROXYRO_PASSWORD = "proxyroPassword" // scale out PARAM_EXPECT_DEPLOY_NEXT_STAGE = "expectedDeployNextStage" PARAM_EXPECT_START_NEXT_STAGE = "expectedStartNextStage" @@ -102,6 +103,7 @@ const ( PARAM_TENANT_NAME = "tenantName" PARAM_TARGET_AGENT_BUILD_VERSION = "targetAgentBuildVersion" + PARAM_TARGET_AGENT_PASSWORD = "targetAgentPassword" // for backup PARAM_NEED_BACKUP_TENANT = "needBackupTenants" @@ -125,6 +127,9 @@ const ( PARAM_RESTORE_SCN = "restoreScn" PARAM_NEED_DELETE_RP = "needDeleteRp" + PARAM_USER_NAME = "userName" + PARAM_USER_PASSWORD = "userPassword" + DATA_ALL_AGENT_DAG_MAP = "allAgentDagMap" DATA_SKIP_START_TASK = "skipStartTask" @@ -155,6 +160,7 @@ const ( TASK_NAME_MIGRATE_TABLE = "Migrate table" TASK_NAME_MODIFY_PWD = "Modify password" TASK_NAME_MIGRATE_DATA = "Migrate data" + TASK_NAME_CREATE_USER = "Create user" TASK_NAME_UPDATE_AGENT = "Update all agents" TASK_NAME_INITIALIZE_DATA = "Initialize data" TASK_NAME_AGENT_SYNC = "Synchronize agent from cluster" @@ -375,6 +381,7 @@ func RegisterObInitTask() { task.RegisterTaskType(ClusterBoostrapTask{}) task.RegisterTaskType(MigrateTableTask{}) task.RegisterTaskType(ModifyPwdTask{}) + task.RegisterTaskType(CreateDefaultUserTask{}) task.RegisterTaskType(MigrateDataTask{}) task.RegisterTaskType(ConvertFollowerToClusterAgentTask{}) task.RegisterTaskType(AgentSyncTask{}) diff --git a/agent/executor/ob/info.go b/agent/executor/ob/info.go index b625df28..7e370e08 100644 --- a/agent/executor/ob/info.go +++ b/agent/executor/ob/info.go @@ -217,16 +217,11 @@ func IsValidScope(s *param.Scope) (err error) { return errors.New("server scope must have target") } for _, server := range s.Target { - info := strings.Split(server, ":") - if len(info) != 2 { - return errors.Errorf("invalid server '%s'", server) - } - port, err := strconv.Atoi(info[1]) + agentInfo, err := meta.ConvertAddressToAgentInfo(server) if err != nil { - return errors.Errorf("invalid server '%s' port '%s'", server, info[1]) + return err } - agentInfo := meta.AgentInfo{Ip: info[0], Port: port} - exist, err := agentService.IsAgentExist(&agentInfo) + exist, err := agentService.IsAgentExist(agentInfo) if err != nil { return err } @@ -269,14 +264,14 @@ func GetObAgents() (agents []meta.AgentInfo, err error) { } for _, server := range serversWithRpcPort { - agents = append(agents, meta.AgentInfo{Ip: server[0], Port: meta.OCS_AGENT.GetPort()}) + agents = append(agents, *meta.NewAgentInfo(server.GetIp(), server.GetPort())) } } return } -func GetAllServerFromOBConf() (serversWithRpcPort [][2]string, err error) { +func GetAllServerFromOBConf() (serversWithRpcPort []meta.AgentInfoInterface, err error) { f := path.ObConfigPath() log.Info("get conf from ", f) file, err := os.Open(f) @@ -291,7 +286,7 @@ func GetAllServerFromOBConf() (serversWithRpcPort [][2]string, err error) { } re := regexp.MustCompile("\x00*([_a-zA-Z]+)=(.*)") - var servers, items []string + var servers []string for scanner.Scan() { line := scanner.Text() match := re.FindStringSubmatch(line) @@ -301,15 +296,11 @@ func GetAllServerFromOBConf() (serversWithRpcPort [][2]string, err error) { if match[1] == ETC_KEY_ALL_SERVER_LIST { servers = strings.Split(match[2], ",") for _, server := range servers { - items = strings.Split(server, ":") - if len(items) != 2 { - return nil, errors.Errorf("invalid server '%s'", server) - } - _, err = strconv.Atoi(items[1]) + serverInfo, err := meta.ConvertAddressToAgentInfo(server) if err != nil { - return nil, errors.Wrapf(err, "invalid server '%s' port '%s'", server, items[1]) + return nil, err } - serversWithRpcPort = append(serversWithRpcPort, [2]string{items[0], items[1]}) + serversWithRpcPort = append(serversWithRpcPort, serverInfo) } log.Infof("get servers from ob.conf %v", serversWithRpcPort) return diff --git a/agent/executor/ob/init.go b/agent/executor/ob/init.go index a517f660..1338dccd 100644 --- a/agent/executor/ob/init.go +++ b/agent/executor/ob/init.go @@ -36,8 +36,15 @@ func CreateInitDag(param param.ObInitParam) (*task.DagDetailDTO, error) { AddTask(newStartObServerTask(), true). AddTask(newClusterBoostrapTask(), false). AddTask(newMigrateTableTask(), false). - AddTask(newModifyPwdTask(), false). - AddTask(newMigrateDataTask(), false). + AddTask(newModifyPwdTask(), false) + if param.CreateProxyroUser { + createUserNode, err := newCreateDefaultUserNode(param.ProxyroPassword) + if err != nil { + return nil, err + } + builder.AddNode(createUserNode) + } + builder.AddTask(newMigrateDataTask(), false). AddTemplate(newConvertClusterTemplate()) if param.ImportScript { builder.AddNode(script.NewImportScriptForTenantNode(false)) diff --git a/agent/executor/ob/minor_freeze.go b/agent/executor/ob/minor_freeze.go index bafdb0ea..992cab3a 100644 --- a/agent/executor/ob/minor_freeze.go +++ b/agent/executor/ob/minor_freeze.go @@ -17,8 +17,6 @@ package ob import ( - "strconv" - "strings" "time" "github.com/oceanbase/obshell/agent/engine/task" @@ -62,10 +60,11 @@ func (t *MinorFreezeTask) GetAllObServer() (servers []oceanbase.OBServer, err er case SCOPE_SERVER: for _, server := range t.scope.Target { - info := strings.Split(server, ":") - ip := info[0] - port, _ := strconv.Atoi(info[1]) - server, err := obclusterService.GetOBServerByAgentInfo(meta.AgentInfo{Ip: ip, Port: port}) + serverInfo, err := meta.ConvertAddressToAgentInfo(server) + if err != nil { + return nil, errors.Wrap(err, "convert address to agent info failed") + } + server, err := obclusterService.GetOBServerByAgentInfo(*serverInfo) if err != nil { return nil, errors.Wrap(err, "get server by agent info failed") } @@ -123,11 +122,11 @@ func (t *MinorFreezeTask) isMinorFreezeOver(servers []oceanbase.OBServer, oldChe // checkpoint_scn is 0, means there is no ls in this server continue } else if checkpointScn > oldCheckpointScn[server] { - t.ExecuteLogf("[server: %s:%d]smallest checkpoint_scn %+v bigger than expired timestamp %+v, check pass ", server.SvrIp, server.SvrPort, checkpointScn, oldCheckpointScn[server]) + t.ExecuteLogf("[server: %s]smallest checkpoint_scn %+v bigger than expired timestamp %+v, check pass ", meta.NewAgentInfo(server.SvrIp, server.SvrPort).String(), checkpointScn, oldCheckpointScn[server]) checkedServer[server] = true continue } else { - t.ExecuteLogf("[server: %s:%d]smallest checkpoint_scn: %+v smaller than expired timestamp %+v, waiting...", server.SvrIp, server.SvrPort, checkpointScn, oldCheckpointScn[server]) + t.ExecuteLogf("[server: %s]smallest checkpoint_scn: %+v smaller than expired timestamp %+v, waiting...", meta.NewAgentInfo(server.SvrIp, server.SvrPort).String(), checkpointScn, oldCheckpointScn[server]) return false, nil } } diff --git a/agent/executor/ob/scale_out.go b/agent/executor/ob/scale_out.go index 76f123b7..2bbf4c56 100644 --- a/agent/executor/ob/scale_out.go +++ b/agent/executor/ob/scale_out.go @@ -104,11 +104,23 @@ func HandleClusterScaleOut(param param.ClusterScaleOutParam) (*task.DagDetailDTO return nil, errors.Occur(errors.ErrUnexpected, err.Error()) } + var rpcPort int + rpcPortStr, ok := param.ObConfigs[constant.CONFIG_RPC_PORT] + if ok { + var err error + if rpcPort, err = strconv.Atoi(rpcPortStr); err != nil { + return nil, errors.Occur(errors.ErrIllegalArgument, "rpc_port is not a number") + } + } else { + rpcPort = constant.DEFAULT_RPC_PORT + } + srvInfo := meta.NewAgentInfo(param.AgentInfo.Ip, rpcPort) + // Check the server is not already in the cluster. - if exist, err := obclusterService.IsServerExist(param.AgentInfo.Ip, param.ObConfigs[constant.CONFIG_RPC_PORT]); err != nil { + if exist, err := obclusterService.IsServerExist(*srvInfo); err != nil { return nil, errors.Occur(errors.ErrUnexpected, err.Error()) } else if exist { - return nil, errors.Occurf(errors.ErrBadRequest, "server %s:%s already exists in the cluster", param.AgentInfo.Ip, param.ObConfigs[constant.CONFIG_RPC_PORT]) + return nil, errors.Occurf(errors.ErrBadRequest, "server %s already exists in the cluster", srvInfo.String()) } // Create Cluster Scale Out Dag @@ -157,8 +169,14 @@ func CreateClusterScaleOutDag(param param.ClusterScaleOutParam, targetVersion st if err != nil { return nil, errors.Wrap(err, "check zone exist failed") } + // get target agent pk + encryptAgentPassword, err := secure.EncryptForAgent(param.TargetAgentPassword, ¶m.AgentInfo) + if err != nil { + return nil, errors.Wrap(err, "encrypt agent password failed") + } + template := buildClusterScaleOutTaskTemplate(param, !isZoneExist) - context := buildClusterScaleOutDagContext(param, !isZoneExist, targetVersion) + context := buildClusterScaleOutDagContext(param, !isZoneExist, targetVersion, encryptAgentPassword) dag, err := clusterTaskService.CreateDagInstanceByTemplate(template, context) if err != nil { return nil, errors.Wrap(err, "create dag instance failed") @@ -233,13 +251,14 @@ func buildLocalScaleOutTaskTemplate(param param.LocalScaleOutParam) *task.Templa Build() } -func buildClusterScaleOutDagContext(param param.ClusterScaleOutParam, isNewZone bool, targetVersion string) *task.TaskContext { +func buildClusterScaleOutDagContext(param param.ClusterScaleOutParam, isNewZone bool, targetVersion string, targetAgentPassword string) *task.TaskContext { context := task.NewTaskContext(). SetParam(PARAM_ZONE, param.Zone). SetParam(PARAM_IS_NEW_ZONE, isNewZone). SetParam(PARAM_AGENT_INFO, param.AgentInfo). SetParam(PARAM_CONFIG, param.ObConfigs). - SetParam(PARAM_TARGET_AGENT_VERSION, targetVersion) + SetParam(PARAM_TARGET_AGENT_VERSION, targetVersion). + SetParam(PARAM_TARGET_AGENT_PASSWORD, targetAgentPassword) return context } @@ -726,6 +745,7 @@ func (t *IntegrateSingleObConfigTask) Execute() error { type CreateLocalScaleOutDagTask struct { scaleCoordinateTask + targetAgentPassword string } func newCreateLocalScaleOutDagTask() *CreateLocalScaleOutDagTask { @@ -741,6 +761,9 @@ func (t *CreateLocalScaleOutDagTask) Execute() error { if err := t.GetContext().GetParamWithValue(PARAM_AGENT_INFO, &agentInfo); err != nil { return errors.Wrap(err, "get agent info failed") } + if err := t.GetContext().GetParamWithValue(PARAM_TARGET_AGENT_PASSWORD, &t.targetAgentPassword); err != nil { + return errors.Wrap(err, "get target agent password failed") + } // Send rpc to target agent. param, err := t.buildLocalScaleOutParam() if err != nil { @@ -752,7 +775,7 @@ func (t *CreateLocalScaleOutDagTask) Execute() error { } var resp LocalScaleOutResp - if err := secure.SendPostRequest(&agentInfo, constant.URI_OB_RPC_PREFIX+constant.URI_SCALE_OUT, param, &resp); err != nil { + if err := secure.SendRequestWithPassword(&agentInfo, constant.URI_OB_RPC_PREFIX+constant.URI_SCALE_OUT, http.POST, t.targetAgentPassword, param, &resp); err != nil { return errors.Wrap(err, "send scale out rpc to target agent failed") } t.ExecuteLogf("create local scale out dag success, genericID:%s", resp.GenericID) @@ -1207,9 +1230,15 @@ func (t *AddServerTask) Execute() error { return errors.New("get zone failed") } - err := obclusterService.AddServer(agentInfo.Ip, configs[constant.CONFIG_RPC_PORT], zone) + port, err := strconv.Atoi(configs[constant.CONFIG_RPC_PORT]) + if err != nil { + return errors.Wrap(err, "convert port to integer failed") + } + + serverInfo := meta.NewAgentInfo(agentInfo.Ip, port) + err = obclusterService.AddServer(*serverInfo, zone) if err != nil { - return errors.Errorf("add server %s:%s failed", agentInfo.Ip, configs[constant.CONFIG_RPC_PORT]) + return errors.Errorf("add server %s failed", serverInfo.String()) } t.GetContext().SetParam(PARAM_ADD_SERVER_SUCCEED, true) @@ -1248,17 +1277,23 @@ func (t *AddServerTask) Rollback() error { return errors.New("get zone failed") } + port, err := strconv.Atoi(configs[constant.CONFIG_RPC_PORT]) + if err != nil { + return errors.Wrap(err, "convert rpc port to integer failed") + } + serverInfo := meta.NewAgentInfo(agentInfo.Ip, port) + // Check whether addserver task execute successfully. - exist, err := obclusterService.IsServerExistWithZone(agentInfo.Ip, configs[constant.CONFIG_RPC_PORT], zone) + exist, err := obclusterService.IsServerExistWithZone(*serverInfo, zone) if err != nil { - return errors.Errorf("check server %s:%s exist failed", agentInfo.Ip, configs[constant.CONFIG_RPC_PORT]) + return errors.Errorf("check server %s exist failed", agentInfo.String()) } if !exist { return nil } - if err = obclusterService.DeleteServerInZone(agentInfo.Ip, configs[constant.CONFIG_RPC_PORT], zone); err != nil { - return errors.Errorf("delete server %s:%s failed", agentInfo.Ip, configs[constant.CONFIG_RPC_PORT]) + if err = obclusterService.DeleteServerInZone(*serverInfo, zone); err != nil { + return errors.Errorf("delete server %s failed", serverInfo.String()) } return nil } diff --git a/agent/executor/ob/start.go b/agent/executor/ob/start.go index 85803921..b38304a4 100644 --- a/agent/executor/ob/start.go +++ b/agent/executor/ob/start.go @@ -248,6 +248,6 @@ func sendGetDagDetailRequest(id string) (*task.DagDetailDTO, error) { } func getDagGenericIDBySubTaskId(id int64) (string, error) { - dagID, err := localTaskService.GetDagIDBySubTaskId(id) - return task.ConvertLocalIDToGenericID(dagID), err + dag, err := localTaskService.GetDagBySubTaskId(id) + return task.ConvertLocalIDToGenericID(dag.GetID(), dag.GetDagType()), err } diff --git a/agent/executor/ob/start_obsvr.go b/agent/executor/ob/start_obsvr.go index aa44f492..71a24328 100644 --- a/agent/executor/ob/start_obsvr.go +++ b/agent/executor/ob/start_obsvr.go @@ -126,7 +126,11 @@ func (t *StartObserverTask) observerHealthCheck(mysqlPort int) error { for retryCount := 1; retryCount <= maxRetries; retryCount++ { time.Sleep(retryInterval) - t.ExecuteLogf("observer health check, retry [%d/%d]", retryCount, maxRetries) + if retryCount%10 == 0 { + t.TimeoutCheck() + } else { + t.ExecuteLogf("observer health check, retry [%d/%d]", retryCount, maxRetries) + } // Check if the observer process exists if exist, err := process.CheckObserverProcess(); !exist || err != nil { @@ -275,6 +279,7 @@ func fillStartConfig(config map[string]string) { func generateStartOpitonCmd(config map[string]string) string { cmd := "" + agentIp := config[constant.CONFIG_LOCAL_IP] for name, value := range startOptionsMap { if val, ok := config[name]; ok { if name == constant.CONFIG_RS_LIST { @@ -285,6 +290,10 @@ func generateStartOpitonCmd(config map[string]string) string { delete(config, name) } } + + if meta.NewAgentInfo(agentIp, 0).IsIPv6() { + cmd += "--ipv6 " + } return cmd } @@ -495,14 +504,15 @@ func (t *CheckObserverForStartTask) checkObsvrProcConfig() error { func (t *AlterStartServerTask) Execute() error { t.ExecuteInfoLog("exec start server sql") - conf, err := observerService.GetObConfigByName(constant.CONFIG_RPC_PORT) - if err != nil { + var rpcPort int + if err := observerService.GetObConfigValueByName(constant.CONFIG_RPC_PORT, &rpcPort); err != nil { return errors.Wrap(err, "get rpc port failed") } if err := getOceanbaseInstance(); err != nil { return err } - sql := fmt.Sprintf("alter system start server '%s:%s'", meta.OCS_AGENT.GetIp(), conf.Value) + + sql := fmt.Sprintf("alter system start server '%s'", meta.NewAgentInfo(meta.OCS_AGENT.GetIp(), rpcPort).String()) log.Info(sql) if err := obclusterService.ExecuteSql(sql); err != nil { return errors.Wrap(err, "alter start server failed") diff --git a/agent/executor/ob/stop.go b/agent/executor/ob/stop.go index 893b2d3a..5509697c 100644 --- a/agent/executor/ob/stop.go +++ b/agent/executor/ob/stop.go @@ -18,8 +18,6 @@ package ob import ( "fmt" - "strconv" - "strings" log "github.com/sirupsen/logrus" @@ -239,12 +237,14 @@ func (t *ExecStopSqlTask) stopServer() (err error) { } for _, server := range t.scope.Target { t.ExecuteLogf("Stop %s", server) - info := strings.Split(server, ":") - ip := info[0] - port, _ := strconv.Atoi(info[1]) + agentInfo, err := meta.ConvertAddressToAgentInfo(server) + if err != nil { + return errors.Errorf("convert server '%s' to agent info failed: %v", server, err) + } for _, agent := range agents { - if ip == agent.Ip && port == agent.Port { - sql := fmt.Sprintf("alter system stop server '%s:%d'", ip, agent.RpcPort) + if agentInfo.Ip == agent.Ip && agentInfo.Port == agent.Port { + serverInfo := meta.NewAgentInfo(agent.Ip, agent.RpcPort) + sql := fmt.Sprintf("alter system stop server '%s'", serverInfo.String()) log.Info(sql) if err = obclusterService.ExecuteSql(sql); err != nil { return err diff --git a/agent/executor/ob/stop_obsvr.go b/agent/executor/ob/stop_obsvr.go index 71731569..f7a0c9fe 100644 --- a/agent/executor/ob/stop_obsvr.go +++ b/agent/executor/ob/stop_obsvr.go @@ -19,8 +19,6 @@ package ob import ( "fmt" "os/exec" - "strconv" - "strings" "time" log "github.com/sirupsen/logrus" @@ -251,15 +249,13 @@ func GenerateTargetAgentList(scope param.Scope) ([]meta.AgentInfo, error) { } case SCOPE_SERVER: for _, server := range scope.Target { - var info meta.AgentInfo - info.Ip = server[0:strings.LastIndex(server, ":")] - info.Port, err = strconv.Atoi(server[strings.LastIndex(server, ":")+1:]) + info, err := meta.ConvertAddressToAgentInfo(server) if err != nil { log.WithError(err).Errorf("parse server '%s' failed", server) return nil, err } - targetAgents = append(targetAgents, info) + targetAgents = append(targetAgents, *info) } } return targetAgents, nil diff --git a/agent/executor/ob/update_agents.go b/agent/executor/ob/update_agents.go index 3ecf3b86..94dc57fd 100644 --- a/agent/executor/ob/update_agents.go +++ b/agent/executor/ob/update_agents.go @@ -49,7 +49,7 @@ func (t *UpdateAllAgentsTask) Execute() error { } for _, agent := range agents { - t.ExecuteLogf("convert agent %s:%d to cluster agent", agent.GetIp(), agent.GetPort()) + t.ExecuteLogf("convert agent %s to cluster agent", agent.String()) if err := agentService.ConvertToClusterAgent(&agent); err != nil { return err } diff --git a/agent/executor/ob/upgrade_agent_check.go b/agent/executor/ob/upgrade_agent_check.go index 7f5d4d69..d8236c4c 100644 --- a/agent/executor/ob/upgrade_agent_check.go +++ b/agent/executor/ob/upgrade_agent_check.go @@ -71,7 +71,7 @@ func preCheckForAgentUpgrade(param param.UpgradeCheckParam) (agentErr *errors.Oc } func checkTargetVersionSupport(version, release string) error { - buildNumber, _, err := splitRelease(release) + buildNumber, _, err := pkg.SplitRelease(release) if err != nil { return err } @@ -90,7 +90,7 @@ func findTargetPkg(version, release string) error { return err } var errs []error - buildNumber, distribution, _ := splitRelease(release) + buildNumber, distribution, _ := pkg.SplitRelease(release) for _, arch := range archList { _, err := obclusterService.GetUpgradePkgInfoByVersionAndRelease(constant.PKG_OBSHELL, version, buildNumber, distribution, arch) if err != nil { @@ -106,7 +106,7 @@ func findTargetPkg(version, release string) error { func buildAgentUpgradeCheckTaskContext(param param.UpgradeCheckParam, agents []meta.AgentInfo) *task.TaskContext { ctx := task.NewTaskContext() - buildNumber, distribution, _ := splitRelease(param.Release) + buildNumber, distribution, _ := pkg.SplitRelease(param.Release) taskTime := strconv.Itoa(int(time.Now().UnixMilli())) ctx.SetParam(PARAM_ALL_AGENTS, agents). SetParam(task.EXECUTE_AGENTS, agents). diff --git a/agent/executor/ob/upgrade_ob.go b/agent/executor/ob/upgrade_ob.go index eb0a7913..66975384 100644 --- a/agent/executor/ob/upgrade_ob.go +++ b/agent/executor/ob/upgrade_ob.go @@ -26,6 +26,7 @@ import ( "github.com/oceanbase/obshell/agent/engine/task" "github.com/oceanbase/obshell/agent/errors" + "github.com/oceanbase/obshell/agent/lib/pkg" "github.com/oceanbase/obshell/agent/meta" "github.com/oceanbase/obshell/agent/repository/model/oceanbase" "github.com/oceanbase/obshell/param" @@ -74,7 +75,7 @@ func CheckAndUpgradeOb(param param.ObUpgradeParam) (*task.DagDetailDTO, *errors. func buildCheckAndUpgradeObTaskContext(p *obUpgradeParams) *task.TaskContext { ctx := task.NewTaskContext() - buildNumber, distribution, _ := splitRelease(p.RequestParam.Release) + buildNumber, distribution, _ := pkg.SplitRelease(p.RequestParam.Release) taskTime := strconv.Itoa(int(time.Now().UnixMilli())) ctx.SetParam(PARAM_ALL_AGENTS, p.agents). SetParam(PARAM_UPGRADE_DIR, p.RequestParam.UpgradeDir). diff --git a/agent/executor/ob/upgrade_ob_check.go b/agent/executor/ob/upgrade_ob_check.go index f37c7168..71346cc1 100644 --- a/agent/executor/ob/upgrade_ob_check.go +++ b/agent/executor/ob/upgrade_ob_check.go @@ -68,7 +68,7 @@ func ObUpgradeCheck(param param.UpgradeCheckParam) (*task.DagDetailDTO, *errors. func buildObUpgradeCheckTaskContext(param param.UpgradeCheckParam, upgradeRoute []RouteNode, agents []meta.AgentInfo) *task.TaskContext { ctx := task.NewTaskContext() - buildNumer, distribution, _ := splitRelease(param.Release) + buildNumer, distribution, _ := pkg.SplitRelease(param.Release) taskTime := strconv.Itoa(int(time.Now().UnixMilli())) ctx.SetParam(task.EXECUTE_AGENTS, agents). SetParam(PARAM_ALL_AGENTS, agents). @@ -133,19 +133,9 @@ func preCheckForObUpgradeCheck(param param.UpgradeCheckParam) (upgradeRoute []Ro return upgradeRoute, nil } -func splitRelease(release string) (buildNumber, distribution string, err error) { - releaseSplit := strings.Split(release, ".") - if len(releaseSplit) < 2 { - return "", "", fmt.Errorf("release format %s is illegal", release) - } - buildNumber = releaseSplit[0] - distribution = releaseSplit[len(releaseSplit)-1] - return -} - func checkForAllRequiredPkgs(targetVersion, targetRelease string) ([]RouteNode, error) { // Param 'targetRelease' is like '***.**.el7'. - targetBuildNumber, targetDistribution, err := splitRelease(targetRelease) + targetBuildNumber, targetDistribution, err := pkg.SplitRelease(targetRelease) if err != nil { return nil, err } diff --git a/agent/executor/ob/upgrade_pkg_install.go b/agent/executor/ob/upgrade_pkg_install.go index b77c23cd..022b13b3 100644 --- a/agent/executor/ob/upgrade_pkg_install.go +++ b/agent/executor/ob/upgrade_pkg_install.go @@ -18,20 +18,16 @@ package ob import ( "fmt" - "io" "os" "os/exec" "path" - "path/filepath" - "github.com/cavaliergopher/cpio" - "github.com/cavaliergopher/rpm" log "github.com/sirupsen/logrus" - "github.com/ulikunitz/xz" "github.com/oceanbase/obshell/agent/constant" "github.com/oceanbase/obshell/agent/engine/task" "github.com/oceanbase/obshell/agent/errors" + "github.com/oceanbase/obshell/agent/lib/pkg" "github.com/oceanbase/obshell/agent/lib/system" ) @@ -108,7 +104,7 @@ func (t *InstallAllRequiredPkgsTask) installAllRequiredPkgs() (err error) { continue } t.ExecuteLogf("Unpack '%s'", rpmPkgInfo.RpmPkgPath) - if err = InstallRpmPkgInPlace(rpmPkgInfo.RpmPkgPath); err != nil { + if err = pkg.InstallRpmPkgInPlace(rpmPkgInfo.RpmPkgPath); err != nil { success = false t.ExecuteErrorLog(err) continue @@ -156,92 +152,6 @@ func (t *InstallAllRequiredPkgsTask) getAgentVersion(rpmPkgInfo *rpmPacakgeInsta return nil } -func InstallRpmPkgInPlace(path string) (err error) { - log.Infof("InstallRpmPkg: %s", path) - f, err := os.Open(path) - if err != nil { - return - } - defer f.Close() - - pkg, err := rpm.Read(f) - if err != nil { - return - } - if err = checkCompressAndFormat(pkg); err != nil { - return - } - - xzReader, err := xz.NewReader(f) - if err != nil { - return - } - installPath := filepath.Dir(path) - cpioReader := cpio.NewReader(xzReader) - - for { - hdr, err := cpioReader.Next() - if err == io.EOF { - break - } - if err != nil { - return err - } - - m := hdr.Mode - if m.IsDir() { - dest := filepath.Join(installPath, hdr.Name) - log.Infof("%s is a directory, creating %s", hdr.Name, dest) - if err := os.MkdirAll(dest, 0755); err != nil { - return errors.Wrapf(err, "mkdir failed %s", hdr.Name) - } - - } else if m.IsRegular() { - if err := handleRegularFile(hdr, cpioReader, installPath); err != nil { - return err - } - - } else if hdr.Linkname != "" { - if err := handleSymlink(hdr, installPath); err != nil { - return err - } - } else { - log.Infof("Skipping unsupported file %s type: %v", hdr.Name, m) - } - } - - return nil -} - -func handleRegularFile(hdr *cpio.Header, cpioReader *cpio.Reader, installPath string) error { - dest := filepath.Join(installPath, hdr.Name) - if err := os.MkdirAll(filepath.Dir(dest), 0755); err != nil { - log.WithError(err).Error("mkdir failed") - return err - } - - outFile, err := os.Create(dest) - if err != nil { - return err - } - defer outFile.Close() - - log.Infof("Extracting %s", hdr.Name) - if _, err := io.Copy(outFile, cpioReader); err != nil { - return err - } - return nil -} - -func handleSymlink(hdr *cpio.Header, installPath string) error { - dest := filepath.Join(installPath, hdr.Name) - if err := os.Symlink(hdr.Linkname, dest); err != nil { - return errors.Wrapf(err, "create symlink failed %s -> %s", dest, hdr.Linkname) - } - log.Infof("Creating symlink %s -> %s", dest, hdr.Linkname) - return nil -} - func (t *InstallAllRequiredPkgsTask) checkObserverBinAvailable(pkgInfo rpmPacakgeInstallInfo) (err error) { t.ExecuteLog("Check if the observer binary is available.") observerBinPath := path.Join(pkgInfo.RpmPkgHomepath, constant.DIR_BIN, constant.PROC_OBSERVER) diff --git a/agent/executor/ob/upgrade_pkg_upload.go b/agent/executor/ob/upgrade_pkg_upload.go index fb1051eb..7a82db8e 100644 --- a/agent/executor/ob/upgrade_pkg_upload.go +++ b/agent/executor/ob/upgrade_pkg_upload.go @@ -144,7 +144,7 @@ func (r *upgradeRpmPkgInfo) fileCheck() (err error) { func (r *upgradeRpmPkgInfo) checkVersion() (err error) { log.Info("version is ", r.version) - r.release, r.distribution, err = splitRelease(r.rpmPkg.Release()) + r.release, r.distribution, err = pkg.SplitRelease(r.rpmPkg.Release()) if err != nil { return } @@ -198,19 +198,9 @@ func (r *upgradeRpmPkgInfo) findAllExpectedFiles(expected []string) (err error) return nil } -func checkCompressAndFormat(pkg *rpm.Package) error { - if pkg.PayloadCompression() != "xz" { - return fmt.Errorf("unsupported compression '%s', the supported compression is 'xz'", pkg.PayloadCompression()) - } - if pkg.PayloadFormat() != "cpio" { - return fmt.Errorf("unsupported payload format '%s', the supported payload format is 'cpio'", pkg.PayloadFormat()) - } - return nil -} - func (r *upgradeRpmPkgInfo) GetUpgradeDepYml() (err error) { log.Info("start to get upgrade dep yml") - if err = checkCompressAndFormat(r.rpmPkg); err != nil { + if err = pkg.CheckCompressAndFormat(r.rpmPkg); err != nil { return } xzReader, err := xz.NewReader(r.rpmFile) diff --git a/agent/executor/obproxy/add.go b/agent/executor/obproxy/add.go new file mode 100644 index 00000000..3413bd31 --- /dev/null +++ b/agent/executor/obproxy/add.go @@ -0,0 +1,570 @@ +/* + * Copyright (c) 2024 OceanBase. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package obproxy + +import ( + "fmt" + "os" + "os/exec" + "path/filepath" + "strconv" + "strings" + "syscall" + "time" + + "github.com/oceanbase/obshell/agent/config" + "github.com/oceanbase/obshell/agent/constant" + "github.com/oceanbase/obshell/agent/engine/task" + "github.com/oceanbase/obshell/agent/errors" + "github.com/oceanbase/obshell/agent/lib/process" + "github.com/oceanbase/obshell/agent/meta" + obproxydb "github.com/oceanbase/obshell/agent/repository/db/obproxy" + "github.com/oceanbase/obshell/agent/repository/db/oceanbase" + "github.com/oceanbase/obshell/agent/secure" + "github.com/oceanbase/obshell/param" + "github.com/oceanbase/obshell/utils" + log "github.com/sirupsen/logrus" + "gorm.io/gorm" +) + +type addObproxyOptions struct { + appName string + homePath string + version string + clusterName string // only for "RS_LIST" mode + encryptedSysPwd string + encryptedProxyroPassword string + parameters map[string]string + sqlPort int + exportPort int +} + +func buildAddObproxyOptions(param *param.AddObproxyParam) (*addObproxyOptions, error) { + version, err := getObproxyVersion(param.HomePath) + if err != nil { + return nil, err + } + + options := addObproxyOptions{ + appName: param.Name, + homePath: param.HomePath, + version: version, + parameters: make(map[string]string), + } + + for k, v := range param.Parameters { + options.parameters[k] = v + } + + options.encryptedSysPwd, err = secure.Encrypt(param.ObproxySysPassword) + if err != nil { + return nil, err + } + + options.encryptedProxyroPassword, err = secure.Encrypt(param.ProxyroPassword) + if err != nil { + return nil, err + } + return &options, nil +} + +func AddObproxy(param param.AddObproxyParam) (*task.DagDetailDTO, *errors.OcsAgentError) { + if meta.IsObproxyAgent() { + return nil, errors.Occur(errors.ErrBadRequest, "agent has already managed obproxy") + } + + if err := checkObproxyHomePath(param.HomePath); err != nil { + return nil, errors.Occurf(errors.ErrBadRequest, "invalid obproxy home path: %s", err) + } + + options, err := buildAddObproxyOptions(¶m) + if err != nil { + return nil, errors.Occur(errors.ErrUnexpected, err) + } + + if err := checkAndFillObproxyPort(¶m, options); err != nil { + return nil, err + } + + // Check obproxy version + if err := checkAndFillObproxyVersion(¶m, options); err != nil { + return nil, err + } + + if err := checkAndFillWorkMode(¶m, options); err != nil { + return nil, err + } + + if rsList, ok := options.parameters[constant.OBPROXY_CONFIG_RS_LIST]; ok && rsList != "" { + if clusterName, err := checkProxyroPasswordAndGetClusterName(rsList, param.ProxyroPassword); err != nil { + return nil, errors.Occur(errors.ErrBadRequest, err) + } else { + options.clusterName = clusterName + log.Infof("cluster name: %s", clusterName) + } + } + + ctx := buildAddObproxyContext(options) + template := buildAddObproxyTemplate(options) + dag, err := localTaskService.CreateDagInstanceByTemplate(template, ctx) + if err != nil { + return nil, errors.Occur(errors.ErrUnexpected, err) + } + return task.NewDagDetailDTO(dag), nil +} + +func checkAndFillWorkMode(param *param.AddObproxyParam, options *addObproxyOptions) *errors.OcsAgentError { + // Check work mode. + if param.RsList != nil && param.ConfigUrl != nil { + return errors.Occur(errors.ErrBadRequest, "rs_list and config_url can not be specified at the same time") + } + + if param.RsList != nil { + options.parameters[constant.OBPROXY_CONFIG_RS_LIST] = *param.RsList + options.parameters[constant.OBPROXY_CONFIG_CONFIG_SERVER_URL] = "" + } else if param.ConfigUrl != nil { + options.parameters[constant.OBPROXY_CONFIG_CONFIG_SERVER_URL] = *param.ConfigUrl + } else { + if !meta.OCS_AGENT.IsClusterAgent() { + return errors.Occur(errors.ErrBadRequest, "rs_list or config_url must be specified when agent is not cluster agent") + } else { + // Use the rs_list of current ob cluster. + rsListStr, err := obclusterService.GetRsListStr() + if err != nil { + // The observer may be inactive. + return errors.Occur(errors.ErrUnexpected, err) + } + + options.parameters[constant.OBPROXY_CONFIG_RS_LIST] = convertToRootServerList(rsListStr) + options.parameters[constant.OBPROXY_CONFIG_CONFIG_SERVER_URL] = "" + } + } + return nil +} + +func checkObproxyHomePath(homePath string) error { + if err := utils.CheckDirExists(homePath); err != nil { + return err + } + + err := syscall.Access(homePath, syscall.O_RDWR) + if err != nil { + return errors.Errorf("no read/write permission for directory '%s'", homePath) + } + + // Check obproxy is installed. + if err := utils.CheckDirExists(filepath.Join(homePath, constant.OBPROXY_DIR_BIN)); err != nil { + return err + } + if err := utils.CheckDirExists(filepath.Join(homePath, constant.OBPROXY_DIR_LIB)); err != nil { + return err + } + + // Check if obproxy has run in the home path. + entrys, err := os.ReadDir(filepath.Join(homePath, constant.OBPROXY_DIR_ETC)) + if err != nil { + if !os.IsNotExist(err) { + return err + } + } + if len(entrys) != 0 { + return errors.New("obproxy etc directory is not empty") + } + + return nil +} + +// checkObproxyVersion checks the version of obproxy located at the given homePath. +// If the version is lower than the minimum supported version (4.0.0), it returns an error. +func checkAndFillObproxyVersion(param *param.AddObproxyParam, options *addObproxyOptions) *errors.OcsAgentError { + version, err := getObproxyVersion(param.HomePath) + if err != nil { + return errors.Occur(errors.ErrBadRequest, "get obproxy version failed: %v", err) + } + if version < constant.OBPROXY_MIN_VERSION_SUPPORT { + return errors.Occurf(errors.ErrBadRequest, "obproxy version %s is lower than the minimum supported version %s", version, constant.OBPROXY_MIN_VERSION_SUPPORT) + } + options.version = version + return nil +} + +func checkProxyroPasswordAndGetClusterName(rsListStr string, password string) (clusterName string, err error) { + rsList := strings.Split(rsListStr, ";") + dsConfig := config.NewObDataSourceConfig(). + SetTryTimes(1). + SetDBName(constant.DB_OCEANBASE). + SetTimeout(10). + SetPassword(password). + SetUsername(constant.SYS_USER_PROXYRO) + var tempDb *gorm.DB + defer func() { + if tempDb != nil { + oceanbaseDB, _ := tempDb.DB() + oceanbaseDB.Close() + } + }() + for _, rs := range rsList { + observerInfo := meta.NewAgentInfoByString(rs) + if observerInfo == nil { + err = errors.Errorf("invalid observer info: %s", rs) + continue + } + dsConfig.SetIp(observerInfo.GetIp()).SetPort(observerInfo.GetPort()) + tempDb, err = oceanbase.LoadTempOceanbaseInstance(dsConfig) + if err != nil { + continue + } + clusterName, err = obproxyService.GetObclusterName(tempDb) + if err != nil { + continue + } + + return clusterName, nil + } + return "", err +} + +func checkAndFillObproxyPort(param *param.AddObproxyParam, options *addObproxyOptions) *errors.OcsAgentError { + if param.SqlPort != nil { + options.parameters[constant.OBPROXY_CONFIG_LISTEN_PORT] = strconv.Itoa(*param.SqlPort) + } else if options.parameters[constant.OBPROXY_CONFIG_LISTEN_PORT] == "" { + options.parameters[constant.OBPROXY_CONFIG_LISTEN_PORT] = strconv.Itoa(constant.OBPROXY_DEFAULT_SQL_PORT) + } + if param.RpcPort != nil { + options.parameters[constant.OBPROXY_CONFIG_RPC_LISTEN_PORT] = strconv.Itoa(*param.RpcPort) + } else if options.parameters[constant.OBPROXY_CONFIG_RPC_LISTEN_PORT] == "" { + options.parameters[constant.OBPROXY_CONFIG_RPC_LISTEN_PORT] = strconv.Itoa(constant.OBPROXY_DEFAULT_RPC_PORT) + } + if param.ExporterPort != nil { + options.parameters[constant.OBPROXY_CONFIG_PROMETHUES_LISTEN_PORT] = strconv.Itoa(*param.ExporterPort) + } else if options.parameters[constant.OBPROXY_CONFIG_PROMETHUES_LISTEN_PORT] == "" { + options.parameters[constant.OBPROXY_CONFIG_PROMETHUES_LISTEN_PORT] = strconv.Itoa(constant.OBPROXY_DEFAULT_EXPORTER_PORT) + } + + // Check port is valid. + var ports = []string{constant.OBPROXY_CONFIG_LISTEN_PORT, constant.OBPROXY_CONFIG_PROMETHUES_LISTEN_PORT, constant.OBPROXY_CONFIG_RPC_LISTEN_PORT} + for _, port := range ports { + if _, err := strconv.Atoi(options.parameters[port]); err != nil { + return errors.Occur(errors.ErrBadRequest, "invalid port: %s", options.parameters[port]) + } + } + options.sqlPort, _ = strconv.Atoi(options.parameters[constant.OBPROXY_CONFIG_LISTEN_PORT]) + options.exportPort, _ = strconv.Atoi(options.parameters[constant.OBPROXY_CONFIG_PROMETHUES_LISTEN_PORT]) + return nil +} + +func buildAddObproxyContext(options *addObproxyOptions) *task.TaskContext { + ctx := task.NewTaskContext(). + SetParam(PARAM_OBPROXY_HOME_PATH, options.homePath). + SetParam(PARAM_OBPROXY_SQL_PORT, options.sqlPort). + SetParam(PARAM_OBPROXY_EXPORTER_PORT, options.exportPort). + SetParam(PARAM_OBPROXY_APP_NAME, options.appName). + SetParam(PARAM_OBPROXY_VERSION, options.version). + SetParam(PARAM_OBPROXY_CLUSTER_NAME, options.clusterName). + SetParam(PARAM_OBPROXY_SYS_PASSWORD, options.encryptedSysPwd) + return ctx +} + +func buildAddObproxyTemplate(options *addObproxyOptions) *task.Template { + templateBuilder := task.NewTemplateBuilder(DAG_ADD_OBPROXY). + SetType(task.DAG_OBPROXY). + AddNode(newPrepareForObproxyAgentNode(false)). + AddNode(newStartObproxyNode(options.parameters)). + AddNode(NewSetObproxyUserPasswordForObNode(options.encryptedProxyroPassword)). + AddTask(newPersistObproxyInfoTask(), false). + SetMaintenance(task.ObproxyMaintenance()) + return templateBuilder.Build() +} + +func convertToRootServerList(rsListStr string) string { + var result []string + entries := strings.Split(rsListStr, ";") + for _, entry := range entries { + parts := strings.Split(entry, ":") + if len(parts) == 3 { + result = append(result, fmt.Sprintf("%s:%s", parts[0], parts[2])) + } + } + return strings.Join(result, ";") +} + +type StartObproxyTask struct { + task.Task + homePath string + sqlPort int + parameters map[string]string + appName string + optionsStr string + clusterName string + obproxySysPassword string + encryptedSysPwd string + + startWithOption bool +} + +func newStartObproxyNode(parameters map[string]string) *task.Node { + newTask := newStartObproxyTask() + ctx := task.NewTaskContext().SetParam(PARAM_OBPROXY_START_PARAMS, parameters).SetParam(PARAM_OBPROXY_START_WITH_OPTIONS, true) + return task.NewNodeWithContext(newTask, false, ctx) +} + +func newStartObproxyWithoutOptionsNode() *task.Node { + newTask := newStartObproxyTask() + ctx := task.NewTaskContext().SetParam(PARAM_OBPROXY_START_WITH_OPTIONS, false) + return task.NewNodeWithContext(newTask, false, ctx) +} + +func newStartObproxyTask() *StartObproxyTask { + newTask := &StartObproxyTask{ + Task: *task.NewSubTask(TASK_START_OBPROXY), + } + newTask.SetCanRetry().SetCanContinue() + return newTask +} + +func (t *StartObproxyTask) Execute() error { + var err error + if err = t.GetContext().GetParamWithValue(PARAM_OBPROXY_HOME_PATH, &t.homePath); err != nil { + return err + } + if err = t.GetContext().GetParamWithValue(PARAM_OBPROXY_START_WITH_OPTIONS, &t.startWithOption); err != nil { + return err + } + + var startCmd string + if !t.startWithOption { + startCmd = t.buildAtartObproxyWithoutOptionsCmd(t.homePath) + } else { + if err := t.GetContext().GetParamWithValue(PARAM_OBPROXY_START_PARAMS, &t.parameters); err != nil { + return err + } + if err := t.GetContext().GetParamWithValue(PARAM_OBPROXY_APP_NAME, &t.appName); err != nil { + return err + } + if err := t.GetContext().GetParamWithValue(PARAM_OBPROXY_CLUSTER_NAME, &t.clusterName); err != nil { + return err + } + if err := t.GetContext().GetParamWithValue(PARAM_OBPROXY_SYS_PASSWORD, &t.encryptedSysPwd); err != nil { + return err + } + if err := t.GetContext().GetParamWithValue(PARAM_OBPROXY_SQL_PORT, &t.sqlPort); err != nil { + return err + } + t.obproxySysPassword, err = secure.Decrypt(t.encryptedSysPwd) + if err != nil { + return errors.Errorf("decrypt obproxy sys password failed: %v", err) + } + + if err := t.buildStartOptionStr(); err != nil { + return err + } + + startCmd = fmt.Sprintf("cd %s; ./bin/obproxy -o %s", t.homePath, t.optionsStr) + if t.appName != "" { + startCmd = fmt.Sprintf("%s -n %s", startCmd, t.appName) + } + if t.clusterName != "" { + startCmd = fmt.Sprintf("%s -c %s", startCmd, t.clusterName) + } + } + t.ExecuteLogf("start obproxy cmd: %s", startCmd) + if output, err := exec.Command("/bin/bash", "-c", startCmd).CombinedOutput(); err != nil { + return errors.Errorf("failed to start obproxy: %v, output: %s", err, string(output)) + } + + if err := t.healthCheck(); err != nil { + return errors.Wrap(err, "obproxy start failed") + } + + if pid, err := process.FindPIDByPort(uint32(t.sqlPort)); err != nil { + return errors.Errorf("get obproxy pid failed: %v", err) + } else if err := process.WritePidForce(filepath.Join(t.homePath, constant.OBPROXY_DIR_RUN, "obproxy.pid"), int(pid)); err != nil { + return errors.Errorf("write obproxy pid failed: %v", err) + } + return nil +} + +func (t *StartObproxyTask) buildStartOptionStr() error { + parameters := t.parameters + // Add single quotes to rs_list. + if rsList, ok := parameters[constant.OBPROXY_CONFIG_RS_LIST]; ok && !strings.HasPrefix(rsList, "'") && !strings.HasSuffix(rsList, "'") { + parameters[constant.OBPROXY_CONFIG_RS_LIST] = fmt.Sprintf("'%s'", rsList) + } + + if t.obproxySysPassword != "" { + parameters[constant.OBPROXY_CONFIG_OBPROXY_SYS_PASSWORD] = utils.Sha1(t.obproxySysPassword) + } else { + // If obproxy sys password is empty, do not need to sha1 it. + parameters[constant.OBPROXY_CONFIG_OBPROXY_SYS_PASSWORD] = "" + } + + optionStrs := make([]string, 0, len(parameters)) + for k, v := range parameters { + optionStrs = append(optionStrs, fmt.Sprintf("%s=%s", k, v)) + } + t.optionsStr = strings.Join(optionStrs, ",") + return nil +} + +func (t *StartObproxyTask) buildAtartObproxyWithoutOptionsCmd(homePath string) string { + return fmt.Sprintf("cd %s; ./bin/obproxy", homePath) +} + +func (t *StartObproxyTask) healthCheck() error { + // Try to connect to obproxy to confirm that it has started. + if t.sqlPort == 0 { + t.sqlPort = meta.OBPROXY_SQL_PORT + } + if t.obproxySysPassword == "" { + t.obproxySysPassword = meta.OBPROXY_SYS_PWD + } + t.ExecuteLog("start obproxy health check") + dsConfig := config.NewObproxyDataSourceConfig().SetPort(t.sqlPort).SetPassword(t.obproxySysPassword) + for retryCount := 1; retryCount <= obproxydb.WAIT_OBPROXY_CONNECTED_MAX_TIMES; retryCount++ { + time.Sleep(obproxydb.WAIT_OBPROXY_CONNECTED_MAX_INTERVAL) + if err := obproxydb.LoadObproxyInstanceForHealthCheck(dsConfig); err != nil { + t.ExecuteWarnLogf("obproxy health check failed: %v", err) + if strings.Contains(err.Error(), "connection refused") { + return err + } + continue + } + if err := obproxyService.UpdateSqlPort(t.sqlPort); err != nil { + return errors.Errorf("update obproxy sql port failed: %v", err) + } + if err := obproxyService.UpdateObproxySysPassword(t.obproxySysPassword); err != nil { + return errors.Errorf("update obproxy sys password failed: %v", err) + } + return nil + } + + return errors.New("obproxy health check timeout") +} + +type PersistObproxyInfoTask struct { + task.Task + homePath string + sqlPort int + version string + encryptedSysPwd string + encryptedProxyroPassword string +} + +func newPersistObproxyInfoTask() *PersistObproxyInfoTask { + newTask := &PersistObproxyInfoTask{ + Task: *task.NewSubTask(TASK_PERSIST_OBPROXY_INFP), + } + newTask.SetCanRetry().SetCanContinue() + return newTask +} + +func (t *PersistObproxyInfoTask) Execute() error { + if err := t.GetContext().GetParamWithValue(PARAM_OBPROXY_HOME_PATH, &t.homePath); err != nil { + return err + } + if err := t.GetContext().GetParamWithValue(PARAM_OBPROXY_SQL_PORT, &t.sqlPort); err != nil { + return err + } + if err := t.GetContext().GetParamWithValue(PARAM_OBPROXY_VERSION, &t.version); err != nil { + return err + } + if err := t.GetContext().GetParamWithValue(PARAM_OBPROXY_SYS_PASSWORD, &t.encryptedSysPwd); err != nil { + return err + } + if err := t.GetContext().GetParamWithValue(PARAM_OBPROXY_PROXYRO_PASSWORD, &t.encryptedProxyroPassword); err != nil { + return err + } + if err := agentService.AddObproxy(t.homePath, t.sqlPort, t.version, t.encryptedSysPwd, t.encryptedProxyroPassword); err != nil { + return err + } + + return nil +} + +type PrepareForAddObproxyTask struct { + task.Task + expectObproxyAgent bool + homePath string +} + +// PrepareForAddObproxyNode will check if the agent is an obproxy agent. +func newPrepareForObproxyAgentNode(expectObproxyAgent bool) *task.Node { + newTask := &PrepareForAddObproxyTask{ + Task: *task.NewSubTask(TASK_CHECK_OBPROXY_STATUS), + } + newTask.SetCanRetry().SetCanContinue() + + ctx := task.NewTaskContext().SetParam(task.FAILURE_EXIT_MAINTENANCE, true).SetParam(PARAM_EXPECT_OBPROXY_AGENT, expectObproxyAgent) + return task.NewNodeWithContext(newTask, false, ctx) +} + +func (t *PrepareForAddObproxyTask) Execute() error { + // Double check if the agent identify. + if err := t.GetContext().GetParamWithValue(PARAM_EXPECT_OBPROXY_AGENT, &t.expectObproxyAgent); err != nil { + return err + } + if err := t.GetContext().GetParamWithValue(PARAM_OBPROXY_HOME_PATH, &t.homePath); err != nil { + return err + } + if t.expectObproxyAgent && !meta.IsObproxyAgent() { + return errors.Errorf("This is not an obproxy agent") + } + if !t.expectObproxyAgent { + if meta.IsObproxyAgent() { + return errors.Errorf("agent has already managed obproxy") + } + // Create obproxy run path + runPath := filepath.Join(t.homePath, constant.OBPROXY_DIR_RUN) + if err := os.MkdirAll(runPath, 0755); err != nil { + return errors.Errorf("create obproxy run path failed: %v", err) + } + } + return nil +} + +// Only support set global proxyro password currently +type SetObproxyUserPasswordForObTask struct { + task.Task + encryptedProxyroPassword string +} + +func NewSetObproxyUserPasswordForObNode(encryptedProxyroPassword string) *task.Node { + newTask := &SetObproxyUserPasswordForObTask{ + Task: *task.NewSubTask(TASK_SET_OBPROXY_USER_PASSWORD), + } + newTask.SetCanRetry().SetCanContinue() + ctx := task.NewTaskContext().SetParam(PARAM_OBPROXY_PROXYRO_PASSWORD, encryptedProxyroPassword) + return task.NewNodeWithContext(newTask, false, ctx) +} + +func (t *SetObproxyUserPasswordForObTask) Execute() error { + if t.GetContext().GetParamWithValue(PARAM_OBPROXY_PROXYRO_PASSWORD, &t.encryptedProxyroPassword) != nil { + return errors.Errorf("get obproxy user password failed") + } + + // Decrypt proxyro password. + proxyroPassword, err := secure.Decrypt(t.encryptedProxyroPassword) + if err != nil { + return errors.Errorf("decrypt proxyro password failed: %v", err) + } + t.ExecuteLog("set obproxy user password") + + if err := obproxyService.SetProxyroPassword(proxyroPassword); err != nil { + return errors.Errorf("set obproxy user password failed: %v", err) + } + return nil +} diff --git a/agent/executor/obproxy/delete.go b/agent/executor/obproxy/delete.go new file mode 100644 index 00000000..1b3c0868 --- /dev/null +++ b/agent/executor/obproxy/delete.go @@ -0,0 +1,95 @@ +/* + * Copyright (c) 2024 OceanBase. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package obproxy + +import ( + "os" + "path/filepath" + + "github.com/oceanbase/obshell/agent/constant" + "github.com/oceanbase/obshell/agent/engine/task" + "github.com/oceanbase/obshell/agent/errors" + "github.com/oceanbase/obshell/agent/meta" +) + +func DeleteObproxy() (*task.DagDetailDTO, *errors.OcsAgentError) { + if !meta.IsObproxyAgent() { + return nil, nil + } + + templateBuilder := task.NewTemplateBuilder(DAG_DELETE_OBPROXY). + SetMaintenance(task.ObproxyMaintenance()). + SetType(task.DAG_OBPROXY). + AddNode(newPrepareForObproxyAgentNode(true)). + AddTask(newStopObproxyTask(), false). + AddTask(newDeleteObproxyTask(), false). + AddTask(newCleanObproxyDirTask(), false) + + context := task.NewTaskContext().SetParam(PARAM_OBPROXY_HOME_PATH, meta.OBPROXY_HOME_PATH) + dag, err := localTaskService.CreateDagInstanceByTemplate(templateBuilder.Build(), context) + if err != nil { + return nil, errors.Occur(errors.ErrUnexpected, err) + } + return task.NewDagDetailDTO(dag), nil +} + +// DeleteObproxyTask will delete the obproxy home path +type DeleteObproxyTask struct { + task.Task +} + +func newDeleteObproxyTask() *DeleteObproxyTask { + newTask := &DeleteObproxyTask{ + Task: *task.NewSubTask(TASK_DELETE_OBPROXY), + } + newTask.SetCanRetry().SetCanContinue().SetCanCancel() + return newTask +} + +func (t *DeleteObproxyTask) Execute() (err error) { + if err := agentService.DeleteObproxy(); err != nil { + return err + } + return nil +} + +type CleanObproxyDirTask struct { + task.Task + obproxyHomePath string +} + +func newCleanObproxyDirTask() *CleanObproxyDirTask { + newTask := &CleanObproxyDirTask{ + Task: *task.NewSubTask(TASK_CLEAN_OBPROXY_DIR), + } + newTask.SetCanRetry().SetCanContinue().SetCanCancel() + return newTask +} + +func (t *CleanObproxyDirTask) Execute() (err error) { + if err := t.GetContext().GetParamWithValue(PARAM_OBPROXY_HOME_PATH, &t.obproxyHomePath); err != nil { + return err + } + deleteFiles := []string{constant.OBPROXY_DIR_ETC, constant.OBPROXY_DIR_LOG, constant.OBPROXY_DIR_RUN, + constant.OBPROXY_DIR_BIN, constant.OBPROXY_DIR_LIB} + for _, file := range deleteFiles { + if err := os.RemoveAll(filepath.Join(t.obproxyHomePath, file)); err != nil { + return errors.Occur(errors.ErrUnexpected, err) + } + } + return nil +} diff --git a/agent/executor/obproxy/enter.go b/agent/executor/obproxy/enter.go new file mode 100644 index 00000000..4f73624f --- /dev/null +++ b/agent/executor/obproxy/enter.go @@ -0,0 +1,128 @@ +/* + * Copyright (c) 2024 OceanBase. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package obproxy + +import ( + "github.com/oceanbase/obshell/agent/engine/task" + agentservice "github.com/oceanbase/obshell/agent/service/agent" + obclusterservice "github.com/oceanbase/obshell/agent/service/obcluster" + obproxyservice "github.com/oceanbase/obshell/agent/service/obproxy" + taskservice "github.com/oceanbase/obshell/agent/service/task" +) + +var obproxyService = obproxyservice.ObproxyService{} +var obclusterService = obclusterservice.ObclusterService{} +var localTaskService = taskservice.NewLocalTaskService() +var agentService = agentservice.AgentService{} + +const ( + WORK_MODE_RS_LIST = "rsList" + WORK_MODE_CONFIG_URL = "configUrl" +) + +var ( + // task name for obproxy + DAG_ADD_OBPROXY = "Add obproxy" + DAG_START_OBPROXY = "Start obproxy" + DAG_STOP_OBPROXY = "Stop obproxy" + DAG_UPGRADE_OBPROXY = "Upgrade obproxy" + DAG_DELETE_OBPROXY = "Delete obproxy" + + TASK_START_OBPROXY = "Start obproxy" + TASK_START_OBPROXYD = "Start obproxyd" + TASK_SET_OBPROXY_SYS_PASSWORD = "Set obproxy sys password" + TASK_SET_OBPROXY_USER_PASSWORD = "Set proxyro password for connect" + TASK_PERSIST_OBPROXY_INFP = "Persist obproxy info" + TASK_STOP_OBPROXY = "Stop obproxy" + TASK_CHECK_OBPROXY_STATUS = "Check obproxy status" + TASK_CHECK_PROXYRO_PASSWORD = "Check proxyro password" + TASK_DELETE_OBPROXY = "Delete obproxy" + TASK_CLEAN_OBPROXY_DIR = "Clean obproxy dir" + + TASK_COPY_CONFIG_DB_FILE = "Copy obproxy config db file" + TASK_HOT_RESTART_OBPROXY = "Hot restart obproxy" + TASK_WAIT_HOT_RESTART_OBPROXY_FINISH = "Wait hot restart obproxy finish" + TASK_RECORD_OBPROXY_INFO = "Record obproxy info" + TASK_REINSTALL_OBPROXY_BIN = "Reinstall obproxy bin" + TASK_DOWNLOAD_RPM_FROM_SQLITE = "Download obproxy pkg from sqlite" + TASK_CHECK_OBPROXY_PKG = "Check obproxy pkg" + TASK_INSTALL_ALL_REQUIRED_PKGS = "Install all required pkgs" + TASK_BACKUP_FOR_UPGRADE = "Backup for upgrade" + TASK_REMOVE_UPGRADE_DIR = "Remove upgrade dir" + TASK_CREATE_UPGRADE_DIR = "Create upgrade dir" + + PARAM_ADD_OBPROXY_OPTION = "addObproxyOption" + PARAM_OBPROXY_HOME_PATH = "homePath" + PARAM_OBPROXY_SQL_PORT = "sqlPort" + PARAM_OBPROXY_EXPORTER_PORT = "exporterPort" + PARAM_OBPROXY_VERSION = "version" + PARAM_OBPROXY_START_PARAMS = "startParams" + PARAM_OBPROXY_START_WITH_OPTIONS = "startWithOptions" + PARAM_OBPROXY_APP_NAME = "appName" + PARAM_OBPROXY_WORK_MODE = "workMode" + PARAM_OBPROXY_RS_LIST = "rsList" + PARAM_OBPROXY_CONFIG_URL = "configUrl" + PARAM_OBPROXY_SYS_PASSWORD = "obproxySysPassword" + PARAM_OBPROXY_PROXYRO_PASSWORD = "proxyroPassword" + PARAM_OBPROXY_CLUSTER_NAME = "clusterName" + PARAM_PERSIST_OBPROXY_INFO_PARAM = "persistObproxyInfoParam" + PARAM_EXPECT_OBPROXY_AGENT = "expectObproxyAgent" + + PARAM_HOT_UPGRADE_ROLLBACK_TIMEOUT = "hotUpgradeRollbackTimeout" + PARAM_HOT_UPGRADE_EXIT_TIMEOUT = "hotUpgradeExitTimeout" + PARAM_OLD_OBPROXY_PID = "oldObproxyPid" + + // for upgrade + PARAM_VERSION = "version" + PARAM_BUILD_NUMBER = "buildNumber" + PARAM_DISTRIBUTION = "distribution" + PARAM_RELEASE_DISTRIBUTION = "releaseDistribution" + PARAM_UPGRADE_DIR = "upgradeDir" + PARAM_TASK_TIME = "taskTime" + PARAM_ONLY_FOR_AGENT = "onlyForAgent" + PARAM_SCRIPT_FILE = "scriptFile" + PARAM_OBPROXY_RPM_PKG_PATH = "obproxyRpmPkgPath" + PARAM_CREATE_UPGRADE_DIR_FLAG = "createUpgradeDirFlag" + + // stop obproxy or obproxyd retry times + STOP_PROCESS_MAX_RETRY_TIME = 15 + STOP_PROCESS_RETRY_INTERVAL = 5 +) + +func RegisterTaskType() { + task.RegisterTaskType(StartObproxyTask{}) + task.RegisterTaskType(SetObproxyUserPasswordForObTask{}) + task.RegisterTaskType(PersistObproxyInfoTask{}) + task.RegisterTaskType(StopObproxyTask{}) + task.RegisterTaskType(PrepareForAddObproxyTask{}) + + task.RegisterTaskType(StopObproxyTask{}) + + task.RegisterTaskType(CopyConfigDbFileTask{}) + task.RegisterTaskType(HotRestartObproxyTask{}) + task.RegisterTaskType(WaitHotRestartObproxyFinishTask{}) + task.RegisterTaskType(RecordObproxyInfoTask{}) + task.RegisterTaskType(ReinstallObproxyBinTask{}) + task.RegisterTaskType(GetObproxyPkgTask{}) + task.RegisterTaskType(CheckObproxyPkgTask{}) + task.RegisterTaskType(BackupObproxyForUpgradeTask{}) + task.RegisterTaskType(RemoveUpgradeObproxyDirTask{}) + task.RegisterTaskType(CreateObproxyUpgradeDirTask{}) + + task.RegisterTaskType(CleanObproxyDirTask{}) + task.RegisterTaskType(DeleteObproxyTask{}) +} diff --git a/agent/executor/obproxy/package.go b/agent/executor/obproxy/package.go new file mode 100644 index 00000000..f35a81d7 --- /dev/null +++ b/agent/executor/obproxy/package.go @@ -0,0 +1,398 @@ +/* + * Copyright (c) 2024 OceanBase. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package obproxy + +import ( + "fmt" + "io" + "os" + "path/filepath" + + "github.com/cavaliergopher/cpio" + "github.com/cavaliergopher/rpm" + "github.com/oceanbase/obshell/agent/constant" + "github.com/oceanbase/obshell/agent/engine/task" + "github.com/oceanbase/obshell/agent/errors" + "github.com/oceanbase/obshell/agent/global" + "github.com/oceanbase/obshell/agent/lib/path" + "github.com/oceanbase/obshell/agent/lib/pkg" + "github.com/oceanbase/obshell/agent/lib/system" + "github.com/ulikunitz/xz" + + "github.com/oceanbase/obshell/agent/repository/model/sqlite" + log "github.com/sirupsen/logrus" +) + +var ( + confficient = 1.1 +) + +type GetObproxyPkgTask struct { + task.Task + targetBuildNumber string + targetVersion string + distribution string + upgradeDir string + + upgradePkgInfo sqlite.UpgradePkgInfo +} + +func newGetObproxyPkgTask() *GetObproxyPkgTask { + newTask := &GetObproxyPkgTask{ + Task: *task.NewSubTask(TASK_DOWNLOAD_RPM_FROM_SQLITE), + } + newTask.SetCanContinue().SetCanRetry().SetCanRollback() + return newTask +} + +func (t *GetObproxyPkgTask) getParams() (err error) { + if err = t.GetContext().GetParamWithValue(PARAM_UPGRADE_DIR, &t.upgradeDir); err != nil { + return err + } + if err = t.GetContext().GetParamWithValue(PARAM_BUILD_NUMBER, &t.targetBuildNumber); err != nil { + return err + } + if err = t.GetContext().GetParamWithValue(PARAM_VERSION, &t.targetVersion); err != nil { + return err + } + if err = t.GetContext().GetParamWithValue(PARAM_DISTRIBUTION, &t.distribution); err != nil { + return err + } + return nil +} + +func (t *GetObproxyPkgTask) Execute() (err error) { + if t.IsContinue() { + t.ExecuteLog("The task is continuing.") + if err = t.Rollback(); err != nil { + return err + } + } + + if err = t.getAllRequiredPkgs(); err != nil { + return + } + return nil +} + +func (t *GetObproxyPkgTask) getAllRequiredPkgs() (err error) { + if err = t.getParams(); err != nil { + return err + } + + t.ExecuteLogf("The directory for this upgrade check task is %s", t.upgradeDir) + if err = os.MkdirAll(t.upgradeDir, 0755); err != nil { + return err + } + + t.ExecuteLog("Confirm that all the required packages have been uploaded.") + + if t.upgradePkgInfo, err = agentService.GetUpgradePkgInfoByVersionAndRelease(constant.PKG_OBPROXY_CE, t.targetVersion, t.targetBuildNumber, t.distribution, global.Architecture); err != nil { + return err + } + + if err = t.CheckDiskFreeSpace(); err != nil { + return + } + + return t.downloadAllRequiredPkgs() +} + +func (t *GetObproxyPkgTask) CheckDiskFreeSpace() error { + t.ExecuteLog("Check the remaining disk space.") + t.ExecuteLogf("The directory being checked is %s", t.upgradeDir) + expectedSize := (t.upgradePkgInfo.Size + t.upgradePkgInfo.PayloadSize) * uint64(confficient) + t.ExecuteLogf("The required disk size is %d", expectedSize) + diskInfo, err := system.GetDiskInfo(t.upgradeDir) + if err != nil { + return errors.Wrap(err, "failed to get disk info") + } + t.ExecuteLogf("The remaining disk size is %d", diskInfo.FreeSizeBytes) + if diskInfo.FreeSizeBytes < expectedSize { + return fmt.Errorf("the remaining disk space is insufficient, the remaining disk space is %d, and the required disk space is %d", diskInfo.FreeSizeBytes, expectedSize) + } + return nil +} + +func (t *GetObproxyPkgTask) downloadAllRequiredPkgs() (err error) { + t.ExecuteLogf("Download all packages to %s", t.upgradeDir) + pkgInfo := t.upgradePkgInfo + rpmDir := GenerateUpgradeRpmDir(t.upgradeDir, pkgInfo.Version, pkgInfo.Architecture) + if err := os.MkdirAll(rpmDir, 0755); err != nil { + return err + } + rpmPkgPath := GenerateRpmPkgPath(rpmDir, pkgInfo.Name) + if err = agentService.DownloadUpgradePkgChunkInBatch(rpmPkgPath, pkgInfo.PkgId, pkgInfo.ChunkCount); err != nil { + return err + } + t.GetContext().SetParam(PARAM_OBPROXY_RPM_PKG_PATH, rpmPkgPath) + t.ExecuteLogf("Downloaded pkg '%s' to '%s'", pkgInfo.Name, rpmPkgPath) + return nil +} + +func (t *GetObproxyPkgTask) Rollback() (err error) { + t.ExecuteLog("Rolling back...") + if err = t.deleteAllRequiredPkgs(); err != nil { + return + } + t.ExecuteLog("Successfully deleted.") + return nil +} + +func (t *GetObproxyPkgTask) deleteAllRequiredPkgs() (err error) { + t.ExecuteLog("Delete all previously downloaded packages.") + if err = t.GetContext().GetParamWithValue(PARAM_UPGRADE_DIR, &t.upgradeDir); err != nil { + return err + } + return os.RemoveAll(t.upgradeDir) + +} + +func GenerateUpgradeRpmDir(upgradeDir, version, arch string) string { + return filepath.Join(upgradeDir, arch, version) +} + +func GenerateRpmPkgPath(rpmDir, rpmName string) string { + return fmt.Sprintf("%s/%s.rpm", rpmDir, rpmName) +} + +type CheckObproxyPkgTask struct { + task.Task + pkgPath string +} + +func newCheckObproxyPkgTask() *CheckObproxyPkgTask { + newTask := &CheckObproxyPkgTask{ + Task: *task.NewSubTask(TASK_CHECK_OBPROXY_PKG), + } + newTask. + SetCanContinue(). + SetCanRollback(). + SetCanRetry(). + SetCanPass(). + SetCanCancel() + return newTask +} + +func (t *CheckObproxyPkgTask) Execute() (err error) { + if t.GetContext().GetParamWithValue(PARAM_OBPROXY_RPM_PKG_PATH, &t.pkgPath); err != nil { + return err + } + if err = t.checkRequiredPkgs(); err != nil { + return + } + return nil +} + +func (t *CheckObproxyPkgTask) checkRequiredPkgs() (err error) { + if err = t.checkUpgradePkgFromDb(t.pkgPath); err != nil { + return err + } + t.ExecuteInfoLog("obproxy-ce package is checked successfully.") + return nil +} + +func (t *CheckObproxyPkgTask) checkUpgradePkgFromDb(filePath string) (err error) { + input, err := os.Open(filePath) + if err != nil { + return err + } + defer input.Close() + r := &upgradeRpmPkgInfo{ + rpmFile: input, + } + + if err = r.CheckUpgradePkg(false); err != nil { + return err + } + return nil +} + +type ReinstallObproxyBinTask struct { + task.Task + rpmPkgPath string +} + +func newReinstallObproxyBinTask() *ReinstallObproxyBinTask { + newTask := &ReinstallObproxyBinTask{ + Task: *task.NewSubTask(TASK_REINSTALL_OBPROXY_BIN), + } + newTask.SetCanContinue(). + SetCanRetry(). + SetCanRollback(). + SetCanCancel(). + SetCanPass() + return newTask +} + +func (t *ReinstallObproxyBinTask) Execute() error { + if err := t.GetContext().GetParamWithValue(PARAM_OBPROXY_RPM_PKG_PATH, &t.rpmPkgPath); err != nil { + return err + } + if err := t.installRpmPkgInPlace(t.rpmPkgPath); err != nil { + return err + } + t.ExecuteLogf("Successfully installed %s", t.rpmPkgPath) + return nil +} + +func (t *ReinstallObproxyBinTask) installRpmPkgInPlace(rpmPkgPath string) (err error) { + log.Infof("InstallRpmPkg: %s", rpmPkgPath) + f, err := os.Open(rpmPkgPath) + if err != nil { + return + } + defer f.Close() + + rpmPkg, err := rpm.Read(f) + if err != nil { + return + } + + if err = pkg.CheckCompressAndFormat(rpmPkg); err != nil { + return + } + + xzReader, err := xz.NewReader(f) + if err != nil { + return + } + cpioReader := cpio.NewReader(xzReader) + + for { + hdr, err := cpioReader.Next() + if err == io.EOF { + break + } + if err != nil { + return err + } + + m := hdr.Mode + if m.IsRegular() && hdr.FileInfo().Name() == "obproxy" { + // Remove obproxy + if err := os.RemoveAll(path.ObproxyBinPath()); err != nil { + return err + } + outFile, err := os.OpenFile(path.ObproxyBinPath(), os.O_CREATE|os.O_WRONLY, 0755) + if err != nil { + return err + } + defer outFile.Close() + log.Infof("Extracting %s", hdr.Name) + if _, err := io.Copy(outFile, cpioReader); err != nil { + return err + } + } + } + + return nil +} + +func (t *ReinstallObproxyBinTask) Rollback() (err error) { + t.ExecuteLog("uninstall new obproxy") + var upgradeDir string + if err = t.GetContext().GetParamWithValue(PARAM_UPGRADE_DIR, &upgradeDir); err != nil { + return err + } + + backupDir := filepath.Join(upgradeDir, "backup") + + dest := path.ObproxyBinPath() + if err := os.RemoveAll(dest); err != nil { + return err + } + return system.CopyFile(fmt.Sprintf("%s/%s", backupDir, constant.PROC_OBPROXY), dest) +} + +type BackupObproxyForUpgradeTask struct { + task.Task + upgradeDir string + backupDir string +} + +func newBackupObproxyForUpgradeTask() *BackupObproxyForUpgradeTask { + newTask := &BackupObproxyForUpgradeTask{ + Task: *task.NewSubTask(TASK_BACKUP_FOR_UPGRADE), + } + newTask. + SetCanRetry(). + SetCanContinue(). + SetCanRollback(). + SetCanPass(). + SetCanCancel() + return newTask +} + +func (t *BackupObproxyForUpgradeTask) Execute() (err error) { + if t.IsContinue() { + t.ExecuteLog("The task is continuing.") + if err = t.Rollback(); err != nil { + return err + } + } + + if err = t.BackupObproxyForUpgrade(); err != nil { + return + } + return nil +} + +func (t *BackupObproxyForUpgradeTask) getParams() (err error) { + if err = t.GetContext().GetParamWithValue(PARAM_UPGRADE_DIR, &t.upgradeDir); err != nil { + return err + } + + t.backupDir = filepath.Join(t.upgradeDir, "backup") + return nil +} + +func (t *BackupObproxyForUpgradeTask) BackupObproxyForUpgrade() error { + t.ExecuteLog("Backup important files.") + if err := t.getParams(); err != nil { + return err + } + + t.ExecuteLogf("The directory for backup is %s", t.backupDir) + t.ExecuteLogf("Backup the bin directory %s", path.BinDir()) + if err := system.CopyDirs(path.ObproxyBinDir(), t.backupDir); err != nil { + return err + } + return nil +} + +func (t *BackupObproxyForUpgradeTask) Rollback() (err error) { + t.ExecuteLog("Rolling back...") + if err = t.deleteBackupDir(); err != nil { + return err + } + t.ExecuteLog("Successfully deleted") + return nil +} + +func (t *BackupObproxyForUpgradeTask) deleteBackupDir() (err error) { + if err = t.getParams(); err != nil { + return err + } + if t.backupDir != "" { + t.ExecuteLog("Delete " + t.backupDir) + if err := os.RemoveAll(t.backupDir); err != nil { + return err + } + } + return nil +} diff --git a/agent/executor/obproxy/start.go b/agent/executor/obproxy/start.go new file mode 100644 index 00000000..f3ab36b8 --- /dev/null +++ b/agent/executor/obproxy/start.go @@ -0,0 +1,41 @@ +/* + * Copyright (c) 2024 OceanBase. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package obproxy + +import ( + "github.com/oceanbase/obshell/agent/engine/task" + "github.com/oceanbase/obshell/agent/errors" + "github.com/oceanbase/obshell/agent/meta" +) + +func StartObproxy() (*task.DagDetailDTO, *errors.OcsAgentError) { + if !meta.IsObproxyAgent() { + return nil, errors.Occur(errors.ErrBadRequest, "This is not an obproxy agent") + } + + template := task.NewTemplateBuilder(DAG_START_OBPROXY). + SetType(task.DAG_OBPROXY). + AddNode(newPrepareForObproxyAgentNode(true)). + AddNode(newStartObproxyWithoutOptionsNode()).Build() + context := task.NewTaskContext().SetParam(PARAM_OBPROXY_HOME_PATH, meta.OBPROXY_HOME_PATH) + dag, err := localTaskService.CreateDagInstanceByTemplate(template, context) + if err != nil { + return nil, errors.Occur(errors.ErrUnexpected, err) + } + return task.NewDagDetailDTO(dag), nil + +} diff --git a/agent/executor/obproxy/stop.go b/agent/executor/obproxy/stop.go new file mode 100644 index 00000000..7aa06e15 --- /dev/null +++ b/agent/executor/obproxy/stop.go @@ -0,0 +1,135 @@ +/* + * Copyright (c) 2024 OceanBase. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package obproxy + +import ( + "os/exec" + "time" + + log "github.com/sirupsen/logrus" + + "github.com/oceanbase/obshell/agent/engine/task" + "github.com/oceanbase/obshell/agent/errors" + "github.com/oceanbase/obshell/agent/lib/process" + "github.com/oceanbase/obshell/agent/meta" +) + +func StopObproxy() (*task.DagDetailDTO, *errors.OcsAgentError) { + if !meta.IsObproxyAgent() { + return nil, errors.Occur(errors.ErrBadRequest, "This is not an obproxy agent") + } + + template := task.NewTemplateBuilder(DAG_STOP_OBPROXY). + SetMaintenance(task.ObproxyMaintenance()). + SetType(task.DAG_OBPROXY). + AddNode(newPrepareForObproxyAgentNode(true)). + AddTask(newStopObproxyTask(), false).Build() + + ctx := task.NewTaskContext().SetParam(PARAM_OBPROXY_HOME_PATH, meta.OBPROXY_HOME_PATH) + dag, err := localTaskService.CreateDagInstanceByTemplate(template, ctx) + if err != nil { + return nil, errors.Occur(errors.ErrUnexpected, err) + } + return task.NewDagDetailDTO(dag), nil +} + +// StopObproxyTask will stop obproyxd and obproxy. +type StopObproxyTask struct { + task.Task +} + +func newStopObproxyTask() *StopObproxyTask { + newTask := &StopObproxyTask{ + Task: *task.NewSubTask(TASK_STOP_OBPROXY), + } + newTask.SetCanRetry().SetCanContinue() + return newTask +} + +func (t *StopObproxyTask) Execute() error { + // if err := t.stopObproxyd(); err != nil { + // return err + // } + if err := t.stopObproxy(); err != nil { + return err + } + return nil +} + +func (t *StopObproxyTask) stopObproxy() error { + pid, err := process.GetObproxyPid() + if err != nil { + return err + } + t.ExecuteLogf("Get obproxy pid: %s", pid) + if pid == "" { + t.ExecuteLog("Obproxy is not running") + return nil + } + for i := 0; i < STOP_PROCESS_MAX_RETRY_TIME; i++ { + t.ExecuteLogf("Kill obproxy process %s", pid) + res := exec.Command("kill", "-9", pid) + if err := res.Run(); err != nil { + log.Warn("Kill obproxy process failed") + } + + time.Sleep(time.Second * time.Duration(STOP_PROCESS_RETRY_INTERVAL)) + t.TimeoutCheck() + + t.ExecuteLog("Check obproxy process") + exist, err := process.CheckObproxyProcess() + if err != nil { + log.Warnf("Check obproxy process failed: %v", err) + } else if !exist { + t.ExecuteLog("Successfully killed the obproxy process") + return nil + } + } + return errors.New("kill obproxy process timeout") +} + +func (t *StopObproxyTask) stopObproxyd() error { + pid, err := process.GetObproxydPid() + if err != nil { + return err + } + t.ExecuteLogf("Get obproxyd pid: %s", pid) + if pid == "" { + t.ExecuteLog("Obproxyd is not running") + return nil + } + for i := 0; i < STOP_PROCESS_MAX_RETRY_TIME; i++ { + t.ExecuteLogf("Kill obproxyd process %s", pid) + res := exec.Command("kill", "-9", pid) + if err := res.Run(); err != nil { + log.Warn("Kill obproxyd process failed") + } + + time.Sleep(time.Second * time.Duration(STOP_PROCESS_RETRY_INTERVAL)) + t.TimeoutCheck() + + t.ExecuteLog("Check obproxyd process") + exist, err := process.CheckObproxydProcess() + if err != nil { + log.Warnf("Check obproxyd process failed: %v", err) + } else if !exist { + t.ExecuteLog("Successfully killed the obproxyd process") + return nil + } + } + return errors.New("kill obproxyd process timeout") +} diff --git a/agent/executor/obproxy/upgrade.go b/agent/executor/obproxy/upgrade.go new file mode 100644 index 00000000..1c6a3ad7 --- /dev/null +++ b/agent/executor/obproxy/upgrade.go @@ -0,0 +1,416 @@ +/* + * Copyright (c) 2024 OceanBase. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package obproxy + +import ( + "fmt" + "os" + "regexp" + "strings" + "time" + + "github.com/oceanbase/obshell/agent/config" + "github.com/oceanbase/obshell/agent/constant" + "github.com/oceanbase/obshell/agent/engine/task" + "github.com/oceanbase/obshell/agent/errors" + "github.com/oceanbase/obshell/agent/global" + "github.com/oceanbase/obshell/agent/lib/parse" + "github.com/oceanbase/obshell/agent/lib/path" + "github.com/oceanbase/obshell/agent/lib/pkg" + "github.com/oceanbase/obshell/agent/lib/process" + "github.com/oceanbase/obshell/agent/lib/system" + "github.com/oceanbase/obshell/agent/meta" + obproxydb "github.com/oceanbase/obshell/agent/repository/db/obproxy" + "github.com/oceanbase/obshell/agent/repository/model/bo" + "github.com/oceanbase/obshell/param" + "github.com/oceanbase/obshell/utils" + log "github.com/sirupsen/logrus" + "gorm.io/gorm" +) + +const waitPeriod = 5 // seconds + +func UpgradeObproxy(param param.UpgradeObproxyParam) (*task.DagDetailDTO, *errors.OcsAgentError) { + if !meta.IsObproxyAgent() { + return nil, errors.Occur(errors.ErrBadRequest, "not obproxy agent") + } + if alive, err := process.CheckObproxyProcess(); err != nil { + return nil, errors.Occur(errors.ErrUnexpected, err) + } else if !alive { + return nil, errors.Occur(errors.ErrBadRequest, "obproxy is not running") + } + + if err := checkVersionSupport(param.Version, param.Release); err != nil { + return nil, err + } + if err := checkUpgradeDir(¶m.UpgradeDir); err != nil { + return nil, errors.Occur(errors.ErrIllegalArgument, err) + } + if err := findTargetPkg(param.Version, param.Release); err != nil { + return nil, err + } + + template := buildUpgradeObproxyTemplate() + context := buildUpgradeObproxyTaskContext(param) + dag, err := localTaskService.CreateDagInstanceByTemplate(template, context) + if err != nil { + return nil, errors.Occur(errors.ErrUnexpected, err) + } + return task.NewDagDetailDTO(dag), nil +} + +func checkVersionSupport(version, release string) *errors.OcsAgentError { + // Check obproxy version + curObproxyVersion, err := obproxyService.GetObproxyVersion() + if err != nil { + return errors.Occur(errors.ErrUnexpected, err) + } + buildNumber, _, err := pkg.SplitRelease(release) + if err != nil { + return errors.Occur(errors.ErrUnexpected, err) + } + if pkg.CompareVersion(curObproxyVersion, fmt.Sprintf("%s-%s", version, buildNumber)) >= 0 { + return errors.Occur(errors.ErrBadRequest, "current obproxy version is greater than or equal to the target version") + } + return nil +} + +func findTargetPkg(version, release string) *errors.OcsAgentError { + buildNumber, distribution, _ := pkg.SplitRelease(release) + _, err := agentService.GetUpgradePkgInfoByVersionAndRelease(constant.PKG_OBPROXY_CE, version, buildNumber, distribution, global.Architecture) + if err != nil { + return errors.Occurf(errors.ErrBadRequest, "find target pkg '%s-%s-%s.%s.rpm' failed", constant.PKG_OBPROXY_CE, version, release, global.Architecture) + } + return nil +} + +func checkUpgradeDir(path *string) (err error) { + log.Infof("checking upgrade directory: '%s'", *path) + str := *path + + *path = strings.TrimSpace(*path) + if len(*path) == 0 { + return nil + } + + return utils.CheckPathValid(str) +} + +func buildUpgradeObproxyTemplate() *task.Template { + return task.NewTemplateBuilder(DAG_UPGRADE_OBPROXY). + SetMaintenance(task.ObproxyMaintenance()). + SetType(task.DAG_OBPROXY). + AddTask(newCreateObproxyUpgradeDirTask(), false). + AddTask(newGetObproxyPkgTask(), false). + AddTask(newCheckObproxyPkgTask(), false). + AddTask(newBackupObproxyForUpgradeTask(), false). + AddTask(newReinstallObproxyBinTask(), false). + AddTask(newCopyConfigDbFileTask(), false). + AddTask(newRecordObproxyInfoTask(), false). + AddTask(newHotRestartObproxyTask(), false). + AddTask(newWaitHotRestartObproxyFinishTask(), false). + AddTask(newRemoveUpgradeCheckDirTask(), false). + Build() +} + +func buildUpgradeObproxyTaskContext(param param.UpgradeObproxyParam) *task.TaskContext { + if param.UpgradeDir == "" { + param.UpgradeDir = meta.OBPROXY_HOME_PATH + } + buildNumber, distribution, _ := pkg.SplitRelease(param.Release) + return task.NewTaskContext(). + SetParam(PARAM_UPGRADE_DIR, fmt.Sprintf("%s/%s-%d", param.UpgradeDir, "obproxy-upgrade-dir", time.Now().Unix())). + SetParam(PARAM_VERSION, param.Version). + SetParam(PARAM_BUILD_NUMBER, buildNumber). + SetParam(PARAM_DISTRIBUTION, distribution). + SetParam(PARAM_RELEASE_DISTRIBUTION, param.Release) +} + +type CopyConfigDbFileTask struct { + task.Task + targetVersion string +} + +func newCopyConfigDbFileTask() *CopyConfigDbFileTask { + newTask := &CopyConfigDbFileTask{ + Task: *task.NewSubTask(TASK_COPY_CONFIG_DB_FILE), + } + newTask.SetCanContinue().SetCanRetry().SetCanRollback() + return newTask +} + +func (t *CopyConfigDbFileTask) Execute() error { + if err := t.GetContext().GetParamWithValue(PARAM_VERSION, &t.targetVersion); err != nil { + return err + } + if pkg.CompareVersion(t.targetVersion, "4.1.0.0") >= 0 { + if _, err := os.Stat(path.ObproxyNewConfigDbFile()); err == nil { + return nil + } else { + return system.CopyFile(path.ObproxyOldConfigDbFile(), path.ObproxyNewConfigDbFile()) + } + } + + return nil +} + +func (t *CopyConfigDbFileTask) Rollback() error { + if pkg.CompareVersion(t.targetVersion, "4.1.0.0") >= 0 { + if _, err := os.Stat(path.ObproxyNewConfigDbFile()); err == nil { + return nil + } else { + return system.CopyFile(path.ObproxyOldConfigDbFile(), path.ObproxyNewConfigDbFile()) + } + } + return nil +} + +type HotRestartObproxyTask struct { + task.Task +} + +func newHotRestartObproxyTask() *HotRestartObproxyTask { + newTask := &HotRestartObproxyTask{ + Task: *task.NewSubTask(TASK_HOT_RESTART_OBPROXY), + } + newTask.SetCanContinue().SetCanRetry().SetCanCancel() + return newTask +} + +func (t *HotRestartObproxyTask) Execute() error { + t.ExecuteLogf("set %s to %s", constant.OBPROXY_CONFIG_PROXY_LOCAL_CMD, constant.RESTART_FOR_PROXY_LOCAL_CMD) + return obproxyService.SetGlobalConfig(constant.OBPROXY_CONFIG_PROXY_LOCAL_CMD, constant.RESTART_FOR_PROXY_LOCAL_CMD) +} + +type RecordObproxyInfoTask struct { + task.Task +} + +func newRecordObproxyInfoTask() *RecordObproxyInfoTask { + newTask := &RecordObproxyInfoTask{ + Task: *task.NewSubTask(TASK_RECORD_OBPROXY_INFO), + } + newTask.SetCanContinue().SetCanRetry().SetCanRollback() + return newTask +} + +func (t *RecordObproxyInfoTask) Execute() error { + rollbackTimeout, err := obproxyService.GetGlobalConfig(constant.OBPROXY_CONFIG_HOT_UPGRADE_ROLLBACK_TIMEOUT) + if err != nil { + return errors.Wrapf(err, "get %s failed", constant.OBPROXY_CONFIG_HOT_UPGRADE_ROLLBACK_TIMEOUT) + } + pid, err := process.FindPIDByPort(uint32(meta.OBPROXY_SQL_PORT)) + if err != nil { + return errors.Wrapf(err, "find obproxy pid failed") + } + t.GetContext().SetData(PARAM_OLD_OBPROXY_PID, pid) + t.GetContext().SetData(PARAM_HOT_UPGRADE_ROLLBACK_TIMEOUT, rollbackTimeout) + return nil +} + +type WaitHotRestartObproxyFinishTask struct { + task.Task + rollbackTimeout string + oldPid int32 + targetVersion string + buildNumber string +} + +func newWaitHotRestartObproxyFinishTask() *WaitHotRestartObproxyFinishTask { + newTask := &WaitHotRestartObproxyFinishTask{ + Task: *task.NewSubTask(TASK_WAIT_HOT_RESTART_OBPROXY_FINISH), + } + newTask.SetCanContinue().SetCanRetry().SetCanCancel() + return newTask +} + +func (t *WaitHotRestartObproxyFinishTask) Execute() error { + var err error + if err = t.GetContext().GetDataWithValue(PARAM_OLD_OBPROXY_PID, &t.oldPid); err != nil { + return err + } + if err = t.GetContext().GetDataWithValue(PARAM_HOT_UPGRADE_ROLLBACK_TIMEOUT, &t.rollbackTimeout); err != nil { + return err + } + if err = t.GetContext().GetParamWithValue(PARAM_VERSION, &t.targetVersion); err != nil { + return err + } + if err = t.GetContext().GetParamWithValue(PARAM_BUILD_NUMBER, &t.buildNumber); err != nil { + return err + } + + // parse rollbackTimeout + rollbackTimeouot, err := parse.TimeParse(t.rollbackTimeout) + if err != nil { + return errors.Wrapf(err, "parse rollback timeout failed") + } + + retryTimes := rollbackTimeouot / waitPeriod + var pid int32 + for i := 0; i < retryTimes; i++ { + t.TimeoutCheck() + time.Sleep(time.Duration(waitPeriod) * time.Second) + pid, err = process.FindPIDByPort(uint32(meta.OBPROXY_SQL_PORT)) + if err != nil { + continue + } + t.ExecuteLogf("obproxy %d is running", pid) + if pid == t.oldPid { + t.ExecuteLogf("obproxy %d is still running, waiting for it to exit...", t.oldPid) + err = errors.New("obproxy is still running") + continue + } + err = t.checkVersion() + break + } + + if err == nil { + // Modify the pid file. + if err := process.WritePidForce(path.ObproxyPidPath(), int(pid)); err != nil { + return errors.Wrapf(err, "write obproxy pid file failed") + } + return nil + + } + + return errors.Wrapf(err, "wait hot restart obproxy finish timeout") +} + +func (t *WaitHotRestartObproxyFinishTask) checkVersion() (err error) { + dsConfig := config.NewObproxyDataSourceConfig().SetPort(meta.OBPROXY_SQL_PORT).SetPassword(meta.OBPROXY_SYS_PWD) + var tempDb *gorm.DB + defer func() { + if tempDb != nil { + db, _ := tempDb.DB() + db.Close() + } + }() + for retryCount := 1; retryCount <= obproxydb.WAIT_OBPROXY_CONNECTED_MAX_TIMES; retryCount++ { + t.ExecuteLogf("retry %d times", retryCount) + t.TimeoutCheck() + time.Sleep(obproxydb.WAIT_OBPROXY_CONNECTED_MAX_INTERVAL) + if tempDb, err = obproxydb.LoadTempObproxyInstance(dsConfig); err != nil { + t.ExecuteLogf("load obproxy instance failed: %s", err.Error()) + continue + } + var proxyInfo bo.ObproxyInfo + if err = tempDb.Raw("show proxyinfo binary").Scan(&proxyInfo).Error; err != nil { + t.ExecuteLogf("show proxyconfig failed: %s", err.Error()) + continue + } + // parse obproxy version + re := regexp.MustCompile(`\d+\.\d+\.\d+\.\d+-\d+`) + version := re.FindString(proxyInfo.Info) + if version != strings.Join([]string{t.targetVersion, t.buildNumber}, "-") { + t.ExecuteLogf("obproxy version is not the target version, current version: %s, target version: %s", version, t.targetVersion) + continue + } + return nil + } + return errors.New("check obproxy version timeout...") +} + +type CreateObproxyUpgradeDirTask struct { + task.Task + upgradeDir string +} + +func newCreateObproxyUpgradeDirTask() *CreateObproxyUpgradeDirTask { + newTask := &CreateObproxyUpgradeDirTask{ + Task: *task.NewSubTask(TASK_CREATE_UPGRADE_DIR), + } + newTask. + SetCanRetry(). + SetCanContinue(). + SetCanRollback(). + SetCanPass(). + SetCanCancel() + return newTask +} + +func (t *CreateObproxyUpgradeDirTask) Execute() (err error) { + if err = t.GetContext().GetParamWithValue(PARAM_UPGRADE_DIR, &t.upgradeDir); err != nil { + return err + } + t.ExecuteLogf("Upgrade dir is %s", t.upgradeDir) + if err = t.checkUpgradeDir(); err != nil { + return err + } + return nil +} + +func (t *CreateObproxyUpgradeDirTask) checkUpgradeDir() (err error) { + t.GetContext().SetData(PARAM_CREATE_UPGRADE_DIR_FLAG, false) + + t.ExecuteLogf("Mkdir %s ", t.upgradeDir) + if err = os.MkdirAll(t.upgradeDir, 0755); err != nil { + return err + } + + isDirEmpty, err := system.IsDirEmpty(t.upgradeDir) + if err != nil { + return err + } + if !isDirEmpty { + return fmt.Errorf("%s is not empty", t.upgradeDir) + } + t.GetContext().SetData(PARAM_CREATE_UPGRADE_DIR_FLAG, true) + return nil +} + +func (t *CreateObproxyUpgradeDirTask) Rollback() (err error) { + t.ExecuteLog("Rolling back...") + if t.GetContext().GetData(PARAM_CREATE_UPGRADE_DIR_FLAG) == nil { + return nil + } + t.ExecuteLog("Remove " + t.upgradeDir) + return os.RemoveAll(t.upgradeDir) +} + +// RemoveUpgradeObproxyDirTask remove upgrade dir +type RemoveUpgradeObproxyDirTask struct { + task.Task + upgradeDir string +} + +func newRemoveUpgradeCheckDirTask() *RemoveUpgradeObproxyDirTask { + newTask := &RemoveUpgradeObproxyDirTask{ + Task: *task.NewSubTask(TASK_REMOVE_UPGRADE_DIR), + } + newTask. + SetCanRetry(). + SetCanContinue(). + SetCanPass(). + SetCanCancel() + return newTask +} + +func (t *RemoveUpgradeObproxyDirTask) Execute() (err error) { + t.ExecuteLog("remove upgrade dir") + if err = t.removeUpgradeDir(); err != nil { + return + } + t.ExecuteLog("remove upgrade check dir finished") + return nil +} + +func (t *RemoveUpgradeObproxyDirTask) removeUpgradeDir() (err error) { + if err := t.GetContext().GetParamWithValue(PARAM_UPGRADE_DIR, &t.upgradeDir); err != nil { + return errors.New("get upgrade check task dir failed") + } + return os.RemoveAll(t.upgradeDir) +} diff --git a/agent/executor/obproxy/upload_pkg.go b/agent/executor/obproxy/upload_pkg.go new file mode 100644 index 00000000..05ab05ba --- /dev/null +++ b/agent/executor/obproxy/upload_pkg.go @@ -0,0 +1,113 @@ +/* + * Copyright (c) 2024 OceanBase. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package obproxy + +import ( + "fmt" + "mime/multipart" + + "github.com/cavaliergopher/rpm" + log "github.com/sirupsen/logrus" + + "github.com/oceanbase/obshell/agent/constant" + "github.com/oceanbase/obshell/agent/errors" + "github.com/oceanbase/obshell/agent/lib/pkg" + "github.com/oceanbase/obshell/agent/repository/model/sqlite" +) + +var defaultFileFormatForObproxy = "/home/admin/obproxy-%s/bin/obproxy" + +type upgradeRpmPkgInfo struct { + rpmFile multipart.File + rpmPkg *rpm.Package + version string + release string + distribution string +} + +func UpgradePkgUpload(input multipart.File) (*sqlite.UpgradePkgInfo, *errors.OcsAgentError) { + r := &upgradeRpmPkgInfo{ + rpmFile: input, + } + + if err := r.CheckUpgradePkg(true); err != nil { + return nil, errors.Occur(errors.ErrKnown, err) + } + + pkgInfo, err := obproxyService.DumpUpgradePkgInfoAndChunkTx(r.rpmPkg, r.rpmFile) + if err != nil { + return nil, errors.Occur(errors.ErrUnexpected, err) + } + return pkgInfo, nil +} + +func (r *upgradeRpmPkgInfo) CheckUpgradePkg(forUpload bool) (err error) { + if r.rpmPkg, err = pkg.ReadRpm(r.rpmFile); err != nil { + return + } + r.version = r.rpmPkg.Version() + + if r.rpmPkg.Name() != constant.PKG_OBPROXY_CE { + return fmt.Errorf("unsupported name '%s', the supported name is '%s'", r.rpmPkg.Name(), constant.PKG_OBPROXY_CE) + } + return r.fileCheck() +} + +func (r *upgradeRpmPkgInfo) fileCheck() (err error) { + // Check for the necessary files required for the agent upgrade process. + if err = r.checkVersion(); err != nil { + return errors.Wrap(err, "failed to check version and release") + } + return r.findAllExpectedFiles([]string{defaultFileFormatForObproxy}) +} + +func (r *upgradeRpmPkgInfo) checkVersion() (err error) { + log.Info("version is ", r.version) + r.release, r.distribution, err = pkg.SplitRelease(r.rpmPkg.Release()) + if err != nil { + return + } + if pkg.CompareVersion(r.rpmPkg.Version(), constant.OBPROXY_MIN_VERSION_SUPPORT) < 0 { + return fmt.Errorf("unsupported obproxy version '%s', the minimum supported version is '%s'", r.rpmPkg.Version(), constant.SUPPORT_MIN_VERSION) + } + return nil +} + +func (r *upgradeRpmPkgInfo) findAllExpectedFiles(expected []string) (err error) { + succeed := true + missingFiles := make([]string, 0) + for _, expect := range expected { + expect = fmt.Sprintf(expect, r.version) + var found bool + for _, actual := range r.rpmPkg.Files() { + if actual.Name() == expect { + log.Info("found file: ", expect) + found = true + break + } + } + if !found { + log.Errorf("file '%s' not found", expect) + missingFiles = append(missingFiles, expect) + succeed = false + } + } + if !succeed { + return fmt.Errorf("these files are missing: '%v'", missingFiles) + } + return nil +} diff --git a/agent/executor/obproxy/utils.go b/agent/executor/obproxy/utils.go new file mode 100644 index 00000000..d34e40f7 --- /dev/null +++ b/agent/executor/obproxy/utils.go @@ -0,0 +1,37 @@ +/* + * Copyright (c) 2024 OceanBase. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package obproxy + +import ( + "fmt" + "os/exec" + "regexp" + "strings" +) + +func getObproxyVersion(homepath string) (string, error) { + output, err := exec.Command(homepath+"/bin/obproxy", "-V").CombinedOutput() + if err != nil { + return "", err + } + re := regexp.MustCompile(`\(OceanBase (\d+\.\d+\.\d+\.\d+) (\d+)\)`) + matches := re.FindStringSubmatch(string(output)) + if len(matches) < 3 { + return "", fmt.Errorf("version not found in output: %s", output) + } + return strings.Join(matches[1:], "-"), nil +} diff --git a/agent/executor/script/script.go b/agent/executor/script/script.go index eabc21bd..1c0d288b 100644 --- a/agent/executor/script/script.go +++ b/agent/executor/script/script.go @@ -174,7 +174,7 @@ func (t *ImportScriptForTenantTask) importByPython(module, scriptPath, sqlfile s } t.ExecuteLogf("Use python to import %s.", module) - str := fmt.Sprintf("%s -h%s -P%d -t%s -f%s", scriptPath, constant.LOCAL_IP, meta.MYSQL_PORT, t.tenantName, sqlfile) + str := fmt.Sprintf("%s -h%s -P%d -t%s -f%s", scriptPath, meta.OCS_AGENT.GetLocalIp(), meta.MYSQL_PORT, t.tenantName, sqlfile) if meta.GetOceanbasePwd() != "" { pwd := strings.ReplaceAll(meta.GetOceanbasePwd(), "'", "'\"'\"'") str = fmt.Sprintf("%s -p'%s'", str, pwd) diff --git a/agent/executor/task/dag.go b/agent/executor/task/dag.go index fd5e33c8..037f9758 100644 --- a/agent/executor/task/dag.go +++ b/agent/executor/task/dag.go @@ -67,6 +67,9 @@ func GetDagDetail(c *gin.Context) { } if agent != nil && !meta.OCS_AGENT.Equal(agent) { + if task.IsObproxyTask(dagDTOParam.GenericID) { + common.SendResponse(c, nil, errors.Occur(errors.ErrTaskNotFound, err)) + } if meta.OCS_AGENT.IsFollowerAgent() { // forward request to master master := agentService.GetMasterAgentInfo() @@ -113,6 +116,10 @@ func GetDagDetail(c *gin.Context) { return } + if task.ConvertToGenericID(dag, dag.GetDagType()) != dagDTOParam.GenericID { + common.SendResponse(c, nil, errors.Occur(errors.ErrTaskNotFound, "dag id not match")) + return + } dagDetailDTO, err = convertDagDetailDTO(dag, *param.ShowDetails) common.SendResponse(c, dagDetailDTO, err) } @@ -167,7 +174,7 @@ func convertDagDetailDTO(dag *task.Dag, fillDeatil bool) (dagDetailDTO *task.Dag return } - nodeDetailDTO, err = getNodeDetail(service, nodes[i]) + nodeDetailDTO, err = getNodeDetail(service, nodes[i], dag.GetDagType()) if err != nil { return } @@ -215,6 +222,9 @@ func DagHandler(c *gin.Context) { } if agent != nil && !meta.OCS_AGENT.Equal(agent) { + if task.IsObproxyTask(dagOperator.GenericID) { + common.SendResponse(c, nil, errors.Occur(errors.ErrTaskNotFound, err)) + } if meta.OCS_AGENT.IsFollowerAgent() { // forward request to master master := agentService.GetMasterAgentInfo() @@ -241,6 +251,11 @@ func DagHandler(c *gin.Context) { return } + if task.ConvertToGenericID(dag, dag.GetDagType()) != dagOperator.GenericID { + common.SendResponse(c, nil, errors.Occur(errors.ErrTaskNotFound, "dag id not match")) + return + } + switch strings.ToUpper(dagOperator.Operator) { case task.ROLLBACK_STR: err = service.SetDagRollback(dag) diff --git a/agent/executor/task/node.go b/agent/executor/task/node.go index e2a9b858..c8706779 100644 --- a/agent/executor/task/node.go +++ b/agent/executor/task/node.go @@ -60,6 +60,9 @@ func GetNodeDetail(c *gin.Context) { } if agent != nil && !meta.OCS_AGENT.Equal(agent) { + if task.IsObproxyTask(nodeDTOParam.GenericID) { + common.SendResponse(c, nil, errors.Occur(errors.ErrTaskNotFound, err)) + } if meta.OCS_AGENT.IsFollowerAgent() { // forward request to master master := agentService.GetMasterAgentInfo() @@ -86,6 +89,7 @@ func GetNodeDetail(c *gin.Context) { common.SendResponse(c, nil, errors.Occur(errors.ErrTaskNotFound, err)) return } + if *param.ShowDetails { _, err = service.GetSubTasks(node) if err != nil { @@ -93,7 +97,16 @@ func GetNodeDetail(c *gin.Context) { return } - nodeDetailDTO, err = getNodeDetail(service, node) + dag, err := service.GetDagInstance(int64(node.GetDagId())) + if err != nil { + common.SendResponse(c, nil, errors.Occur(errors.ErrUnexpected, err)) + return + } + if task.ConvertToGenericID(dag, dag.GetDagType())[0] != nodeDTOParam.GenericID[0] { + common.SendResponse(c, nil, errors.Occur(errors.ErrTaskNotFound, "node type not match")) + return + } + nodeDetailDTO, err = getNodeDetail(service, node, dag.GetDagType()) if err != nil { common.SendResponse(c, nil, errors.Occur(errors.ErrUnexpected, err)) return @@ -103,12 +116,12 @@ func GetNodeDetail(c *gin.Context) { common.SendResponse(c, nodeDetailDTO, nil) } -func getNodeDetail(service taskservice.TaskServiceInterface, node *task.Node) (nodeDetailDTO *task.NodeDetailDTO, err error) { - nodeDetailDTO = task.NewNodeDetailDTO(node) +func getNodeDetail(service taskservice.TaskServiceInterface, node *task.Node, dagType string) (nodeDetailDTO *task.NodeDetailDTO, err error) { + nodeDetailDTO = task.NewNodeDetailDTO(node, dagType) subTasks := node.GetSubTasks() n := len(subTasks) for i := 0; i < n; i++ { - taskDetailDTO, err := getSubTaskDetail(service, subTasks[i]) + taskDetailDTO, err := getSubTaskDetail(service, subTasks[i], dagType) if err != nil { return nil, err } diff --git a/agent/executor/task/sub_task.go b/agent/executor/task/sub_task.go index 5d0b4e44..cc99ee52 100644 --- a/agent/executor/task/sub_task.go +++ b/agent/executor/task/sub_task.go @@ -59,6 +59,9 @@ func GetSubTaskDetail(c *gin.Context) { } if agent != nil && !meta.OCS_AGENT.Equal(agent) { + if task.IsObproxyTask(taskDTOParam.GenericID) { + common.SendResponse(c, nil, errors.Occur(errors.ErrTaskNotFound, err)) + } if meta.OCS_AGENT.IsFollowerAgent() { // forward request to master master := agentService.GetMasterAgentInfo() @@ -85,7 +88,22 @@ func GetSubTaskDetail(c *gin.Context) { return } - taskDetailDTO, err = getSubTaskDetail(service, subTask) + dagType := "" + if subTask.IsLocalTask() { + dag, err := service.GetDagBySubTaskId(subTask.GetID()) + if err != nil { + common.SendResponse(c, nil, errors.Occur(errors.ErrTaskNotFound, err)) + return + } + dagType = dag.GetDagType() + if task.ConvertToGenericID(dag, dag.GetDagType())[0] != taskDTOParam.GenericID[0] { + common.SendResponse(c, nil, errors.Occur(errors.ErrTaskNotFound, "sub task type not match")) + return + } + + } + + taskDetailDTO, err = getSubTaskDetail(service, subTask, dagType) if err != nil { common.SendResponse(c, nil, errors.Occur(errors.ErrUnexpected, err)) return @@ -94,8 +112,8 @@ func GetSubTaskDetail(c *gin.Context) { common.SendResponse(c, taskDetailDTO, nil) } -func getSubTaskDetail(service taskservice.TaskServiceInterface, subTask task.ExecutableTask) (taskDetailDTO *task.TaskDetailDTO, err error) { - taskDetailDTO = task.NewTaskDetailDTO(subTask) +func getSubTaskDetail(service taskservice.TaskServiceInterface, subTask task.ExecutableTask, dagType string) (taskDetailDTO *task.TaskDetailDTO, err error) { + taskDetailDTO = task.NewTaskDetailDTO(subTask, dagType) if subTask.IsRunning() || subTask.IsFinished() { taskDetailDTO.TaskLogs, err = service.GetSubTaskLogsByTaskID(subTask.GetID()) } diff --git a/agent/executor/tenant/create_tenant.go b/agent/executor/tenant/create_tenant.go index 89f07dbc..221a32f4 100644 --- a/agent/executor/tenant/create_tenant.go +++ b/agent/executor/tenant/create_tenant.go @@ -31,6 +31,8 @@ import ( "github.com/oceanbase/obshell/agent/executor/pool" "github.com/oceanbase/obshell/agent/executor/script" "github.com/oceanbase/obshell/agent/executor/zone" + "github.com/oceanbase/obshell/agent/lib/path" + "github.com/oceanbase/obshell/agent/meta" tenantservice "github.com/oceanbase/obshell/agent/service/tenant" "github.com/oceanbase/obshell/param" "github.com/oceanbase/obshell/utils" @@ -101,7 +103,7 @@ func checkParameters(parameters map[string]interface{}) error { return nil } -func checkScenario(scenario string) error { +func checkAndLoadScenario(param *param.CreateTenantParam, scenario string) error { if scenario == "" { return nil } @@ -110,10 +112,30 @@ func checkScenario(scenario string) error { if len(scenarios) == 0 { return errors.New("current observer does not support scenario") } - if utils.ContainsString(scenarios, strings.ToLower(scenario)) { - return nil + if !utils.ContainsString(scenarios, strings.ToLower(scenario)) { + errors.Errorf("scenario only support to be one of %s", strings.Join(scenarios, ", ")) + } + + variables, err := parseTemplate(VARIABLES_TEMPLATE, path.ObshellDefaultVariablePath(), scenario) + if err != nil { + return errors.Wrap(err, "Parse variable template failed") + } + for key, value := range variables { + if _, exist := param.Variables[key]; !exist { + param.Variables[key] = value + } + } + + parameters, err := parseTemplate(PARAMETERS_TEMPLATE, path.ObshellDefaultParameterPath(), scenario) + if err != nil { + return errors.Wrap(err, "Parse parameter template failed") + } + for key, value := range parameters { + if _, exist := param.Parameters[key]; !exist { + param.Parameters[key] = value + } } - return errors.Errorf("scenario only support to be one of %s", strings.Join(scenarios, ", ")) + return nil } func renderCreateTenantParam(param *param.CreateTenantParam) error { @@ -182,10 +204,6 @@ func checkCreateTenantParam(param *param.CreateTenantParam) (err error) { return } - if err = checkScenario(param.Scenario); err != nil { - return - } - if err = checkCharsetAndCollation(param.Charset, param.Collation); err != nil { return } @@ -206,6 +224,10 @@ func checkCreateTenantParam(param *param.CreateTenantParam) (err error) { return } + if err = checkAndLoadScenario(param, param.Scenario); err != nil { + return + } + return nil } @@ -245,18 +267,20 @@ func checkZoneResourceForUnit(zone string, unitName string, unitNum int) error { if err != nil { return err } - log.Infof("server %s:%d used resource: %v", server.SvrIp, server.SvrPort, gatheredUnitInfo) + + serverStr := meta.NewAgentInfo(server.SvrIp, server.SvrPort).String() + log.Infof("server %s used resource: %v", serverStr, gatheredUnitInfo) if server.CpuCapacity-gatheredUnitInfo.MinCpu < unit.MinCpu || server.CpuCapacityMax-gatheredUnitInfo.MaxCpu < unit.MaxCpu { - checkErr = errors.Errorf("server %s:%d CPU resource not enough", server.SvrIp, server.SvrPort) + checkErr = errors.Errorf("server %s CPU resource not enough", serverStr) continue } if server.MemCapacity-gatheredUnitInfo.MemorySize < unit.MemorySize { - checkErr = errors.Errorf("server %s:%d MEMORY_SIZE resource not enough", server.SvrIp, server.SvrPort) + checkErr = errors.Errorf("server %s MEMORY_SIZE resource not enough", serverStr) continue } if server.LogDiskCapacity-gatheredUnitInfo.LogDiskSize < unit.LogDiskSize { - checkErr = errors.Errorf("server %s:%d LOG_DISK_SIZE resource not enough", server.SvrIp, server.SvrPort) + checkErr = errors.Errorf("server %s LOG_DISK_SIZE resource not enough", serverStr) continue } validServer += 1 @@ -277,7 +301,7 @@ type gatheredUnitInfo struct { func gatherAllUnitsOnServer(svrIp string, svrPort int) (*gatheredUnitInfo, error) { units, err := obclusterService.GetObUnitsOnServer(svrIp, svrPort) if err != nil { - return nil, errors.Errorf("Get all units on server %s:%d failed.", svrIp, svrPort) + return nil, errors.Errorf("Get all units on server %s failed.", meta.NewAgentInfo(svrIp, svrPort).String()) } used := &gatheredUnitInfo{} for _, unit := range units { @@ -320,7 +344,7 @@ func CreateTenant(param *param.CreateTenantParam) (*task.DagDetailDTO, *errors.O } // Create 'Create tenant' dag instance. - template, err := buildCreateTenatDagTemplate(param) + template, err := buildCreateTenantDagTemplate(param) if err != nil { return nil, errors.Occur(errors.ErrUnexpected, err.Error()) } @@ -332,7 +356,7 @@ func CreateTenant(param *param.CreateTenantParam) (*task.DagDetailDTO, *errors.O return task.NewDagDetailDTO(dag), nil } -func buildCreateTenatDagTemplate(param *param.CreateTenantParam) (*task.Template, error) { +func buildCreateTenantDagTemplate(param *param.CreateTenantParam) (*task.Template, error) { createTenantNode, err := newCreateTenantNode(param) if err != nil { return nil, err @@ -344,13 +368,24 @@ func buildCreateTenatDagTemplate(param *param.CreateTenantParam) (*task.Template templateBuilder.AddNode(newSetTenantTimeZoneNode(param.TimeZone)) } if param.Parameters != nil && len(param.Parameters) != 0 { - templateBuilder.AddTask(newSetTenantParameterTask(), false) - } - if param.Scenario != "" { - templateBuilder.AddNode(newOptimizeTenantNode(param.Scenario, param)) + templateBuilder.AddNode(newSetTenantParameterNode(param.Parameters)) } templateBuilder.AddNode(newModifyTenantWhitelistNode(*param.Whitelist)) + // Delete the read-only variables + for k := range param.Variables { + if utils.ContainsString(CREATE_TENANT_STATEMENT_VARIABLES, k) { + delete(param.Variables, k) + } + } + if param.Variables != nil && len(param.Variables) != 0 { + node, err := newSetTenantVariableNode(param.Variables) + if err != nil { + return nil, err + } + templateBuilder.AddNode(node) + } + agents, err := agentService.GetAllAgentsInfo() if err != nil { return nil, err @@ -366,13 +401,13 @@ func buildCreateTenatDagTemplate(param *param.CreateTenantParam) (*task.Template } templateBuilder.AddNode(setRootPwdNode) } + return templateBuilder.Build(), nil } func buildCreateTenantDagContext(param *param.CreateTenantParam) *task.TaskContext { context := task.NewTaskContext() context.SetParam(PARAM_TENANT_NAME, param.Name). - SetParam(PARAM_TENANT_PARAMETER, param.Parameters). SetParam(task.FAILURE_EXIT_MAINTENANCE, true) return context } @@ -414,7 +449,7 @@ func newCreateTenantTask() *CreateTenantTask { return newTask } -func buildCreateTenantSql(param param.CreateTenantParam, poolList []string) (string, []interface{}) { +func buildCreateTenantSql(param *param.CreateTenantParam, poolList []string) (string, []interface{}) { resourcePoolList := "\"" + strings.Join(poolList, "\",\"") + "\"" sql := fmt.Sprintf(tenantservice.SQL_CREATE_TENANT_BASIC, *param.Name, resourcePoolList) @@ -462,13 +497,16 @@ func buildCreateTenantSql(param param.CreateTenantParam, poolList []string) (str transferNumber(param.Variables) for k, v := range param.Variables { - if _, ok := v.(string); ok { - sql += ", " + k + "= `%s`" - } else { - sql += ", " + k + "= %v" + if utils.ContainsString(CREATE_TENANT_STATEMENT_VARIABLES, k) { + if _, ok := v.(string); ok { + sql += ", " + k + "= `%s`" + } else { + sql += ", " + k + "= %v" + } + input = append(input, v) } - input = append(input, v) } + return sql, input } @@ -490,7 +528,7 @@ func (t *CreateTenantTask) Execute() error { for _, poolParam := range t.createResourcePoolParam { poolList = append(poolList, poolParam.PoolName) } - basic, input := buildCreateTenantSql(t.CreateTenantParam, poolList) + basic, input := buildCreateTenantSql(&t.CreateTenantParam, poolList) sql := fmt.Sprintf(basic, input...) t.ExecuteLogf("Create tenant sql: %s", sql) if err := tenantService.TryExecute(sql); err != nil { diff --git a/agent/executor/tenant/enter.go b/agent/executor/tenant/enter.go index bf83c0df..8e107ff5 100644 --- a/agent/executor/tenant/enter.go +++ b/agent/executor/tenant/enter.go @@ -34,14 +34,24 @@ var ( clusterTaskService = taskservice.NewClusterTaskService() unitService unit.UnitService agentService agent.AgentService + + // READONLY variables + CREATE_TENANT_STATEMENT_VARIABLES = []string{"lower_case_table_names"} + // Those variables could not set by sys tenant. + VARIAbLES_COLLATION_OR_CHARACTER = []string{ + "collation_server", + "collation_database", + "collation_connection", + "character_set_server", + "character_set_database", + "character_set_connection", + } ) const ( // task param name PARAM_CREATE_TENANT = "createTenant" - PARAM_OPTIMIZE_TENANT = "optimizeTenant" - PARAM_CREATE_TENANT_VARIABLES = "createTenantVariables" - PARAM_CREATE_TENANT_PARAMETERS = "createTenantParameters" + PARAM_TENANT_VARIABLES = "tenantVariables" PARAM_TENANT_NAME = "tenantName" PARAM_TENANT_TIME_ZONE = "timeZone" PARAM_TENANT_ID = "tenantId" @@ -74,6 +84,7 @@ const ( TASK_NAME_MODIFY_PRIMARY_ZONE = "Modify tenant primary zone" TASK_NAME_SET_ROOT_PWD = "Set root password" TASK_NAME_SET_TENANT_PARAM = "Set tenant parameters" + TASK_NAME_SET_TENANT_VARIABLE = "Set tenant variables" TASK_NAME_DROP_RESOURCE_POOL = "Drop resource pools" TASK_NAME_SET_TENANT_PARAMETER = "Set tenant parameter" TASK_NAME_DROP_TENANT = "Drop tenant" @@ -136,7 +147,7 @@ func RegisterTenantTask() { task.RegisterTaskType(SetRootPwdTask{}) task.RegisterTaskType(SetTenantTimeZoneTask{}) task.RegisterTaskType(SetTenantParamterTask{}) - task.RegisterTaskType(OptimizeTenantTask{}) + task.RegisterTaskType(SetTenantVariableTask{}) task.RegisterTaskType(DropTenantTask{}) task.RegisterTaskType(RecycleTenantTask{}) task.RegisterTaskType(BatchCreateResourcePoolTask{}) diff --git a/agent/executor/tenant/optimize.go b/agent/executor/tenant/optimize.go index 3f299015..b55d610f 100644 --- a/agent/executor/tenant/optimize.go +++ b/agent/executor/tenant/optimize.go @@ -20,37 +20,10 @@ import ( "io" "os" - "github.com/oceanbase/obshell/agent/engine/task" - "github.com/oceanbase/obshell/agent/errors" "github.com/oceanbase/obshell/agent/lib/json" "github.com/oceanbase/obshell/agent/lib/path" - "github.com/oceanbase/obshell/param" ) -type OptimizeTenantTask struct { - task.Task - tenantId int - template string - createTenantVariables map[string]interface{} - createTenantParameters map[string]interface{} -} - -func newOptimizeTenantTask() *OptimizeTenantTask { - newTask := &OptimizeTenantTask{ - Task: *task.NewSubTask(TASK_NAME_OPTIMIZE_TENANT), - } - newTask.SetCanRollback().SetCanRetry().SetCanCancel().SetCanContinue().SetCanPass() - return newTask -} - -func newOptimizeTenantNode(template string, createTenantParam *param.CreateTenantParam) *task.Node { - context := task.NewTaskContext(). - SetParam(PARAM_OPTIMIZE_TENANT, template). - SetParam(PARAM_CREATE_TENANT_VARIABLES, createTenantParam.Variables). - SetParam(PARAM_CREATE_TENANT_PARAMETERS, createTenantParam.Parameters) - return task.NewNodeWithContext(newOptimizeTenantTask(), false, context) -} - func getAllSupportedScenarios() (scenarios []string) { if _, err := os.Stat(path.ObshellDefaultVariablePath()); err != nil { return @@ -156,52 +129,3 @@ func parseTemplate(templateType, filepath, scenario string) (map[string]interfac } return res, nil } - -func (t *OptimizeTenantTask) Execute() error { - if err := t.GetContext().GetParamWithValue(PARAM_TENANT_ID, &t.tenantId); err != nil { - return errors.Wrap(err, "Get tenant id failed") - } - if err := t.GetContext().GetParamWithValue(PARAM_OPTIMIZE_TENANT, &t.template); err != nil { - return errors.Wrap(err, "Get template failed") - } - if err := t.GetContext().GetParamWithValue(PARAM_CREATE_TENANT_VARIABLES, &t.createTenantVariables); err != nil { - return errors.Wrap(err, "Get create tenant variables failed") - } - if err := t.GetContext().GetParamWithValue(PARAM_CREATE_TENANT_PARAMETERS, &t.createTenantParameters); err != nil { - return errors.Wrap(err, "Get create tenant parameters failed") - } - - tenantName, err := tenantService.GetTenantName(t.tenantId) - if err != nil { - return errors.Wrap(err, "Get tenant name failed") - } - - variables, err := parseTemplate(VARIABLES_TEMPLATE, path.ObshellDefaultVariablePath(), t.template) - if err != nil { - return errors.Wrap(err, "Parse variable template failed") - } - for key := range t.createTenantVariables { - delete(variables, key) - } - transferNumber(variables) - t.ExecuteLogf("optimize variables: %v\n", variables) - - parameters, err := parseTemplate(PARAMETERS_TEMPLATE, path.ObshellDefaultParameterPath(), t.template) - if err != nil { - return errors.Wrap(err, "Parse parameter template failed") - } - for key := range t.createTenantParameters { - delete(parameters, key) - } - transferNumber(parameters) - t.ExecuteLogf("optimize parameters: %v\n", parameters) - - if err = tenantService.SetTenantVariables(tenantName, variables); err != nil { - return errors.Wrap(err, "Set tenant variables failed") - } - - if err = tenantService.SetTenantParameters(tenantName, parameters); err != nil { - return errors.Wrap(err, "Set tenant parameters failed") - } - return nil -} diff --git a/agent/executor/tenant/parameter.go b/agent/executor/tenant/parameter.go index 4729e11c..8d235df2 100644 --- a/agent/executor/tenant/parameter.go +++ b/agent/executor/tenant/parameter.go @@ -75,6 +75,12 @@ type SetTenantParamterTask struct { tenantId int } +func newSetTenantParameterNode(parameters map[string]interface{}) *task.Node { + subtask := newSetTenantParameterTask() + ctx := task.NewTaskContext().SetParam(PARAM_TENANT_PARAMETER, parameters) + return task.NewNodeWithContext(subtask, false, ctx) +} + func newSetTenantParameterTask() *SetTenantParamterTask { newTask := &SetTenantParamterTask{ Task: *task.NewSubTask(TASK_NAME_SET_TENANT_PARAMETER), diff --git a/agent/executor/tenant/set_root_pwd.go b/agent/executor/tenant/set_root_pwd.go index d92b6306..5b10ec8f 100644 --- a/agent/executor/tenant/set_root_pwd.go +++ b/agent/executor/tenant/set_root_pwd.go @@ -32,7 +32,7 @@ type SetRootPwdTask struct { newPassword string } -func getExecuteAgentForSetTenantRootPwd(tenantName string) (meta.AgentInfoInterface, error) { +func getExecuteAgentForTenant(tenantName string) (meta.AgentInfoInterface, error) { isTenantOn, err := tenantService.IsTenantActiveAgent(tenantName, meta.OCS_AGENT.GetIp(), meta.RPC_PORT) if err != nil { return nil, err @@ -57,7 +57,7 @@ func ModifyTenantRootPassword(c *gin.Context, tenantName string, pwdParam param. if tenantName == constant.TENANT_SYS { return errors.Occur(errors.ErrIllegalArgument, "Can not modify root password for sys tenant."), false } - executeAgent, err := getExecuteAgentForSetTenantRootPwd(tenantName) + executeAgent, err := getExecuteAgentForTenant(tenantName) if err != nil { return errors.Occurf(errors.ErrUnexpected, "get execute agent failed: %s", err.Error()), false } diff --git a/agent/executor/tenant/user.go b/agent/executor/tenant/user.go new file mode 100644 index 00000000..8c80e831 --- /dev/null +++ b/agent/executor/tenant/user.go @@ -0,0 +1,123 @@ +/* + * Copyright (c) 2024 OceanBase. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package tenant + +import ( + "github.com/oceanbase/obshell/agent/constant" + "github.com/oceanbase/obshell/agent/errors" + "github.com/oceanbase/obshell/agent/repository/db/oceanbase" + "github.com/oceanbase/obshell/param" + "gorm.io/gorm" +) + +func CreateUser(tenantName string, param param.CreateUserParam) *errors.OcsAgentError { + if exist, err := tenantService.IsTenantExist(tenantName); err != nil { + return errors.Occurf(errors.ErrUnexpected, "check tenant '%s' exist failed", tenantName) + } else if !exist { + return errors.Occurf(errors.ErrBadRequest, "Tenant '%s' not exists.", tenantName) + } + + var db *gorm.DB + var err error + if tenantName == constant.TENANT_SYS { + db, err = oceanbase.GetInstance() + if err != nil { + return errors.Occurf(errors.ErrUnexpected, "get oceanbase instance failed") + } + } else { + defer func() { + if db != nil { + tempDb, _ := db.DB() + if tempDb != nil { + tempDb.Close() + } + } + }() + db, err = oceanbase.LoadGormWithTenant(tenantName, param.RootPassword) + if err != nil { + return errors.Occurf(errors.ErrUnexpected, "load gorm with tenant '%s' failed", tenantName) + } + } + + if param.HostName == "" { + param.HostName = constant.DEFAULT_HOST + } + + // Create user. + if err := tenantService.CreateUser(db, param.UserName, param.Password, param.HostName); err != nil { + return errors.Occurf(errors.ErrUnexpected, "create user '%s' failed: %s", param.UserName, err.Error()) + } + + // Grant privileges. + if len(param.GlobalPrivileges) != 0 { + if err := tenantService.GrantGlobalPrivileges(db, param.UserName, param.HostName, param.GlobalPrivileges); err != nil { + return errors.Occurf(errors.ErrUnexpected, "grant global privileges to user '%s' failed: %s", param.UserName, err.Error()) + } + } + + for _, dbPrivilege := range param.DbPrivileges { + if err := tenantService.GrantDbPrivileges(db, param.UserName, param.HostName, dbPrivilege); err != nil { + return errors.Occurf(errors.ErrUnexpected, "grant db privileges to user '%s' failed: %s", param.UserName, err.Error()) + } + } + + return nil +} + +func DropUser(tenantName, userName, rootPassword string) *errors.OcsAgentError { + if exist, err := tenantService.IsTenantExist(tenantName); err != nil { + return errors.Occurf(errors.ErrUnexpected, "check tenant '%s' exist failed", tenantName) + } else if !exist { + return errors.Occurf(errors.ErrBadRequest, "Tenant '%s' not exists.", tenantName) + } + + var db *gorm.DB + var err error + if tenantName == constant.TENANT_SYS { + db, err = oceanbase.GetInstance() + if err != nil { + return errors.Occurf(errors.ErrUnexpected, "get oceanbase instance failed") + } + } else { + defer func() { + if db != nil { + tempDb, _ := db.DB() + if tempDb != nil { + tempDb.Close() + } + } + }() + db, err = oceanbase.LoadGormWithTenant(tenantName, rootPassword) + if err != nil { + return errors.Occurf(errors.ErrUnexpected, "load gorm with tenant '%s' failed", tenantName) + } + } + + // Check user exist. + if exist, err := tenantService.IsUserExist(db, userName); err != nil { + return errors.Occurf(errors.ErrUnexpected, "check user '%s' exist failed", userName) + } else if !exist { + return nil + } + + // Drop user. + if err := tenantService.DropUser(db, userName); err != nil { + return errors.Occurf(errors.ErrUnexpected, "drop user '%s' failed: %s", userName, err.Error()) + } + + return nil +} diff --git a/agent/executor/tenant/variable.go b/agent/executor/tenant/variable.go index 28b99691..a3a73e5d 100644 --- a/agent/executor/tenant/variable.go +++ b/agent/executor/tenant/variable.go @@ -19,9 +19,15 @@ package tenant import ( "regexp" + "github.com/gin-gonic/gin" + "github.com/oceanbase/obshell/agent/api/common" "github.com/oceanbase/obshell/agent/constant" + "github.com/oceanbase/obshell/agent/engine/task" "github.com/oceanbase/obshell/agent/errors" + "github.com/oceanbase/obshell/agent/meta" "github.com/oceanbase/obshell/agent/repository/model/oceanbase" + "github.com/oceanbase/obshell/param" + "github.com/oceanbase/obshell/utils" ) func isUnkonwnTimeZoneErr(err error) bool { @@ -56,23 +62,48 @@ func GetTenantVariable(tenantName string, variableName string) (*oceanbase.CdbOb return variable, nil } -func SetTenantVariables(tenantName string, variables map[string]interface{}) *errors.OcsAgentError { +func SetTenantVariables(c *gin.Context, tenantName string, param param.SetTenantVariablesParam) *errors.OcsAgentError { if _, err := checkTenantExistAndStatus(tenantName); err != nil { return err } - for k, v := range variables { + for k, v := range param.Variables { if k == "" || v == nil { return errors.Occur(errors.ErrIllegalArgument, "variable name or value is empty") } } - transferNumber(variables) - if err := tenantService.SetTenantVariables(tenantName, variables); err != nil { - if errors.IsUnkonwnTimeZoneErr(err) { - if value, exist := variables[constant.VARIABLE_TIME_ZONE]; exist { - return timeZoneErrorReporter(value, err) + transferNumber(param.Variables) + + needConnectTenant := false + for k := range param.Variables { + if utils.ContainsString(VARIAbLES_COLLATION_OR_CHARACTER, k) { + needConnectTenant = true + break + } + } + + if !needConnectTenant { + if err := tenantService.SetTenantVariables(tenantName, param.Variables); err != nil { + if errors.IsUnkonwnTimeZoneErr(err) { + if value, exist := param.Variables[constant.VARIABLE_TIME_ZONE]; exist { + return timeZoneErrorReporter(value, err) + } + } + return errors.Occur(errors.ErrBadRequest, err) + } + } else { + executeAgent, err := getExecuteAgentForTenant(tenantName) + if err != nil { + return errors.Occurf(errors.ErrUnexpected, "get execute agent failed: %s", err.Error()) + } + + if meta.OCS_AGENT.Equal(executeAgent) { + if err := tenantService.SetTenantVariablesWithTenant(tenantName, param.TenantPassword, param.Variables); err != nil { + return errors.Occur(errors.ErrUnexpected, err) } + } else { + common.ForwardRequest(c, executeAgent, param) + return nil } - return errors.Occur(errors.ErrBadRequest, err) } return nil @@ -84,9 +115,61 @@ func timeZoneErrorReporter(timeZone interface{}, err error) *errors.OcsAgentErro re := regexp.MustCompile(pattern) if re.MatchString(v) { if empty, _ := tenantService.IsTimeZoneTableEmpty(); empty { - return errors.Occur(errors.ErrBadRequest, errors.Wrapf(err, "Please check whether the sys tenat has been import time zone info")) + return errors.Occur(errors.ErrBadRequest, errors.Wrapf(err, "Please check whether the sys tenant has been import time zone info")) } } } return errors.Occur(errors.ErrBadRequest, err) } + +type SetTenantVariableTask struct { + task.Task + variables map[string]interface{} + tenantName string +} + +func newSetTenantVariableNode(variables map[string]interface{}) (*task.Node, error) { + agents, err := agentService.GetAllAgentsInfoFromOB() + if err != nil { + return nil, errors.Wrap(err, "create set tenant variable task failed") + } + ctx := task.NewTaskContext(). + SetParam(task.EXECUTE_AGENTS, agents). + SetParam(PARAM_TENANT_VARIABLES, variables) + return task.NewNodeWithContext(newSetTenantVariableTask(), true, ctx), nil +} + +func newSetTenantVariableTask() *SetTenantVariableTask { + newTask := &SetTenantVariableTask{ + Task: *task.NewSubTask(TASK_NAME_SET_TENANT_VARIABLE), + } + + newTask.SetCanContinue().SetCanRollback().SetCanRetry().SetCanCancel() + return newTask +} + +func (t *SetTenantVariableTask) Execute() error { + if err := t.GetContext().GetParamWithValue(PARAM_TENANT_NAME, &t.tenantName); err != nil { + return errors.Wrap(err, "Get tenant name failed") + } + + if err := t.GetContext().GetParamWithValue(PARAM_TENANT_VARIABLES, &t.variables); err != nil { + return errors.Wrap(err, "Get tenant variables failed") + } + + executeAgent, err := tenantService.GetTenantActiveAgent(t.tenantName) + if err != nil { + return err + } + if executeAgent == nil { + return errors.New("tenant is not active") + } + + if meta.OCS_AGENT.Equal(executeAgent) { + transferNumber(t.variables) + if err := tenantService.SetTenantVariablesWithTenant(t.tenantName, "", t.variables); err != nil { + return errors.Occurf(errors.ErrUnexpected, "set tenant variables failed: %s", err.Error()) + } + } + return nil +} diff --git a/agent/global/variable.go b/agent/global/variable.go index b94678ca..373f1a30 100644 --- a/agent/global/variable.go +++ b/agent/global/variable.go @@ -31,17 +31,18 @@ import ( ) var ( - HomePath string - Uid uint32 - Gid uint32 - Pid = os.Getpid() - StartAt = time.Now().UnixNano() - Protocol = "http" - CaCertPool *x509.CertPool - SkipVerify bool - EnableHTTPS bool - Architecture string - Os string + HomePath string + Uid uint32 + Gid uint32 + Pid = os.Getpid() + StartAt = time.Now().UnixNano() + Protocol = "http" + CaCertPool *x509.CertPool + SkipVerify bool + EnableHTTPS bool + Architecture string + Os string + ObproxyHomePath string ) var ( diff --git a/agent/lib/binary/ob.go b/agent/lib/binary/ob.go index 1de1b0df..39cfe3cf 100644 --- a/agent/lib/binary/ob.go +++ b/agent/lib/binary/ob.go @@ -18,6 +18,7 @@ package binary import ( "fmt" + "os" "os/exec" "path/filepath" "regexp" @@ -30,9 +31,12 @@ import ( func GetMyOBVersion() (version string, err error) { myOBPath := filepath.Join(global.HomePath, constant.DIR_BIN, constant.PROC_OBSERVER) bash := fmt.Sprintf("export LD_LIBRARY_PATH='%s/lib'; %s -V", global.HomePath, myOBPath) + if os.Stat(myOBPath); err != nil { + return "", errors.Wrap(err, "get my ob version failed") + } out, err := exec.Command("/bin/bash", "-c", bash).CombinedOutput() if err != nil { - return "", errors.Wrap(err, "exec get my ob version failed") + return "", err } res := string(out) diff --git a/agent/lib/http/http.go b/agent/lib/http/http.go index bebe2f3d..1528e7f5 100644 --- a/agent/lib/http/http.go +++ b/agent/lib/http/http.go @@ -146,7 +146,7 @@ func NewClient() *resty.Client { func sendHttpRequest(agentInfo meta.AgentInfoInterface, uri string, method string, param, ret interface{}, headers map[string]string) (agentResponse ocsAgentResponse, err error) { var agentResp OcsAgentResponse var response *resty.Response - targetUrl := fmt.Sprintf("%s://%s:%d%s", global.Protocol, agentInfo.GetIp(), agentInfo.GetPort(), uri) + targetUrl := fmt.Sprintf("%s://%s%s", global.Protocol, agentInfo.String(), uri) request := NewClient().R() if ret != nil { request.SetResult(&agentResp) diff --git a/agent/lib/parse/time.go b/agent/lib/parse/time.go new file mode 100644 index 00000000..4ea86146 --- /dev/null +++ b/agent/lib/parse/time.go @@ -0,0 +1,67 @@ +/* + * Copyright (c) 2024 OceanBase. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package parse + +import ( + "errors" + "fmt" + "regexp" + "strconv" + "strings" +) + +const ( + TIME_SECOND = "S" + TIME_MINUTE = "M" + TIME_HOUR = "H" + TIME_DAY = "D" +) + +func TimeParse(input string) (int, error) { + // Compile a regular expression to match the input format + pattern := regexp.MustCompile(`^([0-9]+)([a-zA-Z]?)$`) + matches := pattern.FindStringSubmatch(input) + + // Check if the input matches the pattern + if matches == nil { + return 0, errors.New("The input string is invalid: " + input) + } + + // Convert the captured numeric part of the input to an integer + num, err := strconv.Atoi(matches[1]) + if err != nil { + return 0, fmt.Errorf("Error parsing number: %v", err) + } + + // Get the unit character (if any) and determine the conversion factor + unit := matches[2] + switch strings.ToUpper(unit) { + case "": + // Default unit is microseconds, so convert to seconds + return num / 1000 / 1000, nil + case TIME_SECOND: + return num, nil + case TIME_MINUTE: + return num * 60, nil + case TIME_HOUR: + return num * 60 * 60, nil + case TIME_DAY: + return num * 24 * 60 * 60, nil + default: + return 0, errors.New("The input string is invalid: " + input) + } +} diff --git a/agent/lib/path/obproxy.go b/agent/lib/path/obproxy.go new file mode 100644 index 00000000..f921b1dc --- /dev/null +++ b/agent/lib/path/obproxy.go @@ -0,0 +1,65 @@ +/* + * Copyright (c) 2024 OceanBase. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package path + +import ( + "fmt" + "path/filepath" + + "github.com/oceanbase/obshell/agent/constant" + "github.com/oceanbase/obshell/agent/meta" +) + +func ObproxyEtcDir() string { + return filepath.Join(meta.OBPROXY_HOME_PATH, constant.OBPROXY_DIR_ETC) +} + +func ObproxyLibDir() string { + return filepath.Join(meta.OBPROXY_HOME_PATH, constant.OBPROXY_DIR_LIB) +} + +func ObproxyLogDir() string { + return filepath.Join(meta.OBPROXY_HOME_PATH, constant.OBPROXY_DIR_LIB) +} + +func ObproxyRunDir() string { + return filepath.Join(meta.OBPROXY_HOME_PATH, constant.OBPROXY_DIR_RUN) +} + +func ObproxyBinDir() string { + return filepath.Join(meta.OBPROXY_HOME_PATH, constant.OBPROXY_DIR_BIN) +} + +func ObproxyBinPath() string { + return filepath.Join(ObproxyBinDir(), constant.BIN_OBPROXY) +} + +func ObproxyPidPath() string { + return filepath.Join(ObproxyRunDir(), fmt.Sprintf("%s.pid", constant.BIN_OBPROXY)) +} + +func ObproxydPidPath() string { + return filepath.Join(ObproxyRunDir(), fmt.Sprintf("%s.pid", constant.BIN_OBPROXYD)) +} + +func ObproxyNewConfigDbFile() string { + return filepath.Join(ObproxyEtcDir(), "proxyconfig_v1.db") +} + +func ObproxyOldConfigDbFile() string { + return filepath.Join(ObproxyEtcDir(), "proxyconfig.db") +} diff --git a/agent/lib/pkg/rpm.go b/agent/lib/pkg/rpm.go new file mode 100644 index 00000000..9721ce70 --- /dev/null +++ b/agent/lib/pkg/rpm.go @@ -0,0 +1,153 @@ +/* + * Copyright (c) 2024 OceanBase. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package pkg + +import ( + "fmt" + "io" + "mime/multipart" + "os" + "path/filepath" + "strings" + + + log "github.com/sirupsen/logrus" + + "github.com/cavaliergopher/cpio" + "github.com/cavaliergopher/rpm" + "github.com/oceanbase/obshell/agent/errors" + "github.com/ulikunitz/xz" +) + +func ReadRpm(input multipart.File) (pkg *rpm.Package, err error) { + if _, err = input.Seek(0, 0); err != nil { + return + } + if err = rpm.MD5Check(input); err != nil { + return + } + if _, err = input.Seek(0, 0); err != nil { + return + } + return rpm.Read(input) +} + +func SplitRelease(release string) (buildNumber, distribution string, err error) { + releaseSplit := strings.Split(release, ".") + if len(releaseSplit) < 2 { + return "", "", fmt.Errorf("release format %s is illegal", release) + } + buildNumber = releaseSplit[0] + distribution = releaseSplit[len(releaseSplit)-1] + return +} + +func InstallRpmPkgInPlace(path string) (err error) { + log.Infof("InstallRpmPkg: %s", path) + f, err := os.Open(path) + if err != nil { + return + } + defer f.Close() + + pkg, err := rpm.Read(f) + if err != nil { + return + } + if err = CheckCompressAndFormat(pkg); err != nil { + return + } + + xzReader, err := xz.NewReader(f) + if err != nil { + return + } + installPath := filepath.Dir(path) + cpioReader := cpio.NewReader(xzReader) + + for { + hdr, err := cpioReader.Next() + if err == io.EOF { + break + } + if err != nil { + return err + } + + m := hdr.Mode + if m.IsDir() { + dest := filepath.Join(installPath, hdr.Name) + log.Infof("%s is a directory, creating %s", hdr.Name, dest) + if err := os.MkdirAll(dest, 0755); err != nil { + return errors.Wrapf(err, "mkdir failed %s", hdr.Name) + } + + } else if m.IsRegular() { + if err := handleRegularFile(hdr, cpioReader, installPath); err != nil { + return err + } + + } else if hdr.Linkname != "" { + if err := handleSymlink(hdr, installPath); err != nil { + return err + } + } else { + log.Infof("Skipping unsupported file %s type: %v", hdr.Name, m) + } + } + + return nil +} + +func handleRegularFile(hdr *cpio.Header, cpioReader *cpio.Reader, installPath string) error { + dest := filepath.Join(installPath, hdr.Name) + if err := os.MkdirAll(filepath.Dir(dest), 0755); err != nil { + log.WithError(err).Error("mkdir failed") + return err + } + + outFile, err := os.Create(dest) + if err != nil { + return err + } + defer outFile.Close() + + log.Infof("Extracting %s", hdr.Name) + if _, err := io.Copy(outFile, cpioReader); err != nil { + return err + } + return nil +} + +func handleSymlink(hdr *cpio.Header, installPath string) error { + dest := filepath.Join(installPath, hdr.Name) + if err := os.Symlink(hdr.Linkname, dest); err != nil { + return errors.Wrapf(err, "create symlink failed %s -> %s", dest, hdr.Linkname) + } + log.Infof("Creating symlink %s -> %s", dest, hdr.Linkname) + return nil +} + +func CheckCompressAndFormat(pkg *rpm.Package) error { + if pkg.PayloadCompression() != "xz" { + return fmt.Errorf("unsupported compression '%s', the supported compression is 'xz'", pkg.PayloadCompression()) + } + if pkg.PayloadFormat() != "cpio" { + return fmt.Errorf("unsupported payload format '%s', the supported payload format is 'cpio'", pkg.PayloadFormat()) + } + return nil +} diff --git a/agent/lib/process/process.go b/agent/lib/process/process.go index a4d0ee0f..faac14be 100644 --- a/agent/lib/process/process.go +++ b/agent/lib/process/process.go @@ -26,6 +26,7 @@ import ( "strings" "syscall" + "github.com/shirou/gopsutil/v3/net" log "github.com/sirupsen/logrus" "github.com/oceanbase/obshell/agent/lib/path" @@ -81,6 +82,28 @@ func getObserverProcess() (*ProcessInfo, error) { }, nil } +func getObproxyProcess() (*ProcessInfo, error) { + pid, err := getPid(path.ObproxyPidPath()) + if err != nil { + return nil, err + } + return &ProcessInfo{ + pid: fmt.Sprint(pid), + procPath: fmt.Sprintf("/proc/%d", pid), + }, nil +} + +func getObproxydProcess() (*ProcessInfo, error) { + pid, err := getPid(path.ObproxydPidPath()) + if err != nil { + return nil, err + } + return &ProcessInfo{ + pid: fmt.Sprint(pid), + procPath: fmt.Sprintf("/proc/%d", pid), + }, nil +} + func (p *ProcessInfo) exist() (bool, error) { if p.pid == "" { return false, nil @@ -139,6 +162,22 @@ func CheckObserverProcess() (bool, error) { return process.Exist() } +func CheckObproxyProcess() (bool, error) { + process, err := getObproxyProcess() + if err != nil { + return false, err + } + return process.Exist() +} + +func CheckObproxydProcess() (bool, error) { + process, err := getObproxydProcess() + if err != nil { + return false, err + } + return process.Exist() +} + func CheckProcessExist(pid int32) (bool, error) { proc, err := os.FindProcess(int(pid)) if err != nil { @@ -208,3 +247,65 @@ func ExecuteBinary(binaryPath string, inputs []string) (err error) { // Wait for command execution to complete. return cmd.Wait() } + +// for obproxy +func GetObproxyPid() (string, error) { + process, err := getObproxyProcess() + if err != nil { + return "", err + } + return process.Pid() +} + +// for obproxy +func GetObproxydPid() (string, error) { + process, err := getObproxydProcess() + if err != nil { + return "", err + } + return process.Pid() +} + +// writePid writes the pid to the specified path atomically. +// If the file already exists, an error is returned. +func WritePid(path string, pid int) (err error) { + f, err := os.OpenFile(path, os.O_RDWR|os.O_CREATE|os.O_EXCL|os.O_SYNC|syscall.O_CLOEXEC, 0644) + if err != nil { + return err + } + defer f.Close() + _, err = fmt.Fprint(f, pid) + if err != nil { + return err + } + return nil +} + +// writePid writes the pid to the specified path atomically. +func WritePidForce(path string, pid int) (err error) { + f, err := os.OpenFile(path, os.O_RDWR|os.O_CREATE|os.O_TRUNC|os.O_SYNC|syscall.O_CLOEXEC, 0644) + if err != nil { + return err + } + defer f.Close() + _, err = fmt.Fprint(f, pid) + if err != nil { + return err + } + return nil +} + +func FindPIDByPort(port uint32) (int32, error) { + // NOTICE: use inet6 to support ipv6 + connections, err := net.Connections("inet") + if err != nil { + return 0, err + } + + for _, conn := range connections { + if conn.Laddr.Port == port { + return conn.Pid, nil + } + } + return 0, fmt.Errorf("no process found on port %d", port) +} diff --git a/agent/meta/agent.go b/agent/meta/agent.go index 2d876e97..d37bea51 100644 --- a/agent/meta/agent.go +++ b/agent/meta/agent.go @@ -18,8 +18,14 @@ package meta import ( "fmt" + "net" + "regexp" "strconv" "strings" + + "github.com/oceanbase/obshell/agent/constant" + "github.com/oceanbase/obshell/agent/errors" + "github.com/oceanbase/obshell/utils" ) type AgentIdentity string @@ -66,6 +72,8 @@ type Agent interface { GetVersion() string GetAgentInfo() AgentInfo String() string + IsIPv6() bool + GetLocalIp() string Equal(other AgentInfoInterface) bool } @@ -80,11 +88,25 @@ func (agentInfo *AgentInfo) GetIp() string { return agentInfo.Ip } +func (agentInfo *AgentInfo) GetLocalIp() string { + if agentInfo.IsIPv6() { + return constant.LOCAL_IP_V6 + } + return constant.LOCAL_IP +} + func (agentInfo *AgentInfo) GetPort() int { return agentInfo.Port } -func (agentInfo *AgentInfo) String() string { +func (agentInfo *AgentInfo) IsIPv6() bool { + return strings.Contains(agentInfo.Ip, ":") +} + +func (agentInfo AgentInfo) String() string { + if agentInfo.IsIPv6() { + return fmt.Sprintf("[%s]:%d", agentInfo.Ip, agentInfo.Port) + } return fmt.Sprintf("%s:%d", agentInfo.Ip, agentInfo.Port) } @@ -133,6 +155,7 @@ type AgentStatus struct { HomePath string `json:"homePath"` OBVersion string `json:"obVersion"` AgentInstance + Security bool `json:"security"` SupportedAuth []string `json:"supportedAuth"` } @@ -199,21 +222,76 @@ func NewAgentInfo(ip string, port int) *AgentInfo { } } -func NewAgentInfoByString(info string) *AgentInfo { - portIndex := strings.LastIndex(info, ":") - if portIndex == -1 { - return nil +func ConvertAddressToAgentInfo(host string) (*AgentInfo, error) { + if host == "" { + return nil, errors.New("host is empty") + } + if strings.Contains(host, ".") { + // If the host contains '.', it might be an IPv4 address, but further validation is needed. + return convertIPv4ToAgentInfo(host) + } else { + // If the host contains '.', it might be an IPv6 address, but further validation is needed. + return convertIPv6ToAgentInfo(host) } +} - ip := info[:portIndex] - port, err := strconv.Atoi(info[portIndex+1:]) - if err != nil { - return nil +func convertIPv4ToAgentInfo(host string) (*AgentInfo, error) { + var ip string + var err error + var port = constant.DEFAULT_AGENT_PORT + matches := strings.Split(host, ":") + if len(matches) == 1 { + return NewAgentInfo(matches[0], constant.DEFAULT_AGENT_PORT), nil + } else if len(matches) == 2 { + if port, err = strconv.Atoi(matches[1]); err != nil || !utils.IsValidPortValue(port) { + return nil, errors.Errorf("Invalid port: %s. Port number should be in the range [1024, 65535].", matches[1]) + } + ip = matches[0] + } else { + return nil, errors.Errorf("Invalid server format: %s", host) } - return &AgentInfo{ - Ip: ip, - Port: port, + + ipv4 := net.ParseIP(ip) + if ipv4 == nil || ipv4.To4() == nil { + return nil, errors.Errorf("%s is not a valid IP address", ip) } + return NewAgentInfo(ip, port), nil +} + +func convertIPv6ToAgentInfo(host string) (*AgentInfo, error) { + re := regexp.MustCompile(`(?:\[([0-9a-fA-F:]+)\]|([0-9a-fA-F:]+))(?:\:(\d+))?`) + matches := re.FindStringSubmatch(host) + + if matches == nil { + return nil, errors.Errorf("Invalid server format: %s", host) + } + + var ip string + var err error + var port = constant.DEFAULT_AGENT_PORT + if matches[1] != "" { + ip = matches[1] + } else { + ip = matches[2] + } + + if matches[3] != "" { + if port, err = strconv.Atoi(matches[3]); err != nil || !utils.IsValidPortValue(port) { + return nil, errors.Errorf("Invalid port: %s. Port number should be in the range [1024, 65535].", matches[1]) + } + } + + ipv6 := net.ParseIP(ip) + if ipv6 == nil || ipv6.To4() != nil { + return nil, errors.Errorf("%s is not a valid IP address", ip) + } + return NewAgentInfo(ip, port), nil +} + +func NewAgentInfoByString(info string) *AgentInfo { + // if err != nil, agent will be nil. So, no need to check err. + agent, _ := ConvertAddressToAgentInfo(info) + return agent } func NewAgentInfoByInterface(agentInfo AgentInfoInterface) *AgentInfo { @@ -306,13 +384,14 @@ func NewAgentSecretByAgentInfo(agent AgentInfoInterface, publicKey string) *Agen } } -func NewAgentStatus(agent Agent, pid int, state int32, startAt int64, homePath string, obVersion string) *AgentStatus { +func NewAgentStatus(agent Agent, pid int, state int32, startAt int64, homePath string, obVersion string, isAgentPasswordSet bool) *AgentStatus { return &AgentStatus{ Pid: pid, State: state, StartAt: startAt, HomePath: homePath, OBVersion: obVersion, + Security: isAgentPasswordSet, AgentInstance: *NewAgentInstanceByAgent(agent), SupportedAuth: []string{AUTH_V2}, } diff --git a/agent/meta/obproxy.go b/agent/meta/obproxy.go new file mode 100644 index 00000000..beb67796 --- /dev/null +++ b/agent/meta/obproxy.go @@ -0,0 +1,28 @@ +/* + * Copyright (c) 2024 OceanBase. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package meta + +var ( + OBPROXY_HOME_PATH = "" + OBPROXY_SQL_PORT = 0 +) + +type ObproxyInfo = AgentInfo + +func IsObproxyAgent() bool { + return OBPROXY_HOME_PATH != "" +} diff --git a/agent/meta/security.go b/agent/meta/security.go index c2f70d8b..ed5d7974 100644 --- a/agent/meta/security.go +++ b/agent/meta/security.go @@ -16,10 +16,31 @@ package meta +type AgentPwd struct { + inited bool + password string +} + var ( - OCEANBASE_PWD string + OCEANBASE_PWD string + OBPROXY_SYS_PWD string + AGENT_PWD AgentPwd + OCEANBASE_PASSWORD_INITIALIZED bool // Which means the oceanbase password has been initialized ) +func (p *AgentPwd) Inited() bool { + return p.inited +} + +func (p *AgentPwd) GetPassword() string { + return p.password +} + +func (p *AgentPwd) SetPassword(pwd string) { + p.password = pwd + p.inited = true +} + func GetOceanbasePwd() string { if OCS_AGENT != nil && (OCS_AGENT.IsClusterAgent() || OCS_AGENT.IsTakeover()) { return OCEANBASE_PWD @@ -29,4 +50,15 @@ func GetOceanbasePwd() string { func SetOceanbasePwd(pwd string) { OCEANBASE_PWD = pwd + if !OCEANBASE_PASSWORD_INITIALIZED { + OCEANBASE_PASSWORD_INITIALIZED = true + } +} + +func GetObproxySysPwd() string { + return OBPROXY_SYS_PWD +} + +func SetObproxySysPwd(pwd string) { + OBPROXY_SYS_PWD = pwd } diff --git a/agent/repository/db/obproxy/instance.go b/agent/repository/db/obproxy/instance.go new file mode 100644 index 00000000..d3d5065b --- /dev/null +++ b/agent/repository/db/obproxy/instance.go @@ -0,0 +1,126 @@ +/* + * Copyright (c) 2024 OceanBase. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package obproxy + +import ( + "time" + + log "github.com/sirupsen/logrus" + + "github.com/oceanbase/obshell/agent/config" + "github.com/oceanbase/obshell/agent/constant" + "github.com/oceanbase/obshell/agent/errors" + "github.com/oceanbase/obshell/agent/meta" + "github.com/oceanbase/obshell/agent/repository/driver" + "gorm.io/gorm" + "gorm.io/gorm/logger" + "gorm.io/gorm/schema" +) + +var ( + obproxyInstance *gorm.DB + + WAIT_OBPROXY_CONNECTED_MAX_TIMES = 100 + WAIT_OBPROXY_CONNECTED_MAX_INTERVAL = 10 * time.Second +) + +func LoadObproxyInstance() (db *gorm.DB, err error) { + if meta.OBPROXY_SQL_PORT == 0 { + return nil, errors.New("obproxy sql port has not been initialized") + } + dsConfig := config.NewObproxyDataSourceConfig().SetPort(meta.OBPROXY_SQL_PORT).SetPassword(meta.OBPROXY_SYS_PWD) + + gormConfig := gorm.Config{ + Logger: logger.Default.LogMode(dsConfig.GetLoggerLevel()), + NamingStrategy: schema.NamingStrategy{ + SingularTable: constant.DB_SINGULAR_TABLE, + }} + db, err = gorm.Open(driver.OpenObproxy(dsConfig.GetDSN()), &gormConfig) + if err == nil { + releaseDB(obproxyInstance) + obproxyInstance = db + } else { + return nil, errors.Wrap(err, "load obproxy instance failed") + } + return obproxyInstance, nil +} + +func LoadObproxyInstanceForHealthCheck(dsConfig *config.ObDataSourceConfig) (err error) { + db, err := LoadTempObproxyInstance(dsConfig) + if err != nil { + return errors.Wrap(err, "load obproxy instance failed") + } + if err := db.Exec("show proxyconfig").Error; err != nil { + return errors.Wrap(err, "check obproxy instance failed") + } + meta.OBPROXY_SYS_PWD = dsConfig.GetPassword() + releaseDB(db) + return err +} + +func LoadTempObproxyInstance(dsConfig *config.ObDataSourceConfig) (db *gorm.DB, err error) { + gormConfig := gorm.Config{ + Logger: logger.Default.LogMode(dsConfig.GetLoggerLevel()), + NamingStrategy: schema.NamingStrategy{ + SingularTable: constant.DB_SINGULAR_TABLE, + }} + db, err = gorm.Open(driver.OpenObproxy(dsConfig.GetDSN()), &gormConfig) + if err != nil { + return nil, errors.Wrap(err, "load temp obproxy instance failed") + } + return db, nil +} + +func GetObproxyInstance() (*gorm.DB, error) { + if obproxyInstance == nil { + log.Info("obproxy instance is nil, load obproxy instance") + if _, err := LoadObproxyInstance(); err != nil { + return nil, err + } + } + // health check + if err := obproxyInstance.Exec("show proxyconfig").Error; err != nil { + log.WithError(err).Warn("obproxy instance is not available") + return nil, err + } + return obproxyInstance, nil +} + +func releaseDB(preDB *gorm.DB) { + // Delay release db + if preDB != nil { + db, err := preDB.DB() + if err != nil { + log.WithError(err).Warn("release pre db failed") + } + + go func() { + defer func() { + err := recover() + if err != nil { + log.WithError(err.(error)).Warn("release pre db failed") + } + }() + + for db.Stats().InUse != 0 { + log.Debug("pre db is using, wait for release") + time.Sleep(time.Second) + } + db.Close() + }() + } +} diff --git a/agent/repository/db/oceanbase/loader.go b/agent/repository/db/oceanbase/loader.go index 2a96ecbc..a54228e8 100644 --- a/agent/repository/db/oceanbase/loader.go +++ b/agent/repository/db/oceanbase/loader.go @@ -281,3 +281,12 @@ func loadObGormForTest(dsConfig *config.ObDataSourceConfig) error { } return nil } + +func LoadTempOceanbaseInstance(dsConfig *config.ObDataSourceConfig) (*gorm.DB, error) { + db, err := gorm.Open(driver.Open(dsConfig.GetDSN())) + if err != nil { + log.WithError(err).Error("open ob db failed") + return nil, err + } + return db, nil +} diff --git a/agent/repository/db/sqlite/builder.go b/agent/repository/db/sqlite/builder.go index 4ed68d3e..4af7137b 100644 --- a/agent/repository/db/sqlite/builder.go +++ b/agent/repository/db/sqlite/builder.go @@ -119,6 +119,7 @@ var SqliteTables = []interface{}{ sqlite.AllAgent{}, sqlite.ObSysParameter{}, sqlite.OcsInfo{}, + sqlite.ObproxyInfo{}, sqlite.ObGlobalConfig{}, sqlite.ObZoneConfig{}, sqlite.ObServerConfig{}, @@ -130,6 +131,8 @@ var SqliteTables = []interface{}{ sqlite.SubTaskLog{}, sqlite.DagInstance{}, sqlite.NodeInstance{}, + sqlite.UpgradePkgInfo{}, + sqlite.UpgradePkgChunk{}, } // MigrateSqliteTables will check if the sqlite tables exist, if not, it will create them. diff --git a/agent/repository/driver/oceanbase.go b/agent/repository/driver/oceanbase.go index ef0678e1..08e60cb7 100644 --- a/agent/repository/driver/oceanbase.go +++ b/agent/repository/driver/oceanbase.go @@ -33,6 +33,14 @@ func Open(dsn string) gorm.Dialector { } } +func OpenObproxy(dsn string) gorm.Dialector { + mysqlDialector := mysql.Open(dsn).(*mysql.Dialector) + mysqlDialector.Config.SkipInitializeWithVersion = true + return Dialector{ + Dialector: *mysqlDialector, + } +} + func (dialector Dialector) Migrator(db *gorm.DB) gorm.Migrator { mysqlMigrator := mysql.Migrator{ Migrator: migrator.Migrator{ diff --git a/agent/repository/model/bo/obproxy_config.go b/agent/repository/model/bo/obproxy_config.go new file mode 100644 index 00000000..fa3e1a94 --- /dev/null +++ b/agent/repository/model/bo/obproxy_config.go @@ -0,0 +1,27 @@ +/* + * Copyright (c) 2024 OceanBase. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package bo + +type ProxyConfig struct { + Name string `json:"name"` + Value string `json:"value"` +} + +type ObproxyInfo struct { + Name string `json:"name"` + Info string `json:"info"` +} diff --git a/agent/repository/model/oceanbase/dag_instance.go b/agent/repository/model/oceanbase/dag_instance.go index a59994b3..6df75731 100644 --- a/agent/repository/model/oceanbase/dag_instance.go +++ b/agent/repository/model/oceanbase/dag_instance.go @@ -19,6 +19,7 @@ package oceanbase import ( "time" + "github.com/oceanbase/obshell/agent/engine/task" "github.com/oceanbase/obshell/agent/repository/model/bo" ) @@ -45,8 +46,8 @@ type DagInstance struct { func (d *DagInstance) ToBO() *bo.DagInstance { MaintenanceType := d.MaintenanceType - if d.IsMaintenance && MaintenanceType == 1 { - MaintenanceType = 2 + if d.IsMaintenance && MaintenanceType == task.NOT_UNDER_MAINTENANCE { + MaintenanceType = task.GLOBAL_MAINTENANCE } return &bo.DagInstance{ Id: d.Id, diff --git a/agent/repository/model/sqlite/dag_instance.go b/agent/repository/model/sqlite/dag_instance.go index a619066c..1af8443f 100644 --- a/agent/repository/model/sqlite/dag_instance.go +++ b/agent/repository/model/sqlite/dag_instance.go @@ -19,6 +19,7 @@ package sqlite import ( "time" + "github.com/oceanbase/obshell/agent/engine/task" "github.com/oceanbase/obshell/agent/repository/model/bo" ) @@ -42,9 +43,9 @@ type DagInstance struct { } func (d *DagInstance) ToBO() *bo.DagInstance { - MaintenanceType := 0 + MaintenanceType := task.NOT_UNDER_MAINTENANCE if d.IsMaintenance { - MaintenanceType = 2 + MaintenanceType = task.GLOBAL_MAINTENANCE } return &bo.DagInstance{ Id: d.Id, diff --git a/agent/repository/model/sqlite/obproxy_info.go b/agent/repository/model/sqlite/obproxy_info.go new file mode 100644 index 00000000..5f65e6be --- /dev/null +++ b/agent/repository/model/sqlite/obproxy_info.go @@ -0,0 +1,23 @@ +/* + * Copyright (c) 2024 OceanBase. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package sqlite + +type ObproxyInfo struct { + Name string `gorm:"type:varchar(128);not null;unique"` + Value string `gorm:"type:varchar(65536);not null"` + Info string +} diff --git a/agent/repository/model/sqlite/upgrade_pkg_chunk.go b/agent/repository/model/sqlite/upgrade_pkg_chunk.go new file mode 100644 index 00000000..2eafabff --- /dev/null +++ b/agent/repository/model/sqlite/upgrade_pkg_chunk.go @@ -0,0 +1,24 @@ +/* + * Copyright (c) 2024 OceanBase. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package sqlite + +type UpgradePkgChunk struct { + PkgId int `gorm:"primaryKey;column:pkg_id;not null"` + ChunkId int `gorm:"primaryKey;column:chunk_id;not null"` + ChunkCount int `gorm:"not null"` + Chunk []byte `gorm:"type:MEDIUMBLOB;not null"` +} diff --git a/agent/repository/model/sqlite/upgrade_pkg_info.go b/agent/repository/model/sqlite/upgrade_pkg_info.go new file mode 100644 index 00000000..fac4cd54 --- /dev/null +++ b/agent/repository/model/sqlite/upgrade_pkg_info.go @@ -0,0 +1,34 @@ +/* + * Copyright (c) 2024 OceanBase. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package sqlite + +import "time" + +type UpgradePkgInfo struct { + PkgId int `gorm:"primaryKey;autoIncrement;not null"` + Name string `gorm:"type:varchar(128);not null"` + Version string `gorm:"type:varchar(128);not null"` + ReleaseDistribution string `gorm:"type:varchar(128);not null"` + Distribution string `gorm:"type:varchar(128);not null"` + Release string `gorm:"type:varchar(128);not null"` + Architecture string `gorm:"type:varchar(128);not null"` + Size uint64 `gorm:"not null"` + PayloadSize uint64 `gorm:"not null"` + ChunkCount int `gorm:"not null"` + Md5 string `gorm:"type:varchar(128);not null"` + GmtModify time.Time `gorm:"type:TIMESTAMP;default:CURRENT_TIMESTAMP"` +} diff --git a/agent/rpc/agent_handler.go b/agent/rpc/agent_handler.go index d67e6058..264cc71f 100644 --- a/agent/rpc/agent_handler.go +++ b/agent/rpc/agent_handler.go @@ -30,11 +30,39 @@ import ( "github.com/oceanbase/obshell/param" ) -func agentJoinHandler(c *gin.Context) { +func agentAddTokenHandler(c *gin.Context) { if !meta.OCS_AGENT.IsMasterAgent() { common.SendResponse(c, nil, errors.Occurf(errors.ErrBadRequest, "%s:%d is not master", meta.OCS_AGENT.GetIp(), meta.OCS_AGENT.GetPort())) return } + var param param.AddTokenParam + ip := c.RemoteIP() + if err := c.Bind(¶m); err != nil { + return + } + if param.AgentInfo.Ip == "" { + param.AgentInfo.Ip = ip + } + + agentService := agentservice.AgentService{} + agentInstance, err := agentService.FindAgentInstance(¶m.AgentInfo) + if err != nil { + common.SendResponse(c, nil, err) + return + } + if agentInstance != nil { + common.SendResponse(c, nil, errors.Occurf(errors.ErrBadRequest, "%s:%d already exists", agentInstance.Ip, agentInstance.Port)) + return + } + + common.SendResponse(c, nil, agent.AddSingleToken(param)) +} + +func agentJoinHandler(c *gin.Context) { + if !meta.OCS_AGENT.IsMasterAgent() { + common.SendResponse(c, nil, errors.Occurf(errors.ErrBadRequest, "%s is not master", meta.OCS_AGENT.String())) + return + } ip := c.RemoteIP() var param param.JoinMasterParam @@ -53,7 +81,7 @@ func agentJoinHandler(c *gin.Context) { } if agentInstance != nil { - common.SendResponse(c, nil, errors.Occurf(errors.ErrBadRequest, "%s:%d already exists", agentInstance.Ip, agentInstance.Port)) + common.SendResponse(c, nil, errors.Occurf(errors.ErrBadRequest, "%s already exists", agentInstance.String())) return } else { if err := agent.AddFollowerAgent(param); err != nil { @@ -73,7 +101,7 @@ func agentRemoveHandler(c *gin.Context) { } else if meta.OCS_AGENT.IsSingleAgent() { common.SendResponse(c, task.DagDetailDTO{}, nil) } else { - common.SendResponse(c, nil, errors.Occurf(errors.ErrBadRequest, "%s:%d is %s", meta.OCS_AGENT.GetIp(), meta.OCS_AGENT.GetPort(), meta.OCS_AGENT.GetIdentity())) + common.SendResponse(c, nil, errors.Occurf(errors.ErrBadRequest, "%s is %s", meta.OCS_AGENT.String(), meta.OCS_AGENT.GetIdentity())) } } @@ -132,7 +160,7 @@ func updateAllAgentsHandler(c *gin.Context) { func obServerDeployHandler(c *gin.Context) { if !meta.OCS_AGENT.IsFollowerAgent() { - common.SendResponse(c, nil, errors.Occurf(errors.ErrBadRequest, "%s:%d is not follower agent", meta.OCS_AGENT.GetIp(), meta.OCS_AGENT.GetPort())) + common.SendResponse(c, nil, errors.Occurf(errors.ErrBadRequest, "%s is not follower agent", meta.OCS_AGENT.String())) return } var dirs param.DeployTaskParams @@ -151,7 +179,7 @@ func obServerDeployHandler(c *gin.Context) { func obServerDestroyHandler(c *gin.Context) { if !meta.OCS_AGENT.IsFollowerAgent() { - common.SendResponse(c, nil, errors.Occurf(errors.ErrBadRequest, "%s:%d is not follower agent", meta.OCS_AGENT.GetIp(), meta.OCS_AGENT.GetPort())) + common.SendResponse(c, nil, errors.Occurf(errors.ErrBadRequest, "%s is not follower agent", meta.OCS_AGENT.String())) return } dag, err := ob.CreateDestroyDag() @@ -228,7 +256,7 @@ func agentUpdateHandler(c *gin.Context) { func takeOverAgentUpdateBinaryHandler(c *gin.Context) { if !meta.OCS_AGENT.IsTakeover() { - common.SendResponse(c, nil, errors.Occurf(errors.ErrBadRequest, "%s:%d is not takeover agent", meta.OCS_AGENT.GetIp(), meta.OCS_AGENT.GetPort())) + common.SendResponse(c, nil, errors.Occurf(errors.ErrBadRequest, "%s is not takeover agent", meta.OCS_AGENT.String())) return } diff --git a/agent/rpc/agent_route.go b/agent/rpc/agent_route.go index c63a1bb5..d3184068 100644 --- a/agent/rpc/agent_route.go +++ b/agent/rpc/agent_route.go @@ -36,6 +36,7 @@ func InitOcsAgentRpcRoutes(s *http2.State, r *gin.Engine, isLocalRoute bool) { agent := v1.Group(constant.URI_AGENT_GROUP) agent.POST("", agentJoinHandler) + agent.POST(constant.URI_TOKEN, agentAddTokenHandler) agent.DELETE("", agentRemoveHandler) agent.POST(constant.URI_UPDATE, agentUpdateHandler) agent.POST(constant.URI_SYNC_BIN, takeOverAgentUpdateBinaryHandler) diff --git a/agent/secure/auth.go b/agent/secure/auth.go index d82114e3..6069cc1f 100644 --- a/agent/secure/auth.go +++ b/agent/secure/auth.go @@ -18,46 +18,71 @@ package secure import ( "errors" + "fmt" "strconv" "strings" log "github.com/sirupsen/logrus" "github.com/oceanbase/obshell/agent/config" + "github.com/oceanbase/obshell/agent/meta" "github.com/oceanbase/obshell/agent/repository/db/oceanbase" ) -type AgentAuth struct { - Password string - Ts int64 +type RouteType int +type VerifyType int + +const ( + ROUTE_OCEANBASE RouteType = iota + ROUTE_OBPROXY + ROUTE_TASK + + OCEANBASE_PASSWORD VerifyType = iota + AGENT_PASSWORD +) + +func VerifyTimeStamp(ts string, curTs int64) error { + tsInt, err := strconv.ParseInt(ts, 10, 64) + if err != nil { + log.WithError(err).Errorf("parse ts failed, ts:%v", ts) + return err + } + if curTs > int64(tsInt) { + log.Warnf("auth expired at: %v, current: %v", ts, curTs) + return errors.New("auth expired") + } + return nil } -func VerifyAuth(pwd string, ts string, curTs int64) error { +func VerifyAuth(pwd string, ts string, curTs int64, verifyType VerifyType) error { if pwd != "" { - tsInt, err := strconv.ParseInt(ts, 10, 64) - if err != nil { - log.WithError(err).Errorf("parse ts failed, ts:%v", ts) + if err := VerifyTimeStamp(ts, curTs); err != nil { return err } - if curTs > int64(tsInt) { - log.Warnf("auth expired at: %v, current: %v", ts, curTs) - return errors.New("auth expired") - } } - if pwd != meta.OCEANBASE_PWD { - if oceanbase.HasOceanbaseInstance() { - if err := VerifyOceanbasePassword(pwd); err != nil { - return err - } - if err := dumpPassword(); err != nil { - log.WithError(err).Error("dump password failed") - return err + if verifyType == AGENT_PASSWORD { + if pwd != meta.AGENT_PWD.GetPassword() { + log.Infof("agent password is incorrect, pwd:%v, agentPwd:%v", pwd, meta.AGENT_PWD.GetPassword()) + return fmt.Errorf("access denied: %s", "agent password is incorrect") + } + } else if verifyType == OCEANBASE_PASSWORD { + if pwd != meta.OCEANBASE_PWD { + if oceanbase.HasOceanbaseInstance() { + if err := VerifyOceanbasePassword(pwd); err != nil { + return err + } + if err := dumpPassword(); err != nil { + log.WithError(err).Error("dump password failed") + return err + } + } else { + return fmt.Errorf("access denied: %s", "oceanbase password is incorrect") } - } else { - return errors.New("access denied") } + } else { + return errors.New("unknown password type") } return nil } diff --git a/agent/secure/body.go b/agent/secure/body.go index dca99c21..b3d36642 100644 --- a/agent/secure/body.go +++ b/agent/secure/body.go @@ -86,7 +86,7 @@ func EncryptBodyWithRsa(agentInfo meta.AgentInfoInterface, body interface{}) (en } pk := GetAgentPublicKey(agentInfo) if pk == "" { - log.Warnf("no key for agent '%s:%d'", agentInfo.GetIp(), agentInfo.GetPort()) + log.Warnf("no key for agent '%s'", agentInfo.String()) return } encryptedBody, err = crypto.RSAEncrypt(mBody, pk) diff --git a/agent/secure/crypto.go b/agent/secure/crypto.go index b9f9ef0b..53907f77 100644 --- a/agent/secure/crypto.go +++ b/agent/secure/crypto.go @@ -88,19 +88,19 @@ func GetAgentPublicKey(agent meta.AgentInfoInterface) string { pk, err := getPublicKeyByAgentInfo(agent) if err != nil { // Need to query sqlite instead. - log.WithError(err).Errorf("query oceanbase '%s' for '%s:%d' failed", constant.TABLE_ALL_AGENT, agent.GetIp(), agent.GetPort()) + log.WithError(err).Errorf("query oceanbase '%s' for '%s' failed", constant.TABLE_ALL_AGENT, agent.String()) } if pk != "" { err = updateAgentPublicKey(agent, pk) if err != nil { - log.WithError(err).Errorf("update sqlite '%s' for '%s:%d' failed", constant.TABLE_ALL_AGENT, agent.GetIp(), agent.GetPort()) + log.WithError(err).Errorf("update sqlite '%s' for '%s' failed", constant.TABLE_ALL_AGENT, agent.String()) } // Although backup failed, the key should be returned. return pk } pk, err = getPublicKeyByAgentInfo(agent) if err != nil { - log.WithError(err).Errorf("query sqlite '%s' for '%s:%d' failed", constant.TABLE_ALL_AGENT, agent.GetIp(), agent.GetPort()) + log.WithError(err).Errorf("query sqlite '%s' for '%s' failed", constant.TABLE_ALL_AGENT, agent.String()) } if pk != "" { return pk @@ -113,12 +113,12 @@ func GetAgentPublicKey(agent meta.AgentInfoInterface) string { return "" } -// LoadPassword will load password from environment variable or sqlite. -func LoadPassword(password *string) error { +// LoadOceanbasePassword will load password from environment variable or sqlite. +func LoadOceanbasePassword(password *string) error { if password == nil { rootPwd, isSet := syscall.Getenv(constant.OB_ROOT_PASSWORD) if !isSet { - return CheckPasswordInSqlite() + return CheckObPasswordInSqlite() } log.Info("get password from environment variable") password = &rootPwd @@ -129,11 +129,33 @@ func LoadPassword(password *string) error { // clear root password, avoid to cover sqlite when agent restart syscall.Unsetenv(constant.OB_ROOT_PASSWORD) meta.SetOceanbasePwd(*password) - go dumpTempPassword(*password) + go dumpTempObPassword(*password) return nil } -func dumpTempPassword(pwd string) { +func LoadAgentPassword() error { + var pwd string + err := getOCSInfo(constant.CONFIG_AGENT_PASSWORD, &pwd) + if err != nil { + if errors.Is(err, gorm.ErrRecordNotFound) { + log.Info("no password in sqlite") + return nil + } + return err + } + // Decrypt password + if pwd != "" { + pwd, err = Crypter.Decrypt(pwd) + if err != nil { + return err + } + } + meta.AGENT_PWD.SetPassword(pwd) + return nil + +} + +func dumpTempObPassword(pwd string) { log.Info("current password is temporary, will dump it into sqlite") for meta.OCEANBASE_PWD == pwd { if oceanbase.HasOceanbaseInstance() { @@ -148,9 +170,9 @@ func dumpTempPassword(pwd string) { } } -// CheckPasswordInSqlite will try connecting ob using password stored in sqlite. -func CheckPasswordInSqlite() error { - log.Info("retore password from sqlite") +// CheckObPasswordInSqlite will try connecting ob using password stored in sqlite. +func CheckObPasswordInSqlite() error { + log.Info("retore password of root@sys from sqlite") password, err := getCipherPassword() if err != nil { if !errors.Is(err, gorm.ErrRecordNotFound) { @@ -186,6 +208,19 @@ func dumpPassword() error { return updateOBConifg(constant.CONFIG_ROOT_PWD, passwrod) } +func dumpObproxyPassword() error { + passwrod := meta.OBPROXY_SYS_PWD + if meta.OBPROXY_SYS_PWD != "" { + cipherPassword, err := Crypter.Encrypt(meta.OBPROXY_SYS_PWD) + if err != nil { + log.WithError(err).Error("encrypt password failed") + return err + } + passwrod = cipherPassword + } + return updateObproxyConfig(constant.OBPROXY_CONFIG_OBPROXY_SYS_PASSWORD, passwrod) +} + func EncryptPwdInObConfigs(configs []sqlite.ObConfig) (err error) { for i := range configs { if configs[i].Name == constant.CONFIG_ROOT_PWD && configs[i].Value != "" { diff --git a/agent/secure/handler.go b/agent/secure/handler.go index 2d7ca2ae..a5dbac6d 100644 --- a/agent/secure/handler.go +++ b/agent/secure/handler.go @@ -34,7 +34,7 @@ func GetSecret(ctx context.Context) *meta.AgentSecret { } func sendGetSecretApi(agentInfo meta.AgentInfoInterface) *meta.AgentSecret { - log.Infof("Send get secret request from '%s:%d'", agentInfo.GetIp(), agentInfo.GetPort()) + log.Infof("Send get secret request from '%s'", agentInfo.String()) ret := &meta.AgentSecret{} err := http.SendGetRequest(agentInfo, "/api/v1/secret", nil, ret) if err != nil { diff --git a/agent/secure/header.go b/agent/secure/header.go index 58a20b51..94aaa296 100644 --- a/agent/secure/header.go +++ b/agent/secure/header.go @@ -20,13 +20,12 @@ import ( "fmt" "time" - "github.com/oceanbase/obshell/agent/lib/json" - log "github.com/sirupsen/logrus" "github.com/oceanbase/obshell/agent/constant" "github.com/oceanbase/obshell/agent/errors" "github.com/oceanbase/obshell/agent/lib/crypto" + "github.com/oceanbase/obshell/agent/lib/json" "github.com/oceanbase/obshell/agent/meta" ) @@ -46,12 +45,27 @@ type HttpHeader struct { ForwardAgent meta.AgentInfo } +func BuildAgentHeader(agentInfo meta.AgentInfoInterface, password string, uri string, isForword bool, keys ...[]byte) map[string]string { + auth := buildHeader(agentInfo, password, uri, isForword, keys...) + header := map[string]string{ + constant.OCS_AGENT_HEADER: auth, + } + return header +} + func BuildHeader(agentInfo meta.AgentInfoInterface, uri string, isForword bool, keys ...[]byte) map[string]string { - headers := make(map[string]string) + auth := buildHeader(agentInfo, meta.OCEANBASE_PWD, uri, isForword, keys...) + header := map[string]string{ + constant.OCS_HEADER: auth, + } + return header +} + +func buildHeader(agentInfo meta.AgentInfoInterface, password string, uri string, isForword bool, keys ...[]byte) string { pk := GetAgentPublicKey(agentInfo) if pk == "" { - log.Warnf("no key for agent '%s:%d'", agentInfo.GetIp(), agentInfo.GetPort()) - return nil + log.Warnf("no key for agent '%s'", agentInfo.String()) + return "" } var token string @@ -68,7 +82,7 @@ func BuildHeader(agentInfo meta.AgentInfoInterface, uri string, isForword bool, aesKeys = append(keys[0], keys[1]...) } header := HttpHeader{ - Auth: meta.OCEANBASE_PWD, + Auth: password, Ts: fmt.Sprintf("%d", time.Now().Add(getAuthExpiredDuration()).Unix()), Token: token, Uri: uri, @@ -83,15 +97,14 @@ func BuildHeader(agentInfo meta.AgentInfoInterface, uri string, isForword bool, mAuth, err := json.Marshal(header) if err != nil { log.WithError(err).Error("json marshal failed") - return nil + return "" } auth, err := crypto.RSAEncrypt(mAuth, pk) if err != nil { log.WithError(err).Error("rsa encrypt failed") - return nil + return "" } - headers[constant.OCS_HEADER] = auth - return headers + return auth } func DecryptHeader(ciphertext string) (HttpHeader, error) { diff --git a/agent/secure/http.go b/agent/secure/http.go index e66ee999..e7c632fc 100644 --- a/agent/secure/http.go +++ b/agent/secure/http.go @@ -144,3 +144,12 @@ func BuildBody(agentInfo meta.AgentInfoInterface, param interface{}) (encryptedB func BuildHeaderForForward(agentInfo meta.AgentInfoInterface, uri string, keys ...[]byte) map[string]string { return BuildHeader(agentInfo, uri, true, keys...) } + +func SendRequestWithPassword(agentInfo meta.AgentInfoInterface, uri string, method string, agentPassword string, param interface{}, ret interface{}) error { + encryptedBody, Key, Iv, err := BuildBody(agentInfo, param) + if err != nil { + return errors.Wrap(err, "build body failed") + } + header := BuildAgentHeader(agentInfo, agentPassword, uri, false, Key, Iv) + return http.SendRequestAndBuildReturn(agentInfo, uri, method, encryptedBody, ret, header) +} diff --git a/agent/secure/service.go b/agent/secure/service.go index d19f6e41..e8c66d69 100644 --- a/agent/secure/service.go +++ b/agent/secure/service.go @@ -74,6 +74,22 @@ func updateOBConifg(key string, value interface{}) (err error) { return updateOBConifgInTransaction(db, key, value) } +func updateObproxyConfig(key string, value interface{}) (err error) { + db, err := sqlitedb.GetSqliteInstance() + if err != nil { + return + } + data := map[string]interface{}{ + "name": key, + "value": value, + } + err = db.Model(obConfigModel).Clauses(clause.OnConflict{ + Columns: []clause.Column{{Name: "name"}}, + DoUpdates: clause.AssignmentColumns([]string{"value"}), + }).Create(data).Error + return +} + func updateOBConifgInTransaction(tx *gorm.DB, key string, value interface{}) (err error) { data := map[string]interface{}{ "name": key, diff --git a/agent/secure/token.go b/agent/secure/token.go index b4c8202f..32e2cf6b 100644 --- a/agent/secure/token.go +++ b/agent/secure/token.go @@ -26,7 +26,13 @@ import ( // NewToken generates a token for the agent to join/scale-out an existing cluster func NewToken(targetAgent meta.AgentInfoInterface) (string, error) { - token := uuid.New().String() + token, err := getTokenByAgentInfo(meta.OCS_AGENT) + if err != nil { + return "", err + } + if token == "" { + token = uuid.New().String() + } if err := updateToken(meta.OCS_AGENT, token); err != nil { return "", err } @@ -42,7 +48,7 @@ func VerifyToken(token string) error { if err != nil { return err } - if agentToken != token { + if agentToken != token || token == "" { return errors.New("wrong token") } return nil diff --git a/agent/service/agent/all_agents.go b/agent/service/agent/all_agents.go index 02d3f9e6..8f3b9aa8 100644 --- a/agent/service/agent/all_agents.go +++ b/agent/service/agent/all_agents.go @@ -215,6 +215,25 @@ func (s *AgentService) addAgentToken(tx *gorm.DB, agentInfo meta.AgentInfoInterf }).Create(&ocsToken).Error } +func (s *AgentService) AddSingleToken(agentInfo meta.AgentInfoInterface, token string) error { + if token == "" { + return nil + } + db, err := sqlitedb.GetSqliteInstance() + if err != nil { + return err + } + ocsToken := sqlite.OcsToken{ + Ip: agentInfo.GetIp(), + Port: agentInfo.GetPort(), + Token: token, + } + return db.Model(&sqlite.OcsToken{}).Clauses(clause.OnConflict{ + Columns: []clause.Column{{Name: "ip"}, {Name: "port"}}, + DoUpdates: clause.AssignmentColumns([]string{"token"}), + }).Create(&ocsToken).Error +} + func (s *AgentService) addAgent(db *gorm.DB, agentInstance meta.Agent, homePath string, os string, arch string, publicKey string) error { agent := &sqlite.AllAgent{ Ip: agentInstance.GetIp(), @@ -230,20 +249,6 @@ func (s *AgentService) addAgent(db *gorm.DB, agentInstance meta.Agent, homePath return db.Create(agent).Error } -func (s *AgentService) UpdateAgent(agentInstance meta.Agent, homePath string, os string, arch string, publicKey string, token string) error { - db, err := sqlitedb.GetSqliteInstance() - if err != nil { - return err - } - return db.Transaction(func(tx *gorm.DB) error { - err = s.addAgentToken(tx, agentInstance, token) - if err == nil { - err = s.updateAgent(tx, agentInstance, homePath, os, arch, publicKey) - } - return err - }) -} - func (s *AgentService) UpdateAgentOBPort(agent meta.AgentInfoInterface, mysqlPort, rpcPort int) error { db, err := sqlitedb.GetSqliteInstance() if err != nil { @@ -394,7 +399,7 @@ func (s *AgentService) CheckCanBeTakeOverMaster() (bool, error) { } if !self_exist { - return false, fmt.Errorf("%s:%d not in cluster", meta.OCS_AGENT.GetIp(), meta.RPC_PORT) + return false, fmt.Errorf("%s not in cluster", meta.NewAgentInfo(meta.OCS_AGENT.GetIp(), meta.RPC_PORT).String()) } return other_exist, nil diff --git a/agent/service/agent/binary.go b/agent/service/agent/binary.go index 02c8339d..dd83cd9e 100644 --- a/agent/service/agent/binary.go +++ b/agent/service/agent/binary.go @@ -85,6 +85,9 @@ func (s *AgentService) UpgradeBinary() error { if err := tx.Exec("SET SESSION ob_query_timeout=1000000000").Error; err != nil { return err } + if err := tx.Exec("SET SESSION ob_trx_timeout=1000000000").Error; err != nil { + return err + } info := &oceanbase.AgentBinaryInfo{ Version: constant.VERSION, diff --git a/agent/service/agent/config.go b/agent/service/agent/config.go index 4cd2982b..300cd9ac 100644 --- a/agent/service/agent/config.go +++ b/agent/service/agent/config.go @@ -26,6 +26,7 @@ import ( "github.com/oceanbase/obshell/agent/meta" sqlitedb "github.com/oceanbase/obshell/agent/repository/db/sqlite" "github.com/oceanbase/obshell/agent/repository/model/sqlite" + "github.com/oceanbase/obshell/agent/secure" ) func (s *AgentService) UpdatePort(mysqlPort, rpcPort int) error { @@ -105,3 +106,26 @@ func (s *AgentService) getOBConifg(db *gorm.DB, name string, value interface{}) } return err } + +func (s *AgentService) SetAgentPassword(password string) error { + sqliteDb, err := sqlitedb.GetSqliteInstance() + if err != nil { + return err + } + encrptyPassword, err := secure.Encrypt(password) + if err != nil { + return err + } + + if err := sqliteDb.Clauses(clause.OnConflict{ + Columns: []clause.Column{{Name: "name"}}, + DoUpdates: clause.AssignmentColumns([]string{"value"}), + }).Create(&sqlite.OcsInfo{ + Name: constant.CONFIG_AGENT_PASSWORD, + Value: encrptyPassword}).Error; err != nil { + return err + } + + meta.AGENT_PWD.SetPassword(password) + return nil +} diff --git a/agent/service/agent/enter.go b/agent/service/agent/enter.go index e314e671..1fe5dc84 100644 --- a/agent/service/agent/enter.go +++ b/agent/service/agent/enter.go @@ -28,6 +28,7 @@ import ( "github.com/oceanbase/obshell/agent/meta" sqlitedb "github.com/oceanbase/obshell/agent/repository/db/sqlite" "github.com/oceanbase/obshell/agent/repository/model/sqlite" + "github.com/oceanbase/obshell/agent/secure" ) var ( @@ -75,10 +76,50 @@ func (s *AgentService) InitAgent() error { } default: } + + if err := s.initObproxy(); err != nil { + return err + } + meta.OCS_AGENT = ocsAgent return nil } +// initObproxy will initialize obproxy info of the agent. +func (s *AgentService) initObproxy() (err error) { + db, err := sqlitedb.GetSqliteInstance() + if err != nil { + return + } + + if err = db.Model(&sqlite.ObproxyInfo{}). + Select("value"). + Where("name = ?", constant.OBPROXY_INFO_HOME_PATH). + Scan(&meta.OBPROXY_HOME_PATH).Error; err != nil { + return + } + + if err = db.Model(&sqlite.ObproxyInfo{}). + Select("value"). + Where("name = ?", constant.OBPROXY_INFO_SQL_PORT). + Scan(&meta.OBPROXY_SQL_PORT).Error; err != nil { + return + } + + encryptedSysPwd := "" + if err = db.Model(&sqlite.ObproxyInfo{}). + Select("value"). + Where("name = ?", constant.OBPROXY_CONFIG_OBPROXY_SYS_PASSWORD). + Scan(&encryptedSysPwd).Error; err != nil { + return err + } + if meta.OBPROXY_SYS_PWD, err = secure.Decrypt(encryptedSysPwd); err != nil { + return err + } + + return nil +} + func (s *AgentService) initOBPort() error { sqliteDb, err := sqlitedb.GetSqliteInstance() if err != nil { @@ -106,11 +147,17 @@ func (agentService *AgentService) InitializeAgentStatus() (err error) { } if err = db.Create(&sqlite.OcsInfo{Name: constant.OCS_INFO_STATUS, Value: strconv.Itoa(task.NOT_UNDER_MAINTENANCE)}).Error; err != nil { sqliteErr, ok := err.(sqlite3.Error) - if ok && sqliteErr.Code == sqlite3.ErrConstraint { - return nil + if !ok || sqliteErr.Code != sqlite3.ErrConstraint { + return } } - return + if err = db.Create(&sqlite.ObproxyInfo{Name: constant.OCS_INFO_STATUS, Value: strconv.Itoa(task.NOT_UNDER_MAINTENANCE)}).Error; err != nil { + sqliteErr, ok := err.(sqlite3.Error) + if !ok || sqliteErr.Code != sqlite3.ErrConstraint { + return + } + } + return nil } func (s *AgentService) getAgentInfo() (agentInfo meta.AgentInstance, err error) { diff --git a/agent/service/agent/info.go b/agent/service/agent/info.go index b9292093..fba2c2d2 100644 --- a/agent/service/agent/info.go +++ b/agent/service/agent/info.go @@ -308,6 +308,7 @@ func (s *AgentService) BeScalingOutAgent(zone string) error { }) } + func (s *AgentService) SyncAgentData() (err error) { oceanbaseDb, err := oceanbasedb.GetOcsInstance() if err != nil { diff --git a/agent/service/agent/obproxy.go b/agent/service/agent/obproxy.go new file mode 100644 index 00000000..db355104 --- /dev/null +++ b/agent/service/agent/obproxy.go @@ -0,0 +1,134 @@ +/* + * Copyright (c) 2024 OceanBase. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package agent + +import ( + "os" + "strconv" + + "github.com/oceanbase/obshell/agent/constant" + "github.com/oceanbase/obshell/agent/meta" + sqlitedb "github.com/oceanbase/obshell/agent/repository/db/sqlite" + "github.com/oceanbase/obshell/agent/repository/model/sqlite" + "gorm.io/gorm" + "gorm.io/gorm/clause" +) + +func (*AgentService) DeleteObproxy() error { + db, err := sqlitedb.GetSqliteInstance() + if err != nil { + return err + } + return db.Transaction(func(tx *gorm.DB) error { + if err := tx.Exec("DELETE FROM obproxy_info").Error; err != nil { + return err + } + meta.OBPROXY_HOME_PATH = "" + meta.OBPROXY_SQL_PORT = 0 + return nil + }) +} + +func (*AgentService) AddObproxy(homePath string, sqlPort int, version, enObproxySysPwd, enObproxyProxyroPwd string) error { + db, err := sqlitedb.GetSqliteInstance() + if err != nil { + return err + } + infos := make(map[string]string) + infos[constant.OBPROXY_INFO_OBPROXY_SYS_PASSWORD] = enObproxySysPwd + infos[constant.OBPROXY_INFO_PROXYRO_PASSWORD] = enObproxyProxyroPwd + infos[constant.OBPROXY_INFO_HOME_PATH] = homePath + infos[constant.OBPROXY_INFO_SQL_PORT] = strconv.Itoa(sqlPort) + infos[constant.OBPROXY_INFO_VERSION] = version + + return db.Transaction(func(tx *gorm.DB) error { + for k, v := range infos { + // create or update + if err := tx.Clauses(clause.OnConflict{ + Columns: []clause.Column{{Name: "name"}}, + DoUpdates: clause.AssignmentColumns([]string{"value"}), + }).Create(&sqlite.ObproxyInfo{ + Name: k, + Value: v, + }).Error; err != nil { + return err + } + } + meta.OBPROXY_HOME_PATH = homePath + return nil + }) +} + +func (*AgentService) GetUpgradePkgInfoByVersion(name, version, arch, distribution string, deprecatedInfo []string) (pkgInfo sqlite.UpgradePkgInfo, err error) { + db, err := sqlitedb.GetSqliteInstance() + if err != nil { + return + } + if len(deprecatedInfo) == 0 { + err = db.Model(&sqlite.UpgradePkgInfo{}).Where("name = ? and version = ? and distribution = ? and architecture = ? ", name, version, arch, distribution).Last(&pkgInfo).Error + } else { + err = db.Model(&sqlite.UpgradePkgInfo{}).Where("name = ? and version = ? and distribution = ? and architecture = ? and `release` not in ?", name, version, distribution, arch, deprecatedInfo).Last(&pkgInfo).Error + } + return +} + +func (*AgentService) GetUpgradePkgInfoByVersionAndRelease(name, version, release, distribution, arch string) (pkgInfo sqlite.UpgradePkgInfo, err error) { + db, err := sqlitedb.GetSqliteInstance() + if err != nil { + return + } + err = db.Model(&sqlite.UpgradePkgInfo{}).Where("name = ? and version = ? and distribution = ? and architecture = ? and `release` = ?", name, version, distribution, arch, release).Last(&pkgInfo).Error + return +} + +func (agentService *AgentService) DownloadUpgradePkgChunkInBatch(filepath string, pkgId, chunkCount int) error { + file, err := os.OpenFile(filepath, os.O_WRONLY|os.O_CREATE|os.O_APPEND, 0755) + if err != nil { + return err + } + defer file.Close() + + for i := 0; i < chunkCount; i++ { + chunk, err := agentService.GetUpgradePkgChunkByPkgIdAndChunkId(pkgId, i) + if err != nil { + return err + } + _, err = file.Write(chunk.Chunk) + if err != nil { + return err + } + } + return nil +} + +func (agentService *AgentService) GetUpgradePkgChunkByPkgIdAndChunkId(pkgId, chunkId int) (chunk sqlite.UpgradePkgChunk, err error) { + db, err := sqlitedb.GetSqliteInstance() + if err != nil { + return chunk, err + } + err = db.Model(&sqlite.UpgradePkgChunk{}).Where("pkg_id = ? and chunk_id = ?", pkgId, chunkId).First(&chunk).Error + return +} + +func (agentService *AgentService) GetUpgradePkgChunkCountByPkgId(pkgId int) (count int64, err error) { + db, err := sqlitedb.GetSqliteInstance() + if err != nil { + return 0, err + } + err = db.Model(&sqlite.UpgradePkgChunk{}).Where("pkg_id = ?", pkgId).Count(&count).Error + return +} diff --git a/agent/service/obcluster/config.go b/agent/service/obcluster/config.go index ad30fb82..18e72222 100644 --- a/agent/service/obcluster/config.go +++ b/agent/service/obcluster/config.go @@ -43,6 +43,15 @@ func (s *ObserverService) GetObConfigByName(name string) (config sqlite.ObConfig return } +func (s *ObserverService) GetObConfigValueByName(name string, val interface{}) (err error) { + db, err := sqlitedb.GetSqliteInstance() + if err != nil { + return + } + err = db.Model(&sqlite.ObConfig{}).Select("value").Where("name = ?", name).Scan(val).Error + return +} + func (s *ObserverService) GetOBParatemerByName(name string, value interface{}) (err error) { db, err := oceanbase.GetInstance() if err != nil { diff --git a/agent/service/obcluster/obcluster.go b/agent/service/obcluster/obcluster.go index 38e412ee..15e1756b 100644 --- a/agent/service/obcluster/obcluster.go +++ b/agent/service/obcluster/obcluster.go @@ -125,10 +125,10 @@ func (obclusterService *ObclusterService) MinorFreeze(servers []oceanbase.OBServ } var targetCmd []string for _, server := range servers { - targetCmd = append(targetCmd, fmt.Sprintf("'%s:%d'", server.SvrIp, server.SvrPort)) + targetCmd = append(targetCmd, meta.NewAgentInfo(server.SvrIp, server.SvrPort).String()) } - serverList := strings.Join(targetCmd, ",") - sql := fmt.Sprintf("alter system minor freeze server = (%[1]s);", serverList) + serverList := strings.Join(targetCmd, "','") + sql := fmt.Sprintf("alter system minor freeze server = ('%[1]s');", serverList) return db.Exec(sql).Error } @@ -234,21 +234,21 @@ func (obclusterService *ObclusterService) GetUpgradePkgInfoByVersionAndRelease(n return } -func (obclusterService *ObclusterService) AddServer(ip, port, zoneName string) (err error) { +func (obclusterService *ObclusterService) AddServer(svrInfo meta.ObserverSvrInfo, zoneName string) (err error) { db, err := oceanbasedb.GetInstance() if err != nil { return err } - alterSql := fmt.Sprintf("ALTER SYSTEM ADD SERVER '%s:%s' ZONE '%s'", ip, port, zoneName) + alterSql := fmt.Sprintf("ALTER SYSTEM ADD SERVER '%s' ZONE '%s'", svrInfo.String(), zoneName) return db.Exec(alterSql).Error } -func (obclusterService *ObclusterService) DeleteServerInZone(ip, port, zoneName string) (err error) { +func (obclusterService *ObclusterService) DeleteServerInZone(svrInfo meta.ObserverSvrInfo, zoneName string) (err error) { db, err := oceanbasedb.GetInstance() if err != nil { return err } - alterSql := fmt.Sprintf("ALTER SYSTEM DELETE SERVER '%s:%s' ZONE '%s'", ip, port, zoneName) + alterSql := fmt.Sprintf("ALTER SYSTEM DELETE SERVER '%s' ZONE '%s'", svrInfo.String(), zoneName) return db.Exec(alterSql).Error } @@ -257,7 +257,7 @@ func (obclusterService *ObclusterService) DeleteServer(svrInfo meta.ObserverSvrI if err != nil { return err } - alterSql := fmt.Sprintf("ALTER SYSTEM DELETE SERVER '%s:%d'", svrInfo.GetIp(), svrInfo.GetPort()) + alterSql := fmt.Sprintf("ALTER SYSTEM DELETE SERVER '%s'", svrInfo.String()) return db.Exec(alterSql).Error } @@ -266,30 +266,30 @@ func (ObclusterService *ObclusterService) CancelDeleteServer(svrInfo meta.Observ if err != nil { return err } - alterSql := fmt.Sprintf("ALTER SYSTEM CANCEL DELETE SERVER '%s:%d'", svrInfo.GetIp(), svrInfo.GetPort()) + alterSql := fmt.Sprintf("ALTER SYSTEM CANCEL DELETE SERVER '%s'", svrInfo.String()) return db.Exec(alterSql).Error } -func (obclusterService *ObclusterService) IsServerExist(ip string, port string) (bool, error) { +func (obclusterService *ObclusterService) IsServerExist(svrInfo meta.ObserverSvrInfo) (bool, error) { db, err := oceanbasedb.GetInstance() if err != nil { return false, err } var count int - err = db.Raw("select count(*) from oceanbase.dba_ob_servers where svr_ip = ? and svr_port = ?", ip, port).First(&count).Error + err = db.Raw("select count(*) from oceanbase.dba_ob_servers where svr_ip = ? and svr_port = ?", svrInfo.GetIp(), svrInfo.GetPort()).First(&count).Error if err != nil { return false, err } return count > 0, nil } -func (obclusterService *ObclusterService) IsServerExistWithZone(ip string, port string, zone string) (bool, error) { +func (obclusterService *ObclusterService) IsServerExistWithZone(svrInfo meta.ObserverSvrInfo, zone string) (bool, error) { db, err := oceanbasedb.GetInstance() if err != nil { return false, err } var count int64 - err = db.Table(DBA_OB_SERVERS).Where("svr_ip = ? and svr_port = ? and zone = ?", ip, port, zone).Count(&count).Error + err = db.Table(DBA_OB_SERVERS).Where("svr_ip = ? and svr_port = ? and zone = ?", svrInfo.GetIp(), svrInfo.GetPort(), zone).Count(&count).Error if err != nil { return false, err } @@ -643,7 +643,7 @@ func (obclusterService *ObclusterService) RestoreParamsForUpgrade(params []ocean } sql = fmt.Sprintf("ALTER SYSTEM SET %s = '%s' TENANT = %s", param.Name, param.Value, tenantName) case "CLUSTER": - sql = fmt.Sprintf("ALTER SYSTEM SET %s = '%s' SERVER = '%s:%d'", param.Name, param.Value, param.SvrIp, param.SvrPort) + sql = fmt.Sprintf("ALTER SYSTEM SET %s = '%s' SERVER = '%s'", param.Name, param.Value, meta.NewAgentInfo(param.SvrIp, param.SvrPort).String()) default: return errors.New("unknown scope") } @@ -737,7 +737,7 @@ func (ObclusterService *ObclusterService) IsLsMultiPaxosAlive(lsId int, tenantId // GetLogInfosInServer returns the log stat in target server // only contains tenant_id and ls_id. -func (ObclusterService *ObclusterService) GetLogInfosInServer(svrInfo meta.ObserverSvrInfo) (logStats []oceanbase.ObLogStat, err error) { +func (*ObclusterService) GetLogInfosInServer(svrInfo meta.ObserverSvrInfo) (logStats []oceanbase.ObLogStat, err error) { oceanbaseDb, err := oceanbasedb.GetInstance() if err != nil { return nil, err @@ -746,7 +746,7 @@ func (ObclusterService *ObclusterService) GetLogInfosInServer(svrInfo meta.Obser return } -func (ObclusterService *ObclusterService) HasUnitInZone(zone string) (exist bool, err error) { +func (*ObclusterService) HasUnitInZone(zone string) (exist bool, err error) { oceanbaseDb, err := oceanbasedb.GetInstance() if err != nil { return false, err @@ -755,3 +755,34 @@ func (ObclusterService *ObclusterService) HasUnitInZone(zone string) (exist bool err = oceanbaseDb.Table(DBA_OB_UNITS).Where("ZONE = ?", zone).Count(&count).Error return count > 0, err } + +func (obclusterService *ObclusterService) CreateProxyroUser(password string) error { + oceanbaseDb, err := oceanbasedb.GetInstance() + if err != nil { + return err + } + + sqlText := fmt.Sprintf("CREATE USER IF NOT EXISTS `%s`@`%s`", constant.SYS_USER_PROXYRO, "%") + if password != "" { + sqlText += fmt.Sprintf(" IDENTIFIED BY '%s'", strings.ReplaceAll(password, "'", "'\"'\"'")) + } + if err = oceanbaseDb.Exec(sqlText).Error; err != nil { + return err + } + if err := oceanbaseDb.Exec(fmt.Sprintf("GRANT SELECT ON oceanbase.* TO %s", constant.SYS_USER_PROXYRO)).Error; err != nil { + return err + } + return nil +} + +func (obclusterService *ObclusterService) GetRsListStr() (rsListStr string, err error) { + oceanbaseDb, err := oceanbasedb.GetInstance() + if err != nil { + return "", err + } + err = oceanbaseDb.Table(GV_OB_PARAMETERS). + Select("VALUE"). + Where("NAME = ?", "rootservice_list"). + Scan(&rsListStr).Error + return +} diff --git a/agent/service/obproxy/enter.go b/agent/service/obproxy/enter.go new file mode 100644 index 00000000..14483e2c --- /dev/null +++ b/agent/service/obproxy/enter.go @@ -0,0 +1,25 @@ +/* + * Copyright (c) 2024 OceanBase. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package obproxy + +type ObproxyService struct{} + +const ( + OBPROXY_PROXYRO_USERNAME = "proxyro" + + GV_OB_PARAMETERS = "oceanbase.GV$OB_PARAMETERS" +) diff --git a/agent/service/obproxy/obproxy.go b/agent/service/obproxy/obproxy.go new file mode 100644 index 00000000..7710c7eb --- /dev/null +++ b/agent/service/obproxy/obproxy.go @@ -0,0 +1,184 @@ +/* + * Copyright (c) 2024 OceanBase. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package obproxy + +import ( + "encoding/hex" + "fmt" + "mime/multipart" + "regexp" + "strconv" + "strings" + + "github.com/cavaliergopher/rpm" + "github.com/oceanbase/obshell/agent/constant" + "github.com/oceanbase/obshell/agent/errors" + "github.com/oceanbase/obshell/agent/meta" + obproxydb "github.com/oceanbase/obshell/agent/repository/db/obproxy" + sqlitedb "github.com/oceanbase/obshell/agent/repository/db/sqlite" + "github.com/oceanbase/obshell/agent/secure" + + "gorm.io/gorm" + "gorm.io/gorm/clause" + + "github.com/oceanbase/obshell/agent/repository/model/bo" + "github.com/oceanbase/obshell/agent/repository/model/sqlite" +) + +func (obproxyService *ObproxyService) SetSysPassword(password string) (err error) { + return obproxyService.SetGlobalConfig(constant.OBPROXY_CONFIG_OBPROXY_SYS_PASSWORD, password) +} + +func (obproxyService *ObproxyService) SetProxyroPassword(password string) error { + return obproxyService.SetGlobalConfig(constant.OBPROXY_CONFIG_PROXYRO_PASSWORD, password) +} + +func (*ObproxyService) SetGlobalConfig(name string, value string) error { + db, err := obproxydb.GetObproxyInstance() + if err != nil { + return err + } + + if err := db.Exec(fmt.Sprintf("ALTER proxyconfig SET %s = %s ", name, value)).Error; err != nil { + return err + } + return nil +} + +func (*ObproxyService) GetObproxyVersion() (version string, err error) { + db, err := obproxydb.GetObproxyInstance() + if err != nil { + return + } + var proxyInfo bo.ObproxyInfo + if err = db.Raw("show proxyinfo binary").Scan(&proxyInfo).Error; err != nil { + return "", err + } + // parse obproxy version + re := regexp.MustCompile(`\d+\.\d+\.\d+\.\d+-\d+`) + version = re.FindString(proxyInfo.Info) + return version, err +} + +func (*ObproxyService) GetGlobalConfig(name string) (value string, err error) { + db, err := obproxydb.GetObproxyInstance() + if err != nil { + return + } + var proxyConfig bo.ProxyConfig + err = db.Raw(fmt.Sprintf("show proxyconfig like '%s'", name)).Scan(&proxyConfig).Error + return proxyConfig.Value, err +} + +func (obproxyService *ObproxyService) UpdateSqlPort(sqlPort int) (err error) { + if err := obproxyService.UpdateObproxyInfo(constant.OBPROXY_INFO_SQL_PORT, strconv.Itoa(sqlPort)); err != nil { + return err + } + meta.OBPROXY_SQL_PORT = sqlPort + return nil +} + +func (obproxyService *ObproxyService) UpdateObproxySysPassword(obproxySysPassword string) (err error) { + encryptPwd, err := secure.Encrypt(obproxySysPassword) + if err != nil { + return err + } + if err := obproxyService.UpdateObproxyInfo(constant.OBPROXY_INFO_OBPROXY_SYS_PASSWORD, encryptPwd); err != nil { + return err + } + meta.OBPROXY_SYS_PWD = obproxySysPassword + return nil +} + +func (obproxyService *ObproxyService) UpdateObproxyInfo(name string, value string) (err error) { + db, err := sqlitedb.GetSqliteInstance() + if err != nil { + return + } + obproxyInfo := &sqlite.ObproxyInfo{ + Name: name, + Value: value, + } + return db.Model(obproxyInfo).Clauses(clause.OnConflict{ + Columns: []clause.Column{{Name: "name"}}, + DoUpdates: clause.AssignmentColumns([]string{"value"}), + }).Create(obproxyInfo).Error +} + +func (*ObproxyService) GetObclusterName(db *gorm.DB) (name string, err error) { + err = db.Table(GV_OB_PARAMETERS).Where("name = ?", "cluster").Select("value").Scan(&name).Error + return +} + +func (*ObproxyService) ClearObproxyInfo() (err error) { + db, err := sqlitedb.GetSqliteInstance() + if err != nil { + return + } + return db.Delete(&sqlite.ObproxyInfo{}).Error +} + +func (*ObproxyService) DumpUpgradePkgInfoAndChunkTx(rpmPkg *rpm.Package, file multipart.File) (pkgInfo *sqlite.UpgradePkgInfo, err error) { + payloadSize := uint64(rpmPkg.Signature.GetTag(1000).Int64()) + chunkCount := payloadSize / constant.CHUNK_SIZE + if payloadSize%constant.CHUNK_SIZE != 0 { + chunkCount++ + } + pkgInfo = &sqlite.UpgradePkgInfo{ + Name: rpmPkg.Name(), + Version: rpmPkg.Version(), + ReleaseDistribution: rpmPkg.Release(), + Distribution: strings.Split(rpmPkg.Release(), ".")[1], + Release: strings.Split(rpmPkg.Release(), ".")[0], + Architecture: rpmPkg.Architecture(), + Size: rpmPkg.Size(), + ChunkCount: int(chunkCount), + PayloadSize: payloadSize, + Md5: hex.EncodeToString(rpmPkg.Signature.Tags[1004].Bytes()), + } + + db, err := sqlitedb.GetSqliteInstance() + if err != nil { + return nil, err + } + err = db.Transaction(func(tx *gorm.DB) error { + if err := tx.Model(&sqlite.UpgradePkgInfo{}).Create(&pkgInfo).Error; err != nil { + return err + } + chunkBuffer := make([]byte, constant.CHUNK_SIZE) + _, err = file.Seek(0, 0) + if err != nil { + return errors.Wrap(err, "Seek failed") + } + for i := 0; i < pkgInfo.ChunkCount; i++ { + n, err := file.Read(chunkBuffer) + if err != nil { + return err + } + record := &sqlite.UpgradePkgChunk{ + PkgId: pkgInfo.PkgId, + ChunkId: i, + ChunkCount: pkgInfo.ChunkCount, + Chunk: chunkBuffer[:n]} + if err = tx.Model(&sqlite.UpgradePkgChunk{}).Create(record).Error; err != nil { + return err + } + } + return nil + }) + return +} diff --git a/agent/service/task/convert.go b/agent/service/task/convert.go index 957b6799..89bc541e 100644 --- a/agent/service/task/convert.go +++ b/agent/service/task/convert.go @@ -71,7 +71,7 @@ func (s *taskService) convertNodeInstance(bo *bo.NodeInstance) (*task.Node, erro return nil, err } - return task.NewNodeWithId(bo.Id, bo.Name, bo.Type, bo.State, bo.Operator, bo.StructName, ctx, s.isLocal, bo.StartTime, bo.EndTime), nil + return task.NewNodeWithId(bo.Id, bo.Name, int(bo.DagId), bo.Type, bo.State, bo.Operator, bo.StructName, ctx, s.isLocal, bo.StartTime, bo.EndTime), nil } // convertSubTaskInstance convert SubTaskInstance to task.ExecutableTask. diff --git a/agent/service/task/dag.go b/agent/service/task/dag.go index db1eeb04..1894b9ab 100644 --- a/agent/service/task/dag.go +++ b/agent/service/task/dag.go @@ -56,7 +56,7 @@ func (s *taskService) GetDagDetail(dagId int64) (dagDetailDTO *task.DagDetailDTO return nil, err } - nodeDetailDTO, err := getNodeDetail(s, nodes[i]) + nodeDetailDTO, err := getNodeDetail(s, nodes[i], dag.GetDagType()) if err != nil { return nil, err } @@ -65,12 +65,12 @@ func (s *taskService) GetDagDetail(dagId int64) (dagDetailDTO *task.DagDetailDTO return dagDetailDTO, nil } -func getNodeDetail(service TaskServiceInterface, node *task.Node) (nodeDetailDTO *task.NodeDetailDTO, err error) { - nodeDetailDTO = task.NewNodeDetailDTO(node) +func getNodeDetail(service TaskServiceInterface, node *task.Node, dagType string) (nodeDetailDTO *task.NodeDetailDTO, err error) { + nodeDetailDTO = task.NewNodeDetailDTO(node, dagType) subTasks := node.GetSubTasks() n := len(subTasks) for i := 0; i < n; i++ { - taskDetailDTO, err := getSubTaskDetail(service, subTasks[i]) + taskDetailDTO, err := getSubTaskDetail(service, subTasks[i], dagType) if err != nil { return nil, err } @@ -79,8 +79,8 @@ func getNodeDetail(service TaskServiceInterface, node *task.Node) (nodeDetailDTO return } -func getSubTaskDetail(service TaskServiceInterface, subTask task.ExecutableTask) (taskDetailDTO *task.TaskDetailDTO, err error) { - taskDetailDTO = task.NewTaskDetailDTO(subTask) +func getSubTaskDetail(service TaskServiceInterface, subTask task.ExecutableTask, dagType string) (taskDetailDTO *task.TaskDetailDTO, err error) { + taskDetailDTO = task.NewTaskDetailDTO(subTask, dagType) if subTask.IsRunning() || subTask.IsFinished() { taskDetailDTO.TaskLogs, err = service.GetSubTaskLogsByTaskID(subTask.GetID()) } @@ -140,25 +140,29 @@ func (s *taskService) FindLastMaintenanceDag() (*task.Dag, error) { return dag, err } -func (s *taskService) GetDagIDBySubTaskId(taskID int64) (dagID int64, err error) { +// notice: GetDagBySubTaskId will occur error if the task is remote. +func (s *taskService) GetDagBySubTaskId(taskID int64) (*task.Dag, error) { db, err := s.getDbInstance() if err != nil { - return + return nil, err } var nodeID int64 if err = db.Model(s.getSubTaskModel()).Select("node_id").Where("id=?", taskID).First(&nodeID).Error; err != nil { - return + return nil, err } - err = db.Model(s.getNodeModel()).Select("dag_id").Where("id=?", nodeID).First(&dagID).Error - return + var dagID int64 + if err = db.Model(s.getNodeModel()).Select("dag_id").Where("id=?", nodeID).First(&dagID).Error; err != nil { + return nil, err + } + return s.GetDagInstance(dagID) } func (s *taskService) GetDagGenericIDBySubTaskId(taskID int64) (dagGenericID string, err error) { - var dagID int64 - if dagID, err = s.GetDagIDBySubTaskId(taskID); err != nil { + dag, err := s.GetDagBySubTaskId(taskID) + if err != nil { return } - dagGenericID = task.ConvertIDToGenericID(dagID, s.isLocal) + dagGenericID = task.ConvertIDToGenericID(dag.GetID(), s.isLocal, dag.GetDagType()) return } @@ -245,6 +249,7 @@ func (s *taskService) newDagInstanceBO(template *task.Template, ctx *task.TaskCo return &bo.DagInstance{ Name: template.Name, Stage: 1, + Type: template.Type, MaxStage: len(template.GetNodes()), State: task.READY, Operator: task.RUN, diff --git a/agent/service/task/interface.go b/agent/service/task/interface.go index b1570f32..4e422f11 100644 --- a/agent/service/task/interface.go +++ b/agent/service/task/interface.go @@ -102,6 +102,8 @@ type SubTaskServiceInterface interface { GetLocalTaskInstanceByRemoteTaskId(int64) (*sqlite.SubtaskInstance, error) + GetDagBySubTaskId(taskId int64) (*task.Dag, error) + GetSubTasks(*task.Node) ([]task.ExecutableTask, error) GetSubTaskByTaskID(int64) (task.ExecutableTask, error) diff --git a/agent/service/task/status_maintainer.go b/agent/service/task/status_maintainer.go index 9c4d3786..add6e262 100644 --- a/agent/service/task/status_maintainer.go +++ b/agent/service/task/status_maintainer.go @@ -119,8 +119,8 @@ func (maintainer *clusterStatusMaintainer) UpdateMaintenanceTask(tx *gorm.DB, da } if lock.DagID > 0 && dag.IsFail() && dag.GetID() != lock.DagID { - gid := task.ConvertIDToGenericID(dag.GetID(), false) - oldGid := task.ConvertIDToGenericID(lock.DagID, false) + gid := task.ConvertIDToGenericID(dag.GetID(), false, "") + oldGid := task.ConvertIDToGenericID(lock.DagID, false, "") return fmt.Errorf("%s has already executed task %s. '%s: %s' cannot be executed. Please submit a new request", lock.LockName, oldGid, gid, dag.GetName()) } @@ -188,12 +188,28 @@ type agentStatusMaintainer struct { } func (maintainer *agentStatusMaintainer) setStatus(tx *gorm.DB, newStatus int, oldStatus int) error { - resp := tx.Model(&sqlite.OcsInfo{}).Where("name=? and value=?", constant.OCS_INFO_STATUS, oldStatus).Update("value", strconv.Itoa(newStatus)) - if resp.Error != nil { - return resp.Error + var resp *gorm.DB + if newStatus == task.OBPROXY_MAINTENACE { + resp = tx.Model(&sqlite.ObproxyInfo{}).Where("name=? and value=?", constant.OBPROXY_INFO_STATUS, oldStatus).Update("value", strconv.Itoa(newStatus)) + if resp.Error != nil { + return resp.Error + } + } else { + resp = tx.Model(&sqlite.OcsInfo{}).Where("name=? and value=?", constant.OCS_INFO_STATUS, oldStatus).Update("value", strconv.Itoa(newStatus)) + if resp.Error != nil { + return resp.Error + } } if resp.RowsAffected == 0 { - return fmt.Errorf("failed to start maintenance: agent status is not %d", oldStatus) + if newStatus == task.NOT_UNDER_MAINTENANCE { + var nowStatus int + if err := tx.Set("gorm:query_option", "FOR UPDATE").Model(&sqlite.OcsInfo{}).Select("value").Where("name=?", "status").First(&nowStatus).Error; err != nil { + return err + } else if nowStatus == newStatus { + return nil + } + } + return fmt.Errorf("failed to set status to %d: agent status is not %d", newStatus, oldStatus) } return nil } @@ -202,7 +218,7 @@ func (maintainer *agentStatusMaintainer) StartMaintenance(tx *gorm.DB, dag task. if !dag.IsMaintenance() { return nil } - return maintainer.setStatus(tx, task.GLOBAL_MAINTENANCE, task.NOT_UNDER_MAINTENANCE) + return maintainer.setStatus(tx, dag.GetMaintenanceType(), task.NOT_UNDER_MAINTENANCE) } func (maintainer *agentStatusMaintainer) UpdateMaintenanceTask(tx *gorm.DB, dag *task.Dag) error { @@ -210,7 +226,14 @@ func (maintainer *agentStatusMaintainer) UpdateMaintenanceTask(tx *gorm.DB, dag } func (maintainer *agentStatusMaintainer) StopMaintenance(tx *gorm.DB, dag task.Maintainer) error { - return maintainer.setStatus(tx, task.NOT_UNDER_MAINTENANCE, task.GLOBAL_MAINTENANCE) + switch dag.GetMaintenanceType() { + case task.GLOBAL_MAINTENANCE: + return maintainer.setStatus(tx, task.NOT_UNDER_MAINTENANCE, task.GLOBAL_MAINTENANCE) + case task.OBPROXY_MAINTENACE: + return maintainer.setStatus(tx, task.NOT_UNDER_MAINTENANCE, task.OBPROXY_MAINTENACE) + default: + return nil + } } func (maintainer *agentStatusMaintainer) IsRunning() (bool, error) { diff --git a/agent/service/task/sub_task.go b/agent/service/task/sub_task.go index 5aa12b68..48e65fbb 100644 --- a/agent/service/task/sub_task.go +++ b/agent/service/task/sub_task.go @@ -182,7 +182,7 @@ func (s *taskService) StartSubTask(subtask task.ExecutableTask) error { } else if taskInstanceBO.ExecuteTimes != subtask.GetExecuteTimes() { return fmt.Errorf("failed to start task: sub task %d execute times is %d now", subtask.GetID(), taskInstanceBO.ExecuteTimes) } else if taskInstanceBO.ExecuterAgentIp != subtask.GetExecuteAgent().Ip || taskInstanceBO.ExecuterAgentPort != subtask.GetExecuteAgent().Port { - return fmt.Errorf("failed to start task: sub task %d execute agent is %s:%d now", subtask.GetID(), taskInstanceBO.ExecuterAgentIp, taskInstanceBO.ExecuterAgentPort) + return fmt.Errorf("failed to start task: sub task %d execute agent is %s now", subtask.GetID(), meta.NewAgentInfo(taskInstanceBO.ExecuterAgentIp, taskInstanceBO.ExecuterAgentPort).String()) } } subtask.SetState(taskInstanceBO.State) diff --git a/agent/service/tenant/enter.go b/agent/service/tenant/enter.go index 5dc9fef4..73d2db64 100644 --- a/agent/service/tenant/enter.go +++ b/agent/service/tenant/enter.go @@ -28,6 +28,7 @@ const ( DBA_OB_UNIT_CONFIGS = "oceanbase.DBA_OB_UNIT_CONFIGS" DBA_OB_CLUSTER_EVENT_HISTORY = "oceanbase.DBA_OB_CLUSTER_EVENT_HISTORY" DBA_RECYCLEBIN = "oceanbase.DBA_RECYCLEBIN" + DBA_OB_USERS = "oceanbase.DBA_OB_USERS" CDB_OB_SYS_VARIABLES = "oceanbase.CDB_OB_SYS_VARIABLES" CDB_OB_ARCHIVELOG = "oceanbase.CDB_OB_ARCHIVELOG" diff --git a/agent/service/tenant/tenant.go b/agent/service/tenant/tenant.go index 76bef2f4..016e1dac 100644 --- a/agent/service/tenant/tenant.go +++ b/agent/service/tenant/tenant.go @@ -313,6 +313,29 @@ func (t *TenantService) ModifyTenantRootPassword(tenantName string, oldPwd strin return nil } +func (t *TenantService) SetTenantVariablesWithTenant(tenantName, password string, variables map[string]interface{}) error { + tempDb, err := oceanbasedb.LoadGormWithTenant(tenantName, password) + if err != nil { + return errors.Occur(errors.ErrUnexpected, err.Error()) + } + defer func() { + db, _ := tempDb.DB() + if db != nil { + db.Close() + } + }() + variablesSql := "" + for k, v := range variables { + if val, ok := v.(string); ok { + variablesSql += fmt.Sprintf(", GLOBAL "+k+"= `%v`", val) + } else { + variablesSql += fmt.Sprintf(", GLOBAL "+k+"= %v", v) + } + } + sqlText := fmt.Sprintf("SET %s", variablesSql[1:]) + return tempDb.Exec(sqlText).Error +} + func (t *TenantService) AlterTenantPrimaryZone(tenantName string, primaryZone string) error { db, err := oceanbasedb.GetInstance() if err != nil { diff --git a/agent/service/tenant/user.go b/agent/service/tenant/user.go new file mode 100644 index 00000000..ab2ffd5a --- /dev/null +++ b/agent/service/tenant/user.go @@ -0,0 +1,51 @@ +/* + * Copyright (c) 2024 OceanBase. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package tenant + +import ( + "fmt" + "strings" + + "github.com/oceanbase/obshell/param" + "gorm.io/gorm" +) + +func (t *TenantService) IsUserExist(db *gorm.DB, userName string) (bool, error) { + var count int64 + err := db.Table(DBA_OB_USERS).Where("USER_NAME = ?", userName).Count(&count).Error + return count > 0, err +} + +func (t *TenantService) CreateUser(db *gorm.DB, userName, password, hostName string) error { + sql := fmt.Sprintf("CREATE USER IF NOT EXISTS `%s`@`%s` IDENTIFIED BY '%s'", userName, hostName, strings.ReplaceAll(password, "'", "'\"'\"'")) + return db.Exec(sql).Error +} + +func (t *TenantService) GrantGlobalPrivileges(db *gorm.DB, userName, hostName string, privilege []string) error { + sql := fmt.Sprintf("GRANT %s ON *.* TO `%s`@`%s`", strings.Join(privilege, ","), userName, hostName) + return db.Exec(sql).Error +} + +func (t *TenantService) GrantDbPrivileges(db *gorm.DB, userName, hostName string, privilege param.DbPrivilegeParam) error { + sql := fmt.Sprintf("GRANT %s ON `%s`.* TO `%s`@`%s`", strings.Join(privilege.Privileges, ","), privilege.DbName, userName, hostName) + return db.Exec(sql).Error +} + +func (t *TenantService) DropUser(db *gorm.DB, userName string) error { + sql := fmt.Sprintf("DROP USER `%s`", userName) + return db.Exec(sql).Error +} diff --git a/client/cmd/cluster/enter.go b/client/cmd/cluster/enter.go index a7edd740..23d667ef 100644 --- a/client/cmd/cluster/enter.go +++ b/client/cmd/cluster/enter.go @@ -53,15 +53,17 @@ const ( // CMD_INIT represents the "init" command used to initialize the cluster. CMD_INIT = "init" // Flags for the "init" command. - FLAG_PASSWORD = "rootpassword" - FLAG_PASSWORD_ALIAS = "rp" - FLAG_CLUSTER_NAME = "cluster_name" - FLAG_CLUSTER_NAME_SH = "n" - FLAG_CLUSTER_ID = "cluster_id" - FLAG_CLUSTER_ID_SH = "i" - FLAG_RS_LIST = "rs_list" - FLAG_RS_LIST_ALIAS = "rs" - FLAG_IMPORT_SCRIPT = "import_script" + FLAG_PASSWORD = "rootpassword" + FLAG_PASSWORD_ALIAS = "rp" + FLAG_CLUSTER_NAME = "cluster_name" + FLAG_CLUSTER_NAME_SH = "n" + FLAG_CLUSTER_ID = "cluster_id" + FLAG_CLUSTER_ID_SH = "i" + FLAG_RS_LIST = "rs_list" + FLAG_RS_LIST_ALIAS = "rs" + FLAG_IMPORT_SCRIPT = "import_script" + FLAG_CREATE_PROXYRO_USER = "create_proxyro_user" + FLAG_PROXYRO_PASSWORD = "proxyro_password" // CMD_START represents the "start" command used to start observers. CMD_START = "start" diff --git a/client/cmd/cluster/init.go b/client/cmd/cluster/init.go index e7c7c8d6..3ed242af 100644 --- a/client/cmd/cluster/init.go +++ b/client/cmd/cluster/init.go @@ -36,9 +36,11 @@ import ( ) type ClusterInitFlags struct { - password string - verbose bool - importScript bool + password string + verbose bool + importScript bool + createProxyroUser bool + proxyroPassword string ObserverConfigFlags } @@ -77,6 +79,8 @@ func newInitCmd() *cobra.Command { initCmd.VarsPs(&opts.optStr, []string{FLAG_OPT_STR, FLAG_OPT_STR_SH}, "", "Additional parameters for the observer, use the format key=value for each configuration, separated by commas.", false) initCmd.VarsPs(&opts.rsList, []string{FLAG_RS_LIST, FLAG_RS_LIST_ALIAS}, "", "Root service list", false) initCmd.VarsPs(&opts.importScript, []string{FLAG_IMPORT_SCRIPT}, false, "Import the observer's scripts for sys tenant.", false) + initCmd.VarsPs(&opts.createProxyroUser, []string{FLAG_CREATE_PROXYRO_USER}, false, "Create the default user 'proxyro'.", false) + initCmd.VarsPs(&opts.proxyroPassword, []string{FLAG_PROXYRO_PASSWORD}, "", "Password for the default user 'proxyro'.", false) initCmd.VarsPs(&opts.verbose, []string{clientconst.FLAG_VERBOSE, clientconst.FLAG_VERBOSE_SH}, false, "Activate verbose output", false) @@ -108,7 +112,9 @@ func clusterInit(cmd *cobra.Command, flags *ClusterInitFlags) error { func buildInitParams(flags *ClusterInitFlags) *param.ObInitParam { return ¶m.ObInitParam{ - ImportScript: flags.importScript, + ImportScript: flags.importScript, + CreateProxyroUser: flags.createProxyroUser, + ProxyroPassword: flags.proxyroPassword, } } diff --git a/client/cmd/cluster/join.go b/client/cmd/cluster/join.go index dfce73fc..d52644c6 100644 --- a/client/cmd/cluster/join.go +++ b/client/cmd/cluster/join.go @@ -17,7 +17,6 @@ package cluster import ( - "regexp" "strconv" "strings" @@ -29,6 +28,7 @@ import ( "github.com/oceanbase/obshell/agent/errors" "github.com/oceanbase/obshell/agent/executor/ob" ocsagentlog "github.com/oceanbase/obshell/agent/log" + "github.com/oceanbase/obshell/agent/meta" "github.com/oceanbase/obshell/client/command" clientconst "github.com/oceanbase/obshell/client/constant" cmdlib "github.com/oceanbase/obshell/client/lib/cmd" @@ -93,7 +93,7 @@ func agentJoin(cmd *cobra.Command, flags *AgentJoinFlags) error { } stdio.StopLoading() - targetAgent, err := NewAgentByString(flags.server) + targetAgent, err := meta.ConvertAddressToAgentInfo(flags.server) if err != nil { return err } @@ -172,11 +172,7 @@ func isValidRsList(rsList string) bool { servers := strings.Split(rsList, ";") for _, server := range servers { if server != "" { - arr := strings.Split(server, ":") - if len(arr) != 3 { - return false - } - if !isValidIp(arr[0]) || !isValidPortStr(arr[1]) || !isValidPortStr(arr[2]) { + if _, err := meta.ConvertAddressToAgentInfo(server); err != nil { return false } } @@ -184,26 +180,6 @@ func isValidRsList(rsList string) bool { return true } -func isValidIp(ip string) bool { - ipRegexp := regexp.MustCompile(`^(?:(?:25[0-5]|2[0-4]\d|[01]?\d\d?)\.){3}(?:25[0-5]|2[0-4]\d|[01]?\d\d?)$`) - return ipRegexp.MatchString(ip) -} - -func isValidPortStr(port string) bool { - if port == "" { - return true - } - p, err := strconv.Atoi(port) - if err != nil { - return false - } - return isValidPort(p) -} - -func isValidPort(port int) bool { - return port > 1024 && port < 65536 -} - func isValidLogLevel(level string) bool { if level == "" { return true @@ -221,14 +197,14 @@ func checkServerConfigFlags(config map[string]string) error { stdio.Verbose("Check whether the configs is valid") if mysqlPort, ok := config[constant.CONFIG_MYSQL_PORT]; ok { stdio.Verbosef("Check mysql port: %s", mysqlPort) - if !isValidPortStr(mysqlPort) { + if !utils.IsValidPort(mysqlPort) { return errors.Errorf("Invalid port: %s. Port number should be in the range [1024, 65535].", mysqlPort) } } if rpcPort, ok := config[constant.CONFIG_RPC_PORT]; ok { stdio.Verbosef("Check rpc port: %s", rpcPort) - if !isValidPortStr(rpcPort) { + if !utils.IsValidPort(rpcPort) { return errors.Errorf("Invalid port: %s. Port number should be in the range [1024, 65535].", rpcPort) } } diff --git a/client/cmd/cluster/remove.go b/client/cmd/cluster/remove.go index 9627fc54..0fc9cb7e 100644 --- a/client/cmd/cluster/remove.go +++ b/client/cmd/cluster/remove.go @@ -25,6 +25,7 @@ import ( "github.com/oceanbase/obshell/agent/errors" "github.com/oceanbase/obshell/agent/lib/http" ocsagentlog "github.com/oceanbase/obshell/agent/log" + "github.com/oceanbase/obshell/agent/meta" "github.com/oceanbase/obshell/client/command" clientconst "github.com/oceanbase/obshell/client/constant" cmdlib "github.com/oceanbase/obshell/client/lib/cmd" @@ -71,7 +72,7 @@ func newRemoveCmd() *cobra.Command { } func agentRemove(flags *AgentRemoveFlags) error { - targetAgent, err := NewAgentByString(flags.server) + targetAgent, err := meta.ConvertAddressToAgentInfo(flags.server) if err != nil { return err } diff --git a/client/cmd/cluster/scale_in.go b/client/cmd/cluster/scale_in.go index f86ef128..79b175d4 100644 --- a/client/cmd/cluster/scale_in.go +++ b/client/cmd/cluster/scale_in.go @@ -27,6 +27,7 @@ import ( "github.com/oceanbase/obshell/agent/errors" "github.com/oceanbase/obshell/agent/lib/http" ocsagentlog "github.com/oceanbase/obshell/agent/log" + "github.com/oceanbase/obshell/agent/meta" "github.com/oceanbase/obshell/client/command" clientconst "github.com/oceanbase/obshell/client/constant" "github.com/oceanbase/obshell/client/global" @@ -110,7 +111,7 @@ func clusterScaleIn(cmd *cobra.Command, flags *ClusterScaleInFlags) (err error) } func deleteServer(server string, forceKill bool) (*task.DagDetailDTO, error) { - targetAgentInfo, err := NewAgentByString(server) + targetAgentInfo, err := meta.ConvertAddressToAgentInfo(server) if err != nil { return nil, err } diff --git a/client/cmd/cluster/scale_out.go b/client/cmd/cluster/scale_out.go index 4f3e6193..0768b9dd 100644 --- a/client/cmd/cluster/scale_out.go +++ b/client/cmd/cluster/scale_out.go @@ -18,8 +18,6 @@ package cluster import ( "fmt" - "strconv" - "strings" log "github.com/sirupsen/logrus" "github.com/spf13/cobra" @@ -107,7 +105,7 @@ func clusterScaleOut(cmd *cobra.Command, flags *ClusterScaleOutFlags) (err error return err } - targetAgentInfo, err := NewAgentByString(flags.agent) + targetAgentInfo, err := meta.ConvertAddressToAgentInfo(flags.agent) if err != nil { return err } @@ -173,29 +171,3 @@ func buildScaleOutParam(flags *ClusterScaleOutFlags) (*param.ScaleOutParam, erro func scaleOutCmdExample() string { return ` obshell cluster scale-out -s 192.168.1.1:2886 -z zone1 --rp ****` } - -func NewAgentByString(str string) (*meta.AgentInfo, error) { - stdio.Verbosef("Parse target agent info from string: %s", str) - info := strings.Split(str, ":") - if !isValidIp(info[0]) { - return nil, errors.Errorf("Invalid ip address: %s", info[0]) - } - //If the observer provides a port number, use the port number, - //otherwise use the default port number 2886 - agent := &meta.AgentInfo{ - Ip: info[0], - Port: constant.DEFAULT_AGENT_PORT, - } - if len(info) > 1 { - if info[1] == "" { - return nil, errors.Errorf("Invalid server format: '%s:'", info[0]) - } - port, err := strconv.Atoi(info[1]) - if err != nil || !isValidPortStr(info[1]) { - return nil, errors.Errorf("Invalid port: %s. Port number should be in the range [1024, 65535].", info[1]) - } - agent.Port = port - } - stdio.Verbosef("Parsed target agent info: %v", agent) - return agent, nil -} diff --git a/client/cmd/cluster/start.go b/client/cmd/cluster/start.go index e08bdcb3..72d9c0e6 100644 --- a/client/cmd/cluster/start.go +++ b/client/cmd/cluster/start.go @@ -121,7 +121,7 @@ func clusterStart(flags *ClusterStartFlags) (err error) { } if flags.server == "" && flags.zone == "" && !flags.global { - flags.server = fmt.Sprintf("%s:%d", agentStatus.Agent.GetIp(), agentStatus.Agent.GetPort()) + flags.server = agentStatus.Agent.String() } stdio.Verbosef("current my agent is %s", agentStatus.Agent.GetIdentity()) diff --git a/client/cmd/cluster/stop.go b/client/cmd/cluster/stop.go index f12d6f0f..c9ae302f 100644 --- a/client/cmd/cluster/stop.go +++ b/client/cmd/cluster/stop.go @@ -116,7 +116,7 @@ func clusterStop(flags *ClusterStopFlags) (err error) { } if flags.server == "" && flags.zone == "" && !flags.global { - flags.server = fmt.Sprintf("%s:%d", agentStatus.Agent.GetIp(), agentStatus.Agent.GetPort()) + flags.server = agentStatus.Agent.String() } if err = CheckAllAgentMaintenance(); err != nil { @@ -169,11 +169,9 @@ func callEmerTypeApi(uri string, param interface{}) (err error) { signal.Notify(sigChan, os.Interrupt, syscall.SIGTERM) go func() { sig := <-sigChan - stdio.StopLoading() stdio.Printf("\nReceived signal: %v", sig) stdio.Info("try to cancel the task, please wait...") if err := dagHandler.CancelDag(); err != nil { - stdio.StopLoading() stdio.Warnf("Failed to cancel the task: %s", err.Error()) os.Exit(1) } diff --git a/client/cmd/cluster/take_over.go b/client/cmd/cluster/take_over.go index 4f830948..c959256d 100644 --- a/client/cmd/cluster/take_over.go +++ b/client/cmd/cluster/take_over.go @@ -24,10 +24,6 @@ import ( log "github.com/sirupsen/logrus" - "github.com/oceanbase/obshell/client/global" - "github.com/oceanbase/obshell/client/lib/http" - "github.com/oceanbase/obshell/client/lib/stdio" - "github.com/oceanbase/obshell/client/utils/api" "github.com/oceanbase/obshell/agent/cmd/admin" "github.com/oceanbase/obshell/agent/cmd/daemon" "github.com/oceanbase/obshell/agent/constant" @@ -36,6 +32,10 @@ import ( "github.com/oceanbase/obshell/agent/lib/path" "github.com/oceanbase/obshell/agent/meta" "github.com/oceanbase/obshell/agent/repository/db/oceanbase" + "github.com/oceanbase/obshell/client/global" + "github.com/oceanbase/obshell/client/lib/http" + "github.com/oceanbase/obshell/client/lib/stdio" + "github.com/oceanbase/obshell/client/utils/api" ) func handleIfInTakeoverProcess() error { @@ -98,38 +98,34 @@ func getServersForEmecStart(flags *ClusterStartFlags) (servers []string, err err } // getServersByInputAndConf takes a ClusterStartFlags structure and a list of server addresses paired with their RPC ports, -func getServersByInputAndConf(flags *ClusterStartFlags, serversWithRpcPort [][2]string) (servers []string, err error) { +func getServersByInputAndConf(flags *ClusterStartFlags, serversWithRpcPort []meta.AgentInfoInterface) (servers []string, err error) { if getScopeType(&flags.scopeFlags) == ob.SCOPE_GLOBAL { for _, server := range serversWithRpcPort { - servers = append(servers, server[0]) + servers = append(servers, server.GetIp()) } return } // If Server scope is specified, perform detailed validation. - var items []string inputServers := strings.Split(strings.TrimSpace(flags.server), ",") for _, inputServer := range inputServers { - items = strings.Split(inputServer, ":") - if len(items) != 2 { - return nil, errors.Errorf("invalid server format: %s", inputServer) - } - if items[1] != fmt.Sprint(constant.DEFAULT_AGENT_PORT) { - return nil, errors.Errorf("unsupported port: %s in emergency case", items[1]) + inputServerInfo, err := meta.ConvertAddressToAgentInfo(inputServer) + if err != nil { + return nil, errors.Errorf("invalid server '%s'", inputServerInfo) } // Check if the server with the default port is present in the configuration. var found bool for _, server := range serversWithRpcPort { - if server[0] == items[0] { + if server.GetIp() == inputServerInfo.GetIp() { found = true break } } if !found { - return nil, errors.Errorf("server %s is not in the ob conf", items[0]) + return nil, errors.Errorf("server %s is not in the ob conf", inputServerInfo.GetIp()) } - servers = append(servers, items[0]) + servers = append(servers, inputServerInfo.GetIp()) } log.Info("servers to start ", servers) return @@ -226,7 +222,8 @@ func sshStartRemoteAgentForTakeOver(server string, agentPort int, sshFlags SSHFl } defer SSHClient.Close() - cmd := fmt.Sprintf(`export OB_ROOT_PASSWORD='%s';%s cluster start -s '%s:%d'`, os.Getenv(constant.OB_ROOT_PASSWORD), path.ObshellBinPath(), server, agentPort) + agentInfo := meta.NewAgentInfo(server, agentPort) + cmd := fmt.Sprintf(`export OB_ROOT_PASSWORD='%s';%s cluster start -s '%s'`, os.Getenv(constant.OB_ROOT_PASSWORD), path.ObshellBinPath(), agentInfo.String()) if msg, err := SSHClient.Exec(cmd); err != nil { errCh <- errors.Wrapf(err, "failed to start remote agent on %s, error msg: %s", server, string(msg)) return diff --git a/client/cmd/recyclebin/tenant/enter.go b/client/cmd/recyclebin/tenant/enter.go index 6cf64d21..5d06e4e6 100644 --- a/client/cmd/recyclebin/tenant/enter.go +++ b/client/cmd/recyclebin/tenant/enter.go @@ -34,7 +34,7 @@ const ( FLAG_NEW_NAME_SH = "n" FLAG_NEW_NAME = "new_name" - // obshell recyclebin tenatn show + // obshell recyclebin tenant show CMD_SHOW = "show" ) diff --git a/client/cmd/tenant/variable/enter.go b/client/cmd/tenant/variable/enter.go index d44408d5..a1319032 100644 --- a/client/cmd/tenant/variable/enter.go +++ b/client/cmd/tenant/variable/enter.go @@ -42,6 +42,8 @@ const ( // obshell tenant variable set CMD_SET = "set" + + FLAG_TENANT_PASSWORD = "tenant_password" ) func NewVariableCmd() *cobra.Command { @@ -112,6 +114,7 @@ func showVariable(cmd *cobra.Command, tenant string, variable string) error { func newSetCmd() *cobra.Command { var verbose bool + var tenantPassword string setCmd := command.NewCommand(&cobra.Command{ Use: CMD_SET, Short: "Set speciaic variables.", @@ -130,7 +133,7 @@ func newSetCmd() *cobra.Command { cmd.SilenceUsage = true ocsagentlog.InitLogger(config.DefaultClientLoggerConifg()) stdio.SetVerboseMode(verbose) - if err := setVariable(cmd, args[0], args[1]); err != nil { + if err := setVariable(cmd, args[0], args[1], tenantPassword); err != nil { stdio.LoadFailedWithoutMsg() stdio.Error(err.Error()) return err @@ -141,17 +144,19 @@ func newSetCmd() *cobra.Command { }) setCmd.Annotations = map[string]string{clientconst.ANNOTATION_ARGS: " "} setCmd.VarsPs(&verbose, []string{clientconst.FLAG_VERBOSE, clientconst.FLAG_VERBOSE_SH}, false, "Activate verbose output", false) + setCmd.VarsPs(&tenantPassword, []string{FLAG_TENANT_PASSWORD}, "", "Tenant password", false) return setCmd.Command } -func setVariable(cmd *cobra.Command, tenant string, str string) error { +func setVariable(cmd *cobra.Command, tenant string, str string, tenantPassword string) error { variables, err := parameter.BuildVariableOrParameterMap(str) if err != nil { cmd.SilenceUsage = false return err } params := param.SetTenantVariablesParam{ - Variables: variables, + Variables: variables, + TenantPassword: tenantPassword, } stdio.StartLoading("set tenant variables") if err := api.CallApiWithMethod(http.PUT, constant.URI_TENANT_API_PREFIX+"/"+tenant+constant.URI_VARIABLES, params, nil); err != nil { diff --git a/client/lib/http/ssh.go b/client/lib/http/ssh.go index ab355974..f4620e47 100644 --- a/client/lib/http/ssh.go +++ b/client/lib/http/ssh.go @@ -17,16 +17,17 @@ package http import ( - "fmt" "io" "os" osuser "os/user" "path/filepath" "strconv" - "github.com/oceanbase/obshell/utils" "github.com/pkg/errors" "golang.org/x/crypto/ssh" + + "github.com/oceanbase/obshell/agent/meta" + "github.com/oceanbase/obshell/utils" ) const DEFALUT_SSH_PORT = 22 @@ -218,5 +219,7 @@ func newClient(config *SSHClient, auth ...ssh.AuthMethod) (*ssh.Client, error) { Auth: auth, HostKeyCallback: ssh.InsecureIgnoreHostKey(), } - return ssh.Dial("tcp", fmt.Sprintf("%s:%d", config.Host, config.Port), conf) + + server := meta.NewAgentInfo(config.Host, config.Port) + return ssh.Dial("tcp", server.String(), conf) } diff --git a/cmd/main.go b/cmd/main.go index 4359bc91..a9708811 100644 --- a/cmd/main.go +++ b/cmd/main.go @@ -18,7 +18,9 @@ package main import ( "os" + "os/signal" "runtime" + "syscall" "github.com/spf13/cobra" @@ -37,6 +39,7 @@ import ( "github.com/oceanbase/obshell/client/cmd/tenant" "github.com/oceanbase/obshell/client/cmd/unit" "github.com/oceanbase/obshell/client/command" + "github.com/oceanbase/obshell/client/lib/stdio" ) func main() { @@ -71,6 +74,12 @@ func main() { } agentcmd.PreHandler() + + defer func() { + stdio.LoadFailedWithoutMsg() + }() + go gracefulTermination() + if err := cmds.Execute(); err != nil { os.Exit(-1) } @@ -88,3 +97,15 @@ func newCmd() *cobra.Command { cmd.SetHelpCommand(&cobra.Command{Hidden: true}) return cmd.Command } + +func gracefulTermination() { + sigs := make(chan os.Signal, 1) + signal.Notify(sigs, syscall.SIGINT, syscall.SIGTERM) + sig := <-sigs + stdio.StopLoading() + if sig == syscall.SIGINT { + os.Exit(130) + } else { + os.Exit(143) + } +} diff --git a/go.mod b/go.mod index 3372a6d8..07b87056 100644 --- a/go.mod +++ b/go.mod @@ -23,12 +23,11 @@ require ( github.com/spf13/pflag v1.0.5 github.com/swaggo/files v1.0.1 github.com/swaggo/gin-swagger v1.6.0 - github.com/swaggo/swag v1.16.2 github.com/ulikunitz/xz v0.5.11 - golang.org/x/crypto v0.27.0 - golang.org/x/sys v0.25.0 - golang.org/x/term v0.24.0 - golang.org/x/text v0.18.0 + golang.org/x/crypto v0.35.0 + golang.org/x/sys v0.30.0 + golang.org/x/term v0.29.0 + golang.org/x/text v0.22.0 gopkg.in/natefinch/lumberjack.v2 v2.2.1 gopkg.in/yaml.v2 v2.4.0 gorm.io/driver/mysql v1.5.1-0.20230509030346-3715c134c25b @@ -44,7 +43,8 @@ require ( github.com/jmespath/go-jmespath v0.4.0 // indirect github.com/mitchellh/mapstructure v1.4.3 // indirect github.com/mozillazg/go-httpheader v0.2.1 // indirect - golang.org/x/sync v0.8.0 // indirect + github.com/swaggo/swag v1.16.2 // indirect + golang.org/x/sync v0.11.0 // indirect golang.org/x/time v0.3.0 // indirect ) @@ -96,7 +96,7 @@ require ( golang.org/x/arch v0.3.0 // indirect golang.org/x/exp v0.0.0-20240909161429-701f63a606c0 golang.org/x/mod v0.21.0 // indirect - golang.org/x/net v0.29.0 // indirect + golang.org/x/net v0.36.0 // indirect golang.org/x/tools v0.25.0 // indirect google.golang.org/protobuf v1.30.0 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect diff --git a/go.sum b/go.sum index 4f9a3659..dda54c12 100644 --- a/go.sum +++ b/go.sum @@ -220,8 +220,8 @@ golang.org/x/arch v0.3.0/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= golang.org/x/crypto v0.14.0/go.mod h1:MVFd36DqK4CsrnJYDkBA3VC4m2GkXAM0PvzMCn4JQf4= -golang.org/x/crypto v0.27.0 h1:GXm2NjJrPaiv/h1tb2UH8QfgC/hOf/+z0p6PT8o1w7A= -golang.org/x/crypto v0.27.0/go.mod h1:1Xngt8kV6Dvbssa53Ziq6Eqn0HqbZi5Z6R0ZpwQzt70= +golang.org/x/crypto v0.35.0 h1:b15kiHdrGCHrP6LvwaQ3c03kgNhhiMgvlhxHQhmg2Xs= +golang.org/x/crypto v0.35.0/go.mod h1:dy7dXNW32cAb/6/PRuTNsix8T+vJAqvuIy5Bli/x0YQ= golang.org/x/exp v0.0.0-20240909161429-701f63a606c0 h1:e66Fs6Z+fZTbFBAxKfP3PALWBtpfqks2bwGcexMxgtk= golang.org/x/exp v0.0.0-20240909161429-701f63a606c0/go.mod h1:2TbTHSBQa924w8M6Xs1QcRcFwyucIwBGpK1p2f1YFFY= golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4= @@ -236,13 +236,13 @@ golang.org/x/net v0.6.0/go.mod h1:2Tu9+aMcznHK/AK1HMvgo6xiTLG5rD5rZLDS+rp2Bjs= golang.org/x/net v0.7.0/go.mod h1:2Tu9+aMcznHK/AK1HMvgo6xiTLG5rD5rZLDS+rp2Bjs= golang.org/x/net v0.10.0/go.mod h1:0qNGK6F8kojg2nk9dLZ2mShWaEBan6FAoqfSigmmuDg= golang.org/x/net v0.17.0/go.mod h1:NxSsAGuq816PNPmqtQdLE42eU2Fs7NoRIZrHJAlaCOE= -golang.org/x/net v0.29.0 h1:5ORfpBpCs4HzDYoodCDBbwHzdR5UrLBZ3sOnUJmFoHo= -golang.org/x/net v0.29.0/go.mod h1:gLkgy8jTGERgjzMic6DS9+SP0ajcu6Xu3Orq/SpETg0= +golang.org/x/net v0.36.0 h1:vWF2fRbw4qslQsQzgFqZff+BItCvGFQqKzKIzx1rmoA= +golang.org/x/net v0.36.0/go.mod h1:bFmbeoIPfrw4sMHNhb4J9f6+tPziuGjq7Jk/38fxi1I= golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sync v0.8.0 h1:3NFvSEYkUoMifnESzZl15y791HH1qU2xm6eCJU5ZPXQ= -golang.org/x/sync v0.8.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= +golang.org/x/sync v0.11.0 h1:GGz8+XQP4FvTTrjZPzNKTMFtSXH80RAzG+5ghFPgK9w= +golang.org/x/sync v0.11.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190222072716-a9d3bda3a223/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190916202348-b4ddaad3f8a3/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= @@ -261,16 +261,16 @@ golang.org/x/sys v0.11.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.13.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.14.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/sys v0.15.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= -golang.org/x/sys v0.25.0 h1:r+8e+loiHxRqhXVl6ML1nO3l1+oFoWbnlu2Ehimmi34= -golang.org/x/sys v0.25.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +golang.org/x/sys v0.30.0 h1:QjkSwP/36a20jFYWkSue1YwXzLmsV5Gfq7Eiy72C1uc= +golang.org/x/sys v0.30.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= golang.org/x/term v0.5.0/go.mod h1:jMB1sMXY+tzblOD4FWmEbocvup2/aLOaQEp7JmGp78k= golang.org/x/term v0.8.0/go.mod h1:xPskH00ivmX89bAKVGSKKtLOWNx2+17Eiy94tnKShWo= golang.org/x/term v0.13.0/go.mod h1:LTmsnFJwVN6bCy1rVCoS+qHT1HhALEFxKncY3WNNh4U= golang.org/x/term v0.14.0/go.mod h1:TySc+nGkYR6qt8km8wUhuFRTVSMIX3XPR58y2lC8vww= -golang.org/x/term v0.24.0 h1:Mh5cbb+Zk2hqqXNO7S1iTjEphVL+jb8ZWaqh/g+JWkM= -golang.org/x/term v0.24.0/go.mod h1:lOBK/LVxemqiMij05LGJ0tzNr8xlmwBRJ81PX6wVLH8= +golang.org/x/term v0.29.0 h1:L6pJp37ocefwRRtYPKSWOWzOtWSxVajvz2ldH/xi3iU= +golang.org/x/term v0.29.0/go.mod h1:6bl4lRlvVuDgSf3179VpIxBF0o10JUpXWOnI7nErv7s= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= @@ -278,8 +278,8 @@ golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= golang.org/x/text v0.7.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8= golang.org/x/text v0.9.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8= golang.org/x/text v0.13.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE= -golang.org/x/text v0.18.0 h1:XvMDiNzPAl0jr17s6W9lcaIhGUfUORdGCNsuLmPG224= -golang.org/x/text v0.18.0/go.mod h1:BuEKDfySbSR4drPmRPG/7iBdf8hvFMuRexcpahXilzY= +golang.org/x/text v0.22.0 h1:bofq7m3/HAFvbF51jz3Q9wLg3jkvSPuiZu/pD1XwgtM= +golang.org/x/text v0.22.0/go.mod h1:YRoo4H8PVmsu+E3Ou7cqLVH8oXWIHVoX0jqUWALQhfY= golang.org/x/time v0.3.0 h1:rg5rLMjNzMS1RkNLzCG38eapWhnYLFYXDXj2gOlr8j4= golang.org/x/time v0.3.0/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= diff --git a/param/agent.go b/param/agent.go index 0e08cfaf..89293d79 100644 --- a/param/agent.go +++ b/param/agent.go @@ -24,8 +24,9 @@ import ( ) type JoinApiParam struct { - AgentInfo meta.AgentInfo `json:"agentInfo" binding:"required"` - ZoneName string `json:"zoneName" binding:"required"` + AgentInfo meta.AgentInfo `json:"agentInfo" binding:"required"` + ZoneName string `json:"zoneName" binding:"required"` + MasterPassword string `json:"masterPassword"` } type JoinMasterParam struct { @@ -43,3 +44,12 @@ type AllAgentsSyncData struct { AllAgents []oceanbase.AllAgent `json:"all_agents" binding:"required"` LastSyncTime time.Time `json:"last_sync_time" binding:"required"` } + +type SetAgentPasswordParam struct { + Password string `json:"password" binding:"required"` +} + +type AddTokenParam struct { + AgentInfo meta.AgentInfo `json:"agentInfo" binding:"required"` + Token string `json:"token" binding:"required"` +} diff --git a/param/ob.go b/param/ob.go index 14593465..42f6b192 100644 --- a/param/ob.go +++ b/param/ob.go @@ -53,6 +53,7 @@ type ScaleOutParam struct { type ClusterScaleOutParam struct { ScaleOutParam + TargetAgentPassword string `json:"targetAgentPassword"` } type LocalScaleOutParam struct { @@ -74,7 +75,9 @@ type ClusterScaleInParam struct { } type ObInitParam struct { - ImportScript bool `json:"import_script"` + ImportScript bool `json:"import_script"` + CreateProxyroUser bool `json:"create_proxyro_user"` + ProxyroPassword string `json:"proxyro_password"` } type ObStopParam struct { diff --git a/param/obproxy.go b/param/obproxy.go new file mode 100644 index 00000000..9f7eeb86 --- /dev/null +++ b/param/obproxy.go @@ -0,0 +1,36 @@ +/* + * Copyright (c) 2024 OceanBase. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package param + +type AddObproxyParam struct { + Name string `json:"name"` + HomePath string `json:"home_path" binding:"required"` + SqlPort *int `json:"sql_port"` // Default to 2883. + RpcPort *int `json:"rpc_port"` // Default to 2884. + ExporterPort *int `json:"exporter_port"` // Default to 2885. + ProxyroPassword string `json:"proxyro_password"` + ObproxySysPassword string `json:"obproxy_sys_password"` + RsList *string `json:"rs_list"` + ConfigUrl *string `json:"config_url"` + Parameters map[string]string `json:"parameters"` +} + +type UpgradeObproxyParam struct { + Version string `json:"version" binding:"required"` + Release string `json:"release" binding:"required"` + UpgradeDir string `json:"upgrade_dir"` +} diff --git a/param/tenant.go b/param/tenant.go index de483ec7..2972d705 100644 --- a/param/tenant.go +++ b/param/tenant.go @@ -27,7 +27,7 @@ type CreateTenantParam struct { Collation string `json:"collation"` ReadOnly bool `json:"read_only"` // Default to false. Comment string `json:"comment"` // Messages. - Variables map[string]interface{} `json:"variables"` // Teantn global variables. + Variables map[string]interface{} `json:"variables"` // Tenant global variables. Parameters map[string]interface{} `json:"parameters"` // Tenant parameters. Scenario string `json:"scenario"` // Tenant scenario. ImportScript bool `json:"import_script"` // whether to import script. @@ -92,7 +92,8 @@ type SetTenantParametersParam struct { } type SetTenantVariablesParam struct { - Variables map[string]interface{} `json:"variables" binding:"required"` + Variables map[string]interface{} `json:"variables" binding:"required"` + TenantPassword string `json:"tenant_password"` } // Task Param @@ -102,3 +103,21 @@ type CreateResourcePoolTaskParam struct { UnitConfigName string UnitNum int } + +type CreateUserParam struct { + UserName string `json:"user_name" binding:"required"` + Password string `json:"password" binding:"required"` + RootPassword string `json:"root_password"` + GlobalPrivileges []string `json:"global_privileges"` + DbPrivileges []DbPrivilegeParam `json:"db_privileges"` + HostName string `json:"host_name"` +} + +type DbPrivilegeParam struct { + DbName string `json:"db_name" binding:"required"` + Privileges []string `json:"privileges" binding:"required"` +} + +type DropUserParam struct { + RootPassword string `json:"root_password"` +} diff --git a/utils/address.go b/utils/address.go new file mode 100644 index 00000000..390c06e2 --- /dev/null +++ b/utils/address.go @@ -0,0 +1,41 @@ +/* + * Copyright (c) 2024 OceanBase. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package utils + +import ( + "net" + "strconv" +) + +func IsValidIp(ip string) bool { + return net.ParseIP(ip) != nil +} + +func IsValidPort(port string) bool { + if port == "" { + return true + } + p, err := strconv.Atoi(port) + if err != nil { + return false + } + return IsValidPortValue(p) +} + +func IsValidPortValue(p int) bool { + return p > 1024 && p < 65536 +} diff --git a/utils/path.go b/utils/path.go index f1e16420..7d85d5e3 100644 --- a/utils/path.go +++ b/utils/path.go @@ -18,6 +18,7 @@ package utils import ( "fmt" + "io" "os" "path/filepath" "regexp" @@ -67,3 +68,37 @@ func CheckPathExistAndValid(path string) error { } return CheckPathValid(path) } + +// CheckDirExists checks if the provided filesystem path exists and is a dir. +// It returns an error if the path does not exist or if the path is not a dir. +func CheckDirExists(dir string) error { + fileInfo, err := os.Stat(dir) + if err != nil { + if os.IsNotExist(err) { + return err + } + return errors.Wrapf(err, "failed to stat path %s", dir) + } + + if !fileInfo.IsDir() { + return errors.Errorf("path '%s' is not a directory", dir) + } + return nil +} + +func CheckDirEmpty(path string) error { + dir, err := os.Open(path) + if err != nil { + return errors.Wrapf(err, "failed to open directory %s", path) + } + defer dir.Close() + + _, err = dir.Readdir(1) + if err == io.EOF { + return nil // Directory is empty + } + if err != nil { + return errors.Wrapf(err, "failed to read directory %s", path) + } + return errors.Errorf("directory '%s' is not empty", path) +} diff --git a/utils/port.go b/utils/port.go new file mode 100644 index 00000000..40817c72 --- /dev/null +++ b/utils/port.go @@ -0,0 +1,31 @@ +/* + * Copyright (c) 2024 OceanBase. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package utils + +import ( + "fmt" + "net" +) + +func IsPortFree(port int) (bool, error) { + ln, err := net.Listen("tcp", fmt.Sprintf(":%d", port)) + if err != nil { + return false, nil + } + ln.Close() + return true, nil +} diff --git a/utils/sha1.go b/utils/sha1.go new file mode 100644 index 00000000..fa52d43a --- /dev/null +++ b/utils/sha1.go @@ -0,0 +1,30 @@ +/* + * Copyright (c) 2024 OceanBase. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package utils + +import ( + "crypto/sha1" + "encoding/hex" +) + +func Sha1(input string) string { + hasher := sha1.New() + hasher.Write([]byte(input)) + hashBytes := hasher.Sum(nil) + hashString := hex.EncodeToString(hashBytes) + return hashString +}