diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml index d997676..39d79ac 100644 --- a/.github/workflows/lint.yml +++ b/.github/workflows/lint.yml @@ -2,16 +2,29 @@ name: lint on: pull_request: +permissions: + contents: read + jobs: flake8: runs-on: ubuntu-latest steps: - - uses: actions/checkout@v2 - - uses: actions/setup-python@v2 + - uses: actions/checkout@v6 + - uses: actions/setup-python@v6 with: - python-version: 3.9 + python-version: "3.10" - uses: TrueBrain/actions-flake8@v2 with: flake8_version: 6.0.0 plugins: flake8-isort==6.0.0 + + ruff-format: + runs-on: ubuntu-latest + + steps: + - uses: actions/checkout@v4 + - uses: astral-sh/ruff-action@v3 + with: + version: "~=0.13.3" + args: format --check --diff --output-format=github diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 46ea20a..0d4d5e4 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -6,26 +6,29 @@ on: branches: - main +permissions: + contents: read + jobs: test: runs-on: ubuntu-latest steps: - - uses: actions/checkout@v2 - - uses: actions/setup-python@v1 - with: - python-version: 3.9 + - uses: actions/checkout@v6 + - uses: actions/setup-python@v6 + with: + python-version: "3.10" - - name: Install dependencies - run: | - python -m pip install --upgrade pip - pip install -r requirements.txt -r requirements-dev.txt + - name: Install dependencies + run: | + python -m pip install uv + uv sync - - run: pip install pytest-github-actions-annotate-failures + - run: uv pip install pytest-github-actions-annotate-failures - - run: py.test --cov=rain_api_core --cov-report=term-missing --cov-report=xml --cov-branch --doctest-modules rain_api_core tests + - run: uv run pytest --cov=src/rain_api_core --cov-report=term-missing --cov-report=xml --cov-branch --doctest-modules src/rain_api_core tests - - name: Report coverage - uses: codecov/codecov-action@v4 - with: - token: ${{ secrets.CODECOV_TOKEN }} - fail_ci_if_error: true + - name: Report coverage + uses: codecov/codecov-action@v6 + with: + token: ${{ secrets.CODECOV_TOKEN }} + fail_ci_if_error: true diff --git a/.gitignore b/.gitignore index 5ae822e..56b3bf5 100644 --- a/.gitignore +++ b/.gitignore @@ -1,6 +1,7 @@ # Python *.pyc .venv +uv.lock # IDE .idea @@ -11,3 +12,6 @@ # Tests .hypothesis .coverage + +# Other +build diff --git a/extra/policy_gen_sandbox.py b/extra/policy_gen_sandbox.py index bd28523..2662578 100644 --- a/extra/policy_gen_sandbox.py +++ b/extra/policy_gen_sandbox.py @@ -46,10 +46,7 @@ def handle_text(): tk.Label(frm_content, text="Bucket map YAML").grid(row=0, column=0) txt_bucketmap = tk.Text(frm_content) - txt_bucketmap.bind( - "", - lambda _: window.after(1, handle_text) - ) + txt_bucketmap.bind("", lambda _: window.after(1, handle_text)) txt_bucketmap.grid(row=1, column=0, sticky="nsew") # Policy panel @@ -65,10 +62,7 @@ def handle_text(): tk.Label(frm_groups, text="User Groups: ").grid(row=0, column=0) var_group = tk.StringVar(value="null") entry_groups = tk.Entry(frm_groups, textvariable=var_group) - entry_groups.bind( - "", - lambda _: window.after(1, handle_text) - ) + entry_groups.bind("", lambda _: window.after(1, handle_text)) entry_groups.grid(row=0, column=1) # Minified size indicator diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..83a1e5b --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,30 @@ +[build-system] +requires = ["uv_build>=0.10.2,<0.11.0"] +build-backend = "uv_build" + +[project] +name = "rain-api-core" +version = "0.1.0" +description = "RAIN API Core" +readme = "README.md" +authors = [ + { name = "Rohan Weeden", email = "reweeden@alaska.edu" } +] +requires-python = "~=3.10" +dependencies = [ + "cachetools~=5.0", + "jinja2~=3.0", + "netaddr~=1.0", + "pyjwt[crypto]~=2.0", + "pyyaml~=6.0", +] + +[dependency-groups] +dev = [ + "boto3~=1.35", + "hypothesis~=6.112", + "moto~=5.0", + "pytest~=8.3", + "pytest-cov~=5.0", + "pytest-mock~=3.14", +] diff --git a/rain_api_core/urs_util.py b/rain_api_core/urs_util.py deleted file mode 100644 index e663128..0000000 --- a/rain_api_core/urs_util.py +++ /dev/null @@ -1,326 +0,0 @@ -import logging -import os -from typing import Optional - -from rain_api_core.auth import JwtManager, UserProfile -from rain_api_core.aws_util import retrieve_secret -from rain_api_core.edl import EdlClient, EdlException -from rain_api_core.logging import log_context - -log = logging.getLogger(__name__) - - -def get_base_url(ctxt: dict = None) -> str: - # Make a redirect url using optional custom domain_name, otherwise use raw domain/stage provided by API Gateway. - try: - domain = os.getenv('DOMAIN_NAME') or f"{ctxt['domainName']}/{ctxt['stage']}" - return f'https://{domain}/' - except (TypeError, KeyError) as e: - log.error('could not create a redirect_url, because {}'.format(e)) - raise - - -def get_redirect_url(ctxt: dict = None) -> str: - return f'{get_base_url(ctxt)}login' - - -def do_auth(code: str, redirect_url: str, aux_headers: dict = {}) -> dict: - # App U:P from URS Application - auth = get_urs_creds()['UrsAuth'] - - data = { - 'grant_type': 'authorization_code', - 'code': code, - 'redirect_uri': redirect_url, - } - - headers = {'Authorization': 'Basic ' + auth} - headers.update(aux_headers) - - client = EdlClient() - try: - return client.request( - 'POST', - '/oauth/token', - data=data, - headers=headers, - ) - except EdlException: - return {} - - -def get_urs_url(ctxt: dict, to: str = None) -> str: - base_url = os.getenv('AUTH_BASE_URL', 'https://urs.earthdata.nasa.gov') + '/oauth/authorize' - - # From URS Application - client_id = get_urs_creds()['UrsId'] - - log.debug('domain name: {0}'.format(os.getenv('DOMAIN_NAME', 'no domainname set'))) - log.debug('if no domain name set: {}.execute-api.{}.amazonaws.com/{}'.format( - ctxt['apiId'], - os.getenv('AWS_DEFAULT_REGION', ''), - ctxt['stage'] - )) - - urs_url = f'{base_url}?client_id={client_id}&response_type=code&redirect_uri={get_redirect_url(ctxt)}' - if to: - urs_url += f"&state={to}" - - # Try to handle scripts - try: - download_agent = ctxt['identity']['userAgent'] - except KeyError: - log.debug("No User Agent!") - return urs_url - - if not download_agent.startswith('Mozilla'): - urs_url += "&app_type=401" - - return urs_url - - -def get_user_profile(urs_user_payload: dict, access_token) -> UserProfile: - return UserProfile( - user_id=urs_user_payload['uid'], - token=access_token, - groups=urs_user_payload['user_groups'], - first_name=urs_user_payload['first_name'], - last_name=urs_user_payload['last_name'], - email=urs_user_payload['email_address'], - ) - - -def get_profile( - user_id: str, - token: str, - temptoken: str = None, - aux_headers: dict = {}, -) -> Optional[UserProfile]: - if not user_id or not token: - return None - - # get_new_token_and_profile() will pass this function a temporary token with - # which to fetch the profile info. We don't want to keep it around, just use - # it here, once: - if temptoken: - headertoken = temptoken - else: - headertoken = token - - headers = {'Authorization': 'Bearer ' + headertoken} - headers.update(aux_headers) - params = {'client_id': get_urs_creds()['UrsId']} - - client = EdlClient() - try: - user_profile = client.request( - 'GET', - f'/api/users/{user_id}', - params=params, - headers=headers, - ) - return get_user_profile(user_profile, headertoken) - except EdlException as e: - log.warning('Error fetching profile: %s', e.inner) - if not temptoken: # This keeps get_new_token_and_profile() from calling this over and over - log.debug('because error above, going to get_new_token_and_profile()') - return get_new_token_and_profile(user_id, token, aux_headers) - - log.debug( - f"We got that 401 above and we're using a temptoken ({temptoken}), " - "so giving up and not getting a profile." - ) - return None - - -def get_new_token_and_profile( - user_id: str, - cookietoken: str, - aux_headers: dict = {}, -) -> Optional[UserProfile]: - # App U:P from URS Application - auth = get_urs_creds()['UrsAuth'] - data = {'grant_type': 'client_credentials'} - - headers = {'Authorization': 'Basic ' + auth} - headers.update(aux_headers) - - client = EdlClient() - try: - log.info('Attempting to get new Token') - - response = client.request( - 'POST', - '/oauth/token', - data=data, - headers=headers, - ) - new_token = response['access_token'] - - log.info('Retrieved new token: %s', new_token) - # Get user profile with new token - return get_profile( - user_id, - cookietoken, - new_token, - aux_headers=aux_headers, - ) - except EdlException: - return None - - -def user_in_group_list(private_groups: list, user_groups: list) -> bool: - client_id = get_urs_creds()['UrsId'] - log.info("Searching for private groups {0} in {1}".format(private_groups, user_groups)) - - group_names = {group["name"] for group in user_groups if group["client_id"] == client_id} - - for group in private_groups: - if group in group_names: - log.info("User belongs to private group {}".format(group)) - return True - return False - - -def user_in_group_urs(private_groups, user_id, token, user_profile=None, refresh_first=False, aux_headers=None): - aux_headers = aux_headers or {} # A safer default - new_profile = {} - - if refresh_first or not user_profile: - user_profile = get_profile(user_id, token, aux_headers=aux_headers) - new_profile = user_profile - - if ( - isinstance(user_profile, dict) - and 'user_groups' in user_profile - and user_in_group_list(private_groups, user_profile['user_groups']) - ): - log.info("User {0} belongs to private group".format(user_id)) - return True, new_profile - - # Couldn't find user in provided groups, but we may as well look at a fresh group list: - if not refresh_first: - # we have a maybe not so fresh user_profile and we could try again to see if someone added a group to this user: - log.debug(f"Could not validate user {user_id} belonging to groups {private_groups}, attempting profile refresh") - - return user_in_group_urs(private_groups, user_id, {}, refresh_first=True, aux_headers=aux_headers) - log.debug("Even after profile refresh, user {0} does not belong to groups {1}".format(user_id, private_groups)) - - return False, new_profile - - -def user_in_group(private_groups, user_profile: UserProfile, refresh_first=False, aux_headers=None): - aux_headers = aux_headers or {} # A safer default - - # If a new profile is fetched, it is assigned to this var, and returned so that a fresh jwt cookie can be set. - new_profile = None - - if not private_groups: - return False, new_profile - - if not user_profile: - return False, new_profile - - if refresh_first: - new_profile = get_profile(user_profile.user_id, user_profile.token, aux_headers=aux_headers) - user_profile.groups = new_profile.groups - - in_group = user_in_group_list(private_groups, user_profile.groups) - if in_group: - return True, new_profile - - if not in_group and not refresh_first: - # one last ditch effort to see if they were so very recently added to group: - user_profile = get_profile( - user_profile.user_id, - user_profile.token, - aux_headers=aux_headers - ) - return user_in_group(private_groups, user_profile, refresh_first=True, aux_headers=aux_headers) - - return False, new_profile - - -def get_urs_creds() -> dict: - """ - Fetches URS creds from secrets manager. - :return: looks like: - { - "UrsId": "stringofseeminglyrandomcharacters", - "UrsAuth": "verymuchlongerstringofseeminglyrandomcharacters" - } - :type: dict - """ - secret_name = os.getenv('URS_CREDS_SECRET_NAME') - - if not secret_name: - log.error('URS_CREDS_SECRET_NAME not set') - return {} - - secret = retrieve_secret(secret_name) - if not ('UrsId' in secret and 'UrsAuth' in secret): - log.error('AWS secret {} does not contain required keys "UrsId" and "UrsAuth"'.format(secret_name)) - - return secret - - -# This do_login() is mainly for chalice clients. -def do_login(args, context, jwt_manager: JwtManager, cookie_domain='', aux_headers=None): - aux_headers = aux_headers or {} # A safer default - - log.debug('the query_params: {}'.format(args)) - - if not args: - template_vars = {'contentstring': 'No params', 'title': 'Could Not Login'} - headers = {} - return 400, template_vars, headers - - if args.get('error', False): - contentstring = 'An error occurred while trying to log into URS. URS says: "{}". '.format(args.get('error', '')) - template_vars = {'contentstring': contentstring, 'title': 'Could Not Login'} - if args.get('error') == 'access_denied': - # This happens when user doesn't agree to EULA. Maybe other times too. - return_status = 401 - template_vars['contentstring'] = 'Be sure to agree to the EULA.' - template_vars['error_code'] = 'EULA_failure' - else: - return_status = 400 - - return return_status, template_vars, {} - - if 'code' not in args: - contentstring = 'Did not get the required CODE from URS' - - template_vars = {'contentstring': contentstring, 'title': 'Could Not Login'} - headers = {} - return 400, template_vars, headers - - log.debug('pre-do_auth() query params: {}'.format(args)) - redir_url = get_redirect_url(context) - auth = do_auth(args.get('code', ''), redir_url, aux_headers=aux_headers) - log.debug('auth: {}'.format(auth)) - if not auth: - log.debug('no auth returned from do_auth()') - - template_vars = {'contentstring': 'There was a problem talking to URS Login', 'title': 'Could Not Login'} - - return 400, template_vars, {} - - user_id = auth['endpoint'].split('/')[-1] - log_context(user_id=user_id) - - user_profile = get_profile(user_id, auth['access_token'], aux_headers={}) - log.debug('Got the user profile: {}'.format(user_profile)) - if user_profile is not None: - log.debug('urs-access-token: {}'.format(auth['access_token'])) - if 'state' in args: - redirect_to = args["state"] - else: - redirect_to = get_base_url(context) - - headers = {'Location': redirect_to} - headers.update(jwt_manager.get_header_to_set_auth_cookie(user_profile, cookie_domain)) - return 301, {}, headers - - template_vars = {'contentstring': 'Could not get user profile from URS', 'title': 'Could Not Login'} - return 400, template_vars, {} diff --git a/requirements-dev.txt b/requirements-dev.txt deleted file mode 100644 index e1d0a6e..0000000 --- a/requirements-dev.txt +++ /dev/null @@ -1,6 +0,0 @@ -boto3==1.35.18 -hypothesis==6.112.0 -moto==5.0.14 -pytest-cov==5.0.0 -pytest-mock==3.14.0 -pytest==8.3.3 diff --git a/requirements.txt b/requirements.txt deleted file mode 100644 index 246de4e..0000000 --- a/requirements.txt +++ /dev/null @@ -1,7 +0,0 @@ -cachetools~=5.0 -jinja2~=3.0 -netaddr~=1.0 -pyjwt[crypto]~=2.0 -pyyaml~=6.0 - -pip~=24.0 # not directly required, pinned by Snyk to avoid a vulnerability diff --git a/setup.py b/setup.py deleted file mode 100644 index a45c3e2..0000000 --- a/setup.py +++ /dev/null @@ -1,15 +0,0 @@ -#!/usr/bin/env python - -from setuptools import setup - -with open("requirements.txt") as f: - requirements = f.readlines() - - -setup( - name="rain-api-core", - author="Alaska Satellite Facility", - url="https://github.com/asfadmin/rain-api-core", - packages=["rain_api_core"], - install_requires=requirements -) diff --git a/rain_api_core/__init__.py b/src/rain_api_core/__init__.py similarity index 100% rename from rain_api_core/__init__.py rename to src/rain_api_core/__init__.py diff --git a/rain_api_core/auth.py b/src/rain_api_core/auth.py similarity index 58% rename from rain_api_core/auth.py rename to src/rain_api_core/auth.py index 96db5a2..0fdaba8 100644 --- a/rain_api_core/auth.py +++ b/src/rain_api_core/auth.py @@ -1,9 +1,10 @@ import contextlib import dataclasses import logging +from collections.abc import Mapping from http.cookies import CookieError, SimpleCookie from time import time -from typing import List, Mapping, Optional +from typing import Optional from wsgiref.handlers import format_date_time as format_7231_date import jwt @@ -15,36 +16,36 @@ class UserProfile: user_id: str token: str - groups: List[str] + groups: list[str] first_name: str last_name: str email: str - iat: int = None - exp: int = None + iat: Optional[int] = None + exp: Optional[int] = None @classmethod def from_jwt_payload(cls, payload): return cls( - user_id=payload.get('urs-user-id'), - token=payload.get('urs-access-token'), - groups=payload.get('urs-groups'), - first_name=payload.get('first_name'), - last_name=payload.get('last_name'), - email=payload.get('email'), - iat=payload.get('iat'), - exp=payload.get('exp') + user_id=payload.get("urs-user-id"), + token=payload.get("urs-access-token"), + groups=payload.get("urs-groups"), + first_name=payload.get("first_name"), + last_name=payload.get("last_name"), + email=payload.get("email"), + iat=payload.get("iat"), + exp=payload.get("exp"), ) def to_jwt_payload(self): return { - 'urs-user-id': self.user_id, - 'urs-access-token': self.token, - 'urs-groups': self.groups, - 'first_name': self.first_name, - 'last_name': self.last_name, - 'email': self.email, - 'iat': self.iat, - 'exp': self.exp, + "urs-user-id": self.user_id, + "urs-access-token": self.token, + "urs-groups": self.groups, + "first_name": self.first_name, + "last_name": self.last_name, + "email": self.email, + "iat": self.iat, + "exp": self.exp, } @@ -56,7 +57,7 @@ def __init__( private_key: str, cookie_name: str, blacklist={}, - session_ttl_in_hours: float = 7 * 24 + session_ttl_in_hours: float = 7 * 24, ): self.algorithm = algorithm self.public_key = public_key @@ -66,7 +67,9 @@ def __init__( self.black_list = blacklist def _get_auth_cookie(self, headers: Mapping[str, str]): - cookie_string = headers.get('cookie') or headers.get('Cookie') or headers.get('COOKIE') + cookie_string = ( + headers.get("cookie") or headers.get("Cookie") or headers.get("COOKIE") + ) if not cookie_string: return {} @@ -79,17 +82,17 @@ def _decode_jwt(self, token: str): try: return jwt.decode(token.encode(), self.public_key, [self.algorithm]) except jwt.ExpiredSignatureError: - log.info('JWT has expired') + log.info("JWT has expired") except jwt.InvalidSignatureError: - log.info('JWT has failed verification') + log.info("JWT has failed verification") return None def _encode_jwt(self, payload: Mapping[str, str]) -> str: try: encoded = jwt.encode(payload, self.private_key, self.algorithm) except TypeError: - log.error('unable to encode jwt cookie') - return '' + log.error("unable to encode jwt cookie") + return "" return encoded def _jwt_payload_from_user_profile(self, user_profile: Optional[UserProfile]): @@ -97,14 +100,14 @@ def _jwt_payload_from_user_profile(self, user_profile: Optional[UserProfile]): return {} now = int(time()) return { - 'urs-user-id': user_profile.user_id, - 'first_name': user_profile.first_name, - 'last_name': user_profile.last_name, - 'email': user_profile.email, - 'urs-access-token': user_profile.token, - 'urs-groups': user_profile.groups, - 'iat': now, - 'exp': now + self.session_ttl + "urs-user-id": user_profile.user_id, + "first_name": user_profile.first_name, + "last_name": user_profile.last_name, + "email": user_profile.email, + "urs-access-token": user_profile.token, + "urs-groups": user_profile.groups, + "iat": now, + "exp": now + self.session_ttl, } def _in_blacklist(self, user_profile: UserProfile): @@ -115,7 +118,10 @@ def _in_blacklist(self, user_profile: UserProfile): return True return False - def get_profile_from_headers(self, headers) -> Optional[UserProfile]: + def get_profile_from_headers( + self, + headers: Mapping[str, str], + ) -> Optional[UserProfile]: """Inspects headers for auth cookie and return user_profile if authenticated, None otherwise""" auth_cookie = self._get_auth_cookie(headers) if not auth_cookie: @@ -130,22 +136,26 @@ def get_profile_from_headers(self, headers) -> Optional[UserProfile]: return None return user_profile - def get_header_to_set_auth_cookie(self, user_profile: Optional[UserProfile], cookie_domain=''): - """ Gets a header to set auth-cookie + def get_header_to_set_auth_cookie( + self, + user_profile: Optional[UserProfile], + cookie_domain: str = "", + ): + """Gets a header to set auth-cookie Parameters: UserProfile: UserProfile to use in construction of a cookie, if none will return header to unset/logout """ payload = self._jwt_payload_from_user_profile(user_profile) - cookie_value = self._encode_jwt(payload) if payload else 'expired' - cookie_domain = f'; Domain={cookie_domain}' if cookie_domain else '' + cookie_value = self._encode_jwt(payload) if payload else "expired" + cookie_domain = f"; Domain={cookie_domain}" if cookie_domain else "" if payload: - expire_date = format_7231_date(payload['exp']) + expire_date = format_7231_date(payload["exp"]) else: - expire_date = 'Thu, 01 Jan 1970 00:00:00 GMT' + expire_date = "Thu, 01 Jan 1970 00:00:00 GMT" return { - 'SET-COOKIE': ( - f'{self.cookie_name}={cookie_value}; Expires={expire_date}; Path=/{cookie_domain}; Secure; ' - 'HttpOnly; SameSite=Lax' + "SET-COOKIE": ( + f"{self.cookie_name}={cookie_value}; Expires={expire_date}; Path=/{cookie_domain}; Secure; " + "HttpOnly; SameSite=Lax" ) } diff --git a/rain_api_core/aws_util.py b/src/rain_api_core/aws_util.py similarity index 60% rename from rain_api_core/aws_util.py rename to src/rain_api_core/aws_util.py index 7170bd5..2754f12 100644 --- a/rain_api_core/aws_util.py +++ b/src/rain_api_core/aws_util.py @@ -4,6 +4,7 @@ import os import urllib.request from time import time +from typing import Optional from boto3 import Session as boto_Session from boto3 import client as botoclient @@ -18,15 +19,15 @@ from rain_api_core.general_util import duration, return_timing_object log = logging.getLogger(__name__) -sts = botoclient('sts') +sts = botoclient("sts") session_cache = {} region_list_cache = [] s3_resource = None -region = '' +region = "" botosess = botosession.Session() role_creds_cache = { - os.getenv('EGRESS_APP_DOWNLOAD_ROLE_INREGION_ARN'): {}, - os.getenv('EGRESS_APP_DOWNLOAD_ROLE_ARN'): {} + os.getenv("EGRESS_APP_DOWNLOAD_ROLE_INREGION_ARN"): {}, + os.getenv("EGRESS_APP_DOWNLOAD_ROLE_ARN"): {}, } @@ -36,7 +37,7 @@ def get_region() -> str: :return: string describing AWS region :type: string """ - global region # pylint: disable=global-statement + global region # pylint: disable=global-statement global botosess # pylint: disable=global-statement if not region: region = botosess.region_name @@ -45,16 +46,16 @@ def get_region() -> str: @functools.lru_cache(maxsize=None) def retrieve_secret(secret_name: str) -> dict: - global region # pylint: disable=global-statement + global region # pylint: disable=global-statement global botosess # pylint: disable=global-statement t0 = time() - region_name = os.getenv('AWS_DEFAULT_REGION') + region_name = os.getenv("AWS_DEFAULT_REGION") # Create a Secrets Manager client client = botosess.client( - service_name='secretsmanager', - region_name=region_name + service_name="secretsmanager", + region_name=region_name, ) # In this sample we only handle the specific exceptions for the 'GetSecretValue' API. @@ -63,23 +64,25 @@ def retrieve_secret(secret_name: str) -> dict: try: timer = time() - get_secret_value_response = client.get_secret_value( - SecretId=secret_name + get_secret_value_response = client.get_secret_value(SecretId=secret_name) + log.info( + return_timing_object( + service="secretsmanager", + endpoint=f"client().get_secret_value({secret_name})", + duration=duration(timer), + ) ) - log.info(return_timing_object( - service="secretsmanager", - endpoint=f"client().get_secret_value({secret_name})", - duration=duration(timer) - )) except ClientError as e: log.error("Encountered fatal error trying to read URS Secret: {0}".format(e)) raise e else: # Decrypts secret using the associated KMS CMK. # Depending on whether the secret is a string or binary, one of these fields will be populated. - if 'SecretString' in get_secret_value_response: - secret = json.loads(get_secret_value_response['SecretString']) - log.debug(f'ET for retrieving secret {secret_name} from secret store: {time() - t0:.4f} sec') + if "SecretString" in get_secret_value_response: + secret = json.loads(get_secret_value_response["SecretString"]) + log.debug( + f"ET for retrieving secret {secret_name} from secret store: {time() - t0:.4f} sec" + ) return secret return {} @@ -94,15 +97,15 @@ def get_s3_resource() -> boto_Session.resource: if not s3_resource: params = {} # Swift signature compatability - signature_version = os.getenv('S3_SIGNATURE_VERSION') + signature_version = os.getenv("S3_SIGNATURE_VERSION") if signature_version: - params['config'] = bc_Config(signature_version=signature_version) - s3_resource = botoresource('s3', **params) + params["config"] = bc_Config(signature_version=signature_version) + s3_resource = botoresource("s3", **params) return s3_resource -def read_s3(bucket: str, key: str, s3: ServiceResource = None) -> str: +def read_s3(bucket: str, key: str, s3: Optional[ServiceResource] = None) -> str: """ returns file :type bucket: str @@ -115,19 +118,21 @@ def read_s3(bucket: str, key: str, s3: ServiceResource = None) -> str: :return: str """ if not s3: - log.warning('creating a S3 resource in read_s3() function') + log.warning("creating a S3 resource in read_s3() function") s3 = get_s3_resource() t0 = time() log.info("Downloading config file {0} from s3://{1}...".format(key, bucket)) obj = s3.Object(bucket, key) - log.debug('ET for reading {} from S3: {} sec'.format(key, round(time() - t0, 4))) + log.debug("ET for reading {} from S3: {} sec".format(key, round(time() - t0, 4))) timer = time() - body = obj.get()['Body'].read().decode('utf-8') - log.info(return_timing_object( - service="s3", - endpoint=f"resource().Object(s3://{bucket}/{key}).get()", - duration=duration(timer) - )) + body = obj.get()["Body"].read().decode("utf-8") + log.info( + return_timing_object( + service="s3", + endpoint=f"resource().Object(s3://{bucket}/{key}).get()", + duration=duration(timer), + ) + ) return body @@ -142,7 +147,9 @@ def get_yaml(bucket: str, file_name: str) -> dict: cfg_yaml = read_s3(bucket, file_name) return safe_load(cfg_yaml) except ClientError as e: - log.error('Could not download yaml file s3://{}/{}, {}'.format(bucket, file_name, e)) + log.error( + "Could not download yaml file s3://{}/{}, {}".format(bucket, file_name, e) + ) raise @@ -155,7 +162,7 @@ def get_yaml_file(bucket: str, key: str) -> dict: return get_yaml(bucket, key) -def get_role_creds(user_id: str = None, in_region: bool = False): +def get_role_creds(user_id: Optional[str] = None, in_region: bool = False): """ :param user_id: string with URS username :param in_region: boolean If True a download role that works only in region will be returned @@ -165,12 +172,12 @@ def get_role_creds(user_id: str = None, in_region: bool = False): """ global sts # pylint: disable=global-statement if not user_id: - user_id = 'unauthenticated' + user_id = "unauthenticated" if in_region: - download_role_arn = os.getenv('EGRESS_APP_DOWNLOAD_ROLE_INREGION_ARN') + download_role_arn = os.getenv("EGRESS_APP_DOWNLOAD_ROLE_INREGION_ARN") else: - download_role_arn = os.getenv('EGRESS_APP_DOWNLOAD_ROLE_ARN') + download_role_arn = os.getenv("EGRESS_APP_DOWNLOAD_ROLE_ARN") dl_arn_name = download_role_arn.split("/")[-1] # chained role assumption like this CANNOT currently be extended past 1 Hour. @@ -179,46 +186,71 @@ def get_role_creds(user_id: str = None, in_region: bool = False): session_params = { "RoleArn": download_role_arn, "RoleSessionName": f"{user_id}@{round(now)}", - "DurationSeconds": 3600 + "DurationSeconds": 3600, } session_offset = 0 if user_id not in role_creds_cache[download_role_arn]: fresh_session = sts.assume_role(**session_params) - log.info(return_timing_object( - service="sts", - endpoint=f"client().assume_role({dl_arn_name}/{user_id})", - duration=duration(now) - )) - role_creds_cache[download_role_arn][user_id] = {"session": fresh_session, "timestamp": now} + log.info( + return_timing_object( + service="sts", + endpoint=f"client().assume_role({dl_arn_name}/{user_id})", + duration=duration(now), + ) + ) + role_creds_cache[download_role_arn][user_id] = { + "session": fresh_session, + "timestamp": now, + } elif now - role_creds_cache[download_role_arn][user_id]["timestamp"] > 600: # If the session has been active for more than 10 minutes, grab a new one. log.info("Replacing 10 minute old session for {0}".format(user_id)) fresh_session = sts.assume_role(**session_params) - log.info(return_timing_object(service="sts", endpoint="client().assume_role()", duration=duration(now))) - role_creds_cache[download_role_arn][user_id] = {"session": fresh_session, "timestamp": now} + log.info( + return_timing_object( + service="sts", endpoint="client().assume_role()", duration=duration(now) + ) + ) + role_creds_cache[download_role_arn][user_id] = { + "session": fresh_session, + "timestamp": now, + } else: log.info("Reusing role credentials for {0}".format(user_id)) - session_offset = round(now - role_creds_cache[download_role_arn][user_id]["timestamp"]) + session_offset = round( + now - role_creds_cache[download_role_arn][user_id]["timestamp"] + ) - log.debug(f'assuming role: {0}, role session username: {1}'.format(download_role_arn, user_id)) + log.debug( + f"assuming role: {0}, role session username: {1}".format( + download_role_arn, user_id + ) + ) return role_creds_cache[download_role_arn][user_id]["session"], session_offset -def get_role_session(creds: dict = None, user_id: str = None) -> boto_Session: +def get_role_session( + creds: Optional[dict] = None, + user_id: Optional[str] = None, +) -> boto_Session: global session_cache # pylint: disable=global-statement sts_resp = creds if creds else get_role_creds(user_id)[0] - log.debug('sts_resp: {0}'.format(sts_resp)) + log.debug("sts_resp: {0}".format(sts_resp)) - session_id = sts_resp['AssumedRoleUser']['AssumedRoleId'] + session_id = sts_resp["AssumedRoleUser"]["AssumedRoleId"] if session_id not in session_cache: now = time() session_cache[session_id] = boto_Session( - aws_access_key_id=sts_resp['Credentials']['AccessKeyId'], - aws_secret_access_key=sts_resp['Credentials']['SecretAccessKey'], - aws_session_token=sts_resp['Credentials']['SessionToken'] + aws_access_key_id=sts_resp["Credentials"]["AccessKeyId"], + aws_secret_access_key=sts_resp["Credentials"]["SecretAccessKey"], + aws_session_token=sts_resp["Credentials"]["SessionToken"], + ) + log.info( + return_timing_object( + service="boto3", endpoint="boto3.session()", duration=duration(now) + ) ) - log.info(return_timing_object(service="boto3", endpoint="boto3.session()", duration=duration(now))) else: log.info("Reusing session {0}".format(session_id)) return session_cache[session_id] @@ -228,19 +260,22 @@ def get_region_cidr_ranges() -> list: """ :return: Utility function to download AWS regions """ - global region_list_cache # pylint: disable=global-statement + global region_list_cache # pylint: disable=global-statement if not region_list_cache: # pylint: disable=used-before-assignment - url = 'https://ip-ranges.amazonaws.com/ip-ranges.json' + url = "https://ip-ranges.amazonaws.com/ip-ranges.json" now = time() req = urllib.request.Request(url) r = urllib.request.urlopen(req).read() # nosec URL is *always* https://ip-ranges... - log.info(return_timing_object(service="AWS", endpoint=url, duration=duration(now))) - region_list_json = json.loads(r.decode('utf-8')) + log.info( + return_timing_object(service="AWS", endpoint=url, duration=duration(now)) + ) + region_list_json = json.loads(r.decode("utf-8")) # Sort out ONLY values from this AWS region this_region = get_region() region_list_cache = [ - IPNetwork(pre["ip_prefix"]) for pre in region_list_json["prefixes"] + IPNetwork(pre["ip_prefix"]) + for pre in region_list_json["prefixes"] if "ip_prefix" in pre and "region" in pre and pre["region"] == this_region ] diff --git a/rain_api_core/bucket_map.py b/src/rain_api_core/bucket_map.py similarity index 91% rename from rain_api_core/bucket_map.py rename to src/rain_api_core/bucket_map.py index 811d5d4..2a37d6d 100644 --- a/rain_api_core/bucket_map.py +++ b/src/rain_api_core/bucket_map.py @@ -1,6 +1,7 @@ from collections import defaultdict +from collections.abc import Generator, Iterable, Sequence from dataclasses import dataclass, field -from typing import Generator, Iterable, Optional, Sequence, Tuple +from typing import Optional # By default, buckets are accessible to any logged in users. This is # represented by an empty set. @@ -8,8 +9,8 @@ def _is_accessible( - required_groups: Optional[set], - groups: Optional[Iterable[str]] + required_groups: Optional[set[str]], + groups: Optional[Iterable[str]], ) -> bool: # Check for public access if required_groups is None: @@ -23,14 +24,14 @@ def _is_accessible( @dataclass() -class BucketMapEntry(): +class BucketMapEntry: bucket: str bucket_path: str object_key: str headers: dict = field(default_factory=dict) _access_control: Optional[dict] = None - def is_accessible(self, groups: Iterable[str] = None) -> bool: + def is_accessible(self, groups: Optional[Iterable[str]] = None) -> bool: """Check if the object is accessible with the given permissions. Setting `groups` to an iterable implies that the user has logged in, @@ -41,7 +42,7 @@ def is_accessible(self, groups: Iterable[str] = None) -> bool: required_groups = self.get_required_groups() return _is_accessible(required_groups, groups) - def get_required_groups(self) -> Optional[set]: + def get_required_groups(self) -> Optional[set[str]]: """Get a set of permissions protecting this object. It is sufficient to have one of the permissions in the set in order to @@ -60,13 +61,13 @@ def get_required_groups(self) -> Optional[set]: return _DEFAULT_PERMISSION_FACTORY() -class BucketMap(): +class BucketMap: def __init__( self, bucket_map: dict, bucket_name_prefix: str = "", reverse: bool = False, - iam_compatible: bool = True + iam_compatible: bool = True, ): self.bucket_map = bucket_map self.access_control = _parse_access_control(bucket_map) @@ -121,21 +122,21 @@ def get_path(self, path: Sequence[str]) -> Optional[BucketMapEntry]: bucket=bucket, bucket_path=bucket_path, object_key=object_key, - headers=headers + headers=headers, ) return None - def entries(self): + def entries(self) -> Generator[BucketMapEntry]: for bucket, path_parts, headers in _walk_entries(self._get_map()): yield self._make_entry( bucket=bucket, bucket_path="/".join(path_parts), object_key="", - headers=headers + headers=headers, ) - def to_iam_policy(self, groups: Iterable[str] = None) -> dict: + def to_iam_policy(self, groups: Optional[Iterable[str]] = None) -> Optional[dict]: if not self._iam_compatible: _check_iam_compatible(self.access_control) generator = IamPolicyGenerator(groups) @@ -150,8 +151,8 @@ def _make_entry( bucket: str, bucket_path: str, object_key: str, - headers: Optional[dict] = None - ): + headers: Optional[dict] = None, + ) -> BucketMapEntry: return BucketMapEntry( bucket=self.bucket_name_prefix + bucket, bucket_path=bucket_path, @@ -160,11 +161,11 @@ def _make_entry( # TODO(reweeden): Do we really want to control access by # bucket? Wouldn't it make more sense to control access by # path instead? - _access_control=self.access_control.get(bucket) + _access_control=self.access_control.get(bucket), ) -def _walk_entries(node: dict, path=()) -> Generator[Tuple[str, tuple, Optional[dict]], None, None]: +def _walk_entries(node: dict, path=()) -> Generator[tuple[str, tuple, Optional[dict]]]: """A generator to recursively yield all leaves of a bucket map""" for key, val in node.items(): @@ -221,7 +222,7 @@ def _parse_access_control(bucket_map: dict) -> dict: # Convert to dictionary for easier lookup on individual buckets # We're relying on python's dictionary keys being insertion ordered access = defaultdict(dict) - for (rule, obj) in access_list: + for rule, obj in access_list: bucket, *prefix = rule.split("/", 1) access[bucket]["".join(prefix)] = obj @@ -274,12 +275,13 @@ def _get_longest_prefix(key: str, prefixes: Iterable[str]) -> Optional[str]: # generated bucketmap that makes heavy use of prefix permissions longest_prefix, _ = max( ( + # ruff hint (k, len(k)) for k in prefixes if key.startswith(k) and key != k ), key=lambda x: x[1], - default=(None, 0) + default=(None, 0), ) return longest_prefix @@ -294,7 +296,7 @@ def _access_text(access) -> str: class IamPolicyGenerator: - def __init__(self, groups: Iterable[str]): + def __init__(self, groups: Optional[Iterable[str]]): self.groups = groups def _is_accessible(self, required_groups: Optional[set]) -> bool: @@ -303,6 +305,7 @@ def _is_accessible(self, required_groups: Optional[set]) -> bool: def generate_policy(self, entries: Iterable[BucketMapEntry]) -> Optional[dict]: # Dedupe across buckets bucket_access = { + # ruff hint entry.bucket: entry._access_control for entry in entries } @@ -319,7 +322,9 @@ def generate_policy(self, entries: Iterable[BucketMapEntry]) -> Optional[dict]: get_object_statement.add_action("s3:ListBucket") get_object_statement.add_resource(f"arn:aws:s3:::{bucket}") - get_object_statement.add_resource(f"arn:aws:s3:::{bucket}/{key_prefix}*") + get_object_statement.add_resource( + f"arn:aws:s3:::{bucket}/{key_prefix}*", + ) if not get_object_statement.resource: return None @@ -340,13 +345,13 @@ def generate_policy(self, entries: Iterable[BucketMapEntry]) -> Optional[dict]: resource=[f"arn:aws:s3:::{bucket}" for bucket in buckets], condition={ "StringLike": { - "s3:prefix": [f"{prefix}*" for prefix in prefixes] - } - } + "s3:prefix": [f"{prefix}*" for prefix in prefixes], + }, + }, ).to_dict() for buckets, prefixes in list_bucket_conditions.items() - ) - ] + ), + ], } def _consolidate_access_rules(self, access_control: Optional[dict]) -> dict: @@ -394,8 +399,8 @@ def __init__( ): self.effect = effect # Using dict instead of set because sets are unordered. - self.action = dict((val, None) for val in action) - self.resource = dict((val, None) for val in resource) + self.action = {val: None for val in action} + self.resource = {val: None for val in resource} self.condition = condition def add_action(self, value: str): @@ -415,7 +420,7 @@ def to_dict(self) -> dict: statement = { "Effect": self.effect, "Action": list(self.action), - "Resource": list(self.resource) + "Resource": list(self.resource), } if self.condition is not None: statement["Condition"] = self.condition diff --git a/rain_api_core/edl.py b/src/rain_api_core/edl.py similarity index 79% rename from rain_api_core/edl.py rename to src/rain_api_core/edl.py index 8b1c813..2e122d1 100644 --- a/rain_api_core/edl.py +++ b/src/rain_api_core/edl.py @@ -32,8 +32,8 @@ class EdlClient: def __init__( self, base_url: str = os.getenv( - 'AUTH_BASE_URL', - 'https://urs.earthdata.nasa.gov', + "AUTH_BASE_URL", + "https://urs.earthdata.nasa.gov", ), ): self.base_url = base_url @@ -48,9 +48,9 @@ def request( ) -> dict: if params: params_encoded = urllib.parse.urlencode(params) - url_params = f'?{params_encoded}' + url_params = f"?{params_encoded}" else: - url_params = '' + url_params = "" # Separate variables so we can log the url without params url = urllib.parse.urljoin(self.base_url, endpoint) @@ -69,36 +69,36 @@ def request( ) log.debug( - 'Request(url=%r, data=%r, headers=%r)', + "Request(url=%r, data=%r, headers=%r)", url_with_params, data, headers, ) timer = Timer() - timer.mark(f'urlopen({url})') + timer.mark(f"urlopen({url})") try: with urllib.request.urlopen(request) as f: payload = f.read() - timer.mark('json.loads()') + timer.mark("json.loads()") msg = json.loads(payload) timer.mark() log.info( return_timing_object( - service='EDL', + service="EDL", endpoint=url, duration=timer.total.duration() * 1000, - unit='milliseconds', + unit="milliseconds", ), ) timer.log_all(log) return msg except urllib.error.URLError as e: - log.error('Error hitting endpoint %s: %s', url, e) + log.error("Error hitting endpoint %s: %s", url, e) timer.mark() - log.debug('ET for the attempt: %.4f', timer.total.duration()) + log.debug("ET for the attempt: %.4f", timer.total.duration()) self._parse_edl_error(e) except json.JSONDecodeError as e: @@ -110,18 +110,18 @@ def _parse_edl_error(self, e: urllib.error.URLError): try: msg = json.loads(payload) except json.JSONDecodeError: - log.error('Could not get json message from payload: %s', payload) + log.error("Could not get json message from payload: %s", payload) msg = {} if ( e.code in (403, 401) - and 'error_description' in msg - and 'eula' in msg['error_description'].lower() + and "error_description" in msg + and "eula" in msg["error_description"].lower() ): # sample json in this case: # `{"status_code": 403, "error_description": "EULA Acceptance Failure", # "resolution_url": "http://uat.urs.earthdata.nasa.gov/approve_app?client_id=LqWhtVpLmwaD4VqHeoN7ww"}` - log.warning('user needs to sign the EULA') + log.warning("user needs to sign the EULA") raise EulaException(e, msg, payload) else: payload = None diff --git a/rain_api_core/egress_util.py b/src/rain_api_core/egress_util.py similarity index 64% rename from rain_api_core/egress_util.py rename to src/rain_api_core/egress_util.py index 757d247..e779da3 100644 --- a/rain_api_core/egress_util.py +++ b/src/rain_api_core/egress_util.py @@ -30,37 +30,41 @@ def get_presigned_url( region_name, expire_seconds, user_id, - method='GET', - api_request_uuid=None + method="GET", + api_request_uuid=None, ) -> str: - timez = datetime.utcnow().strftime('%Y%m%dT%H%M%SZ') + timez = datetime.utcnow().strftime("%Y%m%dT%H%M%SZ") datez = timez[:8] region_id = "." + region_name if region_name != "us-east-1" else "" hostname = f"{bucket_name}.s3{region_id}.amazonaws.com" object_name = urllib.parse.quote(object_name) - cred = session['Credentials']['AccessKeyId'] - secret = session['Credentials']['SecretAccessKey'] - token = session['Credentials']['SessionToken'] + cred = session["Credentials"]["AccessKeyId"] + secret = session["Credentials"]["SecretAccessKey"] + token = session["Credentials"]["SessionToken"] aws4_request = "/".join([datez, region_name, "s3", "aws4_request"]) cred_string = f"{cred}/{aws4_request}" - can_query_string = "&".join([ - f"A-userid={user_id}", - "X-Amz-Algorithm=AWS4-HMAC-SHA256", - "X-Amz-Credential=" + urllib.parse.quote_plus(cred_string), - "X-Amz-Date=" + timez, - f"X-Amz-Expires={expire_seconds}", - "X-Amz-Security-Token=" + urllib.parse.quote_plus(token), - "X-Amz-SignedHeaders=host" - ]) + can_query_string = "&".join( + [ + f"A-userid={user_id}", + "X-Amz-Algorithm=AWS4-HMAC-SHA256", + "X-Amz-Credential=" + urllib.parse.quote_plus(cred_string), + "X-Amz-Date=" + timez, + f"X-Amz-Expires={expire_seconds}", + "X-Amz-Security-Token=" + urllib.parse.quote_plus(token), + "X-Amz-SignedHeaders=host", + ] + ) if api_request_uuid is not None: - can_query_string = "&".join([ - f"A-api-request-uuid={api_request_uuid}", - can_query_string, - ]) + can_query_string = "&".join( + [ + f"A-api-request-uuid={api_request_uuid}", + can_query_string, + ] + ) can_request = ( f"{method}\n" @@ -72,12 +76,9 @@ def get_presigned_url( ) can_request_hash = sha256(can_request.encode()).hexdigest() - string_to_sign = "\n".join([ - "AWS4-HMAC-SHA256", - timez, - aws4_request, - can_request_hash - ]) + string_to_sign = "\n".join( + ["AWS4-HMAC-SHA256", timez, aws4_request, can_request_hash] + ) step_one = hmacsha256(f"AWS4{secret}".encode(), datez).digest() step_two = hmacsha256(step_one, region_name).digest() diff --git a/rain_api_core/general_util.py b/src/rain_api_core/general_util.py similarity index 95% rename from rain_api_core/general_util.py rename to src/rain_api_core/general_util.py index 70e3055..be2d5f1 100644 --- a/rain_api_core/general_util.py +++ b/src/rain_api_core/general_util.py @@ -10,7 +10,7 @@ def return_timing_object(**timing): "endpoint": "Unknown", "method": "GET", "duration": 0, - "unit": "milliseconds" + "unit": "milliseconds", } timing_object.update({k.lower(): v for k, v in timing.items()}) return {"timing": timing_object} diff --git a/rain_api_core/logging.py b/src/rain_api_core/logging.py similarity index 78% rename from rain_api_core/logging.py rename to src/rain_api_core/logging.py index a23d60b..69ff82a 100644 --- a/rain_api_core/logging.py +++ b/src/rain_api_core/logging.py @@ -10,48 +10,50 @@ { "regex": r"(eyJ[A-Za-z0-9-_]{12})[A-Za-z0-9-_]*\.[A-Za-z0-9-_]*\.[A-Za-z0-9-_]*([A-Za-z0-9-_]{10})", "replace": "\\g<1>XXXXXX\\g<2>", - "description": "X-out JWT Token payload" + "description": "X-out JWT Token payload", }, { "regex": r"(EDL-[A-Za-z0-9]+)[A-Za-z0-9]{40}([A-Za-z0-9]{10})", "replace": "\\g<1>XXXXXX\\g<2>", - "description": "X-out non-JWT EDL token" + "description": "X-out non-JWT EDL token", }, { "regex": r"(Basic )[A-Za-z0-9+/=]{4,}", "replace": "\\g<1>XXXXXX", - "description": "X-out Basic Auth Credentials" + "description": "X-out Basic Auth Credentials", }, { "regex": r"([^A-Za-z0-9/+=][A-Za-z0-9/+=]{5})[A-Za-z0-9/+=]{30}([A-Za-z0-9/+=]{5}[^A-Za-z0-9/+=])", "replace": "\\g<1>XXXXXX\\g<2>", - "description": "X-out AWS Secret" - } + "description": "X-out AWS Secret", + }, ] def get_log(): - loglevel = os.getenv('LOGLEVEL', 'INFO') - logtype = os.getenv('LOGTYPE', 'json') - if logtype == 'flat': + loglevel = os.getenv("LOGLEVEL", "INFO") + logtype = os.getenv("LOGTYPE", "json") + if logtype == "flat": formatter = LogCensorFormatter( "%(levelname)s: %(message)s (%(filename)s line %(lineno)d/%(build_vers)s/%(maturity)s) - " "RequestId: %(request_id)s; OriginRequestId: %(origin_request_id)s; user_id: %(user_id)s; route: %(route)s" ) else: - formatter = JSONFormatter({ - "level": "%(levelname)s", - "RequestId": "%(request_id)s", - "OriginRequestId": "%(origin_request_id)s", - "message": "%(message)s", - "maturity": "%(maturity)s", - "user_id": "%(user_id)s", - "route": "%(route)s", - "build": "%(build_vers)s", - "filename": "%(filename)s", - "lineno": "%(lineno)s", - "exception": "%(exc_obj)s" - }) + formatter = JSONFormatter( + { + "level": "%(levelname)s", + "RequestId": "%(request_id)s", + "OriginRequestId": "%(origin_request_id)s", + "message": "%(message)s", + "maturity": "%(maturity)s", + "user_id": "%(user_id)s", + "route": "%(route)s", + "build": "%(build_vers)s", + "filename": "%(filename)s", + "lineno": "%(lineno)s", + "exception": "%(exc_obj)s", + } + ) logger = logging.getLogger() @@ -65,19 +67,19 @@ def get_log(): logger.addHandler(handler) logger.setLevel(loglevel) - if os.getenv("QUIETBOTO", 'TRUE').upper() == 'TRUE': + if os.getenv("QUIETBOTO", "TRUE").upper() == "TRUE": # BOTO, be quiet plz - logging.getLogger('boto3').setLevel(logging.ERROR) - logging.getLogger('botocore').setLevel(logging.ERROR) - logging.getLogger('nose').setLevel(logging.ERROR) - logging.getLogger('elasticsearch').setLevel(logging.ERROR) - logging.getLogger('s3transfer').setLevel(logging.ERROR) - logging.getLogger('urllib3').setLevel(logging.ERROR) - logging.getLogger('connectionpool').setLevel(logging.ERROR) + logging.getLogger("boto3").setLevel(logging.ERROR) + logging.getLogger("botocore").setLevel(logging.ERROR) + logging.getLogger("nose").setLevel(logging.ERROR) + logging.getLogger("elasticsearch").setLevel(logging.ERROR) + logging.getLogger("s3transfer").setLevel(logging.ERROR) + logging.getLogger("urllib3").setLevel(logging.ERROR) + logging.getLogger("connectionpool").setLevel(logging.ERROR) return logger -class PercentPlaceholder(): +class PercentPlaceholder: """A placeholder in a log format object The placeholder can be formatted with the % operator. @@ -87,7 +89,7 @@ class PercentPlaceholder(): 'hello' """ - __slots__ = ("name", ) + __slots__ = ("name",) def __init__(self, name: str): self.name = name @@ -97,7 +99,7 @@ def __mod__(self, args): return args[self.name] -class JSONPercentStyle(): +class JSONPercentStyle: """Format log records into a JSON object (dict, list) using percent formatting The `fmt` dict will be searched for percent formatting strings. When a value in @@ -120,7 +122,9 @@ class JSONPercentStyle(): default_format = {"message": "%(message)s"} placeholder_pattern = re.compile(r"^%\((\w+)\)s$") - validation_pattern = re.compile(r'%\(\w+\)[#0+ -]*(\*|\d+)?(\.(\*|\d+))?[diouxefgcrsa%]', re.I) + validation_pattern = re.compile( + r"%\(\w+\)[#0+ -]*(\*|\d+)?(\.(\*|\d+))?[diouxefgcrsa%]", re.I + ) def __init__(self, fmt: dict): self._fmt = self._convert_placeholders(fmt or self.default_format) @@ -129,6 +133,7 @@ def __init__(self, fmt: dict): def _convert_placeholders(self, obj): """Convert '%(name)s' values into PercentPlaceholder objects""" + def func(obj): if isinstance(obj, str): m = self.placeholder_pattern.match(obj) @@ -185,10 +190,16 @@ def format(self, record: logging.LogRecord) -> str: if self.usesTime(): record.asctime = self.formatTime(record, self.datefmt) - record.exc_obj = self.formatException(record.exc_info).split("\n") if record.exc_info else None + record.exc_obj = ( + self.formatException(record.exc_info).split("\n") + if record.exc_info + else None + ) obj = self.formatMessage(record) - assert not any(isinstance(val, PercentPlaceholder) for val in _iter_json_values(obj)) + assert not any( + isinstance(val, PercentPlaceholder) for val in _iter_json_values(obj) + ) return filter_log_credentials(json.dumps(obj, default=str)) @@ -204,7 +215,7 @@ def __init__(self, *args, **kwargs): "request_id": None, "origin_request_id": None, "user_id": None, - "route": None + "route": None, } def filter(self, record: logging.LogRecord): @@ -236,6 +247,7 @@ def filter_log_credentials(msg: str): # Helpers for traversing json like structures of nested dict/lists + def _fmt_json_val(val, args): if isinstance(val, (str, PercentPlaceholder)) and args: return val % args diff --git a/rain_api_core/timer.py b/src/rain_api_core/timer.py similarity index 90% rename from rain_api_core/timer.py rename to src/rain_api_core/timer.py index 58a65af..478fe74 100644 --- a/rain_api_core/timer.py +++ b/src/rain_api_core/timer.py @@ -1,11 +1,12 @@ import logging import time +from collections.abc import Callable from dataclasses import dataclass -from typing import Callable, Optional +from typing import Optional @dataclass(eq=False) -class Interval(): +class Interval: start: Optional[float] = None end: Optional[float] = None @@ -16,7 +17,7 @@ def duration(self) -> float: return self.end - self.start -class Timer(): +class Timer: """A helper for recording the times of a sequence of events. This object is not thread safe. @@ -28,7 +29,7 @@ def __init__(self, timer: Callable[[], float] = time.time): self.last_name: Optional[str] = None self.total = Interval() - def mark(self, name: str = None) -> float: + def mark(self, name: Optional[str] = None) -> float: """Record a new event. If called without `name`, any previously started event will be marked diff --git a/src/rain_api_core/urs_util.py b/src/rain_api_core/urs_util.py new file mode 100644 index 0000000..0ad9064 --- /dev/null +++ b/src/rain_api_core/urs_util.py @@ -0,0 +1,380 @@ +import logging +import os +from typing import Optional + +from rain_api_core.auth import JwtManager, UserProfile +from rain_api_core.aws_util import retrieve_secret +from rain_api_core.edl import EdlClient, EdlException +from rain_api_core.logging import log_context + +log = logging.getLogger(__name__) + + +def get_base_url(ctxt: Optional[dict] = None) -> str: + # Make a redirect url using optional custom domain_name, otherwise use raw domain/stage provided by API Gateway. + try: + domain = os.getenv("DOMAIN_NAME") or f"{ctxt['domainName']}/{ctxt['stage']}" + return f"https://{domain}/" + except (TypeError, KeyError) as e: + log.error("could not create a redirect_url, because {}".format(e)) + raise + + +def get_redirect_url(ctxt: Optional[dict] = None) -> str: + return f"{get_base_url(ctxt)}login" + + +def do_auth(code: str, redirect_url: str, aux_headers: dict = {}) -> dict: + # App U:P from URS Application + auth = get_urs_creds()["UrsAuth"] + + data = { + "grant_type": "authorization_code", + "code": code, + "redirect_uri": redirect_url, + } + + headers = {"Authorization": "Basic " + auth} + headers.update(aux_headers) + + client = EdlClient() + try: + return client.request( + "POST", + "/oauth/token", + data=data, + headers=headers, + ) + except EdlException: + return {} + + +def get_urs_url(ctxt: dict, to: Optional[str] = None) -> str: + base_url = ( + os.getenv("AUTH_BASE_URL", "https://urs.earthdata.nasa.gov") + + "/oauth/authorize" + ) + + # From URS Application + client_id = get_urs_creds()["UrsId"] + + log.debug("domain name: {0}".format(os.getenv("DOMAIN_NAME", "no domainname set"))) + log.debug( + "if no domain name set: {}.execute-api.{}.amazonaws.com/{}".format( + ctxt["apiId"], os.getenv("AWS_DEFAULT_REGION", ""), ctxt["stage"] + ) + ) + + urs_url = f"{base_url}?client_id={client_id}&response_type=code&redirect_uri={get_redirect_url(ctxt)}" + if to: + urs_url += f"&state={to}" + + # Try to handle scripts + try: + download_agent = ctxt["identity"]["userAgent"] + except KeyError: + log.debug("No User Agent!") + return urs_url + + if not download_agent.startswith("Mozilla"): + urs_url += "&app_type=401" + + return urs_url + + +def get_user_profile(urs_user_payload: dict, access_token) -> UserProfile: + return UserProfile( + user_id=urs_user_payload["uid"], + token=access_token, + groups=urs_user_payload["user_groups"], + first_name=urs_user_payload["first_name"], + last_name=urs_user_payload["last_name"], + email=urs_user_payload["email_address"], + ) + + +def get_profile( + user_id: str, + token: str, + temptoken: Optional[str] = None, + aux_headers: dict = {}, +) -> Optional[UserProfile]: + if not user_id or not token: + return None + + # get_new_token_and_profile() will pass this function a temporary token with + # which to fetch the profile info. We don't want to keep it around, just use + # it here, once: + if temptoken: + headertoken = temptoken + else: + headertoken = token + + headers = {"Authorization": "Bearer " + headertoken} + headers.update(aux_headers) + params = {"client_id": get_urs_creds()["UrsId"]} + + client = EdlClient() + try: + user_profile = client.request( + "GET", + f"/api/users/{user_id}", + params=params, + headers=headers, + ) + return get_user_profile(user_profile, headertoken) + except EdlException as e: + log.warning("Error fetching profile: %s", e.inner) + if ( + not temptoken + ): # This keeps get_new_token_and_profile() from calling this over and over + log.debug("because error above, going to get_new_token_and_profile()") + return get_new_token_and_profile(user_id, token, aux_headers) + + log.debug( + f"We got that 401 above and we're using a temptoken ({temptoken}), " + "so giving up and not getting a profile." + ) + return None + + +def get_new_token_and_profile( + user_id: str, + cookietoken: str, + aux_headers: dict = {}, +) -> Optional[UserProfile]: + # App U:P from URS Application + auth = get_urs_creds()["UrsAuth"] + data = {"grant_type": "client_credentials"} + + headers = {"Authorization": "Basic " + auth} + headers.update(aux_headers) + + client = EdlClient() + try: + log.info("Attempting to get new Token") + + response = client.request( + "POST", + "/oauth/token", + data=data, + headers=headers, + ) + new_token = response["access_token"] + + log.info("Retrieved new token: %s", new_token) + # Get user profile with new token + return get_profile( + user_id, + cookietoken, + new_token, + aux_headers=aux_headers, + ) + except EdlException: + return None + + +def user_in_group_list(private_groups: list, user_groups: list) -> bool: + client_id = get_urs_creds()["UrsId"] + log.info( + "Searching for private groups {0} in {1}".format(private_groups, user_groups) + ) + + group_names = { + group["name"] for group in user_groups if group["client_id"] == client_id + } + + for group in private_groups: + if group in group_names: + log.info("User belongs to private group {}".format(group)) + return True + return False + + +def user_in_group_urs( + private_groups, + user_id, + token, + user_profile=None, + refresh_first=False, + aux_headers=None, +): + aux_headers = aux_headers or {} # A safer default + new_profile = {} + + if refresh_first or not user_profile: + user_profile = get_profile(user_id, token, aux_headers=aux_headers) + new_profile = user_profile + + if ( + isinstance(user_profile, dict) + and "user_groups" in user_profile + and user_in_group_list(private_groups, user_profile["user_groups"]) + ): + log.info("User {0} belongs to private group".format(user_id)) + return True, new_profile + + # Couldn't find user in provided groups, but we may as well look at a fresh group list: + if not refresh_first: + # we have a maybe not so fresh user_profile and we could try again to see if someone added a group to this user: + log.debug( + f"Could not validate user {user_id} belonging to groups {private_groups}, attempting profile refresh" + ) + + return user_in_group_urs( + private_groups, user_id, {}, refresh_first=True, aux_headers=aux_headers + ) + log.debug( + "Even after profile refresh, user {0} does not belong to groups {1}".format( + user_id, private_groups + ) + ) + + return False, new_profile + + +def user_in_group( + private_groups, + user_profile: UserProfile, + refresh_first=False, + aux_headers=None, +): + aux_headers = aux_headers or {} # A safer default + + # If a new profile is fetched, it is assigned to this var, and returned so that a fresh jwt cookie can be set. + new_profile = None + + if not private_groups: + return False, new_profile + + if not user_profile: + return False, new_profile + + if refresh_first: + new_profile = get_profile( + user_profile.user_id, + user_profile.token, + aux_headers=aux_headers, + ) + user_profile.groups = new_profile.groups + + in_group = user_in_group_list(private_groups, user_profile.groups) + if in_group: + return True, new_profile + + if not in_group and not refresh_first: + # one last ditch effort to see if they were so very recently added to group: + user_profile = get_profile( + user_profile.user_id, user_profile.token, aux_headers=aux_headers + ) + return user_in_group( + private_groups, + user_profile, + refresh_first=True, + aux_headers=aux_headers, + ) + + return False, new_profile + + +def get_urs_creds() -> dict: + """ + Fetches URS creds from secrets manager. + :return: looks like: + { + "UrsId": "stringofseeminglyrandomcharacters", + "UrsAuth": "verymuchlongerstringofseeminglyrandomcharacters" + } + :type: dict + """ + secret_name = os.getenv("URS_CREDS_SECRET_NAME") + + if not secret_name: + log.error("URS_CREDS_SECRET_NAME not set") + return {} + + secret = retrieve_secret(secret_name) + if not ("UrsId" in secret and "UrsAuth" in secret): + log.error( + 'AWS secret {} does not contain required keys "UrsId" and "UrsAuth"'.format( + secret_name + ) + ) + + return secret + + +# This do_login() is mainly for chalice clients. +def do_login( + args, context, jwt_manager: JwtManager, cookie_domain="", aux_headers=None +): + aux_headers = aux_headers or {} # A safer default + + log.debug("the query_params: {}".format(args)) + + if not args: + template_vars = {"contentstring": "No params", "title": "Could Not Login"} + headers = {} + return 400, template_vars, headers + + if args.get("error", False): + contentstring = ( + 'An error occurred while trying to log into URS. URS says: "{}". '.format( + args.get("error", "") + ) + ) + template_vars = {"contentstring": contentstring, "title": "Could Not Login"} + if args.get("error") == "access_denied": + # This happens when user doesn't agree to EULA. Maybe other times too. + return_status = 401 + template_vars["contentstring"] = "Be sure to agree to the EULA." + template_vars["error_code"] = "EULA_failure" + else: + return_status = 400 + + return return_status, template_vars, {} + + if "code" not in args: + contentstring = "Did not get the required CODE from URS" + + template_vars = {"contentstring": contentstring, "title": "Could Not Login"} + headers = {} + return 400, template_vars, headers + + log.debug("pre-do_auth() query params: {}".format(args)) + redir_url = get_redirect_url(context) + auth = do_auth(args.get("code", ""), redir_url, aux_headers=aux_headers) + log.debug("auth: {}".format(auth)) + if not auth: + log.debug("no auth returned from do_auth()") + + template_vars = { + "contentstring": "There was a problem talking to URS Login", + "title": "Could Not Login", + } + + return 400, template_vars, {} + + user_id = auth["endpoint"].split("/")[-1] + log_context(user_id=user_id) + + user_profile = get_profile(user_id, auth["access_token"], aux_headers={}) + log.debug("Got the user profile: {}".format(user_profile)) + if user_profile is not None: + log.debug("urs-access-token: {}".format(auth["access_token"])) + if "state" in args: + redirect_to = args["state"] + else: + redirect_to = get_base_url(context) + + headers = {"Location": redirect_to} + headers.update( + jwt_manager.get_header_to_set_auth_cookie(user_profile, cookie_domain) + ) + return 301, {}, headers + + template_vars = { + "contentstring": "Could not get user profile from URS", + "title": "Could Not Login", + } + return 400, template_vars, {} diff --git a/rain_api_core/view_util.py b/src/rain_api_core/view_util.py similarity index 56% rename from rain_api_core/view_util.py rename to src/rain_api_core/view_util.py index d11ea28..e19526c 100644 --- a/rain_api_core/view_util.py +++ b/src/rain_api_core/view_util.py @@ -22,15 +22,15 @@ log = logging.getLogger(__name__) -HTML_TEMPLATE_STATUS = '' -HTML_TEMPLATE_LOCAL_CACHEDIR = '/tmp/templates/' # nosec We want to leverage instance persistance -HTML_TEMPLATE_PROJECT_DIR = Path().resolve() / 'templates' +HTML_TEMPLATE_STATUS = "" +HTML_TEMPLATE_LOCAL_CACHEDIR = "/tmp/templates/" # nosec We want to leverage instance persistance +HTML_TEMPLATE_PROJECT_DIR = Path().resolve() / "templates" _HOURS_PER_WEEK = 7 * 24 -SESSTTL = int(os.getenv('SESSION_TTL', _HOURS_PER_WEEK)) * 60 * 60 +SESSTTL = int(os.getenv("SESSION_TTL", _HOURS_PER_WEEK)) * 60 * 60 -JWT_ALGO = os.getenv('JWT_ALGO', 'RS256') -JWT_COOKIE_NAME = os.getenv('JWT_COOKIENAME', 'asf-urs') +JWT_ALGO = os.getenv("JWT_ALGO", "RS256") +JWT_COOKIE_NAME = os.getenv("JWT_COOKIENAME", "asf-urs") JWT_BLACKLIST = {} @@ -46,14 +46,16 @@ def __init__( self.bucket = bucket self.template_dir = template_dir self.jinja_env = Environment( - loader=FileSystemLoader([ - self.cache_dir, - HTML_TEMPLATE_PROJECT_DIR, - # For legacy compatibility with projects that don't install - # this module with pip and rely on this behavior - os.path.join(os.path.dirname(__file__), '../', 'templates') - ]), - autoescape=select_autoescape(['html', 'xml']) + loader=FileSystemLoader( + [ + self.cache_dir, + HTML_TEMPLATE_PROJECT_DIR, + # For legacy compatibility with projects that don't install + # this module with pip and rely on this behavior + os.path.join(os.path.dirname(__file__), "../", "templates"), + ] + ), + autoescape=select_autoescape(["html", "xml"]), ) self._downloaded = False @@ -62,75 +64,82 @@ def download_templates(self): try: os.mkdir(self.cache_dir, 0o700) except FileExistsError: - log.debug('%s already exists', self.cache_dir) + log.debug("%s already exists", self.cache_dir) if not self.bucket or not self.template_dir: return template_dir = self.template_dir - if not template_dir.endswith('/'): - template_dir = f'{template_dir}/' + if not template_dir.endswith("/"): + template_dir = f"{template_dir}/" # For logging - s3_uri = f's3://{self.bucket}/{template_dir}' + s3_uri = f"s3://{self.bucket}/{template_dir}" try: start = time() - client = botoclient('s3') + client = botoclient("s3") result = client.list_objects( Bucket=self.bucket, Prefix=template_dir, - Delimiter='/' + Delimiter="/", + ) + log.info( + return_timing_object( + service="s3", + endpoint=f"client().list_objects({s3_uri})", + duration=duration(start), + ) ) - log.info(return_timing_object( - service='s3', - endpoint=f'client().list_objects({s3_uri})', - duration=duration(start) - )) download_start = time() - for entry in result.get('Contents', []): - key = entry['Key'] + for entry in result.get("Contents", []): + key = entry["Key"] filename = os.path.basename(key) if not filename: continue local_path = os.path.join(self.cache_dir, filename) - log.debug('attempting to save %s', local_path) + log.debug("attempting to save %s", local_path) start = time() client.download_file(self.bucket, key, local_path) - log.info(return_timing_object( - service='s3', - endpoint=f'client().download_file({s3_uri}/{key})', - duration=duration(start) - )) - - log.debug('ET for download_templates: %.4fs', time() - download_start) + log.info( + return_timing_object( + service="s3", + endpoint=f"client().download_file({s3_uri}/{key})", + duration=duration(start), + ) + ) + + log.debug("ET for download_templates: %.4fs", time() - download_start) except Exception: - log.warning('Failed to download HTML templates from %s', s3_uri, exc_info=True) + log.warning( + "Failed to download HTML templates from %s", s3_uri, exc_info=True + ) finally: self._downloaded = True - def render(self, template_name: str = 'root.html', *args, **kwargs) -> str: + def render(self, template_name: str = "root.html", *args, **kwargs) -> str: if not self._downloaded: self.download_templates() try: template = self.jinja_env.get_template(template_name) except TemplateNotFound as e: - log.error('Template not found: %s', e) - return 'Cannot find the HTML template directory' + log.error("Template not found: %s", e) + return "Cannot find the HTML template directory" return template.render(*args, **kwargs) @functools.lru_cache(maxsize=None) def get_jwt_keys() -> dict: - raw_keys = retrieve_secret(os.getenv('JWT_KEY_SECRET_NAME', '')) + raw_keys = retrieve_secret(os.getenv("JWT_KEY_SECRET_NAME", "")) return { - k: base64.b64decode(v.encode('utf-8')) + # ruff hint + k: base64.b64decode(v.encode("utf-8")) for k, v in raw_keys.items() } @@ -149,9 +158,9 @@ def get_cookie_vars(headers: dict) -> dict: decoded_payload = decode_jwt_payload(cooks[JWT_COOKIE_NAME], JWT_ALGO) return {JWT_COOKIE_NAME: decoded_payload} else: - log.debug('could not find jwt cookie in get_cookie_vars()') + log.debug("could not find jwt cookie in get_cookie_vars()") except KeyError as e: - log.debug('Key error trying to get cookie vars: {}'.format(e)) + log.debug("Key error trying to get cookie vars: {}".format(e)) return {} @@ -165,7 +174,7 @@ def get_cookie_expiration_date_str() -> str: def get_cookies(hdrs: dict) -> dict: - cookie_string = hdrs.get('cookie') or hdrs.get('Cookie') or hdrs.get('COOKIE') + cookie_string = hdrs.get("cookie") or hdrs.get("Cookie") or hdrs.get("COOKIE") if not cookie_string: return {} @@ -174,6 +183,7 @@ def get_cookies(hdrs: dict) -> dict: cookie.load(cookie_string) return { + # ruff hint key: morsel.value for key, morsel in cookie.items() } @@ -181,61 +191,75 @@ def get_cookies(hdrs: dict) -> dict: def make_jwt_payload(payload: dict, algo: str = JWT_ALGO) -> str: try: - log.debug('using secret: {}'.format(os.getenv('JWT_KEY_SECRET_NAME', ''))) + log.debug("using secret: {}".format(os.getenv("JWT_KEY_SECRET_NAME", ""))) timer = time() - encoded = jwt.encode(payload, get_jwt_keys()['rsa_priv_key'], algorithm=algo) - log.info(return_timing_object(service="jwt", endpoint="jwt.encode()", duration=duration(timer))) + encoded = jwt.encode(payload, get_jwt_keys()["rsa_priv_key"], algorithm=algo) + log.info( + return_timing_object( + service="jwt", endpoint="jwt.encode()", duration=duration(timer) + ) + ) return encoded except KeyError as e: - log.error('jwt_keys may be malformed: ') + log.error("jwt_keys may be malformed: ") log.error(e) - return '' + return "" except (ValueError, AttributeError) as e: # TODO(reweeden): how can these error types possibly be triggered!? jwt.encode will raise a TypeError on bad # input, but never ValueError or AttributeError. - log.error('problem with encoding cookie: {}'.format(e)) - return '' + log.error("problem with encoding cookie: {}".format(e)) + return "" def decode_jwt_payload(jwt_payload: str, algo: str = JWT_ALGO) -> dict: try: - rsa_pub_key = get_jwt_keys()['rsa_pub_key'] + rsa_pub_key = get_jwt_keys()["rsa_pub_key"] timer = time() cookiedecoded = jwt.decode(jwt_payload, rsa_pub_key, [algo]) - log.info(return_timing_object(service="jwt", endpoint="jwt.decode()", duration=duration(timer))) + log.info( + return_timing_object( + service="jwt", endpoint="jwt.decode()", duration=duration(timer) + ) + ) except jwt.ExpiredSignatureError: # Signature has expired - log.info('JWT has expired') + log.info("JWT has expired") # TODO what more to do with this, if anything? return {} except jwt.InvalidSignatureError: - log.info('JWT has failed verification. returning empty dict') + log.info("JWT has failed verification. returning empty dict") return {} if os.getenv("BLACKLIST_ENDPOINT"): if is_jwt_blacklisted(cookiedecoded): return {} else: - log.debug('No environment variable BLACKLIST_ENDPOINT') + log.debug("No environment variable BLACKLIST_ENDPOINT") - log.debug('cookiedecoded {}'.format(cookiedecoded)) + log.debug("cookiedecoded {}".format(cookiedecoded)) return cookiedecoded def craft_cookie_domain_payloadpiece(cookie_domain: str) -> str: if cookie_domain: - return f'; Domain={cookie_domain}' + return f"; Domain={cookie_domain}" - return '' + return "" -def make_set_cookie_headers_jwt(payload: dict, expdate: str = '', cookie_domain: str = '') -> dict: +def make_set_cookie_headers_jwt( + payload: dict, + expdate: str = "", + cookie_domain: str = "", +) -> dict: jwt_payload = make_jwt_payload(payload) cookie_domain_payloadpiece = craft_cookie_domain_payloadpiece(cookie_domain) if not expdate: expdate = get_cookie_expiration_date_str() - headers = {'SET-COOKIE': f'{JWT_COOKIE_NAME}={jwt_payload}; Expires={expdate}; Path=/{cookie_domain_payloadpiece}'} + headers = { + "SET-COOKIE": f"{JWT_COOKIE_NAME}={jwt_payload}; Expires={expdate}; Path=/{cookie_domain_payloadpiece}" + } return headers @@ -248,13 +272,19 @@ def is_jwt_blacklisted(decoded_jwt: dict) -> bool: if user_blacklist_time is not None: jwt_mint_time = decoded_jwt["iat"] - log.debug(f"JWT was minted @: {jwt_mint_time}, the Blacklist is for cookies BEFORE: {user_blacklist_time}") + log.debug( + f"JWT was minted @: {jwt_mint_time}, the Blacklist is for cookies BEFORE: {user_blacklist_time}" + ) if user_blacklist_time >= jwt_mint_time: - log.info(f"User {urs_user_id}'s JWT was minted before blacklist date and is INVALID") + log.info( + f"User {urs_user_id}'s JWT was minted before blacklist date and is INVALID" + ) return True else: - log.info(f"User {urs_user_id}s JWT was minted AFTER blacklist date and is still VALID") + log.info( + f"User {urs_user_id}s JWT was minted AFTER blacklist date and is still VALID" + ) log.info(f"User {urs_user_id} is NOT in the blacklist") return False @@ -266,19 +296,25 @@ def is_jwt_blacklisted(decoded_jwt: dict) -> bool: def set_jwt_blacklist() -> dict: global JWT_BLACKLIST # pylint: disable=global-statement - if JWT_BLACKLIST and time() - JWT_BLACKLIST["timestamp"] <= (10 * 60): # If cached in the last 10 minutes + if JWT_BLACKLIST and time() - JWT_BLACKLIST["timestamp"] <= ( + 10 * 60 + ): # If cached in the last 10 minutes return JWT_BLACKLIST endpoint = os.getenv("BLACKLIST_ENDPOINT") # Bandit complains with B310 on the line below. We know the URL, this is safe! timer = time() - output = urllib.request.urlopen(endpoint).read().decode('utf-8') # nosec - log.info(return_timing_object(service="blacklist", endpoint=endpoint, duration=duration(timer))) + output = urllib.request.urlopen(endpoint).read().decode("utf-8") # nosec + log.info( + return_timing_object( + service="blacklist", endpoint=endpoint, duration=duration(timer) + ) + ) blacklist = json.loads(output)["blacklist"] contents = { "blacklist": blacklist, - "timestamp": time() + "timestamp": time(), } JWT_BLACKLIST = contents # Cache it diff --git a/tests/test_auth.py b/tests/test_auth.py index 1951907..1ccdfd5 100644 --- a/tests/test_auth.py +++ b/tests/test_auth.py @@ -6,79 +6,81 @@ from rain_api_core.auth import JwtManager, UserProfile -MODULE = 'rain_api_core.auth' +MODULE = "rain_api_core.auth" @pytest.fixture def jwt_manager(jwt_priv_key, jwt_pub_key): return JwtManager( - algorithm='RS256', + algorithm="RS256", public_key=jwt_pub_key, private_key=jwt_priv_key, - cookie_name='auth-cookie', + cookie_name="auth-cookie", ) def test_decode_jwt(jwt_manager, jwt_priv_key): - payload = {'foo': 'bar'} - encoded = jwt.encode(payload, jwt_priv_key, 'RS256') + payload = {"foo": "bar"} + encoded = jwt.encode(payload, jwt_priv_key, "RS256") assert jwt_manager._decode_jwt(encoded) == payload def test_decode_jwt_expired(jwt_manager, jwt_priv_key): - payload = {'foo': 'bar', 'exp': 0} - encoded = jwt.encode(payload, jwt_priv_key, 'RS256') + payload = {"foo": "bar", "exp": 0} + encoded = jwt.encode(payload, jwt_priv_key, "RS256") assert jwt_manager._decode_jwt(encoded) is None def test_decode_jwt_invalid(jwt_manager): - encoded = b".".join(( - jwt.utils.base64url_encode(b'{"alg": "RS256"}'), - jwt.utils.base64url_encode(b'{"not valid'), - jwt.utils.base64url_encode(b"some bytes"), - )).decode() + encoded = b".".join( + ( + jwt.utils.base64url_encode(b'{"alg": "RS256"}'), + jwt.utils.base64url_encode(b'{"not valid'), + jwt.utils.base64url_encode(b"some bytes"), + ) + ).decode() assert jwt_manager._decode_jwt(encoded) is None def test_get_auth_cookie(jwt_manager): - jwt_manager.cookie_name = 'auth-cookie' + jwt_manager.cookie_name = "auth-cookie" - headers = {'Cookie': 'auth-cookie=foo'} - assert jwt_manager._get_auth_cookie(headers).value == 'foo' - headers = {'Cookie': 'auth-cookie=foo; not-auth-cookie=bar'} - assert jwt_manager._get_auth_cookie(headers).value == 'foo' - headers = {'Cookie': 'not-auth-cookie=foo'} + headers = {"Cookie": "auth-cookie=foo"} + assert jwt_manager._get_auth_cookie(headers).value == "foo" + headers = {"Cookie": "auth-cookie=foo; not-auth-cookie=bar"} + assert jwt_manager._get_auth_cookie(headers).value == "foo" + headers = {"Cookie": "not-auth-cookie=foo"} assert jwt_manager._get_auth_cookie(headers) is None def test_in_blacklist(jwt_manager): jwt_manager.black_list = { - 'blacklisted_user': 100, + "blacklisted_user": 100, } profile = UserProfile( - user_id='blacklisted_user', - token='test_token', - groups=['test_group1', 'test_group2'], - first_name='test_first_name', - last_name='test_last_name', - email='test_email', - iat=75 + user_id="blacklisted_user", + token="test_token", + groups=["test_group1", "test_group2"], + first_name="test_first_name", + last_name="test_last_name", + email="test_email", + iat=75, ) assert jwt_manager._in_blacklist(profile) is True profile.iat = 115 assert jwt_manager._in_blacklist(profile) is False -@mock.patch(f'{MODULE}.time', autospec=True) +@mock.patch(f"{MODULE}.time", autospec=True) def test_jwt_payload_from_user_profile(mock_time, jwt_manager): profile = UserProfile( - user_id='test_user_id', - token='test_token', - groups=['test_group1', 'test_group2'], - first_name='test_first_name', - last_name='test_last_name', - email='test_email', + user_id="test_user_id", + token="test_token", + groups=["test_group1", "test_group2"], + first_name="test_first_name", + last_name="test_last_name", + email="test_email", ) mock_time.return_value = 1 @@ -86,135 +88,139 @@ def test_jwt_payload_from_user_profile(mock_time, jwt_manager): payload = jwt_manager._jwt_payload_from_user_profile(profile) assert payload == { - 'urs-user-id': 'test_user_id', - 'first_name': 'test_first_name', - 'last_name': 'test_last_name', - 'email': 'test_email', - 'urs-access-token': 'test_token', - 'urs-groups': ['test_group1', 'test_group2'], - 'iat': 1, - 'exp': 6 + "urs-user-id": "test_user_id", + "first_name": "test_first_name", + "last_name": "test_last_name", + "email": "test_email", + "urs-access-token": "test_token", + "urs-groups": ["test_group1", "test_group2"], + "iat": 1, + "exp": 6, } -@mock.patch(f'{MODULE}.JwtManager._in_blacklist', autospec=True) +@mock.patch(f"{MODULE}.JwtManager._in_blacklist", autospec=True) def test_get_profile_from_header(mock_in_blacklist, jwt_manager, jwt_priv_key): mock_in_blacklist.return_value = False payload = { - 'urs-user-id': 'test_user', - 'first_name': 'test', - 'last_name': 'user', - 'email': 'user@emailwebsite.com', - 'urs-access-token': 'foo', - 'urs-groups': [] + "urs-user-id": "test_user", + "first_name": "test", + "last_name": "user", + "email": "user@emailwebsite.com", + "urs-access-token": "foo", + "urs-groups": [], } headers = { - 'Cookie': f'auth-cookie={jwt.encode(payload, jwt_priv_key, "RS256")}' + "Cookie": f"auth-cookie={jwt.encode(payload, jwt_priv_key, 'RS256')}", } user_profile = jwt_manager.get_profile_from_headers(headers) - assert user_profile.user_id == 'test_user' + assert user_profile.user_id == "test_user" headers = {} assert jwt_manager.get_profile_from_headers(headers) is None -@mock.patch(f'{MODULE}.JwtManager._in_blacklist', autospec=True) -def test_get_profile_from_header_jwt_blacklisted(mock_in_blacklist, jwt_manager, jwt_priv_key): +@mock.patch(f"{MODULE}.JwtManager._in_blacklist", autospec=True) +def test_get_profile_from_header_jwt_blacklisted( + mock_in_blacklist, + jwt_manager, + jwt_priv_key, +): mock_in_blacklist.return_value = True payload = { - 'urs-user-id': 'test_user', - 'first_name': 'test', - 'last_name': 'user', - 'email': 'user@emailwebsite.com', - 'urs-access-token': 'foo', - 'urs-groups': [] + "urs-user-id": "test_user", + "first_name": "test", + "last_name": "user", + "email": "user@emailwebsite.com", + "urs-access-token": "foo", + "urs-groups": [], } headers = { - 'Cookie': f'auth-cookie={jwt.encode(payload, jwt_priv_key, "RS256")}' + "Cookie": f"auth-cookie={jwt.encode(payload, jwt_priv_key, 'RS256')}", } user_profile = jwt_manager.get_profile_from_headers(headers) assert user_profile is None -@mock.patch(f'{MODULE}.JwtManager._encode_jwt', autospec=True) -@mock.patch(f'{MODULE}.time', autospec=True) +@mock.patch(f"{MODULE}.JwtManager._encode_jwt", autospec=True) +@mock.patch(f"{MODULE}.time", autospec=True) def test_get_header_to_set_auth_cookie( mock_time, mock_encode_jwt, jwt_manager, ): - jwt_manager.cookie_name = 'auth-cookie' + jwt_manager.cookie_name = "auth-cookie" jwt_manager.session_ttl = 1 - mock_encode_jwt.return_value = 'COOKIE_VALUE' + mock_encode_jwt.return_value = "COOKIE_VALUE" mock_time.return_value = 0 profile = UserProfile( - user_id='test_user_id', - token='test_token', - groups=['test_group1', 'test_group2'], - first_name='test_first_name', - last_name='test_last_name', - email='test_email', + user_id="test_user_id", + token="test_token", + groups=["test_group1", "test_group2"], + first_name="test_first_name", + last_name="test_last_name", + email="test_email", iat=0, - exp=0 + exp=0, ) - header = jwt_manager.get_header_to_set_auth_cookie(profile, '') + header = jwt_manager.get_header_to_set_auth_cookie(profile, "") assert header == { - 'SET-COOKIE': ( - 'auth-cookie=COOKIE_VALUE; Expires=Thu, 01 Jan 1970 00:00:01 GMT; Path=/; Secure; HttpOnly; SameSite=Lax' + "SET-COOKIE": ( + "auth-cookie=COOKIE_VALUE; Expires=Thu, 01 Jan 1970 00:00:01 GMT; Path=/; Secure; HttpOnly; SameSite=Lax" ) } - header = jwt_manager.get_header_to_set_auth_cookie(profile, 'DOMAIN') + header = jwt_manager.get_header_to_set_auth_cookie(profile, "DOMAIN") assert header == { - 'SET-COOKIE': ( - 'auth-cookie=COOKIE_VALUE; Expires=Thu, 01 Jan 1970 00:00:01 GMT; Path=/; Domain=DOMAIN; Secure; HttpOnly; ' - 'SameSite=Lax' + "SET-COOKIE": ( + "auth-cookie=COOKIE_VALUE; Expires=Thu, 01 Jan 1970 00:00:01 GMT; Path=/; Domain=DOMAIN; Secure; HttpOnly; " + "SameSite=Lax" ) } def test_get_header_to_set_auth_cookie_logout(jwt_manager): - header = jwt_manager.get_header_to_set_auth_cookie(None, 'DOMAIN') + header = jwt_manager.get_header_to_set_auth_cookie(None, "DOMAIN") assert header == { - 'SET-COOKIE': ( - 'auth-cookie=expired; Expires=Thu, 01 Jan 1970 00:00:00 GMT; Path=/; Domain=DOMAIN; Secure; HttpOnly; ' - 'SameSite=Lax' + "SET-COOKIE": ( + "auth-cookie=expired; Expires=Thu, 01 Jan 1970 00:00:00 GMT; Path=/; Domain=DOMAIN; Secure; HttpOnly; " + "SameSite=Lax" ) } -@mock.patch(f'{MODULE}.time', autospec=True) +@mock.patch(f"{MODULE}.time", autospec=True) def test_jwt_manager_session_ttl_sub_hour(mock_time): jwt_manager = JwtManager( "RSA256", "", "private_key", "cookie_name", - session_ttl_in_hours=0.5 + session_ttl_in_hours=0.5, ) profile = UserProfile( - user_id='test_user_id', - token='test_token', - groups=['test_group1', 'test_group2'], - first_name='test_first_name', - last_name='test_last_name', - email='test_email', + user_id="test_user_id", + token="test_token", + groups=["test_group1", "test_group2"], + first_name="test_first_name", + last_name="test_last_name", + email="test_email", ) mock_time.return_value = 1 payload = jwt_manager._jwt_payload_from_user_profile(profile) assert payload == { - 'urs-user-id': 'test_user_id', - 'first_name': 'test_first_name', - 'last_name': 'test_last_name', - 'email': 'test_email', - 'urs-access-token': 'test_token', - 'urs-groups': ['test_group1', 'test_group2'], - 'iat': 1, - 'exp': 1801, + "urs-user-id": "test_user_id", + "first_name": "test_first_name", + "last_name": "test_last_name", + "email": "test_email", + "urs-access-token": "test_token", + "urs-groups": ["test_group1", "test_group2"], + "iat": 1, + "exp": 1801, } diff --git a/tests/test_aws_util.py b/tests/test_aws_util.py index d7096a0..2a09950 100644 --- a/tests/test_aws_util.py +++ b/tests/test_aws_util.py @@ -142,17 +142,18 @@ def test_get_role_creds(monkeypatch): assert session == { "AssumedRoleUser": { "Arn": mock.ANY, - "AssumedRoleId": mock.ANY + "AssumedRoleId": mock.ANY, }, "Credentials": { "AccessKeyId": mock.ANY, "Expiration": mock.ANY, "SecretAccessKey": mock.ANY, - "SessionToken": mock.ANY + "SessionToken": mock.ANY, }, "PackedPolicySize": 6, "ResponseMetadata": { "HTTPHeaders": { + "content-type": "text/xml", "date": mock.ANY, "server": "amazon.com", "x-amzn-requestid": mock.ANY, @@ -210,14 +211,18 @@ def test_get_region_cidr_ranges(mock_request, data): IPNetwork("13.34.43.192/27"), IPNetwork("15.181.232.0/21"), IPNetwork("52.93.127.163/32"), - IPNetwork("3.2.0.0/24") + IPNetwork("3.2.0.0/24"), ] @mock.patch(f"{MODULE}.urllib.request", autospec=True) @mock.patch(f"{MODULE}.region_list_cache", []) def test_get_region_cidr_ranges_cached(mock_request): - mock_request.urlopen("").read.return_value = b'{"prefixes": [{"ip_prefix": "10.0.0.1/24", "region": "us-east-1"}]}' + mock_request.urlopen( + "" + ).read.return_value = ( + b'{"prefixes": [{"ip_prefix": "10.0.0.1/24", "region": "us-east-1"}]}' + ) get_region_cidr_ranges() assert mock_request.urlopen.call_count == 2 @@ -228,15 +233,17 @@ def test_get_region_cidr_ranges_cached(mock_request): @mock.patch(f"{MODULE}.urllib.request", autospec=True) @mock.patch(f"{MODULE}.region_list_cache", []) def test_get_region_cidr_ranges_bad_data(mock_request): - mock_request.urlopen("").read.return_value = json.dumps({ - "prefixes": [ - {}, - { - "ip_prefix": "10.0.0.1/24", - "region": "us-east-1" - } - ] - }).encode() + mock_request.urlopen("").read.return_value = json.dumps( + { + "prefixes": [ + {}, + { + "ip_prefix": "10.0.0.1/24", + "region": "us-east-1", + }, + ], + } + ).encode() assert get_region_cidr_ranges() == [ IPNetwork("10.0.0.1/24"), diff --git a/tests/test_bucket_map.py b/tests/test_bucket_map.py index c53a021..94cbc7c 100644 --- a/tests/test_bucket_map.py +++ b/tests/test_bucket_map.py @@ -15,19 +15,19 @@ def sample_bucket_map(): "productX": "bucket", "nested": { "nested2a": { - "nested3": "nested-bucket-public" + "nested3": "nested-bucket-public", }, - "nested2b": "nested-bucket-private" - } + "nested2b": "nested-bucket-private", + }, }, "PUBLIC_BUCKETS": { "browse-bucket": "General browse Imagery", - "bucket/browse": "ProductX Browse Imagery" + "bucket/browse": "ProductX Browse Imagery", }, "PRIVATE_BUCKETS": { "bucket/2020/12": ["science_team"], - "nested-bucket-private": [] - } + "nested-bucket-private": [], + }, } @@ -41,19 +41,19 @@ def sample_bucket_map_iam(): "productX": "bucket", "nested": { "nested2a": { - "nested3": "nested-bucket-public" + "nested3": "nested-bucket-public", }, - "nested2b": "nested-bucket-private" - } + "nested2b": "nested-bucket-private", + }, }, "PUBLIC_BUCKETS": { "browse-bucket": "General browse Imagery", - "bucket/browse": "ProductX Browse Imagery" + "bucket/browse": "ProductX Browse Imagery", }, "PRIVATE_BUCKETS": { "bucket": ["science_team"], - "nested-bucket-private": [] - } + "nested-bucket-private": [], + }, } @@ -67,7 +67,7 @@ def groups_bucket_map(): "bucket1": ["group1"], "bucket2": ["group2"], "bucket3": ["group3"], - } + }, } @@ -78,12 +78,14 @@ def test_get_simple(): b_map = BucketMap(bucket_map, bucket_name_prefix="pre-") entry = b_map.get("PATH/obj1") + assert entry is not None assert entry.bucket == "pre-bucket-name" assert entry.bucket_path == "PATH" assert entry.object_key == "obj1" assert entry.headers == {} dir_entry = b_map.get("PATH/") + assert dir_entry is not None assert dir_entry.bucket == "pre-bucket-name" assert dir_entry.bucket_path == "PATH" assert dir_entry.object_key == "" @@ -98,13 +100,14 @@ def test_get_simple(): ( {"foo": "bucket1"}, {"foo": {"bucket": "bucket1"}}, - {"MAP": {"foo": {"bucket": "bucket1"}}} - ) + {"MAP": {"foo": {"bucket": "bucket1"}}}, + ), ) def test_get_compatibility(bucket_map): b_map = BucketMap(bucket_map) entry = b_map.get("foo/bar") + assert entry is not None assert entry.bucket == "bucket1" assert entry.bucket_path == "foo" assert entry.object_key == "bar" @@ -115,20 +118,22 @@ def test_get_nested(): bucket_map = { "PATH": { "LEVEL1": { - "LEVEL2": "bucket-name" - } + "LEVEL2": "bucket-name", + }, }, } b_map = BucketMap(bucket_map) entry = b_map.get("PATH/LEVEL1/LEVEL2/obj1") + assert entry is not None assert entry.bucket == "bucket-name" assert entry.bucket_path == "PATH/LEVEL1/LEVEL2" assert entry.object_key == "obj1" assert entry.headers == {} dir_entry = b_map.get("PATH/LEVEL1/LEVEL2/") + assert dir_entry is not None assert dir_entry.bucket == "bucket-name" assert dir_entry.bucket_path == "PATH/LEVEL1/LEVEL2" assert dir_entry.object_key == "" @@ -142,19 +147,21 @@ def test_get_with_headers(): "PATH": { "bucket": "bucket-name", "headers": { - "Header1": "Value1" - } - } + "Header1": "Value1", + }, + }, } b_map = BucketMap(bucket_map) entry = b_map.get("PATH/obj1") + assert entry is not None assert entry.bucket == "bucket-name" assert entry.bucket_path == "PATH" assert entry.object_key == "obj1" assert entry.headers == {"Header1": "Value1"} dir_entry = b_map.get("PATH/") + assert dir_entry is not None assert dir_entry.bucket == "bucket-name" assert dir_entry.bucket_path == "PATH" assert dir_entry.object_key == "" @@ -166,19 +173,21 @@ def test_get_with_headers(): def test_get_reverse(): bucket_map = { "PATH": { - "STAGE": "bucket-name" - } + "STAGE": "bucket-name", + }, } b_map = BucketMap(bucket_map, reverse=True) entry = b_map.get("STAGE/PATH/obj1") + assert entry is not None assert entry.bucket == "bucket-name" assert entry.bucket_path == "PATH/STAGE" assert entry.object_key == "obj1" dir_entry = b_map.get("STAGE/PATH/") + assert dir_entry is not None assert dir_entry.bucket == "bucket-name" - assert entry.bucket_path == "PATH/STAGE" + assert dir_entry.bucket_path == "PATH/STAGE" assert dir_entry.object_key == "" assert b_map.get("STAGE/PATH") is None @@ -189,8 +198,8 @@ def test_get_reverse(): ( {"foo": "bucket1"}, {"foo": {"bucket": "bucket1"}}, - {"MAP": {"foo": {"bucket": "bucket1"}}} - ) + {"MAP": {"foo": {"bucket": "bucket1"}}}, + ), ) def test_get_path_compatibility(bucket_map): # Using a tuple instead of a list to ensure the input is not modified @@ -201,6 +210,7 @@ def test_get_path_compatibility(bucket_map): b_map = BucketMap(bucket_map) entry = b_map.get_path(path_list) + assert entry is not None assert entry.bucket == "bucket1" assert entry.bucket_path == "foo" assert entry.object_key == "bar/baz" @@ -241,7 +251,7 @@ def test_entries_simple(): def test_entries_multiple(): bucket_map = { "PATH": "bucket1", - "PATH2": "bucket2" + "PATH2": "bucket2", } b_map = BucketMap(bucket_map, bucket_name_prefix="pre-") @@ -255,7 +265,7 @@ def test_entries_multiple(): bucket="pre-bucket2", bucket_path="PATH2", object_key="", - ) + ), ] @@ -264,8 +274,8 @@ def test_entries_multiple(): ( {"foo": "bucket1"}, {"foo": {"bucket": "bucket1"}}, - {"MAP": {"foo": {"bucket": "bucket1"}}} - ) + {"MAP": {"foo": {"bucket": "bucket1"}}}, + ), ) def test_entries_compatibility(bucket_map): b_map = BucketMap(bucket_map) @@ -283,8 +293,8 @@ def test_entries_nested(): bucket_map = { "PATH": { "LEVEL1": { - "LEVEL2": "bucket-name" - } + "LEVEL2": "bucket-name", + }, }, } @@ -312,7 +322,7 @@ def test_entries(sample_bucket_map): bucket="browse-bucket", bucket_path="general-browse", object_key="", - _access_control={"": None} + _access_control={"": None}, ), BucketMapEntry( bucket="bucket", @@ -321,8 +331,8 @@ def test_entries(sample_bucket_map): _access_control={ "2020/12": {"science_team"}, "browse": None, - "": set() - } + "": set(), + }, ), BucketMapEntry( bucket="nested-bucket-public", @@ -333,7 +343,7 @@ def test_entries(sample_bucket_map): bucket="nested-bucket-private", bucket_path="nested/nested2b", object_key="", - _access_control={"": set()} + _access_control={"": set()}, ), ] @@ -343,9 +353,9 @@ def test_entries_with_headers(): "PATH": { "bucket": "bucket-name", "headers": { - "Header1": "Value1" - } - } + "Header1": "Value1", + }, + }, } b_map = BucketMap(bucket_map) @@ -354,7 +364,7 @@ def test_entries_with_headers(): bucket="bucket-name", bucket_path="PATH", object_key="", - headers={"Header1": "Value1"} + headers={"Header1": "Value1"}, ) ] @@ -369,10 +379,18 @@ def test_check_bucket_access(sample_bucket_map): assert b_map.get("ANY_AUTHED/obj1").is_accessible() is False assert b_map.get("ANY_AUTHED/obj1").is_accessible(groups=[]) is True assert b_map.get("general-browse/obj1").is_accessible() is True - assert b_map.get("general-browse/obj1").is_accessible(groups=["science_team"]) is True + assert ( + b_map.get("general-browse/obj1").is_accessible(groups=["science_team"]) is True + ) assert b_map.get("productX/browse/obj1").is_accessible() is True - assert b_map.get("productX/2020/12/obj1").is_accessible(groups=["science_team"]) is True - assert b_map.get("productX/2020/23/obj2").is_accessible(groups=["science_team"]) is True + assert ( + b_map.get("productX/2020/12/obj1").is_accessible(groups=["science_team"]) + is True + ) + assert ( + b_map.get("productX/2020/23/obj2").is_accessible(groups=["science_team"]) + is True + ) assert b_map.get("productX/2020/12/obj1").is_accessible() is False assert b_map.get("nested/nested2b/obj1").is_accessible() is False assert b_map.get("nested/nested2b/obj1").is_accessible(groups=[]) is True @@ -382,14 +400,14 @@ def test_check_bucket_access_conflicting(): # When a bucket is configured to be both public and private bucket_map = { "MAP": { - "PATH": "bucket" + "PATH": "bucket", }, "PUBLIC_BUCKETS": [ - "bucket" + "bucket", ], "PRIVATE_BUCKETS": { - "bucket": ["some_permission"] - } + "bucket": ["some_permission"], + }, } b_map = BucketMap(bucket_map) @@ -401,15 +419,15 @@ def test_check_bucket_access_longest_prefix_first(): # Longer prefixes should be checked first bucket_map = { "MAP": { - "PATH": "bucket" + "PATH": "bucket", }, "PUBLIC_BUCKETS": [ - "bucket" + "bucket", ], "PRIVATE_BUCKETS": { "bucket/foobar": ["other_permission"], "bucket/foo": ["some_permission"], - } + }, } b_map = BucketMap(bucket_map, iam_compatible=False) @@ -422,15 +440,15 @@ def test_check_bucket_access_longest_prefix_first_order(): # Longer prefixes should be checked first bucket_map = { "MAP": { - "PATH": "bucket" + "PATH": "bucket", }, "PUBLIC_BUCKETS": [ - "bucket" + "bucket", ], "PRIVATE_BUCKETS": { "bucket/foo": ["some_permission"], "bucket/foobar": ["other_permission"], - } + }, } b_map = BucketMap(bucket_map, iam_compatible=False) @@ -443,16 +461,16 @@ def test_check_bucket_access_longest_prefix_first_conflicting(): # Longer prefixes should be checked first bucket_map = { "MAP": { - "PATH": "bucket" + "PATH": "bucket", }, "PUBLIC_BUCKETS": [ "bucket/foo", - "bucket/foobar" + "bucket/foobar", ], "PRIVATE_BUCKETS": { "bucket/foo": ["some_permission"], "bucket/foobar": ["other_permission"], - } + }, } b_map = BucketMap(bucket_map, iam_compatible=False) @@ -466,17 +484,17 @@ def test_check_bucket_access_nested_paths(): "MAP": { "nested": { "nested2a": { - "nested3": "nested-bucket-public" + "nested3": "nested-bucket-public", }, - "nested2b": "nested-bucket-private" - } + "nested2b": "nested-bucket-private", + }, }, "PUBLIC_BUCKETS": { - "nested-bucket-public": "Public bucket in 'nested'" + "nested-bucket-public": "Public bucket in 'nested'", }, "PRIVATE_BUCKETS": { - "nested-bucket-private": ["science_team"] - } + "nested-bucket-private": ["science_team"], + }, } b_map = BucketMap(bucket_map) @@ -485,50 +503,62 @@ def test_check_bucket_access_nested_paths(): assert b_map.get("nested/nested2a/nested3") is None assert b_map.get("nested/nested2b/obj1").is_accessible() is False - assert b_map.get("nested/nested2b/obj1").is_accessible(groups=["wrong_group"]) is False - assert b_map.get("nested/nested2b/obj1").is_accessible(groups=["science_team"]) is True + assert ( + b_map.get("nested/nested2b/obj1").is_accessible(groups=["wrong_group"]) is False + ) + assert ( + b_map.get("nested/nested2b/obj1").is_accessible(groups=["science_team"]) is True + ) assert b_map.get("nested/nested2a/nested3/obj1").is_accessible() is True def test_check_bucket_access_nested_private_first(): bucket_map = { "MAP": { - "PATH": "bucket" + "PATH": "bucket", }, "PUBLIC_BUCKETS": [ - "bucket/foo/browse" + "bucket/foo/browse", ], "PRIVATE_BUCKETS": { - "bucket/foo": ["some_permission"] - } + "bucket/foo": ["some_permission"], + }, } b_map = BucketMap(bucket_map, iam_compatible=False) assert b_map.get("PATH/obj1").is_accessible() is False assert b_map.get("PATH/foo/obj1").is_accessible() is False assert b_map.get("PATH/foo/obj1").is_accessible(groups=["some_permission"]) is True - assert b_map.get("PATH/foo/browse/obj1").is_accessible(groups=["some_permission"]) is True + assert ( + b_map.get("PATH/foo/browse/obj1").is_accessible(groups=["some_permission"]) + is True + ) assert b_map.get("PATH/foo/browse/obj1").is_accessible() is True def test_check_bucket_access_nested_public_first(): bucket_map = { "MAP": { - "PATH": "bucket" + "PATH": "bucket", }, "PUBLIC_BUCKETS": [ - "bucket/browse" + "bucket/browse", ], "PRIVATE_BUCKETS": { - "bucket/browse/foo": ["some_permission"] - } + "bucket/browse/foo": ["some_permission"], + }, } b_map = BucketMap(bucket_map, iam_compatible=False) assert b_map.get("PATH/obj1").is_accessible() is False assert b_map.get("PATH/browse/foo/obj1").is_accessible() is False - assert b_map.get("PATH/browse/foo/obj1").is_accessible(groups=["some_permission"]) is True - assert b_map.get("PATH/browse/obj1").is_accessible(groups=["some_permission"]) is True + assert ( + b_map.get("PATH/browse/foo/obj1").is_accessible(groups=["some_permission"]) + is True + ) + assert ( + b_map.get("PATH/browse/obj1").is_accessible(groups=["some_permission"]) is True + ) assert b_map.get("PATH/browse/obj1").is_accessible() is True @@ -536,27 +566,27 @@ def test_check_iam_compatible_nested_private_first(): _ = BucketMap( { "PUBLIC_BUCKETS": ["bucket/browse"], - "PRIVATE_BUCKETS": {"bucket": ["group_1"]} + "PRIVATE_BUCKETS": {"bucket": ["group_1"]}, }, - iam_compatible=True + iam_compatible=True, ) _ = BucketMap( { "PRIVATE_BUCKETS": { "bucket": ["group_1"], - "bucket/foo/": ["group_1", "group_2"] - } + "bucket/foo/": ["group_1", "group_2"], + }, }, - iam_compatible=True + iam_compatible=True, ) _ = BucketMap( { "PRIVATE_BUCKETS": { "bucket": ["group_1"], - "bucket/foo/": [] - } + "bucket/foo/": [], + }, }, - iam_compatible=True + iam_compatible=True, ) @@ -564,19 +594,17 @@ def test_check_iam_compatible_nested_protected_then_private(): with pytest.raises(ValueError): _ = BucketMap( { - "PRIVATE_BUCKETS": { - "bucket/foo/": ["group_1"] - } + "PRIVATE_BUCKETS": {"bucket/foo/": ["group_1"]}, }, - iam_compatible=True + iam_compatible=True, ) with pytest.raises(ValueError): _ = BucketMap( { - "PRIVATE_BUCKETS": {"bucket/foo": ["group_1"]} + "PRIVATE_BUCKETS": {"bucket/foo": ["group_1"]}, }, - iam_compatible=True + iam_compatible=True, ) @@ -585,26 +613,26 @@ def test_check_iam_compatible_nested_public_first(): _ = BucketMap( { "PUBLIC_BUCKETS": ["bucket"], - "PRIVATE_BUCKETS": {"bucket/foo": ["group_1"]} + "PRIVATE_BUCKETS": {"bucket/foo": ["group_1"]}, }, - iam_compatible=True + iam_compatible=True, ) with pytest.raises(ValueError, match="'foo' has protected access"): _ = BucketMap( { "PUBLIC_BUCKETS": ["bucket"], - "PRIVATE_BUCKETS": {"bucket/foo": []} + "PRIVATE_BUCKETS": {"bucket/foo": []}, }, - iam_compatible=True + iam_compatible=True, ) def test_check_bucket_access_malformed(): bucket_map = { "MAP": { - "PATH": "bucket" + "PATH": "bucket", }, - "PUBLIC_BUCKETS": 10 + "PUBLIC_BUCKETS": 10, } b_map = BucketMap(bucket_map) @@ -642,10 +670,10 @@ def test_to_iam_policy_simple(): "Action": ["s3:GetObject", "s3:ListBucket"], "Resource": [ "arn:aws:s3:::bucket-name", - "arn:aws:s3:::bucket-name/*" - ] - } - ] + "arn:aws:s3:::bucket-name/*", + ], + }, + ], } @@ -670,10 +698,10 @@ def test_to_iam_policy_simple_duplicates(): "arn:aws:s3:::bucket-name1", "arn:aws:s3:::bucket-name1/*", "arn:aws:s3:::bucket-name2", - "arn:aws:s3:::bucket-name2/*" - ] + "arn:aws:s3:::bucket-name2/*", + ], } - ] + ], } @@ -681,8 +709,8 @@ def test_to_iam_policy_private(): bucket_map = { "PATH": "bucket-name", "PRIVATE_BUCKETS": { - "bucket-name": ["science_team"] - } + "bucket-name": ["science_team"], + }, } b_map = BucketMap(bucket_map) @@ -693,11 +721,11 @@ def test_to_iam_policy_private_with_public_prefix(): bucket_map = { "PATH": "bucket-name", "PUBLIC_BUCKETS": { - "bucket-name/public/": "Public browse imagery" + "bucket-name/public/": "Public browse imagery", }, "PRIVATE_BUCKETS": { "bucket-name": ["science_team"], - } + }, } b_map = BucketMap(bucket_map) @@ -708,8 +736,8 @@ def test_to_iam_policy_private_with_public_prefix(): "Effect": "Allow", "Action": ["s3:GetObject"], "Resource": [ - "arn:aws:s3:::bucket-name/public/*" - ] + "arn:aws:s3:::bucket-name/public/*", + ], }, { "Effect": "Allow", @@ -719,11 +747,11 @@ def test_to_iam_policy_private_with_public_prefix(): ], "Condition": { "StringLike": { - "s3:prefix": ["public/*"] - } - } - } - ] + "s3:prefix": ["public/*"], + }, + }, + }, + ], } @@ -732,8 +760,8 @@ def test_to_iam_policy_private_with_protected_prefix(): "PATH": "bucket-name", "PRIVATE_BUCKETS": { "bucket-name": ["science_team"], - "bucket-name/public/": [] - } + "bucket-name/public/": [], + }, } b_map = BucketMap(bucket_map) @@ -744,8 +772,8 @@ def test_to_iam_policy_private_with_protected_prefix(): "Effect": "Allow", "Action": ["s3:GetObject"], "Resource": [ - "arn:aws:s3:::bucket-name/public/*" - ] + "arn:aws:s3:::bucket-name/public/*", + ], }, { "Effect": "Allow", @@ -755,13 +783,11 @@ def test_to_iam_policy_private_with_protected_prefix(): ], "Condition": { "StringLike": { - "s3:prefix": [ - "public/*" - ] - } - } - } - ] + "s3:prefix": ["public/*"], + }, + }, + }, + ], } @@ -770,8 +796,8 @@ def test_to_iam_policy_private_with_protected_prefix_full_access(): "PATH": "bucket-name", "PRIVATE_BUCKETS": { "bucket-name": ["science_team"], - "bucket-name/public/": [] - } + "bucket-name/public/": [], + }, } b_map = BucketMap(bucket_map) @@ -783,10 +809,10 @@ def test_to_iam_policy_private_with_protected_prefix_full_access(): "Action": ["s3:GetObject", "s3:ListBucket"], "Resource": [ "arn:aws:s3:::bucket-name", - "arn:aws:s3:::bucket-name/*" - ] + "arn:aws:s3:::bucket-name/*", + ], } - ] + ], } @@ -801,7 +827,7 @@ def test_to_iam_policy_private_with_multiple_nested_private(): "bucket-name/closed1/open1/": [], "bucket-name/closed1/open2/": [], "bucket-name/closed1/open3/": [], - } + }, } b_map = BucketMap(bucket_map) @@ -815,7 +841,7 @@ def test_to_iam_policy_private_with_multiple_nested_private(): "arn:aws:s3:::bucket-name/closed1/open1/*", "arn:aws:s3:::bucket-name/closed1/open2/*", "arn:aws:s3:::bucket-name/closed1/open3/*", - ] + ], }, { "Effect": "Allow", @@ -828,12 +854,12 @@ def test_to_iam_policy_private_with_multiple_nested_private(): "s3:prefix": [ "closed1/open1/*", "closed1/open2/*", - "closed1/open3/*" + "closed1/open3/*", ] } - } - } - ] + }, + }, + ], } @@ -848,7 +874,7 @@ def test_to_iam_policy_private_with_multiple_nested_protected_multiple_buckets() "bucket2/closed2/": ["science_team"], "bucket1/closed1/open1/": [], "bucket2/closed1/open2/": [], - } + }, } b_map = BucketMap(bucket_map) @@ -861,7 +887,7 @@ def test_to_iam_policy_private_with_multiple_nested_protected_multiple_buckets() "Resource": [ "arn:aws:s3:::bucket1/closed1/open1/*", "arn:aws:s3:::bucket2/closed1/open2/*", - ] + ], }, { "Effect": "Allow", @@ -875,7 +901,7 @@ def test_to_iam_policy_private_with_multiple_nested_protected_multiple_buckets() "closed1/open1/*", ] } - } + }, }, { "Effect": "Allow", @@ -889,9 +915,9 @@ def test_to_iam_policy_private_with_multiple_nested_protected_multiple_buckets() "closed1/open2/*", ] } - } - } - ] + }, + }, + ], } @@ -904,7 +930,7 @@ def test_to_iam_policy_merge_prefix_resources(): "bucket2": ["science_team"], "bucket1/theprefix/": [], "bucket2/theprefix/": [], - } + }, } b_map = BucketMap(bucket_map) @@ -917,7 +943,7 @@ def test_to_iam_policy_merge_prefix_resources(): "Resource": [ "arn:aws:s3:::bucket1/theprefix/*", "arn:aws:s3:::bucket2/theprefix/*", - ] + ], }, { "Effect": "Allow", @@ -932,9 +958,9 @@ def test_to_iam_policy_merge_prefix_resources(): "theprefix/*", ] } - } - } - ] + }, + }, + ], } @@ -957,7 +983,7 @@ def test_to_iam_policy(sample_bucket_map_iam): "arn:aws:s3:::pre-nested-bucket-public/*", "arn:aws:s3:::pre-nested-bucket-private", "arn:aws:s3:::pre-nested-bucket-private/*", - ] + ], }, { "Effect": "Allow", @@ -971,9 +997,9 @@ def test_to_iam_policy(sample_bucket_map_iam): "browse*", ] } - } - } - ] + }, + }, + ], } @@ -981,12 +1007,12 @@ def test_to_iam_policy_checks_compatibility(): bucket_map = { "PATH": "bucket", "PRIVATE_BUCKETS": { - "bucket/foo/": ["group"] - } + "bucket/foo/": ["group"], + }, } b_map = BucketMap(bucket_map, iam_compatible=False) - with pytest.raises(ValueError, match="'foo/' has {'group'}"): + with pytest.raises(ValueError, match=r"'foo/' has {'group'}"): b_map.to_iam_policy() @@ -1009,9 +1035,9 @@ def test_to_iam_policy_groups_single_access(groups_bucket_map): "Resource": [ "arn:aws:s3:::bucket1", "arn:aws:s3:::bucket1/*", - ] + ], } - ] + ], } assert b_map.to_iam_policy(groups=("group2",)) == { @@ -1023,9 +1049,9 @@ def test_to_iam_policy_groups_single_access(groups_bucket_map): "Resource": [ "arn:aws:s3:::bucket2", "arn:aws:s3:::bucket2/*", - ] + ], } - ] + ], } @@ -1042,7 +1068,7 @@ def test_to_iam_policy_groups_multiple_access(groups_bucket_map): "arn:aws:s3:::bucket2/*", "arn:aws:s3:::bucket3", "arn:aws:s3:::bucket3/*", - ] + ], } - ] + ], } diff --git a/tests/test_egress_util.py b/tests/test_egress_util.py index 0b6131d..3c3d9ae 100644 --- a/tests/test_egress_util.py +++ b/tests/test_egress_util.py @@ -29,10 +29,17 @@ def test_get_presigned_url(mock_datetime): "Credentials": { "AccessKeyId": "access_key_id", "SecretAccessKey": "secret_access_key", - "SessionToken": "session_token" + "SessionToken": "session_token", } } - presigned_url = get_presigned_url(session, "bucket_name", "object_name", "region_name", 1000, "user_id") + presigned_url = get_presigned_url( + session, + "bucket_name", + "object_name", + "region_name", + 1000, + "user_id", + ) assert presigned_url == ( "https://bucket_name.s3.region_name.amazonaws.com/object_name" "?A-userid=user_id" @@ -53,10 +60,17 @@ def test_get_presigned_url_with_spaces(mock_datetime): "Credentials": { "AccessKeyId": "access_key_id", "SecretAccessKey": "secret_access_key", - "SessionToken": "session_token" + "SessionToken": "session_token", } } - presigned_url = get_presigned_url(session, "bucket_name", "has spaces ", "region_name", 1000, "user_id") + presigned_url = get_presigned_url( + session, + "bucket_name", + "has spaces ", + "region_name", + 1000, + "user_id", + ) assert presigned_url == ( "https://bucket_name.s3.region_name.amazonaws.com/has%20spaces%20" "?A-userid=user_id" @@ -77,10 +91,17 @@ def test_get_presigned_url_with_colons(mock_datetime): "Credentials": { "AccessKeyId": "access_key_id", "SecretAccessKey": "secret_access_key", - "SessionToken": "session_token" + "SessionToken": "session_token", } } - presigned_url = get_presigned_url(session, "bucket_name", "has_:colons:", "region_name", 1000, "user_id") + presigned_url = get_presigned_url( + session, + "bucket_name", + "has_:colons:", + "region_name", + 1000, + "user_id", + ) assert presigned_url == ( "https://bucket_name.s3.region_name.amazonaws.com/has_%3Acolons%3A" "?A-userid=user_id" @@ -101,10 +122,17 @@ def test_get_presigned_url_with_newlines(mock_datetime): "Credentials": { "AccessKeyId": "access_key_id", "SecretAccessKey": "secret_access_key", - "SessionToken": "session_token" + "SessionToken": "session_token", } } - presigned_url = get_presigned_url(session, "bucket_name", "has\nnewlines\n", "region_name", 1000, "user_id") + presigned_url = get_presigned_url( + session, + "bucket_name", + "has\nnewlines\n", + "region_name", + 1000, + "user_id", + ) assert presigned_url == ( "https://bucket_name.s3.region_name.amazonaws.com/has%0Anewlines%0A" "?A-userid=user_id" @@ -125,7 +153,7 @@ def test_get_presigned_url_with_api_request_uuid(mock_datetime): "Credentials": { "AccessKeyId": "access_key_id", "SecretAccessKey": "secret_access_key", - "SessionToken": "session_token" + "SessionToken": "session_token", } } presigned_url = get_presigned_url( @@ -135,7 +163,7 @@ def test_get_presigned_url_with_api_request_uuid(mock_datetime): "region_name", 500, "user_id", - api_request_uuid="uuid_value" + api_request_uuid="uuid_value", ) assert presigned_url == ( "https://bucket_name.s3.region_name.amazonaws.com/object_name" diff --git a/tests/test_general_util.py b/tests/test_general_util.py index d4884a7..96dd13e 100644 --- a/tests/test_general_util.py +++ b/tests/test_general_util.py @@ -12,7 +12,7 @@ def test_return_timing_object(): "endpoint": "Unknown", "method": "GET", "duration": 0, - "unit": "milliseconds" + "unit": "milliseconds", } } assert return_timing_object(Service="some_service", OTHER_KEY="OTHER_VALUE") == { @@ -22,7 +22,7 @@ def test_return_timing_object(): "method": "GET", "duration": 0, "unit": "milliseconds", - "other_key": "OTHER_VALUE" + "other_key": "OTHER_VALUE", } } diff --git a/tests/test_logging.py b/tests/test_logging.py index e0603f1..88bd59b 100644 --- a/tests/test_logging.py +++ b/tests/test_logging.py @@ -93,7 +93,7 @@ def test_get_log(capsys, monkeypatch): "user_id": None, "route": None, "build": "NOBUILD", - "exception": None + "exception": None, } @@ -116,13 +116,13 @@ def test_get_log_json_object(monkeypatch, capsys): "OriginRequestId": None, "message": { "Some": "json", - "object": 100 + "object": 100, }, "maturity": "DEV", "user_id": None, "route": None, "build": "NOBUILD", - "exception": None + "exception": None, } @@ -148,7 +148,7 @@ def test_get_log_json_object_exception(monkeypatch, capsys): "OriginRequestId": None, "message": { "Some": "json", - "object": 100 + "object": 100, }, "maturity": "DEV", "user_id": None, @@ -158,8 +158,8 @@ def test_get_log_json_object_exception(monkeypatch, capsys): "Traceback (most recent call last):", mock.ANY, ' raise Exception("Test Exception")', - "Exception: Test Exception" - ] + "Exception: Test Exception", + ], } @@ -176,7 +176,7 @@ def test_get_log_flat(capsys, monkeypatch): assert re.match( r"INFO: test message: 100, Creds: Basic XXXXXX \([a-z_]+.py line [0-9]+/NOBUILD/DEV\) - " "RequestId: None; OriginRequestId: None; user_id: None; route: None\n", - stdout + stdout, ) @@ -195,7 +195,9 @@ def test_filter_log_credentials(): EDL_TOKEN = "EDL-ABBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBCCCCCCCCCC" BASIC_AUTH_TOKEN = "Basic AAAAABBBBB" AWS_TOKEN = ":AAAAABBBBBBBBBBBBBBBBBBBBBBBBBBBBBBCCCCC:" - assert filter_log_credentials(JWT_TOKEN) == "eyJ0eXAiOiJKV1QXXXXXXuexWU3mqgA" + assert ( + filter_log_credentials(JWT_TOKEN) == "eyJ0eXAiOiJKV1QXXXXXXuexWU3mqgA" + ) assert filter_log_credentials(EDL_TOKEN) == "EDL-AXXXXXXCCCCCCCCCC" assert filter_log_credentials(BASIC_AUTH_TOKEN) == "Basic XXXXXX" assert filter_log_credentials(AWS_TOKEN) == ":AAAAAXXXXXXCCCCC:" @@ -255,7 +257,7 @@ def test_json_logging_exception(logger, caplog): def test_json_logging_quotes(logger, log_io): obj = { - "foo': 'baz', 'qux": "bar" + "foo': 'baz', 'qux": "bar", } logger.info(obj) @@ -266,7 +268,7 @@ def test_json_logging_quotes(logger, log_io): def test_json_logging_quotes_malformed(logger, log_io): obj = { - "foo'": "bar" + "foo'": "bar", } logger.info(obj) @@ -276,7 +278,7 @@ def test_json_logging_quotes_malformed(logger, log_io): def test_json_logging_not_serializable(logger, log_io): - class SomeClass(): + class SomeClass: def __repr__(self) -> str: return "SomeClass()" @@ -302,28 +304,32 @@ def test_json_logging_time_as_field(logger, custom_log_handler, log_io): def test_json_logging_time_in_field(logger, custom_log_handler, log_io): - custom_log_handler.setFormatter(JSONFormatter("the time is %(asctime)s", datefmt="the_date")) + custom_log_handler.setFormatter( + JSONFormatter("the time is %(asctime)s", datefmt="the_date") + ) logger.info("hello") assert log_io.getvalue() == '"the time is the_date"\n' def test_json_logging_deep_format(logger, custom_log_handler, log_io): - custom_log_handler.setFormatter(JSONFormatter({ - "key1": { - "route": "%(route)s", - "key2": [ - {"message": "%(message)s"}, - {"message": "%(message)s", "maturity": "%(maturity)s"}, - { - "key3": ["%(build_vers)s"] - } - ] - }, - "constant": 100, - "string_constant": "FOO", - "format_string": "%(request_id)s from %(origin_request_id)s" - })) + custom_log_handler.setFormatter( + JSONFormatter( + { + "key1": { + "route": "%(route)s", + "key2": [ + {"message": "%(message)s"}, + {"message": "%(message)s", "maturity": "%(maturity)s"}, + {"key3": ["%(build_vers)s"]}, + ], + }, + "constant": 100, + "string_constant": "FOO", + "format_string": "%(request_id)s from %(origin_request_id)s", + } + ) + ) obj = {"foo": "bar"} logger.info(obj) @@ -335,11 +341,11 @@ def test_json_logging_deep_format(logger, custom_log_handler, log_io): {"message": obj}, {"message": obj, "maturity": "DEV"}, { - "key3": ["NOBUILD"] - } - ] + "key3": ["NOBUILD"], + }, + ], }, "constant": 100, "string_constant": "FOO", - "format_string": "the_request_id from the_origin_request_id" + "format_string": "the_request_id from the_origin_request_id", } diff --git a/tests/test_urs_util.py b/tests/test_urs_util.py index 5915922..eda45ca 100644 --- a/tests/test_urs_util.py +++ b/tests/test_urs_util.py @@ -27,24 +27,24 @@ def context(): return { "apiId": "test_apiId", "identity": { - "userAgent": "Mozilla ..." + "userAgent": "Mozilla ...", }, "domainName": "example.com", - "stage": "DEV" + "stage": "DEV", } @pytest.fixture def user_profile(): return UserProfile( - user_id='test_user', - first_name='John', - last_name='Smith', - email='j.smith@email.com', + user_id="test_user", + first_name="John", + last_name="Smith", + email="j.smith@email.com", groups=[], - token='test_token', + token="test_token", iat=0, - exp=0 + exp=0, ) @@ -180,13 +180,22 @@ def test_get_new_token_and_profile(mock_get_urs_creds, mock_get_profile, mock_cl mock_client().request.return_value = {"access_token": "token"} assert get_new_token_and_profile("user_id", "cookietoken") == {"foo": "bar"} - mock_get_profile.assert_called_once_with("user_id", "cookietoken", "token", aux_headers={}) + mock_get_profile.assert_called_once_with( + "user_id", + "cookietoken", + "token", + aux_headers={}, + ) @mock.patch(f"{MODULE}.EdlClient", autospec=True) @mock.patch(f"{MODULE}.get_profile", autospec=True) @mock.patch(f"{MODULE}.get_urs_creds", autospec=True) -def test_get_new_token_and_profile_error(mock_get_urs_creds, mock_get_profile, mock_client): +def test_get_new_token_and_profile_error( + mock_get_urs_creds, + mock_get_profile, + mock_client, +): mock_get_urs_creds.return_value = {"UrsAuth": "URS_AUTH"} mock_client().request.side_effect = EdlException( urllib.error.URLError("test error"), @@ -206,7 +215,7 @@ def test_user_in_group_list(mock_get_urs_creds): {"client_id": "CLIENT_ID_1", "name": "GROUP_1"}, {"client_id": "CLIENT_ID_2", "name": "GROUP_1"}, {"client_id": "CLIENT_ID_3", "name": "GROUP_1"}, - {"client_id": "CLIENT_ID_3", "name": "GROUP_2"} + {"client_id": "CLIENT_ID_3", "name": "GROUP_2"}, ] mock_get_urs_creds.return_value = {"UrsId": "CLIENT_ID_1"} assert user_in_group_list([], user_groups) is False @@ -228,13 +237,16 @@ def test_user_in_group_urs(mock_get_profile, mock_user_in_group_list): mock_get_profile.return_value = {"user_groups": [], "new_profile": True} mock_user_in_group_list.return_value = True - assert user_in_group_urs(private_groups, "user_id", "token", user_profile) == (True, {}) + assert user_in_group_urs(private_groups, "user_id", "token", user_profile) == ( + True, + {}, + ) mock_user_in_group_list.assert_called_once() mock_user_in_group_list.return_value = False assert user_in_group_urs(private_groups, "user_id", "token", user_profile) == ( False, - {"user_groups": [], "new_profile": True} + {"user_groups": [], "new_profile": True}, ) mock_get_profile.assert_called_once() @@ -244,12 +256,12 @@ def test_user_in_group(mock_user_in_group_list): mock_user_in_group_list.return_value = True user = UserProfile( - user_id='test_user_id', - token='test_token', + user_id="test_user_id", + token="test_token", groups=[], - first_name='test_first_name', - last_name='test_last_name', - email='test_email', + first_name="test_first_name", + last_name="test_last_name", + email="test_email", ) assert user_in_group([], {}) == (False, None) @@ -265,12 +277,12 @@ def test_user_in_group(mock_user_in_group_list): def test_user_in_group_refresh(mock_get_profile, mock_user_in_group_list): mock_user_in_group_list.return_value = True user = UserProfile( - user_id='test_user_id', - token='test_token', + user_id="test_user_id", + token="test_token", groups=[], - first_name='test_first_name', - last_name='test_last_name', - email='test_email', + first_name="test_first_name", + last_name="test_last_name", + email="test_email", ) mock_get_profile.return_value = user @@ -292,7 +304,7 @@ def test_get_urs_creds(mock_retrieve_secret, monkeypatch): secret = { "UrsId": "URS_ID", - "UrsAuth": "URS_AUTH" + "UrsAuth": "URS_AUTH", } mock_retrieve_secret.return_value = secret assert get_urs_creds() == secret @@ -302,61 +314,61 @@ def test_get_urs_creds(mock_retrieve_secret, monkeypatch): @mock.patch(f"{MODULE}.get_profile", autospec=True) @mock.patch(f"{MODULE}.JwtManager.get_header_to_set_auth_cookie", autospec=True) def test_do_login( - mock_get_header_to_set_auth_cookie, - mock_get_profile, - mock_do_auth, - context, - user_profile + mock_get_header_to_set_auth_cookie, + mock_get_profile, + mock_do_auth, + context, + user_profile, ): mock_do_auth.return_value = { "endpoint": "ENDPOINT", - "access_token": "ACCESS_TOKEN" + "access_token": "ACCESS_TOKEN", } - user_profile.groups = ['GROUP_1'] + user_profile.groups = ["GROUP_1"] mock_get_profile.return_value = user_profile mock_get_header_to_set_auth_cookie.return_value = { - "SET-COOKIE": "foo=bar" + "SET-COOKIE": "foo=bar", } args = { - "code": "URS_CODE" + "code": "URS_CODE", } - jwt_manager = JwtManager('algorithm', 'pub_key', 'priv_key', 'cookie-name') + jwt_manager = JwtManager("algorithm", "pub_key", "priv_key", "cookie-name") assert do_login(args, context, jwt_manager) == ( 301, {}, { "Location": "https://example.com/DEV/", - "SET-COOKIE": "foo=bar" - } + "SET-COOKIE": "foo=bar", + }, ) args = { "code": "URS_CODE", - "state": "https://somewhere-else.com" + "state": "https://somewhere-else.com", } assert do_login(args, context, jwt_manager) == ( 301, {}, { "Location": "https://somewhere-else.com", - "SET-COOKIE": "foo=bar" - } + "SET-COOKIE": "foo=bar", + }, ) @mock.patch(f"{MODULE}.do_auth", autospec=True) def test_do_login_failed_auth(mock_do_auth, context): mock_do_auth.return_value = {} - jwt_manager = JwtManager('algorithm', 'pub_key', 'priv_key', 'cookie-name') + jwt_manager = JwtManager("algorithm", "pub_key", "priv_key", "cookie-name") assert do_login({"code": "URS_CODE"}, context, jwt_manager) == ( 400, { "contentstring": "There was a problem talking to URS Login", - "title": "Could Not Login" + "title": "Could Not Login", }, - {} + {}, ) @@ -365,48 +377,48 @@ def test_do_login_failed_auth(mock_do_auth, context): def test_do_login_failed_profile(mock_get_profile, mock_do_auth, context): mock_do_auth.return_value = { "endpoint": "ENDPOINT", - "access_token": "ACCESS_TOKEN" + "access_token": "ACCESS_TOKEN", } mock_get_profile.return_value = None - jwt_manager = JwtManager('algorithm', 'pub_key', 'priv_key', 'cookie-name') + jwt_manager = JwtManager("algorithm", "pub_key", "priv_key", "cookie-name") assert do_login({"code": "URS_CODE"}, context, jwt_manager) == ( 400, { "contentstring": "Could not get user profile from URS", - "title": "Could Not Login" + "title": "Could Not Login", }, - {} + {}, ) def test_do_login_error(): - jwt_manager = JwtManager('algorithm', 'pub_key', 'priv_key', 'cookie-name') + jwt_manager = JwtManager("algorithm", "pub_key", "priv_key", "cookie-name") assert do_login({}, {}, jwt_manager) == ( 400, { "contentstring": "No params", - "title": "Could Not Login" + "title": "Could Not Login", }, - {} + {}, ) assert do_login({"error": "URS_ERROR"}, {}, jwt_manager) == ( 400, { "contentstring": 'An error occurred while trying to log into URS. URS says: "URS_ERROR". ', - "title": "Could Not Login" + "title": "Could Not Login", }, - {} + {}, ) assert do_login({"error": "access_denied"}, {}, jwt_manager) == ( 401, { "contentstring": "Be sure to agree to the EULA.", "title": "Could Not Login", - "error_code": "EULA_failure" + "error_code": "EULA_failure", }, - {} + {}, ) assert do_login({"foo": "bar"}, {}, jwt_manager) == ( 400, @@ -414,5 +426,5 @@ def test_do_login_error(): "contentstring": "Did not get the required CODE from URS", "title": "Could Not Login", }, - {} + {}, ) diff --git a/tests/test_view_util.py b/tests/test_view_util.py index 874a047..3def765 100644 --- a/tests/test_view_util.py +++ b/tests/test_view_util.py @@ -42,28 +42,32 @@ def template_dir(data, mocker): return path -cookie_key_characters = st.sampled_from(string.ascii_letters + string.digits + "!#%&'*+-.^_`|~") -cookie_value_characters = st.sampled_from(string.ascii_letters + string.digits + "!#$%&'()*+-./:<=>?@[]^_`{|}~") +cookie_key_characters = st.sampled_from( + string.ascii_letters + string.digits + "!#%&'*+-.^_`|~" +) +cookie_value_characters = st.sampled_from( + string.ascii_letters + string.digits + "!#$%&'()*+-./:<=>?@[]^_`{|}~" +) @mock.patch(f"{MODULE}.retrieve_secret", autospec=True) def test_get_jwt_keys(mock_retrieve_secret): mock_retrieve_secret.return_value = { "foo": "YmFy", - "baz": "cXV4" + "baz": "cXV4", } get_jwt_keys.cache_clear() assert get_jwt_keys() == { "foo": b"bar", - "baz": b"qux" + "baz": b"qux", } @mock.patch(f"{MODULE}.retrieve_secret", autospec=True) def test_get_jwt_keys_error(mock_retrieve_secret): mock_retrieve_secret.return_value = { - "foo": "bar" + "foo": "bar", } get_jwt_keys.cache_clear() @@ -196,7 +200,7 @@ def test_get_cookie_expiration_date_str(mock_time): @given( name=st.text(cookie_key_characters, min_size=1), - value=st.text(cookie_value_characters) + value=st.text(cookie_value_characters), ) def test_get_cookies_valid(name, value): cookie = SimpleCookie() @@ -270,11 +274,13 @@ def test_decode_jwt_payload_expired_token(mock_get_jwt_keys, jwt_pub_key, jwt_pr def test_decode_jwt_payload_invalid_signature(mock_get_jwt_keys, jwt_pub_key): mock_get_jwt_keys.return_value = {"rsa_pub_key": jwt_pub_key} - encoded = b".".join(( - jwt.utils.base64url_encode(b'{"alg": "RS256"}'), - jwt.utils.base64url_encode(b'{"not valid'), - jwt.utils.base64url_encode(b"some bytes"), - )) + encoded = b".".join( + ( + jwt.utils.base64url_encode(b'{"alg": "RS256"}'), + jwt.utils.base64url_encode(b'{"not valid'), + jwt.utils.base64url_encode(b"some bytes"), + ) + ) assert decode_jwt_payload(encoded) == {} @@ -298,7 +304,7 @@ def test_decode_jwt_payload_blacklist( mock_is_jwt_blacklisted, jwt_pub_key, jwt_priv_key, - monkeypatch + monkeypatch, ): mock_get_jwt_keys.return_value = {"rsa_pub_key": jwt_pub_key} mock_is_jwt_blacklisted.return_value = True @@ -314,7 +320,10 @@ def test_decode_jwt_payload_blacklist( @mock.patch(f"{MODULE}.make_jwt_payload", autospec=True) @mock.patch(f"{MODULE}.get_cookie_expiration_date_str", autospec=True) -def test_make_set_cookie_headers_jwt(mock_get_cookie_expiration_date_str, mock_make_jwt_payload): +def test_make_set_cookie_headers_jwt( + mock_get_cookie_expiration_date_str, + mock_make_jwt_payload, +): mock_get_cookie_expiration_date_str.return_value = "THE_EXPDATE" mock_make_jwt_payload.return_value = "THE_JWT_PAYLOAD" @@ -324,7 +333,9 @@ def test_make_set_cookie_headers_jwt(mock_get_cookie_expiration_date_str, mock_m assert make_set_cookie_headers_jwt("", expdate="EXPLICIT_EXPDATE") == { "SET-COOKIE": "asf-urs=THE_JWT_PAYLOAD; Expires=EXPLICIT_EXPDATE; Path=/" } - assert make_set_cookie_headers_jwt("", expdate="EXPLICIT_EXPDATE", cookie_domain="THE_DOMAIN") == { + assert make_set_cookie_headers_jwt( + "", expdate="EXPLICIT_EXPDATE", cookie_domain="THE_DOMAIN" + ) == { "SET-COOKIE": "asf-urs=THE_JWT_PAYLOAD; Expires=EXPLICIT_EXPDATE; Path=/; Domain=THE_DOMAIN" } @@ -332,12 +343,14 @@ def test_make_set_cookie_headers_jwt(mock_get_cookie_expiration_date_str, mock_m @mock.patch(f"{MODULE}.set_jwt_blacklist", autospec=True) @mock.patch(f"{MODULE}.JWT_BLACKLIST", new_callable=dict) def test_is_jwt_blacklisted(jwt_blacklist, mock_set_jwt_blacklist): - jwt_blacklist.update({ - "blacklist": { - "user_id": 1000 - }, - "timestamp": 0 - }) + jwt_blacklist.update( + { + "blacklist": { + "user_id": 1000, + }, + "timestamp": 0, + } + ) assert is_jwt_blacklisted({"urs-user-id": "user_id", "iat": 10}) is True mock_set_jwt_blacklist.assert_called_once() @@ -368,7 +381,7 @@ def test_set_jwt_blacklist(jwt_blacklist, mock_request, mock_time, monkeypatch): assert JWT_BLACKLIST == { "blacklist": {"foo": "bar"}, - "timestamp": 0 + "timestamp": 0, } # The object itself is not touched, only the reference that JWT_BLACKLIST points to is changed assert jwt_blacklist == {} @@ -391,7 +404,7 @@ def test_set_jwt_blacklist_cached(jwt_blacklist, mock_request, mock_time, monkey assert mock_request.urlopen.call_count == 2 assert JWT_BLACKLIST == { "blacklist": {"foo": "bar"}, - "timestamp": 0 + "timestamp": 0, } # The object itself is not touched, only the reference that JWT_BLACKLIST points to is changed assert jwt_blacklist == {} @@ -403,7 +416,7 @@ def test_set_jwt_blacklist_cached(jwt_blacklist, mock_request, mock_time, monkey assert mock_request.urlopen.call_count == 2 assert JWT_BLACKLIST == { "blacklist": {"foo": "bar"}, - "timestamp": 0 + "timestamp": 0, } # Third call, after some time has passed the data is re-fetched @@ -417,5 +430,5 @@ def test_set_jwt_blacklist_cached(jwt_blacklist, mock_request, mock_time, monkey # Variable updated assert JWT_BLACKLIST == { "blacklist": {"baz": "qux"}, - "timestamp": 1000 + "timestamp": 1000, }