Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions auth/kbase_auth_server.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ import (
"fmt"
"io"
"net/http"
"strings"
"time"
)

Expand Down Expand Up @@ -58,7 +59,6 @@ type KBaseAuthServerOption func(*KBaseAuthServerConfig)
// options can be used to modify the default configuration, and are typically
// used for testing with a mock server
func NewKBaseAuthServer(accessToken string, options ...KBaseAuthServerOption) (*KBaseAuthServer, error) {

// set up default configuration
cfg := KBaseAuthServerConfig{
BaseURL: kbaseURL,
Expand Down Expand Up @@ -165,7 +165,12 @@ func kbaseAuthError(response *http.Response) error {
if mErr == nil {
var result kbaseAuthErrorResponse
mErr = json.Unmarshal(body, &result)
if mErr == nil {
if mErr != nil {
if strings.Contains(string(body), "cloudflare") {
err = fmt.Errorf("KBase Auth error (%d): %s", response.StatusCode,
"Authenticator is protected by a Cloudflare challenge")
}
} else {
if len(result.Message) > 0 {
err = fmt.Errorf("KBase Auth error (%d): %s", response.StatusCode,
result.Message)
Expand Down
217 changes: 143 additions & 74 deletions databases/jdp/database.go
Original file line number Diff line number Diff line change
Expand Up @@ -75,8 +75,8 @@ type Config struct {
}

type StagingRequest struct {
// JDP staging request ID
Id int
// JDP staging request IDs (batched because the endpoint has a maximum limit)
Ids []int
// time of staging request (for purging)
Time time.Time
}
Expand Down Expand Up @@ -136,7 +136,9 @@ func (db *Database) Search(orcid string, params databases.SearchParameters) (dat
pageNumber, pageSize := pageNumberAndSize(params.Pagination.Offset, params.Pagination.MaxNum)

p := url.Values{}
p.Add("q", params.Query)
if params.Query != "" {
p.Add("q", params.Query)
}
switch params.Status {
case databases.SearchFileStatusStaged:
p.Add(`ff[file_status]`, "RESTORED")
Expand Down Expand Up @@ -171,20 +173,13 @@ func (db *Database) Search(orcid string, params databases.SearchParameters) (dat
}, err
}

func (db *Database) Descriptors(orcid string, fileIds []string) ([]map[string]any, error) {
// strip the "JDP:" prefix from our files and create a mapping from IDs to
// their original order so we can hand back metadata accordingly
func (db *Database) fetchDescriptors(orcid string, fileIds []string, batchSize int) ([]map[string]any, error) {
// strip the "JDP:" prefix from our files
strippedFileIds := make([]string, len(fileIds))
indexForId := make(map[string]int)
for i, fileId := range fileIds {
strippedFileIds[i] = strings.TrimPrefix(fileId, "JDP:")
indexForId[strippedFileIds[i]] = i
}

// NOTE: the JDP search/by_file_ids/ endpoint (unofficial, undocumented!) only seems to
// NOTE: accept around 50 file IDs at a time, so we have to batch our requests

batchSize := 50
numBatches := len(strippedFileIds) / batchSize
if numBatches*batchSize < len(strippedFileIds) {
numBatches++
Expand Down Expand Up @@ -214,33 +209,60 @@ func (db *Database) Descriptors(orcid string, fileIds []string) ([]map[string]an
return nil, err
}

// get a de-duped list of descriptors
// get a de-duped list of descriptors (with JDP: prefixes reinstated on IDs)
batchDescriptors, err := descriptorsFromResponseBody(body, nil)
if err != nil {
return nil, err
}
descriptors = append(descriptors, batchDescriptors...)
}
return descriptors, nil
}

func (db *Database) Descriptors(orcid string, fileIds []string) ([]map[string]any, error) {
// NOTE: The JDP search/by_file_ids/ endpoint (unofficial, undocumented!) only seems to
// NOTE: accept around 50 file IDs at a time, so we have to batch our requests.
descriptors, err := db.fetchDescriptors(orcid, fileIds, 50)
if err != nil {
return nil, err
}

// reorder the descriptors to match that of the requested file IDs, and track file IDs that aren't
// matched to descriptors
descriptorsByFileId := make(map[string]map[string]any)
for _, descriptor := range descriptors {
descriptorsByFileId[descriptor["id"].(string)] = descriptor
}

// if any file IDs don't have corresponding descriptors, find out which ones and issue an error
// NOTE: Evidently sometimes the search/by_file_ids endpoint doesn't return all of the
// NOTE: relevant information (???), so we make a list of missing descriptors and attempt
// NOTE: to fetch them again. If that attempt fails, we emit an error.
if len(descriptorsByFileId) < len(fileIds) {
missingResources := make([]string, 0)
missingFileIds := make([]string, 0)
for _, fileId := range fileIds {
if _, found := descriptorsByFileId[fileId]; !found {
missingResources = append(missingResources, fileId)
missingFileIds = append(missingFileIds, fileId)
}
}
if recoveredDescriptors, err := db.fetchDescriptors(orcid, missingFileIds, 50); err == nil {
for _, descriptor := range recoveredDescriptors {
descriptorsByFileId[descriptor["id"].(string)] = descriptor
}
if len(recoveredDescriptors) < len(missingFileIds) { // didn't get them all!
missingFileIds = make([]string, 0)
for _, fileId := range fileIds {
if _, found := descriptorsByFileId[fileId]; !found {
missingFileIds = append(missingFileIds, fileId)
}
}
} else {
// got 'em!
missingFileIds = nil
}
}
if len(missingResources) > 0 {
if len(missingFileIds) > 0 {
slices.Sort(missingFileIds)
return nil, &databases.ResourcesNotFoundError{
Database: "JDP",
ResourceIds: missingResources,
ResourceIds: missingFileIds,
}
}
}
Expand All @@ -256,9 +278,7 @@ func (db *Database) EndpointNames() []string {
return []string{db.EndpointName}
}

func (db *Database) StageFiles(orcid string, fileIds []string) (uuid.UUID, error) {
var xferId uuid.UUID

func (db *Database) requestArchivedFiles(orcid string, fileIds []string, batchSize int) ([]int, error) {
// construct a POST request to restore archived files with the given IDs
type RestoreRequest struct {
Ids []string `json:"ids"`
Expand All @@ -275,72 +295,116 @@ func (db *Database) StageFiles(orcid string, fileIds []string) (uuid.UUID, error
}
}

data, err := json.Marshal(RestoreRequest{
Ids: fileIdsWithoutPrefix,
SendEmail: false,
ApiVersion: "2",
IncludePrivateData: 1, // we need this just in case!
})
if err != nil {
return xferId, err
numBatches := len(fileIds) / batchSize
if numBatches*batchSize < len(fileIds) {
numBatches += 1
}

// NOTE: The slash in the resource is all-important for POST requests to
// NOTE: the JDP!!
body, err := db.post("request_archived_files/", orcid, bytes.NewReader(data))
if err != nil {
switch e := err.(type) {
case *databases.ResourcesNotFoundError:
e.ResourceIds = fileIds
requestIds := make([]int, numBatches)
for i := range numBatches {
begin := i * batchSize
end := min((i+1)*batchSize, len(fileIdsWithoutPrefix))
data, err := json.Marshal(RestoreRequest{
Ids: fileIdsWithoutPrefix[begin:end],
SendEmail: false,
ApiVersion: "2",
IncludePrivateData: 1, // we need this just in case!
})
if err != nil {
return nil, err
}

// NOTE: The slash in the resource is all-important for POST requests to the JDP!!
body, err := db.post("request_archived_files/", orcid, bytes.NewReader(data))
if err != nil {
switch e := err.(type) {
case *databases.ResourcesNotFoundError:
e.ResourceIds = fileIds
}
return nil, err
}

type RestoreResponse struct {
RequestId int `json:"request_id"`
}
return xferId, err
}

type RestoreResponse struct {
RequestId int `json:"request_id"`
var jdpResp RestoreResponse
err = json.Unmarshal(body, &jdpResp)
if err != nil {
return nil, err
}
requestIds[i] = jdpResp.RequestId
}

var jdpResp RestoreResponse
err = json.Unmarshal(body, &jdpResp)
return requestIds, nil
}

func (db *Database) StageFiles(orcid string, fileIds []string) (uuid.UUID, error) {
var xferId uuid.UUID

// NOTE: the relevant endpoint seems to return a 404 whenever it gets too many file IDs,
// NOTE: so we batch requests in sets of 1000
requestIds, err := db.requestArchivedFiles(orcid, fileIds, 1000)
if err != nil {
return xferId, err
}
slog.Debug(fmt.Sprintf("Requested %d archived files from JDP (request ID: %d)",
len(fileIds), jdpResp.RequestId))

slog.Debug(fmt.Sprintf("Requested %d archived files from JDP (request IDs: %v)",
len(fileIds), requestIds))
xferId = uuid.New()
db.StagingRequests[xferId] = StagingRequest{
Id: jdpResp.RequestId,
Ids: requestIds,
Time: time.Now(),
}
return xferId, err
}

func (db *Database) StagingStatus(id uuid.UUID) (databases.StagingStatus, error) {
statusForString := map[string]databases.StagingStatus{
"new": databases.StagingStatusActive,
"pending": databases.StagingStatusActive,
"ready": databases.StagingStatusSucceeded,
"failed": databases.StagingStatusFailed,
}
db.pruneStagingRequests()
if request, found := db.StagingRequests[id]; found {
resource := fmt.Sprintf("request_archived_files/requests/%d", request.Id)
body, err := db.get(resource, url.Values{})
if err != nil {
return databases.StagingStatusUnknown, err
}
type JDPResult struct {
Status string `json:"status"` // "new", "pending", or "ready"
}
var jdpResult JDPResult
err = json.Unmarshal(body, &jdpResult)
if err != nil {
return databases.StagingStatusUnknown, err
}
statusForString := map[string]databases.StagingStatus{
"new": databases.StagingStatusActive,
"pending": databases.StagingStatusActive,
"ready": databases.StagingStatusSucceeded,
}
if status, ok := statusForString[jdpResult.Status]; ok {
slog.Debug(fmt.Sprintf("Queried JDP for staging status of transfer with staging ID %s (request ID: %d): %s", id, request.Id, jdpResult.Status))
return status, nil
var status databases.StagingStatus
var statusStr string
for _, requestId := range request.Ids {
resource := fmt.Sprintf("request_archived_files/requests/%d", requestId)
body, err := db.get(resource, url.Values{})
if err != nil {
return databases.StagingStatusUnknown, err
}
type JDPResult struct {
Status string `json:"status"` // "new", "pending", "ready", or "failed"
}
var jdpResult JDPResult
err = json.Unmarshal(body, &jdpResult)
if err != nil {
return databases.StagingStatusUnknown, err
}
if requestStatus, ok := statusForString[jdpResult.Status]; ok {
if status == databases.StagingStatusUnknown { // first status encountered
status = requestStatus
statusStr = jdpResult.Status
} else {
if requestStatus != status { // status update
if requestStatus != databases.StagingStatusSucceeded {
status = requestStatus
statusStr = jdpResult.Status
}
}
}
if status == databases.StagingStatusFailed { // one failure sinks them all
break
}
} else {
return databases.StagingStatusUnknown, fmt.Errorf("unrecognized JDP staging status string: %s", jdpResult.Status)
}
}
return databases.StagingStatusUnknown, fmt.Errorf("unrecognized JDP staging status string: %s", jdpResult.Status)
slog.Debug(fmt.Sprintf("Queried JDP for staging status of transfer with staging ID %s (request IDs: %v): %s", id, request.Ids, statusStr))
return status, nil
} else {
slog.Info(fmt.Sprintf("No staging request found for transfer with staging ID %s", id.String()))
return databases.StagingStatusUnknown, nil
Expand Down Expand Up @@ -658,10 +722,6 @@ func (db *Database) post(resource, orcid string, body io.Reader) ([]byte, error)
case 200, 201, 204:
defer resp.Body.Close()
return io.ReadAll(resp.Body)
case 404:
return nil, &databases.ResourcesNotFoundError{
Database: "JDP",
}
case 503:
return nil, &databases.UnavailableError{
Database: "jdp",
Expand Down Expand Up @@ -809,6 +869,15 @@ func (db Database) addSpecificSearchParameters(params map[string]any, p *url.Val
}
}
p.Add(name, fmt.Sprintf("%d", value))
case "datasets": // specific JGI datasets requested
var value string
if value, ok = jsonValue.(string); !ok {
return &databases.InvalidSearchParameter{
Database: "JDP",
Message: "Invalid JDP dataset(s) requested (must be comma-delimited string)",
}
}
p.Add(name, strings.TrimSpace(value))
case "extra": // comma-separated additional fields requested
var value string
if value, ok = jsonValue.(string); !ok {
Expand Down Expand Up @@ -844,7 +913,7 @@ func (db *Database) pruneStagingRequests() {
for uuid, request := range db.StagingRequests {
requestAge := time.Since(request.Time)
if requestAge > db.DeleteAfter {
slog.Debug(fmt.Sprintf("Pruning staging request with staging ID %s (request ID: %d): age (%s) exceeds limit (%s)", uuid.String(), request.Id, requestAge.String(), db.DeleteAfter.String()))
slog.Debug(fmt.Sprintf("Pruning staging request with staging ID %s (request IDs: %v): age (%s) exceeds limit (%s)", uuid.String(), request.Ids, requestAge.String(), db.DeleteAfter.String()))
delete(db.StagingRequests, uuid)
}
}
Expand Down
Loading
Loading