You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
336 lines
9.2 KiB
336 lines
9.2 KiB
3 years ago
|
// Copyright (C) MongoDB, Inc. 2017-present.
|
||
|
//
|
||
|
// Licensed under the Apache License, Version 2.0 (the "License"); you may
|
||
|
// not use this file except in compliance with the License. You may obtain
|
||
|
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
|
||
|
|
||
|
package driver
|
||
|
|
||
|
import (
|
||
|
"context"
|
||
|
"crypto/tls"
|
||
|
"fmt"
|
||
|
"io"
|
||
|
"strings"
|
||
|
"time"
|
||
|
|
||
|
"go.mongodb.org/mongo-driver/x/bsonx/bsoncore"
|
||
|
"go.mongodb.org/mongo-driver/x/mongo/driver/mongocrypt"
|
||
|
"go.mongodb.org/mongo-driver/x/mongo/driver/mongocrypt/options"
|
||
|
)
|
||
|
|
||
|
const (
|
||
|
defaultKmsPort = 443
|
||
|
defaultKmsTimeout = 10 * time.Second
|
||
|
)
|
||
|
|
||
|
// CollectionInfoFn is a callback used to retrieve collection information.
|
||
|
type CollectionInfoFn func(ctx context.Context, db string, filter bsoncore.Document) (bsoncore.Document, error)
|
||
|
|
||
|
// KeyRetrieverFn is a callback used to retrieve keys from the key vault.
|
||
|
type KeyRetrieverFn func(ctx context.Context, filter bsoncore.Document) ([]bsoncore.Document, error)
|
||
|
|
||
|
// MarkCommandFn is a callback used to add encryption markings to a command.
|
||
|
type MarkCommandFn func(ctx context.Context, db string, cmd bsoncore.Document) (bsoncore.Document, error)
|
||
|
|
||
|
// CryptOptions specifies options to configure a Crypt instance.
|
||
|
type CryptOptions struct {
|
||
|
CollInfoFn CollectionInfoFn
|
||
|
KeyFn KeyRetrieverFn
|
||
|
MarkFn MarkCommandFn
|
||
|
KmsProviders bsoncore.Document
|
||
|
SchemaMap map[string]bsoncore.Document
|
||
|
TLSConfig map[string]*tls.Config
|
||
|
BypassAutoEncryption bool
|
||
|
}
|
||
|
|
||
|
// Crypt is an interface implemented by types that can encrypt and decrypt instances of
|
||
|
// bsoncore.Document.
|
||
|
//
|
||
|
// Users should rely on the driver's crypt type (used by default) for encryption and decryption
|
||
|
// unless they are perfectly confident in another implementation of Crypt.
|
||
|
type Crypt interface {
|
||
|
// Encrypt encrypts the given command.
|
||
|
Encrypt(ctx context.Context, db string, cmd bsoncore.Document) (bsoncore.Document, error)
|
||
|
// Decrypt decrypts the given command response.
|
||
|
Decrypt(ctx context.Context, cmdResponse bsoncore.Document) (bsoncore.Document, error)
|
||
|
// CreateDataKey creates a data key using the given KMS provider and options.
|
||
|
CreateDataKey(ctx context.Context, kmsProvider string, opts *options.DataKeyOptions) (bsoncore.Document, error)
|
||
|
// EncryptExplicit encrypts the given value with the given options.
|
||
|
EncryptExplicit(ctx context.Context, val bsoncore.Value, opts *options.ExplicitEncryptionOptions) (byte, []byte, error)
|
||
|
// DecryptExplicit decrypts the given encrypted value.
|
||
|
DecryptExplicit(ctx context.Context, subtype byte, data []byte) (bsoncore.Value, error)
|
||
|
// Close cleans up any resources associated with the Crypt instance.
|
||
|
Close()
|
||
|
// BypassAutoEncryption returns true if auto-encryption should be bypassed.
|
||
|
BypassAutoEncryption() bool
|
||
|
}
|
||
|
|
||
|
// crypt consumes the libmongocrypt.MongoCrypt type to iterate the mongocrypt state machine and perform encryption
|
||
|
// and decryption.
|
||
|
type crypt struct {
|
||
|
mongoCrypt *mongocrypt.MongoCrypt
|
||
|
collInfoFn CollectionInfoFn
|
||
|
keyFn KeyRetrieverFn
|
||
|
markFn MarkCommandFn
|
||
|
tlsConfig map[string]*tls.Config
|
||
|
|
||
|
bypassAutoEncryption bool
|
||
|
}
|
||
|
|
||
|
// NewCrypt creates a new Crypt instance configured with the given AutoEncryptionOptions.
|
||
|
func NewCrypt(opts *CryptOptions) (Crypt, error) {
|
||
|
c := &crypt{
|
||
|
collInfoFn: opts.CollInfoFn,
|
||
|
keyFn: opts.KeyFn,
|
||
|
markFn: opts.MarkFn,
|
||
|
tlsConfig: opts.TLSConfig,
|
||
|
bypassAutoEncryption: opts.BypassAutoEncryption,
|
||
|
}
|
||
|
|
||
|
mongocryptOpts := options.MongoCrypt().SetKmsProviders(opts.KmsProviders).SetLocalSchemaMap(opts.SchemaMap)
|
||
|
mc, err := mongocrypt.NewMongoCrypt(mongocryptOpts)
|
||
|
if err != nil {
|
||
|
return nil, err
|
||
|
}
|
||
|
|
||
|
c.mongoCrypt = mc
|
||
|
return c, nil
|
||
|
}
|
||
|
|
||
|
// Encrypt encrypts the given command.
|
||
|
func (c *crypt) Encrypt(ctx context.Context, db string, cmd bsoncore.Document) (bsoncore.Document, error) {
|
||
|
if c.bypassAutoEncryption {
|
||
|
return cmd, nil
|
||
|
}
|
||
|
|
||
|
cryptCtx, err := c.mongoCrypt.CreateEncryptionContext(db, cmd)
|
||
|
if err != nil {
|
||
|
return nil, err
|
||
|
}
|
||
|
defer cryptCtx.Close()
|
||
|
|
||
|
return c.executeStateMachine(ctx, cryptCtx, db)
|
||
|
}
|
||
|
|
||
|
// Decrypt decrypts the given command response.
|
||
|
func (c *crypt) Decrypt(ctx context.Context, cmdResponse bsoncore.Document) (bsoncore.Document, error) {
|
||
|
cryptCtx, err := c.mongoCrypt.CreateDecryptionContext(cmdResponse)
|
||
|
if err != nil {
|
||
|
return nil, err
|
||
|
}
|
||
|
defer cryptCtx.Close()
|
||
|
|
||
|
return c.executeStateMachine(ctx, cryptCtx, "")
|
||
|
}
|
||
|
|
||
|
// CreateDataKey creates a data key using the given KMS provider and options.
|
||
|
func (c *crypt) CreateDataKey(ctx context.Context, kmsProvider string, opts *options.DataKeyOptions) (bsoncore.Document, error) {
|
||
|
cryptCtx, err := c.mongoCrypt.CreateDataKeyContext(kmsProvider, opts)
|
||
|
if err != nil {
|
||
|
return nil, err
|
||
|
}
|
||
|
defer cryptCtx.Close()
|
||
|
|
||
|
return c.executeStateMachine(ctx, cryptCtx, "")
|
||
|
}
|
||
|
|
||
|
// EncryptExplicit encrypts the given value with the given options.
|
||
|
func (c *crypt) EncryptExplicit(ctx context.Context, val bsoncore.Value, opts *options.ExplicitEncryptionOptions) (byte, []byte, error) {
|
||
|
idx, doc := bsoncore.AppendDocumentStart(nil)
|
||
|
doc = bsoncore.AppendValueElement(doc, "v", val)
|
||
|
doc, _ = bsoncore.AppendDocumentEnd(doc, idx)
|
||
|
|
||
|
cryptCtx, err := c.mongoCrypt.CreateExplicitEncryptionContext(doc, opts)
|
||
|
if err != nil {
|
||
|
return 0, nil, err
|
||
|
}
|
||
|
defer cryptCtx.Close()
|
||
|
|
||
|
res, err := c.executeStateMachine(ctx, cryptCtx, "")
|
||
|
if err != nil {
|
||
|
return 0, nil, err
|
||
|
}
|
||
|
|
||
|
sub, data := res.Lookup("v").Binary()
|
||
|
return sub, data, nil
|
||
|
}
|
||
|
|
||
|
// DecryptExplicit decrypts the given encrypted value.
|
||
|
func (c *crypt) DecryptExplicit(ctx context.Context, subtype byte, data []byte) (bsoncore.Value, error) {
|
||
|
idx, doc := bsoncore.AppendDocumentStart(nil)
|
||
|
doc = bsoncore.AppendBinaryElement(doc, "v", subtype, data)
|
||
|
doc, _ = bsoncore.AppendDocumentEnd(doc, idx)
|
||
|
|
||
|
cryptCtx, err := c.mongoCrypt.CreateExplicitDecryptionContext(doc)
|
||
|
if err != nil {
|
||
|
return bsoncore.Value{}, err
|
||
|
}
|
||
|
defer cryptCtx.Close()
|
||
|
|
||
|
res, err := c.executeStateMachine(ctx, cryptCtx, "")
|
||
|
if err != nil {
|
||
|
return bsoncore.Value{}, err
|
||
|
}
|
||
|
|
||
|
return res.Lookup("v"), nil
|
||
|
}
|
||
|
|
||
|
// Close cleans up any resources associated with the Crypt instance.
|
||
|
func (c *crypt) Close() {
|
||
|
c.mongoCrypt.Close()
|
||
|
}
|
||
|
|
||
|
func (c *crypt) BypassAutoEncryption() bool {
|
||
|
return c.bypassAutoEncryption
|
||
|
}
|
||
|
|
||
|
func (c *crypt) executeStateMachine(ctx context.Context, cryptCtx *mongocrypt.Context, db string) (bsoncore.Document, error) {
|
||
|
var err error
|
||
|
for {
|
||
|
state := cryptCtx.State()
|
||
|
switch state {
|
||
|
case mongocrypt.NeedMongoCollInfo:
|
||
|
err = c.collectionInfo(ctx, cryptCtx, db)
|
||
|
case mongocrypt.NeedMongoMarkings:
|
||
|
err = c.markCommand(ctx, cryptCtx, db)
|
||
|
case mongocrypt.NeedMongoKeys:
|
||
|
err = c.retrieveKeys(ctx, cryptCtx)
|
||
|
case mongocrypt.NeedKms:
|
||
|
err = c.decryptKeys(cryptCtx)
|
||
|
case mongocrypt.Ready:
|
||
|
return cryptCtx.Finish()
|
||
|
default:
|
||
|
return nil, fmt.Errorf("invalid Crypt state: %v", state)
|
||
|
}
|
||
|
if err != nil {
|
||
|
return nil, err
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
|
||
|
func (c *crypt) collectionInfo(ctx context.Context, cryptCtx *mongocrypt.Context, db string) error {
|
||
|
op, err := cryptCtx.NextOperation()
|
||
|
if err != nil {
|
||
|
return err
|
||
|
}
|
||
|
|
||
|
collInfo, err := c.collInfoFn(ctx, db, op)
|
||
|
if err != nil {
|
||
|
return err
|
||
|
}
|
||
|
if collInfo != nil {
|
||
|
if err = cryptCtx.AddOperationResult(collInfo); err != nil {
|
||
|
return err
|
||
|
}
|
||
|
}
|
||
|
|
||
|
return cryptCtx.CompleteOperation()
|
||
|
}
|
||
|
|
||
|
func (c *crypt) markCommand(ctx context.Context, cryptCtx *mongocrypt.Context, db string) error {
|
||
|
op, err := cryptCtx.NextOperation()
|
||
|
if err != nil {
|
||
|
return err
|
||
|
}
|
||
|
|
||
|
markedCmd, err := c.markFn(ctx, db, op)
|
||
|
if err != nil {
|
||
|
return err
|
||
|
}
|
||
|
if err = cryptCtx.AddOperationResult(markedCmd); err != nil {
|
||
|
return err
|
||
|
}
|
||
|
|
||
|
return cryptCtx.CompleteOperation()
|
||
|
}
|
||
|
|
||
|
func (c *crypt) retrieveKeys(ctx context.Context, cryptCtx *mongocrypt.Context) error {
|
||
|
op, err := cryptCtx.NextOperation()
|
||
|
if err != nil {
|
||
|
return err
|
||
|
}
|
||
|
|
||
|
keys, err := c.keyFn(ctx, op)
|
||
|
if err != nil {
|
||
|
return err
|
||
|
}
|
||
|
|
||
|
for _, key := range keys {
|
||
|
if err = cryptCtx.AddOperationResult(key); err != nil {
|
||
|
return err
|
||
|
}
|
||
|
}
|
||
|
|
||
|
return cryptCtx.CompleteOperation()
|
||
|
}
|
||
|
|
||
|
func (c *crypt) decryptKeys(cryptCtx *mongocrypt.Context) error {
|
||
|
for {
|
||
|
kmsCtx := cryptCtx.NextKmsContext()
|
||
|
if kmsCtx == nil {
|
||
|
break
|
||
|
}
|
||
|
|
||
|
if err := c.decryptKey(kmsCtx); err != nil {
|
||
|
return err
|
||
|
}
|
||
|
}
|
||
|
|
||
|
return cryptCtx.FinishKmsContexts()
|
||
|
}
|
||
|
|
||
|
func (c *crypt) decryptKey(kmsCtx *mongocrypt.KmsContext) error {
|
||
|
host, err := kmsCtx.HostName()
|
||
|
if err != nil {
|
||
|
return err
|
||
|
}
|
||
|
msg, err := kmsCtx.Message()
|
||
|
if err != nil {
|
||
|
return err
|
||
|
}
|
||
|
|
||
|
// add a port to the address if it's not already present
|
||
|
addr := host
|
||
|
if idx := strings.IndexByte(host, ':'); idx == -1 {
|
||
|
addr = fmt.Sprintf("%s:%d", host, defaultKmsPort)
|
||
|
}
|
||
|
|
||
|
kmsProvider := kmsCtx.KMSProvider()
|
||
|
tlsCfg := c.tlsConfig[kmsProvider]
|
||
|
if tlsCfg == nil {
|
||
|
tlsCfg = &tls.Config{MinVersion: tls.VersionTLS12}
|
||
|
}
|
||
|
conn, err := tls.Dial("tcp", addr, tlsCfg)
|
||
|
if err != nil {
|
||
|
return err
|
||
|
}
|
||
|
defer func() {
|
||
|
_ = conn.Close()
|
||
|
}()
|
||
|
|
||
|
if err = conn.SetWriteDeadline(time.Now().Add(defaultKmsTimeout)); err != nil {
|
||
|
return err
|
||
|
}
|
||
|
if _, err = conn.Write(msg); err != nil {
|
||
|
return err
|
||
|
}
|
||
|
|
||
|
for {
|
||
|
bytesNeeded := kmsCtx.BytesNeeded()
|
||
|
if bytesNeeded == 0 {
|
||
|
return nil
|
||
|
}
|
||
|
|
||
|
res := make([]byte, bytesNeeded)
|
||
|
bytesRead, err := conn.Read(res)
|
||
|
if err != nil && err != io.EOF {
|
||
|
return err
|
||
|
}
|
||
|
|
||
|
if err = kmsCtx.FeedResponse(res[:bytesRead]); err != nil {
|
||
|
return err
|
||
|
}
|
||
|
}
|
||
|
}
|