Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
95 changes: 94 additions & 1 deletion packages/pam/local/access.go
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ func StartPAMAccess(accessToken, path, reason, durationStr string, port int) {
case AccountTypeAwsIam:
util.PrintErrorMessageAndExit("AWS IAM access not yet supported in the new PAM model")
case AccountTypeWindows:
util.PrintErrorMessageAndExit("Windows/RDP access not yet supported in the new PAM model")
startRDPProxy(httpClient, &pamResponse, displayPath, durationStr, port)
case AccountTypeActiveDirectory:
util.PrintErrorMessageAndExit("Active Directory access not yet supported in the new PAM model")
default:
Expand Down Expand Up @@ -250,6 +250,99 @@ func startDatabaseProxy(httpClient *resty.Client, response *api.PAMAccessRespons
proxy.Run()
}

func startRDPProxy(httpClient *resty.Client, response *api.PAMAccessResponse, path, durationStr string, port int) {
duration, err := time.ParseDuration(durationStr)
if err != nil {
util.HandleError(err, "Failed to parse duration")
return
}

username, ok := response.Metadata["username"]
if !ok {
util.HandleError(fmt.Errorf("PAM response metadata is missing 'username'"), "Failed to start RDP proxy")
return
}

ctx, cancel := context.WithCancel(context.Background())

proxy := &RDPProxyServer{
BaseProxyServer: BaseProxyServer{
httpClient: httpClient,
relayHost: response.RelayHost,
relayClientCert: response.RelayClientCertificate,
relayClientKey: response.RelayClientPrivateKey,
relayServerCertChain: response.RelayServerCertificateChain,
gatewayClientCert: response.GatewayClientCertificate,
gatewayClientKey: response.GatewayClientPrivateKey,
gatewayServerCertChain: response.GatewayServerCertificateChain,
sessionExpiry: time.Now().Add(duration),
sessionId: response.SessionId,
resourceType: response.AccountType,
ctx: ctx,
cancel: cancel,
shutdownCh: make(chan struct{}),
},
}

if err := proxy.ValidateResourceTypeSupported(); err != nil {
util.HandleError(err, "Gateway version outdated")
return
}

if err := proxy.Start(port); err != nil {
Comment thread
bernie-g marked this conversation as resolved.
util.HandleError(err, "Failed to start RDP proxy server")
return
}
Comment thread
bernie-g marked this conversation as resolved.

rdpFilePath, err := writeRDPFile(proxy.port, response.SessionId, username)
if err != nil {
log.Warn().Err(err).Msg("Failed to write .rdp file; proxy still running")
} else {
proxy.rdpFilePath = rdpFilePath
}

folder, account := parsePath(path)

log.Info().Msgf("RDP proxy server listening on port %d", proxy.port)
util.PrintfStderr("\n")
util.PrintfStderr("**********************************************************************\n")
util.PrintfStderr(" RDP Proxy Session Started! \n")
util.PrintfStderr("**********************************************************************\n")
util.PrintfStderr("\n")
if folder != "" {
util.PrintfStderr(" Folder: %s\n", folder)
}
util.PrintfStderr(" Account: %s\n", account)
util.PrintfStderr(" Duration: %s\n", duration.String())
util.PrintfStderr("\n")
util.PrintfStderr("----------------------------------------------------------------------\n")
util.PrintfStderr(" Connection Details \n")
util.PrintfStderr("----------------------------------------------------------------------\n")
util.PrintfStderr("\n")
util.PrintfStderr(" Host: 127.0.0.1\n")
util.PrintfStderr(" Port: %d\n", proxy.port)
util.PrintfStderr(" Username: %s\n", username)
util.PrintfStderr(" Password: (leave blank)\n")
if proxy.rdpFilePath != "" {
util.PrintfStderr("\n")
util.PrintfStderr(" .rdp file: %s\n", proxy.rdpFilePath)
}
util.PrintfStderr("\n")
util.PrintfStderr(" Press Ctrl+C to terminate the session.\n")
util.PrintfStderr("**********************************************************************\n")
util.PrintfStderr("\n")

sigChan := make(chan os.Signal, 1)
signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM)
go func() {
sig := <-sigChan
log.Info().Msgf("Received signal %v, initiating graceful shutdown...", sig)
proxy.gracefulShutdown()
}()

proxy.Run()
}

// printDatabaseSessionInfo prints the connection info banner for database sessions
func printDatabaseSessionInfo(config DatabaseDisplayConfig, folder, account string, duration time.Duration, username, database string, port int) {
fmt.Printf("\n")
Expand Down
7 changes: 5 additions & 2 deletions packages/pam/local/base-proxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -297,9 +297,12 @@ func (b *BaseProxyServer) WaitForDisconnect(gatewayErrCh, clientErrCh <-chan err
case <-gatewayErrCh:
b.HandleGatewayDisconnect()
case <-clientErrCh:
// Normal client disconnect, proxy stays running
case <-connCtx.Done():
log.Info().Msg("Connection cancelled by context")
select {
case <-gatewayErrCh:
b.HandleGatewayDisconnect()
default:
}
}
}

Expand Down
12 changes: 4 additions & 8 deletions packages/pam/local/rdp-proxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -266,7 +266,7 @@ func (p *RDPProxyServer) handleConnection(clientConn net.Conn) {
connCtx, connCancel := context.WithCancel(p.ctx)
defer connCancel()

done := make(chan struct{}, 2)
gatewayErrCh, clientErrCh := p.NewDisconnectChannels()

go func() {
defer connCancel()
Expand All @@ -278,7 +278,7 @@ func (p *RDPProxyServer) handleConnection(clientConn net.Conn) {
log.Debug().Err(err).Msg("Gateway to client copy ended")
}
}
done <- struct{}{}
gatewayErrCh <- err
}()

go func() {
Expand All @@ -291,14 +291,10 @@ func (p *RDPProxyServer) handleConnection(clientConn net.Conn) {
log.Debug().Err(err).Msg("Client to gateway copy ended")
}
}
done <- struct{}{}
clientErrCh <- err
}()

select {
case <-done:
case <-connCtx.Done():
log.Info().Msg("Connection cancelled by context")
}
p.WaitForDisconnect(gatewayErrCh, clientErrCh, connCtx)
Comment thread
bernie-g marked this conversation as resolved.

log.Info().Msgf("RDP connection closed for client: %s", clientConn.RemoteAddr().String())
}
Expand Down
Loading