Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
package com.jetbrains.rd.framework

import com.jetbrains.rd.framework.WireAddress.Companion.toSocketAddress
import com.jetbrains.rd.framework.base.WireBase
import com.jetbrains.rd.framework.util.getInputStream
import com.jetbrains.rd.framework.util.getOutputStream
import com.jetbrains.rd.util.*
import com.jetbrains.rd.util.lifetime.Lifetime
import com.jetbrains.rd.util.lifetime.isAlive
Expand All @@ -15,6 +18,11 @@ import java.io.IOException
import java.io.InputStream
import java.io.OutputStream
import java.net.*
import java.nio.channels.AsynchronousCloseException
import java.nio.channels.ClosedChannelException
import java.nio.channels.ServerSocketChannel
import java.nio.channels.SocketChannel
import java.nio.file.Path
import java.time.Duration
import java.util.concurrent.ExecutorService
import java.util.concurrent.Executors
Expand Down Expand Up @@ -75,7 +83,7 @@ class SocketWire {
abstract class Base protected constructor(val id: String, private val lifetime: Lifetime, scheduler: IScheduler) : WireBase() {

protected val logger: Logger = getLogger(this::class)
val socketProvider = OptProperty<Socket>()
val socketProvider = OptProperty<SocketChannel>()

private lateinit var output : OutputStream
private lateinit var socketInput : InputStream
Expand Down Expand Up @@ -111,12 +119,12 @@ class SocketWire {

connected.advise(lifetime) { heartbeatAlive.value = it }

socketProvider.advise(lifetime) { socket ->
socketProvider.advise(lifetime) { socketChannel ->

logger.debug { "$id : connected" }

output = socket.outputStream
socketInput = socket.inputStream.buffered()
output = socketChannel.getOutputStream()
socketInput = socketChannel.getInputStream().buffered()
pkgInput = PkgInputStream(socketInput)

sendBuffer.reprocessUnacknowledged()
Expand All @@ -127,12 +135,12 @@ class SocketWire {
scheduler.queue { connected.value = true }

try {
receiverProc(socket)
receiverProc(socketChannel)
} finally {
scheduler.queue { connected.value = false }
heartbeatJob.cancel()
sendBuffer.pause(disconnectedPauseReason)
catchAndDrop { socket.close() }
catchAndDrop { socketChannel.close() }
}
}
}
Expand Down Expand Up @@ -175,14 +183,9 @@ class SocketWire {
}


private fun receiverProc(socket: Socket) {
private fun receiverProc(socket: SocketChannel) {
while (lifetime.isAlive) {
try {
if (!socket.isConnected) {
logger.debug { "Stop receive messages because socket disconnected" }
break
}

if (!readMsg()) {
logger.debug { "$id: Connection was gracefully shutdown" }
break
Expand Down Expand Up @@ -299,7 +302,11 @@ class SocketWire {
synchronized(socketSendLock) {
output.write(ackPkgHeader.getArray(), 0, pkg_header_len)
}
} catch (ex: SocketException) {
}
catch(ex: ClosedChannelException) {
logger.warn { "$id: Exception raised during ACK, seqn = $seqn" }
}
catch (ex: SocketException) {
logger.warn { "$id: Exception raised during ACK, seqn = $seqn" }
}
}
Expand Down Expand Up @@ -345,7 +352,9 @@ class SocketWire {
}
} catch (ex: SocketException) {
sendBuffer.pause(disconnectedPauseReason)

}
catch (ex: IOException) {
sendBuffer.pause(disconnectedPauseReason)
}
}

Expand Down Expand Up @@ -384,26 +393,39 @@ class SocketWire {
}


class Client(lifetime : Lifetime, scheduler: IScheduler, port : Int, optId: String? = null, hostAddress: InetAddress = InetAddress.getLoopbackAddress()) : Base(optId ?:"ClientSocket", lifetime, scheduler) {
class Client internal constructor(lifetime : Lifetime, scheduler: IScheduler, endpoint: SocketAddress, optId: String? = null) : Base(optId ?:"ClientSocket", lifetime, scheduler) {

constructor(lifetime : Lifetime, scheduler: IScheduler, wireAddress: WireAddress, optId: String? = null) : this(lifetime, scheduler, wireAddress.toSocketAddress(), optId)

constructor(
lifetime: Lifetime,
scheduler: IScheduler,
port: Int,
optId: String? = null,
hostAddress: InetAddress = InetAddress.getLoopbackAddress()
) : this(lifetime, scheduler, InetSocketAddress(hostAddress, port), optId)

init {

var socket : Socket? = null
var socket : SocketChannel? = null
val thread = thread(name = id, isDaemon = true) {
try {
var lastReportedErrorHash = 0
while (lifetime.isAlive) {
try {
val s = Socket()
s.tcpNoDelay = true
val s = when (endpoint) {
is InetSocketAddress -> SocketChannel.open().apply { setOption(StandardSocketOptions.TCP_NODELAY, true) }
is UnixDomainSocketAddress -> SocketChannel.open(StandardProtocolFamily.UNIX)
else -> throw IllegalArgumentException("Only InetSocketAddress and UnixDomainSocketAddress are supported, got: $endpoint")
}

// On windows connect will try to send SYN 3 times with interval of 500ms (total time is 1second)
// Connect timeout doesn't work if it's more than 1 second. But we don't need it because we can close socket any moment.

//https://stackoverflow.com/questions/22417228/prevent-tcp-socket-connection-retries
//HKLM\SYSTEM\CurrentControlSet\Services\Tcpip\Parameters\TcpMaxConnectRetransmissions
logger.debug { "$id : connecting to $hostAddress:$port" }
s.connect(InetSocketAddress(hostAddress, port))
logger.debug { "$id : connecting to $endpoint" }
s.connect(endpoint)

synchronized(lock) {
if (!lifetime.isAlive) {
Expand All @@ -426,12 +448,12 @@ class SocketWire {
if (logger.isEnabled(LogLevel.Debug)) {
logger.log(
LogLevel.Debug,
"$id: connection error for endpoint $hostAddress:$port.",
"$id: connection error for endpoint $endpoint.",
e
)
}
} else {
logger.debug { "$id: connection error for endpoint $hostAddress:$port (${e.message})." }
logger.debug { "$id: connection error for endpoint $endpoint (${e.message})." }
}

val shouldReconnect = synchronized(lock) {
Expand All @@ -450,7 +472,11 @@ class SocketWire {

} catch (ex: SocketException) {
logger.info {"$id: closed with exception: $ex"}
} catch (ex: Throwable) {
}
catch (ex: ClosedChannelException) {
logger.info {"$id: closed with exception: $ex"}
}
catch (ex: Throwable) {
logger.error("$id: unhandled exception.", ex)
} finally {
logger.debug { "$id: terminated." }
Expand Down Expand Up @@ -482,34 +508,54 @@ class SocketWire {
}


class Server internal constructor(lifetime : Lifetime, scheduler: IScheduler, ss: ServerSocket, optId: String? = null, allowReconnect: Boolean) : Base(optId ?:"ServerSocket", lifetime, scheduler) {
val port : Int = ss.localPort
class Server internal constructor(lifetime : Lifetime, scheduler: IScheduler, ss: ServerSocketChannel, optId: String? = null, allowReconnect: Boolean) : Base(optId ?:"ServerSocket", lifetime, scheduler) {
val wireAddress: WireAddress = WireAddress.fromSocketAddress(ss.localAddress)

@Deprecated("Use wireAddress instead")
val port : Int = (wireAddress as? WireAddress.TcpAddress)?.port ?: -1

companion object {
internal fun createServerSocket(lifetime: Lifetime, port : Int?, allowRemoteConnections: Boolean) : ServerSocket {
internal fun createServerSocket(lifetime: Lifetime, port : Int?, allowRemoteConnections: Boolean) : ServerSocketChannel {
val address = if (allowRemoteConnections) InetAddress.getByName("0.0.0.0") else InetAddress.getByName("127.0.0.1")
val portToBind = port ?: 0
val res = ServerSocket()
res.reuseAddress = true
val res = ServerSocketChannel.open().apply { setOption(StandardSocketOptions.SO_REUSEADDR, true) }
res.bind(InetSocketAddress(address, portToBind), 0)
lifetime.onTermination {
res.close()
}
return res
}

internal fun createServerSocket(lifetime: Lifetime, endpoint: SocketAddress) : ServerSocketChannel {
val socketChannel = when (endpoint) {
is InetSocketAddress -> ServerSocketChannel.open().apply { setOption(StandardSocketOptions.SO_REUSEADDR, true) }
is UnixDomainSocketAddress -> ServerSocketChannel.open(StandardProtocolFamily.UNIX)
else -> throw IllegalArgumentException("Only InetSocketAddress and UnixDomainSocketAddress are supported, got: $endpoint")
}

socketChannel.bind(endpoint, 0)
lifetime.onTermination {
socketChannel.close()
}
return socketChannel
}
}

constructor (lifetime : Lifetime, scheduler: IScheduler, port : Int?, optId: String? = null, allowRemoteConnections: Boolean = false) : this(lifetime, scheduler, createServerSocket(lifetime, port, allowRemoteConnections), optId, allowReconnect = true)
constructor (lifetime : Lifetime, scheduler: IScheduler, wireAddress: WireAddress, optId: String? = null) : this(lifetime, scheduler, wireAddress.toSocketAddress(), optId)

internal constructor (lifetime : Lifetime, scheduler: IScheduler, endpoint: SocketAddress, optId: String? = null) : this(lifetime, scheduler, createServerSocket(lifetime, endpoint), optId, allowReconnect = true)

init {
var socket : Socket? = null
var socket : SocketChannel? = null
val thread = thread(name = id, isDaemon = true) {
logger.catch {
while (lifetime.isAlive) {
try {
logger.debug { "$id: listening ${ss.localSocketAddress}" }
logger.debug { "$id: listening ${wireAddress}" }
val s = ss.accept() //could be terminated by close
s.tcpNoDelay = true
if(s.localAddress is InetSocketAddress)
s.setOption(StandardSocketOptions.TCP_NODELAY, true)

synchronized(lock) {
if (!lifetime.isAlive) {
Expand All @@ -520,19 +566,22 @@ class SocketWire {
socket = s
}


socketProvider.set(s)
} catch (ex: SocketException) {
logger.debug { "$id closed with exception: $ex" }
} catch (ex: Exception) {
logger.error("$id closed with exception", ex)
}
catch (ex: Exception) {
when(ex) {
is SocketException, is AsynchronousCloseException, is ClosedChannelException -> {
logger.debug { "$id closed with exception: $ex" }
}
else -> logger.error("$id closed with exception", ex)
}
}

if (!allowReconnect) {
logger.debug { "$id: finished listening on ${ss.localSocketAddress}." }
logger.debug { "$id: finished listening on ${wireAddress}." }
break
} else {
logger.debug { "$id: waiting for reconnection on ${ss.localSocketAddress}." }
logger.debug { "$id: waiting for reconnection on ${wireAddress}." }
}
}
}
Expand Down Expand Up @@ -580,7 +629,7 @@ class SocketWire {

init {
val ss = Server.createServerSocket(lifetime, port, allowRemoteConnections)
localPort = ss.localPort
localPort = (ss.localAddress as? InetSocketAddress)?.port ?: -1

fun rec() {
lifetime.executeIfAlive {
Expand All @@ -595,15 +644,23 @@ class SocketWire {

rec()
}


}

}

sealed class WireAddress {
data class TcpAddress(val address: InetAddress, val port: Int): WireAddress()
data class UnixAddress(val path: Path): WireAddress()

//todo remove
val IWire.serverPort: Int get() {
val serverSocketWire = this as? SocketWire.Server ?: throw IllegalArgumentException("You must use SocketWire.Server to get server port")
return serverSocketWire.port
}
internal companion object {
fun fromSocketAddress(socketAddress: SocketAddress): WireAddress = when (socketAddress) {
is InetSocketAddress -> TcpAddress(socketAddress.address, socketAddress.port)
is UnixDomainSocketAddress -> UnixAddress(socketAddress.path)
else -> error("Unknown socket address type: $socketAddress")
}

fun WireAddress.toSocketAddress() = when (this) {
is TcpAddress -> InetSocketAddress(address, port)
is UnixAddress -> UnixDomainSocketAddress.of(path)
}
}
}
Original file line number Diff line number Diff line change
@@ -1,8 +1,15 @@
package com.jetbrains.rd.framework.util

import java.net.InetAddress
import java.io.IOException
import java.io.InputStream
import java.io.OutputStream
import java.net.InetSocketAddress
import java.net.ServerSocket
import java.nio.ByteBuffer
import java.nio.channels.Channels
import java.nio.channels.ReadableByteChannel
import java.nio.channels.SocketChannel
import java.nio.channels.WritableByteChannel

object NetUtils {
private fun isPortFree(port: Int?): Boolean {
Expand Down Expand Up @@ -52,3 +59,36 @@ data class PortPair private constructor(val senderPort: Int, val receiverPort: I
}
}

fun SocketChannel.getInputStream(): InputStream {
val ch = this
return Channels.newInputStream(object : ReadableByteChannel {
override fun read(dst: ByteBuffer?): Int {
return ch.read(dst)
}

override fun close() {
ch.close()
}

override fun isOpen(): Boolean {
return ch.isOpen
}
})
}

fun SocketChannel.getOutputStream(): OutputStream {
val ch = this
return Channels.newOutputStream(object : WritableByteChannel {
override fun write(src: ByteBuffer?): Int {
return ch.write(src)
}

override fun close() {
ch.close()
}

override fun isOpen(): Boolean {
return ch.isOpen
}
})
}
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ package com.jetbrains.rd.framework.test.cases.wire

import com.jetbrains.rd.framework.IProtocol
import com.jetbrains.rd.framework.SocketWire
import com.jetbrains.rd.framework.serverPort
import com.jetbrains.rd.util.Logger
import com.jetbrains.rd.util.error
import com.jetbrains.rd.util.getLogger
Expand Down Expand Up @@ -52,7 +51,7 @@ class SocketProxy internal constructor(val id: String, val lifetime: Lifetime, p
}

internal constructor(id: String, lifetime: Lifetime, protocol: IProtocol) :
this(id, lifetime, protocol.wire.serverPort)
this(id, lifetime, (protocol.wire as SocketWire.Server).port)

fun start() {
fun setSocketOptions(acceptedClient: Socket) {
Expand All @@ -62,7 +61,7 @@ class SocketProxy internal constructor(val id: String, val lifetime: Lifetime, p
try {
logger.info { "Creating proxies for server and client..." }
proxyServer = Socket(InetAddress.getLoopbackAddress(), serverPort).also { lifetime.onTermination(it) }
proxyClient = SocketWire.Server.createServerSocket(lifetime, 0, false).also { lifetime.onTermination(it) }
proxyClient = SocketWire.Server.createServerSocket(lifetime, 0, false).also { lifetime.onTermination(it) }.socket() // TODO: Remove .socket() call and replace SocketServer to ServerSocketChannel

setSocketOptions(proxyServer)
_port = proxyClient.localPort
Expand Down
Loading
Loading