diff --git a/pyhon/connection/auth.py b/pyhon/connection/auth.py index 1cc0217..96bfd83 100644 --- a/pyhon/connection/auth.py +++ b/pyhon/connection/auth.py @@ -3,6 +3,7 @@ import logging import re import secrets import urllib +from datetime import datetime, timedelta from pprint import pformat from typing import List, Tuple from urllib import parse @@ -16,6 +17,9 @@ _LOGGER = logging.getLogger(__name__) class HonAuth: + _TOKEN_EXPIRES_AFTER_HOURS = 8 + _TOKEN_EXPIRE_WARNING_HOURS = 7 + def __init__(self, session, email, password, device) -> None: self._session = session self._email = email @@ -26,6 +30,7 @@ class HonAuth: self._id_token = "" self._device = device self._called_urls: List[Tuple[int, str]] = [] + self._expires: datetime = datetime.utcnow() @property def cognito_token(self): @@ -43,6 +48,17 @@ class HonAuth: def refresh_token(self): return self._refresh_token + def _check_token_expiration(self, hours): + return datetime.utcnow() >= self._expires + timedelta(hours=hours) + + @property + def token_is_expired(self) -> bool: + return self._check_token_expiration(self._TOKEN_EXPIRES_AFTER_HOURS) + + @property + def token_expires_soon(self) -> bool: + return self._check_token_expiration(self._TOKEN_EXPIRE_WARNING_HOURS) + async def _error_logger(self, response, fail=True): result = "hOn Authentication Error\n" for i, (status, url) in enumerate(self._called_urls): @@ -72,6 +88,7 @@ class HonAuth: ) as response: self._called_urls.append((response.status, response.request_info.url)) text = await response.text() + self._expires = datetime.utcnow() if not (login_url := re.findall("url = '(.+?)'", text)): if "oauth/done#access_token=" in text: self._parse_token_data(text) @@ -237,12 +254,14 @@ class HonAuth: await self._error_logger(response, fail=False) return False data = await response.json() + self._expires = datetime.utcnow() self._id_token = data["id_token"] self._access_token = data["access_token"] return await self._api_auth() def clear(self): self._session.cookie_jar.clear_domain(const.AUTH_API.split("/")[-2]) + self._called_urls = [] self._cognito_token = "" self._id_token = "" self._access_token = "" diff --git a/pyhon/connection/handler.py b/pyhon/connection/handler.py index 935c614..d12b559 100644 --- a/pyhon/connection/handler.py +++ b/pyhon/connection/handler.py @@ -100,14 +100,18 @@ class HonConnectionHandler(HonBaseConnectionHandler): ) -> AsyncIterator: kwargs["headers"] = await self._check_headers(kwargs.get("headers", {})) async with method(*args, **kwargs) as response: - if response.status in [401, 403] and loop == 0: + if ( + self._auth.token_expires_soon or response.status in [401, 403] + ) and loop == 0: _LOGGER.info("Try refreshing token...") await self._auth.refresh() async with self._intercept( method, *args, loop=loop + 1, **kwargs ) as result: yield result - elif response.status in [401, 403] and loop == 1: + elif ( + self._auth.token_is_expired or response.status in [401, 403] + ) and loop == 1: _LOGGER.warning( "%s - Error %s - %s", response.request_info.url,