Skip to content

Commit c0248ba

Browse files
authored
Add auth to V3 monitoring search cursor URLs (#10115)
* Add auth to V3 monitoring search cursor URLs * Use -clientId for managed identity auth on search cursors * Allow UseManagedIdentity to be set per search instance
1 parent b1c2839 commit c0248ba

12 files changed

Lines changed: 464 additions & 120 deletions

File tree

src/Catalog/AzureBlobCursor.cs

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
// Copyright (c) .NET Foundation. All rights reserved.
2+
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
3+
4+
using System;
5+
using System.Diagnostics;
6+
using System.Threading;
7+
using System.Threading.Tasks;
8+
using Azure;
9+
using Azure.Storage.Blobs;
10+
using Azure.Storage.Blobs.Models;
11+
using Newtonsoft.Json.Linq;
12+
13+
namespace NuGet.Services.Metadata.Catalog
14+
{
15+
public class AzureBlobCursor : ReadCursor
16+
{
17+
private readonly BlobClient _blobClient;
18+
19+
public AzureBlobCursor(BlobClient blobClient)
20+
{
21+
_blobClient = blobClient ?? throw new ArgumentNullException(nameof(blobClient));
22+
}
23+
24+
public override async Task LoadAsync(CancellationToken cancellationToken)
25+
{
26+
BlobDownloadResult downloadResult;
27+
try
28+
{
29+
downloadResult = await _blobClient.DownloadContentAsync(cancellationToken);
30+
}
31+
catch (RequestFailedException ex)
32+
{
33+
Trace.TraceError("AzureBlobCursor.LoadAsync: error {0} {1}", ex.Status, _blobClient.Uri.AbsoluteUri);
34+
throw;
35+
}
36+
37+
var json = downloadResult.Content.ToString();
38+
39+
JObject obj = JObject.Parse(json);
40+
Value = obj["value"].ToObject<DateTime>();
41+
42+
Trace.TraceInformation("AzureBlobCursor.LoadAsync: {0:O} {1}", Value, _blobClient.Uri.AbsoluteUri);
43+
}
44+
}
45+
}

src/Ng/Arguments.cs

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// Copyright (c) .NET Foundation. All rights reserved.
1+
// Copyright (c) .NET Foundation. All rights reserved.
22
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
33

44
using NuGet.Services.Metadata.Catalog.Monitoring;
@@ -99,6 +99,19 @@ public static class Arguments
9999
/// </summary>
100100
public const string SearchCursorUriPrefix = "searchCursorUri-";
101101

102+
/// <summary>
103+
/// The argument prefix for the cursor SAS token of a <see cref="SearchEndpoint"/> cursor.
104+
/// This is used in conjunction with the <see cref="SearchBaseUriPrefix"/> argument with same suffix for authentication with blob storage.
105+
/// </summary>
106+
public const string SearchCursorSasValuePrefix = "searchCursorSasValue-";
107+
108+
/// <summary>
109+
/// The argument prefix to enable using a token credential for a <see cref="SearchEndpoint"/> cursor.
110+
/// This is used in conjunction with the <see cref="SearchBaseUriPrefix"/> argument with same suffix for authentication with blob storage.
111+
/// If <see cref="ClientId"/> is specified, a managed identity credential will be used. Otherwise, a default Azure credential will be used.
112+
/// </summary>
113+
public const string SearchCursorUseManagedIdentityPrefix = "searchCursorUseManagedIdentity-";
114+
102115
/// <summary>
103116
/// The argument prefix for the base URL of a <see cref="SearchEndpoint"/>. There should be the same number of
104117
/// <see cref="SearchBaseUriPrefix"/> parameters passed as <see cref="SearchCursorUriPrefix"/> with the same
@@ -186,4 +199,4 @@ public static class Arguments
186199
public const string CursorFile = "cursorFile";
187200
#endregion
188201
}
189-
}
202+
}

src/Ng/CommandHelpers.cs

Lines changed: 96 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,9 @@
88
using System.Net;
99
using System.Net.Http;
1010
using System.Security.Cryptography.X509Certificates;
11+
using Azure;
12+
using Azure.Core;
13+
using Azure.Identity;
1114
using Azure.Storage.Blobs;
1215
using Azure.Storage.Queues;
1316
using Microsoft.Extensions.Logging;
@@ -28,6 +31,7 @@ namespace Ng
2831
{
2932
public static class CommandHelpers
3033
{
34+
private const string DefaultStorageSuffix = "core.windows.net";
3135
private static readonly int DefaultKeyVaultSecretCachingTimeout = 60 * 60 * 6; // 6 hours;
3236
private static readonly HashSet<string> NotInjectedKeys = new HashSet<string>(StringComparer.OrdinalIgnoreCase)
3337
{
@@ -211,16 +215,14 @@ private static CatalogStorageFactory CreateStorageFactoryImpl(
211215

212216
if (Arguments.AzureStorageType.Equals(storageType, StringComparison.InvariantCultureIgnoreCase))
213217
{
214-
var storageAccountName = arguments.GetOrThrow<string>(argumentNameMap[Arguments.StorageAccountName]);
215218
var storageContainer = arguments.GetOrThrow<string>(argumentNameMap[Arguments.StorageContainer]);
216219
var storagePath = arguments.GetOrDefault<string>(argumentNameMap[Arguments.StoragePath]);
217-
var storageSuffix = arguments.GetOrDefault(argumentNameMap[Arguments.StorageSuffix], "core.windows.net");
218220
var storageOperationMaxExecutionTime = MaxExecutionTime(arguments.GetOrDefault<int>(argumentNameMap[Arguments.StorageOperationMaxExecutionTimeInSeconds]));
219221
var storageServerTimeout = MaxExecutionTime(arguments.GetOrDefault<int>(argumentNameMap[Arguments.StorageServerTimeoutInSeconds]));
220222
var storageUseServerSideCopy = arguments.GetOrDefault<bool>(argumentNameMap[Arguments.StorageUseServerSideCopy]);
221223
var storageInitializeContainer = arguments.GetOrDefault(argumentNameMap[Arguments.StorageInitializeContainer], defaultValue: true);
222224

223-
BlobServiceClient account = GetBlobServiceClient(storageAccountName, storageSuffix, arguments, argumentNameMap);
225+
BlobServiceClient account = GetBlobServiceClient(arguments, argumentNameMap);
224226

225227
return new CatalogAzureStorageFactory(
226228
account,
@@ -283,19 +285,28 @@ public static Func<HttpMessageHandler> GetHttpMessageHandlerFactory(
283285

284286
public static EndpointConfiguration GetEndpointConfiguration(IDictionary<string, string> arguments)
285287
{
288+
var clientId = arguments.GetOrDefault<string>(Arguments.ClientId);
289+
286290
var registrationCursorUri = arguments.GetOrThrow<Uri>(Arguments.RegistrationCursorUri);
287291
var flatContainerCursorUri = arguments.GetOrThrow<Uri>(Arguments.FlatContainerCursorUri);
288292

289-
var instanceNameToSearchBaseUri = GetSuffixToUri(arguments, Arguments.SearchBaseUriPrefix);
290-
var instanceNameToSearchCursorUri = GetSuffixToUri(arguments, Arguments.SearchCursorUriPrefix);
293+
var instanceNameToSearchBaseUri = GetSuffixToValue<Uri>(arguments, Arguments.SearchBaseUriPrefix);
294+
var instanceNameToSearchCursorUri = GetSuffixToValue<Uri>(arguments, Arguments.SearchCursorUriPrefix);
295+
var instanceNameToSearchCursorSasValue = GetSuffixToValue<string>(arguments, Arguments.SearchCursorSasValuePrefix);
296+
var instanceNameToSearchCursorUseManagedIdentity = GetSuffixToValue<bool>(arguments, Arguments.SearchCursorUseManagedIdentityPrefix);
291297
var instanceNameToSearchConfig = new Dictionary<string, SearchEndpointConfiguration>();
298+
292299
foreach (var pair in instanceNameToSearchBaseUri)
293300
{
294301
var instanceName = pair.Key;
295302

296303
// Find all cursors with an instance name starting with the search base URI instance name. We do this
297304
// because there may be multiple potential cursors representing the state of a search service.
298-
var matchingCursors = instanceNameToSearchCursorUri.Keys.Where(x => x.StartsWith(instanceName)).ToList();
305+
var matchingCursors = instanceNameToSearchCursorUri
306+
.Keys
307+
.Where(x => x.StartsWith(instanceName))
308+
.OrderBy(x => x)
309+
.ToList();
299310

300311
if (!matchingCursors.Any())
301312
{
@@ -304,9 +315,51 @@ public static EndpointConfiguration GetEndpointConfiguration(IDictionary<string,
304315
$"-{Arguments.SearchCursorUriPrefix}{instanceName}* arguments.");
305316
}
306317

307-
instanceNameToSearchConfig[instanceName] = new SearchEndpointConfiguration(
308-
matchingCursors.Select(x => instanceNameToSearchCursorUri[x]).ToList(),
309-
pair.Value);
318+
var cursors = new List<SearchCursorConfiguration>();
319+
320+
foreach (var suffix in matchingCursors)
321+
{
322+
var cursorUri = instanceNameToSearchCursorUri[suffix];
323+
SearchCursorCredentialType credentialType;
324+
325+
BlobClient blobClient = null;
326+
if (instanceNameToSearchCursorUseManagedIdentity.TryGetValue(suffix, out var useManagedIdentity)
327+
&& useManagedIdentity)
328+
{
329+
TokenCredential credential;
330+
if (string.IsNullOrEmpty(clientId))
331+
{
332+
credential = new DefaultAzureCredential();
333+
credentialType = SearchCursorCredentialType.DefaultAzureCredential;
334+
}
335+
else
336+
{
337+
credential = new ManagedIdentityCredential(clientId);
338+
credentialType = SearchCursorCredentialType.ManagedIdentityCredential;
339+
}
340+
341+
blobClient = new BlobClient(cursorUri, credential);
342+
}
343+
else if (instanceNameToSearchCursorSasValue.TryGetValue(suffix, out var sas))
344+
{
345+
if (sas.StartsWith("?"))
346+
{
347+
// workaround for https://github.com/Azure/azure-sdk-for-net/issues/44373
348+
sas = sas.Substring(1);
349+
}
350+
351+
blobClient = new BlobClient(cursorUri, new AzureSasCredential(sas));
352+
credentialType = SearchCursorCredentialType.AzureSasCredential;
353+
}
354+
else
355+
{
356+
credentialType = SearchCursorCredentialType.Anonymous;
357+
}
358+
359+
cursors.Add(new SearchCursorConfiguration(cursorUri, blobClient, credentialType));
360+
}
361+
362+
instanceNameToSearchConfig[instanceName] = new SearchEndpointConfiguration(cursors, pair.Value);
310363

311364
foreach (var key in matchingCursors)
312365
{
@@ -329,13 +382,13 @@ public static EndpointConfiguration GetEndpointConfiguration(IDictionary<string,
329382
instanceNameToSearchConfig);
330383
}
331384

332-
private static Dictionary<string, Uri> GetSuffixToUri(IDictionary<string, string> arguments, string prefix)
385+
private static Dictionary<string, T> GetSuffixToValue<T>(IDictionary<string, string> arguments, string prefix)
333386
{
334-
var suffixToUri = new Dictionary<string, Uri>();
387+
var suffixToUri = new Dictionary<string, T>();
335388
foreach (var key in arguments.Keys.Where(x => x.StartsWith(prefix)))
336389
{
337390
var suffix = key.Substring(prefix.Length);
338-
suffixToUri[suffix] = arguments.GetOrThrow<Uri>(key);
391+
suffixToUri[suffix] = arguments.GetOrThrow<T>(key);
339392
}
340393

341394
return suffixToUri;
@@ -354,10 +407,8 @@ public static IStorageQueue<T> CreateStorageQueue<T>(IDictionary<string, string>
354407

355408
if (Arguments.AzureStorageType.Equals(storageType, StringComparison.InvariantCultureIgnoreCase))
356409
{
357-
var storageAccountName = arguments.GetOrThrow<string>(Arguments.StorageAccountName);
358410
var storageQueueName = arguments.GetOrDefault<string>(Arguments.StorageQueueName);
359-
360-
QueueServiceClient account = GetQueueServiceClient(storageAccountName, endpointSuffix: null, arguments, ArgumentNames);
411+
QueueServiceClient account = GetQueueServiceClient(arguments, ArgumentNames);
361412
return new StorageQueue<T>(new AzureStorageQueue(account, storageQueueName),
362413
new JsonMessageSerializer<T>(JsonSerializerUtility.SerializerSettings), version);
363414
}
@@ -367,8 +418,34 @@ public static IStorageQueue<T> CreateStorageQueue<T>(IDictionary<string, string>
367418
}
368419
}
369420

370-
private static BlobServiceClient GetBlobServiceClient(string storageAccountName, string endpointSuffix, IDictionary<string, string> arguments, IDictionary<string, string> argumentNameMap)
421+
private static BlobServiceClient GetBlobServiceClient(
422+
IDictionary<string, string> arguments,
423+
IDictionary<string, string> argumentNameMap)
424+
{
425+
string connectionString = GetConnectionString(arguments, argumentNameMap, "BlobEndpoint", "blob");
426+
return new BlobServiceClient(connectionString);
427+
}
428+
429+
private static QueueServiceClient GetQueueServiceClient(
430+
IDictionary<string, string> arguments,
431+
IDictionary<string, string> argumentNameMap)
432+
{
433+
string connectionString = GetConnectionString(arguments, argumentNameMap, "QueueEndpoint", "queue");
434+
return new QueueServiceClient(connectionString, new QueueClientOptions
435+
{
436+
// We use base64 encoding for compatibility with the older SDK
437+
MessageEncoding = QueueMessageEncoding.Base64,
438+
});
439+
}
440+
441+
private static string GetConnectionString(
442+
IDictionary<string, string> arguments,
443+
IDictionary<string, string> argumentNameMap,
444+
string endpointKey,
445+
string endpointDomain)
371446
{
447+
var storageAccountName = arguments.GetOrThrow<string>(argumentNameMap[Arguments.StorageAccountName]);
448+
var storageSuffix = arguments.GetOrDefault(argumentNameMap[Arguments.StorageSuffix], DefaultStorageSuffix);
372449
var storageKeyValue = arguments.GetOrDefault<string>(argumentNameMap[Arguments.StorageKeyValue]);
373450

374451
string connectionString;
@@ -382,34 +459,14 @@ private static BlobServiceClient GetBlobServiceClient(string storageAccountName,
382459
storageSasValue = storageSasValue.Substring(1);
383460
}
384461

385-
connectionString = $"BlobEndpoint=https://{storageAccountName}.blob.{endpointSuffix}/;SharedAccessSignature={storageSasValue}";
386-
}
387-
else
388-
{
389-
connectionString = $"DefaultEndpointsProtocol=https;AccountName={storageAccountName};AccountKey={storageKeyValue};EndpointSuffix={endpointSuffix}";
390-
}
391-
392-
return new BlobServiceClient(connectionString);
393-
}
394-
395-
private static QueueServiceClient GetQueueServiceClient(string storageAccountName, string endpointSuffix, IDictionary<string, string> arguments, IDictionary<string, string> argumentNameMap)
396-
{
397-
var storageKeyValue = arguments.GetOrDefault<string>(argumentNameMap[Arguments.StorageKeyValue]);
398-
399-
string connectionString;
400-
401-
if (string.IsNullOrEmpty(storageKeyValue))
402-
{
403-
var storageSasValue = arguments.GetOrThrow<string>(argumentNameMap[Arguments.StorageSasValue]);
404-
connectionString = $"BlobEndpoint=https://{storageAccountName}.blob.{endpointSuffix}/;SharedAccessSignature={storageSasValue}";
462+
connectionString = $"{endpointKey}=https://{storageAccountName}.{endpointDomain}.{storageSuffix}/;SharedAccessSignature={storageSasValue}";
405463
}
406464
else
407465
{
408-
connectionString = $"DefaultEndpointsProtocol=https;AccountName={storageAccountName};AccountKey={storageKeyValue};EndpointSuffix={endpointSuffix}";
466+
connectionString = $"DefaultEndpointsProtocol=https;AccountName={storageAccountName};AccountKey={storageKeyValue};EndpointSuffix={storageSuffix}";
409467
}
410468

411-
return new QueueServiceClient(connectionString);
469+
return connectionString;
412470
}
413-
414471
}
415472
}

src/NuGet.Services.Configuration/DictionaryExtensions.cs

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// Copyright (c) .NET Foundation. All rights reserved.
1+
// Copyright (c) .NET Foundation. All rights reserved.
22
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
33

44
using System;
@@ -35,7 +35,10 @@ public static class DictionaryExtensions
3535
/// <exception cref="ArgumentException">Thrown when the value associated with the key in the dictionary is null or empty.</exception>
3636
public static T GetOrThrow<T>(this IDictionary<string, string> dictionary, string key)
3737
{
38-
var value = dictionary[key];
38+
if (!dictionary.TryGetValue(key, out var value))
39+
{
40+
throw new KeyNotFoundException($"Key {key} was not found in the dictionary.");
41+
}
3942

4043
if (string.IsNullOrEmpty(value))
4144
{

src/NuGet.Services.Metadata.Catalog.Monitoring/Utility/ContainerBuilderExtensions.cs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// Copyright (c) .NET Foundation. All rights reserved.
1+
// Copyright (c) .NET Foundation. All rights reserved.
22
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
33

44
using System;
@@ -77,7 +77,7 @@ private static void RegisterSearchEndpoints(this ContainerBuilder builder, Endpo
7777
builder
7878
.Register(c => new SearchEndpoint(
7979
pair.Key,
80-
pair.Value.CursorUris,
80+
pair.Value.Cursors,
8181
pair.Value.BaseUri,
8282
c.Resolve<Func<HttpMessageHandler>>()))
8383
.As<IEndpoint>()
@@ -240,4 +240,4 @@ private static void RegisterResourceProvider<TProvider>(this ContainerBuilder bu
240240
.Keyed<INuGetResourceProvider>(type);
241241
}
242242
}
243-
}
243+
}
Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
// Copyright (c) .NET Foundation. All rights reserved.
2+
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
3+
4+
using System;
5+
using Azure.Storage.Blobs;
6+
7+
namespace NuGet.Services.Metadata.Catalog.Monitoring
8+
{
9+
public enum SearchCursorCredentialType
10+
{
11+
Anonymous = 1,
12+
AzureSasCredential,
13+
DefaultAzureCredential,
14+
ManagedIdentityCredential,
15+
}
16+
17+
public sealed class SearchCursorConfiguration
18+
{
19+
public SearchCursorConfiguration(Uri cursorUri, BlobClient blobClient, SearchCursorCredentialType credentialType)
20+
{
21+
CursorUri = cursorUri;
22+
BlobClient = blobClient;
23+
CredentialType = credentialType;
24+
}
25+
26+
public Uri CursorUri { get; }
27+
public BlobClient BlobClient { get; }
28+
public SearchCursorCredentialType CredentialType { get; }
29+
}
30+
}

0 commit comments

Comments
 (0)