Skip to content
This repository was archived by the owner on Jul 30, 2024. It is now read-only.

Commit 9535e69

Browse files
authored
Speed up bulk insertions (#530)
The initialization phase of the job has to insert 1.5 million records into the validation DB in batches of 1,000 records. On DEV, each batch was taking 30 - 120 seconds using Entity Framework. This fix uses `SqlBulkCopy` to speed up batches to be sub-second.
1 parent df60067 commit 9535e69

13 files changed

Lines changed: 275 additions & 77 deletions

File tree

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
// Copyright (c) .NET Foundation. All rights reserved.
2+
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
3+
4+
using System;
5+
using System.Data.SqlClient;
6+
using System.Threading.Tasks;
7+
using Microsoft.Extensions.Logging;
8+
using NuGet.Jobs.Configuration;
9+
10+
namespace NuGet.Jobs
11+
{
12+
public class DelegateSqlConnectionFactory<TbDbConfiguration> : ISqlConnectionFactory<TbDbConfiguration>
13+
where TbDbConfiguration : IDbConfiguration
14+
{
15+
private readonly Func<Task<SqlConnection>> _connectionFunc;
16+
private readonly ILogger<DelegateSqlConnectionFactory<TbDbConfiguration>> _logger;
17+
18+
public DelegateSqlConnectionFactory(Func<Task<SqlConnection>> connectionFunc, ILogger<DelegateSqlConnectionFactory<TbDbConfiguration>> logger)
19+
{
20+
_connectionFunc = connectionFunc ?? throw new ArgumentNullException(nameof(connectionFunc));
21+
_logger = logger ?? throw new ArgumentNullException(nameof(logger));
22+
}
23+
24+
public Task<SqlConnection> CreateAsync() => _connectionFunc();
25+
26+
public async Task<SqlConnection> OpenAsync()
27+
{
28+
SqlConnection connection = null;
29+
30+
try
31+
{
32+
_logger.LogDebug("Opening SQL connection...");
33+
34+
connection = await _connectionFunc();
35+
36+
await connection.OpenAsync();
37+
38+
_logger.LogDebug("Opened SQL connection");
39+
40+
return connection;
41+
}
42+
catch (Exception e)
43+
{
44+
_logger.LogError(0, e, "Unable to open SQL connection due to exception");
45+
46+
connection?.Dispose();
47+
48+
throw;
49+
}
50+
}
51+
}
52+
}
Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
// Copyright (c) .NET Foundation. All rights reserved.
2+
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
3+
4+
using System.Data.SqlClient;
5+
using System.Threading.Tasks;
6+
using NuGet.Jobs.Configuration;
7+
8+
namespace NuGet.Jobs
9+
{
10+
/// <summary>
11+
/// A factory to create and open <see cref="SqlConnection"/>s.
12+
/// </summary>
13+
public interface ISqlConnectionFactory
14+
{
15+
/// <summary>
16+
/// Create an unopened SQL connection.
17+
/// </summary>
18+
/// <returns>The unopened SQL connection.</returns>
19+
Task<SqlConnection> CreateAsync();
20+
21+
/// <summary>
22+
/// Create and then open a SQL connection.
23+
/// </summary>
24+
/// <returns>A task that creates and then opens a SQL connection.</returns>
25+
Task<SqlConnection> OpenAsync();
26+
}
27+
28+
/// <summary>
29+
/// A factory to create and open <see cref="SqlConnection"/>s for a specific
30+
/// <see cref="TDbConfiguration"/>. This type can be used to avoid Dependency
31+
/// Injection key bindings.
32+
/// </summary>
33+
/// <typeparam name="TDbConfiguration">The configuration used to create the connection.</typeparam>
34+
public interface ISqlConnectionFactory<TDbConfiguration> : ISqlConnectionFactory
35+
where TDbConfiguration : IDbConfiguration
36+
{
37+
}
38+
}

src/NuGet.Jobs.Common/JobBase.cs

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@
1616

1717
namespace NuGet.Jobs
1818
{
19+
using ICoreSqlConnectionFactory = NuGet.Services.Sql.ISqlConnectionFactory;
20+
1921
public abstract class JobBase
2022
{
2123
private readonly EventSource _jobEventSource;
@@ -29,7 +31,7 @@ protected JobBase(EventSource jobEventSource)
2931
{
3032
JobName = GetType().ToString();
3133
_jobEventSource = jobEventSource;
32-
SqlConnectionFactories = new Dictionary<string, ISqlConnectionFactory>();
34+
SqlConnectionFactories = new Dictionary<string, ICoreSqlConnectionFactory>();
3335
}
3436

3537
public string JobName { get; private set; }
@@ -38,7 +40,7 @@ protected JobBase(EventSource jobEventSource)
3840

3941
protected ILogger Logger { get; private set; }
4042

41-
private Dictionary<string, ISqlConnectionFactory> SqlConnectionFactories { get; }
43+
private Dictionary<string, ICoreSqlConnectionFactory> SqlConnectionFactories { get; }
4244

4345
public void SetLogger(ILoggerFactory loggerFactory, ILogger logger)
4446
{
@@ -56,11 +58,10 @@ public void SetLogger(ILoggerFactory loggerFactory, ILogger logger)
5658
/// </summary>
5759
public abstract Task Run();
5860

59-
6061
/// <summary>
6162
/// Test connection early to fail fast, and log connection diagnostics.
6263
/// </summary>
63-
private async Task TestConnection(string name, ISqlConnectionFactory connectionFactory)
64+
private async Task TestConnection(string name, ICoreSqlConnectionFactory connectionFactory)
6465
{
6566
try
6667
{
@@ -98,17 +99,17 @@ public SqlConnectionStringBuilder GetDatabaseRegistration<T>()
9899
/// </summary>
99100
/// <returns>ConnectionStringBuilder, used for diagnostics.</returns>
100101
public SqlConnectionStringBuilder RegisterDatabase<T>(
101-
IServiceProvider serviceProvider,
102+
IServiceProvider services,
102103
bool testConnection = true)
103104
where T : IDbConfiguration
104105
{
105-
if (serviceProvider == null)
106+
if (services == null)
106107
{
107-
throw new ArgumentNullException(nameof(serviceProvider));
108+
throw new ArgumentNullException(nameof(services));
108109
}
109110

110-
var secretInjector = serviceProvider.GetRequiredService<ISecretInjector>();
111-
var connectionString = serviceProvider.GetRequiredService<IOptionsSnapshot<T>>().Value.ConnectionString;
111+
var secretInjector = services.GetRequiredService<ISecretInjector>();
112+
var connectionString = services.GetRequiredService<IOptionsSnapshot<T>>().Value.ConnectionString;
112113
var connectionFactory = new AzureSqlConnectionFactory(connectionString, secretInjector);
113114

114115
return RegisterDatabase(GetDatabaseKey<T>(), connectionString, testConnection, secretInjector);
@@ -167,13 +168,13 @@ private SqlConnectionStringBuilder RegisterDatabase(
167168
return connectionFactory.SqlConnectionStringBuilder;
168169
}
169170

170-
private ISqlConnectionFactory GetSqlConnectionFactory<T>()
171+
private ICoreSqlConnectionFactory GetSqlConnectionFactory<T>()
171172
where T : IDbConfiguration
172173
{
173174
return GetSqlConnectionFactory(GetDatabaseKey<T>());
174175
}
175176

176-
private ISqlConnectionFactory GetSqlConnectionFactory(string name)
177+
private ICoreSqlConnectionFactory GetSqlConnectionFactory(string name)
177178
{
178179
if (!SqlConnectionFactories.ContainsKey(name))
179180
{

src/NuGet.Jobs.Common/JsonConfigurationJob.cs

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,27 @@ protected virtual void ConfigureDefaultJobServices(IServiceCollection services,
117117

118118
services.AddSingleton(new TelemetryClient());
119119
services.AddTransient<ITelemetryClient, TelemetryClientWrapper>();
120+
121+
services.AddScoped<ISqlConnectionFactory<GalleryDbConfiguration>>(p =>
122+
{
123+
return new DelegateSqlConnectionFactory<GalleryDbConfiguration>(
124+
CreateSqlConnectionAsync<GalleryDbConfiguration>,
125+
p.GetRequiredService<ILogger<DelegateSqlConnectionFactory<GalleryDbConfiguration>>>());
126+
});
127+
128+
services.AddScoped<ISqlConnectionFactory<StatisticsDbConfiguration>>(p =>
129+
{
130+
return new DelegateSqlConnectionFactory<StatisticsDbConfiguration>(
131+
CreateSqlConnectionAsync<StatisticsDbConfiguration>,
132+
p.GetRequiredService<ILogger<DelegateSqlConnectionFactory<StatisticsDbConfiguration>>>());
133+
});
134+
135+
services.AddScoped<ISqlConnectionFactory<ValidationDbConfiguration>>(p =>
136+
{
137+
return new DelegateSqlConnectionFactory<ValidationDbConfiguration>(
138+
CreateSqlConnectionAsync<ValidationDbConfiguration>,
139+
p.GetRequiredService<ILogger<DelegateSqlConnectionFactory<ValidationDbConfiguration>>>());
140+
});
120141
}
121142

122143
private void ConfigureLibraries(IServiceCollection services)

src/NuGet.Jobs.Common/NuGet.Jobs.Common.csproj

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,8 +50,10 @@
5050
<Compile Include="Configuration\StatisticsDbConfiguration.cs" />
5151
<Compile Include="Configuration\ValidationDbConfiguration.cs" />
5252
<Compile Include="Configuration\ValidationStorageConfiguration.cs" />
53+
<Compile Include="DelegateSqlConnectionFactory.cs" />
5354
<Compile Include="Extensions\LoggerExtensions.cs" />
5455
<Compile Include="Extensions\XElementExtensions.cs" />
56+
<Compile Include="ISqlConnectionFactory.cs" />
5557
<Compile Include="JsonConfigurationJob.cs" />
5658
<Compile Include="SecretReader\ISecretReaderFactory.cs" />
5759
<Compile Include="SecretReader\SecretReaderFactory.cs" />

src/NuGet.Services.Revalidate/Job.cs

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
using System;
55
using System.Collections.Generic;
66
using System.ComponentModel.Design;
7+
using System.Data.SqlClient;
78
using System.IO;
89
using System.Linq;
910
using System.Threading.Tasks;
@@ -78,7 +79,7 @@ public override async Task Run()
7879

7980
preinstalledPackagesNames.UnionWith(packagesInPath);
8081
}
81-
82+
8283
File.WriteAllText(_preinstalledSetPath, JsonConvert.SerializeObject(preinstalledPackagesNames));
8384

8485
Logger.LogInformation("Rebuilt the preinstalled package set. Found {PreinstalledPackages} package ids", preinstalledPackagesNames.Count);
@@ -139,6 +140,7 @@ protected override void ConfigureJobServices(IServiceCollection services, IConfi
139140
services.AddTransient<ITelemetryClient, TelemetryClientWrapper>();
140141

141142
services.AddTransient<IPackageRevalidationStateService, PackageRevalidationStateService>();
143+
services.AddTransient<IPackageRevalidationInserter, PackageRevalidationInserter>();
142144
services.AddTransient<IRevalidationJobStateService, RevalidationJobStateService>();
143145
services.AddTransient<IRevalidationStateService, RevalidationStateService>();
144146

src/NuGet.Services.Revalidate/NuGet.Services.Revalidate.csproj

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,12 +58,14 @@
5858
<Compile Include="Services\HealthService.cs" />
5959
<Compile Include="Services\IGalleryService.cs" />
6060
<Compile Include="Services\IHealthService.cs" />
61+
<Compile Include="Services\IPackageRevalidationInserter.cs" />
6162
<Compile Include="Services\IRevalidationQueue.cs" />
6263
<Compile Include="Services\IRevalidationJobStateService.cs" />
6364
<Compile Include="Services\IPackageRevalidationStateService.cs" />
6465
<Compile Include="Services\IRevalidationService.cs" />
6566
<Compile Include="Services\IRevalidationThrottler.cs" />
6667
<Compile Include="Services\ISingletonService.cs" />
68+
<Compile Include="Services\PackageRevalidationInserter.cs" />
6769
<Compile Include="Services\RevalidationOperation.cs" />
6870
<Compile Include="Services\RevalidationQueue.cs" />
6971
<Compile Include="Services\RevalidationResult.cs" />
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
// Copyright (c) .NET Foundation. All rights reserved.
2+
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
3+
4+
using System.Collections.Generic;
5+
using System.Threading.Tasks;
6+
using NuGet.Services.Validation;
7+
8+
namespace NuGet.Services.Revalidate
9+
{
10+
public interface IPackageRevalidationInserter
11+
{
12+
Task AddPackageRevalidationsAsync(IReadOnlyList<PackageRevalidation> revalidations);
13+
}
14+
}
Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,95 @@
1+
// Copyright (c) .NET Foundation. All rights reserved.
2+
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
3+
4+
using System;
5+
using System.Collections.Generic;
6+
using System.Data;
7+
using System.Data.SqlClient;
8+
using System.Threading.Tasks;
9+
using Microsoft.Extensions.Logging;
10+
using NuGet.Jobs;
11+
using NuGet.Jobs.Configuration;
12+
using NuGet.Jobs.Validation;
13+
using NuGet.Services.Validation;
14+
15+
namespace NuGet.Services.Revalidate
16+
{
17+
public class PackageRevalidationInserter : IPackageRevalidationInserter
18+
{
19+
private const string TableName = "[dbo].[PackageRevalidations]";
20+
21+
private const string PackageIdColumn = "PackageId";
22+
private const string PackageNormalizedVersionColumn = "PackageNormalizedVersion";
23+
private const string EnqueuedColumn = "Enqueued";
24+
private const string ValidationTrackingIdColumn = "ValidationTrackingId";
25+
private const string CompletedColumn = "Completed";
26+
27+
private readonly ISqlConnectionFactory<ValidationDbConfiguration> _connectionFactory;
28+
private readonly ILogger<PackageRevalidationInserter> _logger;
29+
30+
public PackageRevalidationInserter(
31+
ISqlConnectionFactory<ValidationDbConfiguration> connectionFactory,
32+
ILogger<PackageRevalidationInserter> logger)
33+
{
34+
_connectionFactory = connectionFactory ?? throw new ArgumentNullException(nameof(connectionFactory));
35+
_logger = logger ?? throw new ArgumentNullException(nameof(logger));
36+
}
37+
38+
public async Task AddPackageRevalidationsAsync(IReadOnlyList<PackageRevalidation> revalidations)
39+
{
40+
_logger.LogDebug("Persisting package revalidations to database...");
41+
42+
var table = PrepareTable(revalidations);
43+
44+
using (var connection = await _connectionFactory.OpenAsync())
45+
{
46+
var bulkCopy = new SqlBulkCopy(
47+
connection,
48+
SqlBulkCopyOptions.TableLock | SqlBulkCopyOptions.FireTriggers | SqlBulkCopyOptions.UseInternalTransaction,
49+
externalTransaction: null);
50+
51+
foreach (DataColumn column in table.Columns)
52+
{
53+
bulkCopy.ColumnMappings.Add(column.ColumnName, column.ColumnName);
54+
}
55+
56+
bulkCopy.DestinationTableName = TableName;
57+
bulkCopy.WriteToServer(table);
58+
}
59+
60+
_logger.LogDebug("Finished persisting package revalidations to database...");
61+
}
62+
63+
private DataTable PrepareTable(IReadOnlyList<PackageRevalidation> revalidations)
64+
{
65+
// Prepare the table.
66+
var table = new DataTable();
67+
68+
table.Columns.Add(PackageIdColumn, typeof(string));
69+
table.Columns.Add(PackageNormalizedVersionColumn, typeof(string));
70+
table.Columns.Add(CompletedColumn, typeof(bool));
71+
72+
var enqueued = table.Columns.Add(EnqueuedColumn, typeof(DateTime));
73+
var trackingId = table.Columns.Add(ValidationTrackingIdColumn, typeof(Guid));
74+
75+
enqueued.AllowDBNull = true;
76+
trackingId.AllowDBNull = true;
77+
78+
// Populate the table.
79+
foreach (var revalidation in revalidations)
80+
{
81+
var row = table.NewRow();
82+
83+
row[PackageIdColumn] = revalidation.PackageId;
84+
row[PackageNormalizedVersionColumn] = revalidation.PackageNormalizedVersion;
85+
row[EnqueuedColumn] = ((object)revalidation.Enqueued) ?? DBNull.Value;
86+
row[ValidationTrackingIdColumn] = ((object)revalidation.ValidationTrackingId) ?? DBNull.Value;
87+
row[CompletedColumn] = revalidation.Completed;
88+
89+
table.Rows.Add(row);
90+
}
91+
92+
return table;
93+
}
94+
}
95+
}

0 commit comments

Comments
 (0)