1// Copyright 2023 Google Inc. All rights reserved.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15package proptools
16
17import (
18	"cmp"
19	"encoding/binary"
20	"fmt"
21	"hash"
22	"hash/fnv"
23	"math"
24	"reflect"
25	"sort"
26	"unsafe"
27)
28
29// byte to insert between elements of lists, fields of structs/maps, etc in order
30// to try and make sure the hash is different when values are moved around between
31// elements. 36 is arbitrary, but it's the ascii code for a record separator
32var recordSeparator []byte = []byte{36}
33
34func CalculateHash(value interface{}) (uint64, error) {
35	hasher := fnv.New64()
36	ptrs := make(map[uintptr]bool)
37	v := reflect.ValueOf(value)
38	var err error
39	if v.IsValid() {
40		err = calculateHashInternal(hasher, v, ptrs)
41	}
42	return hasher.Sum64(), err
43}
44
45func calculateHashInternal(hasher hash.Hash64, v reflect.Value, ptrs map[uintptr]bool) error {
46	var int64Array [8]byte
47	int64Buf := int64Array[:]
48	binary.LittleEndian.PutUint64(int64Buf, uint64(v.Kind()))
49	hasher.Write(int64Buf)
50	v.IsValid()
51	switch v.Kind() {
52	case reflect.Struct:
53		binary.LittleEndian.PutUint64(int64Buf, uint64(v.NumField()))
54		hasher.Write(int64Buf)
55		for i := 0; i < v.NumField(); i++ {
56			hasher.Write(recordSeparator)
57			err := calculateHashInternal(hasher, v.Field(i), ptrs)
58			if err != nil {
59				return fmt.Errorf("in field %s: %s", v.Type().Field(i).Name, err.Error())
60			}
61		}
62	case reflect.Map:
63		binary.LittleEndian.PutUint64(int64Buf, uint64(v.Len()))
64		hasher.Write(int64Buf)
65		indexes := make([]int, v.Len())
66		keys := make([]reflect.Value, v.Len())
67		values := make([]reflect.Value, v.Len())
68		iter := v.MapRange()
69		for i := 0; iter.Next(); i++ {
70			indexes[i] = i
71			keys[i] = iter.Key()
72			values[i] = iter.Value()
73		}
74		sort.SliceStable(indexes, func(i, j int) bool {
75			return compare_values(keys[indexes[i]], keys[indexes[j]]) < 0
76		})
77		for i := 0; i < v.Len(); i++ {
78			hasher.Write(recordSeparator)
79			err := calculateHashInternal(hasher, keys[indexes[i]], ptrs)
80			if err != nil {
81				return fmt.Errorf("in map: %s", err.Error())
82			}
83			hasher.Write(recordSeparator)
84			err = calculateHashInternal(hasher, keys[indexes[i]], ptrs)
85			if err != nil {
86				return fmt.Errorf("in map: %s", err.Error())
87			}
88		}
89	case reflect.Slice, reflect.Array:
90		binary.LittleEndian.PutUint64(int64Buf, uint64(v.Len()))
91		hasher.Write(int64Buf)
92		for i := 0; i < v.Len(); i++ {
93			hasher.Write(recordSeparator)
94			err := calculateHashInternal(hasher, v.Index(i), ptrs)
95			if err != nil {
96				return fmt.Errorf("in %s at index %d: %s", v.Kind().String(), i, err.Error())
97			}
98		}
99	case reflect.Pointer:
100		if v.IsNil() {
101			int64Buf[0] = 0
102			hasher.Write(int64Buf[:1])
103			return nil
104		}
105		// Hardcoded value to indicate it is a pointer
106		binary.LittleEndian.PutUint64(int64Buf, uint64(0x55))
107		hasher.Write(int64Buf)
108		addr := v.Pointer()
109		if _, ok := ptrs[addr]; ok {
110			// We could make this an error if we want to disallow pointer cycles in the future
111			return nil
112		}
113		ptrs[addr] = true
114		err := calculateHashInternal(hasher, v.Elem(), ptrs)
115		if err != nil {
116			return fmt.Errorf("in pointer: %s", err.Error())
117		}
118	case reflect.Interface:
119		if v.IsNil() {
120			int64Buf[0] = 0
121			hasher.Write(int64Buf[:1])
122		} else {
123			// The only way get the pointer out of an interface to hash it or check for cycles
124			// would be InterfaceData(), but that's deprecated and seems like it has undefined behavior.
125			err := calculateHashInternal(hasher, v.Elem(), ptrs)
126			if err != nil {
127				return fmt.Errorf("in interface: %s", err.Error())
128			}
129		}
130	case reflect.String:
131		strLen := len(v.String())
132		if strLen == 0 {
133			// unsafe.StringData is unspecified in this case
134			int64Buf[0] = 0
135			hasher.Write(int64Buf[:1])
136			return nil
137		}
138		hasher.Write(unsafe.Slice(unsafe.StringData(v.String()), strLen))
139	case reflect.Bool:
140		if v.Bool() {
141			int64Buf[0] = 1
142		} else {
143			int64Buf[0] = 0
144		}
145		hasher.Write(int64Buf[:1])
146	case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
147		binary.LittleEndian.PutUint64(int64Buf, v.Uint())
148		hasher.Write(int64Buf)
149	case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
150		binary.LittleEndian.PutUint64(int64Buf, uint64(v.Int()))
151		hasher.Write(int64Buf)
152	case reflect.Float32, reflect.Float64:
153		binary.LittleEndian.PutUint64(int64Buf, math.Float64bits(v.Float()))
154		hasher.Write(int64Buf)
155	default:
156		return fmt.Errorf("data may only contain primitives, strings, arrays, slices, structs, maps, and pointers, found: %s", v.Kind().String())
157	}
158	return nil
159}
160
161func compare_values(x, y reflect.Value) int {
162	if x.Type() != y.Type() {
163		panic("Expected equal types")
164	}
165
166	switch x.Kind() {
167	case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
168		return cmp.Compare(x.Uint(), y.Uint())
169	case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
170		return cmp.Compare(x.Int(), y.Int())
171	case reflect.Float32, reflect.Float64:
172		return cmp.Compare(x.Float(), y.Float())
173	case reflect.String:
174		return cmp.Compare(x.String(), y.String())
175	case reflect.Bool:
176		if x.Bool() == y.Bool() {
177			return 0
178		} else if x.Bool() {
179			return 1
180		} else {
181			return -1
182		}
183	case reflect.Pointer:
184		return cmp.Compare(x.Pointer(), y.Pointer())
185	case reflect.Array:
186		for i := 0; i < x.Len(); i++ {
187			if result := compare_values(x.Index(i), y.Index(i)); result != 0 {
188				return result
189			}
190		}
191		return 0
192	case reflect.Struct:
193		for i := 0; i < x.NumField(); i++ {
194			if result := compare_values(x.Field(i), y.Field(i)); result != 0 {
195				return result
196			}
197		}
198		return 0
199	case reflect.Interface:
200		if x.IsNil() && y.IsNil() {
201			return 0
202		} else if x.IsNil() {
203			return 1
204		} else if y.IsNil() {
205			return -1
206		}
207		return compare_values(x.Elem(), y.Elem())
208	default:
209		panic(fmt.Sprintf("Could not compare types %s and %s", x.Type().String(), y.Type().String()))
210	}
211}
212