From 1e16f3b452594ee396c97b9b0771adbeb2d09dca Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Tue, 10 Jan 2023 17:22:46 +1300 Subject: [PATCH 1/3] tests: add integration tests for gating of outgoing connections --- core/connmgr/gater.go | 1 - p2p/test/connectiongating/gating_test.go | 266 ++++++++++++++++++ .../mock_connection_gater_test.go | 109 +++++++ 3 files changed, 375 insertions(+), 1 deletion(-) create mode 100644 p2p/test/connectiongating/gating_test.go create mode 100644 p2p/test/connectiongating/mock_connection_gater_test.go diff --git a/core/connmgr/gater.go b/core/connmgr/gater.go index 672aef9528..82fa56a876 100644 --- a/core/connmgr/gater.go +++ b/core/connmgr/gater.go @@ -52,7 +52,6 @@ import ( // DisconnectReasons is that we require stream multiplexing capability to open a // control protocol stream to transmit the message. type ConnectionGater interface { - // InterceptPeerDial tests whether we're permitted to Dial the specified peer. // // This is called by the network.Network implementation when dialling a peer. diff --git a/p2p/test/connectiongating/gating_test.go b/p2p/test/connectiongating/gating_test.go new file mode 100644 index 0000000000..5c0888e174 --- /dev/null +++ b/p2p/test/connectiongating/gating_test.go @@ -0,0 +1,266 @@ +package connectiongating + +import ( + "context" + "fmt" + "testing" + "time" + + "github.com/libp2p/go-libp2p" + "github.com/libp2p/go-libp2p/core/network" + "github.com/libp2p/go-libp2p/core/peer" + "github.com/libp2p/go-libp2p/core/protocol" + "github.com/libp2p/go-libp2p/p2p/net/swarm" + + "github.com/golang/mock/gomock" + ma "github.com/multiformats/go-multiaddr" + "github.com/stretchr/testify/require" +) + +//go:generate go run github.com/golang/mock/mockgen -package connectiongating -destination mock_connection_gater_test.go github.com/libp2p/go-libp2p/core/connmgr ConnectionGater + +// This list should contain (at least) one address for every transport we have. +var addrs = []ma.Multiaddr{ + ma.StringCast("/ip4/127.0.0.1/tcp/0"), + ma.StringCast("/ip4/127.0.0.1/tcp/0/ws"), + ma.StringCast("/ip4/127.0.0.1/udp/0/quic"), + ma.StringCast("/ip4/127.0.0.1/udp/0/quic-v1"), + ma.StringCast("/ip4/127.0.0.1/udp/0/quic-v1/webtransport"), +} + +func transportName(a ma.Multiaddr) string { + _, tr := ma.SplitLast(a) + return tr.Protocol().Name +} + +func stripCertHash(addr ma.Multiaddr) ma.Multiaddr { + for { + if _, err := addr.ValueForProtocol(ma.P_CERTHASH); err != nil { + break + } + addr, _ = ma.SplitLast(addr) + } + return addr +} + +func TestInterceptPeerDial(t *testing.T) { + for _, a := range addrs { + t.Run(fmt.Sprintf("dialing %s", transportName(a)), func(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + connGater := NewMockConnectionGater(ctrl) + + h1, err := libp2p.New(libp2p.ConnectionGater(connGater)) + require.NoError(t, err) + defer h1.Close() + h2, err := libp2p.New(libp2p.ListenAddrs(a)) + require.NoError(t, err) + defer h2.Close() + require.Len(t, h2.Addrs(), 1) + + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + connGater.EXPECT().InterceptPeerDial(h2.ID()) + require.ErrorIs(t, h1.Connect(ctx, peer.AddrInfo{ID: h2.ID(), Addrs: h2.Addrs()}), swarm.ErrGaterDisallowedConnection) + }) + } +} + +func TestInterceptAddrDial(t *testing.T) { + for _, a := range addrs { + t.Run(fmt.Sprintf("dialing %s", transportName(a)), func(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + connGater := NewMockConnectionGater(ctrl) + + h1, err := libp2p.New(libp2p.ConnectionGater(connGater)) + require.NoError(t, err) + defer h1.Close() + h2, err := libp2p.New(libp2p.ListenAddrs(a)) + require.NoError(t, err) + defer h2.Close() + require.Len(t, h2.Addrs(), 1) + + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + gomock.InOrder( + connGater.EXPECT().InterceptPeerDial(h2.ID()).Return(true), + connGater.EXPECT().InterceptAddrDial(h2.ID(), h2.Addrs()[0]), + ) + require.ErrorIs(t, h1.Connect(ctx, peer.AddrInfo{ID: h2.ID(), Addrs: h2.Addrs()}), swarm.ErrNoGoodAddresses) + }) + } +} + +func TestInterceptSecuredOutgoing(t *testing.T) { + for _, a := range addrs { + t.Run(fmt.Sprintf("dialing %s", transportName(a)), func(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + connGater := NewMockConnectionGater(ctrl) + + h1, err := libp2p.New(libp2p.ConnectionGater(connGater)) + require.NoError(t, err) + defer h1.Close() + h2, err := libp2p.New(libp2p.ListenAddrs(a)) + require.NoError(t, err) + defer h2.Close() + require.Len(t, h2.Addrs(), 1) + + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + gomock.InOrder( + connGater.EXPECT().InterceptPeerDial(h2.ID()).Return(true), + connGater.EXPECT().InterceptAddrDial(h2.ID(), gomock.Any()).Return(true), + connGater.EXPECT().InterceptSecured(network.DirOutbound, h2.ID(), gomock.Any()).Do(func(_ network.Direction, _ peer.ID, addrs network.ConnMultiaddrs) { + // remove the certhash component from WebTransport addresses + require.Equal(t, stripCertHash(h2.Addrs()[0]).String(), addrs.RemoteMultiaddr().String()) + }), + ) + err = h1.Connect(ctx, peer.AddrInfo{ID: h2.ID(), Addrs: h2.Addrs()}) + require.Error(t, err) + require.NotErrorIs(t, err, context.DeadlineExceeded) + }) + } +} + +func TestInterceptUpgradedOutgoing(t *testing.T) { + for _, a := range addrs { + t.Run(fmt.Sprintf("dialing %s", transportName(a)), func(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + connGater := NewMockConnectionGater(ctrl) + + h1, err := libp2p.New(libp2p.ConnectionGater(connGater)) + require.NoError(t, err) + defer h1.Close() + h2, err := libp2p.New(libp2p.ListenAddrs(a)) + require.NoError(t, err) + defer h2.Close() + require.Len(t, h2.Addrs(), 1) + + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + gomock.InOrder( + connGater.EXPECT().InterceptPeerDial(h2.ID()).Return(true), + connGater.EXPECT().InterceptAddrDial(h2.ID(), gomock.Any()).Return(true), + connGater.EXPECT().InterceptSecured(network.DirOutbound, h2.ID(), gomock.Any()).Return(true), + connGater.EXPECT().InterceptUpgraded(gomock.Any()).Do(func(c network.Conn) { + // remove the certhash component from WebTransport addresses + require.Equal(t, stripCertHash(h2.Addrs()[0]), c.RemoteMultiaddr()) + require.Equal(t, h1.ID(), c.LocalPeer()) + require.Equal(t, h2.ID(), c.RemotePeer()) + })) + err = h1.Connect(ctx, peer.AddrInfo{ID: h2.ID(), Addrs: h2.Addrs()}) + require.Error(t, err) + require.NotErrorIs(t, err, context.DeadlineExceeded) + }) + } +} + +func TestInterceptAccept(t *testing.T) { + for _, a := range addrs { + t.Run(fmt.Sprintf("accepting %s", transportName(a)), func(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + connGater := NewMockConnectionGater(ctrl) + + h1, err := libp2p.New() + require.NoError(t, err) + defer h1.Close() + h2, err := libp2p.New( + libp2p.ListenAddrs(a), + libp2p.ConnectionGater(connGater), + ) + require.NoError(t, err) + defer h2.Close() + require.Len(t, h2.Addrs(), 1) + + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + // The basic host dials the first connection. + connGater.EXPECT().InterceptAccept(gomock.Any()).Do(func(addrs network.ConnMultiaddrs) { + // remove the certhash component from WebTransport addresses + require.Equal(t, stripCertHash(h2.Addrs()[0]), addrs.LocalMultiaddr()) + }) + h1.Peerstore().AddAddrs(h2.ID(), h2.Addrs(), time.Hour) + _, err = h1.NewStream(ctx, h2.ID(), protocol.TestingID) + require.Error(t, err) + require.NotErrorIs(t, err, context.DeadlineExceeded) + }) + } +} + +func TestInterceptSecuredIncoming(t *testing.T) { + for _, a := range addrs { + t.Run(fmt.Sprintf("accepting %s", transportName(a)), func(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + connGater := NewMockConnectionGater(ctrl) + + h1, err := libp2p.New() + require.NoError(t, err) + defer h1.Close() + h2, err := libp2p.New( + libp2p.ListenAddrs(a), + libp2p.ConnectionGater(connGater), + ) + require.NoError(t, err) + defer h2.Close() + require.Len(t, h2.Addrs(), 1) + + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + gomock.InOrder( + connGater.EXPECT().InterceptAccept(gomock.Any()).Return(true), + connGater.EXPECT().InterceptSecured(network.DirInbound, h1.ID(), gomock.Any()).Do(func(_ network.Direction, _ peer.ID, addrs network.ConnMultiaddrs) { + // remove the certhash component from WebTransport addresses + require.Equal(t, stripCertHash(h2.Addrs()[0]), addrs.LocalMultiaddr()) + }), + ) + h1.Peerstore().AddAddrs(h2.ID(), h2.Addrs(), time.Hour) + _, err = h1.NewStream(ctx, h2.ID(), protocol.TestingID) + require.Error(t, err) + require.NotErrorIs(t, err, context.DeadlineExceeded) + }) + } +} + +func TestInterceptUpgradedIncoming(t *testing.T) { + for _, a := range addrs { + _, tr := ma.SplitLast(a) + t.Run(fmt.Sprintf("accepting %s", tr.Protocol().Name), func(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + connGater := NewMockConnectionGater(ctrl) + + h1, err := libp2p.New() + require.NoError(t, err) + defer h1.Close() + h2, err := libp2p.New( + libp2p.ListenAddrs(a), + libp2p.ConnectionGater(connGater), + ) + require.NoError(t, err) + defer h2.Close() + require.Len(t, h2.Addrs(), 1) + + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + gomock.InOrder( + connGater.EXPECT().InterceptAccept(gomock.Any()).Return(true), + connGater.EXPECT().InterceptSecured(network.DirInbound, h1.ID(), gomock.Any()).Return(true), + connGater.EXPECT().InterceptUpgraded(gomock.Any()).Do(func(c network.Conn) { + // remove the certhash component from WebTransport addresses + require.Equal(t, stripCertHash(h2.Addrs()[0]), c.LocalMultiaddr()) + require.Equal(t, h1.ID(), c.RemotePeer()) + require.Equal(t, h2.ID(), c.LocalPeer()) + }), + ) + h1.Peerstore().AddAddrs(h2.ID(), h2.Addrs(), time.Hour) + _, err = h1.NewStream(ctx, h2.ID(), protocol.TestingID) + require.Error(t, err) + require.NotErrorIs(t, err, context.DeadlineExceeded) + }) + } +} diff --git a/p2p/test/connectiongating/mock_connection_gater_test.go b/p2p/test/connectiongating/mock_connection_gater_test.go new file mode 100644 index 0000000000..54be42e563 --- /dev/null +++ b/p2p/test/connectiongating/mock_connection_gater_test.go @@ -0,0 +1,109 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: github.com/libp2p/go-libp2p/core/connmgr (interfaces: ConnectionGater) + +// Package connectiongating is a generated GoMock package. +package connectiongating + +import ( + reflect "reflect" + + gomock "github.com/golang/mock/gomock" + control "github.com/libp2p/go-libp2p/core/control" + network "github.com/libp2p/go-libp2p/core/network" + peer "github.com/libp2p/go-libp2p/core/peer" + multiaddr "github.com/multiformats/go-multiaddr" +) + +// MockConnectionGater is a mock of ConnectionGater interface. +type MockConnectionGater struct { + ctrl *gomock.Controller + recorder *MockConnectionGaterMockRecorder +} + +// MockConnectionGaterMockRecorder is the mock recorder for MockConnectionGater. +type MockConnectionGaterMockRecorder struct { + mock *MockConnectionGater +} + +// NewMockConnectionGater creates a new mock instance. +func NewMockConnectionGater(ctrl *gomock.Controller) *MockConnectionGater { + mock := &MockConnectionGater{ctrl: ctrl} + mock.recorder = &MockConnectionGaterMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockConnectionGater) EXPECT() *MockConnectionGaterMockRecorder { + return m.recorder +} + +// InterceptAccept mocks base method. +func (m *MockConnectionGater) InterceptAccept(arg0 network.ConnMultiaddrs) bool { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "InterceptAccept", arg0) + ret0, _ := ret[0].(bool) + return ret0 +} + +// InterceptAccept indicates an expected call of InterceptAccept. +func (mr *MockConnectionGaterMockRecorder) InterceptAccept(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "InterceptAccept", reflect.TypeOf((*MockConnectionGater)(nil).InterceptAccept), arg0) +} + +// InterceptAddrDial mocks base method. +func (m *MockConnectionGater) InterceptAddrDial(arg0 peer.ID, arg1 multiaddr.Multiaddr) bool { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "InterceptAddrDial", arg0, arg1) + ret0, _ := ret[0].(bool) + return ret0 +} + +// InterceptAddrDial indicates an expected call of InterceptAddrDial. +func (mr *MockConnectionGaterMockRecorder) InterceptAddrDial(arg0, arg1 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "InterceptAddrDial", reflect.TypeOf((*MockConnectionGater)(nil).InterceptAddrDial), arg0, arg1) +} + +// InterceptPeerDial mocks base method. +func (m *MockConnectionGater) InterceptPeerDial(arg0 peer.ID) bool { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "InterceptPeerDial", arg0) + ret0, _ := ret[0].(bool) + return ret0 +} + +// InterceptPeerDial indicates an expected call of InterceptPeerDial. +func (mr *MockConnectionGaterMockRecorder) InterceptPeerDial(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "InterceptPeerDial", reflect.TypeOf((*MockConnectionGater)(nil).InterceptPeerDial), arg0) +} + +// InterceptSecured mocks base method. +func (m *MockConnectionGater) InterceptSecured(arg0 network.Direction, arg1 peer.ID, arg2 network.ConnMultiaddrs) bool { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "InterceptSecured", arg0, arg1, arg2) + ret0, _ := ret[0].(bool) + return ret0 +} + +// InterceptSecured indicates an expected call of InterceptSecured. +func (mr *MockConnectionGaterMockRecorder) InterceptSecured(arg0, arg1, arg2 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "InterceptSecured", reflect.TypeOf((*MockConnectionGater)(nil).InterceptSecured), arg0, arg1, arg2) +} + +// InterceptUpgraded mocks base method. +func (m *MockConnectionGater) InterceptUpgraded(arg0 network.Conn) (bool, control.DisconnectReason) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "InterceptUpgraded", arg0) + ret0, _ := ret[0].(bool) + ret1, _ := ret[1].(control.DisconnectReason) + return ret0, ret1 +} + +// InterceptUpgraded indicates an expected call of InterceptUpgraded. +func (mr *MockConnectionGaterMockRecorder) InterceptUpgraded(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "InterceptUpgraded", reflect.TypeOf((*MockConnectionGater)(nil).InterceptUpgraded), arg0) +} From 95cb582d9dc0ad57ab3c9401ee96ea5f33b40cd5 Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Wed, 5 Apr 2023 13:20:40 +0900 Subject: [PATCH 2/3] tests: integrate connection gating tests into the transport tests --- .../gating_test.go | 126 +++++------------- .../mock_connection_gater_test.go | 4 +- ...{integration_test.go => transport_test.go} | 9 +- 3 files changed, 46 insertions(+), 93 deletions(-) rename p2p/test/{connectiongating => transport}/gating_test.go (64%) rename p2p/test/{connectiongating => transport}/mock_connection_gater_test.go (97%) rename p2p/test/transport/{integration_test.go => transport_test.go} (97%) diff --git a/p2p/test/connectiongating/gating_test.go b/p2p/test/transport/gating_test.go similarity index 64% rename from p2p/test/connectiongating/gating_test.go rename to p2p/test/transport/gating_test.go index 5c0888e174..5b87d227a9 100644 --- a/p2p/test/connectiongating/gating_test.go +++ b/p2p/test/transport/gating_test.go @@ -1,12 +1,10 @@ -package connectiongating +package transport_integration import ( "context" - "fmt" "testing" "time" - "github.com/libp2p/go-libp2p" "github.com/libp2p/go-libp2p/core/network" "github.com/libp2p/go-libp2p/core/peer" "github.com/libp2p/go-libp2p/core/protocol" @@ -17,21 +15,7 @@ import ( "github.com/stretchr/testify/require" ) -//go:generate go run github.com/golang/mock/mockgen -package connectiongating -destination mock_connection_gater_test.go github.com/libp2p/go-libp2p/core/connmgr ConnectionGater - -// This list should contain (at least) one address for every transport we have. -var addrs = []ma.Multiaddr{ - ma.StringCast("/ip4/127.0.0.1/tcp/0"), - ma.StringCast("/ip4/127.0.0.1/tcp/0/ws"), - ma.StringCast("/ip4/127.0.0.1/udp/0/quic"), - ma.StringCast("/ip4/127.0.0.1/udp/0/quic-v1"), - ma.StringCast("/ip4/127.0.0.1/udp/0/quic-v1/webtransport"), -} - -func transportName(a ma.Multiaddr) string { - _, tr := ma.SplitLast(a) - return tr.Protocol().Name -} +//go:generate go run github.com/golang/mock/mockgen -package transport_integration -destination mock_connection_gater_test.go github.com/libp2p/go-libp2p/core/connmgr ConnectionGater func stripCertHash(addr ma.Multiaddr) ma.Multiaddr { for { @@ -44,18 +28,14 @@ func stripCertHash(addr ma.Multiaddr) ma.Multiaddr { } func TestInterceptPeerDial(t *testing.T) { - for _, a := range addrs { - t.Run(fmt.Sprintf("dialing %s", transportName(a)), func(t *testing.T) { + for _, tc := range transportsToTest { + t.Run(tc.Name, func(t *testing.T) { ctrl := gomock.NewController(t) defer ctrl.Finish() connGater := NewMockConnectionGater(ctrl) - h1, err := libp2p.New(libp2p.ConnectionGater(connGater)) - require.NoError(t, err) - defer h1.Close() - h2, err := libp2p.New(libp2p.ListenAddrs(a)) - require.NoError(t, err) - defer h2.Close() + h1 := tc.HostGenerator(t, TransportTestCaseOpts{NoListen: true, ConnGater: connGater}) + h2 := tc.HostGenerator(t, TransportTestCaseOpts{}) require.Len(t, h2.Addrs(), 1) ctx, cancel := context.WithTimeout(context.Background(), time.Second) @@ -67,18 +47,14 @@ func TestInterceptPeerDial(t *testing.T) { } func TestInterceptAddrDial(t *testing.T) { - for _, a := range addrs { - t.Run(fmt.Sprintf("dialing %s", transportName(a)), func(t *testing.T) { + for _, tc := range transportsToTest { + t.Run(tc.Name, func(t *testing.T) { ctrl := gomock.NewController(t) defer ctrl.Finish() connGater := NewMockConnectionGater(ctrl) - h1, err := libp2p.New(libp2p.ConnectionGater(connGater)) - require.NoError(t, err) - defer h1.Close() - h2, err := libp2p.New(libp2p.ListenAddrs(a)) - require.NoError(t, err) - defer h2.Close() + h1 := tc.HostGenerator(t, TransportTestCaseOpts{NoListen: true, ConnGater: connGater}) + h2 := tc.HostGenerator(t, TransportTestCaseOpts{}) require.Len(t, h2.Addrs(), 1) ctx, cancel := context.WithTimeout(context.Background(), time.Second) @@ -93,18 +69,15 @@ func TestInterceptAddrDial(t *testing.T) { } func TestInterceptSecuredOutgoing(t *testing.T) { - for _, a := range addrs { - t.Run(fmt.Sprintf("dialing %s", transportName(a)), func(t *testing.T) { + for _, tc := range transportsToTest { + t.Run(tc.Name, func(t *testing.T) { ctrl := gomock.NewController(t) defer ctrl.Finish() connGater := NewMockConnectionGater(ctrl) - h1, err := libp2p.New(libp2p.ConnectionGater(connGater)) - require.NoError(t, err) - defer h1.Close() - h2, err := libp2p.New(libp2p.ListenAddrs(a)) - require.NoError(t, err) - defer h2.Close() + h1 := tc.HostGenerator(t, TransportTestCaseOpts{NoListen: true, ConnGater: connGater}) + h2 := tc.HostGenerator(t, TransportTestCaseOpts{}) + require.Len(t, h2.Addrs(), 1) require.Len(t, h2.Addrs(), 1) ctx, cancel := context.WithTimeout(context.Background(), time.Second) @@ -117,7 +90,7 @@ func TestInterceptSecuredOutgoing(t *testing.T) { require.Equal(t, stripCertHash(h2.Addrs()[0]).String(), addrs.RemoteMultiaddr().String()) }), ) - err = h1.Connect(ctx, peer.AddrInfo{ID: h2.ID(), Addrs: h2.Addrs()}) + err := h1.Connect(ctx, peer.AddrInfo{ID: h2.ID(), Addrs: h2.Addrs()}) require.Error(t, err) require.NotErrorIs(t, err, context.DeadlineExceeded) }) @@ -125,18 +98,15 @@ func TestInterceptSecuredOutgoing(t *testing.T) { } func TestInterceptUpgradedOutgoing(t *testing.T) { - for _, a := range addrs { - t.Run(fmt.Sprintf("dialing %s", transportName(a)), func(t *testing.T) { + for _, tc := range transportsToTest { + t.Run(tc.Name, func(t *testing.T) { ctrl := gomock.NewController(t) defer ctrl.Finish() connGater := NewMockConnectionGater(ctrl) - h1, err := libp2p.New(libp2p.ConnectionGater(connGater)) - require.NoError(t, err) - defer h1.Close() - h2, err := libp2p.New(libp2p.ListenAddrs(a)) - require.NoError(t, err) - defer h2.Close() + h1 := tc.HostGenerator(t, TransportTestCaseOpts{NoListen: true, ConnGater: connGater}) + h2 := tc.HostGenerator(t, TransportTestCaseOpts{}) + require.Len(t, h2.Addrs(), 1) require.Len(t, h2.Addrs(), 1) ctx, cancel := context.WithTimeout(context.Background(), time.Second) @@ -151,7 +121,7 @@ func TestInterceptUpgradedOutgoing(t *testing.T) { require.Equal(t, h1.ID(), c.LocalPeer()) require.Equal(t, h2.ID(), c.RemotePeer()) })) - err = h1.Connect(ctx, peer.AddrInfo{ID: h2.ID(), Addrs: h2.Addrs()}) + err := h1.Connect(ctx, peer.AddrInfo{ID: h2.ID(), Addrs: h2.Addrs()}) require.Error(t, err) require.NotErrorIs(t, err, context.DeadlineExceeded) }) @@ -159,21 +129,14 @@ func TestInterceptUpgradedOutgoing(t *testing.T) { } func TestInterceptAccept(t *testing.T) { - for _, a := range addrs { - t.Run(fmt.Sprintf("accepting %s", transportName(a)), func(t *testing.T) { + for _, tc := range transportsToTest { + t.Run(tc.Name, func(t *testing.T) { ctrl := gomock.NewController(t) defer ctrl.Finish() connGater := NewMockConnectionGater(ctrl) - h1, err := libp2p.New() - require.NoError(t, err) - defer h1.Close() - h2, err := libp2p.New( - libp2p.ListenAddrs(a), - libp2p.ConnectionGater(connGater), - ) - require.NoError(t, err) - defer h2.Close() + h1 := tc.HostGenerator(t, TransportTestCaseOpts{NoListen: true}) + h2 := tc.HostGenerator(t, TransportTestCaseOpts{ConnGater: connGater}) require.Len(t, h2.Addrs(), 1) ctx, cancel := context.WithTimeout(context.Background(), time.Second) @@ -184,7 +147,7 @@ func TestInterceptAccept(t *testing.T) { require.Equal(t, stripCertHash(h2.Addrs()[0]), addrs.LocalMultiaddr()) }) h1.Peerstore().AddAddrs(h2.ID(), h2.Addrs(), time.Hour) - _, err = h1.NewStream(ctx, h2.ID(), protocol.TestingID) + _, err := h1.NewStream(ctx, h2.ID(), protocol.TestingID) require.Error(t, err) require.NotErrorIs(t, err, context.DeadlineExceeded) }) @@ -192,21 +155,14 @@ func TestInterceptAccept(t *testing.T) { } func TestInterceptSecuredIncoming(t *testing.T) { - for _, a := range addrs { - t.Run(fmt.Sprintf("accepting %s", transportName(a)), func(t *testing.T) { + for _, tc := range transportsToTest { + t.Run(tc.Name, func(t *testing.T) { ctrl := gomock.NewController(t) defer ctrl.Finish() connGater := NewMockConnectionGater(ctrl) - h1, err := libp2p.New() - require.NoError(t, err) - defer h1.Close() - h2, err := libp2p.New( - libp2p.ListenAddrs(a), - libp2p.ConnectionGater(connGater), - ) - require.NoError(t, err) - defer h2.Close() + h1 := tc.HostGenerator(t, TransportTestCaseOpts{NoListen: true}) + h2 := tc.HostGenerator(t, TransportTestCaseOpts{ConnGater: connGater}) require.Len(t, h2.Addrs(), 1) ctx, cancel := context.WithTimeout(context.Background(), time.Second) @@ -219,7 +175,7 @@ func TestInterceptSecuredIncoming(t *testing.T) { }), ) h1.Peerstore().AddAddrs(h2.ID(), h2.Addrs(), time.Hour) - _, err = h1.NewStream(ctx, h2.ID(), protocol.TestingID) + _, err := h1.NewStream(ctx, h2.ID(), protocol.TestingID) require.Error(t, err) require.NotErrorIs(t, err, context.DeadlineExceeded) }) @@ -227,22 +183,14 @@ func TestInterceptSecuredIncoming(t *testing.T) { } func TestInterceptUpgradedIncoming(t *testing.T) { - for _, a := range addrs { - _, tr := ma.SplitLast(a) - t.Run(fmt.Sprintf("accepting %s", tr.Protocol().Name), func(t *testing.T) { + for _, tc := range transportsToTest { + t.Run(tc.Name, func(t *testing.T) { ctrl := gomock.NewController(t) defer ctrl.Finish() connGater := NewMockConnectionGater(ctrl) - h1, err := libp2p.New() - require.NoError(t, err) - defer h1.Close() - h2, err := libp2p.New( - libp2p.ListenAddrs(a), - libp2p.ConnectionGater(connGater), - ) - require.NoError(t, err) - defer h2.Close() + h1 := tc.HostGenerator(t, TransportTestCaseOpts{NoListen: true}) + h2 := tc.HostGenerator(t, TransportTestCaseOpts{ConnGater: connGater}) require.Len(t, h2.Addrs(), 1) ctx, cancel := context.WithTimeout(context.Background(), time.Second) @@ -258,7 +206,7 @@ func TestInterceptUpgradedIncoming(t *testing.T) { }), ) h1.Peerstore().AddAddrs(h2.ID(), h2.Addrs(), time.Hour) - _, err = h1.NewStream(ctx, h2.ID(), protocol.TestingID) + _, err := h1.NewStream(ctx, h2.ID(), protocol.TestingID) require.Error(t, err) require.NotErrorIs(t, err, context.DeadlineExceeded) }) diff --git a/p2p/test/connectiongating/mock_connection_gater_test.go b/p2p/test/transport/mock_connection_gater_test.go similarity index 97% rename from p2p/test/connectiongating/mock_connection_gater_test.go rename to p2p/test/transport/mock_connection_gater_test.go index 54be42e563..d6efc8b022 100644 --- a/p2p/test/connectiongating/mock_connection_gater_test.go +++ b/p2p/test/transport/mock_connection_gater_test.go @@ -1,8 +1,8 @@ // Code generated by MockGen. DO NOT EDIT. // Source: github.com/libp2p/go-libp2p/core/connmgr (interfaces: ConnectionGater) -// Package connectiongating is a generated GoMock package. -package connectiongating +// Package transport_integration is a generated GoMock package. +package transport_integration import ( reflect "reflect" diff --git a/p2p/test/transport/integration_test.go b/p2p/test/transport/transport_test.go similarity index 97% rename from p2p/test/transport/integration_test.go rename to p2p/test/transport/transport_test.go index 578c68d087..bb8912ade2 100644 --- a/p2p/test/transport/integration_test.go +++ b/p2p/test/transport/transport_test.go @@ -13,6 +13,7 @@ import ( "github.com/libp2p/go-libp2p" "github.com/libp2p/go-libp2p/config" + "github.com/libp2p/go-libp2p/core/connmgr" "github.com/libp2p/go-libp2p/core/host" "github.com/libp2p/go-libp2p/core/network" "github.com/libp2p/go-libp2p/core/peer" @@ -30,8 +31,9 @@ type TransportTestCase struct { } type TransportTestCaseOpts struct { - NoListen bool - NoRcmgr bool + NoListen bool + NoRcmgr bool + ConnGater connmgr.ConnectionGater } func transformOpts(opts TransportTestCaseOpts) []config.Option { @@ -40,6 +42,9 @@ func transformOpts(opts TransportTestCaseOpts) []config.Option { if opts.NoRcmgr { libp2pOpts = append(libp2pOpts, libp2p.ResourceManager(&network.NullResourceManager{})) } + if opts.ConnGater != nil { + libp2pOpts = append(libp2pOpts, libp2p.ConnectionGater(opts.ConnGater)) + } return libp2pOpts } From e4c09d414cb98723b9902178bd2443c2a9237cea Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Tue, 9 May 2023 12:17:03 +0300 Subject: [PATCH 3/3] disable gating tests with race detector --- p2p/test/transport/gating_test.go | 23 +++++++++++++++++++++++ 1 file changed, 23 insertions(+) diff --git a/p2p/test/transport/gating_test.go b/p2p/test/transport/gating_test.go index 5b87d227a9..426fc906e5 100644 --- a/p2p/test/transport/gating_test.go +++ b/p2p/test/transport/gating_test.go @@ -10,6 +10,8 @@ import ( "github.com/libp2p/go-libp2p/core/protocol" "github.com/libp2p/go-libp2p/p2p/net/swarm" + "github.com/libp2p/go-libp2p-testing/race" + "github.com/golang/mock/gomock" ma "github.com/multiformats/go-multiaddr" "github.com/stretchr/testify/require" @@ -28,6 +30,9 @@ func stripCertHash(addr ma.Multiaddr) ma.Multiaddr { } func TestInterceptPeerDial(t *testing.T) { + if race.WithRace() { + t.Skip("The upgrader spawns a new Go routine, which leads to race conditions when using GoMock.") + } for _, tc := range transportsToTest { t.Run(tc.Name, func(t *testing.T) { ctrl := gomock.NewController(t) @@ -47,6 +52,9 @@ func TestInterceptPeerDial(t *testing.T) { } func TestInterceptAddrDial(t *testing.T) { + if race.WithRace() { + t.Skip("The upgrader spawns a new Go routine, which leads to race conditions when using GoMock.") + } for _, tc := range transportsToTest { t.Run(tc.Name, func(t *testing.T) { ctrl := gomock.NewController(t) @@ -69,6 +77,9 @@ func TestInterceptAddrDial(t *testing.T) { } func TestInterceptSecuredOutgoing(t *testing.T) { + if race.WithRace() { + t.Skip("The upgrader spawns a new Go routine, which leads to race conditions when using GoMock.") + } for _, tc := range transportsToTest { t.Run(tc.Name, func(t *testing.T) { ctrl := gomock.NewController(t) @@ -98,6 +109,9 @@ func TestInterceptSecuredOutgoing(t *testing.T) { } func TestInterceptUpgradedOutgoing(t *testing.T) { + if race.WithRace() { + t.Skip("The upgrader spawns a new Go routine, which leads to race conditions when using GoMock.") + } for _, tc := range transportsToTest { t.Run(tc.Name, func(t *testing.T) { ctrl := gomock.NewController(t) @@ -129,6 +143,9 @@ func TestInterceptUpgradedOutgoing(t *testing.T) { } func TestInterceptAccept(t *testing.T) { + if race.WithRace() { + t.Skip("The upgrader spawns a new Go routine, which leads to race conditions when using GoMock.") + } for _, tc := range transportsToTest { t.Run(tc.Name, func(t *testing.T) { ctrl := gomock.NewController(t) @@ -155,6 +172,9 @@ func TestInterceptAccept(t *testing.T) { } func TestInterceptSecuredIncoming(t *testing.T) { + if race.WithRace() { + t.Skip("The upgrader spawns a new Go routine, which leads to race conditions when using GoMock.") + } for _, tc := range transportsToTest { t.Run(tc.Name, func(t *testing.T) { ctrl := gomock.NewController(t) @@ -183,6 +203,9 @@ func TestInterceptSecuredIncoming(t *testing.T) { } func TestInterceptUpgradedIncoming(t *testing.T) { + if race.WithRace() { + t.Skip("The upgrader spawns a new Go routine, which leads to race conditions when using GoMock.") + } for _, tc := range transportsToTest { t.Run(tc.Name, func(t *testing.T) { ctrl := gomock.NewController(t)