|
1 |
| -"Dependencies for endpoints.py" |
| 1 | +""" |
| 2 | +Dependencies for endpoints.py |
| 3 | +""" |
| 4 | + |
2 | 5 | from typing import Annotated
|
3 | 6 |
|
4 |
| -from fastapi import Cookie, Depends, HTTPException, Response, status |
| 7 | +from fastapi import Cookie, Depends, HTTPException, status |
5 | 8 | from fastapi.security import OAuth2PasswordRequestForm
|
6 |
| -from sqlalchemy import select |
7 | 9 | from sqlalchemy.ext.asyncio import AsyncSession
|
8 |
| -from jwt import ExpiredSignatureError, InvalidTokenError |
| 10 | +from jwt import InvalidTokenError |
9 | 11 |
|
10 | 12 |
|
11 | 13 | from core.models import pg_db_helper, User
|
12 |
| -from auth.utils import decode_jwt, encode_jwt, validate_password |
| 14 | +from auth.utils import decode_jwt, validate_password |
13 | 15 | from api_v1.users.crud import get_user_by_username, get_user_by_id
|
14 | 16 |
|
15 | 17 |
|
16 | 18 | async def validate_auth_user_password(
|
17 | 19 | form_data: OAuth2PasswordRequestForm = Depends(),
|
18 |
| - session: AsyncSession = Depends(pg_db_helper.get_scoped_session), |
| 20 | + session: AsyncSession = Depends(pg_db_helper.scoped_session_dependency), |
19 | 21 | ):
|
20 | 22 | """
|
21 | 23 | Helper Function for login route to check a user by password
|
@@ -46,67 +48,53 @@ async def validate_auth_user_password(
|
46 | 48 | return user_from_db
|
47 | 49 |
|
48 | 50 |
|
49 |
| -async def validate_tokens( |
50 |
| - response: Response, |
| 51 | +async def validate_access_token( |
51 | 52 | access_token: Annotated[str | None, Cookie()] = None,
|
52 |
| - refresh_token: Annotated[str | None, Cookie()] = None, |
53 | 53 | ) -> dict:
|
54 |
| - if not refresh_token: |
| 54 | + if not access_token: |
55 | 55 | raise HTTPException(
|
56 | 56 | status_code=status.HTTP_401_UNAUTHORIZED,
|
57 | 57 | detail="user is not logged in",
|
58 | 58 | )
|
59 |
| - refresh_payload = {} |
| 59 | + token_payload = {} |
60 | 60 | try:
|
61 |
| - refresh_payload = decode_jwt( |
62 |
| - token=refresh_token, |
| 61 | + token_payload = decode_jwt( |
| 62 | + token=access_token, |
63 | 63 | )
|
64 | 64 | except InvalidTokenError as e:
|
65 | 65 | raise HTTPException(
|
66 | 66 | status_code=status.HTTP_401_UNAUTHORIZED,
|
67 | 67 | # NOTE: REMOVE EXCEPTION IN PROD
|
68 |
| - detail=f"invalid refresh token error: {e}", |
69 |
| - ) |
70 |
| - # if we dont need to create new access_token, use old one as new :) |
71 |
| - new_access_token = access_token |
72 |
| - # if access_token is expired, generate new one |
73 |
| - try: |
74 |
| - decode_jwt( |
75 |
| - token=access_token, |
| 68 | + detail=f"invalid access token error: {e}", |
76 | 69 | )
|
77 |
| - except ExpiredSignatureError as e: |
78 |
| - new_access_token = encode_jwt(payload={"sub": refresh_payload["sub"]}) |
79 |
| - |
80 |
| - # return new access token to set to cookie |
81 |
| - response.set_cookie("access_token", new_access_token) |
82 |
| - return refresh_payload |
| 70 | + return token_payload |
83 | 71 |
|
84 | 72 |
|
85 |
| -async def get_current_refresh_token_payload( |
86 |
| - refresh_token: Annotated[str | None, Cookie()] = None, |
| 73 | +async def get_current_access_token_payload( |
| 74 | + access_token: Annotated[str | None, Cookie()] = None, |
87 | 75 | ) -> dict:
|
88 |
| - if not refresh_token: |
| 76 | + if not access_token: |
89 | 77 | raise HTTPException(
|
90 | 78 | status_code=status.HTTP_401_UNAUTHORIZED,
|
91 | 79 | detail="user is not logged in",
|
92 | 80 | )
|
93 | 81 | payload = {}
|
94 | 82 | try:
|
95 | 83 | payload = decode_jwt(
|
96 |
| - token=refresh_token, |
| 84 | + token=access_token, |
97 | 85 | )
|
98 | 86 | except InvalidTokenError as e:
|
99 | 87 | raise HTTPException(
|
100 | 88 | status_code=status.HTTP_401_UNAUTHORIZED,
|
101 | 89 | # NOTE: REMOVE EXCEPTION IN PROD
|
102 |
| - detail=f"invalid refresh token error: {e}", |
| 90 | + detail=f"invalid access token error: {e}", |
103 | 91 | )
|
104 | 92 | return payload
|
105 | 93 |
|
106 | 94 |
|
107 | 95 | async def get_current_auth_user(
|
108 |
| - payload: dict = Depends(get_current_refresh_token_payload), |
109 |
| - session: AsyncSession = Depends(pg_db_helper.get_scoped_session), |
| 96 | + payload: dict = Depends(get_current_access_token_payload), |
| 97 | + session: AsyncSession = Depends(pg_db_helper.scoped_session_dependency), |
110 | 98 | ) -> User:
|
111 | 99 | id: int | None = payload.get("sub")
|
112 | 100 |
|
|
0 commit comments