A Simple JWT Authentication Tutorial

Ahmet Bilal Akcan
Ahmet Bilal Akcan

Step by step guide to implement JWT Authentication in FastAPI

Hello everyone,

In this beginner-friendly tutorial, we will go through the steps of implementing JWT Authentication in FastAPI using fastapi-another-jwt-auth. We will be using Pydantic for data validation and Passlib for password hashing.

This tutorial assumes that you have basic knowledge of Python and FastAPI with SQLAlchemy which is not going to be covered in this tutorial.

We are going to write some code, so let's jump right into it.

header

Photo by Markus Spiske on Unsplash

Introduction to JWT

JWT (JSON Web Token) is basically an encoded JSON object with which we can securely transfer data between two parties. In our case these two parties are the client and the server.

How does it work?

We call JWTs "self-contained" because they contain all the information needed to authenticate a user. This means that we don't need to store any information about the user on the server-side. This is a huge advantage because it allows us to scale our application easily.

A simple JWT would contain a header, payload and a signature separated by dots.

It will contain the algorithm used to generate the signature and the type of the token.

{
	"alg": "HS256",
	"typ": "JWT"
}

payload

It will contain the data we want to transfer. Some of the fields are reserved and have special meanings. An example to a payload would be:

{
	"sub": "f1677c52-0f8b-4f31-baaf-cc0b444a5b99", // subject
	"iat": 1700776495, // issued at
	"nbf": 1700776495, // not before
	"jti": "fe71746e-397d-4b81-8732-76f8cd042fa3", // JWT ID
	"exp": 1700777695, // expiration time
	"type": "access" // type of the token
}

signature

The signature is generated by combining the header and the payload with a secret key using the algorithm specified in the header. This signature is going to be used to verify that the data has not been changed.

Setting up the project

In this tutorial, we are going to utilize the access and refresh token method. The access token will be used to access the protected routes and the refresh token will be used to get a new access token once the access token expires. For further reading on this topic.

Directory structure

|___ .env
|___ src
|    |___ __init__.py
|    |___ main.py
|    |___ database.py
|    |___ config.py
|    |___ oauth.py
|    |___ auth
|    |    |___ __init__.py
|    |    |___ dependencies.py
|    |    |___ services.py
|    |    |___ schemas.py
|    |    |___ models.py
|    |    |___ router.py
|    |    |___ utils.py
|    |___ dashboard
|         |___ __init__.py
|         |___ router.py
 

Initialize FastAPI

We initialize our FastAPI application in main.py with some default development settings.

main.py

src/main.py

from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
 
from src.database import Base, engine
from src.auth.models import *
 
from src.auth.router import router as auth_router
from src.dashboard.router import router as dashboard_router
 
origins = ["http://localhost:3000"]
 
app = FastAPI()
 
app.add_middleware(
CORSMiddleware,
allow_origins=origins,
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
 
app.include_router(auth_router, tags=["auth"], prefix="/api/auth")
app.include_router(dashboard_router, tags=["dashboard"], prefix="/api/dashboard")
 

.env

Now let's take a look at what we need in our .env file.

.env

...
JWT_PRIVATE_KEY=...
JWT_PUBLIC_KEY=...
JWT_ALGORITHM=RS256

REFRESH_TOKEN_EXPIRE_MINUTES=60
ACCESS_TOKEN_EXPIRE_MINUTES=15
...

config.py

We are going to use the asymmetric RSA algorithm for our JWTs. That means we need to generate a private and a public key. We can do that using this website: cryptotool.net

Ideally, you can use the `openssl` command to generate your keys.

We add our environment variables to config.py

src/config.py

from pydantic import BaseSettings, PostgresDsn
 
class Settings(BaseSettings):
    app_name: str = "A simple JWT Authentication Tutorial"
 
    POSTGRES_USER: str
    POSTGRES_PASSWORD: str
    POSTGRES_SERVER: str
    POSTGRES_PORT: str
    POSTGRES_DB: str
 
    JWT_PRIVATE_KEY: str
    JWT_PUBLIC_KEY: str
    JWT_ALGORITHM: str
 
    REFRESH_TOKEN_EXPIRE_MINUTES: int
    ACCESS_TOKEN_EXPIRE_MINUTES: int
 
    CLIENT_ORIGIN: str
 
    database_url: PostgresDsn = "postgresql://{POSTGRES_USER}:{POSTGRES_PASSWORD}@{POSTGRES_SERVER}:{POSTGRES_PORT}/{POSTGRES_DB}"
 
    class Config:
        env_file = ".env"
 
 
settings = Settings()
 

oauth.py

Now let's take a look at how this JWT package works.

src/oauth.py

import base64
from typing import List
from pydantic import BaseModel
from sqlalchemy.orm import Session
 
from fastapi import Depends, HTTPException, status
from fastapi_another_jwt_auth import AuthJWT
from fastapi_another_jwt_auth.exceptions import InvalidHeaderError, MissingTokenError
 
from src.config import settings
from src.auth.models import User
from src.database import get_db
 
class Settings(BaseModel):
    authjwt_algorithm: str = settings.JWT_ALGORITHM
    authjwt_decode_algorithms: List[str] = [settings.JWT_ALGORITHM]
    authjwt_token_location: set = {"cookies", "headers"}
    authjwt_access_cookie_key: str = "access_token"
    authjwt_refresh_cookie_key: str = "refresh_token"
    authjwt_cookie_csrf_protect: bool = False
    authjwt_public_key: str = base64.b64decode(settings.JWT_PUBLIC_KEY).decode("utf-8")
    authjwt_private_key: str = base64.b64decode(settings.JWT_PRIVATE_KEY).decode("utf-8")
 
@AuthJWT.load_config
def get_config():
    return Settings()

We start with defining our settings. The settings include the algorithm used to generate the signature, the location of the token, the name of the cookies, the public and the private key. Finally we load the configuration. From now on, we can use the AuthJWT instance as dependency in our routes to deal with authentication and authorization.

class UserNotFound(Exception):
    pass
 
def require_user(db: Session = Depends(get_db), Authorize: AuthJWT = Depends()):
    user_id = None
    try:
        Authorize.jwt_required()
        user_id = Authorize.get_jwt_subject()
        user = db.query(User).filter(User.id == user_id).first()
        if not user:
            raise UserNotFound("User no longer exists")
    except MissingTokenError as e:
        if "access" in str(e.message):
            raise HTTPException(
                status_code=status.HTTP_401_UNAUTHORIZED,
                detail="You must refresh",
            )
    except InvalidHeaderError as e:
        raise HTTPException(
            status_code=status.HTTP_403_FORBIDDEN,
            detail="Invalid Header Error",
        )
    except Exception as e:
        error = e.__class__.__name__
        if error == "UserNotFound":
            raise HTTPException(
                status_code=status.HTTP_404_NOT_FOUND, detail="User not found"
            )
    return user_id

We define a dependency function to be used in our routes. This function will be used to get the user id from the JWT. If the user is not found, we raise an exception. If the token is missing, we raise a 401 error. If the header is invalid, we raise a 403 error.

Login route

schemas.py

Right before we start writing our authentication routes, let's create some basic schemas using pydantic to be used in the routes.

src/auth/schemas.py

from pydantic import BaseModel, EmailStr, constr
 
class UserSignin(BaseModel):
    email: EmailStr
    password: constr(min_length=8)
 
class UserSignup(BaseModel):
    first_name: str
    last_name: str
    email: str
    password: constr(min_length=8)
    passwordConfirm: str
 

router.py

user signs up

This is the login flow we are going to take a look at now.

src/auth/router.py

ACCESS_TOKEN_EXPIRE_MINUTES = settings.ACCESS_TOKEN_EXPIRE_MINUTES
REFRESH_TOKEN_EXPIRE_MINUTES = settings.REFRESH_TOKEN_EXPIRE_MINUTES
 
router = APIRouter()
 
@router.post("/login", response_model=User)
async def login(
    payload: UserSignin,
    response: Response,
    db: Session = Depends(get_db),
    Authorize: AuthJWT = Depends(),
):
    user = await authenticate_user(db, payload.email, payload.password)
    if not user:
        raise HTTPException(
            status_code=status.HTTP_400_BAD_REQUEST,
            detail="Incorrect username or password",
            headers={"WWW-Authenticate": "Bearer"},
        )

Let's break down what is going on so far. We are using the Depends function to get the database session and the AuthJWT instance. We are using the authenticate_user function to check if the user exists and the password is correct, which we are going to create now.

src/auth/dependencies.py

from src.auth.services import get_user_by_email
from src.auth.utils import verify_password
 
async def authenticate_user(db: Session, email: str, password: str) -> bool | User:
    user = await get_user_by_email(db, email)
    if user is None:
        return False
    if not verify_password(password, user.hashed_password):
        return False
    return user

We are using the get_user_by_email function to get the user from the database. We verify the password using the verify_password function which comes from our utils.py.

src/auth/utils.py

from passlib.context import CryptContext
 
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
 
def verify_password(raw_password, hashed_password):
    return pwd_context.verify(raw_password, hashed_password)
 

verify_password function takes the raw password and the hashed password which we store in our database and returns a boolean value.

src/auth/router.py

    access_token = create_access_token(Authorize=Authorize, user=user)
    refresh_token = create_refresh_token(Authorize=Authorize, user=user)
 
    response.set_cookie(
        "access_token",
        access_token,
        ACCESS_TOKEN_EXPIRE_MINUTES * 60,
        ACCESS_TOKEN_EXPIRE_MINUTES * 60,
        "/",
        None,
        False,
        True,
        "lax",
    )
    response.set_cookie(
        "refresh_token",
        refresh_token,
        REFRESH_TOKEN_EXPIRE_MINUTES * 60,
        REFRESH_TOKEN_EXPIRE_MINUTES * 60,
        "/",
        None,
        False,
        True,
        "lax",
    )
    return user

We go on with creating the access and the refresh tokens using the create_access_token and the create_refresh_token functions which we are going to create now.

Then we set the cookies using the set_cookie method of the response object from fastAPI. Here is how the set_cookie method of the response object from fastAPI works.

def set_cookie(
    key: str,
    value: str = "",
    max_age: int | None = None,
    expires: datetime | str | int | None = None,
    path: str = "/",
    domain: str | None = None,
    secure: bool = False,
    httponly: bool = False,
    samesite: Literal['lax', 'strict', 'none'] | None = "lax"
) -> None:

src/auth/utils.py

...
def create_access_token(user: User, Authorize: AuthJWT = Depends()):
    access_token = Authorize.create_access_token(
        subject=str(user.id),
        expires_time=timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES),
    )
    return access_token
 
def create_refresh_token(user: User, Authorize: AuthJWT = Depends()):
    refresh_token = Authorize.create_refresh_token(
        subject=str(user.id),
        expires_time=timedelta(minutes=REFRESH_TOKEN_EXPIRE_MINUTES),
    )
    return refresh_token

Notice that we are using the user_id as the subject of the JWT. Later we will use this subject to get the user from the database.

This is it for the login route. Now let's take a look at the signup route which consists of two steps, creating the user and logging the user in. So it should get easier from here.

Signup route

src/auth/router.py

@router.post("/register", status_code=status.HTTP_201_CREATED, response_model=User)
async def signup_user(
    payload: UserSignup,
    response: Response,
    db: Session = Depends(get_db),
    Authorize: AuthJWT = Depends(),
):
    checked_user = await check_user(db=db, email=payload.email)
    if checked_user:
        raise HTTPException(
            status_code=status.HTTP_409_CONFLICT,
            detail="An account by that email already registered!",
        )
    if payload.password != payload.passwordConfirm:
        raise HTTPException(
            status_code=status.HTTP_400_BAD_REQUEST,
            detail="Passwords do not match!",
        )
    new_user = await create_user(db=db, payload=payload)
 
    access_token = create_access_token(Authorize=Authorize, user=new_user)
    refresh_token = create_refresh_token(Authorize=Authorize, user=new_user)
 
    response.set_cookie(
        "access_token",
        access_token,
        ACCESS_TOKEN_EXPIRE_MINUTES * 60,
        ACCESS_TOKEN_EXPIRE_MINUTES * 60,
        "/",
        None,
        False,
        True,
        "lax",
    )
    response.set_cookie(
        "refresh_token",
        refresh_token,
        REFRESH_TOKEN_EXPIRE_MINUTES * 60,
        REFRESH_TOKEN_EXPIRE_MINUTES * 60,
        "/",
        None,
        False,
        True,
        "lax",
    )
    return user

We begin with checking if the user already exists using the check_user function. It is a simple function that checks if the user exists in the database using SQLAlchemy. If the user exists, we raise a 409 error.

Note that we perform a server-side password confirm check. Though, it may not be necessary depending on your client-side

If the passwords do not match, we raise a 400 error. If everything is fine, we create the user using the create_user function.

The rest is the same as the login route; we create the access and the refresh tokens and set the cookies.

Refresh route

Now that the user is logged in, we can take a look at the refresh route. As we talked about earlier, the access token expires earlier than the refresh token. So, we need a router to renew the access token.

src/auth/router.py

@router.get("/refresh", response_model=User)
def refresh_token(
    response: Response,
    Authorize: AuthJWT = Depends(),
    db: Session = Depends(get_db),
):
    try:
        Authorize.jwt_refresh_token_required()
        user_id = Authorize.get_jwt_subject()
 
        if not user_id:
            raise HTTPException(
                status_code=status.HTTP_404_NOT_FOUND,
                detail="Could not identify token",
            )
        user = get_user_by_id(db, user_id)
        if not user:
            raise HTTPException(
                status_code=status.HTTP_401_UNAUTHORIZED,
                detail="The user belonging to this token does not exist.",
            )
        access_token = Authorize.create_access_token(
            subject=str(user.id),
            expires_time=timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES),
        )
    except MissingTokenError as e:
        if "refresh" in e.message:
            raise HTTPException(
                status_code=status.HTTP_403_FORBIDDEN,
                detail="Missing refresh token!",
            )
        raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=e)
    response.set_cookie(
        "access_token",
        access_token,
        ACCESS_TOKEN_EXPIRE_MINUTES * 60,
        ACCESS_TOKEN_EXPIRE_MINUTES * 60,
        "/",
        None,
        False,
        True,
        "lax",
    )
    return user

We start with checking if the refresh token is missing. If it is, we raise a 403 error. If the refresh token is present and we can extract the user_id from the JWT, we create a new access token and set the cookie.

Dashboard route

following requests made with JWT

Now we are going to take a look at once the user is logged in and tries to access a protected route.

src/dashboard/router.py

router = APIRouter()
 
@router.get("/dashboard")
async def get_dashboard(
    user_id: str = Depends(require_user),
    db: Session = Depends(get_db)
):
    if not user_id:
        raise HTTPException(
            status_code=status.HTTP_404_NOT_FOUND,
            detail="User not found",
        )
        # do something with the user
    return {"message": "Hello from the dashboard!"}

Note how we use the require_user function as a dependency to get the user id from the JWT. If the user is not found, we raise a 404 error. It is that simple to protecte a route from unauthorized users.

Logout route

src/auth/router.py

@router.post("/logout", status_code=status.HTTP_200_OK)
async def logout(
    response: Response,
    Authorize: AuthJWT = Depends(),
    user_id: str = Depends(require_user),
    db: Session = Depends(get_db),
):
    if user_id is not None:
        try:
            Authorize.unset_jwt_cookies()
        except Exception as e:
            return {"status": "error", "message": str(e)}
        return {"status": "success"}
    raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="User not found!")

We use the unset_jwt_cookies method of the AuthJWT instance to unset the cookies. That's it for the logout route.

Conclusion

That's a wrap! We have successfully implemented JWT Authentication in FastAPI using fastapi-another-jwt-auth. I hope you enjoyed this tutorial and learned something new. Feel free to reach out to me if you have any questions or suggestions. Cheers ✌🏼

December 11, 2023