diff --git a/packages/pam/local/access.go b/packages/pam/local/access.go index fb482312..785a7c05 100644 --- a/packages/pam/local/access.go +++ b/packages/pam/local/access.go @@ -93,7 +93,7 @@ func StartPAMAccess(accessToken, path, reason, durationStr string, port int) { case AccountTypeAwsIam: util.PrintErrorMessageAndExit("AWS IAM access not yet supported in the new PAM model") case AccountTypeWindows: - util.PrintErrorMessageAndExit("Windows/RDP access not yet supported in the new PAM model") + startRDPProxy(httpClient, &pamResponse, displayPath, durationStr, port) case AccountTypeActiveDirectory: util.PrintErrorMessageAndExit("Active Directory access not yet supported in the new PAM model") default: @@ -250,6 +250,99 @@ func startDatabaseProxy(httpClient *resty.Client, response *api.PAMAccessRespons proxy.Run() } +func startRDPProxy(httpClient *resty.Client, response *api.PAMAccessResponse, path, durationStr string, port int) { + duration, err := time.ParseDuration(durationStr) + if err != nil { + util.HandleError(err, "Failed to parse duration") + return + } + + username, ok := response.Metadata["username"] + if !ok { + util.HandleError(fmt.Errorf("PAM response metadata is missing 'username'"), "Failed to start RDP proxy") + return + } + + ctx, cancel := context.WithCancel(context.Background()) + + proxy := &RDPProxyServer{ + BaseProxyServer: BaseProxyServer{ + httpClient: httpClient, + relayHost: response.RelayHost, + relayClientCert: response.RelayClientCertificate, + relayClientKey: response.RelayClientPrivateKey, + relayServerCertChain: response.RelayServerCertificateChain, + gatewayClientCert: response.GatewayClientCertificate, + gatewayClientKey: response.GatewayClientPrivateKey, + gatewayServerCertChain: response.GatewayServerCertificateChain, + sessionExpiry: time.Now().Add(duration), + sessionId: response.SessionId, + resourceType: response.AccountType, + ctx: ctx, + cancel: cancel, + shutdownCh: make(chan struct{}), + }, + } + + if err := proxy.ValidateResourceTypeSupported(); err != nil { + util.HandleError(err, "Gateway version outdated") + return + } + + if err := proxy.Start(port); err != nil { + util.HandleError(err, "Failed to start RDP proxy server") + return + } + + rdpFilePath, err := writeRDPFile(proxy.port, response.SessionId, username) + if err != nil { + log.Warn().Err(err).Msg("Failed to write .rdp file; proxy still running") + } else { + proxy.rdpFilePath = rdpFilePath + } + + folder, account := parsePath(path) + + log.Info().Msgf("RDP proxy server listening on port %d", proxy.port) + util.PrintfStderr("\n") + util.PrintfStderr("**********************************************************************\n") + util.PrintfStderr(" RDP Proxy Session Started! \n") + util.PrintfStderr("**********************************************************************\n") + util.PrintfStderr("\n") + if folder != "" { + util.PrintfStderr(" Folder: %s\n", folder) + } + util.PrintfStderr(" Account: %s\n", account) + util.PrintfStderr(" Duration: %s\n", duration.String()) + util.PrintfStderr("\n") + util.PrintfStderr("----------------------------------------------------------------------\n") + util.PrintfStderr(" Connection Details \n") + util.PrintfStderr("----------------------------------------------------------------------\n") + util.PrintfStderr("\n") + util.PrintfStderr(" Host: 127.0.0.1\n") + util.PrintfStderr(" Port: %d\n", proxy.port) + util.PrintfStderr(" Username: %s\n", username) + util.PrintfStderr(" Password: (leave blank)\n") + if proxy.rdpFilePath != "" { + util.PrintfStderr("\n") + util.PrintfStderr(" .rdp file: %s\n", proxy.rdpFilePath) + } + util.PrintfStderr("\n") + util.PrintfStderr(" Press Ctrl+C to terminate the session.\n") + util.PrintfStderr("**********************************************************************\n") + util.PrintfStderr("\n") + + sigChan := make(chan os.Signal, 1) + signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM) + go func() { + sig := <-sigChan + log.Info().Msgf("Received signal %v, initiating graceful shutdown...", sig) + proxy.gracefulShutdown() + }() + + proxy.Run() +} + // printDatabaseSessionInfo prints the connection info banner for database sessions func printDatabaseSessionInfo(config DatabaseDisplayConfig, folder, account string, duration time.Duration, username, database string, port int) { fmt.Printf("\n") diff --git a/packages/pam/local/base-proxy.go b/packages/pam/local/base-proxy.go index 0cc52108..394ec20b 100644 --- a/packages/pam/local/base-proxy.go +++ b/packages/pam/local/base-proxy.go @@ -297,9 +297,12 @@ func (b *BaseProxyServer) WaitForDisconnect(gatewayErrCh, clientErrCh <-chan err case <-gatewayErrCh: b.HandleGatewayDisconnect() case <-clientErrCh: - // Normal client disconnect, proxy stays running case <-connCtx.Done(): - log.Info().Msg("Connection cancelled by context") + select { + case <-gatewayErrCh: + b.HandleGatewayDisconnect() + default: + } } } diff --git a/packages/pam/local/rdp-proxy.go b/packages/pam/local/rdp-proxy.go index de021915..c8067a2c 100644 --- a/packages/pam/local/rdp-proxy.go +++ b/packages/pam/local/rdp-proxy.go @@ -266,7 +266,7 @@ func (p *RDPProxyServer) handleConnection(clientConn net.Conn) { connCtx, connCancel := context.WithCancel(p.ctx) defer connCancel() - done := make(chan struct{}, 2) + gatewayErrCh, clientErrCh := p.NewDisconnectChannels() go func() { defer connCancel() @@ -278,7 +278,7 @@ func (p *RDPProxyServer) handleConnection(clientConn net.Conn) { log.Debug().Err(err).Msg("Gateway to client copy ended") } } - done <- struct{}{} + gatewayErrCh <- err }() go func() { @@ -291,14 +291,10 @@ func (p *RDPProxyServer) handleConnection(clientConn net.Conn) { log.Debug().Err(err).Msg("Client to gateway copy ended") } } - done <- struct{}{} + clientErrCh <- err }() - select { - case <-done: - case <-connCtx.Done(): - log.Info().Msg("Connection cancelled by context") - } + p.WaitForDisconnect(gatewayErrCh, clientErrCh, connCtx) log.Info().Msgf("RDP connection closed for client: %s", clientConn.RemoteAddr().String()) }