1// Copyright 2021 Google Inc. All rights reserved. 2// 3// Licensed under the Apache License, Version 2.0 (the "License"); 4// you may not use this file except in compliance with the License. 5// You may obtain a copy of the License at 6// 7// http://www.apache.org/licenses/LICENSE-2.0 8// 9// Unless required by applicable law or agreed to in writing, software 10// distributed under the License is distributed on an "AS IS" BASIS, 11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12// See the License for the specific language governing permissions and 13// limitations under the License. 14 15// run_with_timeout is a utility that can kill a wrapped command after a configurable timeout, 16// optionally running a command to collect debugging information first. 17 18package main 19 20import ( 21 "flag" 22 "fmt" 23 "io" 24 "os" 25 "os/exec" 26 "sync" 27 "syscall" 28 "time" 29) 30 31var ( 32 timeout = flag.Duration("timeout", 0, "time after which to kill command (example: 60s)") 33 onTimeoutCmd = flag.String("on_timeout", "", "command to run with `PID=<pid> sh -c` after timeout.") 34) 35 36func usage() { 37 fmt.Fprintf(os.Stderr, "usage: %s [--timeout N] [--on_timeout CMD] -- command [args...]\n", os.Args[0]) 38 flag.PrintDefaults() 39 fmt.Fprintln(os.Stderr, "run_with_timeout is a utility that can kill a wrapped command after a configurable timeout,") 40 fmt.Fprintln(os.Stderr, "optionally running a command to collect debugging information first.") 41 42 os.Exit(2) 43} 44 45func main() { 46 flag.Usage = usage 47 flag.Parse() 48 49 if flag.NArg() < 1 { 50 fmt.Fprintf(os.Stderr, "%s: error: command is required\n", os.Args[0]) 51 usage() 52 } 53 54 err := runWithTimeout(flag.Arg(0), flag.Args()[1:], *timeout, *onTimeoutCmd, 55 os.Stdin, os.Stdout, os.Stderr) 56 if err != nil { 57 if exitErr, ok := err.(*exec.ExitError); ok { 58 fmt.Fprintf(os.Stderr, "%s: process exited with error: %s\n", os.Args[0], exitErr.Error()) 59 } else { 60 fmt.Fprintf(os.Stderr, "%s: error: %s\n", os.Args[0], err.Error()) 61 } 62 os.Exit(1) 63 } 64} 65 66// concurrentWriter wraps a writer to make it thread-safe to call Write. 67type concurrentWriter struct { 68 w io.Writer 69 sync.Mutex 70} 71 72// Write writes the data to the wrapped writer with a lock to allow for concurrent calls. 73func (c *concurrentWriter) Write(data []byte) (n int, err error) { 74 c.Lock() 75 defer c.Unlock() 76 if c.w == nil { 77 return 0, nil 78 } 79 return c.w.Write(data) 80} 81 82// Close ends the concurrentWriter, causing future calls to Write to be no-ops. It does not close 83// the underlying writer. 84func (c *concurrentWriter) Close() { 85 c.Lock() 86 defer c.Unlock() 87 c.w = nil 88} 89 90func runWithTimeout(command string, args []string, timeout time.Duration, onTimeoutCmdStr string, 91 stdin io.Reader, stdout, stderr io.Writer) error { 92 cmd := exec.Command(command, args...) 93 94 // Wrap the writers in a locking writer so that cmd and onTimeoutCmd don't try to write to 95 // stdout or stderr concurrently. 96 concurrentStdout := &concurrentWriter{w: stdout} 97 concurrentStderr := &concurrentWriter{w: stderr} 98 defer concurrentStdout.Close() 99 defer concurrentStderr.Close() 100 101 cmd.Stdin, cmd.Stdout, cmd.Stderr = stdin, concurrentStdout, concurrentStderr 102 err := cmd.Start() 103 if err != nil { 104 return err 105 } 106 107 // waitCh will signal the subprocess exited. 108 waitCh := make(chan error) 109 go func() { 110 waitCh <- cmd.Wait() 111 }() 112 113 // timeoutCh will signal the subprocess timed out if timeout was set. 114 var timeoutCh <-chan time.Time = make(chan time.Time) 115 if timeout > 0 { 116 timeoutCh = time.After(timeout) 117 } 118 startTime := time.Now() 119 120 select { 121 case err := <-waitCh: 122 if exitErr, ok := err.(*exec.ExitError); ok { 123 return fmt.Errorf("process exited with error: %w", exitErr) 124 } 125 return err 126 case <-timeoutCh: 127 // Continue below. 128 } 129 130 fmt.Fprintf(concurrentStderr, "%s: process timed out after %s\n", os.Args[0], time.Since(startTime)) 131 // Process timed out before exiting. 132 defer cmd.Process.Signal(syscall.SIGKILL) 133 134 if onTimeoutCmdStr != "" { 135 fmt.Fprintf(concurrentStderr, "%s: running on_timeout command `%s`\n", os.Args[0], onTimeoutCmdStr) 136 onTimeoutCmd := exec.Command("sh", "-c", onTimeoutCmdStr) 137 onTimeoutCmd.Stdin, onTimeoutCmd.Stdout, onTimeoutCmd.Stderr = stdin, concurrentStdout, concurrentStderr 138 onTimeoutCmd.Env = append(os.Environ(), fmt.Sprintf("PID=%d", cmd.Process.Pid)) 139 err := onTimeoutCmd.Run() 140 if err != nil { 141 return fmt.Errorf("on_timeout command %q exited with error: %w", onTimeoutCmdStr, err) 142 } 143 } 144 145 return fmt.Errorf("timed out after %s", timeout.String()) 146} 147