diff --git a/index.js b/index.js index 40c86805..31a09481 100644 --- a/index.js +++ b/index.js @@ -12,7 +12,6 @@ const Server = require('./lib/server') const connect = require('./lib/connect') const { FIREWALL, BOOTSTRAP_NODES, KNOWN_NODES, COMMANDS } = require('./lib/constants') const { hash, createKeyPair } = require('./lib/crypto') -const { decode } = require('hypercore-id-encoding') const RawStreamSet = require('./lib/raw-stream-set') const ConnectionPool = require('./lib/connection-pool') const { STREAM_NOT_CONNECTED } = require('./lib/errors') @@ -78,7 +77,7 @@ class HyperDHT extends DHT { static DEFAULTS = DEFAULTS connect(remotePublicKey, opts) { - return connect(this, decode(remotePublicKey), opts) + return connect(this, remotePublicKey, opts) } createServer(opts, onconnection) { diff --git a/lib/connect.js b/lib/connect.js index bad3c896..3dc18fdc 100644 --- a/lib/connect.js +++ b/lib/connect.js @@ -28,14 +28,22 @@ const { RELAY_ABORTED, SUSPENDED } = require('./errors') +const { decode } = require('hypercore-id-encoding') +const HyperDHTAddress = require('hyperdht-address') module.exports = function connect(dht, publicKey, opts = {}) { const pool = opts.pool || null + const { key, nodes: providedNodes } = HyperDHTAddress.decode( + b4a.isBuffer(publicKey) ? publicKey : decode(publicKey) + ) + publicKey = key + if (pool && pool.has(publicKey)) return pool.get(publicKey) publicKey = unslab(publicKey) + opts.relayAddresses = opts.relayAddresses || providedNodes || [] const keyPair = opts.keyPair || dht.defaultKeyPair const relayThrough = selectRelay(opts.relayThrough || null) const encryptedSocket = (opts.createSecretStream || defaultCreateSecretStream)(true, null, { @@ -58,7 +66,7 @@ module.exports = function connect(dht, publicKey, opts = {}) { id, dht, session: dht.session(), - relayAddresses: opts.relayAddresses || [], + relayAddresses: opts.relayAddresses, remoteRelayAddresses: [], pool, round: 0, @@ -315,36 +323,51 @@ async function holepunch(c, opts) { } } +async function connectThroughNodes(c, addresses, socket) { + for (const address of addresses) { + if (isDone(c) || c.connect) return + + c.remoteRelayAddresses.push(address) + await connectThroughNode(c, address, socket) + } +} + async function findAndConnect(c, opts) { let attempts = 0 - let closestNodes = opts.relayAddresses && opts.relayAddresses.length ? opts.relayAddresses : null + let relayAddresses = + opts.relayAddresses && opts.relayAddresses.length ? opts.relayAddresses : null - if (!closestNodes) { + if (!relayAddresses) { const cachedRelayAddresses = c.dht._relayAddressesCache.get(c.id) - if (cachedRelayAddresses) closestNodes = cachedRelayAddresses + if (cachedRelayAddresses) relayAddresses = cachedRelayAddresses } if (c.dht._persistent) { // check if we know the route ourself... const route = c.dht._router.get(c.target) if (route && route.relay !== null) { - closestNodes = [{ host: route.relay.host, port: route.relay.port }] + relayAddresses = [{ host: route.relay.host, port: route.relay.port }] } } // 2 is how many parallel connect attempts we want to do, we can make this configurable - const sem = new Semaphore(2) + const preConnect = relayAddresses !== null && relayAddresses.length > 0 + const sem = new Semaphore(preConnect ? 3 : 2) const signal = sem.signal.bind(sem) - const tries = closestNodes !== null ? 2 : 1 + const tries = relayAddresses !== null ? 2 : 1 + + if (preConnect) { + await sem.wait() + connectThroughNodes(c, relayAddresses, null).then(signal, signal) + } try { for (let i = 0; i < tries && !isDone(c) && !c.connect; i++) { c.query = c.dht.findPeer(c.target, { hash: false, session: c.session, - closestNodes, - onlyClosestNodes: closestNodes !== null, - retries: closestNodes ? 1 : 3 + nodes: relayAddresses, + retries: 3 }) for await (const data of c.query) { @@ -356,12 +379,18 @@ async function findAndConnect(c, opts) { break } + // Skip node already run via preConnect + if (preConnect && relayAddresses && isRelayAddress(relayAddresses, data)) { + sem.signal() + continue + } + c.remoteRelayAddresses.push(data.from) attempts++ connectThroughNode(c, data.from, null).then(signal, signal) } - closestNodes = null + relayAddresses = null if (attempts > 0) await sem.flush() } @@ -847,4 +876,12 @@ function selectRelay(relayThrough) { return relayThrough } +function isRelayAddress(relayAddresses, data) { + for (const node of relayAddresses) { + if (node.host === data.from.host && node.port === data.from.port) return true + } + + return false +} + function noop() {} diff --git a/package.json b/package.json index 8eb19986..1e5cc477 100644 --- a/package.json +++ b/package.json @@ -35,6 +35,7 @@ "dht-rpc": "^6.15.1", "hypercore-crypto": "^3.3.0", "hypercore-id-encoding": "^1.2.0", + "hyperdht-address": "^1.0.1", "noise-curve-ed": "^2.0.0", "noise-handshake": "^4.0.0", "record-cache": "^1.1.1", diff --git a/test/all.js b/test/all.js index 6c5c24f7..a74f2446 100644 --- a/test/all.js +++ b/test/all.js @@ -8,6 +8,7 @@ async function runTests() { test.pause() await import('./announces.js') + await import('./cache.js') await import('./connections.js') await import('./holepuncher.js') await import('./lifecycle.js') diff --git a/test/cache.js b/test/cache.js new file mode 100644 index 00000000..6909dfbb --- /dev/null +++ b/test/cache.js @@ -0,0 +1,50 @@ +const test = require('brittle') +const { swarm } = require('./helpers') +const HyperDHTAddress = require('hyperdht-address') + +test('cache - key with nodes', async function (t) { + t.plan(3) + + const [a, b] = await swarm(t) + const ts = t.test('server') + + ts.plan(2) + + const server = a.createServer(function (socket) { + ts.pass('server side opened') + + socket.once('end', function () { + ts.pass('server side ended') + socket.end() + }) + }) + + await server.listen() + + { + const tn = t.test('client w/nodes') + tn.plan(2) + + const target = HyperDHTAddress.encode(server.publicKey, server.relayAddresses) + + const socket = b.connect(target) + + socket.once('open', function () { + tn.pass('client side opened') + }) + + socket.once('end', function () { + tn.pass('client side ended') + }) + + socket.end() + + await tn + } + + server.on('close', function () { + t.pass('server closed') + }) + + await server.close() +}) diff --git a/test/pool.js b/test/pool.js index bcc8b1c2..6712e1ea 100644 --- a/test/pool.js +++ b/test/pool.js @@ -46,12 +46,16 @@ test('connection pool, server side', async function (t) { const open = t.test('open') open.plan(2) + let atLeastOneOpen = false { const socket = b.connect(server.publicKey) socket .on('open', () => { + if (atLeastOneOpen) return + open.pass('1st stream opened') + atLeastOneOpen = true }) .on('error', () => { open.pass('1st stream errored') @@ -62,7 +66,10 @@ test('connection pool, server side', async function (t) { const socket = b.connect(server.publicKey) socket .on('open', () => { + if (atLeastOneOpen) return + open.pass('2nd stream opened') + atLeastOneOpen = true }) .on('error', () => { open.pass('2nd stream errored') @@ -71,6 +78,7 @@ test('connection pool, server side', async function (t) { } await open + t.ok(atLeastOneOpen, 'verify one client opened') await server.close() })