diff --git a/fasthtml/_modidx.py b/fasthtml/_modidx.py index 412a2a14..2577e3ab 100644 --- a/fasthtml/_modidx.py +++ b/fasthtml/_modidx.py @@ -254,6 +254,17 @@ 'fasthtml.pico.PicoBusy': ('api/pico.html#picobusy', 'fasthtml/pico.py'), 'fasthtml.pico.Search': ('api/pico.html#search', 'fasthtml/pico.py'), 'fasthtml.pico.set_pico_cls': ('api/pico.html#set_pico_cls', 'fasthtml/pico.py')}, + 'fasthtml.ratelimit': { 'fasthtml.ratelimit.TokenBucket': ('api/ratelimit.html#tokenbucket', 'fasthtml/ratelimit.py'), + 'fasthtml.ratelimit.TokenBucket.__init__': ( 'api/ratelimit.html#tokenbucket.__init__', + 'fasthtml/ratelimit.py'), + 'fasthtml.ratelimit.TokenBucket.__repr__': ( 'api/ratelimit.html#tokenbucket.__repr__', + 'fasthtml/ratelimit.py'), + 'fasthtml.ratelimit.TokenBucket._prune': ( 'api/ratelimit.html#tokenbucket._prune', + 'fasthtml/ratelimit.py'), + 'fasthtml.ratelimit.TokenBucket.wait': ('api/ratelimit.html#tokenbucket.wait', 'fasthtml/ratelimit.py'), + 'fasthtml.ratelimit.client_ip': ('api/ratelimit.html#client_ip', 'fasthtml/ratelimit.py'), + 'fasthtml.ratelimit.limiter': ('api/ratelimit.html#limiter', 'fasthtml/ratelimit.py'), + 'fasthtml.ratelimit.parse_rate': ('api/ratelimit.html#parse_rate', 'fasthtml/ratelimit.py')}, 'fasthtml.starlette': {}, 'fasthtml.stripe_otp': { 'fasthtml.stripe_otp.Payment': ('explains/stripe.html#payment', 'fasthtml/stripe_otp.py'), 'fasthtml.stripe_otp._search_app': ('explains/stripe.html#_search_app', 'fasthtml/stripe_otp.py'), diff --git a/fasthtml/core.py b/fasthtml/core.py index 52ba7419..4c8ca3bd 100644 --- a/fasthtml/core.py +++ b/fasthtml/core.py @@ -20,7 +20,7 @@ from fastcore.style import S from types import UnionType, SimpleNamespace as ns, GenericAlias -from typing import get_args, get_origin, Union, Mapping, List, Any +from typing import get_args, get_origin, Union, Mapping, List, Any, Callable from datetime import datetime,date from dataclasses import dataclass from inspect import Parameter,get_annotations @@ -664,7 +664,7 @@ def add_route(self:FastHTML, route): # %% ../nbs/api/00_core.ipynb #26b147ba @patch -def _endp(self:FastHTML, f, body_wrap, before=None): +def _endp(self:FastHTML, f, body_wrap, before:Optional[Callable|tuple]=None): "Create endpoint wrapper with before/after middleware processing" sig = signature_ex(f, True) for n,p in sig.parameters.items(): (msg:=_check_anno(n,p.annotation)) and warn(msg) @@ -679,7 +679,8 @@ async def _f(req): else: bf,skip = b,[] if not any(re.fullmatch(r, req.url.path) for r in skip): resp = await _wrap_call(bf, req, _params(bf)) - if not resp and before: resp = await _wrap_call(before, req, _params(before)) + for b in listify(before): + if not resp: resp = await _wrap_call(b, req, _params(b)) req.body_wrap = body_wrap if not resp: resp = await _wrap_call(f, req, sig.parameters) for a in self.after: @@ -734,7 +735,7 @@ def nested_name(f): # %% ../nbs/api/00_core.ipynb #72760b09 @patch -def _add_route(self:FastHTML, func, path, methods, name, include_in_schema, body_wrap, host=None, before=None): +def _add_route(self:FastHTML, func, path, methods, name, include_in_schema, body_wrap, host=None, before:Optional[Callable|tuple]=None): "Add HTTP route to FastHTML app with automatic method detection" n,fn,p = name,nested_name(func),None if callable(path) else path if methods: m = [methods] if isinstance(methods,str) else methods @@ -751,7 +752,7 @@ def _add_route(self:FastHTML, func, path, methods, name, include_in_schema, body # %% ../nbs/api/00_core.ipynb #f5cb2c2b @patch -def route(self:FastHTML, path:str=None, methods=None, name=None, include_in_schema=True, body_wrap=None, host=None, before=None): +def route(self:FastHTML, path:str=None, methods=None, name=None, include_in_schema=True, body_wrap=None, host=None, before:Optional[Callable|tuple]=None): "Add a route at `path`" def f(func): return self._add_route(func, path, methods, name=name, include_in_schema=include_in_schema, body_wrap=body_wrap, host=host, before=before) diff --git a/fasthtml/ratelimit.py b/fasthtml/ratelimit.py new file mode 100644 index 00000000..dc067b23 --- /dev/null +++ b/fasthtml/ratelimit.py @@ -0,0 +1,74 @@ +"""Simple token-bucket rate limiting for FastHTML routes""" + +# AUTOGENERATED! DO NOT EDIT! File to edit: ../nbs/api/07_ratelimit.ipynb. + +# %% auto #0 +__all__ = ['parse_rate', 'TokenBucket', 'client_ip', 'limiter'] + +# %% ../nbs/api/07_ratelimit.ipynb #b33aec8a +import re,time +from math import ceil +from typing import Callable +from starlette.responses import Response +from fastcore.utils import * + +# %% ../nbs/api/07_ratelimit.ipynb #1078d10b +_units = dict(s=1, m=60, h=3600, d=86400) + +def parse_rate(s): + "Parse rate string like `'5/m'`, `'1 per day'`, `'100/2h'`" + m = re.match(r'(\d+)\s*(?:/|per)\s*(\d+)?\s*([smhd])', s, re.I) + if not m: raise ValueError(f"Invalid rate: {s}") + n,mult,unit = m.groups() + return int(n), _units[unit.lower()] * int(mult or 1) + +# %% ../nbs/api/07_ratelimit.ipynb #7b28079a +class TokenBucket: + "Token-bucket rate limiter" + def __init__(self, + max_reqs:str|int, # Rate string ('5/m') or max requests per window + window_secs:int=None # Window in seconds (required if `max_reqs` is int) + ): + if window_secs is None: max_reqs,window_secs = parse_rate(max_reqs) + store_attr() + self.rate = max_reqs / window_secs + self.buckets = {} + def __repr__(self): return f'TokenBucket({self.max_reqs}, {self.window_secs})' + + def _prune(self): + cutoff = time.time() - self.window_secs + self.buckets = {k:(t,ts) for k,(t,ts) in self.buckets.items() if ts > cutoff} + def wait(self, key): + "Return 0 if allowed, else seconds to wait" + self._prune() + now = time.time() + tokens, last = self.buckets.get(key, (self.max_reqs, now)) + tokens = min(self.max_reqs, tokens + (now - last) * self.rate) + if tokens < 1: return (1 - tokens) / self.rate + self.buckets[key] = (tokens - 1, now) + return 0 + +# %% ../nbs/api/07_ratelimit.ipynb #5b3c8231 +def client_ip(req, **kwargs): + "Get client IP from `X-Forwarded-For` header (assumes deployment behind a single Caddy reverse proxy)" + return req.headers.get('x-forwarded-for', '').split(',')[0].strip() or (req.client and req.client.host) or '' + +# %% ../nbs/api/07_ratelimit.ipynb #97f054cb +def limiter(rate:str|tuple, # Rate string ('5/m') or (max_reqs, window_secs) tuple + key:str|Callable=client_ip, # Key to limit by: route param name, or callable(req, **kwargs) + on_limit:Callable=None # Optional callback(wait_secs) to return custom response on 429 + ): + "Create a `before` function that rate-limits requests" + bucket = TokenBucket(*rate) if isinstance(rate, tuple) else TokenBucket(rate) + def _429(w): + if on_limit: return on_limit(w) + return Response('Too many requests', status_code=429, headers={'Retry-After': str(ceil(w))}) + if callable(key): + def before(req): + if w:=bucket.wait(key(req)): return _429(w) + return before + async def before(req): + v = ifnone(req.path_params.get(key), req.query_params.get(key)) + if v is None: v = (await req.form()).get(key) + if w:=bucket.wait(str(v or '')): return _429(w) + return before diff --git a/nbs/api/00_core.ipynb b/nbs/api/00_core.ipynb index 3a2da93e..ad636f22 100644 --- a/nbs/api/00_core.ipynb +++ b/nbs/api/00_core.ipynb @@ -51,7 +51,7 @@ "from fastcore.style import S\n", "\n", "from types import UnionType, SimpleNamespace as ns, GenericAlias\n", - "from typing import get_args, get_origin, Union, Mapping, List, Any\n", + "from typing import get_args, get_origin, Union, Mapping, List, Any, Callable\n", "from datetime import datetime,date\n", "from dataclasses import dataclass\n", "from inspect import Parameter,get_annotations\n", @@ -1966,7 +1966,7 @@ "source": [ "#| export\n", "@patch\n", - "def _endp(self:FastHTML, f, body_wrap, before=None):\n", + "def _endp(self:FastHTML, f, body_wrap, before:Optional[Callable|tuple]=None):\n", " \"Create endpoint wrapper with before/after middleware processing\"\n", " sig = signature_ex(f, True)\n", " for n,p in sig.parameters.items(): (msg:=_check_anno(n,p.annotation)) and warn(msg)\n", @@ -1981,7 +1981,8 @@ " else: bf,skip = b,[]\n", " if not any(re.fullmatch(r, req.url.path) for r in skip):\n", " resp = await _wrap_call(bf, req, _params(bf))\n", - " if not resp and before: resp = await _wrap_call(before, req, _params(before))\n", + " for b in listify(before):\n", + " if not resp: resp = await _wrap_call(b, req, _params(b))\n", " req.body_wrap = body_wrap\n", " if not resp: resp = await _wrap_call(f, req, sig.parameters)\n", " for a in self.after:\n", @@ -2118,7 +2119,7 @@ "source": [ "#| export\n", "@patch\n", - "def _add_route(self:FastHTML, func, path, methods, name, include_in_schema, body_wrap, host=None, before=None):\n", + "def _add_route(self:FastHTML, func, path, methods, name, include_in_schema, body_wrap, host=None, before:Optional[Callable|tuple]=None):\n", " \"Add HTTP route to FastHTML app with automatic method detection\"\n", " n,fn,p = name,nested_name(func),None if callable(path) else path\n", " if methods: m = [methods] if isinstance(methods,str) else methods\n", @@ -2143,7 +2144,7 @@ "source": [ "#| export\n", "@patch\n", - "def route(self:FastHTML, path:str=None, methods=None, name=None, include_in_schema=True, body_wrap=None, host=None, before=None):\n", + "def route(self:FastHTML, path:str=None, methods=None, name=None, include_in_schema=True, body_wrap=None, host=None, before:Optional[Callable|tuple]=None):\n", " \"Add a route at `path`\"\n", " def f(func):\n", " return self._add_route(func, path, methods, name=name, include_in_schema=include_in_schema, body_wrap=body_wrap, host=host, before=before)\n", @@ -3519,6 +3520,14 @@ "with cli: test_eq(cli.get('/').text, 'fired')" ] }, + { + "cell_type": "markdown", + "id": "b9fdbdd8", + "metadata": {}, + "source": [ + "A route-level `before` can set session data, short-circuit with a response, or raise to block access:" + ] + }, { "cell_type": "code", "execution_count": null, @@ -3541,6 +3550,58 @@ "test_eq(cli.get('/dashboard?user=other').status_code, 403)" ] }, + { + "cell_type": "markdown", + "id": "35615416", + "metadata": {}, + "source": [ + "Multiple `before` functions run in order, each receiving the injected request params:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "6b751acb", + "metadata": {}, + "outputs": [], + "source": [ + "def set_a(sess): sess['a'] = 1\n", + "def set_b(sess): sess['b'] = 2\n", + "\n", + "@rt('/multi', before=(set_a, set_b))\n", + "def get(sess): return f\"{sess.get('a')},{sess.get('b')}\"\n", + "\n", + "test_eq(cli.get('/multi').text, '1,2')" + ] + }, + { + "cell_type": "markdown", + "id": "49cc8791", + "metadata": {}, + "source": [ + "A truthy return from any `before` becomes the response; remaining `before`s and the handler are skipped (return `None` to continue):" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "83188ab0", + "metadata": {}, + "outputs": [], + "source": [ + "calls = []\n", + "def first(sess): \n", + " calls.append('first')\n", + " return 'stop'\n", + "def second(sess): calls.append('second')\n", + "\n", + "@rt('/chain', before=(first, second))\n", + "def get(sess): return 'handler'\n", + "\n", + "test_eq(cli.get('/chain').text, 'stop')\n", + "test_eq(calls, ['first'])" + ] + }, { "cell_type": "markdown", "id": "6a014add", diff --git a/nbs/api/07_ratelimit.ipynb b/nbs/api/07_ratelimit.ipynb new file mode 100644 index 00000000..d032be5c --- /dev/null +++ b/nbs/api/07_ratelimit.ipynb @@ -0,0 +1,711 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "6de889b6", + "metadata": {}, + "source": [ + "# Rate Limiting\n", + "> Simple token-bucket rate limiting for FastHTML routes" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "ec249d73", + "metadata": {}, + "outputs": [], + "source": [ + "#| default_exp ratelimit" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "b33aec8a", + "metadata": {}, + "outputs": [], + "source": [ + "#| export\n", + "import re,time\n", + "from math import ceil\n", + "from typing import Callable\n", + "from starlette.responses import Response\n", + "from fastcore.utils import *" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "c1e39599", + "metadata": {}, + "outputs": [], + "source": [ + "#| hide\n", + "from starlette.testclient import TestClient\n", + "from nbdev.showdoc import show_doc\n", + "from fastcore.test import *\n", + "from fasthtml.fastapp import *" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "1078d10b", + "metadata": {}, + "outputs": [], + "source": [ + "#| export\n", + "_units = dict(s=1, m=60, h=3600, d=86400)\n", + "\n", + "def parse_rate(s):\n", + " \"Parse rate string like `'5/m'`, `'1 per day'`, `'100/2h'`\"\n", + " m = re.match(r'(\\d+)\\s*(?:/|per)\\s*(\\d+)?\\s*([smhd])', s, re.I)\n", + " if not m: raise ValueError(f\"Invalid rate: {s}\")\n", + " n,mult,unit = m.groups()\n", + " return int(n), _units[unit.lower()] * int(mult or 1)" + ] + }, + { + "cell_type": "markdown", + "id": "23e2936b", + "metadata": {}, + "source": [ + "`parse_rate` accepts strings in the form `\"{count}/{window}\"` or `\"{count} per {window}\"`, where:\n", + "\n", + "- **count** — integer number of allowed requests (e.g. `5`, `100`)\n", + "- **window** — optional multiplier + unit: `s`econds, `m`inutes, `h`ours, `d`ays\n", + "\n", + "Examples: `'5/m'`, `'100/2h'`, `'1 per day'`, `5 per 2 hours`" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "19fecdf0", + "metadata": {}, + "outputs": [], + "source": [ + "test_eq(parse_rate('5/m'), (5, 60))\n", + "test_eq(parse_rate('100/2h'), (100, 7200))\n", + "test_eq(parse_rate('1 per day'), (1, 86400))\n", + "test_eq(parse_rate('5 per 2 hours'), (5, 7200))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "7b28079a", + "metadata": {}, + "outputs": [], + "source": [ + "#| export\n", + "class TokenBucket:\n", + " \"Token-bucket rate limiter\"\n", + " def __init__(self,\n", + " max_reqs:str|int, # Rate string ('5/m') or max requests per window\n", + " window_secs:int=None # Window in seconds (required if `max_reqs` is int)\n", + " ):\n", + " if window_secs is None: max_reqs,window_secs = parse_rate(max_reqs)\n", + " store_attr()\n", + " self.rate = max_reqs / window_secs\n", + " self.buckets = {}\n", + " def __repr__(self): return f'TokenBucket({self.max_reqs}, {self.window_secs})'\n", + "\n", + " def _prune(self):\n", + " cutoff = time.time() - self.window_secs\n", + " self.buckets = {k:(t,ts) for k,(t,ts) in self.buckets.items() if ts > cutoff}\n", + " def wait(self, key):\n", + " \"Return 0 if allowed, else seconds to wait\"\n", + " self._prune()\n", + " now = time.time()\n", + " tokens, last = self.buckets.get(key, (self.max_reqs, now))\n", + " tokens = min(self.max_reqs, tokens + (now - last) * self.rate)\n", + " if tokens < 1: return (1 - tokens) / self.rate\n", + " self.buckets[key] = (tokens - 1, now)\n", + " return 0" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "f27963b0", + "metadata": {}, + "outputs": [ + { + "data": { + "text/markdown": [ + "---\n", + "\n", + "### TokenBucket\n", + "\n", + "```python\n", + "\n", + "def TokenBucket(\n", + " max_reqs:str | int, # Rate string ('5/m') or max requests per window\n", + " window_secs:int=None, # Window in seconds (required if `max_reqs` is int)\n", + "):\n", + "\n", + "\n", + "```\n", + "\n", + "*Token-bucket rate limiter*" + ], + "text/plain": [ + "def TokenBucket(\n", + " max_reqs:str | int, # Rate string ('5/m') or max requests per window\n", + " window_secs:int=None, # Window in seconds (required if `max_reqs` is int)\n", + "):\n", + "\"\"\"Token-bucket rate limiter\"\"\"" + ] + }, + "execution_count": null, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "show_doc(TokenBucket)" + ] + }, + { + "cell_type": "markdown", + "id": "349d2064", + "metadata": {}, + "source": [ + "Implements the [token bucket algorithm](https://en.wikipedia.org/wiki/Token_bucket). Tokens are added at a steady rate and consumed by requests. When the bucket is empty, requests are rejected and told how long to wait." + ] + }, + { + "cell_type": "markdown", + "id": "f0028515", + "metadata": {}, + "source": [ + "You can create a bucket with either a rate string or explicit values:\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "417a59a6", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "TokenBucket(3, 10)" + ] + }, + "execution_count": null, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "tb = TokenBucket(3, 10)\n", + "tb" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "623ac5d8", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "TokenBucket(3, 10)" + ] + }, + "execution_count": null, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "tb = TokenBucket('3/10s')\n", + "tb" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "38712997", + "metadata": {}, + "outputs": [ + { + "data": { + "text/markdown": [ + "---\n", + "\n", + "### TokenBucket.wait\n", + "\n", + "```python\n", + "\n", + "def wait(\n", + " key\n", + "):\n", + "\n", + "\n", + "```\n", + "\n", + "*Return 0 if allowed, else seconds to wait*" + ], + "text/plain": [ + "def wait(\n", + " key\n", + "):\n", + "\"\"\"Return 0 if allowed, else seconds to wait\"\"\"" + ] + }, + "execution_count": null, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "show_doc(TokenBucket.wait)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "ff831ed1", + "metadata": {}, + "outputs": [], + "source": [ + "test_eq(tb.wait('x'), 0)\n", + "test_eq(tb.wait('x'), 0)\n", + "test_eq(tb.wait('x'), 0)\n", + "test(tb.wait('x'), 0, operator.gt)" + ] + }, + { + "cell_type": "markdown", + "id": "73587949", + "metadata": {}, + "source": [ + "Each key gets its own independent token bucket. So if user A exhausts their tokens, user B is unaffected — they still have a full bucket of their own:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "6d8039f8", + "metadata": {}, + "outputs": [], + "source": [ + "tb2 = TokenBucket(1, 0.5)\n", + "test_eq(tb2.wait('old'), 0)\n", + "test_eq(tb2.wait('new'), 0)\n", + "test (tb2.wait('old'), 0, operator.gt)" + ] + }, + { + "cell_type": "markdown", + "id": "2e48ec6b", + "metadata": {}, + "source": [ + "Stale keys — inactive longer than the window — are automatically pruned on each wait call:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "5276e7bb", + "metadata": {}, + "outputs": [], + "source": [ + "tb2.wait('old')\n", + "time.sleep(0.6)\n", + "tb2.wait('new')\n", + "test_eq('old' in tb2.buckets, False)\n", + "test_eq('new' in tb2.buckets, True)" + ] + }, + { + "cell_type": "markdown", + "id": "661dafaa", + "metadata": {}, + "source": [ + "Rate-limiting works by deriving a **key** from each request and giving each key its own bucket. The most common choice is to limit by client IP — `client_ip` reads it from the `X-Forwarded-For` header, falling back to `req.client.host` for direct connections.\n", + "\n", + "This assumes deployment behind a single reverse proxy — specifically [Caddy](https://caddyserver.com/), which is what both [solveit](https://solve.it.com) and [plash](https://pla.sh) use. Caddy [strips any incoming `X-Forwarded-*`](https://caddyserver.com/docs/caddyfile/directives/reverse_proxy#headers) before setting its own, so clients can't spoof the IP.\n", + "\n", + "**For other setups — a different reverse proxy, multi-hop (e.g. Cloudflare → Caddy), or direct exposure to the internet — write your own key function.** Trusting `X-Forwarded-For` without a stripping proxy in front lets clients trivially spoof their IP." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "5b3c8231", + "metadata": {}, + "outputs": [], + "source": [ + "#| export\n", + "def client_ip(req, **kwargs):\n", + " \"Get client IP from `X-Forwarded-For` header (assumes deployment behind a single Caddy reverse proxy)\"\n", + " return req.headers.get('x-forwarded-for', '').split(',')[0].strip() or (req.client and req.client.host) or ''" + ] + }, + { + "cell_type": "markdown", + "id": "557fa33b", + "metadata": {}, + "source": [ + "`limiter` ties `TokenBucket` to FastHTML's routing. It returns a `before` function — pass it as `before=limiter(...)` on a route to apply rate limiting, or reuse one `limiter(...)` across routes to share a bucket.\n", + "\n", + "The `key` parameter controls what to limit by. It defaults to `client_ip` (see caveats above), but you can pass a string to match a route/query/form parameter name, or a callable `(req, **kwargs) -> str` for custom logic. \n", + "\n", + "Use `on_limit` to customize the 429 response — it receives the wait time in seconds and returns the response to send." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "97f054cb", + "metadata": {}, + "outputs": [], + "source": [ + "#| export\n", + "def limiter(rate:str|tuple, # Rate string ('5/m') or (max_reqs, window_secs) tuple\n", + " key:str|Callable=client_ip, # Key to limit by: route param name, or callable(req, **kwargs)\n", + " on_limit:Callable=None # Optional callback(wait_secs) to return custom response on 429\n", + " ): \n", + " \"Create a `before` function that rate-limits requests\"\n", + " bucket = TokenBucket(*rate) if isinstance(rate, tuple) else TokenBucket(rate)\n", + " def _429(w):\n", + " if on_limit: return on_limit(w)\n", + " return Response('Too many requests', status_code=429, headers={'Retry-After': str(ceil(w))})\n", + " if callable(key):\n", + " def before(req):\n", + " if w:=bucket.wait(key(req)): return _429(w)\n", + " return before\n", + " async def before(req):\n", + " v = ifnone(req.path_params.get(key), req.query_params.get(key))\n", + " if v is None: v = (await req.form()).get(key)\n", + " if w:=bucket.wait(str(v or '')): return _429(w)\n", + " return before" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "7394181b", + "metadata": {}, + "outputs": [ + { + "data": { + "text/markdown": [ + "---\n", + "\n", + "### limiter\n", + "\n", + "```python\n", + "\n", + "def limiter(\n", + " rate:str | tuple, # Rate string ('5/m') or (max_reqs, window_secs) tuple\n", + " key:Union=client_ip, # Key to limit by: route param name, or callable(req, **kwargs)\n", + " on_limit:Callable=None, # Optional callback(wait_secs) to return custom response on 429\n", + "):\n", + "\n", + "\n", + "```\n", + "\n", + "*Create a `before` function that rate-limits requests*" + ], + "text/plain": [ + "def limiter(\n", + " rate:str | tuple, # Rate string ('5/m') or (max_reqs, window_secs) tuple\n", + " key:Union=client_ip, # Key to limit by: route param name, or callable(req, **kwargs)\n", + " on_limit:Callable=None, # Optional callback(wait_secs) to return custom response on 429\n", + "):\n", + "\"\"\"Create a `before` function that rate-limits requests\"\"\"" + ] + }, + "execution_count": null, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "show_doc(limiter)" + ] + }, + { + "cell_type": "markdown", + "id": "3097f9d4", + "metadata": {}, + "source": [ + "## Example usage" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "19967166", + "metadata": {}, + "outputs": [], + "source": [ + "app, rt = fast_app()\n", + "cli = TestClient(app)" + ] + }, + { + "cell_type": "markdown", + "id": "d0e3dd93", + "metadata": {}, + "source": [ + "### IP-based rate limiting (default)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "7255f9dc", + "metadata": {}, + "outputs": [], + "source": [ + "@rt(before=limiter('3/m')) # keys on client IP by default\n", + "def index(): return 'ok'" + ] + }, + { + "cell_type": "markdown", + "id": "a2c4118b", + "metadata": {}, + "source": [ + "Requests are allowed until the bucket is exhausted, then a 429 is returned with a `Retry-After` header." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "89f5bfcf", + "metadata": {}, + "outputs": [], + "source": [ + "for i in range(3): test_eq(cli.get('/').status_code, 200)\n", + "r = cli.get('/')\n", + "test_eq(r.status_code, 429)\n", + "test(r.headers, 'Retry-After', operator.contains)" + ] + }, + { + "cell_type": "markdown", + "id": "ac3018b1", + "metadata": {}, + "source": [ + "### Path-based limiting:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "e2c945fc", + "metadata": {}, + "outputs": [], + "source": [ + "@rt('/item/{item_id}', before=limiter('1/m', key='item_id'))\n", + "def item(item_id: str): return f'item {item_id}'" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "bc23f534", + "metadata": {}, + "outputs": [], + "source": [ + "test_eq(cli.get('/item/abc').status_code, 200)\n", + "test_eq(cli.get('/item/abc').status_code, 429)\n", + "test_eq(cli.get('/item/xyz').status_code, 200)" + ] + }, + { + "cell_type": "markdown", + "id": "5c7a87db", + "metadata": {}, + "source": [ + "### Parameter-based rate limiting:" + ] + }, + { + "cell_type": "markdown", + "id": "6e1b8dc0", + "metadata": {}, + "source": [ + "Different keys get independent buckets. Here we show `a@test.com` and `b@test.com` each get their own limit." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "bbe72069", + "metadata": {}, + "outputs": [], + "source": [ + "@rt('/submit', methods=['POST'], before=limiter('1/m', key='email'))\n", + "def submit(email: str): return f'hello {email}'" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "40d23160", + "metadata": {}, + "outputs": [], + "source": [ + "test_eq(cli.post('/submit', data={'email': 'a@test.com'}).status_code, 200)\n", + "test_eq(cli.post('/submit', data={'email': 'a@test.com'}).status_code, 429)\n", + "test_eq(cli.post('/submit', data={'email': 'b@test.com'}).status_code, 200)" + ] + }, + { + "cell_type": "markdown", + "id": "70172d98", + "metadata": {}, + "source": [ + "### Callable-based rate limiting:" + ] + }, + { + "cell_type": "markdown", + "id": "dffebde9", + "metadata": {}, + "source": [ + "For full control, pass a callable as `key`. It receives the Starlette `Request` plus any route `**kwargs`, and should return a string to bucket by. Here we rate-limit by the `x-api-key` header — each API key gets its own bucket:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "197fb532", + "metadata": {}, + "outputs": [], + "source": [ + "@rt(before=limiter('2/m', key=lambda req, **kwargs: req.headers.get('x-api-key', '')))\n", + "def custom(): return 'ok'" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "6ed4a731", + "metadata": {}, + "outputs": [], + "source": [ + "for i in range(2): test_eq(cli.get('/custom', headers={'x-api-key': 'abc'}).status_code, 200)\n", + "test_eq(cli.get('/custom', headers={'x-api-key': 'abc'}).status_code, 429)\n", + "test_eq(cli.get('/custom', headers={'x-api-key': 'xyz'}).status_code, 200)" + ] + }, + { + "cell_type": "markdown", + "id": "bf91ea45", + "metadata": {}, + "source": [ + "### Shared limits across routes:" + ] + }, + { + "cell_type": "markdown", + "id": "bd9e28c8", + "metadata": {}, + "source": [ + "Save the decorator and apply it to multiple routes to share a single bucket." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "1fd5a249", + "metadata": {}, + "outputs": [], + "source": [ + "shared = limiter('2/m')\n", + "\n", + "@rt(before=shared)\n", + "def users(): return 'users'\n", + "\n", + "@rt(before=shared)\n", + "def posts(): return 'posts'\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "42e03753", + "metadata": {}, + "outputs": [], + "source": [ + "test_eq(cli.get('/users').status_code, 200)\n", + "test_eq(cli.get('/posts').status_code, 200)\n", + "test_eq(cli.get('/users').status_code, 429)" + ] + }, + { + "cell_type": "markdown", + "id": "c758318c", + "metadata": {}, + "source": [ + "### Multiple limiters on one route" + ] + }, + { + "cell_type": "markdown", + "id": "f1ffc673", + "metadata": {}, + "source": [ + "Pass a tuple of limiters as `before` to require all of them to pass. Below, each user gets 1 requests per minute, and the route as a whole is capped at 3 per minute per IP." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "cc568d35", + "metadata": {}, + "outputs": [], + "source": [ + "@rt('/multi', before=(limiter('3/m'), limiter('1/m', key='user')))\n", + "def multi(user:str): return f'hi {user}'" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "f91911ea", + "metadata": {}, + "outputs": [], + "source": [ + "test_eq(cli.get('/multi?user=a').status_code, 200)\n", + "test_eq(cli.get('/multi?user=a').status_code, 429) # user limited\n", + "test_eq(cli.get('/multi?user=b').status_code, 200)\n", + "test_eq(cli.get('/multi?user=c').status_code, 429) # ip limited" + ] + }, + { + "cell_type": "markdown", + "id": "330b3cba", + "metadata": {}, + "source": [ + "# Export -" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "84300cc1", + "metadata": {}, + "outputs": [], + "source": [ + "#| hide\n", + "import nbdev; nbdev.nbdev_export()" + ] + } + ], + "metadata": {}, + "nbformat": 4, + "nbformat_minor": 5 +}