You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
637 lines
16 KiB
637 lines
16 KiB
// Copyright (C) MongoDB, Inc. 2017-present. |
|
// |
|
// Licensed under the Apache License, Version 2.0 (the "License"); you may |
|
// not use this file except in compliance with the License. You may obtain |
|
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 |
|
|
|
package bsonx |
|
|
|
import ( |
|
"errors" |
|
"fmt" |
|
"reflect" |
|
|
|
"go.mongodb.org/mongo-driver/bson/bsoncodec" |
|
"go.mongodb.org/mongo-driver/bson/bsonrw" |
|
"go.mongodb.org/mongo-driver/bson/bsontype" |
|
) |
|
|
|
var primitiveCodecs PrimitiveCodecs |
|
|
|
var tDocument = reflect.TypeOf((Doc)(nil)) |
|
var tArray = reflect.TypeOf((Arr)(nil)) |
|
var tValue = reflect.TypeOf(Val{}) |
|
var tElementSlice = reflect.TypeOf(([]Elem)(nil)) |
|
|
|
// PrimitiveCodecs is a namespace for all of the default bsoncodec.Codecs for the primitive types |
|
// defined in this package. |
|
type PrimitiveCodecs struct{} |
|
|
|
// RegisterPrimitiveCodecs will register the encode and decode methods attached to PrimitiveCodecs |
|
// with the provided RegistryBuilder. if rb is nil, a new empty RegistryBuilder will be created. |
|
func (pc PrimitiveCodecs) RegisterPrimitiveCodecs(rb *bsoncodec.RegistryBuilder) { |
|
if rb == nil { |
|
panic(errors.New("argument to RegisterPrimitiveCodecs must not be nil")) |
|
} |
|
|
|
rb. |
|
RegisterTypeEncoder(tDocument, bsoncodec.ValueEncoderFunc(pc.DocumentEncodeValue)). |
|
RegisterTypeEncoder(tArray, bsoncodec.ValueEncoderFunc(pc.ArrayEncodeValue)). |
|
RegisterTypeEncoder(tValue, bsoncodec.ValueEncoderFunc(pc.ValueEncodeValue)). |
|
RegisterTypeEncoder(tElementSlice, bsoncodec.ValueEncoderFunc(pc.ElementSliceEncodeValue)). |
|
RegisterTypeDecoder(tDocument, bsoncodec.ValueDecoderFunc(pc.DocumentDecodeValue)). |
|
RegisterTypeDecoder(tArray, bsoncodec.ValueDecoderFunc(pc.ArrayDecodeValue)). |
|
RegisterTypeDecoder(tValue, bsoncodec.ValueDecoderFunc(pc.ValueDecodeValue)). |
|
RegisterTypeDecoder(tElementSlice, bsoncodec.ValueDecoderFunc(pc.ElementSliceDecodeValue)) |
|
} |
|
|
|
// DocumentEncodeValue is the ValueEncoderFunc for *Document. |
|
func (pc PrimitiveCodecs) DocumentEncodeValue(ec bsoncodec.EncodeContext, vw bsonrw.ValueWriter, val reflect.Value) error { |
|
if !val.IsValid() || val.Type() != tDocument { |
|
return bsoncodec.ValueEncoderError{Name: "DocumentEncodeValue", Types: []reflect.Type{tDocument}, Received: val} |
|
} |
|
|
|
if val.IsNil() { |
|
return vw.WriteNull() |
|
} |
|
|
|
doc := val.Interface().(Doc) |
|
|
|
dw, err := vw.WriteDocument() |
|
if err != nil { |
|
return err |
|
} |
|
|
|
return pc.encodeDocument(ec, dw, doc) |
|
} |
|
|
|
// DocumentDecodeValue is the ValueDecoderFunc for *Document. |
|
func (pc PrimitiveCodecs) DocumentDecodeValue(dctx bsoncodec.DecodeContext, vr bsonrw.ValueReader, val reflect.Value) error { |
|
if !val.CanSet() || val.Type() != tDocument { |
|
return bsoncodec.ValueDecoderError{Name: "DocumentDecodeValue", Types: []reflect.Type{tDocument}, Received: val} |
|
} |
|
|
|
return pc.documentDecodeValue(dctx, vr, val.Addr().Interface().(*Doc)) |
|
} |
|
|
|
func (pc PrimitiveCodecs) documentDecodeValue(dctx bsoncodec.DecodeContext, vr bsonrw.ValueReader, doc *Doc) error { |
|
|
|
dr, err := vr.ReadDocument() |
|
if err != nil { |
|
return err |
|
} |
|
|
|
return pc.decodeDocument(dctx, dr, doc) |
|
} |
|
|
|
// ArrayEncodeValue is the ValueEncoderFunc for *Array. |
|
func (pc PrimitiveCodecs) ArrayEncodeValue(ec bsoncodec.EncodeContext, vw bsonrw.ValueWriter, val reflect.Value) error { |
|
if !val.IsValid() || val.Type() != tArray { |
|
return bsoncodec.ValueEncoderError{Name: "ArrayEncodeValue", Types: []reflect.Type{tArray}, Received: val} |
|
} |
|
|
|
if val.IsNil() { |
|
return vw.WriteNull() |
|
} |
|
|
|
arr := val.Interface().(Arr) |
|
|
|
aw, err := vw.WriteArray() |
|
if err != nil { |
|
return err |
|
} |
|
|
|
for _, val := range arr { |
|
dvw, err := aw.WriteArrayElement() |
|
if err != nil { |
|
return err |
|
} |
|
|
|
err = pc.encodeValue(ec, dvw, val) |
|
|
|
if err != nil { |
|
return err |
|
} |
|
} |
|
|
|
return aw.WriteArrayEnd() |
|
} |
|
|
|
// ArrayDecodeValue is the ValueDecoderFunc for *Array. |
|
func (pc PrimitiveCodecs) ArrayDecodeValue(dc bsoncodec.DecodeContext, vr bsonrw.ValueReader, val reflect.Value) error { |
|
if !val.CanSet() || val.Type() != tArray { |
|
return bsoncodec.ValueDecoderError{Name: "ArrayDecodeValue", Types: []reflect.Type{tArray}, Received: val} |
|
} |
|
|
|
ar, err := vr.ReadArray() |
|
if err != nil { |
|
return err |
|
} |
|
|
|
if val.IsNil() { |
|
val.Set(reflect.MakeSlice(tArray, 0, 0)) |
|
} |
|
val.SetLen(0) |
|
|
|
for { |
|
vr, err := ar.ReadValue() |
|
if err == bsonrw.ErrEOA { |
|
break |
|
} |
|
if err != nil { |
|
return err |
|
} |
|
|
|
var elem Val |
|
err = pc.valueDecodeValue(dc, vr, &elem) |
|
if err != nil { |
|
return err |
|
} |
|
|
|
val.Set(reflect.Append(val, reflect.ValueOf(elem))) |
|
} |
|
|
|
return nil |
|
} |
|
|
|
// ElementSliceEncodeValue is the ValueEncoderFunc for []*Element. |
|
func (pc PrimitiveCodecs) ElementSliceEncodeValue(ec bsoncodec.EncodeContext, vw bsonrw.ValueWriter, val reflect.Value) error { |
|
if !val.IsValid() || val.Type() != tElementSlice { |
|
return bsoncodec.ValueEncoderError{Name: "ElementSliceEncodeValue", Types: []reflect.Type{tElementSlice}, Received: val} |
|
} |
|
|
|
if val.IsNil() { |
|
return vw.WriteNull() |
|
} |
|
|
|
return pc.DocumentEncodeValue(ec, vw, val.Convert(tDocument)) |
|
} |
|
|
|
// ElementSliceDecodeValue is the ValueDecoderFunc for []*Element. |
|
func (pc PrimitiveCodecs) ElementSliceDecodeValue(dc bsoncodec.DecodeContext, vr bsonrw.ValueReader, val reflect.Value) error { |
|
if !val.CanSet() || val.Type() != tElementSlice { |
|
return bsoncodec.ValueDecoderError{Name: "ElementSliceDecodeValue", Types: []reflect.Type{tElementSlice}, Received: val} |
|
} |
|
|
|
if val.IsNil() { |
|
val.Set(reflect.MakeSlice(val.Type(), 0, 0)) |
|
} |
|
|
|
val.SetLen(0) |
|
|
|
dr, err := vr.ReadDocument() |
|
if err != nil { |
|
return err |
|
} |
|
elems := make([]reflect.Value, 0) |
|
for { |
|
key, vr, err := dr.ReadElement() |
|
if err == bsonrw.ErrEOD { |
|
break |
|
} |
|
if err != nil { |
|
return err |
|
} |
|
|
|
var elem Elem |
|
err = pc.elementDecodeValue(dc, vr, key, &elem) |
|
if err != nil { |
|
return err |
|
} |
|
|
|
elems = append(elems, reflect.ValueOf(elem)) |
|
} |
|
|
|
val.Set(reflect.Append(val, elems...)) |
|
return nil |
|
} |
|
|
|
// ValueEncodeValue is the ValueEncoderFunc for *Value. |
|
func (pc PrimitiveCodecs) ValueEncodeValue(ec bsoncodec.EncodeContext, vw bsonrw.ValueWriter, val reflect.Value) error { |
|
if !val.IsValid() || val.Type() != tValue { |
|
return bsoncodec.ValueEncoderError{Name: "ValueEncodeValue", Types: []reflect.Type{tValue}, Received: val} |
|
} |
|
|
|
v := val.Interface().(Val) |
|
|
|
return pc.encodeValue(ec, vw, v) |
|
} |
|
|
|
// ValueDecodeValue is the ValueDecoderFunc for *Value. |
|
func (pc PrimitiveCodecs) ValueDecodeValue(dc bsoncodec.DecodeContext, vr bsonrw.ValueReader, val reflect.Value) error { |
|
if !val.CanSet() || val.Type() != tValue { |
|
return bsoncodec.ValueDecoderError{Name: "ValueDecodeValue", Types: []reflect.Type{tValue}, Received: val} |
|
} |
|
|
|
return pc.valueDecodeValue(dc, vr, val.Addr().Interface().(*Val)) |
|
} |
|
|
|
// encodeDocument is a separate function that we use because CodeWithScope |
|
// returns us a DocumentWriter and we need to do the same logic that we would do |
|
// for a document but cannot use a Codec. |
|
func (pc PrimitiveCodecs) encodeDocument(ec bsoncodec.EncodeContext, dw bsonrw.DocumentWriter, doc Doc) error { |
|
for _, elem := range doc { |
|
dvw, err := dw.WriteDocumentElement(elem.Key) |
|
if err != nil { |
|
return err |
|
} |
|
|
|
err = pc.encodeValue(ec, dvw, elem.Value) |
|
|
|
if err != nil { |
|
return err |
|
} |
|
} |
|
|
|
return dw.WriteDocumentEnd() |
|
} |
|
|
|
// DecodeDocument haves decoding into a Doc from a bsonrw.DocumentReader. |
|
func (pc PrimitiveCodecs) DecodeDocument(dctx bsoncodec.DecodeContext, dr bsonrw.DocumentReader, pdoc *Doc) error { |
|
return pc.decodeDocument(dctx, dr, pdoc) |
|
} |
|
|
|
func (pc PrimitiveCodecs) decodeDocument(dctx bsoncodec.DecodeContext, dr bsonrw.DocumentReader, pdoc *Doc) error { |
|
if *pdoc == nil { |
|
*pdoc = make(Doc, 0) |
|
} |
|
*pdoc = (*pdoc)[:0] |
|
for { |
|
key, vr, err := dr.ReadElement() |
|
if err == bsonrw.ErrEOD { |
|
break |
|
} |
|
if err != nil { |
|
return err |
|
} |
|
|
|
var elem Elem |
|
err = pc.elementDecodeValue(dctx, vr, key, &elem) |
|
if err != nil { |
|
return err |
|
} |
|
|
|
*pdoc = append(*pdoc, elem) |
|
} |
|
return nil |
|
} |
|
|
|
func (pc PrimitiveCodecs) elementDecodeValue(dc bsoncodec.DecodeContext, vr bsonrw.ValueReader, key string, elem *Elem) error { |
|
var val Val |
|
switch vr.Type() { |
|
case bsontype.Double: |
|
f64, err := vr.ReadDouble() |
|
if err != nil { |
|
return err |
|
} |
|
val = Double(f64) |
|
case bsontype.String: |
|
str, err := vr.ReadString() |
|
if err != nil { |
|
return err |
|
} |
|
val = String(str) |
|
case bsontype.EmbeddedDocument: |
|
var embeddedDoc Doc |
|
err := pc.documentDecodeValue(dc, vr, &embeddedDoc) |
|
if err != nil { |
|
return err |
|
} |
|
val = Document(embeddedDoc) |
|
case bsontype.Array: |
|
arr := reflect.New(tArray).Elem() |
|
err := pc.ArrayDecodeValue(dc, vr, arr) |
|
if err != nil { |
|
return err |
|
} |
|
val = Array(arr.Interface().(Arr)) |
|
case bsontype.Binary: |
|
data, subtype, err := vr.ReadBinary() |
|
if err != nil { |
|
return err |
|
} |
|
val = Binary(subtype, data) |
|
case bsontype.Undefined: |
|
err := vr.ReadUndefined() |
|
if err != nil { |
|
return err |
|
} |
|
val = Undefined() |
|
case bsontype.ObjectID: |
|
oid, err := vr.ReadObjectID() |
|
if err != nil { |
|
return err |
|
} |
|
val = ObjectID(oid) |
|
case bsontype.Boolean: |
|
b, err := vr.ReadBoolean() |
|
if err != nil { |
|
return err |
|
} |
|
val = Boolean(b) |
|
case bsontype.DateTime: |
|
dt, err := vr.ReadDateTime() |
|
if err != nil { |
|
return err |
|
} |
|
val = DateTime(dt) |
|
case bsontype.Null: |
|
err := vr.ReadNull() |
|
if err != nil { |
|
return err |
|
} |
|
val = Null() |
|
case bsontype.Regex: |
|
pattern, options, err := vr.ReadRegex() |
|
if err != nil { |
|
return err |
|
} |
|
val = Regex(pattern, options) |
|
case bsontype.DBPointer: |
|
ns, pointer, err := vr.ReadDBPointer() |
|
if err != nil { |
|
return err |
|
} |
|
val = DBPointer(ns, pointer) |
|
case bsontype.JavaScript: |
|
js, err := vr.ReadJavascript() |
|
if err != nil { |
|
return err |
|
} |
|
val = JavaScript(js) |
|
case bsontype.Symbol: |
|
symbol, err := vr.ReadSymbol() |
|
if err != nil { |
|
return err |
|
} |
|
val = Symbol(symbol) |
|
case bsontype.CodeWithScope: |
|
code, scope, err := vr.ReadCodeWithScope() |
|
if err != nil { |
|
return err |
|
} |
|
var doc Doc |
|
err = pc.decodeDocument(dc, scope, &doc) |
|
if err != nil { |
|
return err |
|
} |
|
val = CodeWithScope(code, doc) |
|
case bsontype.Int32: |
|
i32, err := vr.ReadInt32() |
|
if err != nil { |
|
return err |
|
} |
|
val = Int32(i32) |
|
case bsontype.Timestamp: |
|
t, i, err := vr.ReadTimestamp() |
|
if err != nil { |
|
return err |
|
} |
|
val = Timestamp(t, i) |
|
case bsontype.Int64: |
|
i64, err := vr.ReadInt64() |
|
if err != nil { |
|
return err |
|
} |
|
val = Int64(i64) |
|
case bsontype.Decimal128: |
|
d128, err := vr.ReadDecimal128() |
|
if err != nil { |
|
return err |
|
} |
|
val = Decimal128(d128) |
|
case bsontype.MinKey: |
|
err := vr.ReadMinKey() |
|
if err != nil { |
|
return err |
|
} |
|
val = MinKey() |
|
case bsontype.MaxKey: |
|
err := vr.ReadMaxKey() |
|
if err != nil { |
|
return err |
|
} |
|
val = MaxKey() |
|
default: |
|
return fmt.Errorf("Cannot read unknown BSON type %s", vr.Type()) |
|
} |
|
|
|
*elem = Elem{Key: key, Value: val} |
|
return nil |
|
} |
|
|
|
// encodeValue does not validation, and the callers must perform validation on val before calling |
|
// this method. |
|
func (pc PrimitiveCodecs) encodeValue(ec bsoncodec.EncodeContext, vw bsonrw.ValueWriter, val Val) error { |
|
var err error |
|
switch val.Type() { |
|
case bsontype.Double: |
|
err = vw.WriteDouble(val.Double()) |
|
case bsontype.String: |
|
err = vw.WriteString(val.StringValue()) |
|
case bsontype.EmbeddedDocument: |
|
var encoder bsoncodec.ValueEncoder |
|
encoder, err = ec.LookupEncoder(tDocument) |
|
if err != nil { |
|
break |
|
} |
|
err = encoder.EncodeValue(ec, vw, reflect.ValueOf(val.Document())) |
|
case bsontype.Array: |
|
var encoder bsoncodec.ValueEncoder |
|
encoder, err = ec.LookupEncoder(tArray) |
|
if err != nil { |
|
break |
|
} |
|
err = encoder.EncodeValue(ec, vw, reflect.ValueOf(val.Array())) |
|
case bsontype.Binary: |
|
// TODO: FIX THIS (╯°□°)╯︵ ┻━┻ |
|
subtype, data := val.Binary() |
|
err = vw.WriteBinaryWithSubtype(data, subtype) |
|
case bsontype.Undefined: |
|
err = vw.WriteUndefined() |
|
case bsontype.ObjectID: |
|
err = vw.WriteObjectID(val.ObjectID()) |
|
case bsontype.Boolean: |
|
err = vw.WriteBoolean(val.Boolean()) |
|
case bsontype.DateTime: |
|
err = vw.WriteDateTime(val.DateTime()) |
|
case bsontype.Null: |
|
err = vw.WriteNull() |
|
case bsontype.Regex: |
|
err = vw.WriteRegex(val.Regex()) |
|
case bsontype.DBPointer: |
|
err = vw.WriteDBPointer(val.DBPointer()) |
|
case bsontype.JavaScript: |
|
err = vw.WriteJavascript(val.JavaScript()) |
|
case bsontype.Symbol: |
|
err = vw.WriteSymbol(val.Symbol()) |
|
case bsontype.CodeWithScope: |
|
code, scope := val.CodeWithScope() |
|
|
|
var cwsw bsonrw.DocumentWriter |
|
cwsw, err = vw.WriteCodeWithScope(code) |
|
if err != nil { |
|
break |
|
} |
|
|
|
err = pc.encodeDocument(ec, cwsw, scope) |
|
case bsontype.Int32: |
|
err = vw.WriteInt32(val.Int32()) |
|
case bsontype.Timestamp: |
|
err = vw.WriteTimestamp(val.Timestamp()) |
|
case bsontype.Int64: |
|
err = vw.WriteInt64(val.Int64()) |
|
case bsontype.Decimal128: |
|
err = vw.WriteDecimal128(val.Decimal128()) |
|
case bsontype.MinKey: |
|
err = vw.WriteMinKey() |
|
case bsontype.MaxKey: |
|
err = vw.WriteMaxKey() |
|
default: |
|
err = fmt.Errorf("%T is not a valid BSON type to encode", val.Type()) |
|
} |
|
|
|
return err |
|
} |
|
|
|
func (pc PrimitiveCodecs) valueDecodeValue(dc bsoncodec.DecodeContext, vr bsonrw.ValueReader, val *Val) error { |
|
switch vr.Type() { |
|
case bsontype.Double: |
|
f64, err := vr.ReadDouble() |
|
if err != nil { |
|
return err |
|
} |
|
*val = Double(f64) |
|
case bsontype.String: |
|
str, err := vr.ReadString() |
|
if err != nil { |
|
return err |
|
} |
|
*val = String(str) |
|
case bsontype.EmbeddedDocument: |
|
var embeddedDoc Doc |
|
err := pc.documentDecodeValue(dc, vr, &embeddedDoc) |
|
if err != nil { |
|
return err |
|
} |
|
*val = Document(embeddedDoc) |
|
case bsontype.Array: |
|
arr := reflect.New(tArray).Elem() |
|
err := pc.ArrayDecodeValue(dc, vr, arr) |
|
if err != nil { |
|
return err |
|
} |
|
*val = Array(arr.Interface().(Arr)) |
|
case bsontype.Binary: |
|
data, subtype, err := vr.ReadBinary() |
|
if err != nil { |
|
return err |
|
} |
|
*val = Binary(subtype, data) |
|
case bsontype.Undefined: |
|
err := vr.ReadUndefined() |
|
if err != nil { |
|
return err |
|
} |
|
*val = Undefined() |
|
case bsontype.ObjectID: |
|
oid, err := vr.ReadObjectID() |
|
if err != nil { |
|
return err |
|
} |
|
*val = ObjectID(oid) |
|
case bsontype.Boolean: |
|
b, err := vr.ReadBoolean() |
|
if err != nil { |
|
return err |
|
} |
|
*val = Boolean(b) |
|
case bsontype.DateTime: |
|
dt, err := vr.ReadDateTime() |
|
if err != nil { |
|
return err |
|
} |
|
*val = DateTime(dt) |
|
case bsontype.Null: |
|
err := vr.ReadNull() |
|
if err != nil { |
|
return err |
|
} |
|
*val = Null() |
|
case bsontype.Regex: |
|
pattern, options, err := vr.ReadRegex() |
|
if err != nil { |
|
return err |
|
} |
|
*val = Regex(pattern, options) |
|
case bsontype.DBPointer: |
|
ns, pointer, err := vr.ReadDBPointer() |
|
if err != nil { |
|
return err |
|
} |
|
*val = DBPointer(ns, pointer) |
|
case bsontype.JavaScript: |
|
js, err := vr.ReadJavascript() |
|
if err != nil { |
|
return err |
|
} |
|
*val = JavaScript(js) |
|
case bsontype.Symbol: |
|
symbol, err := vr.ReadSymbol() |
|
if err != nil { |
|
return err |
|
} |
|
*val = Symbol(symbol) |
|
case bsontype.CodeWithScope: |
|
code, scope, err := vr.ReadCodeWithScope() |
|
if err != nil { |
|
return err |
|
} |
|
var scopeDoc Doc |
|
err = pc.decodeDocument(dc, scope, &scopeDoc) |
|
if err != nil { |
|
return err |
|
} |
|
*val = CodeWithScope(code, scopeDoc) |
|
case bsontype.Int32: |
|
i32, err := vr.ReadInt32() |
|
if err != nil { |
|
return err |
|
} |
|
*val = Int32(i32) |
|
case bsontype.Timestamp: |
|
t, i, err := vr.ReadTimestamp() |
|
if err != nil { |
|
return err |
|
} |
|
*val = Timestamp(t, i) |
|
case bsontype.Int64: |
|
i64, err := vr.ReadInt64() |
|
if err != nil { |
|
return err |
|
} |
|
*val = Int64(i64) |
|
case bsontype.Decimal128: |
|
d128, err := vr.ReadDecimal128() |
|
if err != nil { |
|
return err |
|
} |
|
*val = Decimal128(d128) |
|
case bsontype.MinKey: |
|
err := vr.ReadMinKey() |
|
if err != nil { |
|
return err |
|
} |
|
*val = MinKey() |
|
case bsontype.MaxKey: |
|
err := vr.ReadMaxKey() |
|
if err != nil { |
|
return err |
|
} |
|
*val = MaxKey() |
|
default: |
|
return fmt.Errorf("Cannot read unknown BSON type %s", vr.Type()) |
|
} |
|
|
|
return nil |
|
}
|
|
|