Skip to content

Commit 4ad7b99

Browse files
authored
Known Hosts User Input Fix (#1778)
This handles ambiguous cases where the context that determines the block id of a command is not provided.
1 parent d81b6b8 commit 4ad7b99

File tree

3 files changed

+67
-20
lines changed

3 files changed

+67
-20
lines changed

frontend/app/modals/userinputmodal.scss

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
justify-content: space-between;
1414
gap: 1rem;
1515
margin: 0 1rem 1rem 1rem;
16+
max-width: 500px;
1617

1718
font: var(--fixed-font);
1819
color: var(--main-text-color);

pkg/remote/sshclient.go

Lines changed: 33 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -141,7 +141,13 @@ func createPublicKeyCallback(connCtx context.Context, sshKeywords *wshrpc.ConnKe
141141
authSockSigners = append(authSockSigners, authSockSignersExt...)
142142
authSockSignersPtr := &authSockSigners
143143

144-
return func() ([]ssh.Signer, error) {
144+
return func() (outSigner []ssh.Signer, outErr error) {
145+
defer func() {
146+
panicErr := panichandler.PanicHandler("sshclient:publickey-callback", recover())
147+
if panicErr != nil {
148+
outErr = panicErr
149+
}
150+
}()
145151
// try auth sock
146152
if len(*authSockSignersPtr) != 0 {
147153
authSockSigner := (*authSockSignersPtr)[0]
@@ -219,7 +225,13 @@ func createPublicKeyCallback(connCtx context.Context, sshKeywords *wshrpc.ConnKe
219225
}
220226

221227
func createInteractivePasswordCallbackPrompt(connCtx context.Context, remoteDisplayName string, debugInfo *ConnectionDebugInfo) func() (secret string, err error) {
222-
return func() (secret string, err error) {
228+
return func() (secret string, outErr error) {
229+
defer func() {
230+
panicErr := panichandler.PanicHandler("sshclient:password-callback", recover())
231+
if panicErr != nil {
232+
outErr = panicErr
233+
}
234+
}()
223235
blocklogger.Infof(connCtx, "[conndebug] Password Authentication requested from connection %s...\n", remoteDisplayName)
224236
ctx, cancelFn := context.WithTimeout(connCtx, 60*time.Second)
225237
defer cancelFn()
@@ -244,7 +256,13 @@ func createInteractivePasswordCallbackPrompt(connCtx context.Context, remoteDisp
244256
}
245257

246258
func createInteractiveKbdInteractiveChallenge(connCtx context.Context, remoteName string, debugInfo *ConnectionDebugInfo) func(name, instruction string, questions []string, echos []bool) (answers []string, err error) {
247-
return func(name, instruction string, questions []string, echos []bool) (answers []string, err error) {
259+
return func(name, instruction string, questions []string, echos []bool) (answers []string, outErr error) {
260+
defer func() {
261+
panicErr := panichandler.PanicHandler("sshclient:kbdinteractive-callback", recover())
262+
if panicErr != nil {
263+
outErr = panicErr
264+
}
265+
}()
248266
if len(questions) != len(echos) {
249267
return nil, fmt.Errorf("bad response from server: questions has len %d, echos has len %d", len(questions), len(echos))
250268
}
@@ -332,7 +350,7 @@ func writeToKnownHosts(knownHostsFile string, newLine string, getUserVerificatio
332350
return f.Close()
333351
}
334352

335-
func createUnknownKeyVerifier(knownHostsFile string, hostname string, remote string, key ssh.PublicKey) func() (*userinput.UserInputResponse, error) {
353+
func createUnknownKeyVerifier(ctx context.Context, knownHostsFile string, hostname string, remote string, key ssh.PublicKey) func() (*userinput.UserInputResponse, error) {
336354
base64Key := base64.StdEncoding.EncodeToString(key.Marshal())
337355
queryText := fmt.Sprintf(
338356
"The authenticity of host '%s (%s)' can't be established "+
@@ -349,7 +367,7 @@ func createUnknownKeyVerifier(knownHostsFile string, hostname string, remote str
349367
Title: "Known Hosts Key Missing",
350368
}
351369
return func() (*userinput.UserInputResponse, error) {
352-
ctx, cancelFn := context.WithTimeout(context.Background(), 60*time.Second)
370+
ctx, cancelFn := context.WithTimeout(ctx, 60*time.Second)
353371
defer cancelFn()
354372
resp, err := userinput.GetUserInput(ctx, request)
355373
if err != nil {
@@ -402,7 +420,7 @@ func lineContainsMatch(line []byte, matches [][]byte) bool {
402420
return false
403421
}
404422

405-
func createHostKeyCallback(sshKeywords *wshrpc.ConnKeywords) (ssh.HostKeyCallback, HostKeyAlgorithms, error) {
423+
func createHostKeyCallback(ctx context.Context, sshKeywords *wshrpc.ConnKeywords) (ssh.HostKeyCallback, HostKeyAlgorithms, error) {
406424
globalKnownHostsFiles := sshKeywords.SshGlobalKnownHostsFile
407425
userKnownHostsFiles := sshKeywords.SshUserKnownHostsFile
408426

@@ -473,7 +491,13 @@ func createHostKeyCallback(sshKeywords *wshrpc.ConnKeywords) (ssh.HostKeyCallbac
473491
}
474492
}
475493

476-
waveHostKeyCallback := func(hostname string, remote net.Addr, key ssh.PublicKey) error {
494+
waveHostKeyCallback := func(hostname string, remote net.Addr, key ssh.PublicKey) (outErr error) {
495+
defer func() {
496+
panicErr := panichandler.PanicHandler("sshclient:wave-hostkey-callback", recover())
497+
if panicErr != nil {
498+
outErr = panicErr
499+
}
500+
}()
477501
err := basicCallback(hostname, remote, key)
478502
if err == nil {
479503
// success
@@ -493,7 +517,7 @@ func createHostKeyCallback(sshKeywords *wshrpc.ConnKeywords) (ssh.HostKeyCallbac
493517
err := fmt.Errorf("placeholder, should not be returned") // a null value here can cause problems with empty slice
494518
for _, filename := range knownHostsFiles {
495519
newLine := xknownhosts.Line([]string{xknownhosts.Normalize(hostname)}, key)
496-
getUserVerification := createUnknownKeyVerifier(filename, hostname, remote.String(), key)
520+
getUserVerification := createUnknownKeyVerifier(ctx, filename, hostname, remote.String(), key)
497521
err = writeToKnownHosts(filename, newLine, getUserVerification)
498522
if err == nil {
499523
break
@@ -623,7 +647,7 @@ func createClientConfig(connCtx context.Context, sshKeywords *wshrpc.ConnKeyword
623647
authMethods = append(authMethods, authMethod)
624648
}
625649

626-
hostKeyCallback, hostKeyAlgorithms, err := createHostKeyCallback(sshKeywords)
650+
hostKeyCallback, hostKeyAlgorithms, err := createHostKeyCallback(connCtx, sshKeywords)
627651
if err != nil {
628652
return nil, err
629653
}

pkg/userinput/userinput.go

Lines changed: 33 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -5,13 +5,16 @@ package userinput
55

66
import (
77
"context"
8+
"errors"
89
"fmt"
910
"log"
1011
"sync"
1112
"time"
1213

1314
"github.com/google/uuid"
15+
"github.com/wavetermdev/waveterm/pkg/blocklogger"
1416
"github.com/wavetermdev/waveterm/pkg/genconn"
17+
"github.com/wavetermdev/waveterm/pkg/util/utilfn"
1518
"github.com/wavetermdev/waveterm/pkg/wps"
1619
"github.com/wavetermdev/waveterm/pkg/wstore"
1720
)
@@ -63,22 +66,19 @@ func (ui *UserInputHandler) unregisterChannel(id string) {
6366
delete(ui.Channels, id)
6467
}
6568

66-
func (ui *UserInputHandler) sendRequestToFrontend(request *UserInputRequest, windowId string) {
69+
func (ui *UserInputHandler) sendRequestToFrontend(request *UserInputRequest, scopes []string) {
6770
wps.Broker.Publish(wps.WaveEvent{
6871
Event: wps.Event_UserInput,
6972
Data: request,
70-
Scopes: []string{windowId},
73+
Scopes: scopes,
7174
})
7275
}
7376

74-
func GetUserInput(ctx context.Context, request *UserInputRequest) (*UserInputResponse, error) {
75-
id, uiCh := MainUserInputHandler.registerChannel()
76-
defer MainUserInputHandler.unregisterChannel(id)
77-
request.RequestId = id
78-
deadline, _ := ctx.Deadline()
79-
request.TimeoutMs = int(time.Until(deadline).Milliseconds()) - 500
80-
77+
func determineScopes(ctx context.Context) ([]string, error) {
8178
connData := genconn.GetConnData(ctx)
79+
if connData == nil {
80+
return nil, fmt.Errorf("context did not contain connection info")
81+
}
8282
// resolve windowId from blockId
8383
tabId, err := wstore.DBFindTabForBlockId(ctx, connData.BlockId)
8484
if err != nil {
@@ -93,9 +93,31 @@ func GetUserInput(ctx context.Context, request *UserInputRequest) (*UserInputRes
9393
return nil, fmt.Errorf("unabled to determine window for route: %w", err)
9494
}
9595

96-
MainUserInputHandler.sendRequestToFrontend(request, windowId)
96+
return []string{windowId}, nil
97+
}
98+
99+
func GetUserInput(ctx context.Context, request *UserInputRequest) (*UserInputResponse, error) {
100+
id, uiCh := MainUserInputHandler.registerChannel()
101+
defer MainUserInputHandler.unregisterChannel(id)
102+
request.RequestId = id
103+
request.TimeoutMs = int(utilfn.TimeoutFromContext(ctx, 30*time.Second).Milliseconds())
104+
105+
scopes, scopesErr := determineScopes(ctx)
106+
if scopesErr != nil {
107+
log.Printf("user input scopes could not be found: %v", scopesErr)
108+
blocklogger.Infof(ctx, "user input scopes could not be found: %v", scopesErr)
109+
allWindows, err := wstore.DBGetAllOIDsByType(ctx, "window")
110+
if err != nil {
111+
blocklogger.Infof(ctx, "unable to find windows for user input: %v", err)
112+
return nil, fmt.Errorf("unable to find windows for user input: %v", err)
113+
}
114+
scopes = allWindows
115+
}
116+
117+
MainUserInputHandler.sendRequestToFrontend(request, scopes)
97118

98119
var response *UserInputResponse
120+
var err error
99121
select {
100122
case resp := <-uiCh:
101123
log.Printf("checking received: %v", resp.RequestId)
@@ -105,7 +127,7 @@ func GetUserInput(ctx context.Context, request *UserInputRequest) (*UserInputRes
105127
}
106128

107129
if response.ErrorMsg != "" {
108-
err = fmt.Errorf(response.ErrorMsg)
130+
err = errors.New(response.ErrorMsg)
109131
}
110132

111133
return response, err

0 commit comments

Comments
 (0)