Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 40 additions & 27 deletions src/Persistence/EntityFramework/AccountRepository.cs
Original file line number Diff line number Diff line change
Expand Up @@ -28,62 +28,69 @@ public AccountRepository(IContextAwareRepositoryProvider repositoryProvider, ILo
}

/// <inheritdoc />
public override async ValueTask<Account?> GetByIdAsync(Guid id, CancellationToken cancellationToken = default)
public override async ValueTask<Account?> GetByIdAsync(Guid id, EntityFrameworkContextBase? context, CancellationToken cancellationToken = default)
{
cancellationToken.ThrowIfCancellationRequested();

(this.RepositoryProvider as ICacheAwareRepositoryProvider)?.EnsureCachesForCurrentGameConfiguration();
using var ownedContext = context is null ? this.GetContext(null) : null;
var origin = context ?? ownedContext!;

using var context = this.GetContext();
await context.Context.Database.OpenConnectionAsync(cancellationToken).ConfigureAwait(false);
if (origin.Context is EntityDataContext { CurrentGameConfiguration: not null })
{
this.RepositoryProvider.EnsureCachesForCurrentGameConfiguration(origin);
}

await origin.Context.Database.OpenConnectionAsync(cancellationToken).ConfigureAwait(false);
try
{
var accountEntry = context.Context.ChangeTracker.Entries<Account>().FirstOrDefault(a => a.Entity.Id == id);
var accountEntry = origin.Context.ChangeTracker.Entries<Account>().FirstOrDefault(a => a.Entity.Id == id);
var account = accountEntry?.Entity;
if (account is null || accountEntry?.References.Any(reference => !reference.IsLoaded) is true)
{
if (account is not null)
{
context.Detach(account);
origin.Detach(account);
}

var objectLoader = new AccountJsonObjectLoader();
account = await objectLoader.LoadObjectAsync<Account>(id, context.Context, cancellationToken).ConfigureAwait(false);
if (account != null && !(context.Context.Entry(account) is { } entry && entry.State != EntityState.Detached))
account = await objectLoader.LoadObjectAsync<Account>(id, origin.Context, cancellationToken).ConfigureAwait(false);
if (account != null && !(origin.Context.Entry(account) is { } entry && entry.State != EntityState.Detached))
{
context.Context.Attach(account);
origin.Context.Attach(account);
}
}

return account;
}
finally
{
await context.Context.Database.CloseConnectionAsync().ConfigureAwait(false);
await origin.Context.Database.CloseConnectionAsync().ConfigureAwait(false);
}
}

/// <summary>
/// Gets the account by character name.
/// </summary>
/// <param name="characterName">The character name.</param>
/// <param name="context">The originating context.</param>
/// <param name="cancellationToken">The cancellation token.</param>
/// <returns>
/// The account; otherwise, null.
/// </returns>
internal async ValueTask<DataModel.Entities.Account?> GetAccountByCharacterNameAsync(string characterName, CancellationToken cancellationToken = default)
internal async ValueTask<DataModel.Entities.Account?> GetAccountByCharacterNameAsync(string characterName, EntityFrameworkContextBase? context, CancellationToken cancellationToken = default)
{
cancellationToken.ThrowIfCancellationRequested();

using var context = this.GetContext();
var accountInfo = await context.Context.Set<Account>()
using var ownedContext = context is null ? this.GetContext(null) : null;
var origin = context ?? ownedContext!;
var accountInfo = await origin.Context.Set<Account>()
.AsNoTracking()
.FirstOrDefaultAsync(a => a.RawCharacters.Any(c => c.Name == characterName), cancellationToken)
.ConfigureAwait(false);

if (accountInfo != null)
{
return await this.GetByIdAsync(accountInfo.Id, cancellationToken).ConfigureAwait(false);
return await this.GetByIdAsync(accountInfo.Id, origin, cancellationToken).ConfigureAwait(false);
}

return null;
Expand All @@ -94,29 +101,33 @@ public AccountRepository(IContextAwareRepositoryProvider repositoryProvider, ILo
/// </summary>
/// <param name="loginName">The login name.</param>
/// <param name="password">The password.</param>
/// <param name="context">The originating context.</param>
/// <param name="cancellationToken">The cancellation token.</param>
/// <returns>
/// The account, if the password is correct. Otherwise, null.
/// </returns>
internal async ValueTask<DataModel.Entities.Account?> GetAccountByLoginNameAsync(string loginName, string password, CancellationToken cancellationToken = default)
internal async ValueTask<DataModel.Entities.Account?> GetAccountByLoginNameAsync(string loginName, string password, EntityFrameworkContextBase? context, CancellationToken cancellationToken = default)
{
using var context = this.GetContext();
return await this.LoadAccountByLoginNameByJsonQueryAsync(loginName, password, context, cancellationToken).ConfigureAwait(false);
using var ownedContext = context is null ? this.GetContext(null) : null;
var origin = context ?? ownedContext!;
return await this.LoadAccountByLoginNameByJsonQueryAsync(loginName, password, origin, cancellationToken).ConfigureAwait(false);
}

/// <summary>
/// Authenticates the account by login name and password, returning minimal state data without loading the full account.
/// </summary>
/// <param name="loginName">The login name.</param>
/// <param name="password">The password.</param>
/// <param name="context">The originating context.</param>
/// <param name="cancellationToken">The cancellation token.</param>
/// <returns>The <see cref="DataModel.Entities.AccountState"/> if credentials are valid; otherwise, null.</returns>
internal async ValueTask<DataModel.Entities.AccountState?> AuthenticateAsync(string loginName, string password, CancellationToken cancellationToken = default)
internal async ValueTask<DataModel.Entities.AccountState?> AuthenticateAsync(string loginName, string password, EntityFrameworkContextBase? context, CancellationToken cancellationToken = default)
{
using var context = this.GetContext();
using var ownedContext = context is null ? this.GetContext(null) : null;
var origin = context ?? ownedContext!;
cancellationToken.ThrowIfCancellationRequested();

var accountInfo = await context.Context.Set<Account>()
var accountInfo = await origin.Context.Set<Account>()
.Where(a => a.LoginName == loginName)
.Select(a => new { a.PasswordHash, a.State })
.AsNoTracking()
Expand All @@ -134,39 +145,41 @@ public AccountRepository(IContextAwareRepositoryProvider repositoryProvider, ILo
/// Gets the account by login name.
/// </summary>
/// <param name="loginName">The login name.</param>
/// <param name="context">The originating context.</param>
/// <param name="cancellationToken">The cancellation token.</param>
/// <returns>
/// The account, if exists. Otherwise, null.
/// </returns>
internal async ValueTask<DataModel.Entities.Account?> GetAccountByLoginNameAsync(string loginName, CancellationToken cancellationToken = default)
internal async ValueTask<DataModel.Entities.Account?> GetAccountByLoginNameAsync(string loginName, EntityFrameworkContextBase? context, CancellationToken cancellationToken = default)
{
using var context = this.GetContext();
using var ownedContext = context is null ? this.GetContext(null) : null;
var origin = context ?? ownedContext!;

var accountInfo = await context.Context.Set<Account>()
var accountInfo = await origin.Context.Set<Account>()
.Select(a => new { a.Id, a.LoginName })
.AsNoTracking()
.FirstOrDefaultAsync(a => a.LoginName == loginName, cancellationToken).ConfigureAwait(false);

if (accountInfo != null)
{
return await this.GetByIdAsync(accountInfo.Id, cancellationToken).ConfigureAwait(false);
return await this.GetByIdAsync(accountInfo.Id, origin, cancellationToken).ConfigureAwait(false);
}

return null;
}

private async ValueTask<Account?> LoadAccountByLoginNameByJsonQueryAsync(string loginName, string password, EntityFrameworkContextBase context, CancellationToken cancellationToken)
private async ValueTask<Account?> LoadAccountByLoginNameByJsonQueryAsync(string loginName, string password, EntityFrameworkContextBase origin, CancellationToken cancellationToken)
{
cancellationToken.ThrowIfCancellationRequested();

var accountInfo = await context.Context.Set<Account>()
var accountInfo = await origin.Context.Set<Account>()
.Select(a => new { a.Id, a.LoginName, a.PasswordHash })
.AsNoTracking()
.FirstOrDefaultAsync(a => a.LoginName == loginName, cancellationToken).ConfigureAwait(false);

if (accountInfo != null && BCrypt.Verify(password, accountInfo.PasswordHash))
{
return await this.GetByIdAsync(accountInfo.Id, cancellationToken).ConfigureAwait(false);
return await this.GetByIdAsync(accountInfo.Id, origin, cancellationToken).ConfigureAwait(false);
}

return null;
Expand Down
64 changes: 45 additions & 19 deletions src/Persistence/EntityFramework/CacheAwareRepositoryProvider.cs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// <copyright file="CacheAwareRepositoryProvider.cs" company="MUnique">
// <copyright file="CacheAwareRepositoryProvider.cs" company="MUnique">
// Licensed under the MIT License. See LICENSE file in the project root for full license information.
// </copyright>

Expand All @@ -21,71 +21,97 @@ internal class CacheAwareRepositoryProvider : ICacheAwareRepositoryProvider, ICo
{
private readonly ILoggerFactory _loggerFactory;

private readonly IRepositoryProvider _nonCachingRepositoryProvider;
private readonly IContextAwareRepositoryProvider _nonCachingRepositoryProvider;

private CachingRepositoryProvider _cachingRepositoryProvider;

/// <summary>
/// Initializes a new instance of the <see cref="CacheAwareRepositoryProvider"/> class.
/// </summary>
/// <param name="loggerFactory">The logger factory.</param>
/// <param name="configurationChangePublisher">The configuration change publisher.</param>
/// <param name="configurationChangeListener">The configuration change listener.</param>
public CacheAwareRepositoryProvider(ILoggerFactory loggerFactory, IConfigurationChangeListener? configurationChangeListener)
{
this._loggerFactory = loggerFactory;
this._cachingRepositoryProvider = new CachingRepositoryProvider(loggerFactory, this);
this._nonCachingRepositoryProvider = new NonCachingRepositoryProvider(loggerFactory, this, configurationChangeListener, this.ContextStack);
this._nonCachingRepositoryProvider = new NonCachingRepositoryProvider(loggerFactory, this, configurationChangeListener);
}

/// <inheritdoc />
public IContextStack ContextStack { get; } = new ContextStack();
public IRepository? GetRepository(Type objectType)
{
return this._cachingRepositoryProvider.GetRepository(objectType)
?? this._nonCachingRepositoryProvider.GetRepository(objectType);
}

/// <inheritdoc />
public IRepository? GetRepository(Type objectType)
public IRepository<T>? GetRepository<T>()
where T : class
{
return this._cachingRepositoryProvider.GetRepository<T>()
?? this._nonCachingRepositoryProvider.GetRepository<T>();
}

/// <inheritdoc />
public TRepository? GetRepository<T, TRepository>()
where T : class
where TRepository : IRepository
{
if (this.ContextStack.GetCurrentContext() is EntityFrameworkContextBase { Context: ITypedContext editContext }
return this._cachingRepositoryProvider.GetRepository<T, TRepository>()
?? this._nonCachingRepositoryProvider.GetRepository<T, TRepository>();
}

/// <inheritdoc />
public IRepository? GetRepository(Type objectType, EntityFrameworkContextBase? context)
{
if (context is { Context: ITypedContext editContext }
&& editContext.IsIncluded(objectType))
{
return this._nonCachingRepositoryProvider.GetRepository(objectType);
}

return this._cachingRepositoryProvider.GetRepository(objectType)
?? this._nonCachingRepositoryProvider.GetRepository(objectType);
return this.GetRepository(objectType);
}

/// <inheritdoc />
public IRepository<T>? GetRepository<T>()
public IRepository<T>? GetRepository<T>(EntityFrameworkContextBase? context)
where T : class
{
if (this.ContextStack.GetCurrentContext() is EntityFrameworkContextBase { Context: ITypedContext editContext }
if (context is { Context: ITypedContext editContext }
&& (editContext.IsIncluded(typeof(T)) || editContext.IsIncluded(typeof(T).BaseType!)))
{
return this._nonCachingRepositoryProvider.GetRepository<T>();
}

return this._cachingRepositoryProvider.GetRepository<T>()
?? this._nonCachingRepositoryProvider.GetRepository<T>();
return this.GetRepository<T>();
}

/// <inheritdoc />
public TRepository? GetRepository<T, TRepository>()
public TRepository? GetRepository<T, TRepository>(EntityFrameworkContextBase? context)
where T : class
where TRepository : IRepository
{
if (this.ContextStack.GetCurrentContext() is EntityFrameworkContextBase { Context: ITypedContext editContext }
if (context is { Context: ITypedContext editContext }
&& (editContext.IsIncluded(typeof(T)) || editContext.IsIncluded(typeof(T).BaseType!)))
{
return this._nonCachingRepositoryProvider.GetRepository<T, TRepository>();
}

return this._cachingRepositoryProvider.GetRepository<T, TRepository>()
?? this._nonCachingRepositoryProvider.GetRepository<T, TRepository>();
return this.GetRepository<T, TRepository>();
}

/// <inheritdoc />
public void EnsureCachesForCurrentGameConfiguration(EntityFrameworkContextBase context)
{
this._cachingRepositoryProvider.EnsureCachesForCurrentGameConfiguration(context);
}

/// <inheritdoc />
public void EnsureCachesForCurrentGameConfiguration()
{
this._cachingRepositoryProvider.EnsureCachesForCurrentGameConfiguration();
// The caches are ensured per originating context (see the overload taking a context)
// and lazily on access. Without an ambient context, there is no "current" configuration
// to ensure here.
}

/// <inheritdoc />
Expand Down Expand Up @@ -117,4 +143,4 @@ public async ValueTask UpdateCachedInstanceAsync(object changedInstance)
}
}
}
}
}
48 changes: 43 additions & 5 deletions src/Persistence/EntityFramework/CachedRepository{T}.cs
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,14 @@
namespace MUnique.OpenMU.Persistence.EntityFramework;

using System.Collections;
using System.Linq;
using System.Threading;

/// <summary>
/// A repository which caches all of its data in memory.
/// </summary>
/// <typeparam name="T">The type of the business object.</typeparam>
public class CachedRepository<T> : IRepository<T>
public class CachedRepository<T> : IRepository<T>, IContextAwareRepository
where T : class, IIdentifiable
{
private readonly IDictionary<Guid, T> _cache;
Expand Down Expand Up @@ -42,7 +43,18 @@ async ValueTask<IEnumerable> IRepository.GetAllAsync(CancellationToken cancellat
}

/// <inheritdoc/>
public async ValueTask<IEnumerable<T>> GetAllAsync(CancellationToken cancellationToken = default)
public ValueTask<IEnumerable<T>> GetAllAsync(CancellationToken cancellationToken = default)
{
return this.GetAllAsync(null, cancellationToken);
}

/// <summary>
/// Gets all objects, using the given originating context to load them from the base repository.
/// </summary>
/// <param name="context">The originating context, or <c>null</c>.</param>
/// <param name="cancellationToken">The cancellation token.</param>
/// <returns>All objects of the repository.</returns>
internal async ValueTask<IEnumerable<T>> GetAllAsync(EntityFrameworkContextBase? context, CancellationToken cancellationToken = default)
{
if (this._allLoaded)
{
Expand All @@ -62,7 +74,9 @@ public async ValueTask<IEnumerable<T>> GetAllAsync(CancellationToken cancellatio
this._loading = true;
try
{
IEnumerable<T> values = await this.BaseRepository.GetAllAsync(cancellationToken).ConfigureAwait(false);
IEnumerable<T> values = this.BaseRepository is IContextAwareRepository contextAware
? (await contextAware.GetAllAsync(context, cancellationToken).ConfigureAwait(false)).Cast<T>()
: await this.BaseRepository.GetAllAsync(cancellationToken).ConfigureAwait(false);
foreach (var obj in values)
{
if (!this._cache.ContainsKey(obj.Id))
Expand All @@ -82,9 +96,21 @@ public async ValueTask<IEnumerable<T>> GetAllAsync(CancellationToken cancellatio
}

/// <inheritdoc/>
public async ValueTask<T?> GetByIdAsync(Guid id, CancellationToken cancellationToken = default)
public ValueTask<T?> GetByIdAsync(Guid id, CancellationToken cancellationToken = default)
{
await this.GetAllAsync(cancellationToken).ConfigureAwait(false);
return this.GetByIdAsync(id, null, cancellationToken);
}

/// <summary>
/// Gets an object by identifier, using the given originating context to load the data.
/// </summary>
/// <param name="id">The identifier.</param>
/// <param name="context">The originating context, or <c>null</c>.</param>
/// <param name="cancellationToken">The cancellation token.</param>
/// <returns>The object with the identifier.</returns>
internal async ValueTask<T?> GetByIdAsync(Guid id, EntityFrameworkContextBase? context, CancellationToken cancellationToken = default)
{
await this.GetAllAsync(context, cancellationToken).ConfigureAwait(false);
this._cache.TryGetValue(id, out var result);
return result;
}
Expand All @@ -95,6 +121,18 @@ public async ValueTask<IEnumerable<T>> GetAllAsync(CancellationToken cancellatio
return await this.GetByIdAsync(id, cancellationToken).ConfigureAwait(false);
}

/// <inheritdoc/>
async ValueTask<IEnumerable> IContextAwareRepository.GetAllAsync(EntityFrameworkContextBase? context, CancellationToken cancellationToken)
{
return await this.GetAllAsync(context, cancellationToken).ConfigureAwait(false);
}

/// <inheritdoc/>
async ValueTask<object?> IContextAwareRepository.GetByIdAsync(Guid id, EntityFrameworkContextBase? context, CancellationToken cancellationToken)
{
return await this.GetByIdAsync(id, context, cancellationToken).ConfigureAwait(false);
}

/// <inheritdoc/>
public async ValueTask<bool> DeleteAsync(object obj)
{
Expand Down
Loading
Loading