Skip to content

Commit 3408192

Browse files
committed
WIP
1 parent 5c0974c commit 3408192

File tree

1 file changed

+189
-98
lines changed

1 file changed

+189
-98
lines changed

pkg/client/client.go

Lines changed: 189 additions & 98 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ package client
33
import (
44
"context"
55
"fmt"
6-
"log"
6+
"sync"
77

88
"github.com/aws/aws-sdk-go-v2/service/dynamodb"
99
"github.com/aws/aws-sdk-go-v2/service/dynamodb/types"
@@ -15,150 +15,247 @@ import (
1515
type EncryptedClient struct {
1616
client *dynamodb.Client
1717
materialsProvider provider.CryptographicMaterialsProvider
18-
primaryKeyInfo *utils.PrimaryKeyInfo
1918
primaryKeyCache map[string]*utils.PrimaryKeyInfo
19+
lock sync.RWMutex
2020
}
2121

2222
// NewEncryptedClient creates a new instance of EncryptedClient.
2323
func NewEncryptedClient(client *dynamodb.Client, materialsProvider provider.CryptographicMaterialsProvider) *EncryptedClient {
2424
return &EncryptedClient{
2525
client: client,
2626
materialsProvider: materialsProvider,
27-
primaryKeyInfo: nil,
2827
primaryKeyCache: make(map[string]*utils.PrimaryKeyInfo),
28+
lock: sync.RWMutex{},
2929
}
3030
}
3131

3232
// PutItem encrypts an item and puts it into a DynamoDB table.
3333
func (ec *EncryptedClient) PutItem(ctx context.Context, input *dynamodb.PutItemInput) (*dynamodb.PutItemOutput, error) {
34+
// Encrypt the item, excluding primary keys
35+
encryptedItem, err := ec.encryptItem(ctx, *input.TableName, input.Item)
36+
if err != nil {
37+
return nil, fmt.Errorf("failed to encrypt item: %v", err)
38+
}
39+
40+
// Create a new PutItemInput with the encrypted item
41+
encryptedInput := &dynamodb.PutItemInput{
42+
TableName: input.TableName,
43+
Item: encryptedItem,
44+
}
45+
46+
// Put the encrypted item into the DynamoDB table
47+
return ec.client.PutItem(ctx, encryptedInput)
48+
}
49+
50+
// GetItem retrieves an item from a DynamoDB table and decrypts it.
51+
func (ec *EncryptedClient) GetItem(ctx context.Context, input *dynamodb.GetItemInput) (*dynamodb.GetItemOutput, error) {
52+
// First, retrieve the encrypted item from DynamoDB
53+
encryptedOutput, err := ec.client.GetItem(ctx, input)
54+
if err != nil {
55+
return nil, fmt.Errorf("error retrieving encrypted item: %v", err)
56+
}
57+
58+
// Check if item is found
59+
if encryptedOutput.Item == nil {
60+
return nil, fmt.Errorf("item not found")
61+
}
62+
63+
// Decrypt the item, excluding primary keys
64+
decryptedItem, err := ec.decryptItem(ctx, *input.TableName, encryptedOutput.Item)
65+
if err != nil {
66+
return nil, fmt.Errorf("failed to decrypt item: %v", err)
67+
}
68+
69+
// Create a new GetItemOutput with the decrypted item
70+
decryptedOutput := &dynamodb.GetItemOutput{
71+
Item: decryptedItem,
72+
}
73+
74+
return decryptedOutput, nil
75+
}
76+
77+
// Query executes a Query operation on DynamoDB and decrypts the returned items.
78+
func (ec *EncryptedClient) Query(ctx context.Context, input *dynamodb.QueryInput) (*dynamodb.QueryOutput, error) {
79+
encryptedOutput, err := ec.client.Query(ctx, input)
80+
if err != nil {
81+
return nil, fmt.Errorf("error querying encrypted items: %v", err)
82+
}
83+
3484
tableName := *input.TableName
3585

36-
// Cache check for primary key info
37-
pkInfo, ok := ec.primaryKeyCache[tableName]
38-
if !ok {
39-
var err error
40-
pkInfo, err = ec.getPrimaryKeyInfo(ctx, tableName)
41-
if err != nil {
42-
return nil, fmt.Errorf("error fetching primary key info: %v", err)
86+
// Decrypt the items in the response
87+
for i, item := range encryptedOutput.Items {
88+
decryptedItem, decryptErr := ec.decryptItem(ctx, tableName, item)
89+
if decryptErr != nil {
90+
return nil, decryptErr
4391
}
92+
encryptedOutput.Items[i] = decryptedItem
4493
}
45-
partitionKeyValue := input.Item[pkInfo.PartitionKey].(*types.AttributeValueMemberS).Value
46-
var sortKeyValue string
47-
if pkInfo.SortKey != "" && input.Item[pkInfo.SortKey] != nil {
48-
sortKeyValue = input.Item[pkInfo.SortKey].(*types.AttributeValueMemberS).Value
94+
95+
return encryptedOutput, nil
96+
}
97+
98+
// Scan executes a Scan operation on DynamoDB and decrypts the returned items.
99+
func (ec *EncryptedClient) Scan(ctx context.Context, input *dynamodb.ScanInput) (*dynamodb.ScanOutput, error) {
100+
encryptedOutput, err := ec.client.Scan(ctx, input)
101+
if err != nil {
102+
return nil, fmt.Errorf("error scanning encrypted items: %v", err)
49103
}
50104

51-
// Construct and hash the material name
52-
rawMaterialName := tableName + "-" + partitionKeyValue
53-
if sortKeyValue != "" {
54-
rawMaterialName += "-" + sortKeyValue
105+
tableName := *input.TableName
106+
107+
// Decrypt the items in the response
108+
for i, item := range encryptedOutput.Items {
109+
decryptedItem, decryptErr := ec.decryptItem(ctx, tableName, item)
110+
if decryptErr != nil {
111+
return nil, decryptErr
112+
}
113+
encryptedOutput.Items[i] = decryptedItem
55114
}
56115

57-
materialName := utils.HashString(rawMaterialName)
116+
return encryptedOutput, nil
117+
}
58118

59-
// Generate and store new material
60-
encryptionMaterials, err := ec.materialsProvider.EncryptionMaterials(context.Background(), materialName)
119+
// BatchWriteItem performs batch write operations, encrypting any items to be put.
120+
func (ec *EncryptedClient) BatchWriteItem(ctx context.Context, input *dynamodb.BatchWriteItemInput) (*dynamodb.BatchWriteItemOutput, error) {
121+
// Iterate over each table's write requests
122+
for tableName, writeRequests := range input.RequestItems {
123+
for i, writeRequest := range writeRequests {
124+
if writeRequest.PutRequest != nil {
125+
// Encrypt the item for PutRequest
126+
encryptedItem, err := ec.encryptItem(ctx, tableName, writeRequest.PutRequest.Item)
127+
if err != nil {
128+
return nil, err
129+
}
130+
input.RequestItems[tableName][i].PutRequest.Item = encryptedItem
131+
}
132+
}
133+
}
134+
135+
return ec.client.BatchWriteItem(ctx, input)
136+
}
137+
138+
// BatchGetItem retrieves a batch of items from DynamoDB and decrypts them.
139+
func (ec *EncryptedClient) BatchGetItem(ctx context.Context, input *dynamodb.BatchGetItemInput) (*dynamodb.BatchGetItemOutput, error) {
140+
encryptedOutput, err := ec.client.BatchGetItem(ctx, input)
61141
if err != nil {
62-
log.Fatalf("Failed to generate encryption materials: %v", err)
142+
return nil, fmt.Errorf("error batch getting encrypted items: %v", err)
63143
}
64144

65-
// Create a new item map to hold encrypted attributes
66-
encryptedItem := make(map[string]types.AttributeValue)
145+
// Decrypt the items in the response for each table
146+
for tableName, result := range encryptedOutput.Responses {
147+
for i, item := range result {
148+
decryptedItem, decryptErr := ec.decryptItem(ctx, tableName, item)
149+
if decryptErr != nil {
150+
return nil, decryptErr
151+
}
152+
encryptedOutput.Responses[tableName][i] = decryptedItem
153+
}
154+
}
155+
156+
return encryptedOutput, nil
157+
}
158+
159+
// getPrimaryKeyInfo lazily loads and caches primary key information in a thread-safe manner.
160+
func (ec *EncryptedClient) getPrimaryKeyInfo(ctx context.Context, tableName string) (*utils.PrimaryKeyInfo, error) {
161+
ec.lock.RLock()
162+
pkInfo, exists := ec.primaryKeyCache[tableName]
163+
ec.lock.RUnlock()
164+
165+
if exists {
166+
return pkInfo, nil
167+
}
168+
169+
ec.lock.Lock()
170+
defer ec.lock.Unlock()
171+
172+
pkInfo, exists = ec.primaryKeyCache[tableName]
173+
if exists {
174+
return pkInfo, nil
175+
}
176+
177+
pkInfo, err := utils.TableInfo(ctx, ec.client, tableName)
178+
if err != nil {
179+
return nil, err
180+
}
181+
182+
ec.primaryKeyCache[tableName] = pkInfo
183+
184+
return pkInfo, nil
185+
}
186+
187+
// encryptItem encrypts a DynamoDB item's attributes, excluding primary keys.
188+
func (ec *EncryptedClient) encryptItem(ctx context.Context, tableName string, item map[string]types.AttributeValue) (map[string]types.AttributeValue, error) {
189+
// Fetch primary key info to exclude these attributes from encryption
190+
pkInfo, err := ec.getPrimaryKeyInfo(ctx, tableName)
191+
if err != nil {
192+
return nil, err
193+
}
194+
195+
// Generate and fetch encryption materials
196+
materialName := ec.constructMaterialName(item, pkInfo)
197+
encryptionMaterials, err := ec.materialsProvider.EncryptionMaterials(ctx, materialName)
198+
if err != nil {
199+
return nil, fmt.Errorf("failed to fetch encryption materials: %v", err)
200+
}
67201

68-
// Encrypt attribute values, excluding primary keys
69-
for key, value := range input.Item {
202+
encryptedItem := make(map[string]types.AttributeValue)
203+
for key, value := range item {
204+
// Exclude primary keys from encryption
70205
if key == pkInfo.PartitionKey || key == pkInfo.SortKey {
71-
// Copy primary key attributes as is
72206
encryptedItem[key] = value
73207
continue
74208
}
75209

76-
// Convert attribute value to bytes
77210
rawData, err := utils.AttributeValueToBytes(value)
78211
if err != nil {
79212
return nil, fmt.Errorf("error converting attribute value to bytes: %v", err)
80213
}
81214

82-
// Encrypt the data
83215
encryptedData, err := encryptionMaterials.EncryptionKey().Encrypt(rawData, []byte(key))
84216
if err != nil {
85217
return nil, fmt.Errorf("error encrypting attribute value: %v", err)
86218
}
87219

88-
// Store the encrypted data as a binary attribute value
89220
encryptedItem[key] = &types.AttributeValueMemberB{Value: encryptedData}
90221
}
91222

92-
// Create a new PutItemInput with the encrypted item
93-
encryptedInput := &dynamodb.PutItemInput{
94-
TableName: input.TableName,
95-
Item: encryptedItem,
96-
}
97-
98-
// Put the encrypted item into the DynamoDB table
99-
return ec.client.PutItem(ctx, encryptedInput)
223+
return encryptedItem, nil
100224
}
101225

102-
// GetItem retrieves an item from a DynamoDB table and decrypts it.
103-
func (ec *EncryptedClient) GetItem(ctx context.Context, input *dynamodb.GetItemInput) (*dynamodb.GetItemOutput, error) {
104-
// First, retrieve the encrypted item from DynamoDB
105-
encryptedOutput, err := ec.client.GetItem(ctx, input)
226+
// decryptItem decrypts a DynamoDB item's attributes, excluding primary keys.
227+
func (ec *EncryptedClient) decryptItem(ctx context.Context, tableName string, item map[string]types.AttributeValue) (map[string]types.AttributeValue, error) {
228+
// Fetch primary key info to identify these attributes
229+
pkInfo, err := ec.getPrimaryKeyInfo(ctx, tableName)
106230
if err != nil {
107-
return nil, fmt.Errorf("error retrieving encrypted item: %v", err)
108-
}
109-
110-
// Check if item is found
111-
if encryptedOutput.Item == nil {
112-
return nil, fmt.Errorf("item not found")
113-
}
114-
115-
tableName := *input.TableName
116-
117-
// Cache check for primary key info
118-
pkInfo, ok := ec.primaryKeyCache[tableName]
119-
if !ok {
120-
var err error
121-
pkInfo, err = ec.getPrimaryKeyInfo(ctx, tableName)
122-
if err != nil {
123-
return nil, fmt.Errorf("error fetching primary key info: %v", err)
124-
}
125-
}
126-
partitionKeyValue := input.Key[pkInfo.PartitionKey].(*types.AttributeValueMemberS).Value
127-
var sortKeyValue string
128-
if pkInfo.SortKey != "" && input.Key[pkInfo.SortKey] != nil {
129-
sortKeyValue = input.Key[pkInfo.SortKey].(*types.AttributeValueMemberS).Value
130-
}
131-
132-
// Construct and hash the material name
133-
rawMaterialName := tableName + "-" + partitionKeyValue
134-
if sortKeyValue != "" {
135-
rawMaterialName += "-" + sortKeyValue
231+
return nil, err
136232
}
137233

138-
materialName := utils.HashString(rawMaterialName)
139-
140-
// Fetch decryption materials
234+
// Construct the material name based on primary keys
235+
materialName := ec.constructMaterialName(item, pkInfo)
141236
decryptionMaterials, err := ec.materialsProvider.DecryptionMaterials(ctx, materialName, 0)
142237
if err != nil {
143238
return nil, fmt.Errorf("failed to fetch decryption materials: %v", err)
144239
}
145240

146-
// Decrypt each attribute value, excluding primary keys
147241
decryptedItem := make(map[string]types.AttributeValue)
148-
for key, value := range encryptedOutput.Item {
242+
for key, value := range item {
243+
// Copy primary key attributes as is
149244
if key == pkInfo.PartitionKey || key == pkInfo.SortKey {
150-
// Copy primary key attributes as is
151245
decryptedItem[key] = value
152246
continue
153247
}
154248

155-
// Decrypt the data
156-
rawData, err := decryptionMaterials.DecryptionKey().Decrypt(value.(*types.AttributeValueMemberB).Value, []byte(key))
249+
encryptedData, ok := value.(*types.AttributeValueMemberB)
250+
if !ok {
251+
return nil, fmt.Errorf("expected binary data for encrypted attribute value")
252+
}
253+
254+
rawData, err := decryptionMaterials.DecryptionKey().Decrypt(encryptedData.Value, []byte(key))
157255
if err != nil {
158256
return nil, fmt.Errorf("error decrypting attribute value: %v", err)
159257
}
160258

161-
// Convert bytes back to AttributeValue
162259
decryptedValue, err := utils.BytesToAttributeValue(rawData)
163260
if err != nil {
164261
return nil, fmt.Errorf("error converting bytes to attribute value: %v", err)
@@ -167,28 +264,22 @@ func (ec *EncryptedClient) GetItem(ctx context.Context, input *dynamodb.GetItemI
167264
decryptedItem[key] = decryptedValue
168265
}
169266

170-
// Create a new GetItemOutput with the decrypted item
171-
decryptedOutput := &dynamodb.GetItemOutput{
172-
Item: decryptedItem,
173-
}
174-
175-
return decryptedOutput, nil
267+
return decryptedItem, nil
176268
}
177269

178-
func (ec *EncryptedClient) getPrimaryKeyInfo(ctx context.Context, tableName string) (*utils.PrimaryKeyInfo, error) {
179-
if ec.primaryKeyInfo != nil {
180-
return ec.primaryKeyInfo, nil
270+
// constructMaterialName constructs a material name based on an item's primary key.
271+
func (ec *EncryptedClient) constructMaterialName(item map[string]types.AttributeValue, pkInfo *utils.PrimaryKeyInfo) string {
272+
partitionKeyValue := item[pkInfo.PartitionKey].(*types.AttributeValueMemberS).Value
273+
sortKeyValue := ""
274+
if pkInfo.SortKey != "" && item[pkInfo.SortKey] != nil {
275+
sortKeyValue = item[pkInfo.SortKey].(*types.AttributeValueMemberS).Value
181276
}
182277

183-
// Fetch the table info since it's not yet cached
184-
pkInfo, err := utils.TableInfo(ctx, ec.client, tableName)
185-
if err != nil {
186-
return nil, err
278+
// rawMaterialName := pkInfo.TableName + "-" + partitionKeyValue
279+
rawMaterialName := partitionKeyValue
280+
if sortKeyValue != "" {
281+
rawMaterialName += "-" + sortKeyValue
187282
}
188283

189-
// Cache the table info for future use
190-
ec.primaryKeyInfo = pkInfo
191-
ec.primaryKeyCache[tableName] = pkInfo
192-
193-
return pkInfo, nil
284+
return utils.HashString(rawMaterialName)
194285
}

0 commit comments

Comments
 (0)