11from __future__ import annotations
22
3+ from enum import IntEnum
34from http .server import BaseHTTPRequestHandler , HTTPServer
45from pathlib import Path
56from threading import Thread
@@ -69,7 +70,7 @@ class Session(LibrespotSession):
6970 def __init__ (
7071 self ,
7172 session_builder : LibrespotSession .Builder ,
72- token : TokenProvider . StoredToken ,
73+ oauth : OAuth ,
7374 language : str = "en" ,
7475 ) -> None :
7576 """
@@ -89,7 +90,7 @@ def __init__(
8990 ),
9091 ApResolver .get_random_accesspoint (),
9192 )
92- self .__token = token
93+ self .__oauth = oauth
9394 self .__language = language
9495 self .connect ()
9596 self .authenticate (session_builder .login_credentials )
@@ -112,8 +113,7 @@ def from_file(cred_file: Path | str, language: str = "en") -> Session:
112113 .build ()
113114 )
114115 session = LibrespotSession .Builder (conf ).stored_file (str (cred_file ))
115- token = session .login_credentials .auth_data # TODO: this is wrong
116- return Session (session , token , language )
116+ return Session (session , OAuth (), language ) # TODO
117117
118118 @staticmethod
119119 def from_oauth (
@@ -148,7 +148,7 @@ def from_oauth(
148148 typ = Authentication .AuthenticationType .values ()[3 ],
149149 auth_data = token .access_token .encode (),
150150 )
151- return Session (session , token , language )
151+ return Session (session , auth , language )
152152
153153 def __get_playable (
154154 self , playable_id : PlayableId , quality : Quality
@@ -188,9 +188,9 @@ def get_episode(self, episode_id: str) -> Episode:
188188 self .api (),
189189 )
190190
191- def token (self ) -> TokenProvider . StoredToken :
192- """Returns API token """
193- return self .__token
191+ def oauth (self ) -> OAuth :
192+ """Returns OAuth service """
193+ return self .__oauth
194194
195195 def language (self ) -> str :
196196 """Returns session language"""
@@ -288,7 +288,7 @@ def __init__(self, session: Session):
288288 self ._session = session
289289
290290 def get_token (self , * scopes ) -> TokenProvider .StoredToken :
291- return self ._session .token ()
291+ return self ._session .oauth (). get_token ()
292292
293293 class StoredToken (LibrespotTokenProvider .StoredToken ):
294294 def __init__ (self , obj ):
@@ -309,6 +309,11 @@ def __init__(self):
309309 self .__server_thread .start ()
310310
311311 def get_authorization_url (self ) -> str :
312+ """
313+ Generate OAuth URL
314+ Returns:
315+ OAuth URL
316+ """
312317 self .__code_verifier = generate_code_verifier ()
313318 code_challenge = get_code_challenge (self .__code_verifier )
314319 params = {
@@ -322,19 +327,48 @@ def get_authorization_url(self) -> str:
322327 return f"{ AUTH_URL } authorize?{ urlencode (params )} "
323328
324329 def await_token (self ) -> TokenProvider .StoredToken :
330+ """
331+ Blocks until server thread gets token
332+ Returns:
333+ StoredToken
334+ """
325335 self .__server_thread .join ()
326336 return self .__token
327337
328- def set_token (self , code : str ) -> None :
338+ def get_token (self ) -> TokenProvider .StoredToken :
339+ """
340+ Gets a valid token
341+ Returns:
342+ StoredToken
343+ """
344+ if self .__token is None :
345+ raise RuntimeError ("Session isn't authenticated!" )
346+ elif self .__token .expired ():
347+ self .set_token (self .__token .refresh_token , OAuth .RequestType .REFRESH )
348+ return self .__token
349+
350+ def set_token (self , code : str , request_type : RequestType ) -> None :
351+ """
352+ Fetches and sets stored token
353+ Returns:
354+ StoredToken
355+ """
329356 token_url = f"{ AUTH_URL } api/token"
330357 headers = {"Content-Type" : "application/x-www-form-urlencoded" }
331- body = {
332- "grant_type" : "authorization_code" ,
333- "code" : code ,
334- "redirect_uri" : REDIRECT_URI ,
335- "client_id" : CLIENT_ID ,
336- "code_verifier" : self .__code_verifier ,
337- }
358+ if request_type == OAuth .RequestType .LOGIN :
359+ body = {
360+ "grant_type" : "authorization_code" ,
361+ "code" : code ,
362+ "redirect_uri" : REDIRECT_URI ,
363+ "client_id" : CLIENT_ID ,
364+ "code_verifier" : self .__code_verifier ,
365+ }
366+ elif request_type == OAuth .RequestType .REFRESH :
367+ body = {
368+ "grant_type" : "refresh_token" ,
369+ "refresh_token" : code ,
370+ "client_id" : CLIENT_ID ,
371+ }
338372 response = post (token_url , headers = headers , data = body )
339373 if response .status_code != 200 :
340374 raise IOError (
@@ -348,6 +382,10 @@ def __run_server(self) -> None:
348382 httpd .authenticator = self
349383 httpd .serve_forever ()
350384
385+ class RequestType (IntEnum ):
386+ LOGIN = 0
387+ REFRESH = 1
388+
351389 class OAuthHTTPServer (HTTPServer ):
352390 authenticator : OAuth
353391
@@ -371,7 +409,9 @@ def do_GET(self) -> None:
371409
372410 if code :
373411 if isinstance (self .server , OAuth .OAuthHTTPServer ):
374- self .server .authenticator .set_token (code [0 ])
412+ self .server .authenticator .set_token (
413+ code [0 ], OAuth .RequestType .LOGIN
414+ )
375415 self .send_response (200 )
376416 self .send_header ("Content-type" , "text/html" )
377417 self .end_headers ()
0 commit comments