Skip to content

Commit 97957fe

Browse files
committed
k8s port forwarding
1 parent f5d4271 commit 97957fe

35 files changed

+1367
-225
lines changed

Diff for: README.md

+11-2
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,9 @@
1717
- `.proto` file discovery
1818
- Selection of multiple services and methods
1919
- Configuration of TLS, including disabling TLS (plain text)
20-
- Authentication: Basic, Bearer Token, JWT, GCE
20+
- Authentication: Basic, Bearer Token, JWT, GCE
21+
- Kubernetes port forwarding
22+
- Authorization in Google Cloud services
2123
- Input generation for all scalar types
2224
- Input generation for nested and looped messages
2325
- Input generation for enums, including nested
@@ -55,4 +57,11 @@ the `Applications` folder and run from `Applications`.
5557

5658
### Windows
5759

58-
[Download](https://github.com/Forest33/warthog/releases) and run `Warthog*-windows-x86-64.exe`.
60+
[Download](https://github.com/Forest33/warthog/releases) and run `Warthog*-windows-x86-64.exe`.
61+
62+
## Google Cloud services authorization
63+
- Enable Kubernetes Engine API and check quota for your project at [https://console.developers.google.com/apis/api/container](https://console.developers.google.com/apis/api/container)
64+
- Install gcloud CLI from [https://cloud.google.com/sdk/](https://cloud.google.com/sdk/) and run
65+
````
66+
gcloud beta auth application-default login
67+
````

Diff for: adapter/database/settings.go

+8-2
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ import (
1313
const (
1414
settingsTable = "settings"
1515
settingsTableFields = `window_width, window_height, window_x, window_y, single_instance, connect_timeout,
16-
request_timeout, non_blocking_connection, sort_methods_by_name, max_loop_depth`
16+
request_timeout, k8s_request_timeout, non_blocking_connection, sort_methods_by_name, max_loop_depth`
1717
)
1818

1919
// SettingsRepository object capable of interacting with SettingsRepository
@@ -38,6 +38,7 @@ type settingsDTO struct {
3838
SingleInstance bool `db:"single_instance"`
3939
ConnectTimeout int `db:"connect_timeout"`
4040
RequestTimeout int `db:"request_timeout"`
41+
K8SRequestTimeout int `db:"k8s_request_timeout"`
4142
NonBlockingConnection bool `db:"non_blocking_connection"`
4243
SortMethodsByName bool `db:"sort_methods_by_name"`
4344
MaxLoopDepth int `db:"max_loop_depth"`
@@ -52,6 +53,7 @@ func (dto *settingsDTO) entity() *entity.Settings {
5253
SingleInstance: &dto.SingleInstance,
5354
ConnectTimeout: &dto.ConnectTimeout,
5455
RequestTimeout: &dto.RequestTimeout,
56+
K8SRequestTimeout: &dto.K8SRequestTimeout,
5557
NonBlockingConnection: &dto.NonBlockingConnection,
5658
SortMethodsByName: &dto.SortMethodsByName,
5759
MaxLoopDepth: &dto.MaxLoopDepth,
@@ -74,7 +76,7 @@ func (repo *SettingsRepository) Get() (*entity.Settings, error) {
7476
func (repo *SettingsRepository) Update(in *entity.Settings) (*entity.Settings, error) {
7577
dto := &settingsDTO{}
7678
attrs := make([]string, 0, 10)
77-
mapper := make(map[string]interface{}, 10)
79+
mapper := make(map[string]interface{}, 11)
7880

7981
if in.WindowWidth > 0 {
8082
attrs = append(attrs, "window_width = :window_width")
@@ -104,6 +106,10 @@ func (repo *SettingsRepository) Update(in *entity.Settings) (*entity.Settings, e
104106
attrs = append(attrs, "request_timeout = :request_timeout")
105107
mapper["request_timeout"] = in.RequestTimeout
106108
}
109+
if in.K8SRequestTimeout != nil {
110+
attrs = append(attrs, "k8s_request_timeout = :k8s_request_timeout")
111+
mapper["k8s_request_timeout"] = in.K8SRequestTimeout
112+
}
107113
if in.NonBlockingConnection != nil {
108114
attrs = append(attrs, "non_blocking_connection = :non_blocking_connection")
109115
mapper["non_blocking_connection"] = in.NonBlockingConnection

Diff for: adapter/grpc/client.go

+16-2
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ type Client struct {
3434
queryCtx context.Context
3535
queryCancel context.CancelFunc
3636
queryStartTime time.Time
37-
cancelQueryMux sync.Mutex
37+
connectionMux sync.RWMutex
3838
requestCh chan *dynamic.Message
3939
responseCh chan *entity.QueryResponse
4040
closeStreamCh chan struct{}
@@ -90,7 +90,11 @@ func (c *Client) Connect(addr string, auth *entity.Auth, opts ...ClientOpt) erro
9090
dialOptions = append(dialOptions, grpc.WithBlock())
9191
if *c.cfg.ConnectTimeout > 0 {
9292
ctx, cancel = context.WithTimeout(c.ctx, time.Second*time.Duration(*c.cfg.ConnectTimeout))
93-
defer cancel()
93+
defer func() {
94+
c.connectionMux.Lock()
95+
cancel()
96+
c.connectionMux.Unlock()
97+
}()
9498
}
9599
}
96100

@@ -139,9 +143,19 @@ func (c *Client) loadTLSCredentials() (credentials.TransportCredentials, error)
139143

140144
// Close closes connection to gRPC server
141145
func (c *Client) Close() {
146+
c.connectionMux.Lock()
147+
defer c.connectionMux.Unlock()
148+
142149
if c.conn != nil {
143150
if err := c.conn.Close(); err != nil {
144151
c.log.Error().Msgf("failed to close connection: %v", err)
145152
}
146153
}
147154
}
155+
156+
func (c *Client) isConnected() bool {
157+
c.connectionMux.RLock()
158+
defer c.connectionMux.RUnlock()
159+
160+
return c.conn != nil
161+
}

Diff for: adapter/grpc/query.go

+9-3
Original file line numberDiff line numberDiff line change
@@ -71,11 +71,15 @@ func (c *Client) createMessage(method *entity.Method, data map[string]interface{
7171
}
7272

7373
// Query executes a gRPC request
74-
func (c *Client) Query(method *entity.Method, data map[string]interface{}, metadata []string) {
74+
func (c *Client) Query(method *entity.Method, data map[string]interface{}, metadata []string) error {
75+
if !c.isConnected() {
76+
return entity.ErrNotConnected
77+
}
78+
7579
ms, err := c.createMessage(method, data, metadata)
7680
if err != nil {
7781
c.responseError(err, "")
78-
return
82+
return err
7983
}
8084

8185
var isNew bool
@@ -92,12 +96,14 @@ func (c *Client) Query(method *entity.Method, data map[string]interface{}, metad
9296
}
9397

9498
if err != nil {
95-
return
99+
return err
96100
}
97101

98102
if isNew || method.Type == entity.MethodTypeClientStream || method.Type == entity.MethodTypeBidiStream {
99103
c.request(ms)
100104
}
105+
106+
return nil
101107
}
102108

103109
// CancelQuery aborting a running gRPC request

Diff for: adapter/k8s/auth.go

+77
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
package k8s
2+
3+
import (
4+
"context"
5+
"encoding/base64"
6+
"fmt"
7+
8+
"golang.org/x/oauth2"
9+
auth "golang.org/x/oauth2/google"
10+
"google.golang.org/api/container/v1"
11+
"k8s.io/client-go/rest"
12+
"k8s.io/client-go/tools/clientcmd"
13+
"k8s.io/client-go/tools/clientcmd/api"
14+
15+
"github.com/forest33/warthog/business/entity"
16+
)
17+
18+
var gcsScopes = []string{"https://www.googleapis.com/auth/cloud-platform"}
19+
20+
func (c *Client) gcsAuth(ctx context.Context, r *entity.GCSAuth) (*rest.Config, error) {
21+
var token *oauth2.Token
22+
23+
credentials, err := auth.FindDefaultCredentials(ctx, gcsScopes...)
24+
if err == nil {
25+
token, err = credentials.TokenSource.Token()
26+
if err != nil {
27+
return nil, err
28+
}
29+
}
30+
31+
containerService, _ := container.NewService(ctx)
32+
33+
name := fmt.Sprintf("projects/%s/locations/%s/clusters/%s", r.Project, r.Location, r.Cluster)
34+
resp, err := containerService.Projects.Locations.Clusters.Get(name).Do()
35+
if err != nil {
36+
return nil, err
37+
}
38+
39+
cert, err := base64.StdEncoding.DecodeString(resp.MasterAuth.ClusterCaCertificate)
40+
if err != nil {
41+
return nil, err
42+
}
43+
44+
apiConfig := &api.Config{
45+
APIVersion: "v1",
46+
Kind: "Config",
47+
Clusters: map[string]*api.Cluster{
48+
r.Cluster: {
49+
CertificateAuthorityData: cert,
50+
Server: fmt.Sprintf("https://%s", resp.Endpoint),
51+
},
52+
},
53+
Contexts: map[string]*api.Context{
54+
r.Cluster: {
55+
Cluster: r.Cluster,
56+
AuthInfo: r.Cluster,
57+
},
58+
},
59+
CurrentContext: r.Cluster,
60+
}
61+
62+
restConfig, err := clientcmd.NewNonInteractiveClientConfig(
63+
*apiConfig,
64+
r.Cluster,
65+
&clientcmd.ConfigOverrides{
66+
CurrentContext: r.Cluster,
67+
},
68+
nil,
69+
).ClientConfig()
70+
if err != nil {
71+
return nil, err
72+
}
73+
74+
restConfig.BearerToken = token.AccessToken
75+
76+
return restConfig, nil
77+
}

Diff for: adapter/k8s/client.go

+172
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,172 @@
1+
package k8s
2+
3+
import (
4+
"bytes"
5+
"context"
6+
"fmt"
7+
"net/http"
8+
"net/url"
9+
"os"
10+
"path/filepath"
11+
"time"
12+
13+
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
14+
"k8s.io/apimachinery/pkg/util/runtime"
15+
"k8s.io/client-go/kubernetes"
16+
"k8s.io/client-go/rest"
17+
"k8s.io/client-go/tools/clientcmd"
18+
"k8s.io/client-go/tools/portforward"
19+
"k8s.io/client-go/transport/spdy"
20+
21+
"github.com/forest33/warthog/business/entity"
22+
"github.com/forest33/warthog/pkg/logger"
23+
)
24+
25+
// Client object capable of interacting with Client
26+
type Client struct {
27+
ctx context.Context
28+
cfg *entity.Settings
29+
log *logger.Zerolog
30+
}
31+
32+
// New creates a new Client
33+
func New(ctx context.Context, log *logger.Zerolog) *Client {
34+
return &Client{
35+
ctx: ctx,
36+
log: log,
37+
}
38+
}
39+
40+
// SetSettings sets application settings
41+
func (c *Client) SetSettings(cfg *entity.Settings) {
42+
c.cfg = cfg
43+
}
44+
45+
// PortForward port forward
46+
func (c *Client) PortForward(r *entity.K8SPortForward) (entity.PortForwardControl, error) {
47+
config, client, err := c.createClient(r.ClientConfig)
48+
if err != nil {
49+
return nil, err
50+
}
51+
52+
var (
53+
podName string
54+
)
55+
56+
if r.PodName != "" {
57+
podName = r.PodName
58+
} else if r.PodNameSelector != "" {
59+
if podName, err = c.findPod(client, r.Namespace, r.PodNameSelector); err != nil {
60+
return nil, err
61+
}
62+
} else {
63+
return nil, entity.ErrK8SPodNotFound
64+
}
65+
66+
ctrl := &PortForwardControl{
67+
stopCh: make(chan struct{}, 1),
68+
out: &bytes.Buffer{},
69+
errOut: &bytes.Buffer{},
70+
}
71+
72+
readyCh := make(chan struct{})
73+
74+
writeError := func(err error) {
75+
if _, err := ctrl.errOut.Write([]byte(err.Error())); err != nil {
76+
c.log.Error().Msgf("failed write to error stream: %v", err)
77+
}
78+
}
79+
80+
if r.ErrHandler != nil {
81+
runtime.ErrorHandlers = []func(error){r.ErrHandler}
82+
}
83+
84+
go func() {
85+
path := fmt.Sprintf("/api/v1/namespaces/%s/pods/%s/portforward", r.Namespace, podName)
86+
87+
u, err := url.Parse(config.Host)
88+
if err != nil {
89+
writeError(err)
90+
return
91+
}
92+
93+
transport, upgrader, err := spdy.RoundTripperFor(config)
94+
if err != nil {
95+
writeError(err)
96+
return
97+
}
98+
99+
dialer := spdy.NewDialer(upgrader, &http.Client{Transport: transport}, http.MethodPost, &url.URL{Scheme: "https", Path: path, Host: u.Host})
100+
fw, err := portforward.New(dialer, []string{fmt.Sprintf("%d:%d", r.LocalPort, r.PodPort)}, ctrl.stopCh, readyCh, ctrl.out, ctrl.errOut)
101+
if err != nil {
102+
writeError(err)
103+
return
104+
}
105+
106+
if err := fw.ForwardPorts(); err != nil {
107+
writeError(err)
108+
}
109+
}()
110+
111+
<-readyCh
112+
113+
return ctrl, nil
114+
}
115+
116+
func (c *Client) createClient(cfg *entity.K8SClientConfig) (*rest.Config, *kubernetes.Clientset, error) {
117+
var (
118+
restConfig *rest.Config
119+
err error
120+
)
121+
122+
if cfg.GCSAuth != nil && cfg.GCSAuth.Enabled {
123+
restConfig, err = c.gcsAuth(c.ctx, cfg.GCSAuth)
124+
if err != nil {
125+
return nil, nil, err
126+
}
127+
} else {
128+
kubeConfig := cfg.KubeConfigFile
129+
if kubeConfig == "" {
130+
if home := homeDir(); home != "" {
131+
kubeConfig = filepath.Join(home, ".kube", "config")
132+
}
133+
}
134+
135+
restConfig, err = clientcmd.BuildConfigFromFlags("", kubeConfig)
136+
if err != nil {
137+
return nil, nil, err
138+
}
139+
140+
restConfig.BearerToken = cfg.BearerToken
141+
}
142+
143+
client, err := kubernetes.NewForConfig(restConfig)
144+
if err != nil {
145+
return nil, nil, err
146+
}
147+
148+
return restConfig, client, nil
149+
}
150+
151+
func (c *Client) findPod(client *kubernetes.Clientset, namespace, selector string) (string, error) {
152+
ctx, cancel := context.WithTimeout(c.ctx, time.Duration(*c.cfg.K8SRequestTimeout)*time.Second)
153+
defer cancel()
154+
155+
pods, err := client.CoreV1().Pods(namespace).List(ctx, metav1.ListOptions{LabelSelector: selector})
156+
if err != nil {
157+
return "", err
158+
}
159+
160+
if len(pods.Items) == 0 {
161+
return "", entity.ErrK8SPodNotFound
162+
}
163+
164+
return pods.Items[0].Name, nil
165+
}
166+
167+
func homeDir() string {
168+
if h := os.Getenv("HOME"); h != "" {
169+
return h
170+
}
171+
return os.Getenv("USERPROFILE")
172+
}

0 commit comments

Comments
 (0)