diff --git a/auth/kbase_auth_server.go b/auth/kbase_auth_server.go index 00358937..57b4bc62 100644 --- a/auth/kbase_auth_server.go +++ b/auth/kbase_auth_server.go @@ -26,6 +26,7 @@ import ( "fmt" "io" "net/http" + "strings" "time" ) @@ -58,7 +59,6 @@ type KBaseAuthServerOption func(*KBaseAuthServerConfig) // options can be used to modify the default configuration, and are typically // used for testing with a mock server func NewKBaseAuthServer(accessToken string, options ...KBaseAuthServerOption) (*KBaseAuthServer, error) { - // set up default configuration cfg := KBaseAuthServerConfig{ BaseURL: kbaseURL, @@ -165,7 +165,12 @@ func kbaseAuthError(response *http.Response) error { if mErr == nil { var result kbaseAuthErrorResponse mErr = json.Unmarshal(body, &result) - if mErr == nil { + if mErr != nil { + if strings.Contains(string(body), "cloudflare") { + err = fmt.Errorf("KBase Auth error (%d): %s", response.StatusCode, + "Authenticator is protected by a Cloudflare challenge") + } + } else { if len(result.Message) > 0 { err = fmt.Errorf("KBase Auth error (%d): %s", response.StatusCode, result.Message) diff --git a/databases/jdp/database.go b/databases/jdp/database.go index c59dc67e..0efdc331 100644 --- a/databases/jdp/database.go +++ b/databases/jdp/database.go @@ -75,8 +75,8 @@ type Config struct { } type StagingRequest struct { - // JDP staging request ID - Id int + // JDP staging request IDs (batched because the endpoint has a maximum limit) + Ids []int // time of staging request (for purging) Time time.Time } @@ -136,7 +136,9 @@ func (db *Database) Search(orcid string, params databases.SearchParameters) (dat pageNumber, pageSize := pageNumberAndSize(params.Pagination.Offset, params.Pagination.MaxNum) p := url.Values{} - p.Add("q", params.Query) + if params.Query != "" { + p.Add("q", params.Query) + } switch params.Status { case databases.SearchFileStatusStaged: p.Add(`ff[file_status]`, "RESTORED") @@ -171,20 +173,13 @@ func (db *Database) Search(orcid string, params databases.SearchParameters) (dat }, err } -func (db *Database) Descriptors(orcid string, fileIds []string) ([]map[string]any, error) { - // strip the "JDP:" prefix from our files and create a mapping from IDs to - // their original order so we can hand back metadata accordingly +func (db *Database) fetchDescriptors(orcid string, fileIds []string, batchSize int) ([]map[string]any, error) { + // strip the "JDP:" prefix from our files strippedFileIds := make([]string, len(fileIds)) - indexForId := make(map[string]int) for i, fileId := range fileIds { strippedFileIds[i] = strings.TrimPrefix(fileId, "JDP:") - indexForId[strippedFileIds[i]] = i } - // NOTE: the JDP search/by_file_ids/ endpoint (unofficial, undocumented!) only seems to - // NOTE: accept around 50 file IDs at a time, so we have to batch our requests - - batchSize := 50 numBatches := len(strippedFileIds) / batchSize if numBatches*batchSize < len(strippedFileIds) { numBatches++ @@ -214,33 +209,60 @@ func (db *Database) Descriptors(orcid string, fileIds []string) ([]map[string]an return nil, err } - // get a de-duped list of descriptors + // get a de-duped list of descriptors (with JDP: prefixes reinstated on IDs) batchDescriptors, err := descriptorsFromResponseBody(body, nil) if err != nil { return nil, err } descriptors = append(descriptors, batchDescriptors...) } + return descriptors, nil +} + +func (db *Database) Descriptors(orcid string, fileIds []string) ([]map[string]any, error) { + // NOTE: The JDP search/by_file_ids/ endpoint (unofficial, undocumented!) only seems to + // NOTE: accept around 50 file IDs at a time, so we have to batch our requests. + descriptors, err := db.fetchDescriptors(orcid, fileIds, 50) + if err != nil { + return nil, err + } - // reorder the descriptors to match that of the requested file IDs, and track file IDs that aren't - // matched to descriptors descriptorsByFileId := make(map[string]map[string]any) for _, descriptor := range descriptors { descriptorsByFileId[descriptor["id"].(string)] = descriptor } - // if any file IDs don't have corresponding descriptors, find out which ones and issue an error + // NOTE: Evidently sometimes the search/by_file_ids endpoint doesn't return all of the + // NOTE: relevant information (???), so we make a list of missing descriptors and attempt + // NOTE: to fetch them again. If that attempt fails, we emit an error. if len(descriptorsByFileId) < len(fileIds) { - missingResources := make([]string, 0) + missingFileIds := make([]string, 0) for _, fileId := range fileIds { if _, found := descriptorsByFileId[fileId]; !found { - missingResources = append(missingResources, fileId) + missingFileIds = append(missingFileIds, fileId) + } + } + if recoveredDescriptors, err := db.fetchDescriptors(orcid, missingFileIds, 50); err == nil { + for _, descriptor := range recoveredDescriptors { + descriptorsByFileId[descriptor["id"].(string)] = descriptor + } + if len(recoveredDescriptors) < len(missingFileIds) { // didn't get them all! + missingFileIds = make([]string, 0) + for _, fileId := range fileIds { + if _, found := descriptorsByFileId[fileId]; !found { + missingFileIds = append(missingFileIds, fileId) + } + } + } else { + // got 'em! + missingFileIds = nil } } - if len(missingResources) > 0 { + if len(missingFileIds) > 0 { + slices.Sort(missingFileIds) return nil, &databases.ResourcesNotFoundError{ Database: "JDP", - ResourceIds: missingResources, + ResourceIds: missingFileIds, } } } @@ -256,9 +278,7 @@ func (db *Database) EndpointNames() []string { return []string{db.EndpointName} } -func (db *Database) StageFiles(orcid string, fileIds []string) (uuid.UUID, error) { - var xferId uuid.UUID - +func (db *Database) requestArchivedFiles(orcid string, fileIds []string, batchSize int) ([]int, error) { // construct a POST request to restore archived files with the given IDs type RestoreRequest struct { Ids []string `json:"ids"` @@ -275,72 +295,116 @@ func (db *Database) StageFiles(orcid string, fileIds []string) (uuid.UUID, error } } - data, err := json.Marshal(RestoreRequest{ - Ids: fileIdsWithoutPrefix, - SendEmail: false, - ApiVersion: "2", - IncludePrivateData: 1, // we need this just in case! - }) - if err != nil { - return xferId, err + numBatches := len(fileIds) / batchSize + if numBatches*batchSize < len(fileIds) { + numBatches += 1 } - // NOTE: The slash in the resource is all-important for POST requests to - // NOTE: the JDP!! - body, err := db.post("request_archived_files/", orcid, bytes.NewReader(data)) - if err != nil { - switch e := err.(type) { - case *databases.ResourcesNotFoundError: - e.ResourceIds = fileIds + requestIds := make([]int, numBatches) + for i := range numBatches { + begin := i * batchSize + end := min((i+1)*batchSize, len(fileIdsWithoutPrefix)) + data, err := json.Marshal(RestoreRequest{ + Ids: fileIdsWithoutPrefix[begin:end], + SendEmail: false, + ApiVersion: "2", + IncludePrivateData: 1, // we need this just in case! + }) + if err != nil { + return nil, err + } + + // NOTE: The slash in the resource is all-important for POST requests to the JDP!! + body, err := db.post("request_archived_files/", orcid, bytes.NewReader(data)) + if err != nil { + switch e := err.(type) { + case *databases.ResourcesNotFoundError: + e.ResourceIds = fileIds + } + return nil, err + } + + type RestoreResponse struct { + RequestId int `json:"request_id"` } - return xferId, err - } - type RestoreResponse struct { - RequestId int `json:"request_id"` + var jdpResp RestoreResponse + err = json.Unmarshal(body, &jdpResp) + if err != nil { + return nil, err + } + requestIds[i] = jdpResp.RequestId } - var jdpResp RestoreResponse - err = json.Unmarshal(body, &jdpResp) + return requestIds, nil +} + +func (db *Database) StageFiles(orcid string, fileIds []string) (uuid.UUID, error) { + var xferId uuid.UUID + + // NOTE: the relevant endpoint seems to return a 404 whenever it gets too many file IDs, + // NOTE: so we batch requests in sets of 1000 + requestIds, err := db.requestArchivedFiles(orcid, fileIds, 1000) if err != nil { return xferId, err } - slog.Debug(fmt.Sprintf("Requested %d archived files from JDP (request ID: %d)", - len(fileIds), jdpResp.RequestId)) + + slog.Debug(fmt.Sprintf("Requested %d archived files from JDP (request IDs: %v)", + len(fileIds), requestIds)) xferId = uuid.New() db.StagingRequests[xferId] = StagingRequest{ - Id: jdpResp.RequestId, + Ids: requestIds, Time: time.Now(), } return xferId, err } func (db *Database) StagingStatus(id uuid.UUID) (databases.StagingStatus, error) { + statusForString := map[string]databases.StagingStatus{ + "new": databases.StagingStatusActive, + "pending": databases.StagingStatusActive, + "ready": databases.StagingStatusSucceeded, + "failed": databases.StagingStatusFailed, + } db.pruneStagingRequests() if request, found := db.StagingRequests[id]; found { - resource := fmt.Sprintf("request_archived_files/requests/%d", request.Id) - body, err := db.get(resource, url.Values{}) - if err != nil { - return databases.StagingStatusUnknown, err - } - type JDPResult struct { - Status string `json:"status"` // "new", "pending", or "ready" - } - var jdpResult JDPResult - err = json.Unmarshal(body, &jdpResult) - if err != nil { - return databases.StagingStatusUnknown, err - } - statusForString := map[string]databases.StagingStatus{ - "new": databases.StagingStatusActive, - "pending": databases.StagingStatusActive, - "ready": databases.StagingStatusSucceeded, - } - if status, ok := statusForString[jdpResult.Status]; ok { - slog.Debug(fmt.Sprintf("Queried JDP for staging status of transfer with staging ID %s (request ID: %d): %s", id, request.Id, jdpResult.Status)) - return status, nil + var status databases.StagingStatus + var statusStr string + for _, requestId := range request.Ids { + resource := fmt.Sprintf("request_archived_files/requests/%d", requestId) + body, err := db.get(resource, url.Values{}) + if err != nil { + return databases.StagingStatusUnknown, err + } + type JDPResult struct { + Status string `json:"status"` // "new", "pending", "ready", or "failed" + } + var jdpResult JDPResult + err = json.Unmarshal(body, &jdpResult) + if err != nil { + return databases.StagingStatusUnknown, err + } + if requestStatus, ok := statusForString[jdpResult.Status]; ok { + if status == databases.StagingStatusUnknown { // first status encountered + status = requestStatus + statusStr = jdpResult.Status + } else { + if requestStatus != status { // status update + if requestStatus != databases.StagingStatusSucceeded { + status = requestStatus + statusStr = jdpResult.Status + } + } + } + if status == databases.StagingStatusFailed { // one failure sinks them all + break + } + } else { + return databases.StagingStatusUnknown, fmt.Errorf("unrecognized JDP staging status string: %s", jdpResult.Status) + } } - return databases.StagingStatusUnknown, fmt.Errorf("unrecognized JDP staging status string: %s", jdpResult.Status) + slog.Debug(fmt.Sprintf("Queried JDP for staging status of transfer with staging ID %s (request IDs: %v): %s", id, request.Ids, statusStr)) + return status, nil } else { slog.Info(fmt.Sprintf("No staging request found for transfer with staging ID %s", id.String())) return databases.StagingStatusUnknown, nil @@ -658,10 +722,6 @@ func (db *Database) post(resource, orcid string, body io.Reader) ([]byte, error) case 200, 201, 204: defer resp.Body.Close() return io.ReadAll(resp.Body) - case 404: - return nil, &databases.ResourcesNotFoundError{ - Database: "JDP", - } case 503: return nil, &databases.UnavailableError{ Database: "jdp", @@ -809,6 +869,15 @@ func (db Database) addSpecificSearchParameters(params map[string]any, p *url.Val } } p.Add(name, fmt.Sprintf("%d", value)) + case "datasets": // specific JGI datasets requested + var value string + if value, ok = jsonValue.(string); !ok { + return &databases.InvalidSearchParameter{ + Database: "JDP", + Message: "Invalid JDP dataset(s) requested (must be comma-delimited string)", + } + } + p.Add(name, strings.TrimSpace(value)) case "extra": // comma-separated additional fields requested var value string if value, ok = jsonValue.(string); !ok { @@ -844,7 +913,7 @@ func (db *Database) pruneStagingRequests() { for uuid, request := range db.StagingRequests { requestAge := time.Since(request.Time) if requestAge > db.DeleteAfter { - slog.Debug(fmt.Sprintf("Pruning staging request with staging ID %s (request ID: %d): age (%s) exceeds limit (%s)", uuid.String(), request.Id, requestAge.String(), db.DeleteAfter.String())) + slog.Debug(fmt.Sprintf("Pruning staging request with staging ID %s (request IDs: %v): age (%s) exceeds limit (%s)", uuid.String(), request.Ids, requestAge.String(), db.DeleteAfter.String())) delete(db.StagingRequests, uuid) } } diff --git a/databases/jdp/database_test.go b/databases/jdp/database_test.go index 0c9de565..98e5b1dc 100644 --- a/databases/jdp/database_test.go +++ b/databases/jdp/database_test.go @@ -67,7 +67,7 @@ var isMockDatabase bool = false var mockJDPServer *httptest.Server var mockJDPSecret string = "mock_shared_secret" var mockOrcId string = "0000-0000-9876-0000" -var mockStagedFileId = 12345 +var mockStagedFileIds = []int{12345} const mockResponseBody string = `{ "organisms": [ @@ -203,7 +203,7 @@ func createMockJDPServer() *httptest.Server { response := struct { RequestId int `json:"request_id"` }{ - RequestId: mockStagedFileId, + RequestId: mockStagedFileIds[0], } w.WriteHeader(http.StatusOK) json.NewEncoder(w).Encode(response) @@ -495,7 +495,7 @@ func TestStageFiles(t *testing.T) { id, err := db.StageFiles(mockOrcId, fileIds) assert.Nil(err, "Database StageFiles encountered an error") assert.NotNil(id, "Database StageFiles returned nil ID") - assert.Equal(mockStagedFileId, db.StagingRequests[id].Id, "Database StageFiles returned incorrect ID") + assert.Equal(mockStagedFileIds, db.StagingRequests[id].Ids, "Database StageFiles returned incorrect ID") } func TestStagingStatus(t *testing.T) { @@ -509,11 +509,11 @@ func TestStagingStatus(t *testing.T) { DeleteAfter: time.Duration(1) * time.Hour, } req1 := StagingRequest{ - Id: 789, + Ids: []int{789, 415}, Time: time.Now(), } req2 := StagingRequest{ - Id: 4, + Ids: []int{4, 8}, Time: time.Now(), } uuid1 := uuid.New() @@ -907,12 +907,12 @@ func TestPruneStagingRequests(t *testing.T) { } newUuid := uuid.New() db.StagingRequests[newUuid] = StagingRequest{ - Id: 1, + Ids: []int{1, 2}, Time: time.Now(), } oldUuid := uuid.New() db.StagingRequests[oldUuid] = StagingRequest{ - Id: 2, + Ids: []int{3, 4}, Time: time.Now().Add(-time.Hour), } db.pruneStagingRequests() diff --git a/services/version.go b/services/version.go index 196820d2..eb249d75 100644 --- a/services/version.go +++ b/services/version.go @@ -6,7 +6,7 @@ import ( // Version numbers var majorVersion = 0 -var minorVersion = 12 +var minorVersion = 13 var patchVersion = 1 // Version string diff --git a/transfers/dispatcher.go b/transfers/dispatcher.go index facb3e9f..cc515b34 100644 --- a/transfers/dispatcher.go +++ b/transfers/dispatcher.go @@ -160,6 +160,21 @@ func (d *dispatcherState) process() { } else { d.Channels.ReturnTransferId <- transferId } + err = d.initialize(transferId) + if err != nil { + slog.Error(fmt.Sprintf("Transfer %s failed: %s", transferId.String(), err.Error())) + status := TransferStatus{ + Code: TransferStatusFailed, + Message: err.Error(), + } + store.SetStatus(transferId, status) + publish(Message{ + Description: fmt.Sprintf("Transfer %s failed: %s", transferId.String(), err.Error()), + TransferId: transferId, + TransferStatus: status, + Time: time.Now(), + }) + } case request := <-d.Channels.CancelTransfer: err := d.cancel(request.Id, request.Orcid) if err == nil { @@ -244,29 +259,33 @@ func (d *dispatcherState) create(spec Specification) (uuid.UUID, error) { return uuid.UUID{}, err } - transferId, err := store.NewTransfer(spec) + return store.NewTransfer(spec) +} + +func (d *dispatcherState) initialize(transferId uuid.UUID) error { + descriptors, err := store.GetDescriptors(transferId) if err != nil { - return uuid.UUID{}, err + return err } - descriptors, err := store.GetDescriptors(transferId) + spec, err := store.GetSpecification(transferId) if err != nil { - return uuid.UUID{}, err + return err } // do we need to stage files for the source database? filesStaged := true descriptorsForEndpoint, err := descriptorsByEndpoint(spec, descriptors) if err != nil { - return uuid.UUID{}, err + return err } for source, descriptorsForSource := range descriptorsForEndpoint { sourceEndpoint, err := endpoints.NewEndpoint(source) if err != nil { - return uuid.UUID{}, err + return err } filesStaged, err = sourceEndpoint.FilesStaged(descriptorsForSource) if err != nil { - return uuid.UUID{}, err + return err } if !filesStaged { break @@ -279,7 +298,7 @@ func (d *dispatcherState) create(spec Specification) (uuid.UUID, error) { err = mover.MoveFiles(transferId) } - return transferId, err + return err } func validateSpecification(spec Specification) error { diff --git a/transfers/store.go b/transfers/store.go index 5b3080ba..748bfc79 100644 --- a/transfers/store.go +++ b/transfers/store.go @@ -236,13 +236,23 @@ func (s *storeState) process(decoder *gob.Decoder) { for running { select { case spec := <-s.Channels.RequestNewTransfer: - id, transfer, err := s.newTransfer(spec) + // create a new transfer ID and return it immediately + id := uuid.New() + s.Channels.ReturnNewTransfer <- id + + // create an entry in the store and finish setting up the transfer + newXfer, err := s.newTransfer(spec) if err != nil { s.Channels.Error <- err - } else { - transfers[id] = transfer - s.Channels.ReturnNewTransfer <- id } + transfers[id] = newXfer + size := newXfer.payloadSize() + publish(Message{ + Description: fmt.Sprintf("Created new transfer %s (%d file(s), %g GB)", id, newXfer.Status.NumFiles, float64(size)/float64(1024*1024*1024)), + TransferId: id, + TransferStatus: transfers[id].Status, + Time: time.Now(), + }) case id := <-s.Channels.RequestDescriptors: if transfer, found := transfers[id]; found { s.Channels.ReturnDescriptors <- transfer.Descriptors @@ -251,18 +261,7 @@ func (s *storeState) process(decoder *gob.Decoder) { } case id := <-s.Channels.RequestPayloadSize: if transfer, found := transfers[id]; found { - var size uint64 - for _, descriptor := range transfer.Descriptors { - switch v := descriptor["bytes"].(type) { - case int: - size += uint64(v) - case int64: - size += uint64(v) - default: - s.Channels.Error <- fmt.Errorf("invalid 'bytes' field type in descriptor: %T", v) - } - } - s.Channels.ReturnPayloadSize <- size + s.Channels.ReturnPayloadSize <- transfer.payloadSize() } else { s.Channels.Error <- TransferNotFoundError{Id: id} } @@ -319,15 +318,29 @@ type transferStoreEntry struct { Status TransferStatus } -func (s *storeState) newTransfer(spec Specification) (uuid.UUID, transferStoreEntry, error) { - id := uuid.New() +func (e transferStoreEntry) payloadSize() uint64 { + var size uint64 + for _, descriptor := range e.Descriptors { + switch v := descriptor["bytes"].(type) { + case int: + size += uint64(v) + case int64: + size += uint64(v) + default: + return 0 // ! + } + } + return size +} + +func (s *storeState) newTransfer(spec Specification) (transferStoreEntry, error) { source, err := databases.NewDatabase(spec.Source) if err != nil { - return id, transferStoreEntry{}, err + return transferStoreEntry{}, err } descriptors, err := source.Descriptors(spec.User.Orcid, spec.FileIds) if err != nil { - return id, transferStoreEntry{}, err + return transferStoreEntry{}, err } slices.SortFunc(descriptors, func(a, b map[string]any) int { return cmp.Compare(a["id"].(string), b["id"].(string)) @@ -340,23 +353,5 @@ func (s *storeState) newTransfer(spec Specification) (uuid.UUID, transferStoreEn }, } - var size uint64 - for _, descriptor := range entry.Descriptors { - switch v := descriptor["bytes"].(type) { - case int: - size += uint64(v) - case int64: - size += uint64(v) - default: - return id, transferStoreEntry{}, fmt.Errorf("invalid 'bytes' field type in descriptor: %T", v) - } - } - publish(Message{ - Description: fmt.Sprintf("Created new transfer %s (%d file(s), %g GB)", id, entry.Status.NumFiles, float64(size)/float64(1024*1024*1024)), - TransferId: id, - TransferStatus: entry.Status, - Time: time.Now(), - }) - - return id, entry, err + return entry, err }