#!/usr/bin/python3
#  OpenVPN 3 Linux client -- Next generation OpenVPN client
#
#  SPDX-License-Identifier: AGPL-3.0-only
#
#  Copyright (C) 2021 - 2023  OpenVPN Inc <sales@openvpn.net>
#  Copyright (C) 2021 - 2023  David Sommerseth <davids@openvpn.net>
#

##
# @file  openvpn3-systemd
#
# @brief OpenVPN 3 session management interface for systemd
#

import argparse
import datetime
import dbus
import errno
import openvpn3
import os
import subprocess
import sys
import systemd.daemon
import time

from dbus.mainloop.glib import DBusGMainLoop
from gi.repository import GLib
from openvpn3 import StatusMajor, StatusMinor, SessionManagerEventType
from openvpn3 import ClientAttentionType

if sys.version_info[0] < 3:
    print("Python 3.x is required")
    sys.exit(1)


##
#  Main class for managing OpenVPN 3 VPN sessions via systemd
#
#  This class takes care of the main loop of a running VPN session
#  and provides log entries to stdout.  The systemd journal will
#  pick up this and preserve this information
#
class OpenVPN3systemd(object):
    def __init__(self):
        # Set up the main GLib loop and connect to the system bus
        self.__mainloop = GLib.MainLoop()
        self.__dbusloop = DBusGMainLoop(set_as_default=True)
        self.__sysbus = dbus.SystemBus(mainloop=self.__dbusloop)

        # Connect to the configuration and session manager
        self.__cmgr = openvpn3.ConfigurationManager(self.__sysbus)
        self.__config = None
        self.__smgr = openvpn3.SessionManager(self.__sysbus)
        self.__smgr.SessionManagerCallback(self._SessionManagerEvent)

        # Initial states for session information
        self.__session = None
        self.__sesspath = None
        self.__error_set = False


    ##
    #  Method for updating the systemd status field of the
    #  systemd service
    #
    def sd_notify(self, ready, msg = None):
        if self.__error_set:
            return

        notifstr = ''
        if ready > 0:
            notifstr += 'READY=%i\n' % (ready)
        if msg is not None and len(msg) > 0:
            notifstr += "%s\n" % (msg)

        systemd.daemon.notify(notifstr)


    ##
    #  Simple Log signal callback function.  Called each time a Log event
    #  happens on this session.
    #
    def _LogHandler(self, group, catg, msg):
        loglines = [l for l in msg.split('\n') if len(l) > 0]
        if len(loglines) < 1:
            return

        print('%s %s' % (datetime.datetime.now(), loglines[0]))
        for line in loglines[1:]:
            print('%s%s' % (' ' * 33, line))


    ##
    #  Simple StatusChange callback function.  Called each time a
    #  StatusChange event happens on this session.
    #
    #  This handler will also shutdown this currently running session if
    #  the session is disconnected outside of this program
    #
    def _StatusHandler(self, major, minor, msg):
        maj = StatusMajor(major)
        min = StatusMinor(minor)
        status_str = "%s: %s" % (maj.name.capitalize(), self._status_minor_describe(min))
        if len(msg) > 0:
            status_str += " - %s" % msg

        print('%s [STATUS] %s' % (datetime.datetime.now(), status_str))
        self.sd_notify(1, "STATUS=%s" % status_str)

        # Session got most likely disconnected outside of this program
        if StatusMajor.SESSION == maj and StatusMinor.PROC_STOPPED == min:
            self.sd_notify(0, "STATUS=%s\nSTOPPING=1" % status_str)
            self.__mainloop.quit()
        if StatusMajor.CONNECTION == maj and StatusMinor.CONN_AUTH_FAILED == min:
            self.sd_notify(0, "STATUS=%s\nERRRNO=%i\nSTOPPING=1" % (status_str, errno.EACCES))
            self.__error_set = True
            self.__mainloop.quit()
        if StatusMajor.CONNECTION == maj and StatusMinor.CONN_DISCONNECTED == min:
            self.sd_notify(0, "STATUS=%s\nSTOPPING=1" % status_str)
            self.__mainloop.quit()
        if StatusMajor.CONNECTION == maj and StatusMinor.CFG_REQUIRE_USER == min:
            error = self.__request_credentials()
            if error is None:
                self.__session.Connect()
            else:
                try:
                    self.__session.Disconnect()
                except:
                    pass
                self.sd_notify(0, "STATUS=Error: %s\nSTOPPING=1" % error)
                print("** ERROR ** %s" % error)
                self.__mainloop.quit()


    def _status_minor_describe(self, minor):
        if StatusMinor.UNSET == minor:
            return "(Unknown status)"
        elif StatusMinor.CFG_ERROR == minor:
            return "Configuration error"
        elif StatusMinor.CFG_OK == minor:
            return "Configuration OK"
        elif StatusMinor.CFG_INLINE_MISSING == minor:
            return "Inline configuration data is missing"
        elif StatusMinor.CFG_REQUIRE_USER == minor:
            return "User need to provide more information"
        elif StatusMinor.CONN_INIT == minor:
            return "VPN connection is initializing"
        elif StatusMinor.CONN_CONNECTING == minor:
            return "VPN connection is starting"
        elif StatusMinor.CONN_CONNECTED == minor:
            return "VPN client is connected"
        elif StatusMinor.CONN_DISCONNECTING == minor:
            return "VPN client is disconnecting"
        elif StatusMinor.CONN_DISCONNECTED == minor:
            return "VPN client has disconnected"
        elif StatusMinor.CONN_FAILED == minor:
            return "VPN connection failed"
        elif StatusMinor.CONN_AUTH_FAILED == minor:
            return "User authentication for VPN session failed"
        elif StatusMinor.CONN_RECONNECTING == minor:
            return "VPN connection is reconnecting"
        elif StatusMinor.CONN_PAUSING == minor:
            return "VPN client is pausing the session"
        elif StatusMinor.CONN_PAUSED == minor:
            return "VPN client has paused the session"
        elif StatusMinor.CONN_RESUMING == minor:
            return "VPN client is resuming the paused session"
        elif StatusMinor.CONN_DONE == minor:
            return "VPN client session has completed"
        elif StatusMinor.SESS_AUTH_USERPASS == minor:
            return "VPN session need account credentials from user"
        elif StatusMinor.SESS_AUTH_URL == minor:
            return "VPN session need web based user authentication to complete"
        elif StatusMinor.SESS_AUTH_CHALLENGE == minor:
            return "VPN session need a challenge response from the user for authentication"
        elif StatusMinor.PROC_STARTED == minor:
            return "VPN client process has started"
        elif StatusMinor.PROC_STOPPED == minor:
            return "VPN client process has stopped"
        elif StatusMinor.PROC_KILLED == minor:
            return "VPN client process has been killed"


    ##
    #  This handles the SessionManagerEvent signals from the session manager.
    #  This is used to catch if the session manager closes the session, which
    #  means this script can stop running.
    #
    def _SessionManagerEvent(self, event):
        if 'OPENVPN3_DEBUG' in os.environ:
            print('%s %s' % (datetime.datetime.now(), str(event)))

        if self.__sesspath == event and \
           SessionManagerEventType.SESS_DESTROYED == event:
              self.__mainloop.quit()

    ##
    #  Method for retrieving an openvpn3.Config object from a configuration
    #  profile name.  This is needed to start a new VPN session
    #
    def RetrieveConfig(self, config):
        if self.__config is not None:
            raise RuntimeError("Configuration profile already retrieved")

        cfglist = self.__cmgr.LookupConfigName(config)
        if 0 == len(cfglist):
            raise RuntimeError("No configuration found")
        elif 1 < len(cfglist):
            raise RuntimeError("More than one configuration profile found.")

        self.__config = self.__cmgr.Retrieve(cfglist[0])
        self.__config.Validate()
        print("Loaded configuration profile %s (path: %s)" % (
            self.__config.GetProperty('name'),
            self.__config.GetPath()))


    ##
    #  Method for retrieving a running VPN session based on the configuration
    #  profile name.  This is needed to restart or stop an on-going VPN
    #  session via systemctl.
    #
    def RetrieveSession(self, config):
        if self.__session is not None:
            raise RuntimeError("Session already retrieved")

        sesslist = self.__smgr.LookupConfigName(config)
        if 0 == len(sesslist):
            raise RuntimeError("No running VPN session for configuration")
        elif 1 < len(sesslist):
            raise RuntimeError("More than one running session uses this configuration profile")

        self.__session = self.__smgr.Retrieve(sesslist[0])
        print("Retrieved session data for profile %s/%s (path: %s)" % (
            self.__session.GetProperty('config_name'),
            self.__session.GetProperty('session_name'),
            self.__session.GetPath()))


    ##
    #  Starts a new VPN session.
    #
    #  Before this method is called, the RetrieveConfig() method must be called
    #
    def Start(self, log_level):
        if self.__session is not None:
            raise RuntimeError("Session already started")
        if self.__config is None:
            raise RuntimeError("No configuration profile has been retrieved")

        # Check if a VPN session is already running with this configuration name
        num_sessions = 0;
        try:
            chk = self.__smgr.LookupConfigName(self.__config.GetConfigName())
            num_sessions = len(chk)
        except:
            # If the LookupConfigName call failed, presume everything is fine
            pass
        finally:
            if num_sessions > 0:
                raise RuntimeError("A VPN session with %s is already active" %  self.__config.GetConfigName())


        #  Create a new VPN Session
        self.__session = self.__smgr.NewTunnel(self.__config)
        time.sleep(2)
        self.__sesspath = self.__session.GetPath()

        # Set up signal callback handlers and the proper log level
        self.__session.LogCallback(self._LogHandler)
        self.__session.StatusChangeCallback(self._StatusHandler)

        if log_level > 0:
            print(">> Changing log level to %i" % log_level)
            self.__session.SetProperty('log_verbosity', dbus.UInt32(log_level))

        print("Session initiated: %s" % self.__session.GetPath())
        done = False
        while done is False:
            try:
                self.__session.Ready()
                print("Starting session connection")
                self.__session.Connect()
                print("Session started successfully")

                self.__mainloop.run()
                try:
                    print("Disconnecting")
                    self.sd_notify(1, "STATUS=Stopping\nSTOPPING=1")
                    try:
                        self.__session.LogCallback(None)
                    except dbus.exceptions.DBusException:
                        pass
                    self.__session.Disconnect()
                except dbus.exceptions.DBusException as e:
                    if str(e).split(':')[0] == 'org.freedesktop.DBus.Error.UnknownMethod':
                        pass
                    else:
                        raise e
                print("Disconnected")
            except dbus.exceptions.DBusException as excep:
                if str(excep).find('Backend VPN process is not ready') > 1:
                    time.sleep(0.5)
                elif str(excep).find(' Missing user credentials') > 0:
                    error = self.__request_credentials()
                    if error is not None:
                        try:
                            self.__session.Disconnect()
                        except:
                            pass
                        raise RuntimeError(error)

                elif str(excep).find('ERR_PROFILE_SERVER_LOCKED_UNSUPPORTED:') > 1:
                    self.__session.Disconnect()
                    raise Exception('Server locked profiles are not supported')
                elif str(excep).find(' Backend VPN process has died') > 1:
                    raise Exception('VPN backend process stopped unexpectedly. '
                                      + 'Session has closed. :: ' + str(excep))
            except KeyboardInterrupt:
                print("Disconnecting (sigint)")
                self.sd_notify(1, "STATUS=Stopping (SIGINT)\nSTOPPING=1")
                done = True
                try:
                    self.__session.LogCallback(None)
                except dbus.exceptions.DBusException:
                    pass
                try:
                    self.__session.Disconnect()
                except dbus.exceptions.DBusException as excep:
                    pass
                break
            except RuntimeError:
                done = True
                break


    ##
    #  Restart an on-going VPN session
    #
    #  The RetrieveSession() method must be called before calling
    #  this method
    def Restart(self):
        if self.__session is None:
            raise RuntimeError("Session not started")

        self.sd_notify(1, "STATUS=Restarting session")
        self.__session.Restart()


    ##
    #  Stop an on-going VPN session
    #
    #  The RetrieveSession() method must be called before calling
    #  this method
    def Stop(self):
        if self.__session is None:
            raise RuntimeError("Session not started")

        self.sd_notify(0, "STOPPING=1\nSTATUS=Disconnecting VPN session\n")
        self.__session.Disconnect()
        self.sd_notify(0, "STOPPING=1\nSTATUS=VPN session stopped\n")


    def __request_credentials(self):
        args_base = [ '/usr/bin/systemd-ask-password','--icon','network-vpn']

        for input_slot in self.__session.FetchUserInputSlots():
            if input_slot.GetTypeGroup()[0] != ClientAttentionType.CREDENTIALS:
                # skip non-user credential requests
                continue

            args = [] + args_base
            if False == input_slot.GetInputMask():
                args += ['--echo']
            args += ['%s:' % (input_slot.GetLabel())]
            try:
                r = subprocess.check_output(args)
                input_slot.ProvideInput(r.decode('utf-8').strip())
                return None
            except subprocess.CalledProcessError as e:
                self.__session.Disconnect()
                return "Could not retrieve user credentials: %s" % str(e)
            except dbus.exceptions.DBusException as e:
                return "VPN session no longer valid: %s" % e.get_dbus_message()


if "__main__" == __name__:
    start_log_level = -1
    if 'OPENVPN3_DEBUG' in os.environ:
        start_log_level = 6
    argp = argparse.ArgumentParser('openvpn3-systemd',
                                   description='OpenVPN 3 systemd session management')
    argp.add_argument('config', metavar='CONFIG_PROFILE', type=str,
                      help='Configuration profile name of session which is to be managed')
    argp.add_argument('--start', action='store_const', const='start', dest='mode',
                      help='Starts a VPN session with this configuration profile')
    argp.add_argument('--restart', action='store_const', const='restart', dest='mode',
                      help='Restarts a running VPN session with this configuration profile')
    argp.add_argument('--stop', action='store_const', const='stop', dest='mode',
                      help='Stops a running VPN session with this configuration profile')
    argp.add_argument('--log-level', type=int, action='store', default=start_log_level,
                      help='Set the log verbosity level')
    args = argp.parse_args()

    if args.mode not in ['start','restart','stop']:
        print('%s: ** ERROR ** One of --start, --restart or --stop is required' % sys.argv[0])
        sys.exit(1)

    try:
        o3srv = OpenVPN3systemd()
        if 'start' == args.mode:
            o3srv.RetrieveConfig(args.config)
            o3srv.Start(args.log_level)
        elif 'restart' == args.mode:
            o3srv.RetrieveSession(args.config)
            print('Restarting VPN session')
            o3srv.Restart();
        elif 'stop' == args.mode:
            o3srv.RetrieveSession(args.config)
            print('Stopping VPN session')
            o3srv.Stop();

    except RuntimeError as err:
        print('%s: ** ERROR **  %s' % (sys.argv[0], str(err)))
    except Exception as err:
        print('%s: ** ERROR **  %s' % (sys.argv[0], str(err)))
