diff --git a/rd-kt/rd-framework/src/main/kotlin/com/jetbrains/rd/framework/SocketWire.kt b/rd-kt/rd-framework/src/main/kotlin/com/jetbrains/rd/framework/SocketWire.kt index 1a136bf8e..cf01231ae 100644 --- a/rd-kt/rd-framework/src/main/kotlin/com/jetbrains/rd/framework/SocketWire.kt +++ b/rd-kt/rd-framework/src/main/kotlin/com/jetbrains/rd/framework/SocketWire.kt @@ -1,6 +1,9 @@ package com.jetbrains.rd.framework +import com.jetbrains.rd.framework.WireAddress.Companion.toSocketAddress import com.jetbrains.rd.framework.base.WireBase +import com.jetbrains.rd.framework.util.getInputStream +import com.jetbrains.rd.framework.util.getOutputStream import com.jetbrains.rd.util.* import com.jetbrains.rd.util.lifetime.Lifetime import com.jetbrains.rd.util.lifetime.isAlive @@ -15,6 +18,11 @@ import java.io.IOException import java.io.InputStream import java.io.OutputStream import java.net.* +import java.nio.channels.AsynchronousCloseException +import java.nio.channels.ClosedChannelException +import java.nio.channels.ServerSocketChannel +import java.nio.channels.SocketChannel +import java.nio.file.Path import java.time.Duration import java.util.concurrent.ExecutorService import java.util.concurrent.Executors @@ -75,7 +83,7 @@ class SocketWire { abstract class Base protected constructor(val id: String, private val lifetime: Lifetime, scheduler: IScheduler) : WireBase() { protected val logger: Logger = getLogger(this::class) - val socketProvider = OptProperty() + val socketProvider = OptProperty() private lateinit var output : OutputStream private lateinit var socketInput : InputStream @@ -111,12 +119,12 @@ class SocketWire { connected.advise(lifetime) { heartbeatAlive.value = it } - socketProvider.advise(lifetime) { socket -> + socketProvider.advise(lifetime) { socketChannel -> logger.debug { "$id : connected" } - output = socket.outputStream - socketInput = socket.inputStream.buffered() + output = socketChannel.getOutputStream() + socketInput = socketChannel.getInputStream().buffered() pkgInput = PkgInputStream(socketInput) sendBuffer.reprocessUnacknowledged() @@ -127,12 +135,12 @@ class SocketWire { scheduler.queue { connected.value = true } try { - receiverProc(socket) + receiverProc(socketChannel) } finally { scheduler.queue { connected.value = false } heartbeatJob.cancel() sendBuffer.pause(disconnectedPauseReason) - catchAndDrop { socket.close() } + catchAndDrop { socketChannel.close() } } } } @@ -175,14 +183,9 @@ class SocketWire { } - private fun receiverProc(socket: Socket) { + private fun receiverProc(socket: SocketChannel) { while (lifetime.isAlive) { try { - if (!socket.isConnected) { - logger.debug { "Stop receive messages because socket disconnected" } - break - } - if (!readMsg()) { logger.debug { "$id: Connection was gracefully shutdown" } break @@ -299,7 +302,11 @@ class SocketWire { synchronized(socketSendLock) { output.write(ackPkgHeader.getArray(), 0, pkg_header_len) } - } catch (ex: SocketException) { + } + catch(ex: ClosedChannelException) { + logger.warn { "$id: Exception raised during ACK, seqn = $seqn" } + } + catch (ex: SocketException) { logger.warn { "$id: Exception raised during ACK, seqn = $seqn" } } } @@ -345,7 +352,9 @@ class SocketWire { } } catch (ex: SocketException) { sendBuffer.pause(disconnectedPauseReason) - + } + catch (ex: IOException) { + sendBuffer.pause(disconnectedPauseReason) } } @@ -384,26 +393,39 @@ class SocketWire { } - class Client(lifetime : Lifetime, scheduler: IScheduler, port : Int, optId: String? = null, hostAddress: InetAddress = InetAddress.getLoopbackAddress()) : Base(optId ?:"ClientSocket", lifetime, scheduler) { + class Client internal constructor(lifetime : Lifetime, scheduler: IScheduler, endpoint: SocketAddress, optId: String? = null) : Base(optId ?:"ClientSocket", lifetime, scheduler) { + + constructor(lifetime : Lifetime, scheduler: IScheduler, wireAddress: WireAddress, optId: String? = null) : this(lifetime, scheduler, wireAddress.toSocketAddress(), optId) + + constructor( + lifetime: Lifetime, + scheduler: IScheduler, + port: Int, + optId: String? = null, + hostAddress: InetAddress = InetAddress.getLoopbackAddress() + ) : this(lifetime, scheduler, InetSocketAddress(hostAddress, port), optId) init { - var socket : Socket? = null + var socket : SocketChannel? = null val thread = thread(name = id, isDaemon = true) { try { var lastReportedErrorHash = 0 while (lifetime.isAlive) { try { - val s = Socket() - s.tcpNoDelay = true + val s = when (endpoint) { + is InetSocketAddress -> SocketChannel.open().apply { setOption(StandardSocketOptions.TCP_NODELAY, true) } + is UnixDomainSocketAddress -> SocketChannel.open(StandardProtocolFamily.UNIX) + else -> throw IllegalArgumentException("Only InetSocketAddress and UnixDomainSocketAddress are supported, got: $endpoint") + } // On windows connect will try to send SYN 3 times with interval of 500ms (total time is 1second) // Connect timeout doesn't work if it's more than 1 second. But we don't need it because we can close socket any moment. //https://stackoverflow.com/questions/22417228/prevent-tcp-socket-connection-retries //HKLM\SYSTEM\CurrentControlSet\Services\Tcpip\Parameters\TcpMaxConnectRetransmissions - logger.debug { "$id : connecting to $hostAddress:$port" } - s.connect(InetSocketAddress(hostAddress, port)) + logger.debug { "$id : connecting to $endpoint" } + s.connect(endpoint) synchronized(lock) { if (!lifetime.isAlive) { @@ -426,12 +448,12 @@ class SocketWire { if (logger.isEnabled(LogLevel.Debug)) { logger.log( LogLevel.Debug, - "$id: connection error for endpoint $hostAddress:$port.", + "$id: connection error for endpoint $endpoint.", e ) } } else { - logger.debug { "$id: connection error for endpoint $hostAddress:$port (${e.message})." } + logger.debug { "$id: connection error for endpoint $endpoint (${e.message})." } } val shouldReconnect = synchronized(lock) { @@ -450,7 +472,11 @@ class SocketWire { } catch (ex: SocketException) { logger.info {"$id: closed with exception: $ex"} - } catch (ex: Throwable) { + } + catch (ex: ClosedChannelException) { + logger.info {"$id: closed with exception: $ex"} + } + catch (ex: Throwable) { logger.error("$id: unhandled exception.", ex) } finally { logger.debug { "$id: terminated." } @@ -482,34 +508,54 @@ class SocketWire { } - class Server internal constructor(lifetime : Lifetime, scheduler: IScheduler, ss: ServerSocket, optId: String? = null, allowReconnect: Boolean) : Base(optId ?:"ServerSocket", lifetime, scheduler) { - val port : Int = ss.localPort + class Server internal constructor(lifetime : Lifetime, scheduler: IScheduler, ss: ServerSocketChannel, optId: String? = null, allowReconnect: Boolean) : Base(optId ?:"ServerSocket", lifetime, scheduler) { + val wireAddress: WireAddress = WireAddress.fromSocketAddress(ss.localAddress) + + @Deprecated("Use wireAddress instead") + val port : Int = (wireAddress as? WireAddress.TcpAddress)?.port ?: -1 companion object { - internal fun createServerSocket(lifetime: Lifetime, port : Int?, allowRemoteConnections: Boolean) : ServerSocket { + internal fun createServerSocket(lifetime: Lifetime, port : Int?, allowRemoteConnections: Boolean) : ServerSocketChannel { val address = if (allowRemoteConnections) InetAddress.getByName("0.0.0.0") else InetAddress.getByName("127.0.0.1") val portToBind = port ?: 0 - val res = ServerSocket() - res.reuseAddress = true + val res = ServerSocketChannel.open().apply { setOption(StandardSocketOptions.SO_REUSEADDR, true) } res.bind(InetSocketAddress(address, portToBind), 0) lifetime.onTermination { res.close() } return res } + + internal fun createServerSocket(lifetime: Lifetime, endpoint: SocketAddress) : ServerSocketChannel { + val socketChannel = when (endpoint) { + is InetSocketAddress -> ServerSocketChannel.open().apply { setOption(StandardSocketOptions.SO_REUSEADDR, true) } + is UnixDomainSocketAddress -> ServerSocketChannel.open(StandardProtocolFamily.UNIX) + else -> throw IllegalArgumentException("Only InetSocketAddress and UnixDomainSocketAddress are supported, got: $endpoint") + } + + socketChannel.bind(endpoint, 0) + lifetime.onTermination { + socketChannel.close() + } + return socketChannel + } } constructor (lifetime : Lifetime, scheduler: IScheduler, port : Int?, optId: String? = null, allowRemoteConnections: Boolean = false) : this(lifetime, scheduler, createServerSocket(lifetime, port, allowRemoteConnections), optId, allowReconnect = true) + constructor (lifetime : Lifetime, scheduler: IScheduler, wireAddress: WireAddress, optId: String? = null) : this(lifetime, scheduler, wireAddress.toSocketAddress(), optId) + + internal constructor (lifetime : Lifetime, scheduler: IScheduler, endpoint: SocketAddress, optId: String? = null) : this(lifetime, scheduler, createServerSocket(lifetime, endpoint), optId, allowReconnect = true) init { - var socket : Socket? = null + var socket : SocketChannel? = null val thread = thread(name = id, isDaemon = true) { logger.catch { while (lifetime.isAlive) { try { - logger.debug { "$id: listening ${ss.localSocketAddress}" } + logger.debug { "$id: listening ${wireAddress}" } val s = ss.accept() //could be terminated by close - s.tcpNoDelay = true + if(s.localAddress is InetSocketAddress) + s.setOption(StandardSocketOptions.TCP_NODELAY, true) synchronized(lock) { if (!lifetime.isAlive) { @@ -520,19 +566,22 @@ class SocketWire { socket = s } - socketProvider.set(s) - } catch (ex: SocketException) { - logger.debug { "$id closed with exception: $ex" } - } catch (ex: Exception) { - logger.error("$id closed with exception", ex) + } + catch (ex: Exception) { + when(ex) { + is SocketException, is AsynchronousCloseException, is ClosedChannelException -> { + logger.debug { "$id closed with exception: $ex" } + } + else -> logger.error("$id closed with exception", ex) + } } if (!allowReconnect) { - logger.debug { "$id: finished listening on ${ss.localSocketAddress}." } + logger.debug { "$id: finished listening on ${wireAddress}." } break } else { - logger.debug { "$id: waiting for reconnection on ${ss.localSocketAddress}." } + logger.debug { "$id: waiting for reconnection on ${wireAddress}." } } } } @@ -580,7 +629,7 @@ class SocketWire { init { val ss = Server.createServerSocket(lifetime, port, allowRemoteConnections) - localPort = ss.localPort + localPort = (ss.localAddress as? InetSocketAddress)?.port ?: -1 fun rec() { lifetime.executeIfAlive { @@ -595,15 +644,23 @@ class SocketWire { rec() } - - } - } +sealed class WireAddress { + data class TcpAddress(val address: InetAddress, val port: Int): WireAddress() + data class UnixAddress(val path: Path): WireAddress() -//todo remove -val IWire.serverPort: Int get() { - val serverSocketWire = this as? SocketWire.Server ?: throw IllegalArgumentException("You must use SocketWire.Server to get server port") - return serverSocketWire.port -} + internal companion object { + fun fromSocketAddress(socketAddress: SocketAddress): WireAddress = when (socketAddress) { + is InetSocketAddress -> TcpAddress(socketAddress.address, socketAddress.port) + is UnixDomainSocketAddress -> UnixAddress(socketAddress.path) + else -> error("Unknown socket address type: $socketAddress") + } + + fun WireAddress.toSocketAddress() = when (this) { + is TcpAddress -> InetSocketAddress(address, port) + is UnixAddress -> UnixDomainSocketAddress.of(path) + } + } +} \ No newline at end of file diff --git a/rd-kt/rd-framework/src/main/kotlin/com/jetbrains/rd/framework/util/NetUtils.kt b/rd-kt/rd-framework/src/main/kotlin/com/jetbrains/rd/framework/util/NetUtils.kt index b6c68e346..f24ada61d 100644 --- a/rd-kt/rd-framework/src/main/kotlin/com/jetbrains/rd/framework/util/NetUtils.kt +++ b/rd-kt/rd-framework/src/main/kotlin/com/jetbrains/rd/framework/util/NetUtils.kt @@ -1,8 +1,15 @@ package com.jetbrains.rd.framework.util -import java.net.InetAddress +import java.io.IOException +import java.io.InputStream +import java.io.OutputStream import java.net.InetSocketAddress import java.net.ServerSocket +import java.nio.ByteBuffer +import java.nio.channels.Channels +import java.nio.channels.ReadableByteChannel +import java.nio.channels.SocketChannel +import java.nio.channels.WritableByteChannel object NetUtils { private fun isPortFree(port: Int?): Boolean { @@ -52,3 +59,36 @@ data class PortPair private constructor(val senderPort: Int, val receiverPort: I } } +fun SocketChannel.getInputStream(): InputStream { + val ch = this + return Channels.newInputStream(object : ReadableByteChannel { + override fun read(dst: ByteBuffer?): Int { + return ch.read(dst) + } + + override fun close() { + ch.close() + } + + override fun isOpen(): Boolean { + return ch.isOpen + } + }) +} + +fun SocketChannel.getOutputStream(): OutputStream { + val ch = this + return Channels.newOutputStream(object : WritableByteChannel { + override fun write(src: ByteBuffer?): Int { + return ch.write(src) + } + + override fun close() { + ch.close() + } + + override fun isOpen(): Boolean { + return ch.isOpen + } + }) +} \ No newline at end of file diff --git a/rd-kt/rd-framework/src/test/kotlin/com/jetbrains/rd/framework/test/cases/wire/SocketProxy.kt b/rd-kt/rd-framework/src/test/kotlin/com/jetbrains/rd/framework/test/cases/wire/SocketProxy.kt index af8136d18..61c0a3a86 100644 --- a/rd-kt/rd-framework/src/test/kotlin/com/jetbrains/rd/framework/test/cases/wire/SocketProxy.kt +++ b/rd-kt/rd-framework/src/test/kotlin/com/jetbrains/rd/framework/test/cases/wire/SocketProxy.kt @@ -2,7 +2,6 @@ package com.jetbrains.rd.framework.test.cases.wire import com.jetbrains.rd.framework.IProtocol import com.jetbrains.rd.framework.SocketWire -import com.jetbrains.rd.framework.serverPort import com.jetbrains.rd.util.Logger import com.jetbrains.rd.util.error import com.jetbrains.rd.util.getLogger @@ -52,7 +51,7 @@ class SocketProxy internal constructor(val id: String, val lifetime: Lifetime, p } internal constructor(id: String, lifetime: Lifetime, protocol: IProtocol) : - this(id, lifetime, protocol.wire.serverPort) + this(id, lifetime, (protocol.wire as SocketWire.Server).port) fun start() { fun setSocketOptions(acceptedClient: Socket) { @@ -62,7 +61,7 @@ class SocketProxy internal constructor(val id: String, val lifetime: Lifetime, p try { logger.info { "Creating proxies for server and client..." } proxyServer = Socket(InetAddress.getLoopbackAddress(), serverPort).also { lifetime.onTermination(it) } - proxyClient = SocketWire.Server.createServerSocket(lifetime, 0, false).also { lifetime.onTermination(it) } + proxyClient = SocketWire.Server.createServerSocket(lifetime, 0, false).also { lifetime.onTermination(it) }.socket() // TODO: Remove .socket() call and replace SocketServer to ServerSocketChannel setSocketOptions(proxyServer) _port = proxyClient.localPort diff --git a/rd-kt/rd-framework/src/test/kotlin/com/jetbrains/rd/framework/test/cases/wire/SocketProxyTest.kt b/rd-kt/rd-framework/src/test/kotlin/com/jetbrains/rd/framework/test/cases/wire/SocketProxyTest.kt index dee2c03f8..fa4f32b20 100644 --- a/rd-kt/rd-framework/src/test/kotlin/com/jetbrains/rd/framework/test/cases/wire/SocketProxyTest.kt +++ b/rd-kt/rd-framework/src/test/kotlin/com/jetbrains/rd/framework/test/cases/wire/SocketProxyTest.kt @@ -22,18 +22,18 @@ class SocketProxyTest : TestBase() { val proxyLifetimeDefinition = lifetime.createNested() val proxyLifetime = proxyLifetimeDefinition.lifetime - val serverProtocol = SocketWireTest.server(lifetime) + val serverProtocol = SocketWireTestBase.server(lifetime, 0) val proxy = SocketProxy("TestProxy", proxyLifetime, serverProtocol).apply { start() } Thread.sleep(DefaultTimeoutMs) - val clientProtocol = SocketWireTest.client(lifetime, proxy.port) + val clientProtocol = SocketWireTestBase.client(lifetime, proxy.port) val sp = RdSignal().static(1) - sp.bindTopLevel(lifetime, serverProtocol, SocketWireTest.top) + sp.bindTopLevel(lifetime, serverProtocol, SocketWireTestBase.top) val cp = RdSignal().static(1) - cp.bindTopLevel(lifetime, clientProtocol, SocketWireTest.top) + cp.bindTopLevel(lifetime, clientProtocol, SocketWireTestBase.top) val serverLog = mutableListOf() val clientLog = mutableListOf() diff --git a/rd-kt/rd-framework/src/test/kotlin/com/jetbrains/rd/framework/test/cases/wire/SocketWireTest.kt b/rd-kt/rd-framework/src/test/kotlin/com/jetbrains/rd/framework/test/cases/wire/SocketWireTestBase.kt similarity index 81% rename from rd-kt/rd-framework/src/test/kotlin/com/jetbrains/rd/framework/test/cases/wire/SocketWireTest.kt rename to rd-kt/rd-framework/src/test/kotlin/com/jetbrains/rd/framework/test/cases/wire/SocketWireTestBase.kt index 1fe6f7a62..ae4b521c1 100644 --- a/rd-kt/rd-framework/src/test/kotlin/com/jetbrains/rd/framework/test/cases/wire/SocketWireTest.kt +++ b/rd-kt/rd-framework/src/test/kotlin/com/jetbrains/rd/framework/test/cases/wire/SocketWireTestBase.kt @@ -1,6 +1,7 @@ package com.jetbrains.rd.framework.test.cases.wire import com.jetbrains.rd.framework.* +import com.jetbrains.rd.framework.WireAddress.Companion.toSocketAddress import com.jetbrains.rd.framework.base.bindTopLevel import com.jetbrains.rd.framework.base.static import com.jetbrains.rd.framework.impl.RdOptionalProperty @@ -17,13 +18,18 @@ import org.junit.jupiter.api.Test import org.junit.jupiter.params.ParameterizedTest import org.junit.jupiter.params.provider.ValueSource import java.net.InetAddress +import java.net.InetSocketAddress +import java.net.SocketAddress +import java.net.UnixDomainSocketAddress import java.util.concurrent.ConcurrentLinkedQueue import java.util.concurrent.TimeoutException +import kotlin.io.path.createTempFile +import kotlin.io.path.deleteIfExists -class SocketWireTest : TestBase() { +abstract class SocketWireTestBase : TestBase() { - private fun RdOptionalProperty.waitAndAssert(expected: T, prev: T? = null) { + protected fun RdOptionalProperty.waitAndAssert(expected: T, prev: T? = null) { val start = System.currentTimeMillis() while ((System.currentTimeMillis() - start) < timeoutToWaitConditionMs && valueOrNull != expected) Thread.sleep(100) @@ -34,25 +40,31 @@ class SocketWireTest : TestBase() { companion object { internal const val top = "top" - internal fun server(lifetime: Lifetime, port: Int? = null): Protocol { - return Protocol("Server", Serializers(), Identities(IdKind.Server), TestScheduler, + internal fun server(lifetime: Lifetime, port: Int): Protocol { + return Protocol("Server", Serializers(MarshallersProvider.Dummy), Identities(IdKind.Server), TestScheduler, SocketWire.Server(lifetime, TestScheduler, port, "TestServer"), lifetime ) } internal fun client(lifetime: Lifetime, serverProtocol: Protocol): Protocol { - return Protocol("Client", Serializers(), Identities(IdKind.Client), TestScheduler, + return Protocol("Client", Serializers(MarshallersProvider.Dummy), Identities(IdKind.Client), TestScheduler, SocketWire.Client(lifetime, - TestScheduler, (serverProtocol.wire as SocketWire.Server).port, "TestClient"), lifetime + TestScheduler, (serverProtocol.wire as SocketWire.Server).wireAddress.toSocketAddress(), "TestClient"), lifetime ) } internal fun client(lifetime: Lifetime, port: Int): Protocol { - return Protocol("Client", Serializers(), Identities(IdKind.Client), TestScheduler, + return Protocol("Client", Serializers(MarshallersProvider.Dummy), Identities(IdKind.Client), TestScheduler, SocketWire.Client(lifetime, TestScheduler, port, "TestClient"), lifetime ) } + + internal fun server(lifetime: Lifetime, address: SocketAddress): Protocol { + return Protocol("Server", Serializers(MarshallersProvider.Dummy), Identities(IdKind.Server), TestScheduler, + SocketWire.Server(lifetime, TestScheduler, address, "TestServer"), lifetime + ) + } } lateinit var socketLifetime: Lifetime @@ -63,9 +75,11 @@ class SocketWireTest : TestBase() { // ConsoleLoggerFactory.minLevelToLog = LogLevel.Trace } + abstract protected fun serverProvider(lifetime: Lifetime): Protocol + @Test fun TestBasicRun() { - val serverProtocol = server(socketLifetime) + val serverProtocol = serverProvider(socketLifetime) val clientProtocol = client(socketLifetime, serverProtocol) val sp = RdOptionalProperty().static(1).apply { bindTopLevel(lifetime, serverProtocol, top) } @@ -80,7 +94,7 @@ class SocketWireTest : TestBase() { @Test fun TestOrdering() { - val serverProtocol = server(socketLifetime) + val serverProtocol = serverProvider(socketLifetime) val clientProtocol = client(socketLifetime, serverProtocol) val sp = RdOptionalProperty().static(1).apply { bindTopLevel(lifetime, serverProtocol, top) } @@ -102,7 +116,7 @@ class SocketWireTest : TestBase() { @Disabled @Test fun TestDisconnect() { - val serverProtocol = server(socketLifetime) + val serverProtocol = serverProvider(socketLifetime) val clientProtocol = client(socketLifetime, serverProtocol) val sp = RdSignal().static(1).apply { bindTopLevel(lifetime, serverProtocol, top) } @@ -147,7 +161,7 @@ class SocketWireTest : TestBase() { @Test fun TestDdos() { - val serverProtocol = server(socketLifetime) + val serverProtocol = serverProvider(socketLifetime) val clientProtocol = client(socketLifetime, serverProtocol) val sp = RdSignal().static(1).apply { bindTopLevel(lifetime, serverProtocol, top) } @@ -172,7 +186,7 @@ class SocketWireTest : TestBase() { @Test fun TestBigBuffer() { - val serverProtocol = server(socketLifetime) + val serverProtocol = serverProvider(socketLifetime) val clientProtocol = client(socketLifetime, serverProtocol) val sp = RdOptionalProperty().static(1).apply { bindTopLevel(lifetime, serverProtocol, top) } @@ -185,42 +199,20 @@ class SocketWireTest : TestBase() { cp.waitAndAssert("".padStart(100000, '3'), "1") } - - @Test - fun TestRunWithSlowpokeServer() { - - val port = NetUtils.findFreePort(0) - val clientProtocol = client(socketLifetime, port) - - - val cp = RdOptionalProperty().static(1).apply { bindTopLevel(lifetime, clientProtocol, top) } - - cp.set(1) - - Thread.sleep(2000) - - val serverProtocol = server(socketLifetime, port) - val sp = RdOptionalProperty().static(1).apply { bindTopLevel(lifetime, serverProtocol, top) } - - val prev = sp.valueOrNull - cp.set(4) - sp.waitAndAssert(4, prev) - } - @Test fun TestServerWithoutClient() { - server(socketLifetime) + serverProvider(socketLifetime) } @Test fun TestServerWithoutClientWithDelay() { - server(socketLifetime) + serverProvider(socketLifetime) Thread.sleep(100) } @Test fun TestServerWithoutClientWithDelayAndMessages() { - val protocol = server(socketLifetime) + val protocol = serverProvider(socketLifetime) Thread.sleep(100) val sp = RdOptionalProperty().static(1).apply { bindTopLevel(lifetime, protocol, top) } @@ -229,28 +221,6 @@ class SocketWireTest : TestBase() { Thread.sleep(50) } - @Test - fun TestClientWithoutServer() { - client(socketLifetime, NetUtils.findFreePort(0)) - } - - @Test - fun TestClientWithoutServerWithDelay() { - client(socketLifetime, NetUtils.findFreePort(0)) - Thread.sleep(100) - } - - @Test - fun TestClientWithoutServerWithDelayAndMessages() { - val clientProtocol = client(socketLifetime, NetUtils.findFreePort(0)) - - val cp = RdOptionalProperty().static(1).apply { bindTopLevel(lifetime, clientProtocol, top) } - - cp.set(1) - cp.set(2) - Thread.sleep(50) - } - // @Test // fun testReentrantWrites() { // val serverProtocol = server(socketLifetime) @@ -281,13 +251,11 @@ class SocketWireTest : TestBase() { @Test fun testRemoteSocket() { val serverSocket = SocketWire.Server(lifetime, TestScheduler, 0, allowRemoteConnections = true) - val clientSocket = SocketWire.Client(lifetime, TestScheduler, serverSocket.port, hostAddress = InetAddress.getLocalHost()) + val clientSocket = SocketWire.Client(lifetime, TestScheduler, (serverSocket.wireAddress as WireAddress.TcpAddress).port, hostAddress = InetAddress.getLocalHost()) assertTrue(spinUntil(60000L) { clientSocket.connected.value }) } - - @Test fun testSocketFactory() { val spinTimeoutMs = 5000L @@ -315,11 +283,89 @@ class SocketWireTest : TestBase() { spinUntil { factory.size == 0 } } +// @BeforeClass +// fun beforeClass() { +// setupLogHandler { +// if (it.getLevel() == Level.ERROR) { +// System.err.println(it.message) +// it.throwableInformation?.throwable?.printStackTrace() +// } +// } +// } +// +// private fun setupLogHandler(name: String = "default", action: (LoggingEvent) -> Unit) { +// val rootLogger = org.apache.log4j.Logger.getRootLogger() +// rootLogger.removeAppender("default") +// rootLogger.addAppender(object : AppenderSkeleton() { +// init { +// setName(name) +// } +// +// override fun append(event: LoggingEvent) { +// action(event) +// } +// +// override fun close() {} +// +// override fun requiresLayout(): Boolean { +// return false +// } +// }) +// } +} + + +class SocketWireTcpTest: SocketWireTestBase() { + override fun serverProvider(lifetime: Lifetime): Protocol = server(lifetime, 0) + + @Test + fun TestClientWithoutServerWithDelayAndMessages() { + val clientProtocol = client(socketLifetime, NetUtils.findFreePort(0)) + + val cp = RdOptionalProperty().static(1).apply { bindTopLevel(lifetime, clientProtocol, top) } + + cp.set(1) + cp.set(2) + Thread.sleep(50) + } + + @Test + fun TestRunWithSlowpokeServer() { + + val port = NetUtils.findFreePort(0) + val clientProtocol = client(socketLifetime, port) + + + val cp = RdOptionalProperty().static(1).apply { bindTopLevel(lifetime, clientProtocol, top) } + + cp.set(1) + + Thread.sleep(2000) + + val serverProtocol = server(socketLifetime, port) + val sp = RdOptionalProperty().static(1).apply { bindTopLevel(lifetime, serverProtocol, top) } + + val prev = sp.valueOrNull + cp.set(4) + sp.waitAndAssert(4, prev) + } + + @Test + fun TestClientWithoutServer() { + client(socketLifetime, NetUtils.findFreePort(0)) + } + + @Test + fun TestClientWithoutServerWithDelay() { + client(socketLifetime, NetUtils.findFreePort(0)) + Thread.sleep(100) + } + @ParameterizedTest @ValueSource(booleans = [true, false]) fun testPacketLoss(isClientToServer: Boolean) { Lifetime.using { lifetime -> - val serverProtocol = server(lifetime) + val serverProtocol = serverProvider(lifetime) val serverWire = serverProtocol.wire val proxy = SocketProxy("TestProxy", lifetime, serverProtocol) @@ -360,35 +406,23 @@ class SocketWireTest : TestBase() { assertTrue(clientWire.heartbeatAlive.value) } } +} +class SocketWireUnixDomainSocketTest: SocketWireTestBase() { -// @BeforeClass -// fun beforeClass() { -// setupLogHandler { -// if (it.getLevel() == Level.ERROR) { -// System.err.println(it.message) -// it.throwableInformation?.throwable?.printStackTrace() -// } -// } -// } -// -// private fun setupLogHandler(name: String = "default", action: (LoggingEvent) -> Unit) { -// val rootLogger = org.apache.log4j.Logger.getRootLogger() -// rootLogger.removeAppender("default") -// rootLogger.addAppender(object : AppenderSkeleton() { -// init { -// setName(name) -// } -// -// override fun append(event: LoggingEvent) { -// action(event) -// } -// -// override fun close() {} -// -// override fun requiresLayout(): Boolean { -// return false -// } -// }) -// } -} + private fun createRandomUnixSocketAddress() = createTempFile("rd_test_socket_").let { + it.deleteIfExists() + UnixDomainSocketAddress.of(it) + } + + override fun serverProvider(lifetime: Lifetime): Protocol = server(lifetime, createRandomUnixSocketAddress()) + + @Test + fun testUnixDomainSocket() { + val endpoint = createRandomUnixSocketAddress() + SocketWire.Server(lifetime, TestScheduler, endpoint) + val clientSocket = SocketWire.Client(lifetime, TestScheduler, endpoint) + + assertTrue(spinUntil(60000L) { clientSocket.connected.value }) + } +} \ No newline at end of file diff --git a/rd-net/Lifetimes/Lifetimes/Lifetime.cs b/rd-net/Lifetimes/Lifetimes/Lifetime.cs index f3920f35b..c316c3e2b 100644 --- a/rd-net/Lifetimes/Lifetimes/Lifetime.cs +++ b/rd-net/Lifetimes/Lifetimes/Lifetime.cs @@ -126,7 +126,7 @@ internal LifetimeDefinition Definition internal void AssertInitialized() { - if (!Mode.IsAssertion) return; + /*if (!Mode.IsAssertion) return; // TODO: FIX THIS if (LogErrorIfLifetimeIsNotInitialized && IsUninitialized) { @@ -134,7 +134,7 @@ internal void AssertInitialized() "This may cause a memory leak as default(Lifetime) is treated as Eternal. " + "Please provide a properly initialized Lifetime or use `Lifetime?` if you need to handle both cases. " + "Use Lifetime.Eternal explicitly if that behavior is intended."); - } + }*/ } //ctor diff --git a/rd-net/Rd.sln.DotSettings b/rd-net/Rd.sln.DotSettings index b6eb34fec..ba5985552 100644 --- a/rd-net/Rd.sln.DotSettings +++ b/rd-net/Rd.sln.DotSettings @@ -1,4 +1,7 @@  UTF <Policy Inspect="True" Prefix="my" Suffix="" Style="AaBb" /> - <Policy Inspect="True" Prefix="our" Suffix="" Style="AaBb" /> \ No newline at end of file + <Policy Inspect="True" Prefix="our" Suffix="" Style="AaBb" /> + <Policy><Descriptor Staticness="Instance" AccessRightKinds="Private" Description="Instance fields (private)"><ElementKinds><Kind Name="FIELD" /><Kind Name="READONLY_FIELD" /></ElementKinds></Descriptor><Policy Inspect="True" WarnAboutPrefixesAndSuffixes="False" Prefix="my" Suffix="" Style="AaBb" /></Policy> + <Policy><Descriptor Staticness="Static" AccessRightKinds="Private" Description="Static fields (private)"><ElementKinds><Kind Name="FIELD" /></ElementKinds></Descriptor><Policy Inspect="True" WarnAboutPrefixesAndSuffixes="False" Prefix="our" Suffix="" Style="AaBb" /></Policy> + True \ No newline at end of file diff --git a/rd-net/RdFramework.Reflection/BindableChildrenUtil.cs b/rd-net/RdFramework.Reflection/BindableChildrenUtil.cs index a665b4703..cfe172c1b 100644 --- a/rd-net/RdFramework.Reflection/BindableChildrenUtil.cs +++ b/rd-net/RdFramework.Reflection/BindableChildrenUtil.cs @@ -24,7 +24,7 @@ internal static class BindableChildrenUtil internal static void PrettyPrint(PrettyPrinter p, IReflectionBindable instance) { - Action prettyPrinter; + Action? prettyPrinter; lock (ourPrettyPrintersLock) { ourPrettyPrinters.TryGetValue(instance.GetType(), out prettyPrinter); @@ -65,7 +65,7 @@ internal static void PrettyPrint(PrettyPrinter p, IReflectionBindable instance) internal static void FillBindableFields(IReflectionBindable instance) { var type = instance.GetType(); - Action fillBindableFields; + Action? fillBindableFields; lock (ourFillBindableChildren) { ourFillBindableChildren.TryGetValue(type, out fillBindableFields); diff --git a/rd-net/RdFramework.Reflection/BuiltInSerializers.cs b/rd-net/RdFramework.Reflection/BuiltInSerializers.cs index e47eb50cb..e99fb50c5 100644 --- a/rd-net/RdFramework.Reflection/BuiltInSerializers.cs +++ b/rd-net/RdFramework.Reflection/BuiltInSerializers.cs @@ -244,8 +244,8 @@ public static bool HasBuiltInFields(TypeInfo t) Assertion.Fail($"Invalid BuiltIn serializer for type {typeInfo}. Static field 'Read' with type {typeof(CtxReadDelegate<>).ToString(true)} not found"); if (writeField == null) Assertion.Fail($"Invalid BuiltIn serializer for type {typeInfo}. Static field 'Write' with type {typeof(CtxWriteDelegate<>).ToString(true)} not found"); - var reader = readField.GetValue(null); - var writer = writeField.GetValue(null); + var reader = readField.GetValue(null)!; + var writer = writeField.GetValue(null)!; return new SerializerPair(reader, writer); } diff --git a/rd-net/RdFramework.Reflection/CollectionSerializers.cs b/rd-net/RdFramework.Reflection/CollectionSerializers.cs index bcaf2ab55..0504523c2 100644 --- a/rd-net/RdFramework.Reflection/CollectionSerializers.cs +++ b/rd-net/RdFramework.Reflection/CollectionSerializers.cs @@ -22,7 +22,7 @@ public static SerializerPair CreateListSerializerPair(SerializerPair itemSeri } public static SerializerPair CreateDictionarySerializerPair( - SerializerPair keySerializer, SerializerPair valueSerializer) + SerializerPair keySerializer, SerializerPair valueSerializer) where TKey : notnull { var read = CreateReadDictionary(keySerializer, valueSerializer); @@ -53,7 +53,7 @@ public static SerializerPair CreateDictionarySerializerPair( } public static SerializerPair CreateReadOnlyDictionarySerializerPair( - SerializerPair keySerializer, SerializerPair valueSerializer) + SerializerPair keySerializer, SerializerPair valueSerializer) where TKey : notnull { #if NET35 throw new NotSupportedException(); @@ -88,7 +88,7 @@ public static SerializerPair CreateReadOnlyDictionarySerializerPair?> CreateReadDictionary( - SerializerPair keySerializer, SerializerPair valueSerializer) + SerializerPair keySerializer, SerializerPair valueSerializer) where TKey : notnull { CtxReadDelegate?> read = (context, reader) => { diff --git a/rd-net/RdFramework.Reflection/ProxyGenerator.cs b/rd-net/RdFramework.Reflection/ProxyGenerator.cs index cde603d64..87e074a1c 100644 --- a/rd-net/RdFramework.Reflection/ProxyGenerator.cs +++ b/rd-net/RdFramework.Reflection/ProxyGenerator.cs @@ -83,7 +83,7 @@ public struct FakeTuple { public ProxyGenerator(bool allowSave = false) { myAllowSave = allowSave; -#if NETSTANDARD +#if NETSTANDARD || NETCOREAPP myAssemblyBuilder = new Lazy(() => AssemblyBuilder.DefineDynamicAssembly(new AssemblyName(DynamicAssemblyName), AssemblyBuilderAccess.Run)); myModuleBuilder = new Lazy(() => myAssemblyBuilder.Value.DefineDynamicModule(DynamicAssemblyName)); #else @@ -325,14 +325,15 @@ private void ImplementProperty(TypeBuilderContext ctx, PropertyInfo propertyInfo throw new Exception("Setter for properties in proxy interface is prohibited due to unclear semantic"); } - if (propertyInfo.GetGetMethod() != null) + var methodInfo = propertyInfo.GetGetMethod(); + if (methodInfo != null) { - var getMethod = typebuilder.DefineMethod(propertyInfo.GetGetMethod().Name, MethodAttributes.Final | MethodAttributes.Virtual | MethodAttributes.Private, type, EmptyArray.Instance); + var getMethod = typebuilder.DefineMethod(methodInfo.Name, MethodAttributes.Final | MethodAttributes.Virtual | MethodAttributes.Private, type, EmptyArray.Instance); var il = getMethod.GetILGenerator(); il.Emit(OpCodes.Ldarg_0); il.Emit(OpCodes.Ldfld, field); il.Emit(OpCodes.Ret); - typebuilder.DefineMethodOverride(getMethod, propertyInfo.GetGetMethod()); + typebuilder.DefineMethodOverride(getMethod, methodInfo); } } @@ -636,7 +637,7 @@ internal class ProxyGeneratorMembers // ReSharper disable once PossibleNullReferenceException public readonly MethodInfo EternalLifetimeGet = typeof(Lifetime) - .GetProperty(nameof(Lifetime.Eternal), BindingFlags.Static | BindingFlags.Public) + .GetProperty(nameof(Lifetime.Eternal), BindingFlags.Static | BindingFlags.Public)! .GetGetMethod() .NotNull(nameof(EternalLifetimeGet)); diff --git a/rd-net/RdFramework.Reflection/ProxyGeneratorCache.cs b/rd-net/RdFramework.Reflection/ProxyGeneratorCache.cs index ed812a0b8..6cfbc7641 100644 --- a/rd-net/RdFramework.Reflection/ProxyGeneratorCache.cs +++ b/rd-net/RdFramework.Reflection/ProxyGeneratorCache.cs @@ -15,7 +15,7 @@ public class ProxyGeneratorCache : IProxyGenerator private sealed class TokenComparer : IComparer { public static IComparer Instance { get; } = new TokenComparer(); - public int Compare(MethodInfo x, MethodInfo y) => (x?.MetadataToken ?? -1).CompareTo(y?.MetadataToken ?? -1); + public int Compare(MethodInfo? x, MethodInfo? y) => (x?.MetadataToken ?? -1).CompareTo(y?.MetadataToken ?? -1); } public ProxyGeneratorCache(ProxyGenerator generator) diff --git a/rd-net/RdFramework.Reflection/RdFramework.Reflection.csproj b/rd-net/RdFramework.Reflection/RdFramework.Reflection.csproj index 94e0e498a..dccb85858 100644 --- a/rd-net/RdFramework.Reflection/RdFramework.Reflection.csproj +++ b/rd-net/RdFramework.Reflection/RdFramework.Reflection.csproj @@ -1,7 +1,7 @@  - netstandard2.0;net35;net472 + netstandard2.0;net35;net472;net6.0; JetBrains.RdFramework.Reflection JetBrains.Rd.Reflection diff --git a/rd-net/RdFramework.Reflection/ReflectionRdActivator.cs b/rd-net/RdFramework.Reflection/ReflectionRdActivator.cs index f2d588496..62900d616 100644 --- a/rd-net/RdFramework.Reflection/ReflectionRdActivator.cs +++ b/rd-net/RdFramework.Reflection/ReflectionRdActivator.cs @@ -148,7 +148,7 @@ private object ActivateRd(Type type) object instance; try { - instance = Activator.CreateInstance(implementingType); + instance = Activator.CreateInstance(implementingType)!; } catch (MissingMethodException e) { @@ -451,10 +451,10 @@ public static string GetTypeName(Type type) { var rpcInterface = ReflectionSerializerVerifier.GetRpcInterface(type.GetTypeInfo()); if (rpcInterface != null) - return rpcInterface.AssemblyQualifiedName; + return rpcInterface.AssemblyQualifiedName!; } - return typename; + return typename!; } } } \ No newline at end of file diff --git a/rd-net/RdFramework.Reflection/ReflectionSerializerVerifier.cs b/rd-net/RdFramework.Reflection/ReflectionSerializerVerifier.cs index dbfbe016b..bff7d55e0 100644 --- a/rd-net/RdFramework.Reflection/ReflectionSerializerVerifier.cs +++ b/rd-net/RdFramework.Reflection/ReflectionSerializerVerifier.cs @@ -148,7 +148,7 @@ bool IsValidArray() if (!typeInfo.IsArray) return false; if (typeInfo.GetArrayRank() != 1) return false; - var arrayType = typeInfo.GetElementType().GetTypeInfo(); + var arrayType = typeInfo.GetElementType()!.GetTypeInfo(); return IsFieldType(arrayType, false); } diff --git a/rd-net/RdFramework.Reflection/ReflectionSerializers.cs b/rd-net/RdFramework.Reflection/ReflectionSerializers.cs index 11bf16eaa..ea35f0c97 100644 --- a/rd-net/RdFramework.Reflection/ReflectionSerializers.cs +++ b/rd-net/RdFramework.Reflection/ReflectionSerializers.cs @@ -226,11 +226,13 @@ private void RegisterModelSerializer() object instance; if (isScalar) { +#pragma warning disable SYSLIB0050 instance = FormatterServices.GetUninitializedObject(type); +#pragma warning restore SYSLIB0050 } else { - instance = Activator.CreateInstance(type); + instance = Activator.CreateInstance(type)!; } var bindableInstance = instance as IRdBindable; diff --git a/rd-net/RdFramework.Reflection/ScalarCollectionExtension.cs b/rd-net/RdFramework.Reflection/ScalarCollectionExtension.cs index 8a1bef542..6e6a461ea 100644 --- a/rd-net/RdFramework.Reflection/ScalarCollectionExtension.cs +++ b/rd-net/RdFramework.Reflection/ScalarCollectionExtension.cs @@ -44,7 +44,7 @@ public static void AttachCollectionSerializers(ReflectionSerializers self) } else if (type.IsArray) { - var result = (SerializerPair)ReflectionUtil.InvokeStaticGeneric(typeof(ScalarCollectionExtension), nameof(CreateArraySerializer), type.GetElementType(), new object[] { self })!; + var result = (SerializerPair)ReflectionUtil.InvokeStaticGeneric(typeof(ScalarCollectionExtension), nameof(CreateArraySerializer), type.GetElementType()!, self)!; self.Register(type, result); } }); diff --git a/rd-net/RdFramework.Reflection/ScalarSerializer.cs b/rd-net/RdFramework.Reflection/ScalarSerializer.cs index 7ebc4fb2b..e8cfb06a4 100644 --- a/rd-net/RdFramework.Reflection/ScalarSerializer.cs +++ b/rd-net/RdFramework.Reflection/ScalarSerializer.cs @@ -124,7 +124,9 @@ private SerializerPair CreateCustomScalar(ISerializersSource serializers) if (allowNullable && !unsafeReader.ReadNullness()) return default; +#pragma warning disable SYSLIB0050 object instance = FormatterServices.GetUninitializedObject(typeof(T)); +#pragma warning restore SYSLIB0050 try { @@ -247,7 +249,7 @@ private SerializerPair CreateValueTupleSerializer(ISerializersSource serializ } var type = typeInfo.AsType(); - CtxReadDelegate readerDelegate = (ctx, unsafeReader) => + CtxReadDelegate readerDelegate = (ctx, unsafeReader) => { // todo: consider using IL emit var activatorArgs = new object[argumentTypes.Length]; @@ -258,7 +260,7 @@ private SerializerPair CreateValueTupleSerializer(ISerializersSource serializ } var instance = Activator.CreateInstance(type, activatorArgs); - return (T) instance; + return (T?) instance; }; CtxWriteDelegate writerDelegate = (ctx, unsafeWriter, value) => diff --git a/rd-net/RdFramework.Reflection/SerializerPair.cs b/rd-net/RdFramework.Reflection/SerializerPair.cs index feb34d468..44f7a690e 100644 --- a/rd-net/RdFramework.Reflection/SerializerPair.cs +++ b/rd-net/RdFramework.Reflection/SerializerPair.cs @@ -95,10 +95,10 @@ void WriterDelegate(SerializationCtx ctx, UnsafeWriter writer, T value) => void WriterDelegateStatic(SerializationCtx ctx, UnsafeWriter writer, T value) => writeMethod.Invoke(null, new object?[] { ctx, writer, value, }); - T ReaderDelegate(SerializationCtx ctx, UnsafeReader reader) => - (T) readMethod.Invoke(null, new object?[] { ctx, reader }); + T? ReaderDelegate(SerializationCtx ctx, UnsafeReader reader) => + (T?) readMethod.Invoke(null, new object?[] { ctx, reader }); - CtxReadDelegate ctxReadDelegate = ReaderDelegate; + CtxReadDelegate ctxReadDelegate = ReaderDelegate; CtxWriteDelegate ctxWriteDelegate = writeMethod.IsStatic ? WriterDelegateStatic : WriterDelegate; return new SerializerPair(ctxReadDelegate, ctxWriteDelegate); } @@ -111,13 +111,13 @@ private static SerializerPair CreateFromMethodsImpl1(MethodInfo readMethod, M void WriterDelegate(SerializationCtx ctx, UnsafeWriter writer, T value) => writeMethod.Invoke(null, new object?[] {ctx, writer, value}); - T ReaderDelegate(SerializationCtx ctx, UnsafeReader reader) + T? ReaderDelegate(SerializationCtx ctx, UnsafeReader reader) { - return (T)readMethod.Invoke(null, + return (T?)readMethod.Invoke(null, new[] {ctx, reader, ctxKeyReadDelegate, ctxKeyWriteDelegate}); } - CtxReadDelegate ctxReadDelegate = ReaderDelegate; + CtxReadDelegate ctxReadDelegate = ReaderDelegate; CtxWriteDelegate ctxWriteDelegate = WriterDelegate; return new SerializerPair(ctxReadDelegate, ctxWriteDelegate); } @@ -133,13 +133,13 @@ private static SerializerPair CreateFromMethodsImpl2(MethodInfo readMethod, M void WriterDelegate(SerializationCtx ctx, UnsafeWriter writer, T value) => writeMethod.Invoke(null, new object?[] {ctx, writer, value}); - T ReaderDelegate(SerializationCtx ctx, UnsafeReader reader) + T? ReaderDelegate(SerializationCtx ctx, UnsafeReader reader) { - return (T)readMethod.Invoke(null, + return (T?)readMethod.Invoke(null, new[] {ctx, reader, ctxKeyReadDelegate, ctxKeyWriteDelegate, ctxValueReadDelegate, ctxValueWriteDelegate}); } - CtxReadDelegate ctxReadDelegate = ReaderDelegate; + CtxReadDelegate ctxReadDelegate = ReaderDelegate; CtxWriteDelegate ctxWriteDelegate = WriterDelegate; return new SerializerPair(ctxReadDelegate, ctxWriteDelegate); } @@ -153,7 +153,7 @@ public static SerializerPair FromMarshaller(IBuiltInMarshaller marshaller) private static SerializerPair CreateFromNonProtocolMethodsT(MethodInfo readMethod, MethodInfo writeMethod) { - Assertion.Require(readMethod.IsStatic, $"Read method should be static ({readMethod.DeclaringType.ToString(true)})"); + Assertion.Require(readMethod.IsStatic, $"Read method should be static ({readMethod.DeclaringType?.ToString(true)})"); void WriterDelegate(SerializationCtx ctx, UnsafeWriter writer, T value) { @@ -174,7 +174,7 @@ void WriterDelegateStatic(SerializationCtx ctx, UnsafeWriter writer, T value) if (!typeof(T).IsValueType && !reader.ReadNullness()) return default; - return (T) readMethod.Invoke(null, new object[] {reader}); + return (T?) readMethod.Invoke(null, new object[] {reader}); } CtxReadDelegate ctxReadDelegate = ReaderDelegate; diff --git a/rd-net/RdFramework.Reflection/SerializerReflectionUtil.cs b/rd-net/RdFramework.Reflection/SerializerReflectionUtil.cs index 229d62550..6e1811ee6 100644 --- a/rd-net/RdFramework.Reflection/SerializerReflectionUtil.cs +++ b/rd-net/RdFramework.Reflection/SerializerReflectionUtil.cs @@ -74,8 +74,9 @@ private static IEnumerable GetFields(Type type, Type baseType) yield return field; // private fields only being returned for the current type - while ((type = type.BaseType) != baseType && type != null) + while (type.BaseType != baseType && type.BaseType != null) { + type = type.BaseType; // but protected fields are returned in first step foreach (var baseField in type.GetFields(BindingFlags.Instance | BindingFlags.NonPublic)) if (baseField.IsPrivate) @@ -85,7 +86,7 @@ private static IEnumerable GetFields(Type type, Type baseType) internal static SerializerPair ConvertPair(SerializerPair serializers, Type desiredType) { - return (SerializerPair)ourConvertSerializerPair.MakeGenericMethod(serializers.Writer.GetType().GetGenericArguments()[0], desiredType).Invoke(null, new object[] { serializers }); + return (SerializerPair)ourConvertSerializerPair.MakeGenericMethod(serializers.Writer.GetType().GetGenericArguments()[0], desiredType).Invoke(null, new object[] { serializers })!; } private static readonly MethodInfo ourConvertSerializerPair = typeof(SerializerReflectionUtil).GetTypeInfo().GetMethod(nameof(ConvertPairGeneric), BindingFlags.Static | BindingFlags.NonPublic)!; @@ -113,7 +114,7 @@ internal static CtxReadDelegate ConvertReader(object reader) var genericTypedRead = ourConvertTypedCtxRead.MakeGenericMethod(reader.GetType().GetGenericArguments()[0], typeof(object)); var result = genericTypedRead.Invoke(null, new[] { reader }); - return (CtxReadDelegate)result; + return (CtxReadDelegate)result!; } internal static CtxWriteDelegate ConvertWriter(object writer) @@ -121,7 +122,7 @@ internal static CtxWriteDelegate ConvertWriter(object writer) if (writer is CtxWriteDelegate objWriter) return objWriter; - return (CtxWriteDelegate)ourConvertTypedCtxWrite.MakeGenericMethod(writer.GetType().GetGenericArguments()[0], typeof(TOut)).Invoke(null, new[] { writer }); + return (CtxWriteDelegate)ourConvertTypedCtxWrite.MakeGenericMethod(writer.GetType().GetGenericArguments()[0], typeof(TOut)).Invoke(null, new[] { writer })!; } private static readonly MethodInfo ourConvertTypedCtxRead = typeof(SerializerReflectionUtil).GetTypeInfo().GetMethod(nameof(CtxReadTypedToObject), BindingFlags.Static | BindingFlags.NonPublic)!; diff --git a/rd-net/RdFramework/Base/IRdBindable.cs b/rd-net/RdFramework/Base/IRdBindable.cs index 069f82ca8..58bfc7894 100644 --- a/rd-net/RdFramework/Base/IRdBindable.cs +++ b/rd-net/RdFramework/Base/IRdBindable.cs @@ -24,7 +24,7 @@ public static class RdDynamicEx { public static IProtocol GetProtoOrThrow(this IRdDynamic dynamic) { - return dynamic.TryGetProto() ?? throw new ProtocolNotBoundException(dynamic.ToString()); + return dynamic.TryGetProto() ?? throw new ProtocolNotBoundException(dynamic.ToString() ?? "'dynamic.ToString() was null'"); } } @@ -322,7 +322,7 @@ public static void PrintEx(this object? me, PrettyPrinter printer) break; } default: - printer.Print(me.ToString()); + printer.Print(me.ToString() ?? ""); break; } } diff --git a/rd-net/RdFramework/Base/RdBindableBase.cs b/rd-net/RdFramework/Base/RdBindableBase.cs index a6e4a0209..97b42d265 100644 --- a/rd-net/RdFramework/Base/RdBindableBase.cs +++ b/rd-net/RdFramework/Base/RdBindableBase.cs @@ -297,7 +297,7 @@ private T GetOrCreateExtension(string name, bool highPriorityExtension, Func< // NOTE: dummy implementation which prevents WPF from hanging the viewmodel forever on reflection property descriptor fabricated change events: // when it sees PropertyChanged, it does not look for property descriptor events - public virtual event PropertyChangedEventHandler PropertyChanged { add { } remove { } } + public virtual event PropertyChangedEventHandler? PropertyChanged { add { } remove { } } } public enum BindState diff --git a/rd-net/RdFramework/Impl/EndPointWrapper.cs b/rd-net/RdFramework/Impl/EndPointWrapper.cs new file mode 100644 index 000000000..96ad53d0f --- /dev/null +++ b/rd-net/RdFramework/Impl/EndPointWrapper.cs @@ -0,0 +1,107 @@ +using System; +using System.IO; +using System.Net; +using System.Net.Sockets; + +namespace JetBrains.Rd.Impl; + +public abstract class EndPointWrapper +{ + public AddressFamily AddressFamily { get; private set; } + public SocketType SocketType { get; private set; } + public ProtocolType ProtocolType { get; private set; } + + public abstract EndPoint ToEndPoint(); + + private EndPointWrapper() {} + + public class IPEndpointWrapper : EndPointWrapper + { + public IPEndPoint IpEndPoint { get; } + public IPAddress LocalAddress { get; } + public int LocalPort { get; } + + public IPEndpointWrapper(IPEndPoint endPoint) + { + IpEndPoint = endPoint; + AddressFamily = AddressFamily.InterNetwork; + SocketType = SocketType.Stream; + ProtocolType = ProtocolType.Tcp; + LocalAddress = endPoint.Address; + LocalPort = endPoint.Port; + } + + public override EndPoint ToEndPoint() + { + return IpEndPoint; + } + } + + public class UnixEndpointWrapper : EndPointWrapper + { + public string LocalPath { get; private set; } + +#if NET6_0_OR_GREATER + public UnixDomainSocketEndPoint UnixEndPoint { get; } +#endif + + public UnixEndpointWrapper(UnixSocketConnectionParams connectionParams) : this(connectionParams.Path) {} + + public UnixEndpointWrapper(string path) + { +#if NET6_0_OR_GREATER + UnixEndPoint = new UnixDomainSocketEndPoint(path); +#endif + AddressFamily = AddressFamily.Unix; + SocketType = SocketType.Stream; + ProtocolType = ProtocolType.Unspecified; + LocalPath = path; + } + + public override EndPoint ToEndPoint() + { +#if NET6_0_OR_GREATER + return UnixEndPoint; +#else + throw new NotSupportedException("Unix domain sockets are not supported on this platform"); +#endif + } + } + + + public static IPEndpointWrapper CreateIpEndPoint(IPAddress? address = null, int? port = null) + { + var address1 = address ?? IPAddress.Loopback; + var port1 = port ?? 0; + return new IPEndpointWrapper(new IPEndPoint(address1, port1)); + } + + public static UnixEndpointWrapper CreateUnixEndPoint(UnixSocketConnectionParams connectionParams) + { + return new UnixEndpointWrapper(connectionParams); + } + + public static UnixEndpointWrapper CreateUnixEndPoint(string? path) + { + return new UnixEndpointWrapper(path ?? Path.GetTempFileName()); + } + + public static EndPointWrapper FromEndPoint(EndPoint endPoint) + { + if (endPoint is IPEndPoint ipEndPoint) + return new IPEndpointWrapper(ipEndPoint); +#if NET6_0_OR_GREATER + if (endPoint is UnixDomainSocketEndPoint unixEndPoint) + return new UnixEndpointWrapper(unixEndPoint.ToString()!); +#endif + throw new NotSupportedException($"Unknown endpoint type: {endPoint.GetType()}"); + } + +#if NET6_0_OR_GREATER + public static bool AreUnixSocketsSupported => true; +#else + public static bool AreUnixSocketsSupported => false; +#endif + + public record struct UnixSocketConnectionParams(string Path); +} \ No newline at end of file diff --git a/rd-net/RdFramework/Impl/RdEntitiesRegistrar.cs b/rd-net/RdFramework/Impl/RdEntitiesRegistrar.cs index 9a29f0241..a95159454 100644 --- a/rd-net/RdFramework/Impl/RdEntitiesRegistrar.cs +++ b/rd-net/RdFramework/Impl/RdEntitiesRegistrar.cs @@ -17,7 +17,7 @@ internal void Register(Lifetime lifetime, RdId rdId, IRdDynamic dynamic) myMap.BlockingAddUnique(lifetime, myMap, rdId, dynamic); } - public bool TryGetEntity(RdId rdId, out IRdDynamic entity) + public bool TryGetEntity(RdId rdId, out IRdDynamic? entity) { lock (myMap) { diff --git a/rd-net/RdFramework/Impl/RdPerContextMap.cs b/rd-net/RdFramework/Impl/RdPerContextMap.cs index 683968fbd..a0b828cab 100644 --- a/rd-net/RdFramework/Impl/RdPerContextMap.cs +++ b/rd-net/RdFramework/Impl/RdPerContextMap.cs @@ -59,7 +59,7 @@ protected override void PreInit(Lifetime lifetime, IProtocol proto) if (!cookie.Succeed) return; - value.WithId(RdId.Mix(contextValue.ToString())); + value.WithId(RdId.Mix(contextValue.ToString() ?? "")); value.PreBind(contextValueLifetime, this, $"[{contextValue.ToString()}]"); } diff --git a/rd-net/RdFramework/Impl/RdSecureString.cs b/rd-net/RdFramework/Impl/RdSecureString.cs index ddfc0f8b6..ad664c727 100644 --- a/rd-net/RdFramework/Impl/RdSecureString.cs +++ b/rd-net/RdFramework/Impl/RdSecureString.cs @@ -26,7 +26,7 @@ public bool Equals(RdSecureString other) return string.Equals(Contents, other.Contents); } - public override bool Equals(object obj) + public override bool Equals(object? obj) { if (ReferenceEquals(null, obj)) return false; return obj is RdSecureString && Equals((RdSecureString) obj); diff --git a/rd-net/RdFramework/Impl/Serializers.cs b/rd-net/RdFramework/Impl/Serializers.cs index f8f58f3e8..02acc0918 100644 --- a/rd-net/RdFramework/Impl/Serializers.cs +++ b/rd-net/RdFramework/Impl/Serializers.cs @@ -257,7 +257,7 @@ public void Register(CtxReadDelegate reader, CtxWriteDelegate writer, l bool TryGetReader(RdId rdId, out CtxReadDelegate readDelegate) { lock (myLock) - return myReaders.TryGetValue(rdId, out readDelegate); + return myReaders.TryGetValue(rdId, out readDelegate!); } #if !NET35 myBackgroundRegistrar.Join(); diff --git a/rd-net/RdFramework/Impl/SocketWire.cs b/rd-net/RdFramework/Impl/SocketWire.cs index fbef69a70..1326b2bd3 100644 --- a/rd-net/RdFramework/Impl/SocketWire.cs +++ b/rd-net/RdFramework/Impl/SocketWire.cs @@ -1,4 +1,5 @@ using System; +using System.IO; using System.Net; using System.Net.Sockets; using System.Threading; @@ -113,7 +114,7 @@ protected Base(string id, Lifetime lifetime, IScheduler scheduler) private Timer StartHeartbeat() { var timer = new Timer(HeartBeatInterval.TotalMilliseconds) { AutoReset = false }; - void OnTimedEvent(object sender, ElapsedEventArgs e) + void OnTimedEvent(object? sender, ElapsedEventArgs e) { Ping(); timer.Start(); @@ -458,7 +459,7 @@ protected override void SendPkg(UnsafeWriter.Cookie cookie) //It's a kind of magic... protected static void SetSocketOptions(Socket s) { - s.NoDelay = true; + if (s.ProtocolType == ProtocolType.Tcp) s.NoDelay = true; // if (!TimeoutForbidden()) // s.ReceiveTimeout = TimeoutMs; //sometimes shutdown and close doesn't lead Receive to throw exception @@ -509,8 +510,11 @@ protected void AddTerminationActions(Thread receiverThread) ); } - public int Port { get; protected set; } + public EndPointWrapper? ConnectionEndPoint { get; protected set; } + + [Obsolete("Use ConnectionEndPoint instead")] + public int? Port => (ConnectionEndPoint as EndPointWrapper.IPEndpointWrapper)?.LocalPort; protected virtual bool AcceptHandshake(Socket socket) { @@ -522,11 +526,17 @@ protected virtual bool AcceptHandshake(Socket socket) public class Client : Base { public Client(Lifetime lifetime, IScheduler scheduler, int port, string? optId = null) : - this(lifetime, scheduler, new IPEndPoint(IPAddress.Loopback, port), optId) {} - - public Client(Lifetime lifetime, IScheduler scheduler, IPEndPoint endPoint, string? optId = null) : + this(lifetime, scheduler, EndPointWrapper.CreateIpEndPoint(IPAddress.Loopback, port), optId) {} + +#if NET6_0_OR_GREATER + public Client(Lifetime lifetime, IScheduler scheduler, EndPointWrapper.UnixSocketConnectionParams connectionParams, string? optId = null) : + this(lifetime, scheduler, EndPointWrapper.CreateUnixEndPoint(connectionParams), optId) {} +#endif + + public Client(Lifetime lifetime, IScheduler scheduler, EndPointWrapper endPointWrapper, string? optId = null) : base("ClientSocket-"+(optId ?? ""), lifetime, scheduler) { + ConnectionEndPoint = endPointWrapper; var thread = new Thread(() => { try @@ -538,12 +548,12 @@ public Client(Lifetime lifetime, IScheduler scheduler, IPEndPoint endPoint, stri { try { - var s = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp); + var s = new Socket(endPointWrapper.AddressFamily, endPointWrapper.SocketType, endPointWrapper.ProtocolType); Socket = s; SetSocketOptions(s); - Log.Verbose("{0}: connecting to {1}.", Id, endPoint); - s.Connect(endPoint); + Log.Verbose("{0}: connecting to {1}.", Id, endPointWrapper.ToEndPoint()); + s.Connect(endPointWrapper.ToEndPoint()); lock (Lock) { @@ -569,11 +579,11 @@ public Client(Lifetime lifetime, IScheduler scheduler, IPEndPoint endPoint, stri { lastReportedErrorHash = errorHashCode; if (Log.IsVersboseEnabled()) - Log.Verbose(ex, $"{Id}: connection error for endpoint \"{endPoint}\"."); + Log.Verbose(ex, $"{Id}: connection error for endpoint \"{endPointWrapper.ToEndPoint()}\"."); } else { - Log.Verbose("{0}: connection error for endpoint \"{1}\" ({2}).", Id, endPoint, ex.Message); + Log.Verbose("{0}: connection error for endpoint \"{1}\" ({2}).", Id, endPointWrapper.ToEndPoint(), ex.Message); } lock (Lock) @@ -613,15 +623,15 @@ public Client(Lifetime lifetime, IScheduler scheduler, IPEndPoint endPoint, stri public class Server : Base { - public Server(Lifetime lifetime, IScheduler scheduler, IPEndPoint? endPoint = null, string? optId = null) : this(lifetime, scheduler, optId) + public Server(Lifetime lifetime, IScheduler scheduler, EndPointWrapper? endPointWrapper = null, string? optId = null) : this(lifetime, scheduler, optId) { - var serverSocket = CreateServerSocket(endPoint); + var serverSocket = CreateServerSocket(endPointWrapper); StartServerSocket(lifetime, serverSocket); lifetime.OnTermination(() => { - ourStaticLog.Verbose("closing server socket"); + ourStaticLog.Verbose("closing server socket."); CloseSocket(serverSocket); } ); @@ -639,17 +649,22 @@ public Server(Lifetime lifetime, IScheduler scheduler, Socket serverSocket, stri private Server(Lifetime lifetime, IScheduler scheduler, string? optId = null) : base("ServerSocket-"+(optId ?? ""), lifetime, scheduler) {} - public static Socket CreateServerSocket(IPEndPoint? endPoint) + public static Socket CreateServerSocket(EndPointWrapper? endPointWrapper) { - Protocol.InitLogger.Verbose("Creating server socket on endpoint: {0}", endPoint); + Protocol.InitLogger.Verbose("Creating server socket on endpoint: {0}", endPointWrapper?.ToEndPoint()); + // by default we will use IPEndpoint + endPointWrapper ??= EndPointWrapper.CreateIpEndPoint(IPAddress.Loopback, 0); - var serverSocket = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp); + var serverSocket = new Socket(endPointWrapper.AddressFamily, endPointWrapper.SocketType, endPointWrapper.ProtocolType); SetSocketOptions(serverSocket); - endPoint = endPoint ?? new IPEndPoint(IPAddress.Loopback, 0); - serverSocket.Bind(endPoint); + if (endPointWrapper is EndPointWrapper.UnixEndpointWrapper unixEndpointWrapper) + { + if (File.Exists(unixEndpointWrapper.LocalPath)) File.Delete(unixEndpointWrapper.LocalPath); + } + serverSocket.Bind(endPointWrapper.ToEndPoint()); serverSocket.Listen(1); - Protocol.InitLogger.Verbose("Server socket created, listening started on endpoint: {0}", endPoint); + Protocol.InitLogger.Verbose("Server socket created, listening started on endpoint: {0}", endPointWrapper.ToEndPoint()); return serverSocket; } @@ -657,19 +672,30 @@ public static Socket CreateServerSocket(IPEndPoint? endPoint) private void StartServerSocket(Lifetime lifetime, Socket serverSocket) { if (serverSocket == null) throw new ArgumentNullException(nameof(serverSocket)); - Port = ((IPEndPoint) serverSocket.LocalEndPoint).Port; - Log.Verbose("{0} : started, port: {1}", Id, Port); + // To get the actual endpoint, we should take it from the socket, + // because it's possible to bind it to port = 0 and let the "service provider" to assign an available port. + ConnectionEndPoint = serverSocket.LocalEndPoint is {} endPoint ? EndPointWrapper.FromEndPoint(endPoint) : null; + + Log.Verbose("{0} : started, port: {1}", Id, ConnectionEndPoint); var thread = new Thread(() => { +#if NET6_0_OR_GREATER + Log.Catch(async () => +#else Log.Catch(() => +#endif { while (lifetime.IsAlive) { try { - Log.Verbose("{0} : accepting, port: {1}", Id, Port); + Log.Verbose("{0} : accepting, port: {1}", Id, ConnectionEndPoint); +#if NET6_0_OR_GREATER + var s = await serverSocket.AcceptAsync(lifetime); +#else var s = serverSocket.Accept(); +#endif lock (Lock) { if (!lifetime.IsAlive) @@ -701,6 +727,10 @@ private void StartServerSocket(Lifetime lifetime, Socket serverSocket) { Log.Verbose("{0}: ObjectDisposedException with message {1}", Id, e.Message); } + catch (OperationCanceledException e) + { + Log.Verbose("{0} : OperationCanceledException with message {1}", Id, e.Message); + } catch (Exception e) { Log.Error(e, Id); @@ -736,33 +766,34 @@ public void Deconstruct(out IScheduler scheduler, out string? id) } } - - - public class ServerFactory { - [PublicAPI] public readonly int LocalPort; + [PublicAPI] public readonly int? LocalPort; + + public readonly EndPointWrapper ConnectionEndPoint; + [PublicAPI] public readonly IViewableSet Connected = new ViewableSet(); - public ServerFactory(Lifetime lifetime, IScheduler scheduler, IPEndPoint? endpoint = null) - : this(lifetime, () => new WireParameters(scheduler, null), endpoint) {} + public ServerFactory(Lifetime lifetime, IScheduler scheduler, EndPointWrapper? endpointWrapper = null) + : this(lifetime, () => new WireParameters(scheduler, null), endpointWrapper) {} public ServerFactory( Lifetime lifetime, Func wireParametersFactory, - IPEndPoint? endpoint = null + EndPointWrapper? endpointWrapper = null ) { - var serverSocket = Server.CreateServerSocket(endpoint); + var serverSocket = Server.CreateServerSocket(endpointWrapper); var serverSocketLifetimeDef = new LifetimeDefinition(lifetime); serverSocketLifetimeDef.Lifetime.OnTermination(() => { ourStaticLog.Verbose("closing server socket"); Base.CloseSocket(serverSocket); }); - LocalPort = ((IPEndPoint) serverSocket.LocalEndPoint).Port; + LocalPort = (serverSocket.LocalEndPoint as IPEndPoint)?.Port; + ConnectionEndPoint = endpointWrapper ?? EndPointWrapper.CreateIpEndPoint(IPAddress.Loopback, LocalPort ?? 0); void Rec() { diff --git a/rd-net/RdFramework/RdFramework.csproj b/rd-net/RdFramework/RdFramework.csproj index 6d2e119ab..9c5ee759e 100644 --- a/rd-net/RdFramework/RdFramework.csproj +++ b/rd-net/RdFramework/RdFramework.csproj @@ -1,7 +1,7 @@  - netstandard2.0;net35;net472 + netstandard2.0;net35;net472;net6.0; JetBrains.RdFramework JetBrains.Rd diff --git a/rd-net/RdFramework/Tasks/RdFault.cs b/rd-net/RdFramework/Tasks/RdFault.cs index 757f4e44e..cb95c63f3 100644 --- a/rd-net/RdFramework/Tasks/RdFault.cs +++ b/rd-net/RdFramework/Tasks/RdFault.cs @@ -11,9 +11,9 @@ namespace JetBrains.Rd.Tasks [Serializable] public class RdFault : Exception { - public string ReasonTypeFqn { get; private set; } - public string ReasonText { get; private set; } - public string ReasonMessage { get; private set; } + public string? ReasonTypeFqn { get; private set; } + public string? ReasonText { get; private set; } + public string? ReasonMessage { get; private set; } public RdFault(Exception inner) : base(inner.Message, inner) { @@ -29,15 +29,18 @@ public RdFault(Exception inner) : base(inner.Message, inner) } } - [SecurityPermission(SecurityAction.Demand, SerializationFormatter = true)] + // [SecurityPermission(SecurityAction.Demand, SerializationFormatter = true)] + [Obsolete("Obsolete")] protected RdFault(SerializationInfo info, StreamingContext context) : base(info, context) { ReasonTypeFqn = info.GetString(nameof(ReasonTypeFqn)); ReasonText = info.GetString(nameof(ReasonText)); ReasonMessage = info.GetString(nameof(ReasonMessage)); } - + +#if !NET6_0_OR_GREATER [SecurityPermission(SecurityAction.Demand, SerializationFormatter = true)] +#endif public override void GetObjectData(SerializationInfo info, StreamingContext context) { info.AddValue(nameof(ReasonTypeFqn), ReasonTypeFqn); diff --git a/rd-net/RdFramework/Text/Intrinsics/TextBufferVersion.cs b/rd-net/RdFramework/Text/Intrinsics/TextBufferVersion.cs index 2fb33b983..c13e610d9 100644 --- a/rd-net/RdFramework/Text/Intrinsics/TextBufferVersion.cs +++ b/rd-net/RdFramework/Text/Intrinsics/TextBufferVersion.cs @@ -33,7 +33,7 @@ public bool Equals(TextBufferVersion other) return !left.Equals(right); } - public override bool Equals(object obj) + public override bool Equals(object? obj) { if (ReferenceEquals(null, obj)) return false; return obj is TextBufferVersion && Equals((TextBufferVersion) obj); diff --git a/rd-net/RdFramework/Util/ConcurrentSet.cs b/rd-net/RdFramework/Util/ConcurrentSet.cs index f3834b264..547379f54 100644 --- a/rd-net/RdFramework/Util/ConcurrentSet.cs +++ b/rd-net/RdFramework/Util/ConcurrentSet.cs @@ -8,7 +8,7 @@ internal class ConcurrentSet : #if NET35 ICollection #else - ISet + ISet where T : notnull #endif { private readonly ConcurrentDictionary myDictionary = new ConcurrentDictionary(); diff --git a/rd-net/RdFramework/WireEx.cs b/rd-net/RdFramework/WireEx.cs index f772f40d4..2f552e668 100644 --- a/rd-net/RdFramework/WireEx.cs +++ b/rd-net/RdFramework/WireEx.cs @@ -11,7 +11,10 @@ public static int GetServerPort(this IWire wire) var serverSocketWire = wire as SocketWire.Server; if (serverSocketWire == null) throw new ArgumentException("You must use SocketWire.Server to get server port"); - return serverSocketWire.Port; + var port = (serverSocketWire.ConnectionEndPoint as EndPointWrapper.IPEndpointWrapper)?.LocalPort; + if (!port.HasValue) + throw new ArgumentException("You must use SocketWire.Server with connection over TCP to get server port"); + return port.Value; } public static void Send(this IWire wire, RdId id, Action writer) diff --git a/rd-net/Test.Cross/Test.Cross.csproj b/rd-net/Test.Cross/Test.Cross.csproj index afc6426bf..73f5b25bc 100644 --- a/rd-net/Test.Cross/Test.Cross.csproj +++ b/rd-net/Test.Cross/Test.Cross.csproj @@ -1,7 +1,7 @@  - netcoreapp3.1 + netcoreapp3.1;net8.0 Test.RdCross CrossTests AnyCPU diff --git a/rd-net/Test.RdFramework/Reflection/data/Generated/RefExt.cs b/rd-net/Test.RdFramework/Reflection/data/Generated/RefExt.cs index 2ba6f0e76..592ef4811 100644 --- a/rd-net/Test.RdFramework/Reflection/data/Generated/RefExt.cs +++ b/rd-net/Test.RdFramework/Reflection/data/Generated/RefExt.cs @@ -344,7 +344,7 @@ [NotNull] string @string public static new CtxWriteDelegate Write = (ctx, writer, value) => { - writer.WriteString(value.String); + writer.Write(value.String); }; //constants @@ -425,7 +425,7 @@ [NotNull] string openString public static new CtxWriteDelegate Write = (ctx, writer, value) => { - writer.WriteString(value.OpenString); + writer.Write(value.OpenString); }; //constants @@ -524,7 +524,7 @@ [NotNull] string field { value.RdId.Write(writer); RdProperty.Write(ctx, writer, value._String); - writer.WriteString(value.Field); + writer.Write(value.Field); }; //constants @@ -592,7 +592,7 @@ [NotNull] string field { value.RdId.Write(writer); RdProperty.Write(ctx, writer, value._String); - writer.WriteString(value.Field); + writer.Write(value.Field); }; //constants @@ -657,8 +657,8 @@ [NotNull] string openString public static new CtxWriteDelegate Write = (ctx, writer, value) => { - writer.WriteString(value.OpenString); - writer.WriteString(value.OpenDerivedString); + writer.Write(value.OpenString); + writer.Write(value.OpenDerivedString); }; //constants @@ -739,8 +739,8 @@ [NotNull] string openString public static new CtxWriteDelegate Write = (ctx, writer, value) => { - writer.WriteString(value.OpenDerivedString); - writer.WriteString(value.OpenString); + writer.Write(value.OpenDerivedString); + writer.Write(value.OpenString); }; //constants @@ -818,7 +818,7 @@ [NotNull] string openString public static new CtxWriteDelegate Write = (ctx, writer, value) => { - writer.WriteString(value.OpenString); + writer.Write(value.OpenString); }; //constants diff --git a/rd-net/Test.RdFramework/SocketProxyTest.cs b/rd-net/Test.RdFramework/SocketProxyTest.cs index cbc0aa6f7..82f572faf 100644 --- a/rd-net/Test.RdFramework/SocketProxyTest.cs +++ b/rd-net/Test.RdFramework/SocketProxyTest.cs @@ -26,19 +26,19 @@ public void TestSimple() { SynchronousScheduler.Instance.SetActive(lifetime); - var serverProtocol = SocketWireTest.Server(lifetime); + var (serverProtocol, _) = SocketWireIpEndpointTest.CreateServer(lifetime); var proxy = new SocketProxy("TestProxy", proxyLifetime, serverProtocol).With(socketProxy => socketProxy.Start()); - Thread.Sleep(SocketWireTest.DefaultTimeout); + Thread.Sleep(SocketWireIpEndpointTest.DefaultTimeout); - var clientProtocol = SocketWireTest.Client(lifetime, proxy.Port); + var clientProtocol = SocketWireIpEndpointTest.CreateClient(lifetime, proxy.Port); var sp = NewRdSignal().Static(1); - sp.BindTopLevel(lifetime, serverProtocol, SocketWireTest.Top); + sp.BindTopLevel(lifetime, serverProtocol, SocketWireIpEndpointTest.Top); var cp = NewRdSignal().Static(1); - cp.BindTopLevel(lifetime, clientProtocol, SocketWireTest.Top); + cp.BindTopLevel(lifetime, clientProtocol, SocketWireIpEndpointTest.Top); var serverLog = new List(); var clientLog = new List(); diff --git a/rd-net/Test.RdFramework/SocketWireIpEndpointTest.cs b/rd-net/Test.RdFramework/SocketWireIpEndpointTest.cs new file mode 100644 index 000000000..d135855e9 --- /dev/null +++ b/rd-net/Test.RdFramework/SocketWireIpEndpointTest.cs @@ -0,0 +1,147 @@ +#if !NET35 +using System; +using System.Net; +using System.Net.Sockets; +using System.Threading; +using JetBrains.Collections.Viewable; +using JetBrains.Diagnostics; +using JetBrains.Diagnostics.Internal; +using JetBrains.Lifetimes; +using JetBrains.Rd; +using JetBrains.Rd.Impl; +using NUnit.Framework; + +namespace Test.RdFramework; + +[TestFixture] +public class SocketWireIpEndpointTest : SocketWireTestBase +{ + internal override int GetPortOrPath() + { + var l = new TcpListener(IPAddress.Loopback, 0); + l.Start(); + int port = ((IPEndPoint) l.LocalEndpoint).Port; + l.Stop(); + return port; + } + + internal override (IProtocol ServerProtocol, int portOrPath) Server(Lifetime lifetime, int port = 0) => CreateServer(lifetime, port); + + internal static (IProtocol ServerProtocol, int portOrPath) CreateServer(Lifetime lifetime, int port = 0) + { + var id = "TestServer"; + var endPointWrapper = EndPointWrapper.CreateIpEndPoint(IPAddress.Loopback, port); + var server = new SocketWire.Server(lifetime, SynchronousScheduler.Instance, endPointWrapper, id); + var protocol = new Protocol(id, new Serializers(), new Identities(IdKind.Server), SynchronousScheduler.Instance, server, lifetime); + return (protocol, server.Port!.Value); + } + + internal override IProtocol Client(Lifetime lifetime, int port) => CreateClient(lifetime, port); + + internal static IProtocol CreateClient(Lifetime lifetime, int port) + { + var id = "TestClient"; + var client = new SocketWire.Client(lifetime, SynchronousScheduler.Instance, port, id); + return new Protocol(id, new Serializers(), new Identities(IdKind.Server), SynchronousScheduler.Instance, client, lifetime); + } + + internal override EndPointWrapper CreateEndpointWrapper() => EndPointWrapper.CreateIpEndPoint(); + + internal IProtocol Client(Lifetime lifetime, IProtocol serverProtocol) + { + // ReSharper disable once PossibleNullReferenceException + // ReSharper disable once PossibleInvalidOperationException + return Client(lifetime, (serverProtocol.Wire as SocketWire.Server).Port.Value); + } + + internal override (IProtocol ServerProtocol, IProtocol ClientProtocol) CreateServerClient(Lifetime lifetime) + { + var (serverProtocol, _) = Server(lifetime); + var clientProtocol = Client(lifetime, serverProtocol); + return (serverProtocol, clientProtocol); + } + + [TestCase(true)] + [TestCase(false)] + public void TestPacketLoss(bool isClientToServer) + { + using (Log.UsingLogFactory(new TextWriterLogFactory(Console.Out, LoggingLevel.TRACE))) + Lifetime.Using(lifetime => + { + SynchronousScheduler.Instance.SetActive(lifetime); + + var (serverProtocol, _) = Server(lifetime); + var serverWire = (SocketWire.Base) serverProtocol.Wire; + + var proxy = new SocketProxy("TestProxy", lifetime, serverProtocol); + proxy.Start(); + + var clientProtocol = Client(lifetime, proxy.Port); + var clientWire = (SocketWire.Base) clientProtocol.Wire; + + Thread.Sleep(DefaultTimeout); + + if (isClientToServer) + proxy.StopClientToServerMessaging(); + else + proxy.StopServerToClientMessaging(); + + var detectionTimeoutTicks = ((SocketWire.Base) clientProtocol.Wire).HeartBeatInterval.Ticks * + (SocketWire.Base.MaximumHeartbeatDelay + 3); + var detectionTimeout = TimeSpan.FromTicks(detectionTimeoutTicks); + + Thread.Sleep(detectionTimeout); + + Assert.IsTrue(serverWire.Connected.Value); + Assert.IsTrue(clientWire.Connected.Value); + + Assert.IsFalse(serverWire.HeartbeatAlive.Value); + Assert.IsFalse(clientWire.HeartbeatAlive.Value); + + if (isClientToServer) + proxy.StartClientToServerMessaging(); + else + proxy.StartServerToClientMessaging(); + + Thread.Sleep(detectionTimeout); + + Assert.IsTrue(serverWire.Connected.Value); + Assert.IsTrue(clientWire.Connected.Value); + + Assert.IsTrue(serverWire.HeartbeatAlive.Value); + Assert.IsTrue(clientWire.HeartbeatAlive.Value); + + }); + } + + // [Test] + // [Ignore("Not enough timeout to get the correct test")] + // public void TestStressHeartbeat() + // { + // // using (Log.UsingLogFactory(new TextWriterLogFactory(Console.Out, LoggingLevel.TRACE))) + // Lifetime.Using(lifetime => + // { + // SynchronousScheduler.Instance.SetActive(lifetime); + // + // var interval = TimeSpan.FromMilliseconds(50); + // + // var serverProtocol = Server(lifetime); + // var serverWire = ((SocketWire.Base) serverProtocol.Wire).With(wire => wire.HeartBeatInterval = interval); + // + // var latency = TimeSpan.FromMilliseconds(40); + // var proxy = new SocketProxy("TestProxy", lifetime, serverProtocol) {Latency = latency}; + // proxy.Start(); + // + // var clientProtocol = Client(lifetime, proxy.Port); + // var clientWire = ((SocketWire.Base) clientProtocol.Wire).With(wire => wire.HeartBeatInterval = interval); + // + // Thread.Sleep(DefaultTimeout); + // + // serverWire.HeartbeatAlive.WhenFalse(lifetime, _ => Assert.Fail("Detected false disconnect on server side")); + // clientWire.HeartbeatAlive.WhenFalse(lifetime, _ => Assert.Fail("Detected false disconnect on client side")); + // + // Thread.Sleep(TimeSpan.FromSeconds(50)); + // }); + // } +} +#endif \ No newline at end of file diff --git a/rd-net/Test.RdFramework/SocketWireTest.cs b/rd-net/Test.RdFramework/SocketWireTest.cs deleted file mode 100644 index 40d06ffc4..000000000 --- a/rd-net/Test.RdFramework/SocketWireTest.cs +++ /dev/null @@ -1,503 +0,0 @@ -#if !NET35 - -using System; -using System.Collections.Generic; -using System.Linq; -using System.Net; -using System.Net.Sockets; -using System.Threading; -using JetBrains.Collections.Viewable; -using JetBrains.Core; -using JetBrains.Diagnostics; -using JetBrains.Diagnostics.Internal; -using JetBrains.Lifetimes; -using JetBrains.Rd; -using JetBrains.Rd.Base; -using JetBrains.Rd.Impl; -using JetBrains.Threading; -using NUnit.Framework; -using Test.Lifetimes; - -namespace Test.RdFramework -{ - [TestFixture] - public class SocketWireTest : LifetimesTestBase - { - internal static TimeSpan DefaultTimeout = TimeSpan.FromMilliseconds(100); - - internal const string Top = "top"; - private void WaitAndAssert(RdProperty property, T expected, T prev) - { - WaitAndAssert(property, expected, new Maybe(prev)); - } - - - private void WaitAndAssert(RdProperty property, T expected, Maybe prev = default(Maybe)) - { - var start = Environment.TickCount; - const int timeout = 5000; - while (Environment.TickCount - start < timeout && property.Maybe == prev) Thread.Sleep(10); - if (property.Maybe == prev) - throw new TimeoutException($"Timeout {timeout} ms while waiting for value '{expected}'"); - Assert.AreEqual(expected, property.Value); - } - - - static int FindFreePort() - { - TcpListener l = new TcpListener(IPAddress.Loopback, 0); - l.Start(); - int port = ((IPEndPoint) l.LocalEndpoint).Port; - l.Stop(); - return port; - } - - - internal static IProtocol Server(Lifetime lifetime, int? port = null) - { - var id = "TestServer"; - var server = new SocketWire.Server(lifetime, SynchronousScheduler.Instance, new IPEndPoint(IPAddress.Loopback, port ?? 0), id); - return new Protocol(id, new Serializers(), new Identities(IdKind.Server), SynchronousScheduler.Instance, server, lifetime); - } - - internal static IProtocol Client(Lifetime lifetime, int port) - { - var id = "TestClient"; - var client = new SocketWire.Client(lifetime, SynchronousScheduler.Instance, port, id); - return new Protocol(id, new Serializers(), new Identities(IdKind.Server), SynchronousScheduler.Instance, client, lifetime); - } - - internal static IProtocol Client(Lifetime lifetime, IProtocol serverProtocol) - { - // ReSharper disable once PossibleNullReferenceException - return Client(lifetime, (serverProtocol.Wire as SocketWire.Server).Port); - } - - [Test] - public void TestBasicRun() - { - Lifetime.Using(lifetime => - { - SynchronousScheduler.Instance.SetActive(lifetime); - var serverProtocol = Server(lifetime); - var clientProtocol = Client(lifetime, serverProtocol); - - var sp = NewRdProperty().Static(1); - sp.BindTopLevel(lifetime, serverProtocol, Top); - var cp = NewRdProperty().Static(1); - cp.BindTopLevel(lifetime, clientProtocol, Top); - - cp.SetValue(1); - WaitAndAssert(sp, 1); - }); - } - - - [Test] - public void TestOrdering() - { - Lifetime.Using(lifetime => - { - SynchronousScheduler.Instance.SetActive(lifetime); - var serverProtocol = Server(lifetime); - var clientProtocol = Client(lifetime, serverProtocol); - - var sp = NewRdProperty().Static(1); - sp.BindTopLevel(lifetime, serverProtocol, Top); - var cp = NewRdProperty().Static(1); - cp.BindTopLevel(lifetime, clientProtocol, Top); - - var log = new List(); - sp.Advise(lifetime, it => log.Add(it)); - sp.SetValue(1); - sp.SetValue(2); - sp.SetValue(3); - sp.SetValue(4); - sp.SetValue(5); - - while (log.Count < 5) Thread.Sleep(10); - CollectionAssert.AreEqual(new[] {1, 2, 3, 4, 5}, log); - }); - } - - - [Test] - public void TestBigBuffer() - { - Lifetime.Using(lifetime => - { - SynchronousScheduler.Instance.SetActive(lifetime); - var serverProtocol = Server(lifetime); - var clientProtocol = Client(lifetime, serverProtocol); - - var sp = NewRdProperty().Static(1); - sp.BindTopLevel(lifetime, serverProtocol, Top); - var cp = NewRdProperty().Static(1); - cp.BindTopLevel(lifetime, clientProtocol, Top); - - cp.SetValue("1"); - WaitAndAssert(sp, "1"); - - sp.SetValue(new string('a', 100000)); - WaitAndAssert(cp, new string('a', 100000), "1"); - - cp.SetValue("a"); - WaitAndAssert(sp, "a", new string('a', 100000)); - - cp.SetValue("ab"); - WaitAndAssert(sp, "ab", "a"); - - cp.SetValue("abc"); - WaitAndAssert(sp, "abc", "ab"); - }); - } - - - [Test] - public void TestRunWithSlowpokeServer() - { - Lifetime.Using(lifetime => - { - SynchronousScheduler.Instance.SetActive(lifetime); - - var port = FindFreePort(); - var clientProtocol = Client(lifetime, port); - - var cp = NewRdProperty().Static(1); - cp.BindTopLevel(lifetime, clientProtocol, Top); - cp.SetValue(1); - - Thread.Sleep(2000); - var serverProtocol = Server(lifetime, port); - var sp = NewRdProperty().Static(1); - sp.BindTopLevel(lifetime, serverProtocol, Top); - - var prev = sp.Maybe; - - - cp.SetValue(4); - Thread.Sleep(200); - WaitAndAssert(sp, 4, prev); - }); - } - - - [Test] - [Timeout(5000)] - public void TestServerWithoutClient() - { - Lifetime.Using(lifetime => - { - WithLongTimeout(lifetime); - SynchronousScheduler.Instance.SetActive(lifetime); - Server(lifetime); - }); - } - - [Test] - public void TestServerWithoutClientWithDelay() - { - Lifetime.Using(lifetime => - { - SynchronousScheduler.Instance.SetActive(lifetime); - Server(lifetime); - Thread.Sleep(100); - }); - } - - [Test] - public void TestServerWithoutClientWithDelayAndMessages() - { - Lifetime.Using(lifetime => - { - SynchronousScheduler.Instance.SetActive(lifetime); - var protocol = Server(lifetime); - Thread.Sleep(100); - var p = NewRdProperty().Static(1); - p.BindTopLevel(lifetime, protocol, Top); - p.SetValue(1); - p.SetValue(2); - Thread.Sleep(50); - }); - } - - - [Test] - [Timeout(5000)] - public void TestClientWithoutServer() - { - Lifetime.Using(lifetime => - { - WithLongTimeout(lifetime); - SynchronousScheduler.Instance.SetActive(lifetime); - Client(lifetime, FindFreePort()); - }); - } - - [Test] - public void TestClientWithoutServerWithDelay() - { - Lifetime.Using(lifetime => - { - SynchronousScheduler.Instance.SetActive(lifetime); - Client(lifetime, FindFreePort()); - Thread.Sleep(100); - }); - } - - [Test] - public void TestClientWithoutServerWithDelayAndMessages() - { - Lifetime.Using(lifetime => - { - SynchronousScheduler.Instance.SetActive(lifetime); - var protocol = Client(lifetime, FindFreePort()); - Thread.Sleep(100); - var p = NewRdProperty().Static(1); - p.BindTopLevel(lifetime, protocol, Top); - p.SetValue(1); - p.SetValue(2); - Thread.Sleep(50); - }); - } - - - [Test, Ignore("https://github.com/JetBrains/rd/issues/69")] - public void TestDisconnect() => TestDisconnectBase((list, i) => list.Add(i)); - - [Test] - public void TestDisconnect_AllowDuplicates() => TestDisconnectBase((list, i) => - { - // values may be duplicated due to asynchronous acknowledgement - if (list.LastOrDefault() < i) - list.Add(i); - }); - - private void TestDisconnectBase(Action, int> advise) - { - var timeout = TimeSpan.FromSeconds(1); - - Lifetime.Using(lifetime => - { - SynchronousScheduler.Instance.SetActive(lifetime); - var serverProtocol = Server(lifetime); - var clientProtocol = Client(lifetime, serverProtocol); - - var sp = NewRdSignal().Static(1); - sp.BindTopLevel(lifetime, serverProtocol, Top); - - var cp = NewRdSignal().Static(1); - cp.BindTopLevel(lifetime, clientProtocol, Top); - - var log = new List(); - sp.Advise(lifetime, i => advise(log, i)); - - cp.Fire(1); - cp.Fire(2); - Assert.True(SpinWaitEx.SpinUntil(timeout, () => log.Count == 2)); - Assert.AreEqual(new List {1, 2}, log); - - CloseSocket(clientProtocol); - cp.Fire(3); - cp.Fire(4); - - Assert.True(SpinWaitEx.SpinUntil(timeout, () => log.Count == 4)); - Assert.AreEqual(new List {1, 2, 3, 4}, log); - - CloseSocket(serverProtocol); - cp.Fire(5); - cp.Fire(6); - - Assert.True(SpinWaitEx.SpinUntil(timeout, () => log.Count == 6)); - Assert.AreEqual(new List {1, 2, 3, 4, 5, 6}, log); - }); - } - - [Test] - public void TestReconnect() - { - Lifetime.Using(lifetime => - { - SynchronousScheduler.Instance.SetActive(lifetime); - var serverProtocol = Server(lifetime, null); - - var sp = NewRdProperty().Static(1); - sp.BindTopLevel(lifetime, serverProtocol, Top); - sp.IsMaster = false; - - var wire = serverProtocol.Wire as SocketWire.Base; - int clientCount = 0; - wire.NotNull().Connected.WhenTrue(lifetime, _ => - { - clientCount++; - }); - - Assert.AreEqual(0, clientCount); - - Lifetime.Using(lf => - { - var clientProtocol = Client(lf, serverProtocol); - var cp = NewRdProperty().Static(1); - cp.IsMaster = true; - cp.BindTopLevel(lf, clientProtocol, Top); - cp.SetValue(1); - WaitAndAssert(sp, 1); - Assert.AreEqual(1, clientCount); - }); - - - Lifetime.Using(lf => - { - sp = NewRdProperty().Static(2); - sp.BindTopLevel(lifetime, serverProtocol, Top); - - var clientProtocol = Client(lf, serverProtocol); - var cp = NewRdProperty().Static(2); - cp.BindTopLevel(lf, clientProtocol, Top); - cp.SetValue(2); - WaitAndAssert(sp, 2); - Assert.AreEqual(2, clientCount); - }); - - - Lifetime.Using(lf => - { - var clientProtocol = Client(lf, serverProtocol); - var cp = NewRdProperty().Static(2); - cp.BindTopLevel(lf, clientProtocol, Top); - cp.SetValue(3); - WaitAndAssert(sp, 3, 2); - Assert.AreEqual(3, clientCount); - }); - - }); - - } - - [TestCase(true)] - [TestCase(false)] - public void TestPacketLoss(bool isClientToServer) - { - using (Log.UsingLogFactory(new TextWriterLogFactory(Console.Out, LoggingLevel.TRACE))) - Lifetime.Using(lifetime => - { - SynchronousScheduler.Instance.SetActive(lifetime); - - var serverProtocol = Server(lifetime); - var serverWire = (SocketWire.Base) serverProtocol.Wire; - - var proxy = new SocketProxy("TestProxy", lifetime, serverProtocol); - proxy.Start(); - - var clientProtocol = Client(lifetime, proxy.Port); - var clientWire = (SocketWire.Base) clientProtocol.Wire; - - Thread.Sleep(DefaultTimeout); - - if (isClientToServer) - proxy.StopClientToServerMessaging(); - else - proxy.StopServerToClientMessaging(); - - var detectionTimeoutTicks = ((SocketWire.Base) clientProtocol.Wire).HeartBeatInterval.Ticks * - (SocketWire.Base.MaximumHeartbeatDelay + 3); - var detectionTimeout = TimeSpan.FromTicks(detectionTimeoutTicks); - - Thread.Sleep(detectionTimeout); - - Assert.IsTrue(serverWire.Connected.Value); - Assert.IsTrue(clientWire.Connected.Value); - - Assert.IsFalse(serverWire.HeartbeatAlive.Value); - Assert.IsFalse(clientWire.HeartbeatAlive.Value); - - if (isClientToServer) - proxy.StartClientToServerMessaging(); - else - proxy.StartServerToClientMessaging(); - - Thread.Sleep(detectionTimeout); - - Assert.IsTrue(serverWire.Connected.Value); - Assert.IsTrue(clientWire.Connected.Value); - - Assert.IsTrue(serverWire.HeartbeatAlive.Value); - Assert.IsTrue(clientWire.HeartbeatAlive.Value); - - }); - } - - [Test] - [Ignore("Not enough timeout to get the correct test")] - public void TestStressHeartbeat() - { - // using (Log.UsingLogFactory(new TextWriterLogFactory(Console.Out, LoggingLevel.TRACE))) - Lifetime.Using(lifetime => - { - SynchronousScheduler.Instance.SetActive(lifetime); - - var interval = TimeSpan.FromMilliseconds(50); - - var serverProtocol = Server(lifetime); - var serverWire = ((SocketWire.Base) serverProtocol.Wire).With(wire => wire.HeartBeatInterval = interval); - - var latency = TimeSpan.FromMilliseconds(40); - var proxy = new SocketProxy("TestProxy", lifetime, serverProtocol) {Latency = latency}; - proxy.Start(); - - var clientProtocol = Client(lifetime, proxy.Port); - var clientWire = ((SocketWire.Base) clientProtocol.Wire).With(wire => wire.HeartBeatInterval = interval); - - Thread.Sleep(DefaultTimeout); - - serverWire.HeartbeatAlive.WhenFalse(lifetime, _ => Assert.Fail("Detected false disconnect on server side")); - clientWire.HeartbeatAlive.WhenFalse(lifetime, _ => Assert.Fail("Detected false disconnect on client side")); - - Thread.Sleep(TimeSpan.FromSeconds(50)); - }); - } - - - - [Test] - public void TestSocketFactory() - { - var sLifetime = new LifetimeDefinition(); - var factory = new SocketWire.ServerFactory(sLifetime.Lifetime, SynchronousScheduler.Instance); - - var lf1 = new LifetimeDefinition(); - new SocketWire.Client(lf1.Lifetime, SynchronousScheduler.Instance, factory.LocalPort); - SpinWaitEx.SpinUntil(() => factory.Connected.Count == 1); - - var lf2 = new LifetimeDefinition(); - new SocketWire.Client(lf2.Lifetime, SynchronousScheduler.Instance, factory.LocalPort); - SpinWaitEx.SpinUntil(() => factory.Connected.Count == 2); - - - lf1.Terminate(); - SpinWaitEx.SpinUntil(() => factory.Connected.Count == 1); - - sLifetime.Terminate(); - SpinWaitEx.SpinUntil(() => factory.Connected.Count == 0); - } - - - private static void CloseSocket(IProtocol protocol) - { - if (!(protocol.Wire is SocketWire.Base socketWire)) - { - Assert.Fail(); - return; - } - - SocketWire.Base.CloseSocket(socketWire.Socket.NotNull()); - } - - private static void WithLongTimeout(Lifetime lifetime) - { - var oldValue = SocketWire.Base.TimeoutMs; - lifetime.Bracket(() => SocketWire.Base.TimeoutMs = 100_000, () => SocketWire.Base.TimeoutMs = oldValue); - } - } -} -#endif \ No newline at end of file diff --git a/rd-net/Test.RdFramework/SocketWireTestBase.cs b/rd-net/Test.RdFramework/SocketWireTestBase.cs new file mode 100644 index 000000000..8b0e2f03d --- /dev/null +++ b/rd-net/Test.RdFramework/SocketWireTestBase.cs @@ -0,0 +1,401 @@ +#if !NET35 + +using System; +using System.Collections.Generic; +using System.Linq; +using System.Net; +using System.Threading; +using JetBrains.Collections.Viewable; +using JetBrains.Core; +using JetBrains.Diagnostics; +using JetBrains.Lifetimes; +using JetBrains.Rd; +using JetBrains.Rd.Base; +using JetBrains.Rd.Impl; +using JetBrains.Threading; +using NUnit.Framework; +using Test.Lifetimes; + +namespace Test.RdFramework; + +public abstract class SocketWireTestBase : LifetimesTestBase +{ + internal static TimeSpan DefaultTimeout = TimeSpan.FromMilliseconds(100); + + internal const string Top = "top"; + private void WaitAndAssert(RdProperty property, T expected, T prev) + { + WaitAndAssert(property, expected, new Maybe(prev)); + } + + + private void WaitAndAssert(RdProperty property, T expected, Maybe prev = default(Maybe)) + { + var start = Environment.TickCount; + const int timeout = 5000; + while (Environment.TickCount - start < timeout && property.Maybe == prev) Thread.Sleep(10); + if (property.Maybe == prev) + throw new TimeoutException($"Timeout {timeout} ms while waiting for value '{expected}'"); + Assert.AreEqual(expected, property.Value); + } + + internal abstract (IProtocol ServerProtocol, IProtocol ClientProtocol) CreateServerClient(Lifetime lifetime); + internal abstract T GetPortOrPath(); + internal abstract (IProtocol ServerProtocol, T portOrPath) Server(Lifetime lifetime, T portOrPath = default); + internal abstract IProtocol Client(Lifetime lifetime, T portOrPath); + internal abstract EndPointWrapper CreateEndpointWrapper(); + + [Test] + public void TestBasicRun() + { + Lifetime.Using(lifetime => + { + SynchronousScheduler.Instance.SetActive(lifetime); + var (serverProtocol, clientProtocol) = CreateServerClient(lifetime); + + var sp = NewRdProperty().Static(1); + sp.BindTopLevel(lifetime, serverProtocol, Top); + var cp = NewRdProperty().Static(1); + cp.BindTopLevel(lifetime, clientProtocol, Top); + + cp.SetValue(1); + WaitAndAssert(sp, 1); + }); + } + + [Test] + public void TestOrdering() + { + Lifetime.Using(lifetime => + { + SynchronousScheduler.Instance.SetActive(lifetime); + var (serverProtocol, clientProtocol) = CreateServerClient(lifetime); + + var sp = NewRdProperty().Static(1); + sp.BindTopLevel(lifetime, serverProtocol, Top); + var cp = NewRdProperty().Static(1); + cp.BindTopLevel(lifetime, clientProtocol, Top); + + var log = new List(); + sp.Advise(lifetime, it => log.Add(it)); + sp.SetValue(1); + sp.SetValue(2); + sp.SetValue(3); + sp.SetValue(4); + sp.SetValue(5); + + while (log.Count < 5) Thread.Sleep(10); + CollectionAssert.AreEqual(new[] {1, 2, 3, 4, 5}, log); + }); + } + + + [Test] + public void TestBigBuffer() + { + Lifetime.Using(lifetime => + { + SynchronousScheduler.Instance.SetActive(lifetime); + var (serverProtocol, clientProtocol) = CreateServerClient(lifetime); + + var sp = NewRdProperty().Static(1); + sp.BindTopLevel(lifetime, serverProtocol, Top); + var cp = NewRdProperty().Static(1); + cp.BindTopLevel(lifetime, clientProtocol, Top); + + cp.SetValue("1"); + WaitAndAssert(sp, "1"); + + sp.SetValue(new string('a', 100000)); + WaitAndAssert(cp, new string('a', 100000), "1"); + + cp.SetValue("a"); + WaitAndAssert(sp, "a", new string('a', 100000)); + + cp.SetValue("ab"); + WaitAndAssert(sp, "ab", "a"); + + cp.SetValue("abc"); + WaitAndAssert(sp, "abc", "ab"); + }); + } + + + [Test] + public void TestRunWithSlowpokeServer() + { + Lifetime.Using(lifetime => + { + SynchronousScheduler.Instance.SetActive(lifetime); + + var portOrPath = GetPortOrPath(); + var clientProtocol = Client(lifetime, portOrPath); + + var cp = NewRdProperty().Static(1); + cp.BindTopLevel(lifetime, clientProtocol, Top); + cp.SetValue(1); + + Thread.Sleep(2000); + var (serverProtocol, _) = Server(lifetime, portOrPath); + var sp = NewRdProperty().Static(1); + sp.BindTopLevel(lifetime, serverProtocol, Top); + + var prev = sp.Maybe; + + + cp.SetValue(4); + Thread.Sleep(200); + WaitAndAssert(sp, 4, prev); + }); + } + + + [Test] + [Timeout(5000)] + public void TestServerWithoutClient() + { + Lifetime.Using(lifetime => + { + WithLongTimeout(lifetime); + SynchronousScheduler.Instance.SetActive(lifetime); + Server(lifetime); + }); + } + + [Test] + public void TestServerWithoutClientWithDelay() + { + Lifetime.Using(lifetime => + { + SynchronousScheduler.Instance.SetActive(lifetime); + Server(lifetime); + Thread.Sleep(100); + }); + } + + [Test] + public void TestServerWithoutClientWithDelayAndMessages() + { + Lifetime.Using(lifetime => + { + SynchronousScheduler.Instance.SetActive(lifetime); + var (protocol, _) = Server(lifetime); + Thread.Sleep(100); + var p = NewRdProperty().Static(1); + p.BindTopLevel(lifetime, protocol, Top); + p.SetValue(1); + p.SetValue(2); + Thread.Sleep(50); + }); + } + + + [Test] + [Timeout(5000)] + public void TestClientWithoutServer() + { + Lifetime.Using(lifetime => + { + WithLongTimeout(lifetime); + SynchronousScheduler.Instance.SetActive(lifetime); + Client(lifetime, GetPortOrPath()); + }); + } + + [Test] + public void TestClientWithoutServerWithDelay() + { + Lifetime.Using(lifetime => + { + SynchronousScheduler.Instance.SetActive(lifetime); + Client(lifetime, GetPortOrPath()); + Thread.Sleep(100); + }); + } + + [Test] + public void TestClientWithoutServerWithDelayAndMessages() + { + Lifetime.Using(lifetime => + { + SynchronousScheduler.Instance.SetActive(lifetime); + var protocol = Client(lifetime, GetPortOrPath()); + Thread.Sleep(100); + var p = NewRdProperty().Static(1); + p.BindTopLevel(lifetime, protocol, Top); + p.SetValue(1); + p.SetValue(2); + Thread.Sleep(50); + }); + } + + + [Test, Ignore("https://github.com/JetBrains/rd/issues/69")] + public void TestDisconnect() => TestDisconnectBase((list, i) => list.Add(i)); + + [Test] + public void TestDisconnect_AllowDuplicates() => TestDisconnectBase((list, i) => + { + // values may be duplicated due to asynchronous acknowledgement + if (list.LastOrDefault() < i) + list.Add(i); + }); + + private void TestDisconnectBase(Action, int> advise) + { + var timeout = TimeSpan.FromSeconds(1); + + Lifetime.Using(lifetime => + { + SynchronousScheduler.Instance.SetActive(lifetime); + var (serverProtocol, clientProtocol) = CreateServerClient(lifetime); + + var sp = NewRdSignal().Static(1); + sp.BindTopLevel(lifetime, serverProtocol, Top); + + var cp = NewRdSignal().Static(1); + cp.BindTopLevel(lifetime, clientProtocol, Top); + + var log = new List(); + sp.Advise(lifetime, i => advise(log, i)); + + cp.Fire(1); + cp.Fire(2); + Assert.True(SpinWaitEx.SpinUntil(timeout, () => log.Count == 2)); + Assert.AreEqual(new List {1, 2}, log); + + CloseSocket(clientProtocol); + cp.Fire(3); + cp.Fire(4); + + Assert.True(SpinWaitEx.SpinUntil(timeout, () => log.Count == 4)); + Assert.AreEqual(new List {1, 2, 3, 4}, log); + + CloseSocket(serverProtocol); + cp.Fire(5); + cp.Fire(6); + + Assert.True(SpinWaitEx.SpinUntil(timeout, () => log.Count == 6)); + Assert.AreEqual(new List {1, 2, 3, 4, 5, 6}, log); + }); + } + + [Test] + public void TestReconnect() + { + Lifetime.Using(lifetime => + { + SynchronousScheduler.Instance.SetActive(lifetime); + var (serverProtocol, portOrPath) = Server(lifetime, default); + + var sp = NewRdProperty().Static(1); + sp.BindTopLevel(lifetime, serverProtocol, Top); + sp.IsMaster = false; + + var wire = serverProtocol.Wire as SocketWire.Base; + int clientCount = 0; + wire.NotNull().Connected.WhenTrue(lifetime, _ => + { + clientCount++; + }); + + Assert.AreEqual(0, clientCount); + + Lifetime.Using(lf => + { + var clientProtocol = Client(lf, portOrPath); + var cp = NewRdProperty().Static(1); + cp.IsMaster = true; + cp.BindTopLevel(lf, clientProtocol, Top); + cp.SetValue(1); + WaitAndAssert(sp, 1); + Assert.AreEqual(1, clientCount); + }); + + + Lifetime.Using(lf => + { + sp = NewRdProperty().Static(2); + sp.BindTopLevel(lifetime, serverProtocol, Top); + + var clientProtocol = Client(lf, portOrPath); + var cp = NewRdProperty().Static(2); + cp.BindTopLevel(lf, clientProtocol, Top); + cp.SetValue(2); + WaitAndAssert(sp, 2); + Assert.AreEqual(2, clientCount); + }); + + + Lifetime.Using(lf => + { + var clientProtocol = Client(lf, portOrPath); + var cp = NewRdProperty().Static(2); + cp.BindTopLevel(lf, clientProtocol, Top); + cp.SetValue(3); + WaitAndAssert(sp, 3, 2); + Assert.AreEqual(3, clientCount); + }); + + }); + + } + + + [Test] + public void TestSocketFactory() + { + var sLifetime = new LifetimeDefinition(); + var endPointWrapper = CreateEndpointWrapper(); + var factory = new SocketWire.ServerFactory(sLifetime.Lifetime, SynchronousScheduler.Instance, endPointWrapper); + + var lf1 = new LifetimeDefinition(); + // ReSharper disable once PossibleInvalidOperationException + if (endPointWrapper is EndPointWrapper.IPEndpointWrapper) + { + new SocketWire.Client(lf1.Lifetime, SynchronousScheduler.Instance, factory.LocalPort.Value); + SpinWaitEx.SpinUntil(() => factory.Connected.Count == 1); + + var lf2 = new LifetimeDefinition(); + new SocketWire.Client(lf2.Lifetime, SynchronousScheduler.Instance, factory.LocalPort.Value); + SpinWaitEx.SpinUntil(() => factory.Connected.Count == 2); + } + else + { +#if NET6_0_OR_GREATER + var connectionParams = new EndPointWrapper.UnixSocketConnectionParams { Path = (factory.ConnectionEndPoint as EndPointWrapper.UnixEndpointWrapper)!.LocalPath }; + new SocketWire.Client(lf1.Lifetime, SynchronousScheduler.Instance, connectionParams); + SpinWaitEx.SpinUntil(() => factory.Connected.Count == 1); + + var lf2 = new LifetimeDefinition(); + new SocketWire.Client(lf2.Lifetime, SynchronousScheduler.Instance, connectionParams); + SpinWaitEx.SpinUntil(() => factory.Connected.Count == 2); +#endif + } + + lf1.Terminate(); + SpinWaitEx.SpinUntil(() => factory.Connected.Count == 1); + + sLifetime.Terminate(); + SpinWaitEx.SpinUntil(() => factory.Connected.Count == 0); + } + + + private static void CloseSocket(IProtocol protocol) + { + if (!(protocol.Wire is SocketWire.Base socketWire)) + { + Assert.Fail(); + return; + } + + SocketWire.Base.CloseSocket(socketWire.Socket.NotNull()); + } + + private static void WithLongTimeout(Lifetime lifetime) + { + var oldValue = SocketWire.Base.TimeoutMs; + lifetime.Bracket(() => SocketWire.Base.TimeoutMs = 100_000, () => SocketWire.Base.TimeoutMs = oldValue); + } +} +#endif \ No newline at end of file diff --git a/rd-net/Test.RdFramework/SocketWireUnixEndpointTest.cs b/rd-net/Test.RdFramework/SocketWireUnixEndpointTest.cs new file mode 100644 index 000000000..685df5e12 --- /dev/null +++ b/rd-net/Test.RdFramework/SocketWireUnixEndpointTest.cs @@ -0,0 +1,54 @@ +#if NET6_0_OR_GREATER +using System.IO; +using JetBrains.Collections.Viewable; +using JetBrains.Lifetimes; +using JetBrains.Rd; +using JetBrains.Rd.Impl; +using NUnit.Framework; + +namespace Test.RdFramework; + +[TestFixture] +public class SocketWireUnixEndpointTest : SocketWireTestBase +{ + internal override string GetPortOrPath() => Path.GetTempFileName(); + + internal override (IProtocol ServerProtocol, string portOrPath) Server(Lifetime lifetime, string path = null) + { + var id = "TestServer"; + var connectionsParams = new EndPointWrapper.UnixSocketConnectionParams { Path = path }; + var endPointWrapper = EndPointWrapper.CreateUnixEndPoint(connectionsParams); + var server = new SocketWire.Server(lifetime, SynchronousScheduler.Instance, endPointWrapper, id); + var protocol = new Protocol(id, new Serializers(), new Identities(IdKind.Server), SynchronousScheduler.Instance, server, lifetime); + return (protocol, endPointWrapper.LocalPath); + } + + internal override IProtocol Client(Lifetime lifetime, string path) + { + var id = "TestClient"; + var connectionsParams = new EndPointWrapper.UnixSocketConnectionParams { Path = path }; + var client = new SocketWire.Client(lifetime, SynchronousScheduler.Instance, connectionsParams, id); + return new Protocol(id, new Serializers(), new Identities(IdKind.Server), SynchronousScheduler.Instance, client, lifetime); + } + + internal override EndPointWrapper CreateEndpointWrapper() + { + return EndPointWrapper.CreateUnixEndPoint(null); + } + + // internal IProtocol Client(Lifetime lifetime, IProtocol serverProtocol) + // { + // // ReSharper disable once PossibleNullReferenceException + // // ReSharper disable once PossibleInvalidOperationException + // return Client(lifetime, (serverProtocol.Wire as SocketWire.Server).Port.Value); + // } + + internal override (IProtocol ServerProtocol, IProtocol ClientProtocol) CreateServerClient(Lifetime lifetime) + { + var path = GetPortOrPath(); + var (serverProtocol, _) = Server(lifetime, path); + var clientProtocol = Client(lifetime, path); + return (serverProtocol, clientProtocol); + } +} +#endif \ No newline at end of file diff --git a/rd-net/Test.RdFramework/Test.RdFramework.csproj b/rd-net/Test.RdFramework/Test.RdFramework.csproj index f88fe89b5..ec9fe3b2a 100644 --- a/rd-net/Test.RdFramework/Test.RdFramework.csproj +++ b/rd-net/Test.RdFramework/Test.RdFramework.csproj @@ -1,7 +1,7 @@  - net7.0 + netcoreapp3.1;net8.0 net472;$(TargetFrameworks);net35 Full diff --git a/rd-net/Test.Reflection.App/Program.cs b/rd-net/Test.Reflection.App/Program.cs index bd5f386e4..a331859c4 100644 --- a/rd-net/Test.Reflection.App/Program.cs +++ b/rd-net/Test.Reflection.App/Program.cs @@ -36,7 +36,7 @@ static class Program public static event Action OnChar; - private static readonly IPEndPoint ourIpEndPoint = new IPEndPoint(IPAddress.Loopback, ourPort); + private static readonly EndPointWrapper ourWrapper = EndPointWrapper.CreateIpEndPoint(IPAddress.Loopback, ourPort); public static void StartClient() => Main(new [] {"client"}); public static void StartServer() => Main(new [] {"server"}); @@ -63,13 +63,13 @@ private static void MainLifetime(string[] args, LifetimeDefinition lifetimeDefin if (isServer) { Console.Title = "Server"; - wire = new SocketWire.Server(lifetime, scheduler, ourIpEndPoint); + wire = new SocketWire.Server(lifetime, scheduler, ourWrapper); protocol = new Protocol("Server", reflectionSerializers.Serializers, new Identities(IdKind.Server), scheduler, wire, lifetime); } else { Console.Title = "Client"; - wire = new SocketWire.Client(lifetime, scheduler, ourIpEndPoint); + wire = new SocketWire.Client(lifetime, scheduler, ourWrapper); protocol = new Protocol("Client", reflectionSerializers.Serializers, new Identities(IdKind.Client), scheduler, wire, lifetime); } diff --git a/rd-net/Test.Reflection.App/Test.Reflection.App.csproj b/rd-net/Test.Reflection.App/Test.Reflection.App.csproj index 5fdf0a655..ba201a824 100644 --- a/rd-net/Test.Reflection.App/Test.Reflection.App.csproj +++ b/rd-net/Test.Reflection.App/Test.Reflection.App.csproj @@ -2,7 +2,7 @@ Exe - net472 + net472;net8.0 false LatestMajor disable diff --git a/rd-net/global.json b/rd-net/global.json index 0f1517940..e71d6cc57 100644 --- a/rd-net/global.json +++ b/rd-net/global.json @@ -1,6 +1,6 @@ { "sdk": { - "version": "7.0.410", + "version": "8.0.100", "rollForward": "major" } }