|
2 | 2 | import logging |
3 | 3 | import random |
4 | 4 | import urllib |
| 5 | +import json |
5 | 6 | from hashlib import sha256 |
6 | 7 | from http.server import HTTPServer, BaseHTTPRequestHandler |
7 | 8 | from urllib.parse import urlparse |
8 | 9 | from librespot.proto import Authentication_pb2 as Authentication |
| 10 | +from requests.structures import CaseInsensitiveDict |
| 11 | +from datetime import datetime, timedelta |
9 | 12 | import requests |
10 | 13 |
|
11 | 14 |
|
12 | 15 | class OAuth: |
13 | 16 | logger = logging.getLogger("Librespot:OAuth") |
| 17 | + OAUTH_PKCE_TOKEN = "OAUTH_PKCE_TOKEN" |
14 | 18 | __spotify_auth = "https://accounts.spotify.com/authorize?response_type=code&client_id=%s&redirect_uri=%s&code_challenge=%s&code_challenge_method=S256&scope=%s" |
15 | 19 | __scopes = ["app-remote-control", "playlist-modify", "playlist-modify-private", "playlist-modify-public", "playlist-read", "playlist-read-collaborative", "playlist-read-private", "streaming", "ugc-image-upload", "user-follow-modify", "user-follow-read", "user-library-modify", "user-library-read", "user-modify", "user-modify-playback-state", "user-modify-private", "user-personalized", "user-read-birthdate", "user-read-currently-playing", "user-read-email", "user-read-play-history", "user-read-playback-position", "user-read-playback-state", "user-read-private", "user-read-recently-played", "user-top-read"] |
16 | 20 | __spotify_token = "https://accounts.spotify.com/api/token" |
17 | | - __spotify_token_data = {"grant_type": "authorization_code", "client_id": "", "redirect_uri": "", "code": "", "code_verifier": ""} |
| 21 | + __spotify_token_data = CaseInsensitiveDict({"grant_type": "", |
| 22 | + "client_id": ""}) |
18 | 23 | __client_id = "" |
19 | 24 | __redirect_url = "" |
20 | 25 | __code_verifier = "" |
21 | 26 | __code = "" |
22 | 27 | __token = "" |
| 28 | + __token_expires_at = datetime.now() |
| 29 | + __refresh_token = "" |
23 | 30 | __server = None |
24 | 31 | __oauth_url_callback = None |
25 | 32 | __success_page_content = None |
| 33 | + __listen_all_interfaces = False |
26 | 34 |
|
27 | 35 | def __init__(self, client_id, redirect_url, oauth_url_callback): |
28 | 36 | self.__client_id = client_id |
@@ -53,26 +61,83 @@ def get_auth_url(self): |
53 | 61 |
|
54 | 62 | def set_code(self, code): |
55 | 63 | self.__code = code |
| 64 | + return self |
56 | 65 |
|
57 | 66 | def set_scopes(self, scopes): |
58 | 67 | self.__scopes = scopes |
59 | 68 | return self |
60 | 69 |
|
| 70 | + def set_listen_all(self, listen_all: bool): |
| 71 | + self.__listen_all_interfaces = listen_all |
| 72 | + return self |
| 73 | + |
| 74 | + def ingest_token_response(self, result): |
| 75 | + self.__token = result["access_token"] |
| 76 | + self.__refresh_token = result["refresh_token"] |
| 77 | + if "expires_in" in result: |
| 78 | + self.__token_expires_at = datetime.now() + timedelta(seconds=result["expires_in"]) |
| 79 | + elif "expires_at" in result: |
| 80 | + self.__token_expires_at = datetime.fromtimestamp(result["expires_at"]) |
| 81 | + return self |
| 82 | + |
61 | 83 | def request_token(self): |
62 | 84 | if not self.__code: |
63 | 85 | raise RuntimeError("You need to provide a code before!") |
| 86 | + |
64 | 87 | request_data = self.__spotify_token_data |
| 88 | + request_data["grant_type"] = "authorization_code" |
65 | 89 | request_data["client_id"] = self.__client_id |
66 | 90 | request_data["redirect_uri"] = self.__redirect_url |
67 | 91 | request_data["code"] = self.__code |
68 | 92 | request_data["code_verifier"] = self.__code_verifier |
69 | | - request = requests.post( |
| 93 | + |
| 94 | + response = requests.post( |
| 95 | + self.__spotify_token, |
| 96 | + headers=CaseInsensitiveDict({"Content-Type": "application/x-www-form-urlencoded"}), |
| 97 | + data=request_data, |
| 98 | + ) |
| 99 | + if response.status_code != 200: |
| 100 | + raise RuntimeError("Received status code %d: %s" % (response.status_code, response.reason)) |
| 101 | + return self.ingest_token_response(response.json()) |
| 102 | + |
| 103 | + def refresh_token(self): |
| 104 | + if not self.__refresh_token: |
| 105 | + raise RuntimeError("You need to receive a token before!") |
| 106 | + |
| 107 | + if self.__token_expires_at > datetime.now(): |
| 108 | + return self |
| 109 | + |
| 110 | + request_data = self.__spotify_token_data |
| 111 | + request_data["grant_type"] = "refresh_token" |
| 112 | + request_data["client_id"] = self.__client_id |
| 113 | + request_data["refresh_token"] = self.__refresh_token |
| 114 | + |
| 115 | + response = requests.post( |
70 | 116 | self.__spotify_token, |
| 117 | + headers=CaseInsensitiveDict({"Content-Type": "application/x-www-form-urlencoded"}), |
71 | 118 | data=request_data, |
72 | 119 | ) |
73 | | - if request.status_code != 200: |
74 | | - raise RuntimeError("Received status code %d: %s" % (request.status_code, request.reason)) |
75 | | - self.__token = request.json()["access_token"] |
| 120 | + if response.status_code != 200: |
| 121 | + raise RuntimeError("Received status code %d: %s" % (response.status_code, response.reason)) |
| 122 | + return self.ingest_token_response(response.json()) |
| 123 | + |
| 124 | + def token(self): |
| 125 | + if not self.__token: |
| 126 | + raise RuntimeError("You need to request a token bore!") |
| 127 | + |
| 128 | + self.refresh_token() |
| 129 | + |
| 130 | + return self.__token |
| 131 | + |
| 132 | + def save_creds(self, cred_path: str): |
| 133 | + with open(cred_path, 'w',) as f: |
| 134 | + json.dump({ |
| 135 | + "client_id": self.__client_id, |
| 136 | + "access_token": self.__token, |
| 137 | + "expires_at": self.__token_expires_at.timestamp(), |
| 138 | + "refresh_token": self.__refresh_token, |
| 139 | + "type": self.OAUTH_PKCE_TOKEN |
| 140 | + }, f) |
76 | 141 |
|
77 | 142 | def get_credentials(self): |
78 | 143 | if not self.__token: |
@@ -123,8 +188,9 @@ def __start_server(self): |
123 | 188 |
|
124 | 189 | def run_callback_server(self): |
125 | 190 | url = urlparse(self.__redirect_url) |
| 191 | + address = "" if self.__listen_all_interfaces else url.hostname |
126 | 192 | self.__server = self.CallbackServer( |
127 | | - (url.hostname, url.port), |
| 193 | + (address, url.port), |
128 | 194 | self.CallbackRequestHandler, |
129 | 195 | url.path, |
130 | 196 | self.set_code, |
|
0 commit comments