11// Copyright (c) .NET Foundation. All rights reserved.
22// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
33
4+ using System ;
45using System . Collections . Generic ;
56using System . ComponentModel . Design ;
7+ using System . Data . SqlClient ;
68using System . Diagnostics . Tracing ;
79using System . Threading . Tasks ;
10+ using Microsoft . Extensions . DependencyInjection ;
811using Microsoft . Extensions . Logging ;
12+ using Microsoft . Extensions . Options ;
13+ using NuGet . Jobs . Configuration ;
14+ using NuGet . Services . KeyVault ;
15+ using NuGet . Services . Sql ;
916
1017namespace NuGet . Jobs
1118{
@@ -22,6 +29,7 @@ protected JobBase(EventSource jobEventSource)
2229 {
2330 JobName = GetType ( ) . ToString ( ) ;
2431 _jobEventSource = jobEventSource ;
32+ SqlConnectionFactories = new Dictionary < string , ISqlConnectionFactory > ( ) ;
2533 }
2634
2735 public string JobName { get ; private set ; }
@@ -30,14 +38,180 @@ protected JobBase(EventSource jobEventSource)
3038
3139 protected ILogger Logger { get ; private set ; }
3240
41+ private Dictionary < string , ISqlConnectionFactory > SqlConnectionFactories { get ; }
42+
3343 public void SetLogger ( ILoggerFactory loggerFactory , ILogger logger )
3444 {
3545 LoggerFactory = loggerFactory ;
3646 Logger = logger ;
3747 }
3848
49+ /// <summary>
50+ /// Initialize the job, provided the service container and configuration.
51+ /// </summary>
3952 public abstract void Init ( IServiceContainer serviceContainer , IDictionary < string , string > jobArgsDictionary ) ;
4053
54+ /// <summary>
55+ /// Run the job.
56+ /// </summary>
4157 public abstract Task Run ( ) ;
58+
59+
60+ /// <summary>
61+ /// Test connection early to fail fast, and log connection diagnostics.
62+ /// </summary>
63+ private async Task TestConnection ( string name , ISqlConnectionFactory connectionFactory )
64+ {
65+ try
66+ {
67+ using ( var connection = await connectionFactory . OpenAsync ( ) )
68+ using ( var cmd = new SqlCommand ( "SELECT CONCAT(CURRENT_USER, '/', SYSTEM_USER)" , connection ) )
69+ {
70+ var result = cmd . ExecuteScalar ( ) ;
71+ var user = result . ToString ( ) ;
72+ Logger . LogInformation ( "Verified CreateSqlConnectionAsync({name}) connects to database {DataSource}/{InitialCatalog} as {User}" ,
73+ name , connectionFactory . DataSource , connectionFactory . InitialCatalog , user ) ;
74+ }
75+ }
76+ catch ( Exception e )
77+ {
78+ Logger . LogError ( 0 , e , "Failed to connect to database {DataSource}/{InitialCatalog}" ,
79+ connectionFactory . DataSource , connectionFactory . InitialCatalog ) ;
80+
81+ throw ;
82+ }
83+ }
84+
85+ public SqlConnectionStringBuilder GetDatabaseRegistration < T > ( )
86+ where T : IDbConfiguration
87+ {
88+ if ( SqlConnectionFactories . TryGetValue ( GetDatabaseKey < T > ( ) , out var connectionFactory ) )
89+ {
90+ return ( ( AzureSqlConnectionFactory ) connectionFactory ) . SqlConnectionStringBuilder ;
91+ }
92+
93+ return null ;
94+ }
95+
96+ /// <summary>
97+ /// Initializes an <see cref="ISqlConnectionFactory"/>, for use by validation jobs.
98+ /// </summary>
99+ /// <returns>ConnectionStringBuilder, used for diagnostics.</returns>
100+ public SqlConnectionStringBuilder RegisterDatabase < T > (
101+ IServiceProvider serviceProvider ,
102+ bool testConnection = true )
103+ where T : IDbConfiguration
104+ {
105+ if ( serviceProvider == null )
106+ {
107+ throw new ArgumentNullException ( nameof ( serviceProvider ) ) ;
108+ }
109+
110+ var secretInjector = serviceProvider . GetRequiredService < ISecretInjector > ( ) ;
111+ var connectionString = serviceProvider . GetRequiredService < IOptionsSnapshot < T > > ( ) . Value . ConnectionString ;
112+ var connectionFactory = new AzureSqlConnectionFactory ( connectionString , secretInjector ) ;
113+
114+ return RegisterDatabase ( GetDatabaseKey < T > ( ) , connectionString , testConnection , secretInjector ) ;
115+ }
116+
117+ /// <summary>
118+ /// Initializes an <see cref="ISqlConnectionFactory"/>, for use by non-validation jobs.
119+ /// </summary>
120+ /// <returns>ConnectionStringBuilder, used for diagnostics.</returns>
121+ public SqlConnectionStringBuilder RegisterDatabase (
122+ IServiceContainer serviceContainer ,
123+ IDictionary < string , string > jobArgsDictionary ,
124+ string connectionStringArgName ,
125+ bool testConnection = true )
126+ {
127+ if ( serviceContainer == null )
128+ {
129+ throw new ArgumentNullException ( nameof ( serviceContainer ) ) ;
130+ }
131+
132+ if ( jobArgsDictionary == null )
133+ {
134+ throw new ArgumentNullException ( nameof ( jobArgsDictionary ) ) ;
135+ }
136+
137+ if ( string . IsNullOrEmpty ( connectionStringArgName ) )
138+ {
139+ throw new ArgumentException ( "Argument cannot be null or empty." , nameof ( connectionStringArgName ) ) ;
140+ }
141+
142+ var secretInjector = ( ISecretInjector ) serviceContainer . GetService ( typeof ( ISecretInjector ) ) ;
143+ var connectionString = JobConfigurationManager . GetArgument ( jobArgsDictionary , connectionStringArgName ) ;
144+
145+ return RegisterDatabase ( connectionStringArgName , connectionString , testConnection , secretInjector ) ;
146+ }
147+
148+ /// <summary>
149+ /// Register a job database at initialization time. Each call should overwrite any existing
150+ /// registration because <see cref="JobRunner"/> calls <see cref="Init"/> on every iteration.
151+ /// </summary>
152+ /// <returns>ConnectionStringBuilder, used for diagnostics.</returns>
153+ private SqlConnectionStringBuilder RegisterDatabase (
154+ string name ,
155+ string connectionString ,
156+ bool testConnection ,
157+ ISecretInjector secretInjector )
158+ {
159+ var connectionFactory = new AzureSqlConnectionFactory ( connectionString , secretInjector , Logger ) ;
160+ SqlConnectionFactories [ name ] = connectionFactory ;
161+
162+ if ( testConnection )
163+ {
164+ Task . Run ( ( ) => TestConnection ( name , connectionFactory ) ) . Wait ( ) ;
165+ }
166+
167+ return connectionFactory . SqlConnectionStringBuilder ;
168+ }
169+
170+ private static string GetDatabaseKey < T > ( )
171+ {
172+ return typeof ( T ) . Name ;
173+ }
174+
175+ /// <summary>
176+ /// Create a SqlConnection, for use by validation jobs.
177+ /// </summary>
178+ public Task < SqlConnection > CreateSqlConnectionAsync < T > ( )
179+ where T : IDbConfiguration
180+ {
181+ var name = GetDatabaseKey < T > ( ) ;
182+ if ( ! SqlConnectionFactories . ContainsKey ( name ) )
183+ {
184+ throw new InvalidOperationException ( $ "Database { name } has not been registered.") ;
185+ }
186+
187+ return SqlConnectionFactories [ name ] . CreateAsync ( ) ;
188+ }
189+
190+ /// <summary>
191+ /// Synchronous creation of a SqlConnection, for use by validation jobs.
192+ /// </summary>
193+ public SqlConnection CreateSqlConnection < T > ( )
194+ where T : IDbConfiguration
195+ {
196+ return Task . Run ( ( ) => CreateSqlConnectionAsync < T > ( ) ) . Result ;
197+ }
198+
199+ /// <summary>
200+ /// Creates and opens a SqlConnection, for use by non-validation jobs.
201+ /// </summary>
202+ public Task < SqlConnection > OpenSqlConnectionAsync ( string connectionStringArgName )
203+ {
204+ if ( string . IsNullOrEmpty ( connectionStringArgName ) )
205+ {
206+ throw new ArgumentException ( "Argument cannot be null or empty." , nameof ( connectionStringArgName ) ) ;
207+ }
208+
209+ if ( ! SqlConnectionFactories . ContainsKey ( connectionStringArgName ) )
210+ {
211+ throw new InvalidOperationException ( $ "Database { connectionStringArgName } has not been registered.") ;
212+ }
213+
214+ return SqlConnectionFactories [ connectionStringArgName ] . OpenAsync ( ) ;
215+ }
42216 }
43217}
0 commit comments