diff --git a/src/UnitOfWork/IRepository.cs b/src/UnitOfWork/IRepository.cs index d69252a..75f4f2a 100644 --- a/src/UnitOfWork/IRepository.cs +++ b/src/UnitOfWork/IRepository.cs @@ -166,6 +166,7 @@ TResult GetFirstOrDefault(Expression> selector, /// A function to order elements. /// A function to include navigation properties /// true to disable changing tracking; otherwise, false. Default to true. + /// A to observe while waiting for the task to complete. /// Ignore query filters /// An that contains elements that satisfy the condition specified by . /// Ex: This method defaults to a read-only, no-tracking query. @@ -173,6 +174,7 @@ Task GetFirstOrDefaultAsync(Expression> Expression> predicate = null, Func, IOrderedQueryable> orderBy = null, Func, IIncludableQueryable> include = null, + CancellationToken cancellationToken = default, bool disableTracking = true, bool ignoreQueryFilters = false); @@ -182,6 +184,7 @@ Task GetFirstOrDefaultAsync(Expression> /// A function to test each element for a condition. /// A function to order elements. /// A function to include navigation properties + /// A to observe while waiting for the task to complete. /// true to disable changing tracking; otherwise, false. Default to true. /// Ignore query filters /// An that contains elements that satisfy the condition specified by . @@ -189,6 +192,7 @@ Task GetFirstOrDefaultAsync(Expression> Task GetFirstOrDefaultAsync(Expression> predicate = null, Func, IOrderedQueryable> orderBy = null, Func, IIncludableQueryable> include = null, + CancellationToken cancellationToken = default, bool disableTracking = true, bool ignoreQueryFilters = false); @@ -220,7 +224,7 @@ Task GetFirstOrDefaultAsync(Expression> predicate = /// The values of the primary key for the entity to be found. /// A to observe while waiting for the task to complete. /// A that represents the asynchronous find operation. The task result contains the found entity or null. - ValueTask FindAsync(object[] keyValues, CancellationToken cancellationToken); + ValueTask FindAsync(CancellationToken cancellationToken = default, params object[] keyValues); /// /// Gets all entities. This method is not recommended @@ -247,8 +251,9 @@ IQueryable GetAll(Expression> predicate = null, /// /// Gets all entities. This method is not recommended /// + /// A to observe while waiting for the task to complete. /// The . - Task> GetAllAsync(); + Task> GetAllAsync(CancellationToken cancellationToken = default); /// /// Gets all entities. This method is not recommended @@ -256,6 +261,7 @@ IQueryable GetAll(Expression> predicate = null, /// A function to test each element for a condition. /// A function to order elements. /// A function to include navigation properties + /// A to observe while waiting for the task to complete. /// true to disable changing tracking; otherwise, false. Default to true. /// Ignore query filters /// An that contains elements that satisfy the condition specified by . @@ -263,6 +269,7 @@ IQueryable GetAll(Expression> predicate = null, Task> GetAllAsync(Expression> predicate = null, Func, IOrderedQueryable> orderBy = null, Func, IIncludableQueryable> include = null, + CancellationToken cancellationToken = default, bool disableTracking = true, bool ignoreQueryFilters = false); @@ -277,8 +284,9 @@ Task> GetAllAsync(Expression> predicate = nul /// Gets async the count based on a predicate. /// /// + /// A to observe while waiting for the task to complete. /// - Task CountAsync(Expression> predicate = null); + Task CountAsync(Expression> predicate = null, CancellationToken cancellationToken = default); /// /// Gets the long count based on a predicate. @@ -291,8 +299,9 @@ Task> GetAllAsync(Expression> predicate = nul /// Gets async the long count based on a predicate. /// /// + /// A to observe while waiting for the task to complete. /// - Task LongCountAsync(Expression> predicate = null); + Task LongCountAsync(Expression> predicate = null, CancellationToken cancellationToken = default); /// /// Gets the max based on a predicate. @@ -306,9 +315,10 @@ Task> GetAllAsync(Expression> predicate = nul /// Gets the async max based on a predicate. /// /// - /// /// + /// + /// A to observe while waiting for the task to complete. /// decimal - Task MaxAsync(Expression> predicate = null, Expression> selector = null); + Task MaxAsync(Expression> predicate = null, Expression> selector = null, CancellationToken cancellationToken = default); /// /// Gets the min based on a predicate. @@ -323,8 +333,9 @@ Task> GetAllAsync(Expression> predicate = nul /// /// /// + /// A to observe while waiting for the task to complete. /// decimal - Task MinAsync(Expression> predicate = null, Expression> selector = null); + Task MinAsync(Expression> predicate = null, Expression> selector = null, CancellationToken cancellationToken = default); /// /// Gets the average based on a predicate. @@ -335,12 +346,13 @@ Task> GetAllAsync(Expression> predicate = nul decimal Average (Expression> predicate = null, Expression> selector = null); /// - /// Gets the async average based on a predicate. - /// - /// - /// /// - /// decimal - Task AverageAsync(Expression> predicate = null, Expression> selector = null); + /// Gets the async average based on a predicate. + /// + /// + /// + /// A to observe while waiting for the task to complete. + /// decimal + Task AverageAsync(Expression> predicate = null, Expression> selector = null, CancellationToken cancellationToken = default); /// /// Gets the sum based on a predicate. @@ -354,9 +366,10 @@ Task> GetAllAsync(Expression> predicate = nul /// Gets the async sum based on a predicate. /// /// - /// /// + /// + /// A to observe while waiting for the task to complete. /// decimal - Task SumAsync (Expression> predicate = null, Expression> selector = null); + Task SumAsync (Expression> predicate = null, Expression> selector = null, CancellationToken cancellationToken = default); /// /// Gets the Exists record based on a predicate. @@ -368,8 +381,9 @@ Task> GetAllAsync(Expression> predicate = nul /// Gets the Async Exists record based on a predicate. /// /// + /// A to observe while waiting for the task to complete. /// - Task ExistsAsync(Expression> selector = null); + Task ExistsAsync(Expression> selector = null, CancellationToken cancellationToken = default); /// /// Inserts a new entity synchronously. diff --git a/src/UnitOfWork/IUnitOfWork.cs b/src/UnitOfWork/IUnitOfWork.cs index c78ca54..1bcdac4 100644 --- a/src/UnitOfWork/IUnitOfWork.cs +++ b/src/UnitOfWork/IUnitOfWork.cs @@ -4,6 +4,8 @@ // //----------------------------------------------------------------------- +using System.Threading; + namespace Arch.EntityFrameworkCore.UnitOfWork { using System; @@ -44,8 +46,9 @@ public interface IUnitOfWork : IDisposable /// Asynchronously saves all changes made in this unit of work to the database. /// /// True if save changes ensure auto record the change history. + /// A to observe while waiting for the task to complete. /// A that represents the asynchronous save operation. The task result contains the number of state entities written to database. - Task SaveChangesAsync(bool ensureAutoHistory = false); + Task SaveChangesAsync(bool ensureAutoHistory = false, CancellationToken cancellationToken = default); /// /// Executes the specified raw SQL command. diff --git a/src/UnitOfWork/IUnitOfWorkOfT.cs b/src/UnitOfWork/IUnitOfWorkOfT.cs index 6ea0c7e..7401a8a 100644 --- a/src/UnitOfWork/IUnitOfWorkOfT.cs +++ b/src/UnitOfWork/IUnitOfWorkOfT.cs @@ -1,5 +1,6 @@ // Copyright (c) Arch team. All rights reserved. +using System.Threading; using System.Threading.Tasks; using Microsoft.EntityFrameworkCore; @@ -20,7 +21,8 @@ public interface IUnitOfWork : IUnitOfWork where TContext : DbContext /// /// True if save changes ensure auto record the change history. /// An optional array. + /// A to observe while waiting for the task to complete. /// A that represents the asynchronous save operation. The task result contains the number of state entities written to database. - Task SaveChangesAsync(bool ensureAutoHistory = false, params IUnitOfWork[] unitOfWorks); + Task SaveChangesAsync(bool ensureAutoHistory = false, CancellationToken cancellationToken = default, params IUnitOfWork[] unitOfWorks); } } diff --git a/src/UnitOfWork/Repository.cs b/src/UnitOfWork/Repository.cs index b358d3e..b11b1b2 100644 --- a/src/UnitOfWork/Repository.cs +++ b/src/UnitOfWork/Repository.cs @@ -379,6 +379,7 @@ public virtual TEntity GetFirstOrDefault(Expression> predica public virtual async Task GetFirstOrDefaultAsync(Expression> predicate = null, Func, IOrderedQueryable> orderBy = null, Func, IIncludableQueryable> include = null, + CancellationToken cancellationToken = default, bool disableTracking = true, bool ignoreQueryFilters = false) { @@ -406,11 +407,11 @@ public virtual async Task GetFirstOrDefaultAsync(Expression GetFirstOrDefaultAsync(Expression> predicate = null, Func, IOrderedQueryable> orderBy = null, Func, IIncludableQueryable> include = null, + CancellationToken cancellationToken = default, bool disableTracking = true, bool ignoreQueryFilters = false) { IQueryable query = _dbSet; @@ -495,11 +497,11 @@ public virtual async Task GetFirstOrDefaultAsync(Expression GetFirstOrDefaultAsync(Expression /// The values of the primary key for the entity to be found. - /// A that represents the asynchronous insert operation. - public virtual ValueTask FindAsync(params object[] keyValues) => _dbSet.FindAsync(keyValues); + /// A to observe while waiting for the task to complete. + /// A that represents the asynchronous find operation. The task result contains the found entity or null. + public virtual ValueTask FindAsync(CancellationToken cancellationToken = default, params object[] keyValues) => _dbSet.FindAsync(keyValues, cancellationToken); /// /// Finds an entity with the given primary key values. If found, is attached to the context and returned. If no entity is found, then null is returned. /// /// The values of the primary key for the entity to be found. - /// A to observe while waiting for the task to complete. /// A that represents the asynchronous find operation. The task result contains the found entity or null. - public virtual ValueTask FindAsync(object[] keyValues, CancellationToken cancellationToken) => _dbSet.FindAsync(keyValues, cancellationToken); + public virtual ValueTask FindAsync(params object[] keyValues) => _dbSet.FindAsync(keyValues); + /// /// Gets the count based on a predicate. @@ -554,16 +557,17 @@ public virtual int Count(Expression> predicate = null) /// Gets async the count based on a predicate. /// /// + /// A to observe while waiting for the task to complete. /// - public virtual async Task CountAsync(Expression> predicate = null) + public virtual async Task CountAsync(Expression> predicate = null, CancellationToken cancellationToken = default) { if (predicate == null) { - return await _dbSet.CountAsync(); + return await _dbSet.CountAsync(cancellationToken); } else { - return await _dbSet.CountAsync(predicate); + return await _dbSet.CountAsync(predicate, cancellationToken); } } @@ -588,16 +592,17 @@ public virtual long LongCount(Expression> predicate = null) /// Gets async the long count based on a predicate. /// /// + /// A to observe while waiting for the task to complete. /// - public virtual async Task LongCountAsync(Expression> predicate = null) + public virtual async Task LongCountAsync(Expression> predicate = null, CancellationToken cancellationToken = default) { if (predicate == null) { - return await _dbSet.LongCountAsync(); + return await _dbSet.LongCountAsync(cancellationToken); } else { - return await _dbSet.LongCountAsync(predicate); + return await _dbSet.LongCountAsync(predicate, cancellationToken); } } @@ -619,14 +624,15 @@ public virtual T Max(Expression> predicate = null, Expres /// Gets the async max based on a predicate. /// /// - /// /// + /// + /// A to observe while waiting for the task to complete. /// decimal - public virtual async Task MaxAsync(Expression> predicate = null, Expression> selector = null) + public virtual async Task MaxAsync(Expression> predicate = null, Expression> selector = null, CancellationToken cancellationToken = default) { if (predicate == null) - return await _dbSet.MaxAsync(selector); + return await _dbSet.MaxAsync(selector, cancellationToken); else - return await _dbSet.Where(predicate).MaxAsync(selector); + return await _dbSet.Where(predicate).MaxAsync(selector, cancellationToken); } /// @@ -647,14 +653,15 @@ public virtual T Min(Expression> predicate = null, Expres /// Gets the async min based on a predicate. /// /// - /// /// + /// + /// A to observe while waiting for the task to complete. /// decimal - public virtual async Task MinAsync(Expression> predicate = null, Expression> selector = null) + public virtual async Task MinAsync(Expression> predicate = null, Expression> selector = null, CancellationToken cancellationToken = default) { if (predicate == null) - return await _dbSet.MinAsync(selector); + return await _dbSet.MinAsync(selector, cancellationToken); else - return await _dbSet.Where(predicate).MinAsync(selector); + return await _dbSet.Where(predicate).MinAsync(selector, cancellationToken); } /// @@ -675,14 +682,15 @@ public virtual decimal Average(Expression> predicate = null, /// Gets the async average based on a predicate. /// /// - /// /// + /// + /// A to observe while waiting for the task to complete. /// decimal - public virtual async Task AverageAsync(Expression> predicate = null, Expression> selector = null) + public virtual async Task AverageAsync(Expression> predicate = null, Expression> selector = null, CancellationToken cancellationToken = default) { if (predicate == null) - return await _dbSet.AverageAsync(selector); + return await _dbSet.AverageAsync(selector, cancellationToken); else - return await _dbSet.Where(predicate).AverageAsync(selector); + return await _dbSet.Where(predicate).AverageAsync(selector, cancellationToken); } /// @@ -703,14 +711,15 @@ public virtual decimal Sum(Expression> predicate = null, Exp /// Gets the async sum based on a predicate. /// /// - /// /// + /// + /// A to observe while waiting for the task to complete. /// decimal - public virtual async Task SumAsync(Expression> predicate = null, Expression> selector = null) + public virtual async Task SumAsync(Expression> predicate = null, Expression> selector = null, CancellationToken cancellationToken = default) { if (predicate == null) - return await _dbSet.SumAsync(selector); + return await _dbSet.SumAsync(selector, cancellationToken); else - return await _dbSet.Where(predicate).SumAsync(selector); + return await _dbSet.Where(predicate).SumAsync(selector, cancellationToken); } /// @@ -733,16 +742,17 @@ public bool Exists(Expression> selector = null) /// Gets the async exists based on a predicate. /// /// + /// A to observe while waiting for the task to complete. /// - public async Task ExistsAsync(Expression> selector = null) + public async Task ExistsAsync(Expression> selector = null, CancellationToken cancellationToken = default) { if (selector == null) { - return await _dbSet.AnyAsync(); + return await _dbSet.AnyAsync(cancellationToken); } else { - return await _dbSet.AnyAsync(selector); + return await _dbSet.AnyAsync(selector, cancellationToken); } } /// @@ -876,10 +886,11 @@ public virtual void Delete(object id) /// /// Gets all entities. This method is not recommended /// + /// A to observe while waiting for the task to complete. /// The . - public async Task> GetAllAsync() + public async Task> GetAllAsync(CancellationToken cancellationToken = default) { - return await _dbSet.ToListAsync(); + return await _dbSet.ToListAsync(cancellationToken); } /// @@ -888,6 +899,7 @@ public async Task> GetAllAsync() /// A function to test each element for a condition. /// A function to order elements. /// A function to include navigation properties + /// A to observe while waiting for the task to complete. /// true to disable changing tracking; otherwise, false. Default to true. /// Ignore query filters /// An that contains elements that satisfy the condition specified by . @@ -895,6 +907,7 @@ public async Task> GetAllAsync() public async Task> GetAllAsync(Expression> predicate = null, Func, IOrderedQueryable> orderBy = null, Func, IIncludableQueryable> include = null, + CancellationToken cancellationToken = default, bool disableTracking = true, bool ignoreQueryFilters = false) { IQueryable query = _dbSet; @@ -921,11 +934,11 @@ public async Task> GetAllAsync(Expression> pr if (orderBy != null) { - return await orderBy(query).ToListAsync(); + return await orderBy(query).ToListAsync(cancellationToken); } else { - return await query.ToListAsync(); + return await query.ToListAsync(cancellationToken); } } diff --git a/src/UnitOfWork/UnitOfWork.cs b/src/UnitOfWork/UnitOfWork.cs index 6695746..cf998aa 100644 --- a/src/UnitOfWork/UnitOfWork.cs +++ b/src/UnitOfWork/UnitOfWork.cs @@ -5,6 +5,7 @@ using System.Data; using System.Linq; using System.Text.RegularExpressions; +using System.Threading; using System.Threading.Tasks; using System.Transactions; using Microsoft.EntityFrameworkCore; @@ -138,15 +139,16 @@ public int SaveChanges(bool ensureAutoHistory = false) /// Asynchronously saves all changes made in this unit of work to the database. /// /// True if save changes ensure auto record the change history. + /// A to observe while waiting for the task to complete. /// A that represents the asynchronous save operation. The task result contains the number of state entities written to database. - public async Task SaveChangesAsync(bool ensureAutoHistory = false) + public async Task SaveChangesAsync(bool ensureAutoHistory = false, CancellationToken cancellationToken = default) { if (ensureAutoHistory) { _context.EnsureAutoHistory(); } - return await _context.SaveChangesAsync(); + return await _context.SaveChangesAsync(cancellationToken); } /// @@ -154,18 +156,19 @@ public async Task SaveChangesAsync(bool ensureAutoHistory = false) /// /// True if save changes ensure auto record the change history. /// An optional array. + /// A to observe while waiting for the task to complete. /// A that represents the asynchronous save operation. The task result contains the number of state entities written to database. - public async Task SaveChangesAsync(bool ensureAutoHistory = false, params IUnitOfWork[] unitOfWorks) + public async Task SaveChangesAsync(bool ensureAutoHistory = false, CancellationToken cancellationToken = default, params IUnitOfWork[] unitOfWorks) { using (var ts = new TransactionScope()) { var count = 0; foreach (var unitOfWork in unitOfWorks) { - count += await unitOfWork.SaveChangesAsync(ensureAutoHistory); + count += await unitOfWork.SaveChangesAsync(ensureAutoHistory, cancellationToken); } - count += await SaveChangesAsync(ensureAutoHistory); + count += await SaveChangesAsync(ensureAutoHistory, cancellationToken); ts.Complete();