From 94e0d54c31b94772e29c448eef5a0ed86ba182b1 Mon Sep 17 00:00:00 2001 From: Tom Date: Mon, 21 Oct 2024 20:44:20 +0000 Subject: [PATCH] Added policy messages for WS. Fixed DB changes via stores. Updated chat voices messages via WS to use stores. --- Models/ChatterVoice.cs | 2 +- Models/Policy.cs | 12 --- Requests/CreatePolicy.cs | 49 +++++++++ Requests/CreateTTSUser.cs | 3 +- Requests/GetConnections.cs | 2 +- Requests/GetPolicies.cs | 28 +++++ Requests/RequestManager.cs | 15 ++- Requests/RequestResult.cs | 4 +- Requests/UpdatePolicy.cs | 48 +++++++++ Requests/UpdateTTSUser.cs | 3 +- Services/DatabaseService.cs | 5 +- Socket/Handlers/EmoteDetailsHandler.cs | 2 - Startup.cs | 21 ++-- Store/ChatterStore.cs | 10 +- Store/GroupSaveSqlGenerator.cs | 10 +- Store/PolicyStore.cs | 25 ++--- Store/UserStore.cs | 6 +- Store/VoiceStore.cs | 6 +- db/Database.cs | 141 +++++++++++-------------- 19 files changed, 255 insertions(+), 137 deletions(-) delete mode 100644 Models/Policy.cs create mode 100644 Requests/CreatePolicy.cs create mode 100644 Requests/GetPolicies.cs create mode 100644 Requests/UpdatePolicy.cs diff --git a/Models/ChatterVoice.cs b/Models/ChatterVoice.cs index 2a4f55d..82a993d 100644 --- a/Models/ChatterVoice.cs +++ b/Models/ChatterVoice.cs @@ -2,7 +2,7 @@ namespace HermesSocketServer.Models { public class ChatterVoice { - public string ChatterId { get; set; } + public long ChatterId { get; set; } public string UserId { get; set; } public string VoiceId { get; set; } } diff --git a/Models/Policy.cs b/Models/Policy.cs deleted file mode 100644 index 76c9d58..0000000 --- a/Models/Policy.cs +++ /dev/null @@ -1,12 +0,0 @@ -namespace HermesSocketServer.Store -{ - public class Policy - { - public string Id { get; set; } - public string UserId { get; set; } - public string GroupId { get; set; } - public string Path { get; set; } - public int Usage { get; set; } - public TimeSpan Span { get; set; } - } -} \ No newline at end of file diff --git a/Requests/CreatePolicy.cs b/Requests/CreatePolicy.cs new file mode 100644 index 0000000..a760666 --- /dev/null +++ b/Requests/CreatePolicy.cs @@ -0,0 +1,49 @@ +using HermesSocketServer.Models; +using HermesSocketServer.Services; +using ILogger = Serilog.ILogger; + +namespace HermesSocketServer.Requests +{ + public class CreatePolicy : IRequest + { + public string Name => "create_policy"; + public string[] RequiredKeys => ["groupId", "path", "count", "span"]; + private ChannelManager _channels; + private ILogger _logger; + + public CreatePolicy(ChannelManager channels, ILogger logger) + { + _channels = channels; + _logger = logger; + } + + public async Task Grant(string sender, IDictionary? data) + { + var id = Guid.NewGuid(); + string groupId = data["groupId"].ToString()!; + string path = data["path"].ToString()!; + int count = int.Parse(data["count"].ToString()!); + int span = int.Parse(data["span"].ToString()!); + + var policy = new PolicyMessage() + { + Id = id, + UserId = sender, + GroupId = Guid.Parse(groupId), + Path = path, + Usage = count, + Span = span, + }; + + var channel = _channels.Get(sender); + bool result = channel.Policies.Set(id.ToString(), policy); + + if (result) + { + _logger.Information($"Added policy to channel [policy id: {id}][group id: {groupId}][path: {path}][count: {count}][span: {span}][channel: {sender}]"); + return RequestResult.Successful(policy); + } + return RequestResult.Failed("Something went wrong when updating the cache."); + } + } +} \ No newline at end of file diff --git a/Requests/CreateTTSUser.cs b/Requests/CreateTTSUser.cs index 82d89fa..1ec9115 100644 --- a/Requests/CreateTTSUser.cs +++ b/Requests/CreateTTSUser.cs @@ -8,7 +8,6 @@ namespace HermesSocketServer.Requests public class CreateTTSUser : IRequest { public string Name => "create_tts_user"; - public string[] RequiredKeys => ["chatter", "voice"]; private ChannelManager _channels; private Database _database; @@ -40,7 +39,7 @@ namespace HermesSocketServer.Requests bool result = channel.Chatters.Set(chatterId.ToString(), new ChatterVoice() { UserId = sender, - ChatterId = chatterId.ToString(), + ChatterId = chatterId, VoiceId = data["voice"].ToString()! }); diff --git a/Requests/GetConnections.cs b/Requests/GetConnections.cs index 3539f58..b8863a5 100644 --- a/Requests/GetConnections.cs +++ b/Requests/GetConnections.cs @@ -35,7 +35,7 @@ namespace HermesSocketServer.Requests Default = sql.GetBoolean(7) }) ); - return RequestResult.Successful(connections, false); + return RequestResult.Successful(connections, notifyClientsOnAccount: false); } } } \ No newline at end of file diff --git a/Requests/GetPolicies.cs b/Requests/GetPolicies.cs new file mode 100644 index 0000000..0840a08 --- /dev/null +++ b/Requests/GetPolicies.cs @@ -0,0 +1,28 @@ +using HermesSocketServer.Services; +using ILogger = Serilog.ILogger; + +namespace HermesSocketServer.Requests +{ + public class GetPolicies : IRequest + { + public string Name => "get_policies"; + public string[] RequiredKeys => []; + private ChannelManager _channels; + private ILogger _logger; + + public GetPolicies(ChannelManager channels, ILogger logger) + { + _channels = channels; + _logger = logger; + } + + public async Task Grant(string sender, IDictionary? data) + { + var channel = _channels.Get(sender); + var results = channel.Policies.Get().Values; + + _logger.Information($"Fetched policies for channel [policy size: {results.Count}][channel: {sender}]"); + return RequestResult.Successful(results, notifyClientsOnAccount: false); + } + } +} \ No newline at end of file diff --git a/Requests/RequestManager.cs b/Requests/RequestManager.cs index 18a0daa..715df50 100644 --- a/Requests/RequestManager.cs +++ b/Requests/RequestManager.cs @@ -1,6 +1,5 @@ using HermesSocketLibrary.Socket.Data; using HermesSocketServer.Services; -using Serilog; namespace HermesSocketServer.Requests { @@ -21,11 +20,17 @@ namespace HermesSocketServer.Requests public async Task Grant(string sender, RequestMessage? message) { if (message == null || message.Type == null) + { + _logger.Debug($"Request type does not exist [id: {message.RequestId}][nounce: {message.Nounce}]"); return RequestResult.Failed("Request type does not exist."); - + } + var channel = _channels.Get(sender); if (channel == null) + { + _logger.Debug($"Channel does not exist [id: {message.RequestId}][nounce: {message.Nounce}]"); return RequestResult.Failed("Channel does not exist."); + } if (!_requests.TryGetValue(message.Type, out IRequest? request) || request == null) { @@ -36,12 +41,18 @@ namespace HermesSocketServer.Requests if (request.RequiredKeys.Any()) { if (message.Data == null) + { + _logger.Debug($"Request is lacking data entries [id: {message.RequestId}][nounce: {message.Nounce}]"); return RequestResult.Failed($"Request is lacking data entries."); + } foreach (var key in request.RequiredKeys) { if (!message.Data.ContainsKey(key)) + { + _logger.Debug($"Request is missing '{key}' in its data entries [id: {message.RequestId}][nounce: {message.Nounce}]"); return RequestResult.Failed($"Request is missing '{key}' in its data entries."); + } } } diff --git a/Requests/RequestResult.cs b/Requests/RequestResult.cs index f685dd5..a1548dc 100644 --- a/Requests/RequestResult.cs +++ b/Requests/RequestResult.cs @@ -15,12 +15,12 @@ namespace HermesSocketServer.Requests public static RequestResult Successful(object? result, bool notifyClientsOnAccount = true) { - return RequestResult.Successful(result, notifyClientsOnAccount); + return new RequestResult(true, result, notifyClientsOnAccount); } public static RequestResult Failed(string error, bool notifyClientsOnAccount = true) { - return RequestResult.Successful(error, notifyClientsOnAccount); + return new RequestResult(false, error, notifyClientsOnAccount); } } } \ No newline at end of file diff --git a/Requests/UpdatePolicy.cs b/Requests/UpdatePolicy.cs new file mode 100644 index 0000000..2b0b0f9 --- /dev/null +++ b/Requests/UpdatePolicy.cs @@ -0,0 +1,48 @@ +using HermesSocketServer.Models; +using HermesSocketServer.Services; +using ILogger = Serilog.ILogger; + +namespace HermesSocketServer.Requests +{ + public class UpdatePolicy : IRequest + { + public string Name => "update_policy"; + public string[] RequiredKeys => ["id", "groupId", "path", "count", "span"]; + private ChannelManager _channels; + private ILogger _logger; + + public UpdatePolicy(ChannelManager channels, ILogger logger) + { + _channels = channels; + _logger = logger; + } + + public async Task Grant(string sender, IDictionary? data) + { + var id = Guid.Parse(data["id"].ToString()!); + string groupId = data["groupId"].ToString()!; + string path = data["path"].ToString()!; + int count = int.Parse(data["count"].ToString()!); + int span = int.Parse(data["span"].ToString()!); + + var channel = _channels.Get(sender)!; + bool result = channel.Policies.Set(id.ToString(), new PolicyMessage() + { + Id = id, + UserId = sender, + GroupId = Guid.Parse(groupId), + Path = path, + Usage = count, + Span = span, + }); + + if (result) + { + var policy = channel.Policies.Get(id.ToString()); + _logger.Information($"Updated policy to channel [policy id: {id}][group id: {groupId}][path: {path}][count: {count}][span: {span}][channel: {sender}]"); + return RequestResult.Successful(policy); + } + return RequestResult.Failed("Something went wrong when updating the cache."); + } + } +} \ No newline at end of file diff --git a/Requests/UpdateTTSUser.cs b/Requests/UpdateTTSUser.cs index 6d1b07a..2ffeb71 100644 --- a/Requests/UpdateTTSUser.cs +++ b/Requests/UpdateTTSUser.cs @@ -28,6 +28,7 @@ namespace HermesSocketServer.Requests if (long.TryParse(data["chatter"].ToString(), out long chatterId)) data["chatter"] = chatterId; data["voice"] = data["voice"].ToString(); + data["user"] = sender; var check = await _database.ExecuteScalar("SELECT state FROM \"TtsVoiceState\" WHERE \"userId\" = @user AND \"ttsVoiceId\" = @voice", data) ?? false; if ((check is not bool state || !state) && chatterId != _configuration.Tts.OwnerId) @@ -37,7 +38,7 @@ namespace HermesSocketServer.Requests var result = channel.Chatters.Set(chatterId.ToString(), new ChatterVoice() { UserId = sender, - ChatterId = chatterId.ToString(), + ChatterId = chatterId, VoiceId = data["voice"].ToString()! }); if (result) diff --git a/Services/DatabaseService.cs b/Services/DatabaseService.cs index eca05d5..39f12ee 100644 --- a/Services/DatabaseService.cs +++ b/Services/DatabaseService.cs @@ -1,3 +1,4 @@ +using HermesSocketLibrary.db; using HermesSocketServer.Store; namespace HermesSocketServer.Services @@ -29,15 +30,15 @@ namespace HermesSocketServer.Services await Task.Run(async () => { await Task.Delay(TimeSpan.FromSeconds(_configuration.Database.SaveDelayInSeconds)); - + while (true) { await Task.WhenAll([ _voices.Save(), _users.Save(), _channels.Save(), - Task.Delay(TimeSpan.FromSeconds(_configuration.Database.SaveDelayInSeconds)), ]); + await Task.Delay(TimeSpan.FromSeconds(_configuration.Database.SaveDelayInSeconds)); } }); } diff --git a/Socket/Handlers/EmoteDetailsHandler.cs b/Socket/Handlers/EmoteDetailsHandler.cs index 87c1707..6c859ba 100644 --- a/Socket/Handlers/EmoteDetailsHandler.cs +++ b/Socket/Handlers/EmoteDetailsHandler.cs @@ -70,8 +70,6 @@ namespace HermesSocketServer.Socket.Handlers } } } - - } } } \ No newline at end of file diff --git a/Startup.cs b/Startup.cs index 035c9c7..10c8a88 100644 --- a/Startup.cs +++ b/Startup.cs @@ -84,14 +84,6 @@ s.AddSingleton(); s.AddSingleton(); // Request handlers -s.AddSingleton(); -s.AddSingleton(); -s.AddSingleton(); -s.AddSingleton(); -s.AddSingleton(); -s.AddSingleton(); -s.AddSingleton(); -s.AddSingleton(); s.AddSingleton(); s.AddSingleton(); s.AddSingleton(); @@ -100,8 +92,19 @@ s.AddSingleton(); s.AddSingleton(); s.AddSingleton(); s.AddSingleton(); -s.AddSingleton(); +s.AddSingleton(); +s.AddSingleton(); +s.AddSingleton(); +s.AddSingleton(); +s.AddSingleton(); +s.AddSingleton(); +s.AddSingleton(); +s.AddSingleton(); +s.AddSingleton(); +s.AddSingleton(); s.AddSingleton(); +s.AddSingleton(); +s.AddSingleton(); // Managers s.AddSingleton(); diff --git a/Store/ChatterStore.cs b/Store/ChatterStore.cs index f98ffbc..a79eec4 100644 --- a/Store/ChatterStore.cs +++ b/Store/ChatterStore.cs @@ -32,8 +32,8 @@ namespace HermesSocketServer.Store string sql = $"SELECT \"chatterId\", \"ttsVoiceId\" FROM \"TtsChatVoice\" WHERE \"userId\" = @user"; await _database.Execute(sql, data, (reader) => { - string chatterId = reader.GetInt64(0).ToString(); - _store.Add(chatterId, new ChatterVoice() + var chatterId = reader.GetInt64(0); + _store.Add(chatterId.ToString(), new ChatterVoice() { UserId = _userId, ChatterId = chatterId, @@ -70,7 +70,7 @@ namespace HermesSocketServer.Store } _logger.Debug($"TtsChatVoice - Adding {count} rows to database: {sql}"); - await _database.ExecuteScalar(sql); + await _database.ExecuteScalarTransaction(sql); } if (_modified.Any()) { @@ -81,7 +81,7 @@ namespace HermesSocketServer.Store _modified.Clear(); } _logger.Debug($"TtsChatVoice - Modifying {count} rows in database: {sql}"); - await _database.ExecuteScalar(sql); + await _database.ExecuteScalarTransaction(sql); } if (_deleted.Any()) { @@ -92,7 +92,7 @@ namespace HermesSocketServer.Store _deleted.Clear(); } _logger.Debug($"TtsChatVoice - Deleting {count} rows from database: {sql}"); - await _database.ExecuteScalar(sql); + await _database.ExecuteScalarTransaction(sql); } return true; } diff --git a/Store/GroupSaveSqlGenerator.cs b/Store/GroupSaveSqlGenerator.cs index e57100f..6bfe991 100644 --- a/Store/GroupSaveSqlGenerator.cs +++ b/Store/GroupSaveSqlGenerator.cs @@ -86,7 +86,9 @@ namespace HermesSocketServer.Store .Append("),"); } sb.Remove(sb.Length - 1, 1) - .Append($") AS c(\"{string.Join("\", \"", columns)}\") WHERE id = c.id;"); + .Append($") AS c(\"{string.Join("\", \"", columns)}\") WHERE ") + .Append(string.Join(" AND ", keyColumns.Select(c => "t.\"" + c + "\" = c.\"" + c + "\""))) + .Append(";"); return sb.ToString(); } @@ -131,6 +133,12 @@ namespace HermesSocketServer.Store sb.Append("'") .Append(value) .Append("'"); + else if (type == typeof(Guid)) + sb.Append("'") + .Append(value?.ToString()) + .Append("'"); + else if (type == typeof(TimeSpan)) + sb.Append(((TimeSpan)value).TotalMilliseconds); else sb.Append(value); } diff --git a/Store/PolicyStore.cs b/Store/PolicyStore.cs index 1b591cc..1cc6e53 100644 --- a/Store/PolicyStore.cs +++ b/Store/PolicyStore.cs @@ -1,13 +1,14 @@ using HermesSocketLibrary.db; +using HermesSocketServer.Models; namespace HermesSocketServer.Store { - public class PolicyStore : GroupSaveStore + public class PolicyStore : GroupSaveStore { private readonly string _userId; private readonly Database _database; private readonly Serilog.ILogger _logger; - private readonly GroupSaveSqlGenerator _generator; + private readonly GroupSaveSqlGenerator _generator; public PolicyStore(string userId, Database database, Serilog.ILogger logger) : base(logger) @@ -25,7 +26,7 @@ namespace HermesSocketServer.Store { "count", "Usage" }, { "timespan", "Span" }, }; - _generator = new GroupSaveSqlGenerator(ctp); + _generator = new GroupSaveSqlGenerator(ctp); } public override async Task Load() @@ -34,25 +35,25 @@ namespace HermesSocketServer.Store string sql = $"SELECT id, \"groupId\", path, count, timespan FROM \"GroupPermissionPolicy\" WHERE \"userId\" = @user"; await _database.Execute(sql, data, (reader) => { - string id = reader.GetString(0).ToString(); - _store.Add(id, new Policy() + var id = reader.GetGuid(0); + _store.Add(id.ToString(), new PolicyMessage() { Id = id, UserId = _userId, - GroupId = reader.GetString(1), + GroupId = reader.GetGuid(1), Path = reader.GetString(2), Usage = reader.GetInt32(3), - Span = TimeSpan.FromMilliseconds(reader.GetInt32(4)), + Span = reader.GetInt32(4), }); }); _logger.Information($"Loaded {_store.Count} policies from database."); } - protected override void OnInitialAdd(string key, Policy value) + protected override void OnInitialAdd(string key, PolicyMessage value) { } - protected override void OnInitialModify(string key, Policy value) + protected override void OnInitialModify(string key, PolicyMessage value) { } @@ -75,7 +76,7 @@ namespace HermesSocketServer.Store } _logger.Debug($"GroupPermissionPolicy - Adding {count} rows to database: {sql}"); - await _database.ExecuteScalar(sql); + await _database.ExecuteScalarTransaction(sql); } if (_modified.Any()) { @@ -86,7 +87,7 @@ namespace HermesSocketServer.Store _modified.Clear(); } _logger.Debug($"GroupPermissionPolicy - Modifying {count} rows in database: {sql}"); - await _database.ExecuteScalar(sql); + await _database.ExecuteScalarTransaction(sql); } if (_deleted.Any()) { @@ -97,7 +98,7 @@ namespace HermesSocketServer.Store _deleted.Clear(); } _logger.Debug($"GroupPermissionPolicy - Deleting {count} rows from database: {sql}"); - await _database.ExecuteScalar(sql); + await _database.ExecuteScalarTransaction(sql); } return true; } diff --git a/Store/UserStore.cs b/Store/UserStore.cs index 029434d..267f244 100644 --- a/Store/UserStore.cs +++ b/Store/UserStore.cs @@ -70,7 +70,7 @@ namespace HermesSocketServer.Store _added.Clear(); } _logger.Debug($"User - Adding {count} rows to database: {sql}"); - await _database.ExecuteScalar(sql); + await _database.ExecuteScalarTransaction(sql); } if (_modified.Any()) { @@ -81,7 +81,7 @@ namespace HermesSocketServer.Store _modified.Clear(); } _logger.Debug($"User - Modifying {count} rows in database: {sql}"); - await _database.ExecuteScalar(sql); + await _database.ExecuteScalarTransaction(sql); } if (_deleted.Any()) { @@ -92,7 +92,7 @@ namespace HermesSocketServer.Store _deleted.Clear(); } _logger.Debug($"User - Deleting {count} rows from database: {sql}"); - await _database.ExecuteScalar(sql); + await _database.ExecuteScalarTransaction(sql); } return true; } diff --git a/Store/VoiceStore.cs b/Store/VoiceStore.cs index d9bc048..8fbc8b5 100644 --- a/Store/VoiceStore.cs +++ b/Store/VoiceStore.cs @@ -73,7 +73,7 @@ namespace HermesSocketServer.Store } _logger.Debug($"TtsVoice - Adding {count} rows to database: {sql}"); - await _database.ExecuteScalar(sql); + await _database.ExecuteScalarTransaction(sql); } if (_modified.Any()) { @@ -84,7 +84,7 @@ namespace HermesSocketServer.Store _modified.Clear(); } _logger.Debug($"TtsVoice - Modifying {count} rows in database: {sql}"); - await _database.ExecuteScalar(sql); + await _database.ExecuteScalarTransaction(sql); } if (_deleted.Any()) { @@ -95,7 +95,7 @@ namespace HermesSocketServer.Store _deleted.Clear(); } _logger.Debug($"TtsVoice - Deleting {count} rows from database: {sql}"); - await _database.ExecuteScalar(sql); + await _database.ExecuteScalarTransaction(sql); } return true; } diff --git a/db/Database.cs b/db/Database.cs index e8a2dc4..b60d428 100644 --- a/db/Database.cs +++ b/db/Database.cs @@ -5,9 +5,7 @@ namespace HermesSocketLibrary.db { public class Database { - private NpgsqlDataSource _source; - private ServerConfiguration _configuration; - + private readonly NpgsqlDataSource _source; public NpgsqlDataSource DataSource { get => _source; } @@ -19,111 +17,96 @@ namespace HermesSocketLibrary.db public async Task Execute(string sql, IDictionary? values, Action reading) { - using (var connection = await _source.OpenConnectionAsync()) + await using var connection = await _source.OpenConnectionAsync(); + await using var command = new NpgsqlCommand(sql, connection); + if (values != null) { - using (var command = new NpgsqlCommand(sql, connection)) - { - if (values != null) - { - foreach (var entry in values) - command.Parameters.AddWithValue(entry.Key, entry.Value); - } - await command.PrepareAsync(); + foreach (var entry in values) + command.Parameters.AddWithValue(entry.Key, entry.Value); + } + await command.PrepareAsync(); - using (var reader = await command.ExecuteReaderAsync()) - { - while (await reader.ReadAsync()) - { - reading(reader); - } - } - } + await using var reader = await command.ExecuteReaderAsync(); + while (await reader.ReadAsync()) + { + reading(reader); } } public async Task Execute(string sql, Action action, Action reading) { - using (var connection = await _source.OpenConnectionAsync()) - { - using (var command = new NpgsqlCommand(sql, connection)) - { - action(command); - await command.PrepareAsync(); + await using var connection = await _source.OpenConnectionAsync(); + await using var command = new NpgsqlCommand(sql, connection); + action(command); + await command.PrepareAsync(); - using (var reader = await command.ExecuteReaderAsync()) - { - while (await reader.ReadAsync()) - { - reading(reader); - } - } - } + await using var reader = await command.ExecuteReaderAsync(); + while (await reader.ReadAsync()) + { + reading(reader); } } public async Task Execute(string sql, IDictionary? values) { - using (var connection = await _source.OpenConnectionAsync()) + await using var connection = await _source.OpenConnectionAsync(); + await using var command = new NpgsqlCommand(sql, connection); + if (values != null) { - using (var command = new NpgsqlCommand(sql, connection)) - { - if (values != null) - { - foreach (var entry in values) - command.Parameters.AddWithValue(entry.Key, entry.Value); - } - await command.PrepareAsync(); - - return await command.ExecuteNonQueryAsync(); - } + foreach (var entry in values) + command.Parameters.AddWithValue(entry.Key, entry.Value); } + await command.PrepareAsync(); + return await command.ExecuteNonQueryAsync(); } public async Task Execute(string sql, Action prepare) { - using (var connection = await _source.OpenConnectionAsync()) - { - using (var command = new NpgsqlCommand(sql, connection)) - { - prepare(command); - await command.PrepareAsync(); - - return await command.ExecuteNonQueryAsync(); - } - } + await using var connection = await _source.OpenConnectionAsync(); + await using var command = new NpgsqlCommand(sql, connection); + prepare(command); + await command.PrepareAsync(); + return await command.ExecuteNonQueryAsync(); } public async Task ExecuteScalar(string sql, IDictionary? values = null) { - using (var connection = await _source.OpenConnectionAsync()) + await using var connection = await _source.OpenConnectionAsync(); + await using var command = new NpgsqlCommand(sql, connection); + if (values != null) { - using (var command = new NpgsqlCommand(sql, connection)) - { - if (values != null) - { - foreach (var entry in values) - command.Parameters.AddWithValue(entry.Key, entry.Value); - } - - await command.PrepareAsync(); - - return await command.ExecuteScalarAsync(); - } + foreach (var entry in values) + command.Parameters.AddWithValue(entry.Key, entry.Value); } + + await command.PrepareAsync(); + return await command.ExecuteScalarAsync(); + } + + public async Task ExecuteScalarTransaction(string sql, IDictionary? values = null) + { + await using var connection = await _source.OpenConnectionAsync(); + await using var transaction = await connection.BeginTransactionAsync(); + await using var command = new NpgsqlCommand(sql, connection, transaction); + if (values != null) + { + foreach (var entry in values) + command.Parameters.AddWithValue(entry.Key, entry.Value); + } + + await command.PrepareAsync(); + var results = await command.ExecuteScalarAsync(); + await transaction.CommitAsync(); + return results; } public async Task ExecuteScalar(string sql, Action action) { - using (var connection = await _source.OpenConnectionAsync()) - { - using (var command = new NpgsqlCommand(sql, connection)) - { - action(command); - await command.PrepareAsync(); - - return await command.ExecuteScalarAsync(); - } - } + await using var connection = await _source.OpenConnectionAsync(); + await using var command = new NpgsqlCommand(sql, connection); + action(command); + await command.PrepareAsync(); + return await command.ExecuteScalarAsync(); } } } \ No newline at end of file