1// Copyright 2021 Google Inc. All rights reserved. 2// 3// Licensed under the Apache License, Version 2.0 (the "License"); 4// you may not use this file except in compliance with the License. 5// You may obtain a copy of the License at 6// 7// http://www.apache.org/licenses/LICENSE-2.0 8// 9// Unless required by applicable law or agreed to in writing, software 10// distributed under the License is distributed on an "AS IS" BASIS, 11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12// See the License for the specific language governing permissions and 13// limitations under the License. 14 15package main 16 17import ( 18 "bytes" 19 "io" 20 "regexp" 21 "testing" 22 "time" 23) 24 25func Test_runWithTimeout(t *testing.T) { 26 type args struct { 27 command string 28 args []string 29 timeout time.Duration 30 onTimeoutCmd string 31 stdin io.Reader 32 } 33 tests := []struct { 34 name string 35 args args 36 wantStdout string 37 wantStderr string 38 wantErr bool 39 }{ 40 { 41 name: "no timeout", 42 args: args{ 43 command: "echo", 44 args: []string{"foo"}, 45 }, 46 wantStdout: "foo\n", 47 }, 48 { 49 name: "timeout not reached", 50 args: args{ 51 command: "echo", 52 args: []string{"foo"}, 53 timeout: 10 * time.Second, 54 }, 55 wantStdout: "foo\n", 56 }, 57 { 58 name: "timed out", 59 args: args{ 60 command: "sh", 61 args: []string{"-c", "sleep 10 && echo foo"}, 62 timeout: 1 * time.Millisecond, 63 }, 64 wantStderr: ".*: process timed out after .*\n", 65 wantErr: true, 66 }, 67 { 68 name: "on_timeout command", 69 args: args{ 70 command: "sh", 71 args: []string{"-c", "sleep 10 && echo foo"}, 72 timeout: 1 * time.Millisecond, 73 onTimeoutCmd: "echo bar", 74 }, 75 wantStdout: "bar\n", 76 wantStderr: ".*: process timed out after .*\n.*: running on_timeout command `echo bar`\n", 77 wantErr: true, 78 }, 79 } 80 for _, tt := range tests { 81 t.Run(tt.name, func(t *testing.T) { 82 stdout := &bytes.Buffer{} 83 stderr := &bytes.Buffer{} 84 err := runWithTimeout(tt.args.command, tt.args.args, tt.args.timeout, tt.args.onTimeoutCmd, tt.args.stdin, stdout, stderr) 85 if (err != nil) != tt.wantErr { 86 t.Errorf("runWithTimeout() error = %v, wantErr %v", err, tt.wantErr) 87 return 88 } 89 if gotStdout := stdout.String(); gotStdout != tt.wantStdout { 90 t.Errorf("runWithTimeout() gotStdout = %v, want %v", gotStdout, tt.wantStdout) 91 } 92 if gotStderr := stderr.String(); !regexp.MustCompile(tt.wantStderr).MatchString(gotStderr) { 93 t.Errorf("runWithTimeout() gotStderr = %v, want %v", gotStderr, tt.wantStderr) 94 } 95 }) 96 } 97} 98