Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions fasthtml/_modidx.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,6 +254,17 @@
'fasthtml.pico.PicoBusy': ('api/pico.html#picobusy', 'fasthtml/pico.py'),
'fasthtml.pico.Search': ('api/pico.html#search', 'fasthtml/pico.py'),
'fasthtml.pico.set_pico_cls': ('api/pico.html#set_pico_cls', 'fasthtml/pico.py')},
'fasthtml.ratelimit': { 'fasthtml.ratelimit.TokenBucket': ('api/ratelimit.html#tokenbucket', 'fasthtml/ratelimit.py'),
'fasthtml.ratelimit.TokenBucket.__init__': ( 'api/ratelimit.html#tokenbucket.__init__',
'fasthtml/ratelimit.py'),
'fasthtml.ratelimit.TokenBucket.__repr__': ( 'api/ratelimit.html#tokenbucket.__repr__',
'fasthtml/ratelimit.py'),
'fasthtml.ratelimit.TokenBucket._prune': ( 'api/ratelimit.html#tokenbucket._prune',
'fasthtml/ratelimit.py'),
'fasthtml.ratelimit.TokenBucket.wait': ('api/ratelimit.html#tokenbucket.wait', 'fasthtml/ratelimit.py'),
'fasthtml.ratelimit.client_ip': ('api/ratelimit.html#client_ip', 'fasthtml/ratelimit.py'),
'fasthtml.ratelimit.limiter': ('api/ratelimit.html#limiter', 'fasthtml/ratelimit.py'),
'fasthtml.ratelimit.parse_rate': ('api/ratelimit.html#parse_rate', 'fasthtml/ratelimit.py')},
'fasthtml.starlette': {},
'fasthtml.stripe_otp': { 'fasthtml.stripe_otp.Payment': ('explains/stripe.html#payment', 'fasthtml/stripe_otp.py'),
'fasthtml.stripe_otp._search_app': ('explains/stripe.html#_search_app', 'fasthtml/stripe_otp.py'),
Expand Down
11 changes: 6 additions & 5 deletions fasthtml/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from fastcore.style import S

from types import UnionType, SimpleNamespace as ns, GenericAlias
from typing import get_args, get_origin, Union, Mapping, List, Any
from typing import get_args, get_origin, Union, Mapping, List, Any, Callable
from datetime import datetime,date
from dataclasses import dataclass
from inspect import Parameter,get_annotations
Expand Down Expand Up @@ -664,7 +664,7 @@ def add_route(self:FastHTML, route):

# %% ../nbs/api/00_core.ipynb #26b147ba
@patch
def _endp(self:FastHTML, f, body_wrap, before=None):
def _endp(self:FastHTML, f, body_wrap, before:Optional[Callable|tuple]=None):
"Create endpoint wrapper with before/after middleware processing"
sig = signature_ex(f, True)
for n,p in sig.parameters.items(): (msg:=_check_anno(n,p.annotation)) and warn(msg)
Expand All @@ -679,7 +679,8 @@ async def _f(req):
else: bf,skip = b,[]
if not any(re.fullmatch(r, req.url.path) for r in skip):
resp = await _wrap_call(bf, req, _params(bf))
if not resp and before: resp = await _wrap_call(before, req, _params(before))
for b in listify(before):
if not resp: resp = await _wrap_call(b, req, _params(b))
req.body_wrap = body_wrap
if not resp: resp = await _wrap_call(f, req, sig.parameters)
for a in self.after:
Expand Down Expand Up @@ -734,7 +735,7 @@ def nested_name(f):

# %% ../nbs/api/00_core.ipynb #72760b09
@patch
def _add_route(self:FastHTML, func, path, methods, name, include_in_schema, body_wrap, host=None, before=None):
def _add_route(self:FastHTML, func, path, methods, name, include_in_schema, body_wrap, host=None, before:Optional[Callable|tuple]=None):
"Add HTTP route to FastHTML app with automatic method detection"
n,fn,p = name,nested_name(func),None if callable(path) else path
if methods: m = [methods] if isinstance(methods,str) else methods
Expand All @@ -751,7 +752,7 @@ def _add_route(self:FastHTML, func, path, methods, name, include_in_schema, body

# %% ../nbs/api/00_core.ipynb #f5cb2c2b
@patch
def route(self:FastHTML, path:str=None, methods=None, name=None, include_in_schema=True, body_wrap=None, host=None, before=None):
def route(self:FastHTML, path:str=None, methods=None, name=None, include_in_schema=True, body_wrap=None, host=None, before:Optional[Callable|tuple]=None):
"Add a route at `path`"
def f(func):
return self._add_route(func, path, methods, name=name, include_in_schema=include_in_schema, body_wrap=body_wrap, host=host, before=before)
Expand Down
74 changes: 74 additions & 0 deletions fasthtml/ratelimit.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
"""Simple token-bucket rate limiting for FastHTML routes"""

# AUTOGENERATED! DO NOT EDIT! File to edit: ../nbs/api/07_ratelimit.ipynb.

# %% auto #0
__all__ = ['parse_rate', 'TokenBucket', 'client_ip', 'limiter']

# %% ../nbs/api/07_ratelimit.ipynb #b33aec8a
import re,time
from math import ceil
from typing import Callable
from starlette.responses import Response
from fastcore.utils import *

# %% ../nbs/api/07_ratelimit.ipynb #1078d10b
_units = dict(s=1, m=60, h=3600, d=86400)

def parse_rate(s):
"Parse rate string like `'5/m'`, `'1 per day'`, `'100/2h'`"
m = re.match(r'(\d+)\s*(?:/|per)\s*(\d+)?\s*([smhd])', s, re.I)
if not m: raise ValueError(f"Invalid rate: {s}")
n,mult,unit = m.groups()
return int(n), _units[unit.lower()] * int(mult or 1)

# %% ../nbs/api/07_ratelimit.ipynb #7b28079a
class TokenBucket:
"Token-bucket rate limiter"
def __init__(self,
max_reqs:str|int, # Rate string ('5/m') or max requests per window
window_secs:int=None # Window in seconds (required if `max_reqs` is int)
):
if window_secs is None: max_reqs,window_secs = parse_rate(max_reqs)
store_attr()
self.rate = max_reqs / window_secs
self.buckets = {}
def __repr__(self): return f'TokenBucket({self.max_reqs}, {self.window_secs})'

def _prune(self):
cutoff = time.time() - self.window_secs
self.buckets = {k:(t,ts) for k,(t,ts) in self.buckets.items() if ts > cutoff}
def wait(self, key):
"Return 0 if allowed, else seconds to wait"
self._prune()
now = time.time()
tokens, last = self.buckets.get(key, (self.max_reqs, now))
tokens = min(self.max_reqs, tokens + (now - last) * self.rate)
if tokens < 1: return (1 - tokens) / self.rate
self.buckets[key] = (tokens - 1, now)
return 0

# %% ../nbs/api/07_ratelimit.ipynb #5b3c8231
def client_ip(req, **kwargs):
"Get client IP from `X-Forwarded-For` header (assumes deployment behind a single Caddy reverse proxy)"
return req.headers.get('x-forwarded-for', '').split(',')[0].strip() or (req.client and req.client.host) or ''

# %% ../nbs/api/07_ratelimit.ipynb #97f054cb
def limiter(rate:str|tuple, # Rate string ('5/m') or (max_reqs, window_secs) tuple
key:str|Callable=client_ip, # Key to limit by: route param name, or callable(req, **kwargs)
on_limit:Callable=None # Optional callback(wait_secs) to return custom response on 429
):
"Create a `before` function that rate-limits requests"
bucket = TokenBucket(*rate) if isinstance(rate, tuple) else TokenBucket(rate)
def _429(w):
if on_limit: return on_limit(w)
return Response('Too many requests', status_code=429, headers={'Retry-After': str(ceil(w))})
if callable(key):
def before(req):
if w:=bucket.wait(key(req)): return _429(w)
return before
async def before(req):
v = ifnone(req.path_params.get(key), req.query_params.get(key))
if v is None: v = (await req.form()).get(key)
if w:=bucket.wait(str(v or '')): return _429(w)
return before
71 changes: 66 additions & 5 deletions nbs/api/00_core.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@
"from fastcore.style import S\n",
"\n",
"from types import UnionType, SimpleNamespace as ns, GenericAlias\n",
"from typing import get_args, get_origin, Union, Mapping, List, Any\n",
"from typing import get_args, get_origin, Union, Mapping, List, Any, Callable\n",
"from datetime import datetime,date\n",
"from dataclasses import dataclass\n",
"from inspect import Parameter,get_annotations\n",
Expand Down Expand Up @@ -1966,7 +1966,7 @@
"source": [
"#| export\n",
"@patch\n",
"def _endp(self:FastHTML, f, body_wrap, before=None):\n",
"def _endp(self:FastHTML, f, body_wrap, before:Optional[Callable|tuple]=None):\n",
" \"Create endpoint wrapper with before/after middleware processing\"\n",
" sig = signature_ex(f, True)\n",
" for n,p in sig.parameters.items(): (msg:=_check_anno(n,p.annotation)) and warn(msg)\n",
Expand All @@ -1981,7 +1981,8 @@
" else: bf,skip = b,[]\n",
" if not any(re.fullmatch(r, req.url.path) for r in skip):\n",
" resp = await _wrap_call(bf, req, _params(bf))\n",
" if not resp and before: resp = await _wrap_call(before, req, _params(before))\n",
" for b in listify(before):\n",
" if not resp: resp = await _wrap_call(b, req, _params(b))\n",
" req.body_wrap = body_wrap\n",
" if not resp: resp = await _wrap_call(f, req, sig.parameters)\n",
" for a in self.after:\n",
Expand Down Expand Up @@ -2118,7 +2119,7 @@
"source": [
"#| export\n",
"@patch\n",
"def _add_route(self:FastHTML, func, path, methods, name, include_in_schema, body_wrap, host=None, before=None):\n",
"def _add_route(self:FastHTML, func, path, methods, name, include_in_schema, body_wrap, host=None, before:Optional[Callable|tuple]=None):\n",
" \"Add HTTP route to FastHTML app with automatic method detection\"\n",
" n,fn,p = name,nested_name(func),None if callable(path) else path\n",
" if methods: m = [methods] if isinstance(methods,str) else methods\n",
Expand All @@ -2143,7 +2144,7 @@
"source": [
"#| export\n",
"@patch\n",
"def route(self:FastHTML, path:str=None, methods=None, name=None, include_in_schema=True, body_wrap=None, host=None, before=None):\n",
"def route(self:FastHTML, path:str=None, methods=None, name=None, include_in_schema=True, body_wrap=None, host=None, before:Optional[Callable|tuple]=None):\n",
" \"Add a route at `path`\"\n",
" def f(func):\n",
" return self._add_route(func, path, methods, name=name, include_in_schema=include_in_schema, body_wrap=body_wrap, host=host, before=before)\n",
Expand Down Expand Up @@ -3519,6 +3520,14 @@
"with cli: test_eq(cli.get('/').text, 'fired')"
]
},
{
"cell_type": "markdown",
"id": "b9fdbdd8",
"metadata": {},
"source": [
"A route-level `before` can set session data, short-circuit with a response, or raise to block access:"
]
},
{
"cell_type": "code",
"execution_count": null,
Expand All @@ -3541,6 +3550,58 @@
"test_eq(cli.get('/dashboard?user=other').status_code, 403)"
]
},
{
"cell_type": "markdown",
"id": "35615416",
"metadata": {},
"source": [
"Multiple `before` functions run in order, each receiving the injected request params:"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "6b751acb",
"metadata": {},
"outputs": [],
"source": [
"def set_a(sess): sess['a'] = 1\n",
"def set_b(sess): sess['b'] = 2\n",
"\n",
"@rt('/multi', before=(set_a, set_b))\n",
"def get(sess): return f\"{sess.get('a')},{sess.get('b')}\"\n",
"\n",
"test_eq(cli.get('/multi').text, '1,2')"
]
},
{
"cell_type": "markdown",
"id": "49cc8791",
"metadata": {},
"source": [
"A truthy return from any `before` becomes the response; remaining `before`s and the handler are skipped (return `None` to continue):"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "83188ab0",
"metadata": {},
"outputs": [],
"source": [
"calls = []\n",
"def first(sess): \n",
" calls.append('first')\n",
" return 'stop'\n",
"def second(sess): calls.append('second')\n",
"\n",
"@rt('/chain', before=(first, second))\n",
"def get(sess): return 'handler'\n",
"\n",
"test_eq(cli.get('/chain').text, 'stop')\n",
"test_eq(calls, ['first'])"
]
},
{
"cell_type": "markdown",
"id": "6a014add",
Expand Down
Loading
Loading