Add server with basic auth stuff

This commit is contained in:
Tulir Asokan
2020-10-31 21:53:46 +02:00
parent d3adedf3df
commit 9151f4cb6d
54 changed files with 3415 additions and 419 deletions

View File

@ -0,0 +1,38 @@
# maunium-stickerpicker - A fast and simple Matrix sticker picker widget.
# Copyright (C) 2020 Tulir Asokan
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Affero General Public License for more details.
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
from aiohttp import web
from ..config import Config
from .auth import (routes as auth_routes, init as auth_init,
token_middleware, widget_secret_middleware)
from .fed_connector import init as init_fed_connector
from .packs import routes as packs_routes, init as packs_init
from .setup import routes as setup_routes
integrations_app = web.Application()
integrations_app.add_routes(auth_routes)
packs_app = web.Application(middlewares=[widget_secret_middleware])
packs_app.add_routes(packs_routes)
setup_app = web.Application(middlewares=[token_middleware])
setup_app.add_routes(setup_routes)
def init(config: Config) -> None:
init_fed_connector()
auth_init(config)
packs_init(config)

216
sticker/server/api/auth.py Normal file
View File

@ -0,0 +1,216 @@
# maunium-stickerpicker - A fast and simple Matrix sticker picker widget.
# Copyright (C) 2020 Tulir Asokan
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Affero General Public License for more details.
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
from typing import Tuple, Callable, Awaitable, Optional, TYPE_CHECKING
import logging
import json
from mautrix.client import Client
from mautrix.types import UserID
from mautrix.util.logging import TraceLogger
from aiohttp import web, hdrs, ClientError, ClientSession
from yarl import URL
from ..database import AccessToken, User
from ..config import Config
from .errors import Error
from . import fed_connector
if TYPE_CHECKING:
from typing import TypedDict
class OpenIDPayload(TypedDict):
access_token: str
token_type: str
matrix_server_name: str
expires_in: int
class OpenIDResponse(TypedDict):
sub: str
Handler = Callable[[web.Request], Awaitable[web.Response]]
log: TraceLogger = logging.getLogger("mau.api.auth")
routes = web.RouteTableDef()
config: Config
def get_ip(request: web.Request) -> str:
if config["server.trust_forward_headers"]:
try:
return request.headers["X-Forwarded-For"]
except KeyError:
pass
return request.remote
def get_auth_header(request: web.Request) -> str:
try:
auth = request.headers["Authorization"]
if not auth.startswith("Bearer "):
raise Error.invalid_auth_header
return auth[len("Bearer "):]
except KeyError:
raise Error.missing_auth_header
async def get_user(request: web.Request) -> Tuple[User, AccessToken]:
auth = get_auth_header(request)
try:
token_id, token_val = auth.split(":")
token_id = int(token_id)
except ValueError:
raise Error.invalid_auth_token
token = await AccessToken.get(token_id)
if not token or not token.check(token_val):
raise Error.invalid_auth_token
elif token.expired:
raise Error.auth_token_expired
await token.update_ip(get_ip(request))
return await User.get(token.user_id), token
@web.middleware
async def token_middleware(request: web.Request, handler: Handler) -> web.Response:
if request.method == hdrs.METH_OPTIONS:
return await handler(request)
user, token = await get_user(request)
request["user"] = user
request["token"] = token
return await handler(request)
async def get_widget_user(request: web.Request) -> User:
try:
user_id = UserID(request.headers["X-Matrix-User-ID"])
except KeyError:
raise Error.missing_user_id_header
user = await User.get(user_id)
if user is None:
raise Error.user_not_found
return user
@web.middleware
async def widget_secret_middleware(request: web.Request, handler: Handler) -> web.Response:
if request.method == hdrs.METH_OPTIONS:
return await handler(request)
user = await get_widget_user(request)
request["user"] = user
return await handler(request)
account_cors_headers = {
"Access-Control-Allow-Origin": "*",
"Access-Control-Allow-Methods": "OPTIONS, GET, POST",
"Access-Control-Allow-Headers": "Authorization, Content-Type",
}
@routes.get("/account")
async def get_auth(request: web.Request) -> web.Response:
user, token = await get_user(request)
return web.json_response({"user_id": token.user_id}, headers=account_cors_headers)
async def check_openid_token(homeserver: str, token: str) -> Optional[UserID]:
server_info = await fed_connector.resolve_server_name(homeserver)
headers = {"Host": server_info.host_header}
userinfo_url = URL.build(scheme="https", host=server_info.host, port=server_info.port,
path="/_matrix/federation/v1/openid/userinfo",
query={"access_token": token})
try:
async with fed_connector.http.get(userinfo_url, headers=headers) as resp:
data: 'OpenIDResponse' = await resp.json()
return UserID(data["sub"])
except (ClientError, json.JSONDecodeError, KeyError, ValueError) as e:
log.debug(f"Failed to check OpenID token from {homeserver}", exc_info=True)
return None
@routes.route(hdrs.METH_OPTIONS, "/account/register")
@routes.route(hdrs.METH_OPTIONS, "/account/logout")
@routes.route(hdrs.METH_OPTIONS, "/account")
async def cors_token(_: web.Request) -> web.Response:
return web.Response(status=200, headers=account_cors_headers)
async def resolve_client_well_known(server_name: str) -> str:
url = URL.build(scheme="https", host=server_name, port=443, path="/.well-known/matrix/client")
async with ClientSession() as sess, sess.get(url) as resp:
data = await resp.json()
return data["m.homeserver"]["base_url"]
@routes.post("/account/register")
async def exchange_token(request: web.Request) -> web.Response:
try:
data: 'OpenIDPayload' = await request.json()
except json.JSONDecodeError:
raise Error.request_not_json
try:
matrix_server_name = data["matrix_server_name"]
access_token = data["access_token"]
except KeyError:
raise Error.invalid_openid_payload
log.trace(f"Validating OpenID token from {matrix_server_name}")
user_id = await check_openid_token(matrix_server_name, access_token)
if user_id is None:
raise Error.invalid_openid_token
_, homeserver = Client.parse_user_id(user_id)
if homeserver != data["matrix_server_name"]:
raise Error.homeserver_mismatch
permissions = config.get_permissions(user_id)
if not permissions.access:
raise Error.no_access
try:
log.trace(f"Trying to resolve {matrix_server_name}'s client .well-known")
homeserver_url = await resolve_client_well_known(matrix_server_name)
log.trace(f"Got {homeserver_url} from {matrix_server_name}'s client .well-known")
except (ClientError, json.JSONDecodeError, KeyError, ValueError, TypeError):
log.trace(f"Failed to resolve {matrix_server_name}'s client .well-known", exc_info=True)
raise Error.client_well_known_error
user = await User.get(user_id)
if user is None:
log.debug(f"Creating user {user_id} with homeserver client URL {homeserver_url}")
user = User.new(user_id, homeserver_url=homeserver_url)
await user.insert()
elif user.homeserver_url != homeserver_url:
log.debug(f"Updating {user_id}'s homeserver client URL from {user.homeserver_url} "
f"to {homeserver_url}")
await user.set_homeserver_url(homeserver_url)
token = await user.new_access_token(get_ip(request))
return web.json_response({
"user_id": user_id,
"token": token,
"permissions": permissions._asdict(),
}, headers=account_cors_headers)
@routes.post("/account/logout")
async def logout(request: web.Request) -> web.Response:
user, token = await get_user(request)
await token.delete()
return web.json_response({}, status=204, headers=account_cors_headers)
def init(cfg: Config) -> None:
global config
config = cfg

View File

@ -0,0 +1,110 @@
# maunium-stickerpicker - A fast and simple Matrix sticker picker widget.
# Copyright (C) 2020 Tulir Asokan
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Affero General Public License for more details.
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
from typing import Dict
import json
from aiohttp import web
class _ErrorMeta:
def __init__(self, *args, **kwargs) -> None:
pass
@staticmethod
def _make_error(errcode: str, error: str) -> Dict[str, str]:
return {
"body": json.dumps({
"error": error,
"errcode": errcode,
}).encode("utf-8"),
"content_type": "application/json",
"headers": {
"Access-Control-Allow-Origin": "*",
"Access-Control-Allow-Methods": "OPTIONS, GET, POST, PUT, DELETE, HEAD",
"Access-Control-Allow-Headers": "Authorization, Content-Type",
}
}
@property
def request_not_json(self) -> web.HTTPException:
return web.HTTPBadRequest(**self._make_error("M_NOT_JSON",
"Request body is not valid JSON"))
@property
def missing_auth_header(self) -> web.HTTPException:
return web.HTTPForbidden(**self._make_error("M_MISSING_TOKEN",
"Missing authorization header"))
@property
def missing_user_id_header(self) -> web.HTTPException:
return web.HTTPForbidden(**self._make_error("NET.MAUNIUM_MISSING_USER_ID",
"Missing user ID header"))
@property
def user_not_found(self) -> web.HTTPException:
return web.HTTPNotFound(**self._make_error("NET.MAUNIUM_USER_NOT_FOUND",
"User not found"))
@property
def invalid_auth_header(self) -> web.HTTPException:
return web.HTTPForbidden(**self._make_error("M_UNKNOWN_TOKEN",
"Invalid authorization header"))
@property
def invalid_auth_token(self) -> web.HTTPException:
return web.HTTPForbidden(**self._make_error("M_UNKNOWN_TOKEN",
"Invalid authorization token"))
@property
def auth_token_expired(self) -> web.HTTPException:
return web.HTTPForbidden(**self._make_error("NET.MAUNIUM_TOKEN_EXPIRED",
"Authorization token has expired"))
@property
def invalid_openid_payload(self) -> web.HTTPException:
return web.HTTPBadRequest(**self._make_error("M_BAD_JSON", "Missing one or more "
"fields in OpenID payload"))
@property
def invalid_openid_token(self) -> web.HTTPException:
return web.HTTPForbidden(**self._make_error("M_UNKNOWN_TOKEN",
"Invalid OpenID token"))
@property
def no_access(self) -> web.HTTPException:
return web.HTTPUnauthorized(**self._make_error(
"M_UNAUTHORIZED",
"You are not authorized to access this maunium-stickerpicker instance"))
@property
def homeserver_mismatch(self) -> web.HTTPException:
return web.HTTPUnauthorized(**self._make_error(
"M_UNAUTHORIZED", "Request matrix_server_name and OpenID sub homeserver don't match"))
@property
def pack_not_found(self) -> web.HTTPException:
return web.HTTPNotFound(**self._make_error("NET.MAUNIUM_PACK_NOT_FOUND",
"Sticker pack not found"))
@property
def client_well_known_error(self) -> web.HTTPException:
return web.HTTPForbidden(**self._make_error("NET.MAUNIUM_CLIENT_WELL_KNOWN_ERROR",
"Failed to resolve homeserver URL "
"from client .well-known"))
class Error(metaclass=_ErrorMeta):
pass

View File

@ -0,0 +1,110 @@
from typing import Tuple, Any, NamedTuple, Dict, Optional
from time import time
import ipaddress
import logging
import asyncio
import json
from mautrix.util.logging import TraceLogger
from aiohttp import ClientRequest, TCPConnector, ClientSession, ClientTimeout, ClientError
from aiohttp.client_proto import ResponseHandler
from yarl import URL
import aiodns
log: TraceLogger = logging.getLogger("mau.federation")
class ResolvedServerName(NamedTuple):
host_header: str
host: str
port: int
expire: int
class ServerNameSplit(NamedTuple):
host: str
port: Optional[int]
is_ip: bool
dns_resolver: aiodns.DNSResolver
http: ClientSession
server_name_cache: Dict[str, ResolvedServerName] = {}
class MatrixFederationTCPConnector(TCPConnector):
"""An extension to aiohttp's TCPConnector that correctly sets the TLS SNI for Matrix federation
requests, where the TCP host may not match the SNI/Host header."""
async def _wrap_create_connection(self, *args: Any, server_hostname: str, req: ClientRequest,
**kwargs: Any) -> Tuple[asyncio.Transport, ResponseHandler]:
split = parse_server_name(req.headers["Host"])
return await super()._wrap_create_connection(*args, server_hostname=split.host,
req=req, **kwargs)
def parse_server_name(name: str) -> ServerNameSplit:
port_split = name.rsplit(":", 1)
if len(port_split) == 2 and port_split[1].isdecimal():
name, port = port_split
else:
port = None
try:
ipaddress.ip_address(name)
is_ip = True
except ValueError:
is_ip = False
res = ServerNameSplit(host=name, port=port, is_ip=is_ip)
log.trace(f"Parsed server name {name} into {res}")
return res
async def resolve_server_name(server_name: str) -> ResolvedServerName:
try:
cached = server_name_cache[server_name]
if cached.expire > int(time()):
log.trace(f"Using cached server name resolution for {server_name}: {cached}")
return cached
except KeyError:
log.trace(f"No cached server name resolution for {server_name}")
host_header = server_name
hostname, port, is_ip = parse_server_name(host_header)
ttl = 86400
if port is None and not is_ip:
well_known_url = URL.build(scheme="https", host=host_header, port=443,
path="/.well-known/matrix/server")
try:
log.trace(f"Requesting {well_known_url} to resolve {server_name}'s .well-known")
async with http.get(well_known_url) as resp:
if resp.status == 200:
well_known_data = await resp.json()
host_header = well_known_data["m.server"]
log.debug(f"Got {host_header} from {server_name}'s .well-known")
hostname, port, is_ip = parse_server_name(host_header)
else:
log.trace(f"Got non-200 status {resp.status} from {server_name}'s .well-known")
except (ClientError, json.JSONDecodeError, KeyError, ValueError) as e:
log.debug(f"Failed to fetch .well-known for {server_name}: {e}")
if port is None and not is_ip:
log.trace(f"Querying SRV at _matrix._tcp.{host_header}")
res = await dns_resolver.query(f"_matrix._tcp.{host_header}", "SRV")
if res:
hostname = res[0].host
port = res[0].port
ttl = max(res[0].ttl, 300)
log.debug(f"Got {hostname}:{port} from {host_header}'s Matrix SRV record")
else:
log.trace(f"No SRV records found at _matrix._tcp.{host_header}")
result = ResolvedServerName(host_header=host_header, host=hostname, port=port or 8448,
expire=int(time()) + ttl)
server_name_cache[server_name] = result
log.debug(f"Resolved server name {server_name} -> {result}")
return result
def init():
global http, dns_resolver
dns_resolver = aiodns.DNSResolver(loop=asyncio.get_running_loop())
http = ClientSession(timeout=ClientTimeout(total=10),
connector=MatrixFederationTCPConnector())

View File

@ -0,0 +1,52 @@
# maunium-stickerpicker - A fast and simple Matrix sticker picker widget.
# Copyright (C) 2020 Tulir Asokan
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Affero General Public License for more details.
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
from aiohttp import web
from ..database import User
from ..config import Config
from .errors import Error
routes = web.RouteTableDef()
config: Config
@routes.get("/index.json")
async def get_packs(req: web.Request) -> web.Response:
user: User = req["user"]
packs = await user.get_packs()
return web.json_response({
"homeserver_url": user.homeserver_url,
"is_sticker_server": True,
"packs": [f"{pack.id}.json" for pack in packs],
})
@routes.get("/{pack_id}.json")
async def get_pack(req: web.Request) -> web.Response:
user: User = req["user"]
pack = await user.get_pack(req.match_info["pack_id"])
if pack is None:
raise Error.pack_not_found
stickers = await pack.get_stickers()
return web.json_response({
**pack.to_dict(),
"stickers": [sticker.to_dict() for sticker in stickers],
})
def init(cfg: Config) -> None:
global config
config = cfg

View File

@ -0,0 +1,19 @@
# maunium-stickerpicker - A fast and simple Matrix sticker picker widget.
# Copyright (C) 2020 Tulir Asokan
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Affero General Public License for more details.
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
from aiohttp import web
routes = web.RouteTableDef()