Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 66 additions & 13 deletions catalog/aws.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package catalog

import (
"context"
"fmt"
"strconv"
"strings"
Expand All @@ -9,7 +10,9 @@ import (

x "github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go-v2/aws/awserr"
"github.com/aws/aws-sdk-go-v2/service/ecs"
sd "github.com/aws/aws-sdk-go-v2/service/servicediscovery"
"github.com/hashicorp/consul-aws/subcommand"
"github.com/hashicorp/go-hclog"
)

Expand All @@ -29,7 +32,7 @@ type namespace struct {

type aws struct {
lock sync.RWMutex
client *sd.ServiceDiscovery
client *sd.Client
log hclog.Logger
namespace namespace
services map[string]service
Expand Down Expand Up @@ -57,6 +60,12 @@ func (a *aws) sync(consul *consul, stop, stopped chan struct{}) {
consul.log.Info("created", "count", fmt.Sprintf("%d", count))
}

create = tagsNeedUpdate(a.getServices(), consul.getServices())
count = consul.create(create)
if count > 0 {
consul.log.Info("updated", "count", fmt.Sprintf("%d", count))
}

remove := onlyInFirst(consul.getServices(), a.getServices())
count = consul.remove(remove)
if count > 0 {
Expand All @@ -70,7 +79,7 @@ func (a *aws) sync(consul *consul, stop, stopped chan struct{}) {

func (a *aws) fetchNamespace(id string) (*sd.Namespace, error) {
req := a.client.GetNamespaceRequest(&sd.GetNamespaceInput{Id: x.String(id)})
resp, err := req.Send()
resp, err := req.Send(context.Background())
if err != nil {
return nil, err
}
Expand All @@ -85,9 +94,9 @@ func (a *aws) fetchServices() ([]sd.ServiceSummary, error) {
Values: []string{a.namespace.id},
}},
})
p := req.Paginate()
p := sd.NewListServicesPaginator(req)
services := []sd.ServiceSummary{}
for p.Next() {
for p.Next(context.Background()) {
services = append(services, p.CurrentPage().Services...)
}
return services, p.Err()
Expand Down Expand Up @@ -154,6 +163,19 @@ func (a *aws) fetch() error {
}
s.nodes = nodes

s.tags = make(map[string]string)
for _, nodes := range s.nodes {
for _, node := range nodes {
tags, err := a.discoverTags(node.awsID, node.attributes)
if err != nil {
a.log.Error("cannot discover tags", "error", err)
}
for k, v := range tags {
s.tags[k] = v
}
}
}

healths, err := a.fetchHealths(s.awsID)
if err != nil {
a.log.Error("cannot fetch healths", "error", err)
Expand Down Expand Up @@ -230,8 +252,8 @@ func (a *aws) fetchHealths(id string) (map[string]health, error) {
ServiceId: &id,
})
result := map[string]health{}
p := req.Paginate()
for p.Next() {
p := sd.NewGetInstancesHealthStatusPaginator(req)
for p.Next(context.Background()) {
for id, health := range p.CurrentPage().Status {
result[id] = statusFromAWS(health)
}
Expand Down Expand Up @@ -273,9 +295,9 @@ func (a *aws) fetchNodes(id string) ([]sd.InstanceSummary, error) {
req := a.client.ListInstancesRequest(&sd.ListInstancesInput{
ServiceId: &id,
})
p := req.Paginate()
p := sd.NewListInstancesPaginator(req)
nodes := []sd.InstanceSummary{}
for p.Next() {
for p.Next(context.Background()) {
nodes = append(nodes, p.CurrentPage().Instances...)
}
return nodes, p.Err()
Expand All @@ -287,7 +309,7 @@ func (a *aws) discoverNodes(name string) ([]sd.InstanceSummary, error) {
NamespaceName: x.String(a.namespace.name),
ServiceName: x.String(name),
})
resp, err := req.Send()
resp, err := req.Send(context.Background())
if err != nil {
return nil, err
}
Expand All @@ -298,6 +320,37 @@ func (a *aws) discoverNodes(name string) ([]sd.InstanceSummary, error) {
return nodes, nil
}

func (a *aws) discoverTags(id string, attributes map[string]string) (map[string]string, error) {
tags := map[string]string{}
ecsClusterName := attributes["ECS_CLUSTER_NAME"]
ecsServiceName := attributes["ECS_SERVICE_NAME"]
ecsTaskDefinitionFamily := attributes["ECS_TASK_DEFINITION_FAMILY"]
// If this is an ECS service we look for tags in the ECS task
if ecsClusterName != "" && ecsServiceName != "" && ecsTaskDefinitionFamily != "" {
config, err := subcommand.AWSConfig()
if err != nil {
return tags, err
}
client := ecs.New(config)
input := &ecs.DescribeTasksInput{
Cluster: &ecsClusterName,
Tasks: []string{id},
Include: []ecs.TaskField{ecs.TaskFieldTags},
}
req := client.DescribeTasksRequest(input)
tasks, err := req.Send(context.Background())
if err != nil {
return tags, err
}
for _, task := range tasks.Tasks {
for _, t := range task.Tags {
tags[*t.Key] = *t.Value
}
}
}
return tags, nil
}

func (a *aws) getServices() map[string]service {
a.lock.RLock()
copy := a.services
Expand Down Expand Up @@ -341,7 +394,7 @@ func (a *aws) create(services map[string]service) int {
}
}
req := a.client.CreateServiceRequest(&input)
resp, err := req.Send()
resp, err := req.Send(context.Background())
if err != nil {
if err, ok := err.(awserr.Error); ok {
switch err.Code() {
Expand Down Expand Up @@ -370,7 +423,7 @@ func (a *aws) create(services map[string]service) int {
Attributes: attributes,
InstanceId: &instanceID,
})
_, err := req.Send()
_, err := req.Send(context.Background())
if err != nil {
a.log.Error("cannot create nodes", "error", err.Error())
}
Expand Down Expand Up @@ -412,7 +465,7 @@ func (a *aws) remove(services map[string]service) int {
ServiceId: &serviceID,
InstanceId: &id,
})
_, err := req.Send()
_, err := req.Send(context.Background())
if err != nil {
a.log.Error("cannot remove instance", "error", err.Error())
}
Expand All @@ -434,7 +487,7 @@ func (a *aws) remove(services map[string]service) int {
req := a.client.DeleteServiceRequest(&sd.DeleteServiceInput{
Id: &s.awsID,
})
_, err := req.Send()
_, err := req.Send(context.Background())
if err != nil {
a.log.Error("cannot remove services", "name", k, "id", s.awsID, "error", err.Error())
} else {
Expand Down
2 changes: 1 addition & 1 deletion catalog/aws_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ import (
)

type mSDClint struct {
sdi.ServiceDiscoveryAPI
sdi.ClientAPI
}

func (m *mSDClint) CreateServiceRequest(input *sd.CreateServiceRequest) (*sd.CreateServiceOutput, error) {
Expand Down
15 changes: 10 additions & 5 deletions catalog/consul.go
Original file line number Diff line number Diff line change
Expand Up @@ -229,11 +229,12 @@ func (c *consul) fetch(waitIndex uint64) (uint64, error) {
func (c *consul) transformServices(cservices map[string][]string) map[string]service {
services := make(map[string]service, len(cservices))
for k, tags := range cservices {
s := service{id: k, name: k, consulID: k}
s := service{id: k, name: k, consulID: k, tags: map[string]string{}}
for _, t := range tags {
if t == ConsulAWSTag {
s.fromAWS = true
break
} else if parts := strings.SplitN(t, ":", 2); len(parts) == 2 {
s.tags[parts[0]] = parts[1]
}
}
if s.fromAWS {
Expand Down Expand Up @@ -292,7 +293,7 @@ func (c *consul) create(services map[string]service) int {
for h, nodes := range s.nodes {
for _, n := range nodes {
wg.Add(1)
go func(ns, k, name, h string, n node) {
go func(ns, k, name, h string, n node, tags map[string]string) {
defer wg.Done()
id := id(k, h, n.port)
meta := map[string]string{}
Expand All @@ -302,10 +303,14 @@ func (c *consul) create(services map[string]service) int {
meta[ConsulSourceKey] = ConsulAWSTag
meta[ConsulAWSNS] = ns
meta[ConsulAWSID] = n.awsID
consulTags := []string{ConsulAWSTag}
for k, v := range tags {
consulTags = append(consulTags, fmt.Sprintf("%s:%s", k, v))
}
service := api.AgentService{
ID: id,
Service: name,
Tags: []string{ConsulAWSTag},
Tags: consulTags,
Address: h,
Meta: meta,
}
Expand All @@ -326,7 +331,7 @@ func (c *consul) create(services map[string]service) int {
c.setNode(k, h, n.port, n)
count++
}
}(s.awsNamespace, k, name, h, n)
}(s.awsNamespace, k, name, h, n, s.tags)
}
}
for awsID, h := range s.healths {
Expand Down
5 changes: 4 additions & 1 deletion catalog/consul_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,10 @@ func TestConsulRekeyHealths(t *testing.T) {
func TestConsulTransformServices(t *testing.T) {
c := consul{awsPrefix: "aws_"}
services := map[string][]string{"s1": {"abc"}, "aws_s2": {ConsulAWSTag}}
expected := map[string]service{"s1": {id: "s1", name: "s1", consulID: "s1"}, "s2": {id: "aws_s2", name: "s2", consulID: "aws_s2", fromAWS: true}}
expected := map[string]service{
"s1": {id: "s1", name: "s1", consulID: "s1", tags: map[string]string{}},
"s2": {id: "aws_s2", name: "s2", consulID: "aws_s2", fromAWS: true, tags: map[string]string{}},
}

require.Equal(t, expected, c.transformServices(services))
}
Expand Down
16 changes: 16 additions & 0 deletions catalog/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package catalog

import (
"fmt"
"reflect"
"strconv"
"strings"
)
Expand All @@ -24,6 +25,7 @@ type service struct {
awsID string
consulID string
awsNamespace string
tags map[string]string
}

type node struct {
Expand Down Expand Up @@ -126,3 +128,17 @@ func onlyInFirst(servicesA, servicesB map[string]service) map[string]service {
}
return result
}

func tagsNeedUpdate(servicesA, servicesB map[string]service) map[string]service {
result := map[string]service{}
for k, sa := range servicesA {
sb, ok := servicesB[k]
if !ok {
continue
}
if !reflect.DeepEqual(sa.tags, sb.tags) {
result[k] = sa
}
}
return onlyInFirst(result, map[string]service{})
}
2 changes: 1 addition & 1 deletion catalog/sync.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ import (
)

// Sync aws->consul and vice versa.
func Sync(toAWS, toConsul bool, namespaceID, consulPrefix, awsPrefix, awsPullInterval string, awsDNSTTL int64, stale bool, awsClient *sd.ServiceDiscovery, consulClient *api.Client, stop, stopped chan struct{}) {
func Sync(toAWS, toConsul bool, namespaceID, consulPrefix, awsPrefix, awsPullInterval string, awsDNSTTL int64, stale bool, awsClient *sd.Client, consulClient *api.Client, stop, stopped chan struct{}) {
defer close(stopped)
log := hclog.Default().Named("sync")
consul := consul{
Expand Down
25 changes: 13 additions & 12 deletions catalog/sync_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package catalog

import (
"context"
"fmt"
"os"
"testing"
Expand Down Expand Up @@ -142,7 +143,7 @@ func deleteServiceInConsul(c *api.Client, id string) {
c.Catalog().Deregister(&api.CatalogDeregistration{Node: ConsulAWSNodeName, ServiceID: id}, nil)
}

func createServiceInAWS(a *sd.ServiceDiscovery, namespaceID, name string) (string, error) {
func createServiceInAWS(a *sd.Client, namespaceID, name string) (string, error) {
ttl := int64(60)
input := sd.CreateServiceInput{
Name: &name,
Expand All @@ -156,14 +157,14 @@ func createServiceInAWS(a *sd.ServiceDiscovery, namespaceID, name string) (strin
},
}
req := a.CreateServiceRequest(&input)
resp, err := req.Send()
resp, err := req.Send(context.Background())
if err != nil {
return "", err
}
return *resp.Service.Id, nil
}

func createInstanceInAWS(a *sd.ServiceDiscovery, serviceID string) error {
func createInstanceInAWS(a *sd.Client, serviceID string) error {
req := a.RegisterInstanceRequest(&sd.RegisterInstanceInput{
ServiceId: &serviceID,
InstanceId: &serviceID,
Expand All @@ -173,21 +174,21 @@ func createInstanceInAWS(a *sd.ServiceDiscovery, serviceID string) error {
"FUBAR": "BARFU",
},
})
_, err := req.Send()
_, err := req.Send(context.Background())
return err
}

func deleteInstanceInAWS(a *sd.ServiceDiscovery, id string) error {
func deleteInstanceInAWS(a *sd.Client, id string) error {
req := a.DeregisterInstanceRequest(&sd.DeregisterInstanceInput{ServiceId: &id, InstanceId: &id})
_, err := req.Send()
_, err := req.Send(context.Background())
return err
}

func deleteServiceInAWS(a *sd.ServiceDiscovery, id string) error {
func deleteServiceInAWS(a *sd.Client, id string) error {
var err error
for i := 0; i < 50; i++ {
req := a.DeleteServiceRequest(&sd.DeleteServiceInput{Id: &id})
_, err = req.Send()
_, err = req.Send(context.Background())
if err != nil {
time.Sleep(100 * time.Millisecond)
} else {
Expand Down Expand Up @@ -239,7 +240,7 @@ func checkForImportedAWSService(c *api.Client, name, namespaceID, serviceID stri
return fmt.Errorf("shrug")
}

func checkForImportedConsulService(a *sd.ServiceDiscovery, namespaceID, name string, repeat int) error {
func checkForImportedConsulService(a *sd.Client, namespaceID, name string, repeat int) error {
for i := 0; i < repeat; i++ {
req := a.ListServicesRequest(&sd.ListServicesInput{
Filters: []sd.ServiceFilter{{
Expand All @@ -248,8 +249,8 @@ func checkForImportedConsulService(a *sd.ServiceDiscovery, namespaceID, name str
Values: []string{namespaceID},
}},
})
p := req.Paginate()
for p.Next() {
p := sd.NewListServicesPaginator(req)
for p.Next(context.Background()) {
for _, s := range p.CurrentPage().Services {
if *s.Name == name {
if !(s.Description != nil || *s.Description == awsServiceDescription) {
Expand All @@ -260,7 +261,7 @@ func checkForImportedConsulService(a *sd.ServiceDiscovery, namespaceID, name str
ireq := a.ListInstancesRequest(&sd.ListInstancesInput{
ServiceId: s.Id,
})
out, err := ireq.Send()
out, err := ireq.Send(context.Background())
if err != nil {
continue
}
Expand Down
Loading