Skip to content

Commit

Permalink
Merge pull request #150 from gscert/main
Browse files Browse the repository at this point in the history
Add support for providing a custom get S2A stream function to the TLSClientConfigFactory
  • Loading branch information
zeromath authored Oct 1, 2024
2 parents 3b283c9 + 80f23db commit f1fc08d
Show file tree
Hide file tree
Showing 5 changed files with 47 additions and 7 deletions.
5 changes: 3 additions & 2 deletions internal/v2/s2av2.go
Original file line number Diff line number Diff line change
Expand Up @@ -306,8 +306,9 @@ func NewClientTLSConfig(
tokenManager tokenmanager.AccessTokenManager,
verificationMode s2av2pb.ValidatePeerCertificateChainReq_VerificationMode,
serverName string,
serverAuthorizationPolicy []byte) (*tls.Config, error) {
s2AStream, err := createStream(ctx, s2av2Address, transportCreds, nil)
serverAuthorizationPolicy []byte,
getStream func(context.Context, string) (stream.S2AStream, error)) (*tls.Config, error) {
s2AStream, err := createStream(ctx, s2av2Address, transportCreds, getStream)
if err != nil {
grpclog.Infof("Failed to connect to S2Av2: %v", err)
return nil, err
Expand Down
28 changes: 26 additions & 2 deletions internal/v2/s2av2_e2e_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ import (
"github.com/google/s2a-go/internal/tokenmanager"
"github.com/google/s2a-go/internal/v2/fakes2av2"
"github.com/google/s2a-go/retry"
"github.com/google/s2a-go/stream"
"google.golang.org/grpc/credentials"
"google.golang.org/grpc/grpclog"

Expand Down Expand Up @@ -422,7 +423,7 @@ func TestNewClientTlsConfigWithTokenManager(t *testing.T) {
}
ctx, cancel := context.WithTimeout(context.Background(), defaultE2ETimeout)
defer cancel()
config, err := NewClientTLSConfig(ctx, s2AAddr, nil, accessTokenManager, s2av2pb.ValidatePeerCertificateChainReq_CONNECT_TO_GOOGLE, "test_server_name", nil)
config, err := NewClientTLSConfig(ctx, s2AAddr, nil, accessTokenManager, s2av2pb.ValidatePeerCertificateChainReq_CONNECT_TO_GOOGLE, "test_server_name", nil, nil)
if err != nil {
t.Errorf("NewClientTLSConfig() failed: %v", err)
}
Expand All @@ -442,7 +443,7 @@ func TestNewClientTlsConfigWithoutTokenManager(t *testing.T) {
var tokenManager tokenmanager.AccessTokenManager
ctx, cancel := context.WithTimeout(context.Background(), defaultE2ETimeout)
defer cancel()
config, err := NewClientTLSConfig(ctx, s2AAddr, nil, tokenManager, s2av2pb.ValidatePeerCertificateChainReq_CONNECT_TO_GOOGLE, "test_server_name", nil)
config, err := NewClientTLSConfig(ctx, s2AAddr, nil, tokenManager, s2av2pb.ValidatePeerCertificateChainReq_CONNECT_TO_GOOGLE, "test_server_name", nil, nil)
if err != nil {
t.Errorf("NewClientTLSConfig() failed: %v", err)
}
Expand All @@ -455,3 +456,26 @@ func TestNewClientTlsConfigWithoutTokenManager(t *testing.T) {
t.Errorf("tls.Config has unexpected certificate: got: %v, want: %v", got, want)
}
}

func TestNewClientTlsConfigWithCustomS2AStream(t *testing.T) {
os.Unsetenv(accessTokenEnvVariable)
s2aAddr := startFakeS2A(t, "TestNewClientTlsConfig_token")
var tokenManager tokenmanager.AccessTokenManager
ctx, cancel := context.WithTimeout(context.Background(), defaultE2ETimeout)
t.Cleanup(func() {
cancel()
})

getStreamFuncCalled := false
_, err := NewClientTLSConfig(ctx, s2aAddr, nil, tokenManager, s2av2pb.ValidatePeerCertificateChainReq_CONNECT_TO_GOOGLE, "test_server_name", nil, func(ctx context.Context, s2av2Address string) (stream.S2AStream, error) {
getStreamFuncCalled = true
return s2ATestStream{debug: "test S2A stream"}, nil
})
if err != nil {
t.Errorf("NewClientTLSConfig() failed: %v", err)
}

if !getStreamFuncCalled {
t.Errorf("custom getStream function was called = %v, want: %v", getStreamFuncCalled, true)
}
}
13 changes: 12 additions & 1 deletion internal/v2/s2av2_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -318,7 +318,18 @@ func (x s2ATestStream) Send(m *s2av2pb.SessionReq) error {
}

func (x s2ATestStream) Recv() (*s2av2pb.SessionResp, error) {
return nil, nil
return &s2av2pb.SessionResp{
RespOneof: &s2av2pb.SessionResp_GetTlsConfigurationResp{
GetTlsConfigurationResp: &s2av2pb.GetTlsConfigurationResp{
TlsConfiguration: &s2av2pb.GetTlsConfigurationResp_ClientTlsConfiguration_{
ClientTlsConfiguration: &s2av2pb.GetTlsConfigurationResp_ClientTlsConfiguration{
MinTlsVersion: commonpb.TLSVersion_TLS_VERSION_1_3,
MaxTlsVersion: commonpb.TLSVersion_TLS_VERSION_1_3,
},
},
},
},
}, nil
}

func (x s2ATestStream) CloseSend() error {
Expand Down
2 changes: 1 addition & 1 deletion internal/v2/tlsconfigstore/tlsconfigstore.go
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ func GetTLSConfigurationForClient(serverHostname string, s2AStream stream.S2AStr
return nil, fmt.Errorf("failed to get TLS configuration from S2A: %d, %v", resp.GetStatus().Code, resp.GetStatus().Details)
}

// Extract TLS configiguration from SessionResp.
// Extract TLS configuration from SessionResp.
tlsConfig := resp.GetGetTlsConfigurationResp().GetClientTlsConfiguration()

var cert tls.Certificate
Expand Down
6 changes: 5 additions & 1 deletion s2a.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ import (
"github.com/google/s2a-go/internal/tokenmanager"
"github.com/google/s2a-go/internal/v2"
"github.com/google/s2a-go/retry"
"github.com/google/s2a-go/stream"
"google.golang.org/grpc/credentials"
"google.golang.org/grpc/grpclog"
"google.golang.org/protobuf/proto"
Expand Down Expand Up @@ -330,6 +331,7 @@ func NewTLSClientConfigFactory(opts *ClientOptions) (TLSClientConfigFactory, err
tokenManager: nil,
verificationMode: getVerificationMode(opts.VerificationMode),
serverAuthorizationPolicy: opts.serverAuthorizationPolicy,
getStream: opts.getS2AStream,
}, nil
}
return &s2aTLSClientConfigFactory{
Expand All @@ -338,6 +340,7 @@ func NewTLSClientConfigFactory(opts *ClientOptions) (TLSClientConfigFactory, err
tokenManager: tokenManager,
verificationMode: getVerificationMode(opts.VerificationMode),
serverAuthorizationPolicy: opts.serverAuthorizationPolicy,
getStream: opts.getS2AStream,
}, nil
}

Expand All @@ -347,6 +350,7 @@ type s2aTLSClientConfigFactory struct {
tokenManager tokenmanager.AccessTokenManager
verificationMode s2av2pb.ValidatePeerCertificateChainReq_VerificationMode
serverAuthorizationPolicy []byte
getStream func(context.Context, string) (stream.S2AStream, error)
}

func (f *s2aTLSClientConfigFactory) Build(
Expand All @@ -355,7 +359,7 @@ func (f *s2aTLSClientConfigFactory) Build(
if opts != nil && opts.ServerName != "" {
serverName = opts.ServerName
}
return v2.NewClientTLSConfig(ctx, f.s2av2Address, f.transportCreds, f.tokenManager, f.verificationMode, serverName, f.serverAuthorizationPolicy)
return v2.NewClientTLSConfig(ctx, f.s2av2Address, f.transportCreds, f.tokenManager, f.verificationMode, serverName, f.serverAuthorizationPolicy, f.getStream)
}

func getVerificationMode(verificationMode VerificationModeType) s2av2pb.ValidatePeerCertificateChainReq_VerificationMode {
Expand Down

0 comments on commit f1fc08d

Please sign in to comment.