Skip to content

Commit 603dbde

Browse files
committed
removing error from NewProtoReader return, fixing tests and lint issues
1 parent 9d64454 commit 603dbde

File tree

16 files changed

+156
-94
lines changed

16 files changed

+156
-94
lines changed

lib/auth/init_test.go

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -215,6 +215,9 @@ func TestSignatureAlgorithmSuite(t *testing.T) {
215215
setupInitConfig := func(t *testing.T, capOrigin string, fips, hsm bool) InitConfig {
216216
cfg := setupConfig(t)
217217
cfg.FIPS = fips
218+
if hsm {
219+
cfg.KeyStoreConfig = keystore.HSMTestConfig(t)
220+
}
218221
cfg.AuthPreference.SetOrigin(capOrigin)
219222
if capOrigin != types.OriginDefaults {
220223
cfg.AuthPreference.SetSignatureAlgorithmSuite(types.SignatureAlgorithmSuite_SIGNATURE_ALGORITHM_SUITE_UNSPECIFIED)
Lines changed: 61 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,27 @@
1+
// Teleport
2+
// Copyright (C) 2025 Gravitational, Inc.
3+
//
4+
// This program is free software: you can redistribute it and/or modify
5+
// it under the terms of the GNU Affero General Public License as published by
6+
// the Free Software Foundation, either version 3 of the License, or
7+
// (at your option) any later version.
8+
//
9+
// This program is distributed in the hope that it will be useful,
10+
// but WITHOUT ANY WARRANTY; without even the implied warranty of
11+
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12+
// GNU Affero General Public License for more details.
13+
//
14+
// You should have received a copy of the GNU Affero General Public License
15+
// along with this program. If not, see <http://www.gnu.org/licenses/>.
16+
117
package recordingencryption
218

319
import (
420
"context"
521

622
"github.com/gravitational/trace"
723

24+
recordingencryptionv1 "github.com/gravitational/teleport/api/gen/proto/go/teleport/recordingencryption/v1"
825
"github.com/gravitational/teleport/api/types"
926
"github.com/gravitational/teleport/lib/services"
1027
)
@@ -26,45 +43,63 @@ func NewClusterConfigService(service services.ClusterConfigurationInternal, reso
2643

2744
// CreateSessionRecordingConfig evaluates RecordingEncryption state before creating the SessionRecordingConfig.
2845
func (s *ClusterConfigService) CreateSessionRecordingConfig(ctx context.Context, cfg types.SessionRecordingConfig) (types.SessionRecordingConfig, error) {
29-
if cfg.GetEncrypted() {
30-
encryption, err := s.resolver.ResolveRecordingEncryption(ctx)
31-
if err != nil {
32-
return nil, trace.Wrap(err)
33-
}
46+
if !cfg.GetEncrypted() {
47+
res, err := s.ClusterConfigurationInternal.CreateSessionRecordingConfig(ctx, cfg)
48+
return res, trace.Wrap(err)
49+
}
3450

35-
cfg.SetEncryptionKeys(GetAgeEncryptionKeys(encryption.GetSpec().ActiveKeys))
51+
var res types.SessionRecordingConfig
52+
_, err := s.resolver.ResolveRecordingEncryption(ctx, func(ctx context.Context, encryption *recordingencryptionv1.RecordingEncryption) error {
53+
cfg.SetEncryptionKeys(getAgeEncryptionKeys(encryption.GetSpec().ActiveKeys))
54+
var err error
55+
res, err = s.ClusterConfigurationInternal.CreateSessionRecordingConfig(ctx, cfg)
56+
return err
57+
})
58+
if err != nil {
59+
return nil, trace.Wrap(err)
3660
}
3761

38-
res, err := s.ClusterConfigurationInternal.CreateSessionRecordingConfig(ctx, cfg)
39-
return res, trace.Wrap(err)
62+
return res, nil
4063
}
4164

4265
// UpdateSessionRecordingConfig evaluates RecordingEncryption state before updating the SessionRecordingConfig.
43-
func (r *ClusterConfigService) UpdateSessionRecordingConfig(ctx context.Context, cfg types.SessionRecordingConfig) (types.SessionRecordingConfig, error) {
44-
if cfg.GetEncrypted() {
45-
encryption, err := r.resolver.ResolveRecordingEncryption(ctx)
46-
if err != nil {
47-
return nil, trace.Wrap(err)
48-
}
66+
func (s *ClusterConfigService) UpdateSessionRecordingConfig(ctx context.Context, cfg types.SessionRecordingConfig) (types.SessionRecordingConfig, error) {
67+
if !cfg.GetEncrypted() {
68+
res, err := s.ClusterConfigurationInternal.UpdateSessionRecordingConfig(ctx, cfg)
69+
return res, trace.Wrap(err)
70+
}
4971

50-
cfg.SetEncryptionKeys(GetAgeEncryptionKeys(encryption.GetSpec().ActiveKeys))
72+
var res types.SessionRecordingConfig
73+
_, err := s.resolver.ResolveRecordingEncryption(ctx, func(ctx context.Context, encryption *recordingencryptionv1.RecordingEncryption) error {
74+
cfg.SetEncryptionKeys(getAgeEncryptionKeys(encryption.GetSpec().ActiveKeys))
75+
var err error
76+
res, err = s.ClusterConfigurationInternal.UpdateSessionRecordingConfig(ctx, cfg)
77+
return err
78+
})
79+
if err != nil {
80+
return nil, trace.Wrap(err)
5181
}
5282

53-
res, err := r.ClusterConfigurationInternal.UpdateSessionRecordingConfig(ctx, cfg)
54-
return res, trace.Wrap(err)
83+
return res, nil
5584
}
5685

5786
// UpsertSessionRecordingConfig evaluates RecordingEncryption state before upserting the SessionRecordingConfig.
58-
func (r *ClusterConfigService) UpsertSessionRecordingConfig(ctx context.Context, cfg types.SessionRecordingConfig) (types.SessionRecordingConfig, error) {
59-
if cfg.GetEncrypted() {
60-
encryption, err := r.resolver.ResolveRecordingEncryption(ctx)
61-
if err != nil {
62-
return nil, trace.Wrap(err)
63-
}
87+
func (s *ClusterConfigService) UpsertSessionRecordingConfig(ctx context.Context, cfg types.SessionRecordingConfig) (types.SessionRecordingConfig, error) {
88+
if !cfg.GetEncrypted() {
89+
res, err := s.ClusterConfigurationInternal.UpsertSessionRecordingConfig(ctx, cfg)
90+
return res, trace.Wrap(err)
91+
}
6492

65-
cfg.SetEncryptionKeys(GetAgeEncryptionKeys(encryption.GetSpec().ActiveKeys))
93+
var res types.SessionRecordingConfig
94+
_, err := s.resolver.ResolveRecordingEncryption(ctx, func(ctx context.Context, encryption *recordingencryptionv1.RecordingEncryption) error {
95+
cfg.SetEncryptionKeys(getAgeEncryptionKeys(encryption.GetSpec().ActiveKeys))
96+
var err error
97+
res, err = s.ClusterConfigurationInternal.UpsertSessionRecordingConfig(ctx, cfg)
98+
return err
99+
})
100+
if err != nil {
101+
return nil, trace.Wrap(err)
66102
}
67103

68-
res, err := r.ClusterConfigurationInternal.UpsertSessionRecordingConfig(ctx, cfg)
69-
return res, trace.Wrap(err)
104+
return res, nil
70105
}

lib/auth/recordingencryption/encryptedio.go

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,27 @@
1+
// Teleport
2+
// Copyright (C) 2025 Gravitational, Inc.
3+
//
4+
// This program is free software: you can redistribute it and/or modify
5+
// it under the terms of the GNU Affero General Public License as published by
6+
// the Free Software Foundation, either version 3 of the License, or
7+
// (at your option) any later version.
8+
//
9+
// This program is distributed in the hope that it will be useful,
10+
// but WITHOUT ANY WARRANTY; without even the implied warranty of
11+
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12+
// GNU Affero General Public License for more details.
13+
//
14+
// You should have received a copy of the GNU Affero General Public License
15+
// along with this program. If not, see <http://www.gnu.org/licenses/>.
16+
117
package recordingencryption
218

319
import (
420
"context"
521
"io"
622

7-
"github.com/gravitational/trace"
8-
923
"filippo.io/age"
24+
"github.com/gravitational/trace"
1025

1126
"github.com/gravitational/teleport/api/types"
1227
)

lib/auth/recordingencryption/encryptedio_test.go

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,19 @@
1+
// Teleport
2+
// Copyright (C) 2025 Gravitational, Inc.
3+
//
4+
// This program is free software: you can redistribute it and/or modify
5+
// it under the terms of the GNU Affero General Public License as published by
6+
// the Free Software Foundation, either version 3 of the License, or
7+
// (at your option) any later version.
8+
//
9+
// This program is distributed in the hope that it will be useful,
10+
// but WITHOUT ANY WARRANTY; without even the implied warranty of
11+
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12+
// GNU Affero General Public License for more details.
13+
//
14+
// You should have received a copy of the GNU Affero General Public License
15+
// along with this program. If not, see <http://www.gnu.org/licenses/>.
16+
117
package recordingencryption_test
218

319
import (
@@ -34,6 +50,7 @@ func TestEncryptedIO(t *testing.T) {
3450

3551
msg := []byte("testing encrypted IO")
3652
_, err = writer.Write(msg)
53+
require.NoError(t, err)
3754

3855
// writer must be closed to ensure data is flushed
3956
err = writer.Close()

lib/auth/recordingencryption/watcher.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ func NewWatcher(cfg WatchConfig) (*Watcher, error) {
6666
return nil, trace.BadParameter("cluster config backend is required")
6767
}
6868
if cfg.Logger == nil {
69-
cfg.Logger = slog.With(teleport.ComponentKey, "encryption-watcher")
69+
cfg.Logger = slog.With(teleport.ComponentKey, "recording-encryption-watcher")
7070
}
7171

7272
return &Watcher{

lib/client/player.go

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -51,11 +51,7 @@ func (p *playFromFileStreamer) StreamSessionEvents(
5151
}
5252
defer f.Close()
5353

54-
pr, err := events.NewProtoReader(f, nil)
55-
if err != nil {
56-
errs <- trace.Wrap(err)
57-
return
58-
}
54+
pr := events.NewProtoReader(f, nil)
5955

6056
for i := int64(0); ; i++ {
6157
evt, err := pr.Read(ctx)

lib/events/auditlog.go

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -579,11 +579,7 @@ func (l *AuditLog) StreamSessionEvents(ctx context.Context, sessionID session.ID
579579
return
580580
}
581581

582-
protoReader, err := NewProtoReader(rawSession, l.decrypter)
583-
if err != nil {
584-
e <- trace.Wrap(err)
585-
return
586-
}
582+
protoReader := NewProtoReader(rawSession, l.decrypter)
587583
defer protoReader.Close()
588584

589585
firstEvent := true

lib/events/emitter_test.go

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -108,8 +108,7 @@ func TestProtoStreamer(t *testing.T) {
108108
require.NoError(t, err)
109109

110110
for _, part := range parts {
111-
reader, err := events.NewProtoReader(bytes.NewReader(part), nil)
112-
require.NoError(t, err)
111+
reader := events.NewProtoReader(bytes.NewReader(part), nil)
113112
out, err := reader.ReadAll(ctx)
114113
require.NoError(t, err, "part crash %#v", part)
115114
outEvents = append(outEvents, out...)
@@ -257,8 +256,7 @@ func TestExport(t *testing.T) {
257256
_, err := f.Write(part)
258257
require.NoError(t, err)
259258
}
260-
reader, err := events.NewProtoReader(io.MultiReader(readers...), nil)
261-
require.NoError(t, err)
259+
reader := events.NewProtoReader(io.MultiReader(readers...), nil)
262260
outEvents, err := reader.ReadAll(ctx)
263261
require.NoError(t, err)
264262

lib/events/filesessions/fileasync.go

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -439,11 +439,7 @@ func (u *Uploader) startUpload(ctx context.Context, fileName string) (err error)
439439
return trace.Wrap(err, "uploader could not acquire file lock for %q", sessionFilePath)
440440
}
441441

442-
protoReader, err := events.NewProtoReader(sessionFile, nil)
443-
if err != nil {
444-
return trace.Wrap(err)
445-
}
446-
442+
protoReader := events.NewProtoReader(sessionFile, nil)
447443
upload := &upload{
448444
sessionID: sessionID,
449445
reader: protoReader,

lib/events/filesessions/fileasync_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -669,7 +669,7 @@ func readStream(ctx context.Context, t *testing.T, uploadID string, uploader *ev
669669
var reader *events.ProtoReader
670670
for i, part := range parts {
671671
if i == 0 {
672-
reader = events.NewProtoReader(bytes.NewReader(part))
672+
reader = events.NewProtoReader(bytes.NewReader(part), nil)
673673
} else {
674674
err := reader.Reset(bytes.NewReader(part))
675675
require.NoError(t, err)

lib/events/playback.go

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ func DetectFormat(r io.ReadSeeker) (*Header, error) {
5353
return nil, trace.ConvertSystemError(err)
5454
}
5555
protocolVersion := binary.BigEndian.Uint64(version)
56-
if protocolVersion == ProtoStreamV1 {
56+
if protocolVersion >= ProtoStreamV1 && protocolVersion <= ProtoStreamV2 {
5757
return &Header{
5858
Proto: true,
5959
ProtoVersion: int64(protocolVersion),
@@ -83,11 +83,7 @@ func Export(ctx context.Context, rs io.ReadSeeker, w io.Writer, exportFormat str
8383
}
8484
switch {
8585
case format.Proto:
86-
protoReader, err := NewProtoReader(rs, nil)
87-
if err != nil {
88-
return trace.Wrap(err)
89-
}
90-
86+
protoReader := NewProtoReader(rs, nil)
9187
for {
9288
event, err := protoReader.Read(ctx)
9389
if err != nil {

lib/events/session_writer_test.go

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -74,8 +74,7 @@ func TestSessionWriter(t *testing.T) {
7474
require.NoError(t, err)
7575

7676
for _, part := range parts {
77-
reader, err := events.NewProtoReader(bytes.NewReader(part), nil)
78-
require.NoError(t, err)
77+
reader := events.NewProtoReader(bytes.NewReader(part), nil)
7978
out, err := reader.ReadAll(test.ctx)
8079
require.NoError(t, err, "part crash %#v", part)
8180
outEvents = append(outEvents, out...)
@@ -421,8 +420,7 @@ func (a *sessionWriterTest) collectEvents(t *testing.T) []apievents.AuditEvent {
421420
for _, part := range parts {
422421
readers = append(readers, bytes.NewReader(part))
423422
}
424-
reader, err := events.NewProtoReader(io.MultiReader(readers...), nil)
425-
require.NoError(t, err)
423+
reader := events.NewProtoReader(io.MultiReader(readers...), nil)
426424
outEvents, err := reader.ReadAll(a.ctx)
427425
require.NoError(t, err, "failed to read")
428426
t.Logf("Reader stats :%v", reader.GetStats().ToFields())

lib/events/stream.go

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -91,12 +91,12 @@ const (
9191
// of the record header, it consists of the record length
9292
ProtoStreamV1RecordHeaderSize = Int32Size
9393

94+
// AgeHeader prefixes all encrypted recording parts
95+
AgeHeader = "age-encryption.org/"
96+
9497
// uploaderReservePartErrorMessage error message present when
9598
// `ReserveUploadPart` fails.
9699
uploaderReservePartErrorMessage = "uploader failed to reserve upload part"
97-
98-
// ageHeader prefixes all encrypted recording parts
99-
ageHeader = "age-encryption.org/"
100100
)
101101

102102
// An EncryptionWrapper wraps a given io.WriteCloser with encryption.
@@ -930,7 +930,7 @@ func (s *slice) reader() (io.ReadSeeker, error) {
930930
s.buffer.Write(padding)
931931
}
932932
data := s.buffer.Bytes()
933-
encrypted := slices.Equal(data[ProtoStreamV2PartHeaderSize:ProtoStreamV2PartHeaderSize+len(ageHeader)], []byte(ageHeader))
933+
encrypted := slices.Equal(data[ProtoStreamV2PartHeaderSize:ProtoStreamV2PartHeaderSize+len(AgeHeader)], []byte(AgeHeader))
934934
// when the slice was created, the first bytes were reserved
935935
// for the protocol version number and size of the slice in bytes
936936
binary.BigEndian.PutUint64(data[0:], ProtoStreamV2)
@@ -998,12 +998,12 @@ func (s *slice) recordEvent(event protoEvent) error {
998998
}
999999

10001000
// NewProtoReader returns a new proto reader with slice pool
1001-
func NewProtoReader(r io.Reader, decrypter DecryptionWrapper) (*ProtoReader, error) {
1001+
func NewProtoReader(r io.Reader, decrypter DecryptionWrapper) *ProtoReader {
10021002
return &ProtoReader{
10031003
reader: r,
10041004
lastIndex: -1,
10051005
decrypter: decrypter,
1006-
}, nil
1006+
}
10071007
}
10081008

10091009
// SessionReader provides method to read
@@ -1156,6 +1156,9 @@ func (r *ProtoReader) Read(ctx context.Context) (apievents.AuditEvent, error) {
11561156
var encrypted bool
11571157
if protocolVersion > 1 {
11581158
_, err = io.ReadFull(r.reader, r.sizeBytes[:Int64Size])
1159+
if err != nil {
1160+
return nil, r.setError(trace.ConvertSystemError(err))
1161+
}
11591162
flags := r.sizeBytes[0]
11601163
encrypted = flags&ProtoStreamFlagEncrypted != 0
11611164
}

0 commit comments

Comments
 (0)