#
# Copyright (C) 2022 Red Hat, Inc.
# SPDX-License-Identifier: GPL-3.0-or-later


import array
import asyncio
import contextlib
import getpass
import logging
import os
import shutil
import socket
import subprocess
from os.path import basename
from tempfile import TemporaryDirectory
from typing import Sequence

from cockpit._vendor import ferny
from cockpit._vendor.bei.bootloader import make_bootloader
from cockpit._vendor.systemd_ctypes import Bus, Variant, bus

from .beipack import BridgeBeibootHelper
from .jsonutil import JsonObject, get_str
from .packages import BridgeConfig
from .peer import ConfiguredPeer, Peer, PeerError
from .polkit import PolkitAgent
from .router import Router, RoutingError, RoutingRule

logger = logging.getLogger(__name__)


def sudo_supports_askpass(sudo_path: str) -> bool:
    try:
        # Returns 0 if -A is supported, non-zero if it's not
        subprocess.run(
            [sudo_path, '-A', '--help'],
            stdin=subprocess.DEVNULL,
            stdout=subprocess.DEVNULL,
            stderr=subprocess.DEVNULL,
            check=True
        )
        return True
    except (subprocess.CalledProcessError, OSError):
        return False


def is_valid_superuser_config(config: BridgeConfig) -> bool:
    if not config.privileged:
        return False
    command = shutil.which(config.spawn[0])
    if command is None:
        return False
    if basename(command) == 'sudo' and not sudo_supports_askpass(command):
        return False
    return True


class SuperuserPeer(ConfiguredPeer):
    responder: ferny.AskpassHandler

    def __init__(self, router: Router, config: BridgeConfig, responder: ferny.AskpassHandler):
        super().__init__(router, config)
        self.responder = responder

    async def start_transient_unit(self, args: 'Sequence[str]', stderr: object) -> asyncio.Transport:
        ours, theirs = socket.socketpair(socket.AF_UNIX, socket.SOCK_STREAM)
        loop = asyncio.get_running_loop()

        unit_name = f"cockpit-superuser-{os.getpid()}.service"

        system = Bus.default_system()
        msg = system.message_new_method_call(
            "org.freedesktop.systemd1",
            "/org/freedesktop/systemd1",
            "org.freedesktop.systemd1.Manager",
            "StartTransientUnit",
            "ssa(sv)a(sa(sv))",
            unit_name,
            "fail",
            [
                ("Description", {"t": "s", "v": "Cockpit privileged bridge"}),
                ("Type", {"t": "s", "v": "exec"}),
                ("User", {"t": "s", "v": "root"}),
                ("StandardInputFileDescriptor", {"t": "h", "v": theirs}),
                ("StandardOutputFileDescriptor", {"t": "h", "v": theirs}),
                ("StandardErrorFileDescriptor", {"t": "h", "v": stderr}),
                ("ExecStart", {"t": "a(sasb)", "v": [(shutil.which(args[0]), args, False)]}),
            ],
            [],
        )
        msg.set_allow_interactive_authorization(True)

        # We fire and forget.  If we catch an authentication error, that'll be
        # raised as a BusError which will propagate upwards to our caller (ie:
        # the Start method) and get displayed as an error dialog.  If we get
        # other errors (like failure to spawn the named executable for some
        # reason or unusual exit codes) then we will see the stderr output but
        # otherwise won't get any notification about it.  The amount of work
        # required to do this "properly" is quite high and it's not super
        # useful.
        await system.call_async(msg)

        transport, protocol = await loop.create_connection(lambda: self, sock=ours)
        assert protocol is self
        return transport

    async def do_connect_transport(self) -> None:
        async with contextlib.AsyncExitStack() as context:
            if self.config.polkit:
                logger.debug('connecting polkit superuser peer transport %r', self.args)
                await context.enter_async_context(PolkitAgent(self.responder))
            else:
                logger.debug('connecting non-polkit superuser peer transport %r', self.args)

            responders: 'list[ferny.InteractionHandler]' = [self.responder]

            if '# cockpit-bridge' in self.args:
                logger.debug('going to beiboot superuser bridge %r', self.args)
                helper = BridgeBeibootHelper(self, ['--privileged'])
                responders.append(helper)
                stage1 = make_bootloader(helper.steps, gadgets=ferny.BEIBOOT_GADGETS).encode()
            else:
                stage1 = None

            agent = ferny.InteractionAgent(responders)

            if 'SUDO_ASKPASS=ferny-askpass' in self.env:
                tmpdir = context.enter_context(TemporaryDirectory())
                ferny_askpass = ferny.write_askpass_to_tmpdir(tmpdir)
                env: Sequence[str] = [f'SUDO_ASKPASS={ferny_askpass}']
            else:
                env = self.env

            if self.config.method == 'StartTransientUnit':
                transport = await self.start_transient_unit(self.args, stderr=agent)
            else:
                transport = await self.spawn(self.args, env, stderr=agent, start_new_session=True)

            if stage1 is not None:
                transport.write(stage1)

            try:
                await agent.communicate()
            except ferny.InteractionError as exc:
                raise PeerError('authentication-failed', message=str(exc)) from exc


class CockpitResponder(ferny.AskpassHandler):
    commands = ('ferny.askpass', 'cockpit.send-stderr')

    async def do_custom_command(
        self, command: str, args: 'tuple[object, ...]', fds: 'list[int]', stderr: str
    ) -> None:
        if command == 'cockpit.send-stderr':
            with socket.socket(fileno=fds[0]) as sock:
                fds.pop(0)
                # socket.send_fds(sock, [b'\0'], [2])  # New in Python 3.9
                sock.sendmsg([b'\0'], [(socket.SOL_SOCKET, socket.SCM_RIGHTS, array.array("i", [2]))])


class AuthorizeResponder(CockpitResponder):
    def __init__(self, router: Router):
        self.router = router
        self.authorize_attempted = False

    async def do_askpass(self, messages: str, prompt: str, hint: str) -> 'str | None':
        if self.authorize_attempted:
            logger.info("noninteractive authorize during init already attempted, rejecting")
            return None
        self.authorize_attempted = True

        hexuser = ''.join(f'{c:02x}' for c in getpass.getuser().encode('ascii'))
        password = await self.router.request_authorization(f'plain1:{hexuser}')
        # translate "no password" from authorize protocol (empty string) to ferny protocol (None)
        return None if password == '' else password


class SuperuserRoutingRule(RoutingRule, CockpitResponder, bus.Object, interface='cockpit.Superuser'):
    superuser_configs: Sequence[BridgeConfig] = ()
    pending_prompt: 'asyncio.Future[str] | None'
    peer: 'SuperuserPeer | None'

    # D-Bus signals
    prompt = bus.Interface.Signal('s', 's', 's', 'b', 's')  # message, prompt, default, echo, error

    # D-Bus properties
    bridges = bus.Interface.Property('as', value=[])
    current = bus.Interface.Property('s', value='none')
    methods = bus.Interface.Property('a{sv}', value={})

    # RoutingRule
    def apply_rule(self, options: JsonObject) -> 'Peer | None':
        superuser = options.get('superuser')

        if not superuser or self.current == 'root':
            # superuser not requested, or already superuser?  Next rule.
            return None
        elif self.peer or superuser == 'try':
            # superuser requested and active?  Return it.
            # 'try' requested?  Either return the peer, or None.
            return self.peer
        else:
            # superuser requested, but not active?  That's an error.
            raise RoutingError('access-denied')

    # ferny.AskpassHandler
    async def do_askpass(self, messages: str, prompt: str, hint: str) -> 'str | None':
        assert self.pending_prompt is None
        echo = hint == "confirm"
        self.pending_prompt = asyncio.get_running_loop().create_future()
        try:
            logger.debug('prompting for %s', prompt)
            # with sudo, all stderr messages are treated as warning/errors by the UI
            # (such as the lecture or "wrong password"), so pass them in the "error" field
            self.prompt('', prompt, '', echo, messages)
            return await self.pending_prompt
        finally:
            self.pending_prompt = None

    def __init__(self, router: Router, *, privileged: bool = False):
        super().__init__(router)

        self.pending_prompt = None
        self.peer = None
        self.startup = None

        if privileged or os.getuid() == 0:
            self.current = 'root'

    def peer_done(self) -> None:
        self.current = 'none'
        self.peer = None

    async def go(self, name: str, responder: ferny.AskpassHandler) -> None:
        if self.current != 'none':
            raise bus.BusError('cockpit.Superuser.Error', 'Superuser bridge already running')

        assert self.peer is None
        assert self.startup is None

        for config in self.superuser_configs:
            if name in (config.name, 'any'):
                break
        else:
            raise bus.BusError('cockpit.Superuser.Error', f'Unknown superuser bridge type "{name}"')

        self.current = 'init'
        self.peer = SuperuserPeer(self.router, config, responder)
        self.peer.add_done_callback(self.peer_done)

        try:
            await self.peer.start(init_host=self.router.init_host)
        except asyncio.CancelledError:
            raise bus.BusError('cockpit.Superuser.Error.Cancelled', 'Operation aborted') from None
        except (OSError, PeerError) as exc:
            raise bus.BusError('cockpit.Superuser.Error', str(exc)) from exc

        self.current = self.peer.config.name

    def set_configs(self, configs: Sequence[BridgeConfig]) -> None:
        logger.debug("set_configs() with %d items", len(configs))
        configs = [config for config in configs if is_valid_superuser_config(config)]
        self.superuser_configs = tuple(configs)
        self.bridges = [config.name for config in self.superuser_configs]
        self.methods = {c.label: Variant({'label': Variant(c.label)}, 'a{sv}') for c in configs if c.label}

        logger.debug("  bridges are now %s", self.bridges)

        # If the currently active bridge config is not in the new set of configs, stop it
        if self.peer is not None:
            if self.peer.config not in self.superuser_configs:
                logger.debug("  stopping superuser bridge '%s': it disappeared from configs", self.peer.config.name)
                self.shutdown()

    def cancel_prompt(self) -> None:
        if self.pending_prompt is not None:
            self.pending_prompt.cancel()
            self.pending_prompt = None

    def shutdown(self) -> None:
        self.cancel_prompt()

        if self.peer is not None:
            self.peer.close()

        # close() should have disconnected the peer immediately
        assert self.peer is None

    # Connect-on-startup functionality
    def init(self, params: JsonObject) -> None:
        name = get_str(params, 'id', 'any')
        responder = AuthorizeResponder(self.router)
        self._init_task = asyncio.create_task(self.go(name, responder))
        self._init_task.add_done_callback(self._init_done)

    def _init_done(self, task: 'asyncio.Task[None]') -> None:
        logger.debug('superuser init done! %s', task.exception())
        self.router.write_control(command='superuser-init-done')
        del self._init_task

    # D-Bus methods
    @bus.Interface.Method(in_types=['s'])  # type: ignore[misc]
    async def start(self, name: str) -> None:
        await self.go(name, self)

    @bus.Interface.Method()  # type: ignore[misc]
    def stop(self) -> None:
        self.shutdown()

    @bus.Interface.Method(in_types=['s'])  # type: ignore[misc]
    def answer(self, reply: str) -> None:
        if self.pending_prompt is not None:
            logger.debug('responding to pending prompt')
            self.pending_prompt.set_result(reply)
        else:
            logger.debug('got Answer, but no prompt pending')
