Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,9 @@ public void KeyMatchTest()
var user = "SYSDBA";
var password = "masterkey";
var client = new Srp256Client();
var salt = client.GetSalt();
var salt = Srp256Client.GetSalt();
var serverKeyPair = client.ServerSeed(user, password, salt);
var serverSessionKey = client.GetServerSessionKey(user, password, salt, client.PublicKey, serverKeyPair.Item1, serverKeyPair.Item2);
var serverSessionKey = Srp256Client.GetServerSessionKey(user, password, salt, client.PublicKey, serverKeyPair.Item1, serverKeyPair.Item2);
client.ClientProof(user, password, salt, serverKeyPair.Item1);
Assert.AreEqual(serverSessionKey.ToString(), client.SessionKey.ToString());
}
Expand Down
4 changes: 2 additions & 2 deletions src/FirebirdSql.Data.FirebirdClient.Tests/SrpClientTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,9 @@ public void KeyMatchTest()
var user = "SYSDBA";
var password = "masterkey";
var client = new SrpClient();
var salt = client.GetSalt();
var salt = SrpClient.GetSalt();
var serverKeyPair = client.ServerSeed(user, password, salt);
var serverSessionKey = client.GetServerSessionKey(user, password, salt, client.PublicKey, serverKeyPair.Item1, serverKeyPair.Item2);
var serverSessionKey = SrpClient.GetServerSessionKey(user, password, salt, client.PublicKey, serverKeyPair.Item1, serverKeyPair.Item2);
client.ClientProof(user, password, salt, serverKeyPair.Item1);
Assert.AreEqual(serverSessionKey.ToString(), client.SessionKey.ToString());
}
Expand Down
96 changes: 59 additions & 37 deletions src/FirebirdSql.Data.FirebirdClient/Client/Managed/AuthBlock.cs
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,8 @@ sealed class AuthBlock

public bool WireCryptInitialized { get; private set; }

private const byte SEPARATOR_BYTE = (byte)',';

public AuthBlock(GdsConnection connection, string user, string password, WireCryptOption wireCrypt)
{
_srp256 = new Srp256Client();
Expand All @@ -68,66 +70,72 @@ public byte[] UserIdentificationData()
{
using (var result = new MemoryStream(256))
{
Span<byte> scratchpad = stackalloc byte[258];
var userString = Environment.GetEnvironmentVariable("USERNAME") ?? Environment.GetEnvironmentVariable("USER") ?? string.Empty;
var user = Encoding.UTF8.GetBytes(userString);
result.WriteByte(IscCodes.CNCT_user);
result.WriteByte((byte)user.Length);
result.Write(user, 0, user.Length);

var host = Encoding.UTF8.GetBytes(Dns.GetHostName());
result.WriteByte(IscCodes.CNCT_host);
result.WriteByte((byte)host.Length);
result.Write(host, 0, host.Length);
WriteUserIdentificationParams(result, scratchpad, IscCodes.CNCT_user, userString);
var hostName = Dns.GetHostName();
WriteUserIdentificationParams(result, scratchpad, IscCodes.CNCT_host, hostName);

result.WriteByte(IscCodes.CNCT_user_verification);
result.WriteByte(0);

if (!string.IsNullOrEmpty(User))
{
var login = Encoding.UTF8.GetBytes(User);
result.WriteByte(IscCodes.CNCT_login);
result.WriteByte((byte)login.Length);
result.Write(login, 0, login.Length);

var pluginNameBytes = Encoding.UTF8.GetBytes(_srp256.Name);
result.WriteByte(IscCodes.CNCT_plugin_name);
result.WriteByte((byte)pluginNameBytes.Length);
result.Write(pluginNameBytes, 0, pluginNameBytes.Length);
var specificData = Encoding.UTF8.GetBytes(_srp256.PublicKeyHex);
WriteMultiPartHelper(result, IscCodes.CNCT_specific_data, specificData);
WriteUserIdentificationParams(result, scratchpad, IscCodes.CNCT_login, User);
WriteUserIdentificationParams(result, scratchpad, IscCodes.CNCT_plugin_name, _srp256.Name);

var len = Encoding.UTF8.GetBytes(_srp256.PublicKeyHex, scratchpad);
WriteMultiPartHelper(result, IscCodes.CNCT_specific_data, scratchpad[..len]);

var plugins = string.Join(",", new[] { _srp256.Name, _srp.Name });
var pluginsBytes = Encoding.UTF8.GetBytes(plugins);
result.WriteByte(IscCodes.CNCT_plugin_list);
result.WriteByte((byte)pluginsBytes.Length);
result.Write(pluginsBytes, 0, pluginsBytes.Length);
WriteUserIdentificationParams(result, scratchpad, IscCodes.CNCT_plugin_list, _srp256.Name, _srp.Name);

result.WriteByte(IscCodes.CNCT_client_crypt);
result.WriteByte(4);
result.Write(TypeEncoder.EncodeInt32(WireCryptOptionValue(WireCrypt)), 0, 4);
if (!BitConverter.TryWriteBytes(scratchpad, IPAddress.NetworkToHostOrder(WireCryptOptionValue(WireCrypt))))
{
throw new InvalidOperationException("Failed to write wire crypt option bytes.");
}
result.Write(scratchpad[..4]);
}
else
{
var pluginNameBytes = Encoding.UTF8.GetBytes(_sspi.Name);
result.WriteByte(IscCodes.CNCT_plugin_name);
result.WriteByte((byte)pluginNameBytes.Length);
result.Write(pluginNameBytes, 0, pluginNameBytes.Length);
WriteUserIdentificationParams(result, scratchpad, IscCodes.CNCT_plugin_name, _sspi.Name);

var specificData = _sspi.InitializeClientSecurity();
WriteMultiPartHelper(result, IscCodes.CNCT_specific_data, specificData);

result.WriteByte(IscCodes.CNCT_plugin_list);
result.WriteByte((byte)pluginNameBytes.Length);
result.Write(pluginNameBytes, 0, pluginNameBytes.Length);
WriteUserIdentificationParams(result, scratchpad, IscCodes.CNCT_plugin_list, _sspi.Name);

result.WriteByte(IscCodes.CNCT_client_crypt);
result.WriteByte(4);
result.Write(TypeEncoder.EncodeInt32(IscCodes.WIRE_CRYPT_DISABLED), 0, 4);
if (!BitConverter.TryWriteBytes(scratchpad, IPAddress.NetworkToHostOrder(IscCodes.WIRE_CRYPT_DISABLED)))
{
throw new InvalidOperationException("Failed to write wire crypt disabled bytes.");
}
result.Write(scratchpad[..4]);
}

scratchpad.Clear();
return result.ToArray();
}
}

static void WriteUserIdentificationParams(MemoryStream result, Span<byte> scratchpad, byte type, params ReadOnlySpan<string> strings)
{
scratchpad[0] = type;
int len = 2;
if(strings.Length > 0)
{
len += Encoding.UTF8.GetBytes(strings[0], scratchpad[len..]);
for(int i = 1; i < strings.Length; i++)
{
scratchpad[len++] = SEPARATOR_BYTE;
len += Encoding.UTF8.GetBytes(strings[i], scratchpad[len..]);
}
}
scratchpad[1] = (byte)(len - 2);
result.Write(scratchpad[..len]);
}

public void SendContAuthToBuffer()
{
Connection.Xdr.Write(IscCodes.op_cont_auth);
Expand Down Expand Up @@ -309,7 +317,21 @@ void ReleaseAuth()
_sspi = null;
}

static void WriteMultiPartHelper(Stream stream, byte code, byte[] data)
static void WriteMultiPartHelper(MemoryStream stream, byte code, byte[] data)
{
const int MaxLength = 255 - 1;
var part = 0;
for (var i = 0; i < data.Length; i += MaxLength) {
stream.WriteByte(code);
var length = Math.Min(data.Length - i, MaxLength);
stream.WriteByte((byte)(length + 1));
stream.WriteByte((byte)part);
stream.Write(data, i, length);
part++;
}
}

static void WriteMultiPartHelper(MemoryStream stream, byte code, ReadOnlySpan<byte> data)
{
const int MaxLength = 255 - 1;
var part = 0;
Expand All @@ -319,7 +341,7 @@ static void WriteMultiPartHelper(Stream stream, byte code, byte[] data)
var length = Math.Min(data.Length - i, MaxLength);
stream.WriteByte((byte)(length + 1));
stream.WriteByte((byte)part);
stream.Write(data, i, length);
stream.Write(data[i..(i+length)]);
part++;
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,10 @@

//$Authors = Jiri Cincura (jiri@cincura.net)

using System;
using System.IO;
using System.Runtime.CompilerServices;
using System.Runtime.InteropServices;
using System.Threading;
using System.Threading.Tasks;

Expand All @@ -36,12 +38,31 @@ public int Read(byte[] buffer, int offset, int count)
{
return _stream.Read(buffer, offset, count);
}

[MethodImpl(MethodImplOptions.AggressiveInlining)]
public int Read(Span<byte> buffer, int offset, int count)
{
return _stream.Read(buffer[offset..(offset+count)]);
}

[MethodImpl(MethodImplOptions.AggressiveInlining)]
public ValueTask<int> ReadAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken = default)
{
return new ValueTask<int>(_stream.ReadAsync(buffer, offset, count, cancellationToken));
}

[MethodImpl(MethodImplOptions.AggressiveInlining)]
public ValueTask<int> ReadAsync(Memory<byte> buffer, int offset, int count, CancellationToken cancellationToken = default)
{
return _stream.ReadAsync(buffer.Slice(offset, count), cancellationToken);
}

[MethodImpl(MethodImplOptions.AggressiveInlining)]
public void Write(ReadOnlySpan<byte> buffer)
{
_stream.Write(buffer);
}

[MethodImpl(MethodImplOptions.AggressiveInlining)]
public void Write(byte[] buffer, int offset, int count)
{
Expand All @@ -53,6 +74,12 @@ public ValueTask WriteAsync(byte[] buffer, int offset, int count, CancellationTo
return new ValueTask(_stream.WriteAsync(buffer, offset, count, cancellationToken));
}

[MethodImpl(MethodImplOptions.AggressiveInlining)]
public ValueTask WriteAsync(ReadOnlyMemory<byte> buffer, int offset, int count, CancellationToken cancellationToken = default)
{
return _stream.WriteAsync(buffer.Slice(offset, count), cancellationToken);
}

[MethodImpl(MethodImplOptions.AggressiveInlining)]
public void Flush()
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,39 @@ public int Read(byte[] buffer, int offset, int count)
var dataLength = ReadFromInputBuffer(buffer, offset, count);
return dataLength;
}

public int Read(Span<byte> buffer, int offset, int count)
{
if (_inputBuffer.Count < count)
{
var readBuffer = _readBuffer;
int read;
try
{
read = _dataProvider.Read(readBuffer, 0, readBuffer.Length);
}
catch (IOException)
{
IOFailed = true;
throw;
}
if (read != 0)
{
if (_decryptor != null)
{
_decryptor.ProcessBytes(readBuffer, 0, read, readBuffer, 0);
}
if (_decompressor != null)
{
read = HandleDecompression(readBuffer, read);
readBuffer = _compressionBuffer;
}
WriteToInputBuffer(readBuffer, read);
}
}
var dataLength = ReadFromInputBuffer(buffer, offset, count);
return dataLength;
}
public async ValueTask<int> ReadAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken = default)
{
if (_inputBuffer.Count < count)
Expand Down Expand Up @@ -120,6 +153,24 @@ public async ValueTask<int> ReadAsync(byte[] buffer, int offset, int count, Canc
return dataLength;
}

public async ValueTask<int> ReadAsync(Memory<byte> buffer, int offset, int count, CancellationToken cancellationToken = default)
{
var rented = new byte[count];
try
{
var read = await ReadAsync(rented, 0, count, cancellationToken).ConfigureAwait(false);
rented.AsSpan(0, read).CopyTo(buffer.Span.Slice(offset, read));
return read;
}
finally { }
}

public void Write(ReadOnlySpan<byte> buffer)
{
foreach (var b in buffer)
_outputBuffer.Enqueue(b);
}

public void Write(byte[] buffer, int offset, int count)
{
for (var i = offset; i < count; i++)
Expand All @@ -132,6 +183,14 @@ public ValueTask WriteAsync(byte[] buffer, int offset, int count, CancellationTo
return ValueTask.CompletedTask;
}

public ValueTask WriteAsync(ReadOnlyMemory<byte> buffer, int offset, int count, CancellationToken cancellationToken = default)
{
var span = buffer.Span.Slice(offset, count);
foreach (var b in span)
_outputBuffer.Enqueue(b);
return ValueTask.CompletedTask;
}

public void Flush()
{
var buffer = _outputBuffer.ToArray();
Expand Down Expand Up @@ -206,6 +265,15 @@ int ReadFromInputBuffer(byte[] buffer, int offset, int count)
return read;
}

int ReadFromInputBuffer(Span<byte> buffer, int offset, int count)
{
var read = Math.Min(count, _inputBuffer.Count);
for (var i = 0; i < read; i++) {
buffer[offset+i] = _inputBuffer.Dequeue();
}
return read;
}

void WriteToInputBuffer(byte[] data, int count)
{
for (var i = 0; i < count; i++)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -216,7 +216,7 @@ public void Identify(string database)
break;
}

if (AuthBlock.ServerKeys.Any())
if (AuthBlock.ServerKeys.Length > 0)
{
AuthBlock.SendWireCryptToBuffer();
Xdr.Flush();
Expand Down Expand Up @@ -330,7 +330,7 @@ await Xdr.ReadBooleanAsync(cancellationToken).ConfigureAwait(false),
break;
}

if (AuthBlock.ServerKeys.Any())
if (AuthBlock.ServerKeys.Length > 0)
{
await AuthBlock.SendWireCryptToBufferAsync(cancellationToken).ConfigureAwait(false);
await Xdr.FlushAsync(cancellationToken).ConfigureAwait(false);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

//$Authors = Jiri Cincura (jiri@cincura.net)

using System;
using System.Threading;
using System.Threading.Tasks;

Expand All @@ -23,10 +24,14 @@ namespace FirebirdSql.Data.Client.Managed;
interface IDataProvider
{
int Read(byte[] buffer, int offset, int count);
int Read(Span<byte> buffer, int offset, int count);
ValueTask<int> ReadAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken = default);
ValueTask<int> ReadAsync(Memory<byte> buffer, int offset, int count, CancellationToken cancellationToken = default);

void Write(ReadOnlySpan<byte> buffer);
void Write(byte[] buffer, int offset, int count);
ValueTask WriteAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken = default);
ValueTask WriteAsync(ReadOnlyMemory<byte> buffer, int offset, int count, CancellationToken cancellationToken = default);

void Flush();
ValueTask FlushAsync(CancellationToken cancellationToken = default);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,12 +27,18 @@ namespace FirebirdSql.Data.Client.Managed;
interface IXdrReader
{
byte[] ReadBytes(byte[] buffer, int count);
void ReadBytes(Span<byte> buffer, int count);
ValueTask ReadBytesAsync(Memory<byte> buffer, int count, CancellationToken cancellationToken = default);
ValueTask<byte[]> ReadBytesAsync(byte[] buffer, int count, CancellationToken cancellationToken = default);

byte[] ReadOpaque(int length);
void ReadOpaque(Span<byte> buffer, int length);
ValueTask ReadOpaqueAsync(Memory<byte> buffer, int length, CancellationToken cancellationToken = default);
ValueTask<byte[]> ReadOpaqueAsync(int length, CancellationToken cancellationToken = default);

byte[] ReadBuffer();
void ReadBuffer(Span<byte> buffer);
ValueTask ReadBufferAsync(Memory<byte> buffer, CancellationToken cancellationToken = default);
ValueTask<byte[]> ReadBufferAsync(CancellationToken cancellationToken = default);

string ReadString();
Expand Down
Loading