diff --git a/src/Persistence/EntityFramework/AccountRepository.cs b/src/Persistence/EntityFramework/AccountRepository.cs
index f5f226025..2b72dc494 100644
--- a/src/Persistence/EntityFramework/AccountRepository.cs
+++ b/src/Persistence/EntityFramework/AccountRepository.cs
@@ -28,30 +28,35 @@ public AccountRepository(IContextAwareRepositoryProvider repositoryProvider, ILo
}
///
- public override async ValueTask GetByIdAsync(Guid id, CancellationToken cancellationToken = default)
+ public override async ValueTask GetByIdAsync(Guid id, EntityFrameworkContextBase? context, CancellationToken cancellationToken = default)
{
cancellationToken.ThrowIfCancellationRequested();
- (this.RepositoryProvider as ICacheAwareRepositoryProvider)?.EnsureCachesForCurrentGameConfiguration();
+ using var ownedContext = context is null ? this.GetContext(null) : null;
+ var origin = context ?? ownedContext!;
- using var context = this.GetContext();
- await context.Context.Database.OpenConnectionAsync(cancellationToken).ConfigureAwait(false);
+ if (origin.Context is EntityDataContext { CurrentGameConfiguration: not null })
+ {
+ this.RepositoryProvider.EnsureCachesForCurrentGameConfiguration(origin);
+ }
+
+ await origin.Context.Database.OpenConnectionAsync(cancellationToken).ConfigureAwait(false);
try
{
- var accountEntry = context.Context.ChangeTracker.Entries().FirstOrDefault(a => a.Entity.Id == id);
+ var accountEntry = origin.Context.ChangeTracker.Entries().FirstOrDefault(a => a.Entity.Id == id);
var account = accountEntry?.Entity;
if (account is null || accountEntry?.References.Any(reference => !reference.IsLoaded) is true)
{
if (account is not null)
{
- context.Detach(account);
+ origin.Detach(account);
}
var objectLoader = new AccountJsonObjectLoader();
- account = await objectLoader.LoadObjectAsync(id, context.Context, cancellationToken).ConfigureAwait(false);
- if (account != null && !(context.Context.Entry(account) is { } entry && entry.State != EntityState.Detached))
+ account = await objectLoader.LoadObjectAsync(id, origin.Context, cancellationToken).ConfigureAwait(false);
+ if (account != null && !(origin.Context.Entry(account) is { } entry && entry.State != EntityState.Detached))
{
- context.Context.Attach(account);
+ origin.Context.Attach(account);
}
}
@@ -59,7 +64,7 @@ public AccountRepository(IContextAwareRepositoryProvider repositoryProvider, ILo
}
finally
{
- await context.Context.Database.CloseConnectionAsync().ConfigureAwait(false);
+ await origin.Context.Database.CloseConnectionAsync().ConfigureAwait(false);
}
}
@@ -67,23 +72,25 @@ public AccountRepository(IContextAwareRepositoryProvider repositoryProvider, ILo
/// Gets the account by character name.
///
/// The character name.
+ /// The originating context.
/// The cancellation token.
///
/// The account; otherwise, null.
///
- internal async ValueTask GetAccountByCharacterNameAsync(string characterName, CancellationToken cancellationToken = default)
+ internal async ValueTask GetAccountByCharacterNameAsync(string characterName, EntityFrameworkContextBase? context, CancellationToken cancellationToken = default)
{
cancellationToken.ThrowIfCancellationRequested();
- using var context = this.GetContext();
- var accountInfo = await context.Context.Set()
+ using var ownedContext = context is null ? this.GetContext(null) : null;
+ var origin = context ?? ownedContext!;
+ var accountInfo = await origin.Context.Set()
.AsNoTracking()
.FirstOrDefaultAsync(a => a.RawCharacters.Any(c => c.Name == characterName), cancellationToken)
.ConfigureAwait(false);
if (accountInfo != null)
{
- return await this.GetByIdAsync(accountInfo.Id, cancellationToken).ConfigureAwait(false);
+ return await this.GetByIdAsync(accountInfo.Id, origin, cancellationToken).ConfigureAwait(false);
}
return null;
@@ -94,14 +101,16 @@ public AccountRepository(IContextAwareRepositoryProvider repositoryProvider, ILo
///
/// The login name.
/// The password.
+ /// The originating context.
/// The cancellation token.
///
/// The account, if the password is correct. Otherwise, null.
///
- internal async ValueTask GetAccountByLoginNameAsync(string loginName, string password, CancellationToken cancellationToken = default)
+ internal async ValueTask GetAccountByLoginNameAsync(string loginName, string password, EntityFrameworkContextBase? context, CancellationToken cancellationToken = default)
{
- using var context = this.GetContext();
- return await this.LoadAccountByLoginNameByJsonQueryAsync(loginName, password, context, cancellationToken).ConfigureAwait(false);
+ using var ownedContext = context is null ? this.GetContext(null) : null;
+ var origin = context ?? ownedContext!;
+ return await this.LoadAccountByLoginNameByJsonQueryAsync(loginName, password, origin, cancellationToken).ConfigureAwait(false);
}
///
@@ -109,14 +118,16 @@ public AccountRepository(IContextAwareRepositoryProvider repositoryProvider, ILo
///
/// The login name.
/// The password.
+ /// The originating context.
/// The cancellation token.
/// The if credentials are valid; otherwise, null.
- internal async ValueTask AuthenticateAsync(string loginName, string password, CancellationToken cancellationToken = default)
+ internal async ValueTask AuthenticateAsync(string loginName, string password, EntityFrameworkContextBase? context, CancellationToken cancellationToken = default)
{
- using var context = this.GetContext();
+ using var ownedContext = context is null ? this.GetContext(null) : null;
+ var origin = context ?? ownedContext!;
cancellationToken.ThrowIfCancellationRequested();
- var accountInfo = await context.Context.Set()
+ var accountInfo = await origin.Context.Set()
.Where(a => a.LoginName == loginName)
.Select(a => new { a.PasswordHash, a.State })
.AsNoTracking()
@@ -134,39 +145,41 @@ public AccountRepository(IContextAwareRepositoryProvider repositoryProvider, ILo
/// Gets the account by login name.
///
/// The login name.
+ /// The originating context.
/// The cancellation token.
///
/// The account, if exists. Otherwise, null.
///
- internal async ValueTask GetAccountByLoginNameAsync(string loginName, CancellationToken cancellationToken = default)
+ internal async ValueTask GetAccountByLoginNameAsync(string loginName, EntityFrameworkContextBase? context, CancellationToken cancellationToken = default)
{
- using var context = this.GetContext();
+ using var ownedContext = context is null ? this.GetContext(null) : null;
+ var origin = context ?? ownedContext!;
- var accountInfo = await context.Context.Set()
+ var accountInfo = await origin.Context.Set()
.Select(a => new { a.Id, a.LoginName })
.AsNoTracking()
.FirstOrDefaultAsync(a => a.LoginName == loginName, cancellationToken).ConfigureAwait(false);
if (accountInfo != null)
{
- return await this.GetByIdAsync(accountInfo.Id, cancellationToken).ConfigureAwait(false);
+ return await this.GetByIdAsync(accountInfo.Id, origin, cancellationToken).ConfigureAwait(false);
}
return null;
}
- private async ValueTask LoadAccountByLoginNameByJsonQueryAsync(string loginName, string password, EntityFrameworkContextBase context, CancellationToken cancellationToken)
+ private async ValueTask LoadAccountByLoginNameByJsonQueryAsync(string loginName, string password, EntityFrameworkContextBase origin, CancellationToken cancellationToken)
{
cancellationToken.ThrowIfCancellationRequested();
- var accountInfo = await context.Context.Set()
+ var accountInfo = await origin.Context.Set()
.Select(a => new { a.Id, a.LoginName, a.PasswordHash })
.AsNoTracking()
.FirstOrDefaultAsync(a => a.LoginName == loginName, cancellationToken).ConfigureAwait(false);
if (accountInfo != null && BCrypt.Verify(password, accountInfo.PasswordHash))
{
- return await this.GetByIdAsync(accountInfo.Id, cancellationToken).ConfigureAwait(false);
+ return await this.GetByIdAsync(accountInfo.Id, origin, cancellationToken).ConfigureAwait(false);
}
return null;
diff --git a/src/Persistence/EntityFramework/CacheAwareRepositoryProvider.cs b/src/Persistence/EntityFramework/CacheAwareRepositoryProvider.cs
index 64cb0aa6d..e64857699 100644
--- a/src/Persistence/EntityFramework/CacheAwareRepositoryProvider.cs
+++ b/src/Persistence/EntityFramework/CacheAwareRepositoryProvider.cs
@@ -1,4 +1,4 @@
-//
+//
// Licensed under the MIT License. See LICENSE file in the project root for full license information.
//
@@ -21,7 +21,7 @@ internal class CacheAwareRepositoryProvider : ICacheAwareRepositoryProvider, ICo
{
private readonly ILoggerFactory _loggerFactory;
- private readonly IRepositoryProvider _nonCachingRepositoryProvider;
+ private readonly IContextAwareRepositoryProvider _nonCachingRepositoryProvider;
private CachingRepositoryProvider _cachingRepositoryProvider;
@@ -29,63 +29,89 @@ internal class CacheAwareRepositoryProvider : ICacheAwareRepositoryProvider, ICo
/// Initializes a new instance of the class.
///
/// The logger factory.
- /// The configuration change publisher.
+ /// The configuration change listener.
public CacheAwareRepositoryProvider(ILoggerFactory loggerFactory, IConfigurationChangeListener? configurationChangeListener)
{
this._loggerFactory = loggerFactory;
this._cachingRepositoryProvider = new CachingRepositoryProvider(loggerFactory, this);
- this._nonCachingRepositoryProvider = new NonCachingRepositoryProvider(loggerFactory, this, configurationChangeListener, this.ContextStack);
+ this._nonCachingRepositoryProvider = new NonCachingRepositoryProvider(loggerFactory, this, configurationChangeListener);
}
///
- public IContextStack ContextStack { get; } = new ContextStack();
+ public IRepository? GetRepository(Type objectType)
+ {
+ return this._cachingRepositoryProvider.GetRepository(objectType)
+ ?? this._nonCachingRepositoryProvider.GetRepository(objectType);
+ }
///
- public IRepository? GetRepository(Type objectType)
+ public IRepository? GetRepository()
+ where T : class
+ {
+ return this._cachingRepositoryProvider.GetRepository()
+ ?? this._nonCachingRepositoryProvider.GetRepository();
+ }
+
+ ///
+ public TRepository? GetRepository()
+ where T : class
+ where TRepository : IRepository
{
- if (this.ContextStack.GetCurrentContext() is EntityFrameworkContextBase { Context: ITypedContext editContext }
+ return this._cachingRepositoryProvider.GetRepository()
+ ?? this._nonCachingRepositoryProvider.GetRepository();
+ }
+
+ ///
+ public IRepository? GetRepository(Type objectType, EntityFrameworkContextBase? context)
+ {
+ if (context is { Context: ITypedContext editContext }
&& editContext.IsIncluded(objectType))
{
return this._nonCachingRepositoryProvider.GetRepository(objectType);
}
- return this._cachingRepositoryProvider.GetRepository(objectType)
- ?? this._nonCachingRepositoryProvider.GetRepository(objectType);
+ return this.GetRepository(objectType);
}
///
- public IRepository? GetRepository()
+ public IRepository? GetRepository(EntityFrameworkContextBase? context)
where T : class
{
- if (this.ContextStack.GetCurrentContext() is EntityFrameworkContextBase { Context: ITypedContext editContext }
+ if (context is { Context: ITypedContext editContext }
&& (editContext.IsIncluded(typeof(T)) || editContext.IsIncluded(typeof(T).BaseType!)))
{
return this._nonCachingRepositoryProvider.GetRepository();
}
- return this._cachingRepositoryProvider.GetRepository()
- ?? this._nonCachingRepositoryProvider.GetRepository();
+ return this.GetRepository();
}
///
- public TRepository? GetRepository()
+ public TRepository? GetRepository(EntityFrameworkContextBase? context)
where T : class
where TRepository : IRepository
{
- if (this.ContextStack.GetCurrentContext() is EntityFrameworkContextBase { Context: ITypedContext editContext }
+ if (context is { Context: ITypedContext editContext }
&& (editContext.IsIncluded(typeof(T)) || editContext.IsIncluded(typeof(T).BaseType!)))
{
return this._nonCachingRepositoryProvider.GetRepository();
}
- return this._cachingRepositoryProvider.GetRepository()
- ?? this._nonCachingRepositoryProvider.GetRepository();
+ return this.GetRepository();
+ }
+
+ ///
+ public void EnsureCachesForCurrentGameConfiguration(EntityFrameworkContextBase context)
+ {
+ this._cachingRepositoryProvider.EnsureCachesForCurrentGameConfiguration(context);
}
///
public void EnsureCachesForCurrentGameConfiguration()
{
- this._cachingRepositoryProvider.EnsureCachesForCurrentGameConfiguration();
+ // The caches are ensured per originating context (see the overload taking a context)
+ // and lazily on access. Without an ambient context, there is no "current" configuration
+ // to ensure here.
}
///
@@ -117,4 +143,4 @@ public async ValueTask UpdateCachedInstanceAsync(object changedInstance)
}
}
}
-}
\ No newline at end of file
+}
diff --git a/src/Persistence/EntityFramework/CachedRepository{T}.cs b/src/Persistence/EntityFramework/CachedRepository{T}.cs
index 9c2d938ac..50a845404 100644
--- a/src/Persistence/EntityFramework/CachedRepository{T}.cs
+++ b/src/Persistence/EntityFramework/CachedRepository{T}.cs
@@ -5,13 +5,14 @@
namespace MUnique.OpenMU.Persistence.EntityFramework;
using System.Collections;
+using System.Linq;
using System.Threading;
///
/// A repository which caches all of its data in memory.
///
/// The type of the business object.
-public class CachedRepository : IRepository
+public class CachedRepository : IRepository, IContextAwareRepository
where T : class, IIdentifiable
{
private readonly IDictionary _cache;
@@ -42,7 +43,18 @@ async ValueTask IRepository.GetAllAsync(CancellationToken cancellat
}
///
- public async ValueTask> GetAllAsync(CancellationToken cancellationToken = default)
+ public ValueTask> GetAllAsync(CancellationToken cancellationToken = default)
+ {
+ return this.GetAllAsync(null, cancellationToken);
+ }
+
+ ///
+ /// Gets all objects, using the given originating context to load them from the base repository.
+ ///
+ /// The originating context, or null.
+ /// The cancellation token.
+ /// All objects of the repository.
+ internal async ValueTask> GetAllAsync(EntityFrameworkContextBase? context, CancellationToken cancellationToken = default)
{
if (this._allLoaded)
{
@@ -62,7 +74,9 @@ public async ValueTask> GetAllAsync(CancellationToken cancellatio
this._loading = true;
try
{
- IEnumerable values = await this.BaseRepository.GetAllAsync(cancellationToken).ConfigureAwait(false);
+ IEnumerable values = this.BaseRepository is IContextAwareRepository contextAware
+ ? (await contextAware.GetAllAsync(context, cancellationToken).ConfigureAwait(false)).Cast()
+ : await this.BaseRepository.GetAllAsync(cancellationToken).ConfigureAwait(false);
foreach (var obj in values)
{
if (!this._cache.ContainsKey(obj.Id))
@@ -82,9 +96,21 @@ public async ValueTask> GetAllAsync(CancellationToken cancellatio
}
///
- public async ValueTask GetByIdAsync(Guid id, CancellationToken cancellationToken = default)
+ public ValueTask GetByIdAsync(Guid id, CancellationToken cancellationToken = default)
{
- await this.GetAllAsync(cancellationToken).ConfigureAwait(false);
+ return this.GetByIdAsync(id, null, cancellationToken);
+ }
+
+ ///
+ /// Gets an object by identifier, using the given originating context to load the data.
+ ///
+ /// The identifier.
+ /// The originating context, or null.
+ /// The cancellation token.
+ /// The object with the identifier.
+ internal async ValueTask GetByIdAsync(Guid id, EntityFrameworkContextBase? context, CancellationToken cancellationToken = default)
+ {
+ await this.GetAllAsync(context, cancellationToken).ConfigureAwait(false);
this._cache.TryGetValue(id, out var result);
return result;
}
@@ -95,6 +121,18 @@ public async ValueTask> GetAllAsync(CancellationToken cancellatio
return await this.GetByIdAsync(id, cancellationToken).ConfigureAwait(false);
}
+ ///
+ async ValueTask IContextAwareRepository.GetAllAsync(EntityFrameworkContextBase? context, CancellationToken cancellationToken)
+ {
+ return await this.GetAllAsync(context, cancellationToken).ConfigureAwait(false);
+ }
+
+ ///
+ async ValueTask