diff --git a/fasthtml/_modidx.py b/fasthtml/_modidx.py
index 412a2a14..2577e3ab 100644
--- a/fasthtml/_modidx.py
+++ b/fasthtml/_modidx.py
@@ -254,6 +254,17 @@
'fasthtml.pico.PicoBusy': ('api/pico.html#picobusy', 'fasthtml/pico.py'),
'fasthtml.pico.Search': ('api/pico.html#search', 'fasthtml/pico.py'),
'fasthtml.pico.set_pico_cls': ('api/pico.html#set_pico_cls', 'fasthtml/pico.py')},
+ 'fasthtml.ratelimit': { 'fasthtml.ratelimit.TokenBucket': ('api/ratelimit.html#tokenbucket', 'fasthtml/ratelimit.py'),
+ 'fasthtml.ratelimit.TokenBucket.__init__': ( 'api/ratelimit.html#tokenbucket.__init__',
+ 'fasthtml/ratelimit.py'),
+ 'fasthtml.ratelimit.TokenBucket.__repr__': ( 'api/ratelimit.html#tokenbucket.__repr__',
+ 'fasthtml/ratelimit.py'),
+ 'fasthtml.ratelimit.TokenBucket._prune': ( 'api/ratelimit.html#tokenbucket._prune',
+ 'fasthtml/ratelimit.py'),
+ 'fasthtml.ratelimit.TokenBucket.wait': ('api/ratelimit.html#tokenbucket.wait', 'fasthtml/ratelimit.py'),
+ 'fasthtml.ratelimit.client_ip': ('api/ratelimit.html#client_ip', 'fasthtml/ratelimit.py'),
+ 'fasthtml.ratelimit.limiter': ('api/ratelimit.html#limiter', 'fasthtml/ratelimit.py'),
+ 'fasthtml.ratelimit.parse_rate': ('api/ratelimit.html#parse_rate', 'fasthtml/ratelimit.py')},
'fasthtml.starlette': {},
'fasthtml.stripe_otp': { 'fasthtml.stripe_otp.Payment': ('explains/stripe.html#payment', 'fasthtml/stripe_otp.py'),
'fasthtml.stripe_otp._search_app': ('explains/stripe.html#_search_app', 'fasthtml/stripe_otp.py'),
diff --git a/fasthtml/core.py b/fasthtml/core.py
index 52ba7419..4c8ca3bd 100644
--- a/fasthtml/core.py
+++ b/fasthtml/core.py
@@ -20,7 +20,7 @@
from fastcore.style import S
from types import UnionType, SimpleNamespace as ns, GenericAlias
-from typing import get_args, get_origin, Union, Mapping, List, Any
+from typing import get_args, get_origin, Union, Mapping, List, Any, Callable
from datetime import datetime,date
from dataclasses import dataclass
from inspect import Parameter,get_annotations
@@ -664,7 +664,7 @@ def add_route(self:FastHTML, route):
# %% ../nbs/api/00_core.ipynb #26b147ba
@patch
-def _endp(self:FastHTML, f, body_wrap, before=None):
+def _endp(self:FastHTML, f, body_wrap, before:Optional[Callable|tuple]=None):
"Create endpoint wrapper with before/after middleware processing"
sig = signature_ex(f, True)
for n,p in sig.parameters.items(): (msg:=_check_anno(n,p.annotation)) and warn(msg)
@@ -679,7 +679,8 @@ async def _f(req):
else: bf,skip = b,[]
if not any(re.fullmatch(r, req.url.path) for r in skip):
resp = await _wrap_call(bf, req, _params(bf))
- if not resp and before: resp = await _wrap_call(before, req, _params(before))
+ for b in listify(before):
+ if not resp: resp = await _wrap_call(b, req, _params(b))
req.body_wrap = body_wrap
if not resp: resp = await _wrap_call(f, req, sig.parameters)
for a in self.after:
@@ -734,7 +735,7 @@ def nested_name(f):
# %% ../nbs/api/00_core.ipynb #72760b09
@patch
-def _add_route(self:FastHTML, func, path, methods, name, include_in_schema, body_wrap, host=None, before=None):
+def _add_route(self:FastHTML, func, path, methods, name, include_in_schema, body_wrap, host=None, before:Optional[Callable|tuple]=None):
"Add HTTP route to FastHTML app with automatic method detection"
n,fn,p = name,nested_name(func),None if callable(path) else path
if methods: m = [methods] if isinstance(methods,str) else methods
@@ -751,7 +752,7 @@ def _add_route(self:FastHTML, func, path, methods, name, include_in_schema, body
# %% ../nbs/api/00_core.ipynb #f5cb2c2b
@patch
-def route(self:FastHTML, path:str=None, methods=None, name=None, include_in_schema=True, body_wrap=None, host=None, before=None):
+def route(self:FastHTML, path:str=None, methods=None, name=None, include_in_schema=True, body_wrap=None, host=None, before:Optional[Callable|tuple]=None):
"Add a route at `path`"
def f(func):
return self._add_route(func, path, methods, name=name, include_in_schema=include_in_schema, body_wrap=body_wrap, host=host, before=before)
diff --git a/fasthtml/ratelimit.py b/fasthtml/ratelimit.py
new file mode 100644
index 00000000..dc067b23
--- /dev/null
+++ b/fasthtml/ratelimit.py
@@ -0,0 +1,74 @@
+"""Simple token-bucket rate limiting for FastHTML routes"""
+
+# AUTOGENERATED! DO NOT EDIT! File to edit: ../nbs/api/07_ratelimit.ipynb.
+
+# %% auto #0
+__all__ = ['parse_rate', 'TokenBucket', 'client_ip', 'limiter']
+
+# %% ../nbs/api/07_ratelimit.ipynb #b33aec8a
+import re,time
+from math import ceil
+from typing import Callable
+from starlette.responses import Response
+from fastcore.utils import *
+
+# %% ../nbs/api/07_ratelimit.ipynb #1078d10b
+_units = dict(s=1, m=60, h=3600, d=86400)
+
+def parse_rate(s):
+ "Parse rate string like `'5/m'`, `'1 per day'`, `'100/2h'`"
+ m = re.match(r'(\d+)\s*(?:/|per)\s*(\d+)?\s*([smhd])', s, re.I)
+ if not m: raise ValueError(f"Invalid rate: {s}")
+ n,mult,unit = m.groups()
+ return int(n), _units[unit.lower()] * int(mult or 1)
+
+# %% ../nbs/api/07_ratelimit.ipynb #7b28079a
+class TokenBucket:
+ "Token-bucket rate limiter"
+ def __init__(self,
+ max_reqs:str|int, # Rate string ('5/m') or max requests per window
+ window_secs:int=None # Window in seconds (required if `max_reqs` is int)
+ ):
+ if window_secs is None: max_reqs,window_secs = parse_rate(max_reqs)
+ store_attr()
+ self.rate = max_reqs / window_secs
+ self.buckets = {}
+ def __repr__(self): return f'TokenBucket({self.max_reqs}, {self.window_secs})'
+
+ def _prune(self):
+ cutoff = time.time() - self.window_secs
+ self.buckets = {k:(t,ts) for k,(t,ts) in self.buckets.items() if ts > cutoff}
+ def wait(self, key):
+ "Return 0 if allowed, else seconds to wait"
+ self._prune()
+ now = time.time()
+ tokens, last = self.buckets.get(key, (self.max_reqs, now))
+ tokens = min(self.max_reqs, tokens + (now - last) * self.rate)
+ if tokens < 1: return (1 - tokens) / self.rate
+ self.buckets[key] = (tokens - 1, now)
+ return 0
+
+# %% ../nbs/api/07_ratelimit.ipynb #5b3c8231
+def client_ip(req, **kwargs):
+ "Get client IP from `X-Forwarded-For` header (assumes deployment behind a single Caddy reverse proxy)"
+ return req.headers.get('x-forwarded-for', '').split(',')[0].strip() or (req.client and req.client.host) or ''
+
+# %% ../nbs/api/07_ratelimit.ipynb #97f054cb
+def limiter(rate:str|tuple, # Rate string ('5/m') or (max_reqs, window_secs) tuple
+ key:str|Callable=client_ip, # Key to limit by: route param name, or callable(req, **kwargs)
+ on_limit:Callable=None # Optional callback(wait_secs) to return custom response on 429
+ ):
+ "Create a `before` function that rate-limits requests"
+ bucket = TokenBucket(*rate) if isinstance(rate, tuple) else TokenBucket(rate)
+ def _429(w):
+ if on_limit: return on_limit(w)
+ return Response('Too many requests', status_code=429, headers={'Retry-After': str(ceil(w))})
+ if callable(key):
+ def before(req):
+ if w:=bucket.wait(key(req)): return _429(w)
+ return before
+ async def before(req):
+ v = ifnone(req.path_params.get(key), req.query_params.get(key))
+ if v is None: v = (await req.form()).get(key)
+ if w:=bucket.wait(str(v or '')): return _429(w)
+ return before
diff --git a/nbs/api/00_core.ipynb b/nbs/api/00_core.ipynb
index 3a2da93e..ad636f22 100644
--- a/nbs/api/00_core.ipynb
+++ b/nbs/api/00_core.ipynb
@@ -51,7 +51,7 @@
"from fastcore.style import S\n",
"\n",
"from types import UnionType, SimpleNamespace as ns, GenericAlias\n",
- "from typing import get_args, get_origin, Union, Mapping, List, Any\n",
+ "from typing import get_args, get_origin, Union, Mapping, List, Any, Callable\n",
"from datetime import datetime,date\n",
"from dataclasses import dataclass\n",
"from inspect import Parameter,get_annotations\n",
@@ -1966,7 +1966,7 @@
"source": [
"#| export\n",
"@patch\n",
- "def _endp(self:FastHTML, f, body_wrap, before=None):\n",
+ "def _endp(self:FastHTML, f, body_wrap, before:Optional[Callable|tuple]=None):\n",
" \"Create endpoint wrapper with before/after middleware processing\"\n",
" sig = signature_ex(f, True)\n",
" for n,p in sig.parameters.items(): (msg:=_check_anno(n,p.annotation)) and warn(msg)\n",
@@ -1981,7 +1981,8 @@
" else: bf,skip = b,[]\n",
" if not any(re.fullmatch(r, req.url.path) for r in skip):\n",
" resp = await _wrap_call(bf, req, _params(bf))\n",
- " if not resp and before: resp = await _wrap_call(before, req, _params(before))\n",
+ " for b in listify(before):\n",
+ " if not resp: resp = await _wrap_call(b, req, _params(b))\n",
" req.body_wrap = body_wrap\n",
" if not resp: resp = await _wrap_call(f, req, sig.parameters)\n",
" for a in self.after:\n",
@@ -2118,7 +2119,7 @@
"source": [
"#| export\n",
"@patch\n",
- "def _add_route(self:FastHTML, func, path, methods, name, include_in_schema, body_wrap, host=None, before=None):\n",
+ "def _add_route(self:FastHTML, func, path, methods, name, include_in_schema, body_wrap, host=None, before:Optional[Callable|tuple]=None):\n",
" \"Add HTTP route to FastHTML app with automatic method detection\"\n",
" n,fn,p = name,nested_name(func),None if callable(path) else path\n",
" if methods: m = [methods] if isinstance(methods,str) else methods\n",
@@ -2143,7 +2144,7 @@
"source": [
"#| export\n",
"@patch\n",
- "def route(self:FastHTML, path:str=None, methods=None, name=None, include_in_schema=True, body_wrap=None, host=None, before=None):\n",
+ "def route(self:FastHTML, path:str=None, methods=None, name=None, include_in_schema=True, body_wrap=None, host=None, before:Optional[Callable|tuple]=None):\n",
" \"Add a route at `path`\"\n",
" def f(func):\n",
" return self._add_route(func, path, methods, name=name, include_in_schema=include_in_schema, body_wrap=body_wrap, host=host, before=before)\n",
@@ -3519,6 +3520,14 @@
"with cli: test_eq(cli.get('/').text, 'fired')"
]
},
+ {
+ "cell_type": "markdown",
+ "id": "b9fdbdd8",
+ "metadata": {},
+ "source": [
+ "A route-level `before` can set session data, short-circuit with a response, or raise to block access:"
+ ]
+ },
{
"cell_type": "code",
"execution_count": null,
@@ -3541,6 +3550,58 @@
"test_eq(cli.get('/dashboard?user=other').status_code, 403)"
]
},
+ {
+ "cell_type": "markdown",
+ "id": "35615416",
+ "metadata": {},
+ "source": [
+ "Multiple `before` functions run in order, each receiving the injected request params:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "6b751acb",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "def set_a(sess): sess['a'] = 1\n",
+ "def set_b(sess): sess['b'] = 2\n",
+ "\n",
+ "@rt('/multi', before=(set_a, set_b))\n",
+ "def get(sess): return f\"{sess.get('a')},{sess.get('b')}\"\n",
+ "\n",
+ "test_eq(cli.get('/multi').text, '1,2')"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "49cc8791",
+ "metadata": {},
+ "source": [
+ "A truthy return from any `before` becomes the response; remaining `before`s and the handler are skipped (return `None` to continue):"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "83188ab0",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "calls = []\n",
+ "def first(sess): \n",
+ " calls.append('first')\n",
+ " return 'stop'\n",
+ "def second(sess): calls.append('second')\n",
+ "\n",
+ "@rt('/chain', before=(first, second))\n",
+ "def get(sess): return 'handler'\n",
+ "\n",
+ "test_eq(cli.get('/chain').text, 'stop')\n",
+ "test_eq(calls, ['first'])"
+ ]
+ },
{
"cell_type": "markdown",
"id": "6a014add",
diff --git a/nbs/api/07_ratelimit.ipynb b/nbs/api/07_ratelimit.ipynb
new file mode 100644
index 00000000..d032be5c
--- /dev/null
+++ b/nbs/api/07_ratelimit.ipynb
@@ -0,0 +1,711 @@
+{
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "id": "6de889b6",
+ "metadata": {},
+ "source": [
+ "# Rate Limiting\n",
+ "> Simple token-bucket rate limiting for FastHTML routes"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "ec249d73",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "#| default_exp ratelimit"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "b33aec8a",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "#| export\n",
+ "import re,time\n",
+ "from math import ceil\n",
+ "from typing import Callable\n",
+ "from starlette.responses import Response\n",
+ "from fastcore.utils import *"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "c1e39599",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "#| hide\n",
+ "from starlette.testclient import TestClient\n",
+ "from nbdev.showdoc import show_doc\n",
+ "from fastcore.test import *\n",
+ "from fasthtml.fastapp import *"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "1078d10b",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "#| export\n",
+ "_units = dict(s=1, m=60, h=3600, d=86400)\n",
+ "\n",
+ "def parse_rate(s):\n",
+ " \"Parse rate string like `'5/m'`, `'1 per day'`, `'100/2h'`\"\n",
+ " m = re.match(r'(\\d+)\\s*(?:/|per)\\s*(\\d+)?\\s*([smhd])', s, re.I)\n",
+ " if not m: raise ValueError(f\"Invalid rate: {s}\")\n",
+ " n,mult,unit = m.groups()\n",
+ " return int(n), _units[unit.lower()] * int(mult or 1)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "23e2936b",
+ "metadata": {},
+ "source": [
+ "`parse_rate` accepts strings in the form `\"{count}/{window}\"` or `\"{count} per {window}\"`, where:\n",
+ "\n",
+ "- **count** — integer number of allowed requests (e.g. `5`, `100`)\n",
+ "- **window** — optional multiplier + unit: `s`econds, `m`inutes, `h`ours, `d`ays\n",
+ "\n",
+ "Examples: `'5/m'`, `'100/2h'`, `'1 per day'`, `5 per 2 hours`"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "19fecdf0",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "test_eq(parse_rate('5/m'), (5, 60))\n",
+ "test_eq(parse_rate('100/2h'), (100, 7200))\n",
+ "test_eq(parse_rate('1 per day'), (1, 86400))\n",
+ "test_eq(parse_rate('5 per 2 hours'), (5, 7200))"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "7b28079a",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "#| export\n",
+ "class TokenBucket:\n",
+ " \"Token-bucket rate limiter\"\n",
+ " def __init__(self,\n",
+ " max_reqs:str|int, # Rate string ('5/m') or max requests per window\n",
+ " window_secs:int=None # Window in seconds (required if `max_reqs` is int)\n",
+ " ):\n",
+ " if window_secs is None: max_reqs,window_secs = parse_rate(max_reqs)\n",
+ " store_attr()\n",
+ " self.rate = max_reqs / window_secs\n",
+ " self.buckets = {}\n",
+ " def __repr__(self): return f'TokenBucket({self.max_reqs}, {self.window_secs})'\n",
+ "\n",
+ " def _prune(self):\n",
+ " cutoff = time.time() - self.window_secs\n",
+ " self.buckets = {k:(t,ts) for k,(t,ts) in self.buckets.items() if ts > cutoff}\n",
+ " def wait(self, key):\n",
+ " \"Return 0 if allowed, else seconds to wait\"\n",
+ " self._prune()\n",
+ " now = time.time()\n",
+ " tokens, last = self.buckets.get(key, (self.max_reqs, now))\n",
+ " tokens = min(self.max_reqs, tokens + (now - last) * self.rate)\n",
+ " if tokens < 1: return (1 - tokens) / self.rate\n",
+ " self.buckets[key] = (tokens - 1, now)\n",
+ " return 0"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "f27963b0",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/markdown": [
+ "---\n",
+ "\n",
+ "### TokenBucket\n",
+ "\n",
+ "```python\n",
+ "\n",
+ "def TokenBucket(\n",
+ " max_reqs:str | int, # Rate string ('5/m') or max requests per window\n",
+ " window_secs:int=None, # Window in seconds (required if `max_reqs` is int)\n",
+ "):\n",
+ "\n",
+ "\n",
+ "```\n",
+ "\n",
+ "*Token-bucket rate limiter*"
+ ],
+ "text/plain": [
+ "def TokenBucket(\n",
+ " max_reqs:str | int, # Rate string ('5/m') or max requests per window\n",
+ " window_secs:int=None, # Window in seconds (required if `max_reqs` is int)\n",
+ "):\n",
+ "\"\"\"Token-bucket rate limiter\"\"\""
+ ]
+ },
+ "execution_count": null,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "show_doc(TokenBucket)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "349d2064",
+ "metadata": {},
+ "source": [
+ "Implements the [token bucket algorithm](https://en.wikipedia.org/wiki/Token_bucket). Tokens are added at a steady rate and consumed by requests. When the bucket is empty, requests are rejected and told how long to wait."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "f0028515",
+ "metadata": {},
+ "source": [
+ "You can create a bucket with either a rate string or explicit values:\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "417a59a6",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "TokenBucket(3, 10)"
+ ]
+ },
+ "execution_count": null,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "tb = TokenBucket(3, 10)\n",
+ "tb"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "623ac5d8",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "TokenBucket(3, 10)"
+ ]
+ },
+ "execution_count": null,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "tb = TokenBucket('3/10s')\n",
+ "tb"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "38712997",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/markdown": [
+ "---\n",
+ "\n",
+ "### TokenBucket.wait\n",
+ "\n",
+ "```python\n",
+ "\n",
+ "def wait(\n",
+ " key\n",
+ "):\n",
+ "\n",
+ "\n",
+ "```\n",
+ "\n",
+ "*Return 0 if allowed, else seconds to wait*"
+ ],
+ "text/plain": [
+ "def wait(\n",
+ " key\n",
+ "):\n",
+ "\"\"\"Return 0 if allowed, else seconds to wait\"\"\""
+ ]
+ },
+ "execution_count": null,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "show_doc(TokenBucket.wait)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "ff831ed1",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "test_eq(tb.wait('x'), 0)\n",
+ "test_eq(tb.wait('x'), 0)\n",
+ "test_eq(tb.wait('x'), 0)\n",
+ "test(tb.wait('x'), 0, operator.gt)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "73587949",
+ "metadata": {},
+ "source": [
+ "Each key gets its own independent token bucket. So if user A exhausts their tokens, user B is unaffected — they still have a full bucket of their own:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "6d8039f8",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "tb2 = TokenBucket(1, 0.5)\n",
+ "test_eq(tb2.wait('old'), 0)\n",
+ "test_eq(tb2.wait('new'), 0)\n",
+ "test (tb2.wait('old'), 0, operator.gt)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "2e48ec6b",
+ "metadata": {},
+ "source": [
+ "Stale keys — inactive longer than the window — are automatically pruned on each wait call:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "5276e7bb",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "tb2.wait('old')\n",
+ "time.sleep(0.6)\n",
+ "tb2.wait('new')\n",
+ "test_eq('old' in tb2.buckets, False)\n",
+ "test_eq('new' in tb2.buckets, True)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "661dafaa",
+ "metadata": {},
+ "source": [
+ "Rate-limiting works by deriving a **key** from each request and giving each key its own bucket. The most common choice is to limit by client IP — `client_ip` reads it from the `X-Forwarded-For` header, falling back to `req.client.host` for direct connections.\n",
+ "\n",
+ "This assumes deployment behind a single reverse proxy — specifically [Caddy](https://caddyserver.com/), which is what both [solveit](https://solve.it.com) and [plash](https://pla.sh) use. Caddy [strips any incoming `X-Forwarded-*`](https://caddyserver.com/docs/caddyfile/directives/reverse_proxy#headers) before setting its own, so clients can't spoof the IP.\n",
+ "\n",
+ "**For other setups — a different reverse proxy, multi-hop (e.g. Cloudflare → Caddy), or direct exposure to the internet — write your own key function.** Trusting `X-Forwarded-For` without a stripping proxy in front lets clients trivially spoof their IP."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "5b3c8231",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "#| export\n",
+ "def client_ip(req, **kwargs):\n",
+ " \"Get client IP from `X-Forwarded-For` header (assumes deployment behind a single Caddy reverse proxy)\"\n",
+ " return req.headers.get('x-forwarded-for', '').split(',')[0].strip() or (req.client and req.client.host) or ''"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "557fa33b",
+ "metadata": {},
+ "source": [
+ "`limiter` ties `TokenBucket` to FastHTML's routing. It returns a `before` function — pass it as `before=limiter(...)` on a route to apply rate limiting, or reuse one `limiter(...)` across routes to share a bucket.\n",
+ "\n",
+ "The `key` parameter controls what to limit by. It defaults to `client_ip` (see caveats above), but you can pass a string to match a route/query/form parameter name, or a callable `(req, **kwargs) -> str` for custom logic. \n",
+ "\n",
+ "Use `on_limit` to customize the 429 response — it receives the wait time in seconds and returns the response to send."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "97f054cb",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "#| export\n",
+ "def limiter(rate:str|tuple, # Rate string ('5/m') or (max_reqs, window_secs) tuple\n",
+ " key:str|Callable=client_ip, # Key to limit by: route param name, or callable(req, **kwargs)\n",
+ " on_limit:Callable=None # Optional callback(wait_secs) to return custom response on 429\n",
+ " ): \n",
+ " \"Create a `before` function that rate-limits requests\"\n",
+ " bucket = TokenBucket(*rate) if isinstance(rate, tuple) else TokenBucket(rate)\n",
+ " def _429(w):\n",
+ " if on_limit: return on_limit(w)\n",
+ " return Response('Too many requests', status_code=429, headers={'Retry-After': str(ceil(w))})\n",
+ " if callable(key):\n",
+ " def before(req):\n",
+ " if w:=bucket.wait(key(req)): return _429(w)\n",
+ " return before\n",
+ " async def before(req):\n",
+ " v = ifnone(req.path_params.get(key), req.query_params.get(key))\n",
+ " if v is None: v = (await req.form()).get(key)\n",
+ " if w:=bucket.wait(str(v or '')): return _429(w)\n",
+ " return before"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "7394181b",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/markdown": [
+ "---\n",
+ "\n",
+ "### limiter\n",
+ "\n",
+ "```python\n",
+ "\n",
+ "def limiter(\n",
+ " rate:str | tuple, # Rate string ('5/m') or (max_reqs, window_secs) tuple\n",
+ " key:Union=client_ip, # Key to limit by: route param name, or callable(req, **kwargs)\n",
+ " on_limit:Callable=None, # Optional callback(wait_secs) to return custom response on 429\n",
+ "):\n",
+ "\n",
+ "\n",
+ "```\n",
+ "\n",
+ "*Create a `before` function that rate-limits requests*"
+ ],
+ "text/plain": [
+ "def limiter(\n",
+ " rate:str | tuple, # Rate string ('5/m') or (max_reqs, window_secs) tuple\n",
+ " key:Union=client_ip, # Key to limit by: route param name, or callable(req, **kwargs)\n",
+ " on_limit:Callable=None, # Optional callback(wait_secs) to return custom response on 429\n",
+ "):\n",
+ "\"\"\"Create a `before` function that rate-limits requests\"\"\""
+ ]
+ },
+ "execution_count": null,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "show_doc(limiter)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "3097f9d4",
+ "metadata": {},
+ "source": [
+ "## Example usage"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "19967166",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "app, rt = fast_app()\n",
+ "cli = TestClient(app)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "d0e3dd93",
+ "metadata": {},
+ "source": [
+ "### IP-based rate limiting (default)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "7255f9dc",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "@rt(before=limiter('3/m')) # keys on client IP by default\n",
+ "def index(): return 'ok'"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "a2c4118b",
+ "metadata": {},
+ "source": [
+ "Requests are allowed until the bucket is exhausted, then a 429 is returned with a `Retry-After` header."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "89f5bfcf",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "for i in range(3): test_eq(cli.get('/').status_code, 200)\n",
+ "r = cli.get('/')\n",
+ "test_eq(r.status_code, 429)\n",
+ "test(r.headers, 'Retry-After', operator.contains)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "ac3018b1",
+ "metadata": {},
+ "source": [
+ "### Path-based limiting:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "e2c945fc",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "@rt('/item/{item_id}', before=limiter('1/m', key='item_id'))\n",
+ "def item(item_id: str): return f'item {item_id}'"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "bc23f534",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "test_eq(cli.get('/item/abc').status_code, 200)\n",
+ "test_eq(cli.get('/item/abc').status_code, 429)\n",
+ "test_eq(cli.get('/item/xyz').status_code, 200)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "5c7a87db",
+ "metadata": {},
+ "source": [
+ "### Parameter-based rate limiting:"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "6e1b8dc0",
+ "metadata": {},
+ "source": [
+ "Different keys get independent buckets. Here we show `a@test.com` and `b@test.com` each get their own limit."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "bbe72069",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "@rt('/submit', methods=['POST'], before=limiter('1/m', key='email'))\n",
+ "def submit(email: str): return f'hello {email}'"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "40d23160",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "test_eq(cli.post('/submit', data={'email': 'a@test.com'}).status_code, 200)\n",
+ "test_eq(cli.post('/submit', data={'email': 'a@test.com'}).status_code, 429)\n",
+ "test_eq(cli.post('/submit', data={'email': 'b@test.com'}).status_code, 200)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "70172d98",
+ "metadata": {},
+ "source": [
+ "### Callable-based rate limiting:"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "dffebde9",
+ "metadata": {},
+ "source": [
+ "For full control, pass a callable as `key`. It receives the Starlette `Request` plus any route `**kwargs`, and should return a string to bucket by. Here we rate-limit by the `x-api-key` header — each API key gets its own bucket:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "197fb532",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "@rt(before=limiter('2/m', key=lambda req, **kwargs: req.headers.get('x-api-key', '')))\n",
+ "def custom(): return 'ok'"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "6ed4a731",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "for i in range(2): test_eq(cli.get('/custom', headers={'x-api-key': 'abc'}).status_code, 200)\n",
+ "test_eq(cli.get('/custom', headers={'x-api-key': 'abc'}).status_code, 429)\n",
+ "test_eq(cli.get('/custom', headers={'x-api-key': 'xyz'}).status_code, 200)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "bf91ea45",
+ "metadata": {},
+ "source": [
+ "### Shared limits across routes:"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "bd9e28c8",
+ "metadata": {},
+ "source": [
+ "Save the decorator and apply it to multiple routes to share a single bucket."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "1fd5a249",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "shared = limiter('2/m')\n",
+ "\n",
+ "@rt(before=shared)\n",
+ "def users(): return 'users'\n",
+ "\n",
+ "@rt(before=shared)\n",
+ "def posts(): return 'posts'\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "42e03753",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "test_eq(cli.get('/users').status_code, 200)\n",
+ "test_eq(cli.get('/posts').status_code, 200)\n",
+ "test_eq(cli.get('/users').status_code, 429)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "c758318c",
+ "metadata": {},
+ "source": [
+ "### Multiple limiters on one route"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "f1ffc673",
+ "metadata": {},
+ "source": [
+ "Pass a tuple of limiters as `before` to require all of them to pass. Below, each user gets 1 requests per minute, and the route as a whole is capped at 3 per minute per IP."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "cc568d35",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "@rt('/multi', before=(limiter('3/m'), limiter('1/m', key='user')))\n",
+ "def multi(user:str): return f'hi {user}'"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "f91911ea",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "test_eq(cli.get('/multi?user=a').status_code, 200)\n",
+ "test_eq(cli.get('/multi?user=a').status_code, 429) # user limited\n",
+ "test_eq(cli.get('/multi?user=b').status_code, 200)\n",
+ "test_eq(cli.get('/multi?user=c').status_code, 429) # ip limited"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "330b3cba",
+ "metadata": {},
+ "source": [
+ "# Export -"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "84300cc1",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "#| hide\n",
+ "import nbdev; nbdev.nbdev_export()"
+ ]
+ }
+ ],
+ "metadata": {},
+ "nbformat": 4,
+ "nbformat_minor": 5
+}