Skip to content

Commit ecd31ba

Browse files
committed
Process repositories concurrently with asyncio
1 parent 00b9d3c commit ecd31ba

2 files changed

Lines changed: 43 additions & 30 deletions

File tree

oshminer/Wikifactory.py

Lines changed: 17 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -17,12 +17,7 @@
1717
import oshminer.filetypes as filetypes
1818
from oshminer.errors import exceptions
1919

20-
# Select transport with the Wikifactory API endpoint URL
21-
transport = AIOHTTPTransport(url = "https://wikifactory.com/api/graphql")
22-
# Create a GraphQL client using the defined transport
23-
client = Client(transport = transport, fetch_schema_from_transport = True)
24-
25-
async def get_files_info(project: dict) -> dict:
20+
async def get_files_info(project: dict, session) -> dict:
2621
# Provide a GraphQL query
2722
query = gql(
2823
"""
@@ -55,7 +50,7 @@ async def get_files_info(project: dict) -> dict:
5550

5651
# Execute the query on the transport
5752
# `execute_async()` is the asynchronous version of `execute()`
58-
API_response = await client.execute_async(query, variable_values = params)
53+
API_response = await session.execute(query, variable_values = params)
5954
if API_response["project"]["result"] is None:
6055
raise exceptions.BadRepoError # Raise error if Wikifactory API can't find this project
6156
# Only continue if there are files in the project
@@ -131,7 +126,7 @@ async def get_files_info(project: dict) -> dict:
131126

132127
return result
133128

134-
async def get_issues_level(project: dict) -> dict:
129+
async def get_issues_level(project: dict, session) -> dict:
135130
# Provide a GraphQL query
136131
query = gql(
137132
"""
@@ -175,7 +170,7 @@ async def get_issues_level(project: dict) -> dict:
175170
# Make queries to retrieve project issues
176171
while issues_has_next_page:
177172
# `execute_async()` is the asynchronous version of `execute()`
178-
API_response: dict = await client.execute_async(query, variable_values = params)
173+
API_response: dict = await session.execute(query, variable_values = params)
179174
if API_response["project"]["result"] is None:
180175
raise exceptions.BadRepoError # Raise error if Wikifactory API can't find this project
181176
if API_response["project"]["result"]["tracker"]["issues"]["totalCount"] > 0: # Only continue if there are issues
@@ -200,7 +195,7 @@ async def get_issues_level(project: dict) -> dict:
200195
}
201196
return result
202197

203-
async def get_commits_level(project: dict) -> dict:
198+
async def get_commits_level(project: dict, session) -> dict:
204199
# Provide a GraphQL query
205200
query = gql(
206201
"""
@@ -239,7 +234,7 @@ async def get_commits_level(project: dict) -> dict:
239234
contribs: list = []
240235
while contribs_has_next_page:
241236
# `execute_async()` is the asynchronous version of `execute()`
242-
API_response: dict = await client.execute_async(query, variable_values = params)
237+
API_response: dict = await session.execute(query, variable_values = params)
243238
if API_response["project"]["result"] is None:
244239
raise exceptions.BadRepoError # Raise error if Wikifactory API can't find this project
245240
if API_response["project"]["result"]["contributions"]["totalCount"] > 0: # Only continue if there are contributions
@@ -261,7 +256,7 @@ async def get_commits_level(project: dict) -> dict:
261256
}
262257
return result
263258

264-
async def get_tags(project: dict) -> dict:
259+
async def get_tags(project: dict, session) -> dict:
265260
# Provide a GraphQL query
266261
query = gql(
267262
"""
@@ -304,7 +299,7 @@ async def get_tags(project: dict) -> dict:
304299
# Keep going while there is a next page of user tags
305300
while user_tags_has_next_page:
306301
# `execute_async()` is the asynchronous version of `execute()`
307-
API_response: dict = await client.execute_async(query, variable_values = params)
302+
API_response: dict = await session.execute(query, variable_values = params)
308303
if API_response["project"]["result"] is None:
309304
raise exceptions.BadRepoError # Raise error if Wikifactory API can't find this project
310305
# Append project level tags to list
@@ -324,7 +319,7 @@ async def get_tags(project: dict) -> dict:
324319
result["tags"] = list(set(result["tags"]))
325320
return result
326321

327-
async def get_license(project: dict) -> dict:
322+
async def get_license(project: dict, session) -> dict:
328323
# Provide a GraphQL query
329324
query = gql(
330325
"""
@@ -355,7 +350,7 @@ async def get_license(project: dict) -> dict:
355350

356351
# Execute the query on the transport
357352
# `execute_async()` is the asynchronous version of `execute()`
358-
API_response: dict = await client.execute_async(query, variable_values = params)
353+
API_response: dict = await session.execute(query, variable_values = params)
359354
if API_response["project"]["result"] is None:
360355
raise exceptions.BadRepoError # Raise error if Wikifactory API can't find this project
361356
# Get license string for this project
@@ -416,11 +411,15 @@ async def make_Wikifactory_request(url: str, data: list) -> str:
416411
"requested_data": {}
417412
}
418413

414+
# Select transport with the Wikifactory API endpoint URL
415+
transport = AIOHTTPTransport(url = "https://wikifactory.com/api/graphql")
416+
419417
# Get "space" and "slug" components from this repository's URL
420418
space_slug: dict = parse_url(url)
421419

422-
for data_type in data:
423-
query_result: dict = await queries[data_type](space_slug)
424-
results["requested_data"].update(query_result)
420+
async with Client(transport = transport, fetch_schema_from_transport = True) as session:
421+
for data_type in data:
422+
query_result: dict = await queries[data_type](space_slug, session)
423+
results["requested_data"].update(query_result)
425424

426425
return results

oshminer/main.py

Lines changed: 26 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
# SPDX-License-Identifier: AGPL-3.0-or-later
55

66
# Python Standard Library imports
7+
import asyncio
78
import json
89
import sys
910

@@ -43,6 +44,17 @@ class MiningRequest(BaseModel):
4344
async def root():
4445
return {"message": "Dashboard data-mining backend is on"}
4546

47+
async def process_repo(repo: HttpUrl, requests: list[str], responses: list):
48+
platform: str = repo.host.replace("www.", "")
49+
try:
50+
repo_info: dict = await supported_domains[platform](repo, requests)
51+
except exceptions.BadRepoError:
52+
return JSONResponse(
53+
status_code = status.HTTP_400_BAD_REQUEST,
54+
content = f"Error with repository: {repo}"
55+
)
56+
responses.append(repo_info)
57+
4658
@app.post(
4759
"/data/",
4860
name = "API endpoint",
@@ -85,21 +97,23 @@ async def mining_request(request_body: MiningRequest):
8597
# Prepare API response
8698
#
8799

88-
response: list = []
100+
response_list: list = []
89101

90102
#
91103
# Construct, send API requests, and get results
92104
#
93105

94-
for repo in request_body.repo_urls:
95-
platform = repo.host.replace("www.", "")
96-
try:
97-
repo_info: dict = await supported_domains[platform](repo, request_body.requested_data)
98-
except exceptions.BadRepoError:
99-
return JSONResponse(
100-
status_code = status.HTTP_400_BAD_REQUEST,
101-
content = f"Error with repository: {repo}"
102-
)
103-
response.append(repo_info)
106+
await asyncio.gather(*[process_repo(repo, request_body.requested_data, response_list) for repo in request_body.repo_urls])
107+
108+
# for repo in request_body.repo_urls:
109+
# platform = repo.host.replace("www.", "")
110+
# try:
111+
# repo_info: dict = await supported_domains[platform](repo, request_body.requested_data)
112+
# except exceptions.BadRepoError:
113+
# return JSONResponse(
114+
# status_code = status.HTTP_400_BAD_REQUEST,
115+
# content = f"Error with repository: {repo}"
116+
# )
117+
# response_list.append(repo_info)
104118

105-
return response
119+
return response_list

0 commit comments

Comments
 (0)