1// Copyright 2021 Google Inc. All rights reserved.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15package main
16
17import (
18	"bytes"
19	"io"
20	"regexp"
21	"testing"
22	"time"
23)
24
25func Test_runWithTimeout(t *testing.T) {
26	type args struct {
27		command      string
28		args         []string
29		timeout      time.Duration
30		onTimeoutCmd string
31		stdin        io.Reader
32	}
33	tests := []struct {
34		name       string
35		args       args
36		wantStdout string
37		wantStderr string
38		wantErr    bool
39	}{
40		{
41			name: "no timeout",
42			args: args{
43				command: "echo",
44				args:    []string{"foo"},
45			},
46			wantStdout: "foo\n",
47		},
48		{
49			name: "timeout not reached",
50			args: args{
51				command: "echo",
52				args:    []string{"foo"},
53				timeout: 10 * time.Second,
54			},
55			wantStdout: "foo\n",
56		},
57		{
58			name: "timed out",
59			args: args{
60				command: "sh",
61				args:    []string{"-c", "sleep 10 && echo foo"},
62				timeout: 1 * time.Millisecond,
63			},
64			wantStderr: ".*: process timed out after .*\n",
65			wantErr:    true,
66		},
67		{
68			name: "on_timeout command",
69			args: args{
70				command:      "sh",
71				args:         []string{"-c", "sleep 10 && echo foo"},
72				timeout:      1 * time.Millisecond,
73				onTimeoutCmd: "echo bar",
74			},
75			wantStdout: "bar\n",
76			wantStderr: ".*: process timed out after .*\n.*: running on_timeout command `echo bar`\n",
77			wantErr:    true,
78		},
79	}
80	for _, tt := range tests {
81		t.Run(tt.name, func(t *testing.T) {
82			stdout := &bytes.Buffer{}
83			stderr := &bytes.Buffer{}
84			err := runWithTimeout(tt.args.command, tt.args.args, tt.args.timeout, tt.args.onTimeoutCmd, tt.args.stdin, stdout, stderr)
85			if (err != nil) != tt.wantErr {
86				t.Errorf("runWithTimeout() error = %v, wantErr %v", err, tt.wantErr)
87				return
88			}
89			if gotStdout := stdout.String(); gotStdout != tt.wantStdout {
90				t.Errorf("runWithTimeout() gotStdout = %v, want %v", gotStdout, tt.wantStdout)
91			}
92			if gotStderr := stderr.String(); !regexp.MustCompile(tt.wantStderr).MatchString(gotStderr) {
93				t.Errorf("runWithTimeout() gotStderr = %v, want %v", gotStderr, tt.wantStderr)
94			}
95		})
96	}
97}
98