You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
1619 lines
56 KiB
1619 lines
56 KiB
3 years ago
|
package driver
|
||
|
|
||
|
import (
|
||
|
"bytes"
|
||
|
"context"
|
||
|
"errors"
|
||
|
"fmt"
|
||
|
"strconv"
|
||
|
"strings"
|
||
|
"time"
|
||
|
|
||
|
"go.mongodb.org/mongo-driver/bson"
|
||
|
"go.mongodb.org/mongo-driver/bson/bsontype"
|
||
|
"go.mongodb.org/mongo-driver/bson/primitive"
|
||
|
"go.mongodb.org/mongo-driver/event"
|
||
|
"go.mongodb.org/mongo-driver/internal"
|
||
|
"go.mongodb.org/mongo-driver/mongo/description"
|
||
|
"go.mongodb.org/mongo-driver/mongo/readconcern"
|
||
|
"go.mongodb.org/mongo-driver/mongo/readpref"
|
||
|
"go.mongodb.org/mongo-driver/mongo/writeconcern"
|
||
|
"go.mongodb.org/mongo-driver/x/bsonx/bsoncore"
|
||
|
"go.mongodb.org/mongo-driver/x/mongo/driver/session"
|
||
|
"go.mongodb.org/mongo-driver/x/mongo/driver/wiremessage"
|
||
|
)
|
||
|
|
||
|
const defaultLocalThreshold = 15 * time.Millisecond
|
||
|
|
||
|
var dollarCmd = [...]byte{'.', '$', 'c', 'm', 'd'}
|
||
|
|
||
|
var (
|
||
|
// ErrNoDocCommandResponse occurs when the server indicated a response existed, but none was found.
|
||
|
ErrNoDocCommandResponse = errors.New("command returned no documents")
|
||
|
// ErrMultiDocCommandResponse occurs when the server sent multiple documents in response to a command.
|
||
|
ErrMultiDocCommandResponse = errors.New("command returned multiple documents")
|
||
|
// ErrReplyDocumentMismatch occurs when the number of documents returned in an OP_QUERY does not match the numberReturned field.
|
||
|
ErrReplyDocumentMismatch = errors.New("number of documents returned does not match numberReturned field")
|
||
|
// ErrNonPrimaryReadPref is returned when a read is attempted in a transaction with a non-primary read preference.
|
||
|
ErrNonPrimaryReadPref = errors.New("read preference in a transaction must be primary")
|
||
|
)
|
||
|
|
||
|
const (
|
||
|
// maximum BSON object size when client side encryption is enabled
|
||
|
cryptMaxBsonObjectSize uint32 = 2097152
|
||
|
// minimum wire version necessary to use automatic encryption
|
||
|
cryptMinWireVersion int32 = 8
|
||
|
// minimum wire version necessary to use read snapshots
|
||
|
readSnapshotMinWireVersion int32 = 13
|
||
|
)
|
||
|
|
||
|
// RetryablePoolError is a connection pool error that can be retried while executing an operation.
|
||
|
type RetryablePoolError interface {
|
||
|
Retryable() bool
|
||
|
}
|
||
|
|
||
|
// InvalidOperationError is returned from Validate and indicates that a required field is missing
|
||
|
// from an instance of Operation.
|
||
|
type InvalidOperationError struct{ MissingField string }
|
||
|
|
||
|
func (err InvalidOperationError) Error() string {
|
||
|
return "the " + err.MissingField + " field must be set on Operation"
|
||
|
}
|
||
|
|
||
|
// opReply stores information returned in an OP_REPLY response from the server.
|
||
|
// The err field stores any error that occurred when decoding or validating the OP_REPLY response.
|
||
|
type opReply struct {
|
||
|
responseFlags wiremessage.ReplyFlag
|
||
|
cursorID int64
|
||
|
startingFrom int32
|
||
|
numReturned int32
|
||
|
documents []bsoncore.Document
|
||
|
err error
|
||
|
}
|
||
|
|
||
|
// startedInformation keeps track of all of the information necessary for monitoring started events.
|
||
|
type startedInformation struct {
|
||
|
cmd bsoncore.Document
|
||
|
requestID int32
|
||
|
cmdName string
|
||
|
documentSequenceIncluded bool
|
||
|
connID string
|
||
|
serverConnID *int32
|
||
|
redacted bool
|
||
|
serviceID *primitive.ObjectID
|
||
|
}
|
||
|
|
||
|
// finishedInformation keeps track of all of the information necessary for monitoring success and failure events.
|
||
|
type finishedInformation struct {
|
||
|
cmdName string
|
||
|
requestID int32
|
||
|
response bsoncore.Document
|
||
|
cmdErr error
|
||
|
connID string
|
||
|
serverConnID *int32
|
||
|
startTime time.Time
|
||
|
redacted bool
|
||
|
serviceID *primitive.ObjectID
|
||
|
}
|
||
|
|
||
|
// ResponseInfo contains the context required to parse a server response.
|
||
|
type ResponseInfo struct {
|
||
|
ServerResponse bsoncore.Document
|
||
|
Server Server
|
||
|
Connection Connection
|
||
|
ConnectionDescription description.Server
|
||
|
CurrentIndex int
|
||
|
}
|
||
|
|
||
|
// Operation is used to execute an operation. It contains all of the common code required to
|
||
|
// select a server, transform an operation into a command, write the command to a connection from
|
||
|
// the selected server, read a response from that connection, process the response, and potentially
|
||
|
// retry.
|
||
|
//
|
||
|
// The required fields are Database, CommandFn, and Deployment. All other fields are optional.
|
||
|
//
|
||
|
// While an Operation can be constructed manually, drivergen should be used to generate an
|
||
|
// implementation of an operation instead. This will ensure that there are helpers for constructing
|
||
|
// the operation and that this type isn't configured incorrectly.
|
||
|
type Operation struct {
|
||
|
// CommandFn is used to create the command that will be wrapped in a wire message and sent to
|
||
|
// the server. This function should only add the elements of the command and not start or end
|
||
|
// the enclosing BSON document. Per the command API, the first element must be the name of the
|
||
|
// command to run. This field is required.
|
||
|
CommandFn func(dst []byte, desc description.SelectedServer) ([]byte, error)
|
||
|
|
||
|
// Database is the database that the command will be run against. This field is required.
|
||
|
Database string
|
||
|
|
||
|
// Deployment is the MongoDB Deployment to use. While most of the time this will be multiple
|
||
|
// servers, commands that need to run against a single, preselected server can use the
|
||
|
// SingleServerDeployment type. Commands that need to run on a preselected connection can use
|
||
|
// the SingleConnectionDeployment type.
|
||
|
Deployment Deployment
|
||
|
|
||
|
// ProcessResponseFn is called after a response to the command is returned. The server is
|
||
|
// provided for types like Cursor that are required to run subsequent commands using the same
|
||
|
// server.
|
||
|
ProcessResponseFn func(ResponseInfo) error
|
||
|
|
||
|
// Selector is the server selector that's used during both initial server selection and
|
||
|
// subsequent selection for retries. Depending on the Deployment implementation, the
|
||
|
// SelectServer method may not actually be called.
|
||
|
Selector description.ServerSelector
|
||
|
|
||
|
// ReadPreference is the read preference that will be attached to the command. If this field is
|
||
|
// not specified a default read preference of primary will be used.
|
||
|
ReadPreference *readpref.ReadPref
|
||
|
|
||
|
// ReadConcern is the read concern used when running read commands. This field should not be set
|
||
|
// for write operations. If this field is set, it will be encoded onto the commands sent to the
|
||
|
// server.
|
||
|
ReadConcern *readconcern.ReadConcern
|
||
|
|
||
|
// MinimumReadConcernWireVersion specifies the minimum wire version to add the read concern to
|
||
|
// the command being executed.
|
||
|
MinimumReadConcernWireVersion int32
|
||
|
|
||
|
// WriteConcern is the write concern used when running write commands. This field should not be
|
||
|
// set for read operations. If this field is set, it will be encoded onto the commands sent to
|
||
|
// the server.
|
||
|
WriteConcern *writeconcern.WriteConcern
|
||
|
|
||
|
// MinimumWriteConcernWireVersion specifies the minimum wire version to add the write concern to
|
||
|
// the command being executed.
|
||
|
MinimumWriteConcernWireVersion int32
|
||
|
|
||
|
// Client is the session used with this operation. This can be either an implicit or explicit
|
||
|
// session. If the server selected does not support sessions and Client is specified the
|
||
|
// behavior depends on the session type. If the session is implicit, the session fields will not
|
||
|
// be encoded onto the command. If the session is explicit, an error will be returned. The
|
||
|
// caller is responsible for ensuring that this field is nil if the Deployment does not support
|
||
|
// sessions.
|
||
|
Client *session.Client
|
||
|
|
||
|
// Clock is a cluster clock, different from the one contained within a session.Client. This
|
||
|
// allows updating cluster times for a global cluster clock while allowing individual session's
|
||
|
// cluster clocks to be only updated as far as the last command that's been run.
|
||
|
Clock *session.ClusterClock
|
||
|
|
||
|
// RetryMode specifies how to retry. There are three modes that enable retry: RetryOnce,
|
||
|
// RetryOncePerCommand, and RetryContext. For more information about what these modes do, please
|
||
|
// refer to their definitions. Both RetryMode and Type must be set for retryability to be enabled.
|
||
|
RetryMode *RetryMode
|
||
|
|
||
|
// Type specifies the kind of operation this is. There is only one mode that enables retry: Write.
|
||
|
// For more information about what this mode does, please refer to it's definition. Both Type and
|
||
|
// RetryMode must be set for retryability to be enabled.
|
||
|
Type Type
|
||
|
|
||
|
// Batches contains the documents that are split when executing a write command that potentially
|
||
|
// has more documents than can fit in a single command. This should only be specified for
|
||
|
// commands that are batch compatible. For more information, please refer to the definition of
|
||
|
// Batches.
|
||
|
Batches *Batches
|
||
|
|
||
|
// Legacy sets the legacy type for this operation. There are only 3 types that require legacy
|
||
|
// support: find, getMore, and killCursors. For more information about LegacyOperationKind,
|
||
|
// please refer to it's definition.
|
||
|
Legacy LegacyOperationKind
|
||
|
|
||
|
// CommandMonitor specifies the monitor to use for APM events. If this field is not set,
|
||
|
// no events will be reported.
|
||
|
CommandMonitor *event.CommandMonitor
|
||
|
|
||
|
// Crypt specifies a Crypt object to use for automatic client side encryption and decryption.
|
||
|
Crypt Crypt
|
||
|
|
||
|
// ServerAPI specifies options used to configure the API version sent to the server.
|
||
|
ServerAPI *ServerAPIOptions
|
||
|
|
||
|
// IsOutputAggregate specifies whether this operation is an aggregate with an output stage. If true,
|
||
|
// read preference will not be added to the command on wire versions < 13.
|
||
|
IsOutputAggregate bool
|
||
|
|
||
|
// cmdName is only set when serializing OP_MSG and is used internally in readWireMessage.
|
||
|
cmdName string
|
||
|
}
|
||
|
|
||
|
// shouldEncrypt returns true if this operation should automatically be encrypted.
|
||
|
func (op Operation) shouldEncrypt() bool {
|
||
|
return op.Crypt != nil && !op.Crypt.BypassAutoEncryption()
|
||
|
}
|
||
|
|
||
|
// selectServer handles performing server selection for an operation.
|
||
|
func (op Operation) selectServer(ctx context.Context) (Server, error) {
|
||
|
if err := op.Validate(); err != nil {
|
||
|
return nil, err
|
||
|
}
|
||
|
|
||
|
selector := op.Selector
|
||
|
if selector == nil {
|
||
|
rp := op.ReadPreference
|
||
|
if rp == nil {
|
||
|
rp = readpref.Primary()
|
||
|
}
|
||
|
selector = description.CompositeSelector([]description.ServerSelector{
|
||
|
description.ReadPrefSelector(rp),
|
||
|
description.LatencySelector(defaultLocalThreshold),
|
||
|
})
|
||
|
}
|
||
|
|
||
|
return op.Deployment.SelectServer(ctx, selector)
|
||
|
}
|
||
|
|
||
|
// getServerAndConnection should be used to retrieve a Server and Connection to execute an operation.
|
||
|
func (op Operation) getServerAndConnection(ctx context.Context) (Server, Connection, error) {
|
||
|
server, err := op.selectServer(ctx)
|
||
|
if err != nil {
|
||
|
return nil, nil, err
|
||
|
}
|
||
|
|
||
|
// If the provided client session has a pinned connection, it should be used for the operation because this
|
||
|
// indicates that we're in a transaction and the target server is behind a load balancer.
|
||
|
if op.Client != nil && op.Client.PinnedConnection != nil {
|
||
|
return server, op.Client.PinnedConnection, nil
|
||
|
}
|
||
|
|
||
|
// Otherwise, default to checking out a connection from the server's pool.
|
||
|
conn, err := server.Connection(ctx)
|
||
|
if err != nil {
|
||
|
return nil, nil, err
|
||
|
}
|
||
|
|
||
|
// If we're in load balanced mode and this is the first operation in a transaction, pin the session to a connection.
|
||
|
if conn.Description().LoadBalanced() && op.Client != nil && op.Client.TransactionStarting() {
|
||
|
pinnedConn, ok := conn.(PinnedConnection)
|
||
|
if !ok {
|
||
|
// Close the original connection to avoid a leak.
|
||
|
_ = conn.Close()
|
||
|
return nil, nil, fmt.Errorf("expected Connection used to start a transaction to be a PinnedConnection, but got %T", conn)
|
||
|
}
|
||
|
if err := pinnedConn.PinToTransaction(); err != nil {
|
||
|
// Close the original connection to avoid a leak.
|
||
|
_ = conn.Close()
|
||
|
return nil, nil, fmt.Errorf("error incrementing connection reference count when starting a transaction: %v", err)
|
||
|
}
|
||
|
op.Client.PinnedConnection = pinnedConn
|
||
|
}
|
||
|
|
||
|
return server, conn, nil
|
||
|
}
|
||
|
|
||
|
// Validate validates this operation, ensuring the fields are set properly.
|
||
|
func (op Operation) Validate() error {
|
||
|
if op.CommandFn == nil {
|
||
|
return InvalidOperationError{MissingField: "CommandFn"}
|
||
|
}
|
||
|
if op.Deployment == nil {
|
||
|
return InvalidOperationError{MissingField: "Deployment"}
|
||
|
}
|
||
|
if op.Database == "" {
|
||
|
return InvalidOperationError{MissingField: "Database"}
|
||
|
}
|
||
|
if op.Client != nil && !writeconcern.AckWrite(op.WriteConcern) {
|
||
|
return errors.New("session provided for an unacknowledged write")
|
||
|
}
|
||
|
return nil
|
||
|
}
|
||
|
|
||
|
// Execute runs this operation. The scratch parameter will be used and overwritten (potentially many
|
||
|
// times), this should mainly be used to enable pooling of byte slices.
|
||
|
func (op Operation) Execute(ctx context.Context, scratch []byte) error {
|
||
|
err := op.Validate()
|
||
|
if err != nil {
|
||
|
return err
|
||
|
}
|
||
|
|
||
|
if op.Client != nil {
|
||
|
if err := op.Client.StartCommand(); err != nil {
|
||
|
return err
|
||
|
}
|
||
|
}
|
||
|
|
||
|
var retries int
|
||
|
if op.RetryMode != nil {
|
||
|
switch op.Type {
|
||
|
case Write:
|
||
|
if op.Client == nil {
|
||
|
break
|
||
|
}
|
||
|
switch *op.RetryMode {
|
||
|
case RetryOnce, RetryOncePerCommand:
|
||
|
retries = 1
|
||
|
case RetryContext:
|
||
|
retries = -1
|
||
|
}
|
||
|
case Read:
|
||
|
switch *op.RetryMode {
|
||
|
case RetryOnce, RetryOncePerCommand:
|
||
|
retries = 1
|
||
|
case RetryContext:
|
||
|
retries = -1
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
|
||
|
var srvr Server
|
||
|
var conn Connection
|
||
|
var res bsoncore.Document
|
||
|
var operationErr WriteCommandError
|
||
|
var prevErr error
|
||
|
batching := op.Batches.Valid()
|
||
|
retryEnabled := op.RetryMode != nil && op.RetryMode.Enabled()
|
||
|
retrySupported := false
|
||
|
first := true
|
||
|
currIndex := 0
|
||
|
|
||
|
// resetForRetry records the error that caused the retry, decrements retries, and resets the
|
||
|
// retry loop variables to request a new server and a new connection for the next attempt.
|
||
|
resetForRetry := func(err error) {
|
||
|
retries--
|
||
|
prevErr = err
|
||
|
// If we got a connection, close it immediately to release pool resources for
|
||
|
// subsequent retries.
|
||
|
if conn != nil {
|
||
|
conn.Close()
|
||
|
}
|
||
|
// Set the server and connection to nil to request a new server and connection.
|
||
|
srvr = nil
|
||
|
conn = nil
|
||
|
}
|
||
|
|
||
|
for {
|
||
|
// If the server or connection are nil, try to select a new server and get a new connection.
|
||
|
if srvr == nil || conn == nil {
|
||
|
srvr, conn, err = op.getServerAndConnection(ctx)
|
||
|
if err != nil {
|
||
|
// If the returned error is retryable and there are retries remaining (negative
|
||
|
// retries means retry indefinitely), then retry the operation. Set the server
|
||
|
// and connection to nil to request a new server and connection.
|
||
|
if rerr, ok := err.(RetryablePoolError); ok && rerr.Retryable() && retries != 0 {
|
||
|
resetForRetry(err)
|
||
|
continue
|
||
|
}
|
||
|
|
||
|
// If this is a retry and there's an error from a previous attempt, return the previous
|
||
|
// error instead of the current connection error.
|
||
|
if prevErr != nil {
|
||
|
return prevErr
|
||
|
}
|
||
|
return err
|
||
|
}
|
||
|
defer conn.Close()
|
||
|
}
|
||
|
|
||
|
// Run steps that must only be run on the first attempt, but not again for retries.
|
||
|
if first {
|
||
|
// Determine if retries are supported for the current operation on the current server
|
||
|
// description. Per the retryable writes specification, only determine this for the
|
||
|
// first server selected:
|
||
|
//
|
||
|
// If the server selected for the first attempt of a retryable write operation does
|
||
|
// not support retryable writes, drivers MUST execute the write as if retryable writes
|
||
|
// were not enabled.
|
||
|
retrySupported = op.retryable(conn.Description())
|
||
|
|
||
|
// If retries are supported for the current operation on the current server description,
|
||
|
// client retries are enabled, the operation type is write, and we haven't incremented
|
||
|
// the txn number yet, enable retry writes on the session and increment the txn number.
|
||
|
// Calling IncrementTxnNumber() for server descriptions or topologies that do not
|
||
|
// support retries (e.g. standalone topologies) will cause server errors. Only do this
|
||
|
// check for the first attempt to keep retried writes in the same transaction.
|
||
|
if retrySupported && op.RetryMode != nil && op.Type == Write && op.Client != nil {
|
||
|
op.Client.RetryWrite = false
|
||
|
if op.RetryMode.Enabled() {
|
||
|
op.Client.RetryWrite = true
|
||
|
if !op.Client.Committing && !op.Client.Aborting {
|
||
|
op.Client.IncrementTxnNumber()
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
|
||
|
first = false
|
||
|
}
|
||
|
|
||
|
desc := description.SelectedServer{Server: conn.Description(), Kind: op.Deployment.Kind()}
|
||
|
scratch = scratch[:0]
|
||
|
if desc.WireVersion == nil || desc.WireVersion.Max < 4 {
|
||
|
switch op.Legacy {
|
||
|
case LegacyFind:
|
||
|
return op.legacyFind(ctx, scratch, srvr, conn, desc)
|
||
|
case LegacyGetMore:
|
||
|
return op.legacyGetMore(ctx, scratch, srvr, conn, desc)
|
||
|
case LegacyKillCursors:
|
||
|
return op.legacyKillCursors(ctx, scratch, srvr, conn, desc)
|
||
|
}
|
||
|
}
|
||
|
if desc.WireVersion == nil || desc.WireVersion.Max < 3 {
|
||
|
switch op.Legacy {
|
||
|
case LegacyListCollections:
|
||
|
return op.legacyListCollections(ctx, scratch, srvr, conn, desc)
|
||
|
case LegacyListIndexes:
|
||
|
return op.legacyListIndexes(ctx, scratch, srvr, conn, desc)
|
||
|
}
|
||
|
}
|
||
|
|
||
|
if batching {
|
||
|
targetBatchSize := desc.MaxDocumentSize
|
||
|
maxDocSize := desc.MaxDocumentSize
|
||
|
if op.shouldEncrypt() {
|
||
|
// For client-side encryption, we want the batch to be split at 2 MiB instead of 16MiB.
|
||
|
// If there's only one document in the batch, it can be up to 16MiB, so we set target batch size to
|
||
|
// 2MiB but max document size to 16MiB. This will allow the AdvanceBatch call to create a batch
|
||
|
// with a single large document.
|
||
|
targetBatchSize = cryptMaxBsonObjectSize
|
||
|
}
|
||
|
|
||
|
err = op.Batches.AdvanceBatch(int(desc.MaxBatchCount), int(targetBatchSize), int(maxDocSize))
|
||
|
if err != nil {
|
||
|
// TODO(GODRIVER-982): Should we also be returning operationErr?
|
||
|
return err
|
||
|
}
|
||
|
}
|
||
|
|
||
|
// convert to wire message
|
||
|
if len(scratch) > 0 {
|
||
|
scratch = scratch[:0]
|
||
|
}
|
||
|
wm, startedInfo, err := op.createWireMessage(ctx, scratch, desc, conn)
|
||
|
if err != nil {
|
||
|
return err
|
||
|
}
|
||
|
|
||
|
// set extra data and send event if possible
|
||
|
startedInfo.connID = conn.ID()
|
||
|
startedInfo.cmdName = op.getCommandName(startedInfo.cmd)
|
||
|
op.cmdName = startedInfo.cmdName
|
||
|
startedInfo.redacted = op.redactCommand(startedInfo.cmdName, startedInfo.cmd)
|
||
|
startedInfo.serviceID = conn.Description().ServiceID
|
||
|
startedInfo.serverConnID = conn.ServerConnectionID()
|
||
|
op.publishStartedEvent(ctx, startedInfo)
|
||
|
|
||
|
// get the moreToCome flag information before we compress
|
||
|
moreToCome := wiremessage.IsMsgMoreToCome(wm)
|
||
|
|
||
|
// compress wiremessage if allowed
|
||
|
if compressor, ok := conn.(Compressor); ok && op.canCompress(startedInfo.cmdName) {
|
||
|
wm, err = compressor.CompressWireMessage(wm, nil)
|
||
|
if err != nil {
|
||
|
return err
|
||
|
}
|
||
|
}
|
||
|
|
||
|
finishedInfo := finishedInformation{
|
||
|
cmdName: startedInfo.cmdName,
|
||
|
requestID: startedInfo.requestID,
|
||
|
startTime: time.Now(),
|
||
|
connID: startedInfo.connID,
|
||
|
serverConnID: startedInfo.serverConnID,
|
||
|
redacted: startedInfo.redacted,
|
||
|
serviceID: startedInfo.serviceID,
|
||
|
}
|
||
|
|
||
|
// Check if there's enough time to perform a best-case network round trip before the Context
|
||
|
// deadline. If not, create a network error that wraps a context.DeadlineExceeded error and
|
||
|
// skip the actual round trip.
|
||
|
if deadline, ok := ctx.Deadline(); ok && time.Now().Add(srvr.MinRTT()).After(deadline) {
|
||
|
err = op.networkError(context.DeadlineExceeded)
|
||
|
} else {
|
||
|
// roundtrip using either the full roundTripper or a special one for when the moreToCome
|
||
|
// flag is set
|
||
|
var roundTrip = op.roundTrip
|
||
|
if moreToCome {
|
||
|
roundTrip = op.moreToComeRoundTrip
|
||
|
}
|
||
|
res, err = roundTrip(ctx, conn, wm)
|
||
|
|
||
|
if ep, ok := srvr.(ErrorProcessor); ok {
|
||
|
_ = ep.ProcessError(err, conn)
|
||
|
}
|
||
|
}
|
||
|
|
||
|
finishedInfo.response = res
|
||
|
finishedInfo.cmdErr = err
|
||
|
op.publishFinishedEvent(ctx, finishedInfo)
|
||
|
|
||
|
var perr error
|
||
|
switch tt := err.(type) {
|
||
|
case WriteCommandError:
|
||
|
if e := err.(WriteCommandError); retrySupported && op.Type == Write && e.UnsupportedStorageEngine() {
|
||
|
return ErrUnsupportedStorageEngine
|
||
|
}
|
||
|
|
||
|
connDesc := conn.Description()
|
||
|
retryableErr := tt.Retryable(connDesc.WireVersion)
|
||
|
preRetryWriteLabelVersion := connDesc.WireVersion != nil && connDesc.WireVersion.Max < 9
|
||
|
inTransaction := op.Client != nil &&
|
||
|
!(op.Client.Committing || op.Client.Aborting) && op.Client.TransactionRunning()
|
||
|
// If retry is enabled and the operation isn't in a transaction, add a RetryableWriteError label for
|
||
|
// retryable errors from pre-4.4 servers
|
||
|
if retryableErr && preRetryWriteLabelVersion && retryEnabled && !inTransaction {
|
||
|
tt.Labels = append(tt.Labels, RetryableWriteError)
|
||
|
}
|
||
|
|
||
|
// If retries are supported for the current operation on the first server description,
|
||
|
// the error is considered retryable, and there are retries remaining (negative retries
|
||
|
// means retry indefinitely), then retry the operation.
|
||
|
if retrySupported && retryableErr && retries != 0 {
|
||
|
if op.Client != nil && op.Client.Committing {
|
||
|
// Apply majority write concern for retries
|
||
|
op.Client.UpdateCommitTransactionWriteConcern()
|
||
|
op.WriteConcern = op.Client.CurrentWc
|
||
|
}
|
||
|
resetForRetry(tt)
|
||
|
continue
|
||
|
}
|
||
|
|
||
|
// If the operation isn't being retried, process the response
|
||
|
if op.ProcessResponseFn != nil {
|
||
|
info := ResponseInfo{
|
||
|
ServerResponse: res,
|
||
|
Server: srvr,
|
||
|
Connection: conn,
|
||
|
ConnectionDescription: desc.Server,
|
||
|
CurrentIndex: currIndex,
|
||
|
}
|
||
|
_ = op.ProcessResponseFn(info)
|
||
|
}
|
||
|
|
||
|
if batching && len(tt.WriteErrors) > 0 && currIndex > 0 {
|
||
|
for i := range tt.WriteErrors {
|
||
|
tt.WriteErrors[i].Index += int64(currIndex)
|
||
|
}
|
||
|
}
|
||
|
|
||
|
// If batching is enabled and either ordered is the default (which is true) or
|
||
|
// explicitly set to true and we have write errors, return the errors.
|
||
|
if batching && (op.Batches.Ordered == nil || *op.Batches.Ordered) && len(tt.WriteErrors) > 0 {
|
||
|
return tt
|
||
|
}
|
||
|
if op.Client != nil && op.Client.Committing && tt.WriteConcernError != nil {
|
||
|
// When running commitTransaction we return WriteConcernErrors as an Error.
|
||
|
err := Error{
|
||
|
Name: tt.WriteConcernError.Name,
|
||
|
Code: int32(tt.WriteConcernError.Code),
|
||
|
Message: tt.WriteConcernError.Message,
|
||
|
Labels: tt.Labels,
|
||
|
Raw: tt.Raw,
|
||
|
}
|
||
|
// The UnknownTransactionCommitResult label is added to all writeConcernErrors besides unknownReplWriteConcernCode
|
||
|
// and unsatisfiableWriteConcernCode
|
||
|
if err.Code != unknownReplWriteConcernCode && err.Code != unsatisfiableWriteConcernCode {
|
||
|
err.Labels = append(err.Labels, UnknownTransactionCommitResult)
|
||
|
}
|
||
|
if retryableErr && retryEnabled {
|
||
|
err.Labels = append(err.Labels, RetryableWriteError)
|
||
|
}
|
||
|
return err
|
||
|
}
|
||
|
operationErr.WriteConcernError = tt.WriteConcernError
|
||
|
operationErr.WriteErrors = append(operationErr.WriteErrors, tt.WriteErrors...)
|
||
|
operationErr.Labels = tt.Labels
|
||
|
operationErr.Raw = tt.Raw
|
||
|
case Error:
|
||
|
if tt.HasErrorLabel(TransientTransactionError) || tt.HasErrorLabel(UnknownTransactionCommitResult) {
|
||
|
if err := op.Client.ClearPinnedResources(); err != nil {
|
||
|
return err
|
||
|
}
|
||
|
}
|
||
|
|
||
|
if e := err.(Error); retrySupported && op.Type == Write && e.UnsupportedStorageEngine() {
|
||
|
return ErrUnsupportedStorageEngine
|
||
|
}
|
||
|
|
||
|
connDesc := conn.Description()
|
||
|
var retryableErr bool
|
||
|
if op.Type == Write {
|
||
|
retryableErr = tt.RetryableWrite(connDesc.WireVersion)
|
||
|
preRetryWriteLabelVersion := connDesc.WireVersion != nil && connDesc.WireVersion.Max < 9
|
||
|
inTransaction := op.Client != nil &&
|
||
|
!(op.Client.Committing || op.Client.Aborting) && op.Client.TransactionRunning()
|
||
|
// If retryWrites is enabled and the operation isn't in a transaction, add a RetryableWriteError label
|
||
|
// for network errors and retryable errors from pre-4.4 servers
|
||
|
if retryEnabled && !inTransaction &&
|
||
|
(tt.HasErrorLabel(NetworkError) || (retryableErr && preRetryWriteLabelVersion)) {
|
||
|
tt.Labels = append(tt.Labels, RetryableWriteError)
|
||
|
}
|
||
|
} else {
|
||
|
retryableErr = tt.RetryableRead()
|
||
|
}
|
||
|
|
||
|
// If retries are supported for the current operation on the first server description,
|
||
|
// the error is considered retryable, and there are retries remaining (negative retries
|
||
|
// means retry indefinitely), then retry the operation.
|
||
|
if retrySupported && retryableErr && retries != 0 {
|
||
|
if op.Client != nil && op.Client.Committing {
|
||
|
// Apply majority write concern for retries
|
||
|
op.Client.UpdateCommitTransactionWriteConcern()
|
||
|
op.WriteConcern = op.Client.CurrentWc
|
||
|
}
|
||
|
resetForRetry(tt)
|
||
|
continue
|
||
|
}
|
||
|
|
||
|
// If the operation isn't being retried, process the response
|
||
|
if op.ProcessResponseFn != nil {
|
||
|
info := ResponseInfo{
|
||
|
ServerResponse: res,
|
||
|
Server: srvr,
|
||
|
Connection: conn,
|
||
|
ConnectionDescription: desc.Server,
|
||
|
CurrentIndex: currIndex,
|
||
|
}
|
||
|
_ = op.ProcessResponseFn(info)
|
||
|
}
|
||
|
|
||
|
if op.Client != nil && op.Client.Committing && (retryableErr || tt.Code == 50) {
|
||
|
// If we got a retryable error or MaxTimeMSExpired error, we add UnknownTransactionCommitResult.
|
||
|
tt.Labels = append(tt.Labels, UnknownTransactionCommitResult)
|
||
|
}
|
||
|
return tt
|
||
|
case nil:
|
||
|
if moreToCome {
|
||
|
return ErrUnacknowledgedWrite
|
||
|
}
|
||
|
if op.ProcessResponseFn != nil {
|
||
|
info := ResponseInfo{
|
||
|
ServerResponse: res,
|
||
|
Server: srvr,
|
||
|
Connection: conn,
|
||
|
ConnectionDescription: desc.Server,
|
||
|
CurrentIndex: currIndex,
|
||
|
}
|
||
|
perr = op.ProcessResponseFn(info)
|
||
|
}
|
||
|
if perr != nil {
|
||
|
return perr
|
||
|
}
|
||
|
default:
|
||
|
if op.ProcessResponseFn != nil {
|
||
|
info := ResponseInfo{
|
||
|
ServerResponse: res,
|
||
|
Server: srvr,
|
||
|
Connection: conn,
|
||
|
ConnectionDescription: desc.Server,
|
||
|
CurrentIndex: currIndex,
|
||
|
}
|
||
|
_ = op.ProcessResponseFn(info)
|
||
|
}
|
||
|
return err
|
||
|
}
|
||
|
|
||
|
// If we're batching and there are batches remaining, advance to the next batch. This isn't
|
||
|
// a retry, so increment the transaction number, reset the retries number, and don't set
|
||
|
// server or connection to nil to continue using the same connection.
|
||
|
if batching && len(op.Batches.Documents) > 0 {
|
||
|
if retrySupported && op.Client != nil && op.RetryMode != nil {
|
||
|
if *op.RetryMode > RetryNone {
|
||
|
op.Client.IncrementTxnNumber()
|
||
|
}
|
||
|
if *op.RetryMode == RetryOncePerCommand {
|
||
|
retries = 1
|
||
|
}
|
||
|
}
|
||
|
currIndex += len(op.Batches.Current)
|
||
|
op.Batches.ClearBatch()
|
||
|
continue
|
||
|
}
|
||
|
break
|
||
|
}
|
||
|
if len(operationErr.WriteErrors) > 0 || operationErr.WriteConcernError != nil {
|
||
|
return operationErr
|
||
|
}
|
||
|
return nil
|
||
|
}
|
||
|
|
||
|
// Retryable writes are supported if the server supports sessions, the operation is not
|
||
|
// within a transaction, and the write is acknowledged
|
||
|
func (op Operation) retryable(desc description.Server) bool {
|
||
|
switch op.Type {
|
||
|
case Write:
|
||
|
if op.Client != nil && (op.Client.Committing || op.Client.Aborting) {
|
||
|
return true
|
||
|
}
|
||
|
if retryWritesSupported(desc) &&
|
||
|
desc.WireVersion != nil && desc.WireVersion.Max >= 6 &&
|
||
|
op.Client != nil && !(op.Client.TransactionInProgress() || op.Client.TransactionStarting()) &&
|
||
|
writeconcern.AckWrite(op.WriteConcern) {
|
||
|
return true
|
||
|
}
|
||
|
case Read:
|
||
|
if op.Client != nil && (op.Client.Committing || op.Client.Aborting) {
|
||
|
return true
|
||
|
}
|
||
|
if desc.WireVersion != nil && desc.WireVersion.Max >= 6 &&
|
||
|
(op.Client == nil || !(op.Client.TransactionInProgress() || op.Client.TransactionStarting())) {
|
||
|
return true
|
||
|
}
|
||
|
}
|
||
|
return false
|
||
|
}
|
||
|
|
||
|
// roundTrip writes a wiremessage to the connection and then reads a wiremessage. The wm parameter
|
||
|
// is reused when reading the wiremessage.
|
||
|
func (op Operation) roundTrip(ctx context.Context, conn Connection, wm []byte) ([]byte, error) {
|
||
|
err := conn.WriteWireMessage(ctx, wm)
|
||
|
if err != nil {
|
||
|
return nil, op.networkError(err)
|
||
|
}
|
||
|
|
||
|
return op.readWireMessage(ctx, conn, wm)
|
||
|
}
|
||
|
|
||
|
func (op Operation) readWireMessage(ctx context.Context, conn Connection, wm []byte) ([]byte, error) {
|
||
|
var err error
|
||
|
|
||
|
wm, err = conn.ReadWireMessage(ctx, wm[:0])
|
||
|
if err != nil {
|
||
|
return nil, op.networkError(err)
|
||
|
}
|
||
|
|
||
|
// If we're using a streamable connection, we set its streaming state based on the moreToCome flag in the server
|
||
|
// response.
|
||
|
if streamer, ok := conn.(StreamerConnection); ok {
|
||
|
streamer.SetStreaming(wiremessage.IsMsgMoreToCome(wm))
|
||
|
}
|
||
|
|
||
|
// decompress wiremessage
|
||
|
wm, err = op.decompressWireMessage(wm)
|
||
|
if err != nil {
|
||
|
return nil, err
|
||
|
}
|
||
|
|
||
|
// decode
|
||
|
res, err := op.decodeResult(wm)
|
||
|
// Update cluster/operation time and recovery tokens before handling the error to ensure we're properly updating
|
||
|
// everything.
|
||
|
op.updateClusterTimes(res)
|
||
|
op.updateOperationTime(res)
|
||
|
op.Client.UpdateRecoveryToken(bson.Raw(res))
|
||
|
|
||
|
// Update snapshot time if operation was a "find", "aggregate" or "distinct".
|
||
|
if op.cmdName == "find" || op.cmdName == "aggregate" || op.cmdName == "distinct" {
|
||
|
op.Client.UpdateSnapshotTime(res)
|
||
|
}
|
||
|
|
||
|
if err != nil {
|
||
|
return res, err
|
||
|
}
|
||
|
|
||
|
// If there is no error, automatically attempt to decrypt all results if client side encryption is enabled.
|
||
|
if op.Crypt != nil {
|
||
|
return op.Crypt.Decrypt(ctx, res)
|
||
|
}
|
||
|
return res, nil
|
||
|
}
|
||
|
|
||
|
// networkError wraps the provided error in an Error with label "NetworkError" and, if a transaction
|
||
|
// is running or committing, the appropriate transaction state labels. The returned error indicates
|
||
|
// the operation should be retried for reads and writes. If err is nil, networkError returns nil.
|
||
|
func (op Operation) networkError(err error) error {
|
||
|
if err == nil {
|
||
|
return nil
|
||
|
}
|
||
|
|
||
|
labels := []string{NetworkError}
|
||
|
if op.Client != nil {
|
||
|
op.Client.MarkDirty()
|
||
|
}
|
||
|
if op.Client != nil && op.Client.TransactionRunning() && !op.Client.Committing {
|
||
|
labels = append(labels, TransientTransactionError)
|
||
|
}
|
||
|
if op.Client != nil && op.Client.Committing {
|
||
|
labels = append(labels, UnknownTransactionCommitResult)
|
||
|
}
|
||
|
return Error{Message: err.Error(), Labels: labels, Wrapped: err}
|
||
|
}
|
||
|
|
||
|
// moreToComeRoundTrip writes a wiremessage to the provided connection. This is used when an OP_MSG is
|
||
|
// being sent with the moreToCome bit set.
|
||
|
func (op *Operation) moreToComeRoundTrip(ctx context.Context, conn Connection, wm []byte) ([]byte, error) {
|
||
|
err := conn.WriteWireMessage(ctx, wm)
|
||
|
if err != nil {
|
||
|
if op.Client != nil {
|
||
|
op.Client.MarkDirty()
|
||
|
}
|
||
|
err = Error{Message: err.Error(), Labels: []string{TransientTransactionError, NetworkError}, Wrapped: err}
|
||
|
}
|
||
|
return bsoncore.BuildDocument(nil, bsoncore.AppendInt32Element(nil, "ok", 1)), err
|
||
|
}
|
||
|
|
||
|
// decompressWireMessage handles decompressing a wiremessage. If the wiremessage
|
||
|
// is not compressed, this method will return the wiremessage.
|
||
|
func (Operation) decompressWireMessage(wm []byte) ([]byte, error) {
|
||
|
// read the header and ensure this is a compressed wire message
|
||
|
length, reqid, respto, opcode, rem, ok := wiremessage.ReadHeader(wm)
|
||
|
if !ok || len(wm) < int(length) {
|
||
|
return nil, errors.New("malformed wire message: insufficient bytes")
|
||
|
}
|
||
|
if opcode != wiremessage.OpCompressed {
|
||
|
return wm, nil
|
||
|
}
|
||
|
// get the original opcode and uncompressed size
|
||
|
opcode, rem, ok = wiremessage.ReadCompressedOriginalOpCode(rem)
|
||
|
if !ok {
|
||
|
return nil, errors.New("malformed OP_COMPRESSED: missing original opcode")
|
||
|
}
|
||
|
uncompressedSize, rem, ok := wiremessage.ReadCompressedUncompressedSize(rem)
|
||
|
if !ok {
|
||
|
return nil, errors.New("malformed OP_COMPRESSED: missing uncompressed size")
|
||
|
}
|
||
|
// get the compressor ID and decompress the message
|
||
|
compressorID, rem, ok := wiremessage.ReadCompressedCompressorID(rem)
|
||
|
if !ok {
|
||
|
return nil, errors.New("malformed OP_COMPRESSED: missing compressor ID")
|
||
|
}
|
||
|
compressedSize := length - 25 // header (16) + original opcode (4) + uncompressed size (4) + compressor ID (1)
|
||
|
// return the original wiremessage
|
||
|
msg, rem, ok := wiremessage.ReadCompressedCompressedMessage(rem, compressedSize)
|
||
|
if !ok {
|
||
|
return nil, errors.New("malformed OP_COMPRESSED: insufficient bytes for compressed wiremessage")
|
||
|
}
|
||
|
|
||
|
header := make([]byte, 0, uncompressedSize+16)
|
||
|
header = wiremessage.AppendHeader(header, uncompressedSize+16, reqid, respto, opcode)
|
||
|
opts := CompressionOpts{
|
||
|
Compressor: compressorID,
|
||
|
UncompressedSize: uncompressedSize,
|
||
|
}
|
||
|
uncompressed, err := DecompressPayload(msg, opts)
|
||
|
if err != nil {
|
||
|
return nil, err
|
||
|
}
|
||
|
|
||
|
return append(header, uncompressed...), nil
|
||
|
}
|
||
|
|
||
|
func (op Operation) createWireMessage(ctx context.Context, dst []byte,
|
||
|
desc description.SelectedServer, conn Connection) ([]byte, startedInformation, error) {
|
||
|
// If topology is not LoadBalanced, API version is not declared, and wire version is unknown
|
||
|
// or less than 6, use OP_QUERY. Otherwise, use OP_MSG.
|
||
|
if desc.Kind != description.LoadBalanced && op.ServerAPI == nil &&
|
||
|
(desc.WireVersion == nil || desc.WireVersion.Max < wiremessage.OpmsgWireVersion) {
|
||
|
return op.createQueryWireMessage(dst, desc)
|
||
|
}
|
||
|
return op.createMsgWireMessage(ctx, dst, desc, conn)
|
||
|
}
|
||
|
|
||
|
func (op Operation) addBatchArray(dst []byte) []byte {
|
||
|
aidx, dst := bsoncore.AppendArrayElementStart(dst, op.Batches.Identifier)
|
||
|
for i, doc := range op.Batches.Current {
|
||
|
dst = bsoncore.AppendDocumentElement(dst, strconv.Itoa(i), doc)
|
||
|
}
|
||
|
dst, _ = bsoncore.AppendArrayEnd(dst, aidx)
|
||
|
return dst
|
||
|
}
|
||
|
|
||
|
func (op Operation) createQueryWireMessage(dst []byte, desc description.SelectedServer) ([]byte, startedInformation, error) {
|
||
|
var info startedInformation
|
||
|
flags := op.secondaryOK(desc)
|
||
|
var wmindex int32
|
||
|
info.requestID = wiremessage.NextRequestID()
|
||
|
wmindex, dst = wiremessage.AppendHeaderStart(dst, info.requestID, 0, wiremessage.OpQuery)
|
||
|
dst = wiremessage.AppendQueryFlags(dst, flags)
|
||
|
// FullCollectionName
|
||
|
dst = append(dst, op.Database...)
|
||
|
dst = append(dst, dollarCmd[:]...)
|
||
|
dst = append(dst, 0x00)
|
||
|
dst = wiremessage.AppendQueryNumberToSkip(dst, 0)
|
||
|
dst = wiremessage.AppendQueryNumberToReturn(dst, -1)
|
||
|
|
||
|
wrapper := int32(-1)
|
||
|
rp, err := op.createReadPref(desc, true)
|
||
|
if err != nil {
|
||
|
return dst, info, err
|
||
|
}
|
||
|
if len(rp) > 0 {
|
||
|
wrapper, dst = bsoncore.AppendDocumentStart(dst)
|
||
|
dst = bsoncore.AppendHeader(dst, bsontype.EmbeddedDocument, "$query")
|
||
|
}
|
||
|
idx, dst := bsoncore.AppendDocumentStart(dst)
|
||
|
dst, err = op.CommandFn(dst, desc)
|
||
|
if err != nil {
|
||
|
return dst, info, err
|
||
|
}
|
||
|
|
||
|
if op.Batches != nil && len(op.Batches.Current) > 0 {
|
||
|
dst = op.addBatchArray(dst)
|
||
|
}
|
||
|
|
||
|
dst, err = op.addReadConcern(dst, desc)
|
||
|
if err != nil {
|
||
|
return dst, info, err
|
||
|
}
|
||
|
|
||
|
dst, err = op.addWriteConcern(dst, desc)
|
||
|
if err != nil {
|
||
|
return dst, info, err
|
||
|
}
|
||
|
|
||
|
dst, err = op.addSession(dst, desc)
|
||
|
if err != nil {
|
||
|
return dst, info, err
|
||
|
}
|
||
|
|
||
|
dst = op.addClusterTime(dst, desc)
|
||
|
dst = op.addServerAPI(dst)
|
||
|
|
||
|
dst, _ = bsoncore.AppendDocumentEnd(dst, idx)
|
||
|
// Command monitoring only reports the document inside $query
|
||
|
info.cmd = dst[idx:]
|
||
|
|
||
|
if len(rp) > 0 {
|
||
|
var err error
|
||
|
dst = bsoncore.AppendDocumentElement(dst, "$readPreference", rp)
|
||
|
dst, err = bsoncore.AppendDocumentEnd(dst, wrapper)
|
||
|
if err != nil {
|
||
|
return dst, info, err
|
||
|
}
|
||
|
}
|
||
|
|
||
|
return bsoncore.UpdateLength(dst, wmindex, int32(len(dst[wmindex:]))), info, nil
|
||
|
}
|
||
|
|
||
|
func (op Operation) createMsgWireMessage(ctx context.Context, dst []byte, desc description.SelectedServer,
|
||
|
conn Connection) ([]byte, startedInformation, error) {
|
||
|
|
||
|
var info startedInformation
|
||
|
var flags wiremessage.MsgFlag
|
||
|
var wmindex int32
|
||
|
// We set the MoreToCome bit if we have a write concern, it's unacknowledged, and we either
|
||
|
// aren't batching or we are encoding the last batch.
|
||
|
if op.WriteConcern != nil && !writeconcern.AckWrite(op.WriteConcern) && (op.Batches == nil || len(op.Batches.Documents) == 0) {
|
||
|
flags = wiremessage.MoreToCome
|
||
|
}
|
||
|
// Set the ExhaustAllowed flag if the connection supports streaming. This will tell the server that it can
|
||
|
// respond with the MoreToCome flag and then stream responses over this connection.
|
||
|
if streamer, ok := conn.(StreamerConnection); ok && streamer.SupportsStreaming() {
|
||
|
flags |= wiremessage.ExhaustAllowed
|
||
|
}
|
||
|
|
||
|
info.requestID = wiremessage.NextRequestID()
|
||
|
wmindex, dst = wiremessage.AppendHeaderStart(dst, info.requestID, 0, wiremessage.OpMsg)
|
||
|
dst = wiremessage.AppendMsgFlags(dst, flags)
|
||
|
// Body
|
||
|
dst = wiremessage.AppendMsgSectionType(dst, wiremessage.SingleDocument)
|
||
|
|
||
|
idx, dst := bsoncore.AppendDocumentStart(dst)
|
||
|
|
||
|
dst, err := op.addCommandFields(ctx, dst, desc)
|
||
|
if err != nil {
|
||
|
return dst, info, err
|
||
|
}
|
||
|
dst, err = op.addReadConcern(dst, desc)
|
||
|
if err != nil {
|
||
|
return dst, info, err
|
||
|
}
|
||
|
dst, err = op.addWriteConcern(dst, desc)
|
||
|
if err != nil {
|
||
|
return dst, info, err
|
||
|
}
|
||
|
dst, err = op.addSession(dst, desc)
|
||
|
if err != nil {
|
||
|
return dst, info, err
|
||
|
}
|
||
|
|
||
|
dst = op.addClusterTime(dst, desc)
|
||
|
dst = op.addServerAPI(dst)
|
||
|
|
||
|
dst = bsoncore.AppendStringElement(dst, "$db", op.Database)
|
||
|
rp, err := op.createReadPref(desc, false)
|
||
|
if err != nil {
|
||
|
return dst, info, err
|
||
|
}
|
||
|
if len(rp) > 0 {
|
||
|
dst = bsoncore.AppendDocumentElement(dst, "$readPreference", rp)
|
||
|
}
|
||
|
|
||
|
dst, _ = bsoncore.AppendDocumentEnd(dst, idx)
|
||
|
// The command document for monitoring shouldn't include the type 1 payload as a document sequence
|
||
|
info.cmd = dst[idx:]
|
||
|
|
||
|
// add batch as a document sequence if auto encryption is not enabled
|
||
|
// if auto encryption is enabled, the batch will already be an array in the command document
|
||
|
if !op.shouldEncrypt() && op.Batches != nil && len(op.Batches.Current) > 0 {
|
||
|
info.documentSequenceIncluded = true
|
||
|
dst = wiremessage.AppendMsgSectionType(dst, wiremessage.DocumentSequence)
|
||
|
idx, dst = bsoncore.ReserveLength(dst)
|
||
|
|
||
|
dst = append(dst, op.Batches.Identifier...)
|
||
|
dst = append(dst, 0x00)
|
||
|
|
||
|
for _, doc := range op.Batches.Current {
|
||
|
dst = append(dst, doc...)
|
||
|
}
|
||
|
|
||
|
dst = bsoncore.UpdateLength(dst, idx, int32(len(dst[idx:])))
|
||
|
}
|
||
|
|
||
|
return bsoncore.UpdateLength(dst, wmindex, int32(len(dst[wmindex:]))), info, nil
|
||
|
}
|
||
|
|
||
|
// addCommandFields adds the fields for a command to the wire message in dst. This assumes that the start of the document
|
||
|
// has already been added and does not add the final 0 byte.
|
||
|
func (op Operation) addCommandFields(ctx context.Context, dst []byte, desc description.SelectedServer) ([]byte, error) {
|
||
|
if !op.shouldEncrypt() {
|
||
|
return op.CommandFn(dst, desc)
|
||
|
}
|
||
|
|
||
|
if desc.WireVersion.Max < cryptMinWireVersion {
|
||
|
return dst, errors.New("auto-encryption requires a MongoDB version of 4.2")
|
||
|
}
|
||
|
|
||
|
// create temporary command document
|
||
|
cidx, cmdDst := bsoncore.AppendDocumentStart(nil)
|
||
|
var err error
|
||
|
cmdDst, err = op.CommandFn(cmdDst, desc)
|
||
|
if err != nil {
|
||
|
return dst, err
|
||
|
}
|
||
|
// use a BSON array instead of a type 1 payload because mongocryptd will convert to arrays regardless
|
||
|
if op.Batches != nil && len(op.Batches.Current) > 0 {
|
||
|
cmdDst = op.addBatchArray(cmdDst)
|
||
|
}
|
||
|
cmdDst, _ = bsoncore.AppendDocumentEnd(cmdDst, cidx)
|
||
|
|
||
|
// encrypt the command
|
||
|
encrypted, err := op.Crypt.Encrypt(ctx, op.Database, cmdDst)
|
||
|
if err != nil {
|
||
|
return dst, err
|
||
|
}
|
||
|
// append encrypted command to original destination, removing the first 4 bytes (length) and final byte (terminator)
|
||
|
dst = append(dst, encrypted[4:len(encrypted)-1]...)
|
||
|
return dst, nil
|
||
|
}
|
||
|
|
||
|
// addServerAPI adds the relevant fields for server API specification to the wire message in dst.
|
||
|
func (op Operation) addServerAPI(dst []byte) []byte {
|
||
|
sa := op.ServerAPI
|
||
|
if sa == nil {
|
||
|
return dst
|
||
|
}
|
||
|
|
||
|
dst = bsoncore.AppendStringElement(dst, "apiVersion", sa.ServerAPIVersion)
|
||
|
if sa.Strict != nil {
|
||
|
dst = bsoncore.AppendBooleanElement(dst, "apiStrict", *sa.Strict)
|
||
|
}
|
||
|
if sa.DeprecationErrors != nil {
|
||
|
dst = bsoncore.AppendBooleanElement(dst, "apiDeprecationErrors", *sa.DeprecationErrors)
|
||
|
}
|
||
|
return dst
|
||
|
}
|
||
|
|
||
|
func (op Operation) addReadConcern(dst []byte, desc description.SelectedServer) ([]byte, error) {
|
||
|
if op.MinimumReadConcernWireVersion > 0 && (desc.WireVersion == nil || !desc.WireVersion.Includes(op.MinimumReadConcernWireVersion)) {
|
||
|
return dst, nil
|
||
|
}
|
||
|
rc := op.ReadConcern
|
||
|
client := op.Client
|
||
|
// Starting transaction's read concern overrides all others
|
||
|
if client != nil && client.TransactionStarting() && client.CurrentRc != nil {
|
||
|
rc = client.CurrentRc
|
||
|
}
|
||
|
|
||
|
// start transaction must append afterclustertime IF causally consistent and operation time exists
|
||
|
if rc == nil && client != nil && client.TransactionStarting() && client.Consistent && client.OperationTime != nil {
|
||
|
rc = readconcern.New()
|
||
|
}
|
||
|
|
||
|
if client != nil && client.Snapshot {
|
||
|
if desc.WireVersion.Max < readSnapshotMinWireVersion {
|
||
|
return dst, errors.New("snapshot reads require MongoDB 5.0 or later")
|
||
|
}
|
||
|
rc = readconcern.Snapshot()
|
||
|
}
|
||
|
|
||
|
if rc == nil {
|
||
|
return dst, nil
|
||
|
}
|
||
|
|
||
|
_, data, err := rc.MarshalBSONValue() // always returns a document
|
||
|
if err != nil {
|
||
|
return dst, err
|
||
|
}
|
||
|
|
||
|
if sessionsSupported(desc.WireVersion) && client != nil {
|
||
|
if client.Consistent && client.OperationTime != nil {
|
||
|
data = data[:len(data)-1] // remove the null byte
|
||
|
data = bsoncore.AppendTimestampElement(data, "afterClusterTime", client.OperationTime.T, client.OperationTime.I)
|
||
|
data, _ = bsoncore.AppendDocumentEnd(data, 0)
|
||
|
}
|
||
|
if client.Snapshot && client.SnapshotTime != nil {
|
||
|
data = data[:len(data)-1] // remove the null byte
|
||
|
data = bsoncore.AppendTimestampElement(data, "atClusterTime", client.SnapshotTime.T, client.SnapshotTime.I)
|
||
|
data, _ = bsoncore.AppendDocumentEnd(data, 0)
|
||
|
}
|
||
|
}
|
||
|
|
||
|
if len(data) == bsoncore.EmptyDocumentLength {
|
||
|
return dst, nil
|
||
|
}
|
||
|
return bsoncore.AppendDocumentElement(dst, "readConcern", data), nil
|
||
|
}
|
||
|
|
||
|
func (op Operation) addWriteConcern(dst []byte, desc description.SelectedServer) ([]byte, error) {
|
||
|
if op.MinimumWriteConcernWireVersion > 0 && (desc.WireVersion == nil || !desc.WireVersion.Includes(op.MinimumWriteConcernWireVersion)) {
|
||
|
return dst, nil
|
||
|
}
|
||
|
wc := op.WriteConcern
|
||
|
if wc == nil {
|
||
|
return dst, nil
|
||
|
}
|
||
|
|
||
|
t, data, err := wc.MarshalBSONValue()
|
||
|
if err == writeconcern.ErrEmptyWriteConcern {
|
||
|
return dst, nil
|
||
|
}
|
||
|
if err != nil {
|
||
|
return dst, err
|
||
|
}
|
||
|
|
||
|
return append(bsoncore.AppendHeader(dst, t, "writeConcern"), data...), nil
|
||
|
}
|
||
|
|
||
|
func (op Operation) addSession(dst []byte, desc description.SelectedServer) ([]byte, error) {
|
||
|
client := op.Client
|
||
|
if client == nil || !sessionsSupported(desc.WireVersion) || desc.SessionTimeoutMinutes == 0 {
|
||
|
return dst, nil
|
||
|
}
|
||
|
if err := client.UpdateUseTime(); err != nil {
|
||
|
return dst, err
|
||
|
}
|
||
|
dst = bsoncore.AppendDocumentElement(dst, "lsid", client.SessionID)
|
||
|
|
||
|
var addedTxnNumber bool
|
||
|
if op.Type == Write && client.RetryWrite {
|
||
|
addedTxnNumber = true
|
||
|
dst = bsoncore.AppendInt64Element(dst, "txnNumber", op.Client.TxnNumber)
|
||
|
}
|
||
|
if client.TransactionRunning() || client.RetryingCommit {
|
||
|
if !addedTxnNumber {
|
||
|
dst = bsoncore.AppendInt64Element(dst, "txnNumber", op.Client.TxnNumber)
|
||
|
}
|
||
|
if client.TransactionStarting() {
|
||
|
dst = bsoncore.AppendBooleanElement(dst, "startTransaction", true)
|
||
|
}
|
||
|
dst = bsoncore.AppendBooleanElement(dst, "autocommit", false)
|
||
|
}
|
||
|
|
||
|
return dst, client.ApplyCommand(desc.Server)
|
||
|
}
|
||
|
|
||
|
func (op Operation) addClusterTime(dst []byte, desc description.SelectedServer) []byte {
|
||
|
client, clock := op.Client, op.Clock
|
||
|
if (clock == nil && client == nil) || !sessionsSupported(desc.WireVersion) {
|
||
|
return dst
|
||
|
}
|
||
|
clusterTime := clock.GetClusterTime()
|
||
|
if client != nil {
|
||
|
clusterTime = session.MaxClusterTime(clusterTime, client.ClusterTime)
|
||
|
}
|
||
|
if clusterTime == nil {
|
||
|
return dst
|
||
|
}
|
||
|
val, err := clusterTime.LookupErr("$clusterTime")
|
||
|
if err != nil {
|
||
|
return dst
|
||
|
}
|
||
|
return append(bsoncore.AppendHeader(dst, val.Type, "$clusterTime"), val.Value...)
|
||
|
// return bsoncore.AppendDocumentElement(dst, "$clusterTime", clusterTime)
|
||
|
}
|
||
|
|
||
|
// updateClusterTimes updates the cluster times for the session and cluster clock attached to this
|
||
|
// operation. While the session's AdvanceClusterTime may return an error, this method does not
|
||
|
// because an error being returned from this method will not be returned further up.
|
||
|
func (op Operation) updateClusterTimes(response bsoncore.Document) {
|
||
|
// Extract cluster time.
|
||
|
value, err := response.LookupErr("$clusterTime")
|
||
|
if err != nil {
|
||
|
// $clusterTime not included by the server
|
||
|
return
|
||
|
}
|
||
|
clusterTime := bsoncore.BuildDocumentFromElements(nil, bsoncore.AppendValueElement(nil, "$clusterTime", value))
|
||
|
|
||
|
sess, clock := op.Client, op.Clock
|
||
|
|
||
|
if sess != nil {
|
||
|
_ = sess.AdvanceClusterTime(bson.Raw(clusterTime))
|
||
|
}
|
||
|
|
||
|
if clock != nil {
|
||
|
clock.AdvanceClusterTime(bson.Raw(clusterTime))
|
||
|
}
|
||
|
}
|
||
|
|
||
|
// updateOperationTime updates the operation time on the session attached to this operation. While
|
||
|
// the session's AdvanceOperationTime method may return an error, this method does not because an
|
||
|
// error being returned from this method will not be returned further up.
|
||
|
func (op Operation) updateOperationTime(response bsoncore.Document) {
|
||
|
sess := op.Client
|
||
|
if sess == nil {
|
||
|
return
|
||
|
}
|
||
|
|
||
|
opTimeElem, err := response.LookupErr("operationTime")
|
||
|
if err != nil {
|
||
|
// operationTime not included by the server
|
||
|
return
|
||
|
}
|
||
|
|
||
|
t, i := opTimeElem.Timestamp()
|
||
|
_ = sess.AdvanceOperationTime(&primitive.Timestamp{
|
||
|
T: t,
|
||
|
I: i,
|
||
|
})
|
||
|
}
|
||
|
|
||
|
func (op Operation) getReadPrefBasedOnTransaction() (*readpref.ReadPref, error) {
|
||
|
if op.Client != nil && op.Client.TransactionRunning() {
|
||
|
// Transaction's read preference always takes priority
|
||
|
rp := op.Client.CurrentRp
|
||
|
// Reads in a transaction must have read preference primary
|
||
|
// This must not be checked in startTransaction
|
||
|
if rp != nil && !op.Client.TransactionStarting() && rp.Mode() != readpref.PrimaryMode {
|
||
|
return nil, ErrNonPrimaryReadPref
|
||
|
}
|
||
|
return rp, nil
|
||
|
}
|
||
|
return op.ReadPreference, nil
|
||
|
}
|
||
|
|
||
|
func (op Operation) createReadPref(desc description.SelectedServer, isOpQuery bool) (bsoncore.Document, error) {
|
||
|
// TODO(GODRIVER-2231): Instead of checking if isOutputAggregate and desc.Server.WireVersion.Max < 13,
|
||
|
// somehow check if supplied readPreference was "overwritten" with primary in description.selectForReplicaSet.
|
||
|
if desc.Server.Kind == description.Standalone || (isOpQuery && desc.Server.Kind != description.Mongos) ||
|
||
|
op.Type == Write || (op.IsOutputAggregate && desc.Server.WireVersion.Max < 13) {
|
||
|
// Don't send read preference for:
|
||
|
// 1. all standalones
|
||
|
// 2. non-mongos when using OP_QUERY
|
||
|
// 3. all writes
|
||
|
// 4. when operation is an aggregate with an output stage, and selected server's wire
|
||
|
// version is < 13
|
||
|
return nil, nil
|
||
|
}
|
||
|
|
||
|
idx, doc := bsoncore.AppendDocumentStart(nil)
|
||
|
rp, err := op.getReadPrefBasedOnTransaction()
|
||
|
if err != nil {
|
||
|
return nil, err
|
||
|
}
|
||
|
|
||
|
if rp == nil {
|
||
|
if desc.Kind == description.Single && desc.Server.Kind != description.Mongos {
|
||
|
doc = bsoncore.AppendStringElement(doc, "mode", "primaryPreferred")
|
||
|
doc, _ = bsoncore.AppendDocumentEnd(doc, idx)
|
||
|
return doc, nil
|
||
|
}
|
||
|
return nil, nil
|
||
|
}
|
||
|
|
||
|
switch rp.Mode() {
|
||
|
case readpref.PrimaryMode:
|
||
|
if desc.Server.Kind == description.Mongos {
|
||
|
return nil, nil
|
||
|
}
|
||
|
if desc.Kind == description.Single {
|
||
|
doc = bsoncore.AppendStringElement(doc, "mode", "primaryPreferred")
|
||
|
doc, _ = bsoncore.AppendDocumentEnd(doc, idx)
|
||
|
return doc, nil
|
||
|
}
|
||
|
doc = bsoncore.AppendStringElement(doc, "mode", "primary")
|
||
|
case readpref.PrimaryPreferredMode:
|
||
|
doc = bsoncore.AppendStringElement(doc, "mode", "primaryPreferred")
|
||
|
case readpref.SecondaryPreferredMode:
|
||
|
_, ok := rp.MaxStaleness()
|
||
|
if desc.Server.Kind == description.Mongos && isOpQuery && !ok && len(rp.TagSets()) == 0 && rp.HedgeEnabled() == nil {
|
||
|
return nil, nil
|
||
|
}
|
||
|
doc = bsoncore.AppendStringElement(doc, "mode", "secondaryPreferred")
|
||
|
case readpref.SecondaryMode:
|
||
|
doc = bsoncore.AppendStringElement(doc, "mode", "secondary")
|
||
|
case readpref.NearestMode:
|
||
|
doc = bsoncore.AppendStringElement(doc, "mode", "nearest")
|
||
|
}
|
||
|
|
||
|
sets := make([]bsoncore.Document, 0, len(rp.TagSets()))
|
||
|
for _, ts := range rp.TagSets() {
|
||
|
i, set := bsoncore.AppendDocumentStart(nil)
|
||
|
for _, t := range ts {
|
||
|
set = bsoncore.AppendStringElement(set, t.Name, t.Value)
|
||
|
}
|
||
|
set, _ = bsoncore.AppendDocumentEnd(set, i)
|
||
|
sets = append(sets, set)
|
||
|
}
|
||
|
if len(sets) > 0 {
|
||
|
var aidx int32
|
||
|
aidx, doc = bsoncore.AppendArrayElementStart(doc, "tags")
|
||
|
for i, set := range sets {
|
||
|
doc = bsoncore.AppendDocumentElement(doc, strconv.Itoa(i), set)
|
||
|
}
|
||
|
doc, _ = bsoncore.AppendArrayEnd(doc, aidx)
|
||
|
}
|
||
|
|
||
|
if d, ok := rp.MaxStaleness(); ok {
|
||
|
doc = bsoncore.AppendInt32Element(doc, "maxStalenessSeconds", int32(d.Seconds()))
|
||
|
}
|
||
|
|
||
|
if hedgeEnabled := rp.HedgeEnabled(); hedgeEnabled != nil {
|
||
|
var hedgeIdx int32
|
||
|
hedgeIdx, doc = bsoncore.AppendDocumentElementStart(doc, "hedge")
|
||
|
doc = bsoncore.AppendBooleanElement(doc, "enabled", *hedgeEnabled)
|
||
|
doc, err = bsoncore.AppendDocumentEnd(doc, hedgeIdx)
|
||
|
if err != nil {
|
||
|
return nil, fmt.Errorf("error creating hedge document: %v", err)
|
||
|
}
|
||
|
}
|
||
|
|
||
|
doc, _ = bsoncore.AppendDocumentEnd(doc, idx)
|
||
|
return doc, nil
|
||
|
}
|
||
|
|
||
|
func (op Operation) secondaryOK(desc description.SelectedServer) wiremessage.QueryFlag {
|
||
|
if desc.Kind == description.Single && desc.Server.Kind != description.Mongos {
|
||
|
return wiremessage.SecondaryOK
|
||
|
}
|
||
|
|
||
|
if rp := op.ReadPreference; rp != nil && rp.Mode() != readpref.PrimaryMode {
|
||
|
return wiremessage.SecondaryOK
|
||
|
}
|
||
|
|
||
|
return 0
|
||
|
}
|
||
|
|
||
|
func (Operation) canCompress(cmd string) bool {
|
||
|
if cmd == internal.LegacyHello || cmd == "hello" || cmd == "saslStart" || cmd == "saslContinue" || cmd == "getnonce" || cmd == "authenticate" ||
|
||
|
cmd == "createUser" || cmd == "updateUser" || cmd == "copydbSaslStart" || cmd == "copydbgetnonce" || cmd == "copydb" {
|
||
|
return false
|
||
|
}
|
||
|
return true
|
||
|
}
|
||
|
|
||
|
// decodeOpReply extracts the necessary information from an OP_REPLY wire message.
|
||
|
// includesHeader: specifies whether or not wm includes the message header
|
||
|
// Returns the decoded OP_REPLY. If the err field of the returned opReply is non-nil, an error occurred while decoding
|
||
|
// or validating the response and the other fields are undefined.
|
||
|
func (Operation) decodeOpReply(wm []byte, includesHeader bool) opReply {
|
||
|
var reply opReply
|
||
|
var ok bool
|
||
|
|
||
|
if includesHeader {
|
||
|
wmLength := len(wm)
|
||
|
var length int32
|
||
|
var opcode wiremessage.OpCode
|
||
|
length, _, _, opcode, wm, ok = wiremessage.ReadHeader(wm)
|
||
|
if !ok || int(length) > wmLength {
|
||
|
reply.err = errors.New("malformed wire message: insufficient bytes")
|
||
|
return reply
|
||
|
}
|
||
|
if opcode != wiremessage.OpReply {
|
||
|
reply.err = errors.New("malformed wire message: incorrect opcode")
|
||
|
return reply
|
||
|
}
|
||
|
}
|
||
|
|
||
|
reply.responseFlags, wm, ok = wiremessage.ReadReplyFlags(wm)
|
||
|
if !ok {
|
||
|
reply.err = errors.New("malformed OP_REPLY: missing flags")
|
||
|
return reply
|
||
|
}
|
||
|
reply.cursorID, wm, ok = wiremessage.ReadReplyCursorID(wm)
|
||
|
if !ok {
|
||
|
reply.err = errors.New("malformed OP_REPLY: missing cursorID")
|
||
|
return reply
|
||
|
}
|
||
|
reply.startingFrom, wm, ok = wiremessage.ReadReplyStartingFrom(wm)
|
||
|
if !ok {
|
||
|
reply.err = errors.New("malformed OP_REPLY: missing startingFrom")
|
||
|
return reply
|
||
|
}
|
||
|
reply.numReturned, wm, ok = wiremessage.ReadReplyNumberReturned(wm)
|
||
|
if !ok {
|
||
|
reply.err = errors.New("malformed OP_REPLY: missing numberReturned")
|
||
|
return reply
|
||
|
}
|
||
|
reply.documents, wm, ok = wiremessage.ReadReplyDocuments(wm)
|
||
|
if !ok {
|
||
|
reply.err = errors.New("malformed OP_REPLY: could not read documents from reply")
|
||
|
}
|
||
|
|
||
|
if reply.responseFlags&wiremessage.QueryFailure == wiremessage.QueryFailure {
|
||
|
reply.err = QueryFailureError{
|
||
|
Message: "command failure",
|
||
|
Response: reply.documents[0],
|
||
|
}
|
||
|
return reply
|
||
|
}
|
||
|
if reply.responseFlags&wiremessage.CursorNotFound == wiremessage.CursorNotFound {
|
||
|
reply.err = ErrCursorNotFound
|
||
|
return reply
|
||
|
}
|
||
|
if reply.numReturned != int32(len(reply.documents)) {
|
||
|
reply.err = ErrReplyDocumentMismatch
|
||
|
return reply
|
||
|
}
|
||
|
|
||
|
return reply
|
||
|
}
|
||
|
|
||
|
func (op Operation) decodeResult(wm []byte) (bsoncore.Document, error) {
|
||
|
wmLength := len(wm)
|
||
|
length, _, _, opcode, wm, ok := wiremessage.ReadHeader(wm)
|
||
|
if !ok || int(length) > wmLength {
|
||
|
return nil, errors.New("malformed wire message: insufficient bytes")
|
||
|
}
|
||
|
|
||
|
wm = wm[:wmLength-16] // constrain to just this wiremessage, incase there are multiple in the slice
|
||
|
|
||
|
switch opcode {
|
||
|
case wiremessage.OpReply:
|
||
|
reply := op.decodeOpReply(wm, false)
|
||
|
if reply.err != nil {
|
||
|
return nil, reply.err
|
||
|
}
|
||
|
if reply.numReturned == 0 {
|
||
|
return nil, ErrNoDocCommandResponse
|
||
|
}
|
||
|
if reply.numReturned > 1 {
|
||
|
return nil, ErrMultiDocCommandResponse
|
||
|
}
|
||
|
rdr := reply.documents[0]
|
||
|
if err := rdr.Validate(); err != nil {
|
||
|
return nil, NewCommandResponseError("malformed OP_REPLY: invalid document", err)
|
||
|
}
|
||
|
|
||
|
return rdr, ExtractErrorFromServerResponse(rdr)
|
||
|
case wiremessage.OpMsg:
|
||
|
_, wm, ok = wiremessage.ReadMsgFlags(wm)
|
||
|
if !ok {
|
||
|
return nil, errors.New("malformed wire message: missing OP_MSG flags")
|
||
|
}
|
||
|
|
||
|
var res bsoncore.Document
|
||
|
for len(wm) > 0 {
|
||
|
var stype wiremessage.SectionType
|
||
|
stype, wm, ok = wiremessage.ReadMsgSectionType(wm)
|
||
|
if !ok {
|
||
|
return nil, errors.New("malformed wire message: insuffienct bytes to read section type")
|
||
|
}
|
||
|
|
||
|
switch stype {
|
||
|
case wiremessage.SingleDocument:
|
||
|
res, wm, ok = wiremessage.ReadMsgSectionSingleDocument(wm)
|
||
|
if !ok {
|
||
|
return nil, errors.New("malformed wire message: insufficient bytes to read single document")
|
||
|
}
|
||
|
case wiremessage.DocumentSequence:
|
||
|
// TODO(GODRIVER-617): Implement document sequence returns.
|
||
|
_, _, wm, ok = wiremessage.ReadMsgSectionDocumentSequence(wm)
|
||
|
if !ok {
|
||
|
return nil, errors.New("malformed wire message: insufficient bytes to read document sequence")
|
||
|
}
|
||
|
default:
|
||
|
return nil, fmt.Errorf("malformed wire message: uknown section type %v", stype)
|
||
|
}
|
||
|
}
|
||
|
|
||
|
err := res.Validate()
|
||
|
if err != nil {
|
||
|
return nil, NewCommandResponseError("malformed OP_MSG: invalid document", err)
|
||
|
}
|
||
|
|
||
|
return res, ExtractErrorFromServerResponse(res)
|
||
|
default:
|
||
|
return nil, fmt.Errorf("cannot decode result from %s", opcode)
|
||
|
}
|
||
|
}
|
||
|
|
||
|
// getCommandName returns the name of the command from the given BSON document.
|
||
|
func (op Operation) getCommandName(doc []byte) string {
|
||
|
// skip 4 bytes for document length and 1 byte for element type
|
||
|
idx := bytes.IndexByte(doc[5:], 0x00) // look for the 0 byte after the command name
|
||
|
return string(doc[5 : idx+5])
|
||
|
}
|
||
|
|
||
|
func (op *Operation) redactCommand(cmd string, doc bsoncore.Document) bool {
|
||
|
if cmd == "authenticate" || cmd == "saslStart" || cmd == "saslContinue" || cmd == "getnonce" || cmd == "createUser" ||
|
||
|
cmd == "updateUser" || cmd == "copydbgetnonce" || cmd == "copydbsaslstart" || cmd == "copydb" {
|
||
|
|
||
|
return true
|
||
|
}
|
||
|
if strings.ToLower(cmd) != internal.LegacyHelloLowercase && cmd != "hello" {
|
||
|
return false
|
||
|
}
|
||
|
|
||
|
// A hello without speculative authentication can be monitored.
|
||
|
_, err := doc.LookupErr("speculativeAuthenticate")
|
||
|
return err == nil
|
||
|
}
|
||
|
|
||
|
// publishStartedEvent publishes a CommandStartedEvent to the operation's command monitor if possible. If the command is
|
||
|
// an unacknowledged write, a CommandSucceededEvent will be published as well. If started events are not being monitored,
|
||
|
// no events are published.
|
||
|
func (op Operation) publishStartedEvent(ctx context.Context, info startedInformation) {
|
||
|
if op.CommandMonitor == nil || op.CommandMonitor.Started == nil {
|
||
|
return
|
||
|
}
|
||
|
|
||
|
// Make a copy of the command. Redact if the command is security sensitive and cannot be monitored.
|
||
|
// If there was a type 1 payload for the current batch, convert it to a BSON array.
|
||
|
cmdCopy := bson.Raw{}
|
||
|
if !info.redacted {
|
||
|
cmdCopy = make([]byte, len(info.cmd))
|
||
|
copy(cmdCopy, info.cmd)
|
||
|
if info.documentSequenceIncluded {
|
||
|
cmdCopy = cmdCopy[:len(info.cmd)-1] // remove 0 byte at end
|
||
|
cmdCopy = op.addBatchArray(cmdCopy)
|
||
|
cmdCopy, _ = bsoncore.AppendDocumentEnd(cmdCopy, 0) // add back 0 byte and update length
|
||
|
}
|
||
|
}
|
||
|
|
||
|
started := &event.CommandStartedEvent{
|
||
|
Command: cmdCopy,
|
||
|
DatabaseName: op.Database,
|
||
|
CommandName: info.cmdName,
|
||
|
RequestID: int64(info.requestID),
|
||
|
ConnectionID: info.connID,
|
||
|
ServerConnectionID: info.serverConnID,
|
||
|
ServiceID: info.serviceID,
|
||
|
}
|
||
|
op.CommandMonitor.Started(ctx, started)
|
||
|
}
|
||
|
|
||
|
// publishFinishedEvent publishes either a CommandSucceededEvent or a CommandFailedEvent to the operation's command
|
||
|
// monitor if possible. If success/failure events aren't being monitored, no events are published.
|
||
|
func (op Operation) publishFinishedEvent(ctx context.Context, info finishedInformation) {
|
||
|
success := info.cmdErr == nil
|
||
|
if _, ok := info.cmdErr.(WriteCommandError); ok {
|
||
|
success = true
|
||
|
}
|
||
|
if op.CommandMonitor == nil || (success && op.CommandMonitor.Succeeded == nil) || (!success && op.CommandMonitor.Failed == nil) {
|
||
|
return
|
||
|
}
|
||
|
|
||
|
var durationNanos int64
|
||
|
var emptyTime time.Time
|
||
|
if info.startTime != emptyTime {
|
||
|
durationNanos = time.Since(info.startTime).Nanoseconds()
|
||
|
}
|
||
|
|
||
|
finished := event.CommandFinishedEvent{
|
||
|
CommandName: info.cmdName,
|
||
|
RequestID: int64(info.requestID),
|
||
|
ConnectionID: info.connID,
|
||
|
DurationNanos: durationNanos,
|
||
|
ServerConnectionID: info.serverConnID,
|
||
|
ServiceID: info.serviceID,
|
||
|
}
|
||
|
|
||
|
if success {
|
||
|
res := bson.Raw{}
|
||
|
// Only copy the reply for commands that are not security sensitive
|
||
|
if !info.redacted {
|
||
|
res = make([]byte, len(info.response))
|
||
|
copy(res, info.response)
|
||
|
}
|
||
|
successEvent := &event.CommandSucceededEvent{
|
||
|
Reply: res,
|
||
|
CommandFinishedEvent: finished,
|
||
|
}
|
||
|
op.CommandMonitor.Succeeded(ctx, successEvent)
|
||
|
return
|
||
|
}
|
||
|
|
||
|
failedEvent := &event.CommandFailedEvent{
|
||
|
Failure: info.cmdErr.Error(),
|
||
|
CommandFinishedEvent: finished,
|
||
|
}
|
||
|
op.CommandMonitor.Failed(ctx, failedEvent)
|
||
|
}
|
||
|
|
||
|
// sessionsSupported returns true of the given server version indicates that it supports sessions.
|
||
|
func sessionsSupported(wireVersion *description.VersionRange) bool {
|
||
|
return wireVersion != nil && wireVersion.Max >= 6
|
||
|
}
|
||
|
|
||
|
// retryWritesSupported returns true if this description represents a server that supports retryable writes.
|
||
|
func retryWritesSupported(s description.Server) bool {
|
||
|
return s.SessionTimeoutMinutes != 0 && s.Kind != description.Standalone
|
||
|
}
|