1// Copyright 2022 Google Inc. All rights reserved.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15package starlark_fmt
16
17import (
18	"fmt"
19	"reflect"
20	"sort"
21	"strconv"
22	"strings"
23)
24
25const (
26	indent = 4
27)
28
29// Indention returns an indent string of the specified level.
30func Indention(level int) string {
31	if level < 0 {
32		panic(fmt.Errorf("indent level cannot be less than 0, but got %d", level))
33	}
34	return strings.Repeat(" ", level*indent)
35}
36
37func PrintAny(value any, indentLevel int) string {
38	return printAnyRecursive(reflect.ValueOf(value), indentLevel)
39}
40
41func printAnyRecursive(value reflect.Value, indentLevel int) string {
42	switch value.Type().Kind() {
43	case reflect.String:
44		val := value.String()
45		if strings.Contains(val, "\"") || strings.Contains(val, "\n") {
46			return `'''` + val + `'''`
47		}
48		return `"` + val + `"`
49	case reflect.Bool:
50		if value.Bool() {
51			return "True"
52		} else {
53			return "False"
54		}
55	case reflect.Int:
56		return fmt.Sprintf("%d", value.Int())
57	case reflect.Slice:
58		if value.Len() == 0 {
59			return "[]"
60		} else if value.Len() == 1 {
61			return "[" + printAnyRecursive(value.Index(0), indentLevel) + "]"
62		}
63		list := make([]string, 0, value.Len()+2)
64		list = append(list, "[")
65		innerIndent := Indention(indentLevel + 1)
66		for i := 0; i < value.Len(); i++ {
67			list = append(list, innerIndent+printAnyRecursive(value.Index(i), indentLevel+1)+`,`)
68		}
69		list = append(list, Indention(indentLevel)+"]")
70		return strings.Join(list, "\n")
71	case reflect.Map:
72		if value.Len() == 0 {
73			return "{}"
74		}
75		items := make([]string, 0, value.Len())
76		for _, key := range value.MapKeys() {
77			items = append(items, fmt.Sprintf(`%s%s: %s,`, Indention(indentLevel+1), printAnyRecursive(key, indentLevel+1), printAnyRecursive(value.MapIndex(key), indentLevel+1)))
78		}
79		sort.Strings(items)
80		return fmt.Sprintf(`{
81%s
82%s}`, strings.Join(items, "\n"), Indention(indentLevel))
83	case reflect.Struct:
84		if value.NumField() == 0 {
85			return "struct()"
86		}
87		items := make([]string, 0, value.NumField()+2)
88		items = append(items, "struct(")
89		for i := 0; i < value.NumField(); i++ {
90			if value.Type().Field(i).Anonymous {
91				panic("anonymous fields aren't supported")
92			}
93			name := value.Type().Field(i).Name
94			items = append(items, fmt.Sprintf(`%s%s = %s,`, Indention(indentLevel+1), name, printAnyRecursive(value.Field(i), indentLevel+1)))
95		}
96		items = append(items, Indention(indentLevel)+")")
97		return strings.Join(items, "\n")
98	default:
99		panic("Unhandled kind: " + value.Kind().String())
100	}
101}
102
103// PrintBool returns a Starlark compatible bool string.
104func PrintBool(item bool) string {
105	if item {
106		return "True"
107	} else {
108		return "False"
109	}
110}
111
112// PrintsStringList returns a Starlark-compatible string of a list of Strings/Labels.
113func PrintStringList(items []string, indentLevel int) string {
114	return PrintList(items, indentLevel, func(s string) string {
115		if strings.Contains(s, "\"") {
116			return `'''%s'''`
117		}
118		return `"%s"`
119	})
120}
121
122// PrintList returns a Starlark-compatible string of list formmated as requested.
123func PrintList(items []string, indentLevel int, formatString func(string) string) string {
124	if len(items) == 0 {
125		return "[]"
126	} else if len(items) == 1 {
127		return fmt.Sprintf("["+formatString(items[0])+"]", items[0])
128	}
129	list := make([]string, 0, len(items)+2)
130	list = append(list, "[")
131	innerIndent := Indention(indentLevel + 1)
132	for _, item := range items {
133		list = append(list, fmt.Sprintf(`%s`+formatString(item)+`,`, innerIndent, item))
134	}
135	list = append(list, Indention(indentLevel)+"]")
136	return strings.Join(list, "\n")
137}
138
139// PrintStringListDict returns a Starlark-compatible string formatted as dictionary with
140// string keys and list of string values.
141func PrintStringListDict(dict map[string][]string, indentLevel int) string {
142	formattedValueDict := make(map[string]string, len(dict))
143	for k, v := range dict {
144		formattedValueDict[k] = PrintStringList(v, indentLevel+1)
145	}
146	return PrintDict(formattedValueDict, indentLevel)
147}
148
149// PrintBoolDict returns a starlark-compatible string containing a dictionary with string keys and
150// values printed with no additional formatting.
151func PrintBoolDict(dict map[string]bool, indentLevel int) string {
152	formattedValueDict := make(map[string]string, len(dict))
153	for k, v := range dict {
154		formattedValueDict[k] = PrintBool(v)
155	}
156	return PrintDict(formattedValueDict, indentLevel)
157}
158
159// PrintStringIntDict returns a Starlark-compatible string formatted as dictionary with
160// string keys and int values.
161func PrintStringIntDict(dict map[string]int, indentLevel int) string {
162	valDict := make(map[string]string, len(dict))
163	for k, v := range dict {
164		valDict[k] = strconv.Itoa(v)
165	}
166	return PrintDict(valDict, indentLevel)
167}
168
169// PrintStringStringDict returns a Starlark-compatible string formatted as dictionary with
170// string keys and string values.
171func PrintStringStringDict(dict map[string]string, indentLevel int) string {
172	valDict := make(map[string]string, len(dict))
173	for k, v := range dict {
174		valDict[k] = fmt.Sprintf(`"%s"`, v)
175	}
176	return PrintDict(valDict, indentLevel)
177}
178
179// PrintDict returns a starlark-compatible string containing a dictionary with string keys and
180// values printed with no additional formatting.
181func PrintDict(dict map[string]string, indentLevel int) string {
182	if len(dict) == 0 {
183		return "{}"
184	}
185	items := make([]string, 0, len(dict))
186	for k, v := range dict {
187		items = append(items, fmt.Sprintf(`%s"%s": %s,`, Indention(indentLevel+1), k, v))
188	}
189	sort.Strings(items)
190	return fmt.Sprintf(`{
191%s
192%s}`, strings.Join(items, "\n"), Indention(indentLevel))
193}
194