diff --git a/fastapi-sqlalchemy/alembic/versions/2b3c8830abf6_add_auth_model.py b/fastapi-sqlalchemy/alembic/versions/2b3c8830abf6_add_auth_model.py new file mode 100644 index 00000000..807635c4 --- /dev/null +++ b/fastapi-sqlalchemy/alembic/versions/2b3c8830abf6_add_auth_model.py @@ -0,0 +1,34 @@ +"""Add auth model + +Revision ID: 2b3c8830abf6 +Revises: bea5e58f3328 +Create Date: 2021-09-10 16:22:57.620951 + +""" +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision = "2b3c8830abf6" +down_revision = "bea5e58f3328" +branch_labels = None +depends_on = None + + +def upgrade(): + op.create_table( + "users", + sa.Column("id", sa.Integer(), nullable=False), + sa.Column("email", sa.String(), nullable=False), + sa.Column("password_hash", sa.String(), nullable=False), + sa.PrimaryKeyConstraint("id"), + ) + op.create_index(op.f("ix_users_email"), "users", ["email"], unique=True) + op.create_index(op.f("ix_users_id"), "users", ["id"], unique=False) + + +def downgrade(): + op.drop_index(op.f("ix_users_id"), table_name="users") + op.drop_index(op.f("ix_users_email"), table_name="users") + op.drop_table("users") diff --git a/fastapi-sqlalchemy/api/definitions/user.py b/fastapi-sqlalchemy/api/definitions/user.py new file mode 100644 index 00000000..ffbd396f --- /dev/null +++ b/fastapi-sqlalchemy/api/definitions/user.py @@ -0,0 +1,16 @@ +import strawberry + +from main.models import User as UserModel + + +@strawberry.type +class User: + id: int + email: str + + @classmethod + def from_instance(cls, instance: UserModel): + return cls( + id=instance.id, + email=instance.email, + ) diff --git a/fastapi-sqlalchemy/api/mutation.py b/fastapi-sqlalchemy/api/mutation.py new file mode 100644 index 00000000..eff0a752 --- /dev/null +++ b/fastapi-sqlalchemy/api/mutation.py @@ -0,0 +1,13 @@ +from strawberry.tools import create_type + +from .mutations.register_user import register_user +from .mutations.login_user import login_user + + +Mutation = create_type( + "Mutation", + [ + register_user, + login_user, + ], +) diff --git a/fastapi-sqlalchemy/api/mutations/__init__.py b/fastapi-sqlalchemy/api/mutations/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/fastapi-sqlalchemy/api/mutations/login_user.py b/fastapi-sqlalchemy/api/mutations/login_user.py new file mode 100644 index 00000000..00790301 --- /dev/null +++ b/fastapi-sqlalchemy/api/mutations/login_user.py @@ -0,0 +1,44 @@ +import re +import strawberry + +from main.auth import login +from main.models import get_user_by_email + +from ..definitions.user import User + + +@strawberry.type +class LoginUserSuccess: + user: User + + +@strawberry.type +class LoginUserError: + error_message: str + + +LoginUserResponse = strawberry.union( + "LoginUserResponse", types=(LoginUserSuccess, LoginUserError) +) + + +@strawberry.mutation +def login_user(info, email: str, password: str) -> LoginUserResponse: + if not re.fullmatch(r"[^@]+@[^@]+\.[^@]+", email): + return LoginUserError(error_message="Invalid email") + + db = info.context["db"] + + user = get_user_by_email(db, email) + if not user: + return LoginUserError(error_message="User not found") + + if not user.check_password(password): + return LoginUserError(error_message="Invalid password") + + # Login user + login(info.context["request"], user) + + return LoginUserSuccess( + user=User.from_instance(user), + ) diff --git a/fastapi-sqlalchemy/api/mutations/register_user.py b/fastapi-sqlalchemy/api/mutations/register_user.py new file mode 100644 index 00000000..d7a2e405 --- /dev/null +++ b/fastapi-sqlalchemy/api/mutations/register_user.py @@ -0,0 +1,60 @@ +import re +import strawberry + +from main.auth import login +from main.models import User as UserModel, get_user_by_email + +from ..definitions.user import User + + +@strawberry.input +class RegisterUserInput: + email: str + password: str + + +@strawberry.type +class RegisterUserSuccess: + user: User + + +@strawberry.type +class RegisterUserError: + error_message: str + + +RegisterUserResponse = strawberry.union( + "RegisterUserResponse", types=(RegisterUserSuccess, RegisterUserError) +) + + +@strawberry.mutation +def register_user(info, data: RegisterUserInput) -> RegisterUserResponse: + email = data.email + password = data.password + + if not re.fullmatch(r"[^@]+@[^@]+\.[^@]+", email): + return RegisterUserError(error_message="Invalid email") + + if len(password) < 4: + return RegisterUserError(error_message="Password too short") + + db = info.context["db"] + + existing_user = get_user_by_email(db, email) + if existing_user: + return RegisterUserError(error_message="User already exists") + + user = UserModel( + email=email, + ) + user.set_password(password) + db.add(user) + db.commit() + + # Login user + login(info.context["request"], user) + + return RegisterUserSuccess( + user=User.from_instance(user), + ) diff --git a/fastapi-sqlalchemy/api/schema.py b/fastapi-sqlalchemy/api/schema.py index 0d4c9273..dbc08768 100644 --- a/fastapi-sqlalchemy/api/schema.py +++ b/fastapi-sqlalchemy/api/schema.py @@ -1,4 +1,4 @@ -from typing import List +from typing import List, Optional import strawberry from strawberry.extensions import Extension @@ -6,7 +6,9 @@ from main.models import get_movies from main.database import SessionLocal +from .mutation import Mutation from .definitions.movie import Movie +from .definitions.user import User class SQLAlchemySession(Extension): @@ -25,5 +27,14 @@ def top_rated_movies(self, info, limit: int = 250) -> List[Movie]: movies = get_movies(db, limit=limit) return [Movie.from_instance(movie) for movie in movies] + @strawberry.field + def current_user(self, info) -> Optional[User]: + request = info.context["request"] + + if request.user.is_authenticated: + return User.from_instance(request.user) + + return None + -schema = strawberry.Schema(Query, extensions=[SQLAlchemySession]) +schema = strawberry.Schema(Query, mutation=Mutation, extensions=[SQLAlchemySession]) diff --git a/fastapi-sqlalchemy/main/__init__.py b/fastapi-sqlalchemy/main/__init__.py index 98685315..4d46dfcd 100644 --- a/fastapi-sqlalchemy/main/__init__.py +++ b/fastapi-sqlalchemy/main/__init__.py @@ -1,9 +1,20 @@ from fastapi import FastAPI from strawberry.asgi import GraphQL +from starlette.middleware import Middleware +from starlette.middleware.authentication import AuthenticationMiddleware +from starlette.middleware.sessions import SessionMiddleware from api.schema import schema +from .middleware import SessionBackend + + +middleware = [ + Middleware(SessionMiddleware, secret_key="supersecretkey"), + Middleware(AuthenticationMiddleware, backend=SessionBackend()), +] + graphql_app = GraphQL(schema) -app = FastAPI() +app = FastAPI(middleware=middleware) app.mount("/graphql", graphql_app) diff --git a/fastapi-sqlalchemy/main/auth.py b/fastapi-sqlalchemy/main/auth.py new file mode 100644 index 00000000..ddec82ff --- /dev/null +++ b/fastapi-sqlalchemy/main/auth.py @@ -0,0 +1,33 @@ +from typing import Optional +from sqlalchemy import select +from sqlalchemy.exc import NoResultFound + +from main.models import User + + +SESSION_KEY = "_auth_user_id" + + +def login(request, user: User): + session = request.session + session[SESSION_KEY] = user.id + + +def logout(request): + session = request.session + del session[SESSION_KEY] + + +def get_user(db, request) -> Optional[User]: + session = request.session + if SESSION_KEY not in session: + return None + + try: + user = db.execute( + select(User).filter_by(id=request.session[SESSION_KEY]) + ).scalar_one() + except NoResultFound: + return None + + return user diff --git a/fastapi-sqlalchemy/main/middleware.py b/fastapi-sqlalchemy/main/middleware.py new file mode 100644 index 00000000..f461bf87 --- /dev/null +++ b/fastapi-sqlalchemy/main/middleware.py @@ -0,0 +1,29 @@ +from starlette.authentication import AuthenticationBackend, BaseUser, AuthCredentials + +from main.auth import get_user +from main.database import SessionLocal +from main.models import User + + +class ProxyDBUser(BaseUser): + def __init__(self, instance: User): + self.instance = instance + + # Proxy attributes from instance + def __getattr__(self, name): + return getattr(self.instance, name) + + @property + def is_authenticated(self) -> bool: + return True + + +class SessionBackend(AuthenticationBackend): + async def authenticate(self, request): + db = SessionLocal() + user = get_user(db, request) + + if not user: + return + + return AuthCredentials(["authenticated"]), ProxyDBUser(user) diff --git a/fastapi-sqlalchemy/main/models.py b/fastapi-sqlalchemy/main/models.py index 8808f781..a89daf02 100644 --- a/fastapi-sqlalchemy/main/models.py +++ b/fastapi-sqlalchemy/main/models.py @@ -1,6 +1,9 @@ +from typing import Optional from sqlalchemy import Column, Integer, String, Float, ForeignKey, select from sqlalchemy.orm import relationship, joinedload from sqlalchemy.orm import Session +from sqlalchemy.exc import NoResultFound +from passlib.hash import pbkdf2_sha256 from .database import Base @@ -37,3 +40,25 @@ def get_movies(db: Session, limit: int = 250): result = db.execute(query).unique() return result.scalars() + + +class User(Base): + __tablename__ = "users" + + id: int = Column(Integer, primary_key=True, index=True, nullable=False) + email: str = Column(String, unique=True, index=True, nullable=False) + password_hash: str = Column(String, nullable=False) + + def set_password(self, password: str): + self.password_hash = pbkdf2_sha256.hash(password) + + def check_password(self, password: str): + return pbkdf2_sha256.verify(password, self.password_hash) + + +def get_user_by_email(db, email: str) -> Optional[User]: + try: + user = db.execute(select(User).filter_by(email=email)).scalar_one() + return user + except NoResultFound: + return None diff --git a/fastapi-sqlalchemy/poetry.lock b/fastapi-sqlalchemy/poetry.lock index c34dd763..0202ccba 100644 --- a/fastapi-sqlalchemy/poetry.lock +++ b/fastapi-sqlalchemy/poetry.lock @@ -1,6 +1,6 @@ [[package]] name = "alembic" -version = "1.7.1" +version = "1.7.3" description = "A database migration tool for SQLAlchemy." category = "main" optional = false @@ -45,7 +45,7 @@ tests_no_zope = ["coverage[toml] (>=5.0.2)", "hypothesis", "pympler", "pytest (> [[package]] name = "black" -version = "21.8b0" +version = "21.9b0" description = "The uncompromising code formatter." category = "dev" optional = false @@ -146,7 +146,7 @@ toml = "*" [[package]] name = "flake8-bugbear" -version = "21.4.3" +version = "21.9.1" description = "A plugin for flake8 finding likely bugs and design problems in your program. Contains warnings that don't belong in pyflakes and pycodestyle." category = "dev" optional = false @@ -229,6 +229,14 @@ zipp = {version = ">=3.1.0", markers = "python_version < \"3.10\""} docs = ["sphinx", "jaraco.packaging (>=8.2)", "rst.linker (>=1.9)"] testing = ["pytest (>=4.6)", "pytest-checkdocs (>=2.4)", "pytest-flake8", "pytest-cov", "pytest-enabler (>=1.0.1)", "pytest-black (>=0.3.7)", "pytest-mypy"] +[[package]] +name = "itsdangerous" +version = "2.0.1" +description = "Safely pass data to untrusted environments and back." +category = "main" +optional = false +python-versions = ">=3.6" + [[package]] name = "mako" version = "1.1.5" @@ -286,6 +294,20 @@ category = "main" optional = false python-versions = "*" +[[package]] +name = "passlib" +version = "1.7.4" +description = "comprehensive password hashing framework supporting over 30 schemes" +category = "main" +optional = false +python-versions = "*" + +[package.extras] +argon2 = ["argon2-cffi (>=18.2.0)"] +bcrypt = ["bcrypt (>=3.1.0)"] +build_docs = ["sphinx (>=1.6)", "sphinxcontrib-fulltoc (>=1.2.0)", "cloud-sptheme (>=1.10.1)"] +totp = ["cryptography"] + [[package]] name = "pathspec" version = "0.9.0" @@ -586,12 +608,12 @@ testing = ["pytest (>=4.6)", "pytest-checkdocs (>=2.4)", "pytest-flake8", "pytes [metadata] lock-version = "1.1" python-versions = "^3.7" -content-hash = "98fa69364a20e1543505c385c05fd9be26fdc5abbb60fc0e52136a65a1036d88" +content-hash = "8ff4ed5d052381824fa81934c554f559c82362b5a7acbae81c8e9454d52f766c" [metadata.files] alembic = [ - {file = "alembic-1.7.1-py3-none-any.whl", hash = "sha256:25f996b7408b11493d6a2d669fd9d2ff8d87883fe7434182bc7669d6caa526ab"}, - {file = "alembic-1.7.1.tar.gz", hash = "sha256:aea964d3dcc9c205b8759e4e9c1c3935ea3afeee259bffd7ed8414f8085140fb"}, + {file = "alembic-1.7.3-py3-none-any.whl", hash = "sha256:d0c580041f9f6487d5444df672a83da9be57398f39d6c1802bbedec6fefbeef6"}, + {file = "alembic-1.7.3.tar.gz", hash = "sha256:bc5bdf03d1b9814ee4d72adc0b19df2123f6c50a60c1ea761733f3640feedb8d"}, ] asgiref = [ {file = "asgiref-3.4.1-py3-none-any.whl", hash = "sha256:ffc141aa908e6f175673e7b1b3b7af4fdb0ecb738fc5c8b88f69f055c2415214"}, @@ -602,8 +624,8 @@ attrs = [ {file = "attrs-21.2.0.tar.gz", hash = "sha256:ef6aaac3ca6cd92904cdd0d83f629a15f18053ec84e6432106f7a4d04ae4f5fb"}, ] black = [ - {file = "black-21.8b0-py3-none-any.whl", hash = "sha256:2a0f9a8c2b2a60dbcf1ccb058842fb22bdbbcb2f32c6cc02d9578f90b92ce8b7"}, - {file = "black-21.8b0.tar.gz", hash = "sha256:570608d28aa3af1792b98c4a337dbac6367877b47b12b88ab42095cfc1a627c2"}, + {file = "black-21.9b0-py3-none-any.whl", hash = "sha256:380f1b5da05e5a1429225676655dddb96f5ae8c75bdf91e53d798871b902a115"}, + {file = "black-21.9b0.tar.gz", hash = "sha256:7de4cfc7eb6b710de325712d40125689101d21d25283eed7e9998722cf10eb91"}, ] cached-property = [ {file = "cached-property-1.5.2.tar.gz", hash = "sha256:9fa5755838eecbb2d234c3aa390bd80fbd3ac6b6869109bfc1b499f7bd89a130"}, @@ -630,8 +652,8 @@ flake8-black = [ {file = "flake8_black-0.2.3-py3-none-any.whl", hash = "sha256:cc080ba5b3773b69ba102b6617a00cc4ecbad8914109690cfda4d565ea435d96"}, ] flake8-bugbear = [ - {file = "flake8-bugbear-21.4.3.tar.gz", hash = "sha256:2346c81f889955b39e4a368eb7d508de723d9de05716c287dc860a4073dc57e7"}, - {file = "flake8_bugbear-21.4.3-py36.py37.py38-none-any.whl", hash = "sha256:4f305dca96be62bf732a218fe6f1825472a621d3452c5b994d8f89dae21dbafa"}, + {file = "flake8-bugbear-21.9.1.tar.gz", hash = "sha256:2f60c8ce0dc53d51da119faab2d67dea978227f0f92ed3c44eb7d65fb2e06a96"}, + {file = "flake8_bugbear-21.9.1-py36.py37.py38-none-any.whl", hash = "sha256:45bfdccfb9f2d8aa140e33cac8f46f1e38215c13d5aa8650e7e188d84e2f94c6"}, ] graphql-core = [ {file = "graphql-core-3.1.6.tar.gz", hash = "sha256:e65975b6a13878f9113a1fa5320760585b522d139944e005936b1b8358d0651a"}, @@ -718,6 +740,10 @@ importlib-resources = [ {file = "importlib_resources-5.2.2-py3-none-any.whl", hash = "sha256:2480d8e07d1890056cb53c96e3de44fead9c62f2ba949b0f2e4c4345f4afa977"}, {file = "importlib_resources-5.2.2.tar.gz", hash = "sha256:a65882a4d0fe5fbf702273456ba2ce74fe44892c25e42e057aca526b702a6d4b"}, ] +itsdangerous = [ + {file = "itsdangerous-2.0.1-py3-none-any.whl", hash = "sha256:5174094b9637652bdb841a3029700391451bd092ba3db90600dea710ba28e97c"}, + {file = "itsdangerous-2.0.1.tar.gz", hash = "sha256:9e724d68fc22902a1435351f84c3fb8623f303fffcc566a4cb952df8c572cff0"}, +] mako = [ {file = "Mako-1.1.5-py2.py3-none-any.whl", hash = "sha256:6804ee66a7f6a6416910463b00d76a7b25194cd27f1918500c5bd7be2a088a23"}, {file = "Mako-1.1.5.tar.gz", hash = "sha256:169fa52af22a91900d852e937400e79f535496191c63712e3b9fda5a9bed6fc3"}, @@ -811,6 +837,10 @@ mypy-extensions = [ {file = "mypy_extensions-0.4.3-py2.py3-none-any.whl", hash = "sha256:090fedd75945a69ae91ce1303b5824f428daf5a028d2f6ab8a299250a846f15d"}, {file = "mypy_extensions-0.4.3.tar.gz", hash = "sha256:2d82818f5bb3e369420cb3c4060a7970edba416647068eb4c5343488a6c604a8"}, ] +passlib = [ + {file = "passlib-1.7.4-py2.py3-none-any.whl", hash = "sha256:aa6bca462b8d8bda89c70b382f0c298a20b5560af6cbfa2dce410c0a2fb669f1"}, + {file = "passlib-1.7.4.tar.gz", hash = "sha256:defd50f72b65c5402ab2c573830a6978e5f202ad0d984793c8dde2c4152ebe04"}, +] pathspec = [ {file = "pathspec-0.9.0-py2.py3-none-any.whl", hash = "sha256:7d15c4ddb0b5c802d161efc417ec1a2558ea2653c2e8ad9c19098201dc1c993a"}, {file = "pathspec-0.9.0.tar.gz", hash = "sha256:e564499435a2673d586f6b2130bb5b95f04a3ba06f81b8f895b651a3c76aabb1"}, diff --git a/fastapi-sqlalchemy/pyproject.toml b/fastapi-sqlalchemy/pyproject.toml index 2859ed3b..3dff79ad 100644 --- a/fastapi-sqlalchemy/pyproject.toml +++ b/fastapi-sqlalchemy/pyproject.toml @@ -11,6 +11,8 @@ SQLAlchemy = {extras = ["mypy"], version = "^1.4.23"} alembic = "^1.7.1" strawberry-graphql = {extras = ["asgi"], version = "^0.77.0"} fastapi = "^0.68.1" +passlib = "^1.7.4" +itsdangerous = "^2.0.1" [tool.poetry.dev-dependencies] diff --git a/fastapi-sqlalchemy/setup.cfg b/fastapi-sqlalchemy/setup.cfg index a4fec64b..2f3075d2 100644 --- a/fastapi-sqlalchemy/setup.cfg +++ b/fastapi-sqlalchemy/setup.cfg @@ -1,2 +1,5 @@ [mypy] plugins = sqlalchemy.ext.mypy.plugin,strawberry.ext.mypy_plugin + +[mypy-passlib.*] +ignore_missing_imports = True