You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 

1584 lines
42 KiB

package yaml
import (
"bytes"
"context"
"encoding"
"encoding/base64"
"fmt"
"io"
"io/ioutil"
"math"
"os"
"path/filepath"
"reflect"
"strconv"
"time"
"github.com/goccy/go-yaml/ast"
"github.com/goccy/go-yaml/internal/errors"
"github.com/goccy/go-yaml/parser"
"github.com/goccy/go-yaml/token"
"golang.org/x/xerrors"
)
// Decoder reads and decodes YAML values from an input stream.
type Decoder struct {
reader io.Reader
referenceReaders []io.Reader
anchorNodeMap map[string]ast.Node
anchorValueMap map[string]reflect.Value
toCommentMap CommentMap
opts []DecodeOption
referenceFiles []string
referenceDirs []string
isRecursiveDir bool
isResolvedReference bool
validator StructValidator
disallowUnknownField bool
disallowDuplicateKey bool
useOrderedMap bool
useJSONUnmarshaler bool
parsedFile *ast.File
streamIndex int
}
// NewDecoder returns a new decoder that reads from r.
func NewDecoder(r io.Reader, opts ...DecodeOption) *Decoder {
return &Decoder{
reader: r,
anchorNodeMap: map[string]ast.Node{},
anchorValueMap: map[string]reflect.Value{},
opts: opts,
referenceReaders: []io.Reader{},
referenceFiles: []string{},
referenceDirs: []string{},
isRecursiveDir: false,
isResolvedReference: false,
disallowUnknownField: false,
disallowDuplicateKey: false,
useOrderedMap: false,
}
}
func (d *Decoder) castToFloat(v interface{}) interface{} {
switch vv := v.(type) {
case int:
return float64(vv)
case int8:
return float64(vv)
case int16:
return float64(vv)
case int32:
return float64(vv)
case int64:
return float64(vv)
case uint:
return float64(vv)
case uint8:
return float64(vv)
case uint16:
return float64(vv)
case uint32:
return float64(vv)
case uint64:
return float64(vv)
case float32:
return float64(vv)
case float64:
return vv
case string:
// if error occurred, return zero value
f, _ := strconv.ParseFloat(vv, 64)
return f
}
return 0
}
func (d *Decoder) mergeValueNode(value ast.Node) ast.Node {
if value.Type() == ast.AliasType {
aliasNode := value.(*ast.AliasNode)
aliasName := aliasNode.Value.GetToken().Value
return d.anchorNodeMap[aliasName]
}
return value
}
func (d *Decoder) mapKeyNodeToString(node ast.Node) string {
key := d.nodeToValue(node)
if key == nil {
return "null"
}
if k, ok := key.(string); ok {
return k
}
return fmt.Sprint(key)
}
func (d *Decoder) setToMapValue(node ast.Node, m map[string]interface{}) {
d.setPathToCommentMap(node)
switch n := node.(type) {
case *ast.MappingValueNode:
if n.Key.Type() == ast.MergeKeyType {
d.setToMapValue(d.mergeValueNode(n.Value), m)
} else {
key := d.mapKeyNodeToString(n.Key)
m[key] = d.nodeToValue(n.Value)
}
case *ast.MappingNode:
for _, value := range n.Values {
d.setToMapValue(value, m)
}
case *ast.AnchorNode:
anchorName := n.Name.GetToken().Value
d.anchorNodeMap[anchorName] = n.Value
}
}
func (d *Decoder) setToOrderedMapValue(node ast.Node, m *MapSlice) {
switch n := node.(type) {
case *ast.MappingValueNode:
if n.Key.Type() == ast.MergeKeyType {
d.setToOrderedMapValue(d.mergeValueNode(n.Value), m)
} else {
key := d.mapKeyNodeToString(n.Key)
*m = append(*m, MapItem{Key: key, Value: d.nodeToValue(n.Value)})
}
case *ast.MappingNode:
for _, value := range n.Values {
d.setToOrderedMapValue(value, m)
}
}
}
func (d *Decoder) setPathToCommentMap(node ast.Node) {
if d.toCommentMap == nil {
return
}
commentGroup := node.GetComment()
if commentGroup == nil {
return
}
texts := []string{}
for _, comment := range commentGroup.Comments {
texts = append(texts, comment.Token.Value)
}
if len(texts) == 0 {
return
}
if len(texts) == 1 {
d.toCommentMap[node.GetPath()] = LineComment(texts[0])
} else {
d.toCommentMap[node.GetPath()] = HeadComment(texts...)
}
}
func (d *Decoder) nodeToValue(node ast.Node) interface{} {
d.setPathToCommentMap(node)
switch n := node.(type) {
case *ast.NullNode:
return nil
case *ast.StringNode:
return n.GetValue()
case *ast.IntegerNode:
return n.GetValue()
case *ast.FloatNode:
return n.GetValue()
case *ast.BoolNode:
return n.GetValue()
case *ast.InfinityNode:
return n.GetValue()
case *ast.NanNode:
return n.GetValue()
case *ast.TagNode:
switch token.ReservedTagKeyword(n.Start.Value) {
case token.TimestampTag:
t, _ := d.castToTime(n.Value)
return t
case token.IntegerTag:
i, _ := strconv.Atoi(fmt.Sprint(d.nodeToValue(n.Value)))
return i
case token.FloatTag:
return d.castToFloat(d.nodeToValue(n.Value))
case token.NullTag:
return nil
case token.BinaryTag:
b, _ := base64.StdEncoding.DecodeString(d.nodeToValue(n.Value).(string))
return b
case token.StringTag:
return d.nodeToValue(n.Value)
case token.MappingTag:
return d.nodeToValue(n.Value)
}
case *ast.AnchorNode:
anchorName := n.Name.GetToken().Value
anchorValue := d.nodeToValue(n.Value)
d.anchorNodeMap[anchorName] = n.Value
return anchorValue
case *ast.AliasNode:
aliasName := n.Value.GetToken().Value
node := d.anchorNodeMap[aliasName]
return d.nodeToValue(node)
case *ast.LiteralNode:
return n.Value.GetValue()
case *ast.MappingKeyNode:
return d.nodeToValue(n.Value)
case *ast.MappingValueNode:
if n.Key.Type() == ast.MergeKeyType {
value := d.mergeValueNode(n.Value)
if d.useOrderedMap {
m := MapSlice{}
d.setToOrderedMapValue(value, &m)
return m
}
m := map[string]interface{}{}
d.setToMapValue(value, m)
return m
}
key := d.mapKeyNodeToString(n.Key)
if d.useOrderedMap {
return MapSlice{{Key: key, Value: d.nodeToValue(n.Value)}}
}
return map[string]interface{}{
key: d.nodeToValue(n.Value),
}
case *ast.MappingNode:
if d.useOrderedMap {
m := make(MapSlice, 0, len(n.Values))
for _, value := range n.Values {
d.setToOrderedMapValue(value, &m)
}
return m
}
m := make(map[string]interface{}, len(n.Values))
for _, value := range n.Values {
d.setToMapValue(value, m)
}
return m
case *ast.SequenceNode:
v := make([]interface{}, 0, len(n.Values))
for _, value := range n.Values {
v = append(v, d.nodeToValue(value))
}
return v
}
return nil
}
func (d *Decoder) resolveAlias(node ast.Node) ast.Node {
switch n := node.(type) {
case *ast.MappingNode:
for idx, value := range n.Values {
n.Values[idx] = d.resolveAlias(value).(*ast.MappingValueNode)
}
case *ast.TagNode:
n.Value = d.resolveAlias(n.Value)
case *ast.MappingKeyNode:
n.Value = d.resolveAlias(n.Value)
case *ast.MappingValueNode:
if n.Key.Type() == ast.MergeKeyType && n.Value.Type() == ast.AliasType {
value := d.resolveAlias(n.Value)
keyColumn := n.Key.GetToken().Position.Column
requiredColumn := keyColumn + 2
value.AddColumn(requiredColumn)
n.Value = value
} else {
n.Key = d.resolveAlias(n.Key)
n.Value = d.resolveAlias(n.Value)
}
case *ast.SequenceNode:
for idx, value := range n.Values {
n.Values[idx] = d.resolveAlias(value)
}
case *ast.AliasNode:
aliasName := n.Value.GetToken().Value
return d.resolveAlias(d.anchorNodeMap[aliasName])
}
return node
}
func (d *Decoder) getMapNode(node ast.Node) (ast.MapNode, error) {
if _, ok := node.(*ast.NullNode); ok {
return nil, nil
}
if anchor, ok := node.(*ast.AnchorNode); ok {
mapNode, ok := anchor.Value.(ast.MapNode)
if ok {
return mapNode, nil
}
return nil, errUnexpectedNodeType(anchor.Value.Type(), ast.MappingType, node.GetToken())
}
if alias, ok := node.(*ast.AliasNode); ok {
aliasName := alias.Value.GetToken().Value
node := d.anchorNodeMap[aliasName]
if node == nil {
return nil, xerrors.Errorf("cannot find anchor by alias name %s", aliasName)
}
mapNode, ok := node.(ast.MapNode)
if ok {
return mapNode, nil
}
return nil, errUnexpectedNodeType(node.Type(), ast.MappingType, node.GetToken())
}
mapNode, ok := node.(ast.MapNode)
if !ok {
return nil, errUnexpectedNodeType(node.Type(), ast.MappingType, node.GetToken())
}
return mapNode, nil
}
func (d *Decoder) getArrayNode(node ast.Node) (ast.ArrayNode, error) {
if _, ok := node.(*ast.NullNode); ok {
return nil, nil
}
if anchor, ok := node.(*ast.AnchorNode); ok {
arrayNode, ok := anchor.Value.(ast.ArrayNode)
if ok {
return arrayNode, nil
}
return nil, errUnexpectedNodeType(anchor.Value.Type(), ast.SequenceType, node.GetToken())
}
if alias, ok := node.(*ast.AliasNode); ok {
aliasName := alias.Value.GetToken().Value
node := d.anchorNodeMap[aliasName]
if node == nil {
return nil, xerrors.Errorf("cannot find anchor by alias name %s", aliasName)
}
arrayNode, ok := node.(ast.ArrayNode)
if ok {
return arrayNode, nil
}
return nil, errUnexpectedNodeType(node.Type(), ast.SequenceType, node.GetToken())
}
arrayNode, ok := node.(ast.ArrayNode)
if !ok {
return nil, errUnexpectedNodeType(node.Type(), ast.SequenceType, node.GetToken())
}
return arrayNode, nil
}
func (d *Decoder) fileToNode(f *ast.File) ast.Node {
for _, doc := range f.Docs {
if v := d.nodeToValue(doc.Body); v != nil {
return doc.Body
}
}
return nil
}
func (d *Decoder) convertValue(v reflect.Value, typ reflect.Type) (reflect.Value, error) {
if typ.Kind() != reflect.String {
if !v.Type().ConvertibleTo(typ) {
return reflect.Zero(typ), errTypeMismatch(typ, v.Type())
}
return v.Convert(typ), nil
}
// cast value to string
switch v.Type().Kind() {
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
return reflect.ValueOf(fmt.Sprint(v.Int())), nil
case reflect.Float32, reflect.Float64:
return reflect.ValueOf(fmt.Sprint(v.Float())), nil
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
return reflect.ValueOf(fmt.Sprint(v.Uint())), nil
case reflect.Bool:
return reflect.ValueOf(fmt.Sprint(v.Bool())), nil
}
if !v.Type().ConvertibleTo(typ) {
return reflect.Zero(typ), errTypeMismatch(typ, v.Type())
}
return v.Convert(typ), nil
}
type overflowError struct {
dstType reflect.Type
srcNum string
}
func (e *overflowError) Error() string {
return fmt.Sprintf("cannot unmarshal %s into Go value of type %s ( overflow )", e.srcNum, e.dstType)
}
func errOverflow(dstType reflect.Type, num string) *overflowError {
return &overflowError{dstType: dstType, srcNum: num}
}
type typeError struct {
dstType reflect.Type
srcType reflect.Type
structFieldName *string
}
func (e *typeError) Error() string {
if e.structFieldName != nil {
return fmt.Sprintf("cannot unmarshal %s into Go struct field %s of type %s", e.srcType, *e.structFieldName, e.dstType)
}
return fmt.Sprintf("cannot unmarshal %s into Go value of type %s", e.srcType, e.dstType)
}
func errTypeMismatch(dstType, srcType reflect.Type) *typeError {
return &typeError{dstType: dstType, srcType: srcType}
}
type unknownFieldError struct {
err error
}
func (e *unknownFieldError) Error() string {
return e.err.Error()
}
func errUnknownField(msg string, tk *token.Token) *unknownFieldError {
return &unknownFieldError{err: errors.ErrSyntax(msg, tk)}
}
func errUnexpectedNodeType(actual, expected ast.NodeType, tk *token.Token) error {
return errors.ErrSyntax(fmt.Sprintf("%s was used where %s is expected", actual.YAMLName(), expected.YAMLName()), tk)
}
type duplicateKeyError struct {
err error
}
func (e *duplicateKeyError) Error() string {
return e.err.Error()
}
func errDuplicateKey(msg string, tk *token.Token) *duplicateKeyError {
return &duplicateKeyError{err: errors.ErrSyntax(msg, tk)}
}
func (d *Decoder) deleteStructKeys(structType reflect.Type, unknownFields map[string]ast.Node) error {
if structType.Kind() == reflect.Ptr {
structType = structType.Elem()
}
structFieldMap, err := structFieldMap(structType)
if err != nil {
return errors.Wrapf(err, "failed to create struct field map")
}
for j := 0; j < structType.NumField(); j++ {
field := structType.Field(j)
if isIgnoredStructField(field) {
continue
}
structField, exists := structFieldMap[field.Name]
if !exists {
continue
}
if structField.IsInline {
d.deleteStructKeys(field.Type, unknownFields)
} else {
delete(unknownFields, structField.RenderName)
}
}
return nil
}
func (d *Decoder) lastNode(node ast.Node) ast.Node {
switch n := node.(type) {
case *ast.MappingNode:
if len(n.Values) > 0 {
return d.lastNode(n.Values[len(n.Values)-1])
}
case *ast.MappingValueNode:
return d.lastNode(n.Value)
case *ast.SequenceNode:
if len(n.Values) > 0 {
return d.lastNode(n.Values[len(n.Values)-1])
}
}
return node
}
func (d *Decoder) unmarshalableDocument(node ast.Node) []byte {
node = d.resolveAlias(node)
doc := node.String()
last := d.lastNode(node)
if last != nil && last.Type() == ast.LiteralType {
doc += "\n"
}
return []byte(doc)
}
func (d *Decoder) unmarshalableText(node ast.Node) ([]byte, bool) {
node = d.resolveAlias(node)
if node.Type() == ast.AnchorType {
node = node.(*ast.AnchorNode).Value
}
switch n := node.(type) {
case *ast.StringNode:
return []byte(n.Value), true
case *ast.LiteralNode:
return []byte(n.Value.GetToken().Value), true
default:
scalar, ok := n.(ast.ScalarNode)
if ok {
return []byte(fmt.Sprint(scalar.GetValue())), true
}
}
return nil, false
}
type jsonUnmarshaler interface {
UnmarshalJSON([]byte) error
}
func (d *Decoder) canDecodeByUnmarshaler(dst reflect.Value) bool {
iface := dst.Addr().Interface()
switch iface.(type) {
case BytesUnmarshalerContext:
return true
case BytesUnmarshaler:
return true
case InterfaceUnmarshalerContext:
return true
case InterfaceUnmarshaler:
return true
case *time.Time:
return true
case *time.Duration:
return true
case encoding.TextUnmarshaler:
return true
case jsonUnmarshaler:
return d.useJSONUnmarshaler
}
return false
}
func (d *Decoder) decodeByUnmarshaler(ctx context.Context, dst reflect.Value, src ast.Node) error {
iface := dst.Addr().Interface()
if unmarshaler, ok := iface.(BytesUnmarshalerContext); ok {
if err := unmarshaler.UnmarshalYAML(ctx, d.unmarshalableDocument(src)); err != nil {
return errors.Wrapf(err, "failed to UnmarshalYAML")
}
return nil
}
if unmarshaler, ok := iface.(BytesUnmarshaler); ok {
if err := unmarshaler.UnmarshalYAML(d.unmarshalableDocument(src)); err != nil {
return errors.Wrapf(err, "failed to UnmarshalYAML")
}
return nil
}
if unmarshaler, ok := iface.(InterfaceUnmarshalerContext); ok {
if err := unmarshaler.UnmarshalYAML(ctx, func(v interface{}) error {
rv := reflect.ValueOf(v)
if rv.Type().Kind() != reflect.Ptr {
return errors.ErrDecodeRequiredPointerType
}
if err := d.decodeValue(ctx, rv.Elem(), src); err != nil {
return errors.Wrapf(err, "failed to decode value")
}
return nil
}); err != nil {
return errors.Wrapf(err, "failed to UnmarshalYAML")
}
return nil
}
if unmarshaler, ok := iface.(InterfaceUnmarshaler); ok {
if err := unmarshaler.UnmarshalYAML(func(v interface{}) error {
rv := reflect.ValueOf(v)
if rv.Type().Kind() != reflect.Ptr {
return errors.ErrDecodeRequiredPointerType
}
if err := d.decodeValue(ctx, rv.Elem(), src); err != nil {
return errors.Wrapf(err, "failed to decode value")
}
return nil
}); err != nil {
return errors.Wrapf(err, "failed to UnmarshalYAML")
}
return nil
}
if _, ok := iface.(*time.Time); ok {
return d.decodeTime(ctx, dst, src)
}
if _, ok := iface.(*time.Duration); ok {
return d.decodeDuration(ctx, dst, src)
}
if unmarshaler, isText := iface.(encoding.TextUnmarshaler); isText {
b, ok := d.unmarshalableText(src)
if ok {
if err := unmarshaler.UnmarshalText(b); err != nil {
return errors.Wrapf(err, "failed to UnmarshalText")
}
return nil
}
}
if d.useJSONUnmarshaler {
if unmarshaler, ok := iface.(jsonUnmarshaler); ok {
jsonBytes, err := YAMLToJSON(d.unmarshalableDocument(src))
if err != nil {
return errors.Wrapf(err, "failed to convert yaml to json")
}
jsonBytes = bytes.TrimRight(jsonBytes, "\n")
if err := unmarshaler.UnmarshalJSON(jsonBytes); err != nil {
return errors.Wrapf(err, "failed to UnmarshalJSON")
}
return nil
}
}
return xerrors.Errorf("does not implemented Unmarshaler")
}
var (
astNodeType = reflect.TypeOf((*ast.Node)(nil)).Elem()
)
func (d *Decoder) decodeValue(ctx context.Context, dst reflect.Value, src ast.Node) error {
if src.Type() == ast.AnchorType {
anchorName := src.(*ast.AnchorNode).Name.GetToken().Value
if _, exists := d.anchorValueMap[anchorName]; !exists {
d.anchorValueMap[anchorName] = dst
}
}
if d.canDecodeByUnmarshaler(dst) {
if err := d.decodeByUnmarshaler(ctx, dst, src); err != nil {
return errors.Wrapf(err, "failed to decode by unmarshaler")
}
return nil
}
valueType := dst.Type()
switch valueType.Kind() {
case reflect.Ptr:
if dst.IsNil() {
return nil
}
if src.Type() == ast.NullType {
// set nil value to pointer
dst.Set(reflect.Zero(valueType))
return nil
}
v := d.createDecodableValue(dst.Type())
if err := d.decodeValue(ctx, v, src); err != nil {
return errors.Wrapf(err, "failed to decode ptr value")
}
dst.Set(d.castToAssignableValue(v, dst.Type()))
case reflect.Interface:
if dst.Type() == astNodeType {
dst.Set(reflect.ValueOf(src))
return nil
}
v := reflect.ValueOf(d.nodeToValue(src))
if v.IsValid() {
dst.Set(v)
}
case reflect.Map:
return d.decodeMap(ctx, dst, src)
case reflect.Array:
return d.decodeArray(ctx, dst, src)
case reflect.Slice:
if mapSlice, ok := dst.Addr().Interface().(*MapSlice); ok {
return d.decodeMapSlice(ctx, mapSlice, src)
}
return d.decodeSlice(ctx, dst, src)
case reflect.Struct:
if mapItem, ok := dst.Addr().Interface().(*MapItem); ok {
return d.decodeMapItem(ctx, mapItem, src)
}
return d.decodeStruct(ctx, dst, src)
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
v := d.nodeToValue(src)
switch vv := v.(type) {
case int64:
if !dst.OverflowInt(vv) {
dst.SetInt(vv)
return nil
}
case uint64:
if vv <= math.MaxInt64 && !dst.OverflowInt(int64(vv)) {
dst.SetInt(int64(vv))
return nil
}
case float64:
if vv <= math.MaxInt64 && !dst.OverflowInt(int64(vv)) {
dst.SetInt(int64(vv))
return nil
}
default:
return errTypeMismatch(valueType, reflect.TypeOf(v))
}
return errOverflow(valueType, fmt.Sprint(v))
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
v := d.nodeToValue(src)
switch vv := v.(type) {
case int64:
if 0 <= vv && !dst.OverflowUint(uint64(vv)) {
dst.SetUint(uint64(vv))
return nil
}
case uint64:
if !dst.OverflowUint(vv) {
dst.SetUint(vv)
return nil
}
case float64:
if 0 <= vv && vv <= math.MaxUint64 && !dst.OverflowUint(uint64(vv)) {
dst.SetUint(uint64(vv))
return nil
}
default:
return errTypeMismatch(valueType, reflect.TypeOf(v))
}
return errOverflow(valueType, fmt.Sprint(v))
}
v := reflect.ValueOf(d.nodeToValue(src))
if v.IsValid() {
convertedValue, err := d.convertValue(v, dst.Type())
if err != nil {
return errors.Wrapf(err, "failed to convert value")
}
dst.Set(convertedValue)
}
return nil
}
func (d *Decoder) createDecodableValue(typ reflect.Type) reflect.Value {
for {
if typ.Kind() == reflect.Ptr {
typ = typ.Elem()
continue
}
break
}
return reflect.New(typ).Elem()
}
func (d *Decoder) castToAssignableValue(value reflect.Value, target reflect.Type) reflect.Value {
if target.Kind() != reflect.Ptr {
return value
}
maxTryCount := 5
tryCount := 0
for {
if tryCount > maxTryCount {
return value
}
if value.Type().AssignableTo(target) {
break
}
value = value.Addr()
tryCount++
}
return value
}
func (d *Decoder) createDecodedNewValue(
ctx context.Context, typ reflect.Type, defaultVal reflect.Value, node ast.Node,
) (reflect.Value, error) {
if node.Type() == ast.AliasType {
aliasName := node.(*ast.AliasNode).Value.GetToken().Value
newValue := d.anchorValueMap[aliasName]
if newValue.IsValid() {
return newValue, nil
}
}
if node.Type() == ast.NullType {
return reflect.Zero(typ), nil
}
newValue := d.createDecodableValue(typ)
for defaultVal.Kind() == reflect.Ptr {
defaultVal = defaultVal.Elem()
}
if defaultVal.IsValid() && defaultVal.Type().AssignableTo(newValue.Type()) {
newValue.Set(defaultVal)
}
if err := d.decodeValue(ctx, newValue, node); err != nil {
return newValue, errors.Wrapf(err, "failed to decode value")
}
return newValue, nil
}
func (d *Decoder) keyToNodeMap(node ast.Node, ignoreMergeKey bool, getKeyOrValueNode func(*ast.MapNodeIter) ast.Node) (map[string]ast.Node, error) {
mapNode, err := d.getMapNode(node)
if err != nil {
return nil, errors.Wrapf(err, "failed to get map node")
}
keyMap := map[string]struct{}{}
keyToNodeMap := map[string]ast.Node{}
if mapNode == nil {
return keyToNodeMap, nil
}
mapIter := mapNode.MapRange()
for mapIter.Next() {
keyNode := mapIter.Key()
if keyNode.Type() == ast.MergeKeyType {
if ignoreMergeKey {
continue
}
mergeMap, err := d.keyToNodeMap(mapIter.Value(), ignoreMergeKey, getKeyOrValueNode)
if err != nil {
return nil, errors.Wrapf(err, "failed to get keyToNodeMap by MergeKey node")
}
for k, v := range mergeMap {
if err := d.validateDuplicateKey(keyMap, k, v); err != nil {
return nil, errors.Wrapf(err, "invalid struct key")
}
keyToNodeMap[k] = v
}
} else {
key, ok := d.nodeToValue(keyNode).(string)
if !ok {
return nil, errors.Wrapf(err, "failed to decode map key")
}
if err := d.validateDuplicateKey(keyMap, key, keyNode); err != nil {
return nil, errors.Wrapf(err, "invalid struct key")
}
keyToNodeMap[key] = getKeyOrValueNode(mapIter)
}
}
return keyToNodeMap, nil
}
func (d *Decoder) keyToKeyNodeMap(node ast.Node, ignoreMergeKey bool) (map[string]ast.Node, error) {
m, err := d.keyToNodeMap(node, ignoreMergeKey, func(nodeMap *ast.MapNodeIter) ast.Node { return nodeMap.Key() })
if err != nil {
return nil, errors.Wrapf(err, "failed to get keyToNodeMap")
}
return m, nil
}
func (d *Decoder) keyToValueNodeMap(node ast.Node, ignoreMergeKey bool) (map[string]ast.Node, error) {
m, err := d.keyToNodeMap(node, ignoreMergeKey, func(nodeMap *ast.MapNodeIter) ast.Node { return nodeMap.Value() })
if err != nil {
return nil, errors.Wrapf(err, "failed to get keyToNodeMap")
}
return m, nil
}
func (d *Decoder) setDefaultValueIfConflicted(v reflect.Value, fieldMap StructFieldMap) error {
typ := v.Type()
if typ.Kind() != reflect.Struct {
return nil
}
embeddedStructFieldMap, err := structFieldMap(typ)
if err != nil {
return errors.Wrapf(err, "failed to get struct field map by embedded type")
}
for i := 0; i < typ.NumField(); i++ {
field := typ.Field(i)
if isIgnoredStructField(field) {
continue
}
structField := embeddedStructFieldMap[field.Name]
if !fieldMap.isIncludedRenderName(structField.RenderName) {
continue
}
// if declared same key name, set default value
fieldValue := v.Field(i)
if fieldValue.CanSet() {
fieldValue.Set(reflect.Zero(fieldValue.Type()))
}
}
return nil
}
// This is a subset of the formats allowed by the regular expression
// defined at http://yaml.org/type/timestamp.html.
var allowedTimestampFormats = []string{
"2006-1-2T15:4:5.999999999Z07:00", // RCF3339Nano with short date fields.
"2006-1-2t15:4:5.999999999Z07:00", // RFC3339Nano with short date fields and lower-case "t".
"2006-1-2 15:4:5.999999999", // space separated with no time zone
"2006-1-2", // date only
}
func (d *Decoder) castToTime(src ast.Node) (time.Time, error) {
if src == nil {
return time.Time{}, nil
}
v := d.nodeToValue(src)
if t, ok := v.(time.Time); ok {
return t, nil
}
s, ok := v.(string)
if !ok {
return time.Time{}, errTypeMismatch(reflect.TypeOf(time.Time{}), reflect.TypeOf(v))
}
for _, format := range allowedTimestampFormats {
t, err := time.Parse(format, s)
if err != nil {
// invalid format
continue
}
return t, nil
}
return time.Time{}, nil
}
func (d *Decoder) decodeTime(ctx context.Context, dst reflect.Value, src ast.Node) error {
t, err := d.castToTime(src)
if err != nil {
return errors.Wrapf(err, "failed to convert to time")
}
dst.Set(reflect.ValueOf(t))
return nil
}
func (d *Decoder) castToDuration(src ast.Node) (time.Duration, error) {
if src == nil {
return 0, nil
}
v := d.nodeToValue(src)
if t, ok := v.(time.Duration); ok {
return t, nil
}
s, ok := v.(string)
if !ok {
return 0, errTypeMismatch(reflect.TypeOf(time.Duration(0)), reflect.TypeOf(v))
}
t, err := time.ParseDuration(s)
if err != nil {
return 0, errors.Wrapf(err, "failed to parse duration")
}
return t, nil
}
func (d *Decoder) decodeDuration(ctx context.Context, dst reflect.Value, src ast.Node) error {
t, err := d.castToDuration(src)
if err != nil {
return errors.Wrapf(err, "failed to convert to duration")
}
dst.Set(reflect.ValueOf(t))
return nil
}
// getMergeAliasName support single alias only
func (d *Decoder) getMergeAliasName(src ast.Node) string {
mapNode, err := d.getMapNode(src)
if err != nil {
return ""
}
if mapNode == nil {
return ""
}
mapIter := mapNode.MapRange()
for mapIter.Next() {
key := mapIter.Key()
value := mapIter.Value()
if key.Type() == ast.MergeKeyType && value.Type() == ast.AliasType {
return value.(*ast.AliasNode).Value.GetToken().Value
}
}
return ""
}
func (d *Decoder) decodeStruct(ctx context.Context, dst reflect.Value, src ast.Node) error {
if src == nil {
return nil
}
structType := dst.Type()
srcValue := reflect.ValueOf(src)
srcType := srcValue.Type()
if srcType.Kind() == reflect.Ptr {
srcType = srcType.Elem()
srcValue = srcValue.Elem()
}
if structType == srcType {
// dst value implements ast.Node
dst.Set(srcValue)
return nil
}
structFieldMap, err := structFieldMap(structType)
if err != nil {
return errors.Wrapf(err, "failed to create struct field map")
}
ignoreMergeKey := structFieldMap.hasMergeProperty()
keyToNodeMap, err := d.keyToValueNodeMap(src, ignoreMergeKey)
if err != nil {
return errors.Wrapf(err, "failed to get keyToValueNodeMap")
}
var unknownFields map[string]ast.Node
if d.disallowUnknownField {
unknownFields, err = d.keyToKeyNodeMap(src, ignoreMergeKey)
if err != nil {
return errors.Wrapf(err, "failed to get keyToKeyNodeMap")
}
}
aliasName := d.getMergeAliasName(src)
var foundErr error
for i := 0; i < structType.NumField(); i++ {
field := structType.Field(i)
if isIgnoredStructField(field) {
continue
}
structField := structFieldMap[field.Name]
if structField.IsInline {
fieldValue := dst.FieldByName(field.Name)
if structField.IsAutoAlias {
if aliasName != "" {
newFieldValue := d.anchorValueMap[aliasName]
if newFieldValue.IsValid() {
fieldValue.Set(d.castToAssignableValue(newFieldValue, fieldValue.Type()))
}
}
continue
}
if !fieldValue.CanSet() {
return xerrors.Errorf("cannot set embedded type as unexported field %s.%s", field.PkgPath, field.Name)
}
if fieldValue.Type().Kind() == reflect.Ptr && src.Type() == ast.NullType {
// set nil value to pointer
fieldValue.Set(reflect.Zero(fieldValue.Type()))
continue
}
mapNode := ast.Mapping(nil, false)
for k, v := range keyToNodeMap {
key := &ast.StringNode{BaseNode: &ast.BaseNode{}, Value: k}
mapNode.Values = append(mapNode.Values, ast.MappingValue(nil, key, v))
}
newFieldValue, err := d.createDecodedNewValue(ctx, fieldValue.Type(), fieldValue, mapNode)
if d.disallowUnknownField {
if err := d.deleteStructKeys(fieldValue.Type(), unknownFields); err != nil {
return errors.Wrapf(err, "cannot delete struct keys")
}
}
if err != nil {
if foundErr != nil {
continue
}
var te *typeError
if xerrors.As(err, &te) {
if te.structFieldName != nil {
fieldName := fmt.Sprintf("%s.%s", structType.Name(), *te.structFieldName)
te.structFieldName = &fieldName
} else {
fieldName := fmt.Sprintf("%s.%s", structType.Name(), field.Name)
te.structFieldName = &fieldName
}
foundErr = te
continue
} else {
foundErr = err
}
continue
}
d.setDefaultValueIfConflicted(newFieldValue, structFieldMap)
fieldValue.Set(d.castToAssignableValue(newFieldValue, fieldValue.Type()))
continue
}
v, exists := keyToNodeMap[structField.RenderName]
if !exists {
continue
}
delete(unknownFields, structField.RenderName)
fieldValue := dst.FieldByName(field.Name)
if fieldValue.Type().Kind() == reflect.Ptr && src.Type() == ast.NullType {
// set nil value to pointer
fieldValue.Set(reflect.Zero(fieldValue.Type()))
continue
}
newFieldValue, err := d.createDecodedNewValue(ctx, fieldValue.Type(), fieldValue, v)
if err != nil {
if foundErr != nil {
continue
}
var te *typeError
if xerrors.As(err, &te) {
fieldName := fmt.Sprintf("%s.%s", structType.Name(), field.Name)
te.structFieldName = &fieldName
foundErr = te
} else {
foundErr = err
}
continue
}
fieldValue.Set(d.castToAssignableValue(newFieldValue, fieldValue.Type()))
}
if foundErr != nil {
return errors.Wrapf(foundErr, "failed to decode value")
}
// Ignore unknown fields when parsing an inline struct (recognized by a nil token).
// Unknown fields are expected (they could be fields from the parent struct).
if len(unknownFields) != 0 && d.disallowUnknownField && src.GetToken() != nil {
for key, node := range unknownFields {
return errUnknownField(fmt.Sprintf(`unknown field "%s"`, key), node.GetToken())
}
}
if d.validator != nil {
if err := d.validator.Struct(dst.Interface()); err != nil {
ev := reflect.ValueOf(err)
if ev.Type().Kind() == reflect.Slice {
for i := 0; i < ev.Len(); i++ {
fieldErr, ok := ev.Index(i).Interface().(FieldError)
if !ok {
continue
}
fieldName := fieldErr.StructField()
structField, exists := structFieldMap[fieldName]
if !exists {
continue
}
node, exists := keyToNodeMap[structField.RenderName]
if exists {
// TODO: to make FieldError message cutomizable
return errors.ErrSyntax(fmt.Sprintf("%s", err), node.GetToken())
} else if t := src.GetToken(); t != nil && t.Prev != nil && t.Prev.Prev != nil {
// A missing required field will not be in the keyToNodeMap
// the error needs to be associated with the parent of the source node
return errors.ErrSyntax(fmt.Sprintf("%s", err), t.Prev.Prev)
}
}
}
return err
}
}
return nil
}
func (d *Decoder) decodeArray(ctx context.Context, dst reflect.Value, src ast.Node) error {
arrayNode, err := d.getArrayNode(src)
if err != nil {
return errors.Wrapf(err, "failed to get array node")
}
if arrayNode == nil {
return nil
}
iter := arrayNode.ArrayRange()
arrayValue := reflect.New(dst.Type()).Elem()
arrayType := dst.Type()
elemType := arrayType.Elem()
idx := 0
var foundErr error
for iter.Next() {
v := iter.Value()
if elemType.Kind() == reflect.Ptr && v.Type() == ast.NullType {
// set nil value to pointer
arrayValue.Index(idx).Set(reflect.Zero(elemType))
} else {
dstValue, err := d.createDecodedNewValue(ctx, elemType, reflect.Value{}, v)
if err != nil {
if foundErr == nil {
foundErr = err
}
continue
} else {
arrayValue.Index(idx).Set(d.castToAssignableValue(dstValue, elemType))
}
}
idx++
}
dst.Set(arrayValue)
if foundErr != nil {
return errors.Wrapf(foundErr, "failed to decode value")
}
return nil
}
func (d *Decoder) decodeSlice(ctx context.Context, dst reflect.Value, src ast.Node) error {
arrayNode, err := d.getArrayNode(src)
if err != nil {
return errors.Wrapf(err, "failed to get array node")
}
if arrayNode == nil {
return nil
}
iter := arrayNode.ArrayRange()
sliceType := dst.Type()
sliceValue := reflect.MakeSlice(sliceType, 0, iter.Len())
elemType := sliceType.Elem()
var foundErr error
for iter.Next() {
v := iter.Value()
if elemType.Kind() == reflect.Ptr && v.Type() == ast.NullType {
// set nil value to pointer
sliceValue = reflect.Append(sliceValue, reflect.Zero(elemType))
continue
}
dstValue, err := d.createDecodedNewValue(ctx, elemType, reflect.Value{}, v)
if err != nil {
if foundErr == nil {
foundErr = err
}
continue
}
sliceValue = reflect.Append(sliceValue, d.castToAssignableValue(dstValue, elemType))
}
dst.Set(sliceValue)
if foundErr != nil {
return errors.Wrapf(foundErr, "failed to decode value")
}
return nil
}
func (d *Decoder) decodeMapItem(ctx context.Context, dst *MapItem, src ast.Node) error {
mapNode, err := d.getMapNode(src)
if err != nil {
return errors.Wrapf(err, "failed to get map node")
}
if mapNode == nil {
return nil
}
mapIter := mapNode.MapRange()
if !mapIter.Next() {
return nil
}
key := mapIter.Key()
value := mapIter.Value()
if key.Type() == ast.MergeKeyType {
if err := d.decodeMapItem(ctx, dst, value); err != nil {
return errors.Wrapf(err, "failed to decode map with merge key")
}
return nil
}
*dst = MapItem{
Key: d.nodeToValue(key),
Value: d.nodeToValue(value),
}
return nil
}
func (d *Decoder) validateDuplicateKey(keyMap map[string]struct{}, key interface{}, keyNode ast.Node) error {
k, ok := key.(string)
if !ok {
return nil
}
if d.disallowDuplicateKey {
if _, exists := keyMap[k]; exists {
return errDuplicateKey(fmt.Sprintf(`duplicate key "%s"`, k), keyNode.GetToken())
}
}
keyMap[k] = struct{}{}
return nil
}
func (d *Decoder) decodeMapSlice(ctx context.Context, dst *MapSlice, src ast.Node) error {
mapNode, err := d.getMapNode(src)
if err != nil {
return errors.Wrapf(err, "failed to get map node")
}
if mapNode == nil {
return nil
}
mapSlice := MapSlice{}
mapIter := mapNode.MapRange()
keyMap := map[string]struct{}{}
for mapIter.Next() {
key := mapIter.Key()
value := mapIter.Value()
if key.Type() == ast.MergeKeyType {
var m MapSlice
if err := d.decodeMapSlice(ctx, &m, value); err != nil {
return errors.Wrapf(err, "failed to decode map with merge key")
}
for _, v := range m {
if err := d.validateDuplicateKey(keyMap, v.Key, value); err != nil {
return errors.Wrapf(err, "invalid map key")
}
mapSlice = append(mapSlice, v)
}
continue
}
k := d.nodeToValue(key)
if err := d.validateDuplicateKey(keyMap, k, key); err != nil {
return errors.Wrapf(err, "invalid map key")
}
mapSlice = append(mapSlice, MapItem{
Key: k,
Value: d.nodeToValue(value),
})
}
*dst = mapSlice
return nil
}
func (d *Decoder) decodeMap(ctx context.Context, dst reflect.Value, src ast.Node) error {
mapNode, err := d.getMapNode(src)
if err != nil {
return errors.Wrapf(err, "failed to get map node")
}
if mapNode == nil {
return nil
}
mapType := dst.Type()
mapValue := reflect.MakeMap(mapType)
keyType := mapValue.Type().Key()
valueType := mapValue.Type().Elem()
mapIter := mapNode.MapRange()
keyMap := map[string]struct{}{}
var foundErr error
for mapIter.Next() {
key := mapIter.Key()
value := mapIter.Value()
if key.Type() == ast.MergeKeyType {
if err := d.decodeMap(ctx, dst, value); err != nil {
return errors.Wrapf(err, "failed to decode map with merge key")
}
iter := dst.MapRange()
for iter.Next() {
if err := d.validateDuplicateKey(keyMap, iter.Key(), value); err != nil {
return errors.Wrapf(err, "invalid map key")
}
mapValue.SetMapIndex(iter.Key(), iter.Value())
}
continue
}
k := reflect.ValueOf(d.nodeToValue(key))
if k.IsValid() && k.Type().ConvertibleTo(keyType) {
k = k.Convert(keyType)
}
if k.IsValid() {
if err := d.validateDuplicateKey(keyMap, k.Interface(), key); err != nil {
return errors.Wrapf(err, "invalid map key")
}
}
if valueType.Kind() == reflect.Ptr && value.Type() == ast.NullType {
// set nil value to pointer
mapValue.SetMapIndex(k, reflect.Zero(valueType))
continue
}
dstValue, err := d.createDecodedNewValue(ctx, valueType, reflect.Value{}, value)
if err != nil {
if foundErr == nil {
foundErr = err
}
}
if !k.IsValid() {
// expect nil key
mapValue.SetMapIndex(d.createDecodableValue(keyType), d.castToAssignableValue(dstValue, valueType))
continue
}
mapValue.SetMapIndex(k, d.castToAssignableValue(dstValue, valueType))
}
dst.Set(mapValue)
if foundErr != nil {
return errors.Wrapf(foundErr, "failed to decode value")
}
return nil
}
func (d *Decoder) fileToReader(file string) (io.Reader, error) {
reader, err := os.Open(file)
if err != nil {
return nil, errors.Wrapf(err, "failed to open file")
}
return reader, nil
}
func (d *Decoder) isYAMLFile(file string) bool {
ext := filepath.Ext(file)
if ext == ".yml" {
return true
}
if ext == ".yaml" {
return true
}
return false
}
func (d *Decoder) readersUnderDir(dir string) ([]io.Reader, error) {
pattern := fmt.Sprintf("%s/*", dir)
matches, err := filepath.Glob(pattern)
if err != nil {
return nil, errors.Wrapf(err, "failed to get files by %s", pattern)
}
readers := []io.Reader{}
for _, match := range matches {
if !d.isYAMLFile(match) {
continue
}
reader, err := d.fileToReader(match)
if err != nil {
return nil, errors.Wrapf(err, "failed to get reader")
}
readers = append(readers, reader)
}
return readers, nil
}
func (d *Decoder) readersUnderDirRecursive(dir string) ([]io.Reader, error) {
readers := []io.Reader{}
if err := filepath.Walk(dir, func(path string, info os.FileInfo, err error) error {
if !d.isYAMLFile(path) {
return nil
}
reader, err := d.fileToReader(path)
if err != nil {
return errors.Wrapf(err, "failed to get reader")
}
readers = append(readers, reader)
return nil
}); err != nil {
return nil, errors.Wrapf(err, "interrupt walk in %s", dir)
}
return readers, nil
}
func (d *Decoder) resolveReference() error {
for _, opt := range d.opts {
if err := opt(d); err != nil {
return errors.Wrapf(err, "failed to exec option")
}
}
for _, file := range d.referenceFiles {
reader, err := d.fileToReader(file)
if err != nil {
return errors.Wrapf(err, "failed to get reader")
}
d.referenceReaders = append(d.referenceReaders, reader)
}
for _, dir := range d.referenceDirs {
if !d.isRecursiveDir {
readers, err := d.readersUnderDir(dir)
if err != nil {
return errors.Wrapf(err, "failed to get readers from under the %s", dir)
}
d.referenceReaders = append(d.referenceReaders, readers...)
} else {
readers, err := d.readersUnderDirRecursive(dir)
if err != nil {
return errors.Wrapf(err, "failed to get readers from under the %s", dir)
}
d.referenceReaders = append(d.referenceReaders, readers...)
}
}
for _, reader := range d.referenceReaders {
bytes, err := ioutil.ReadAll(reader)
if err != nil {
return errors.Wrapf(err, "failed to read buffer")
}
// assign new anchor definition to anchorMap
if _, err := d.parse(bytes); err != nil {
return errors.Wrapf(err, "failed to decode")
}
}
d.isResolvedReference = true
return nil
}
func (d *Decoder) parse(bytes []byte) (*ast.File, error) {
var parseMode parser.Mode
if d.toCommentMap != nil {
parseMode = parser.ParseComments
}
f, err := parser.ParseBytes(bytes, parseMode)
if err != nil {
return nil, errors.Wrapf(err, "failed to parse yaml")
}
normalizedFile := &ast.File{}
for _, doc := range f.Docs {
// try to decode ast.Node to value and map anchor value to anchorMap
if v := d.nodeToValue(doc.Body); v != nil {
normalizedFile.Docs = append(normalizedFile.Docs, doc)
}
}
return normalizedFile, nil
}
func (d *Decoder) isInitialized() bool {
return d.parsedFile != nil
}
func (d *Decoder) decodeInit() error {
if !d.isResolvedReference {
if err := d.resolveReference(); err != nil {
return errors.Wrapf(err, "failed to resolve reference")
}
}
var buf bytes.Buffer
if _, err := io.Copy(&buf, d.reader); err != nil {
return errors.Wrapf(err, "failed to copy from reader")
}
file, err := d.parse(buf.Bytes())
if err != nil {
return errors.Wrapf(err, "failed to decode")
}
d.parsedFile = file
return nil
}
func (d *Decoder) decode(ctx context.Context, v reflect.Value) error {
if len(d.parsedFile.Docs) <= d.streamIndex {
return io.EOF
}
body := d.parsedFile.Docs[d.streamIndex].Body
if body == nil {
return nil
}
if err := d.decodeValue(ctx, v.Elem(), body); err != nil {
return errors.Wrapf(err, "failed to decode value")
}
d.streamIndex++
return nil
}
// Decode reads the next YAML-encoded value from its input
// and stores it in the value pointed to by v.
//
// See the documentation for Unmarshal for details about the
// conversion of YAML into a Go value.
func (d *Decoder) Decode(v interface{}) error {
return d.DecodeContext(context.Background(), v)
}
// DecodeContext reads the next YAML-encoded value from its input
// and stores it in the value pointed to by v with context.Context.
func (d *Decoder) DecodeContext(ctx context.Context, v interface{}) error {
rv := reflect.ValueOf(v)
if rv.Type().Kind() != reflect.Ptr {
return errors.ErrDecodeRequiredPointerType
}
if d.isInitialized() {
if err := d.decode(ctx, rv); err != nil {
if err == io.EOF {
return err
}
return errors.Wrapf(err, "failed to decode")
}
return nil
}
if err := d.decodeInit(); err != nil {
return errors.Wrapf(err, "failed to decodeInit")
}
if err := d.decode(ctx, rv); err != nil {
if err == io.EOF {
return err
}
return errors.Wrapf(err, "failed to decode")
}
return nil
}
// DecodeFromNode decodes node into the value pointed to by v.
func (d *Decoder) DecodeFromNode(node ast.Node, v interface{}) error {
return d.DecodeFromNodeContext(context.Background(), node, v)
}
// DecodeFromNodeContext decodes node into the value pointed to by v with context.Context.
func (d *Decoder) DecodeFromNodeContext(ctx context.Context, node ast.Node, v interface{}) error {
rv := reflect.ValueOf(v)
if rv.Type().Kind() != reflect.Ptr {
return errors.ErrDecodeRequiredPointerType
}
if !d.isInitialized() {
if err := d.decodeInit(); err != nil {
return errors.Wrapf(err, "failed to decodInit")
}
}
// resolve references to the anchor on the same file
d.nodeToValue(node)
if err := d.decodeValue(ctx, rv.Elem(), node); err != nil {
return errors.Wrapf(err, "failed to decode value")
}
return nil
}