| // Copyright (c) 2023, Google Inc. |
| // |
| // Permission to use, copy, modify, and/or distribute this software for any |
| // purpose with or without fee is hereby granted, provided that the above |
| // copyright notice and this permission notice appear in all copies. |
| // |
| // THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES |
| // WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF |
| // MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY |
| // SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES |
| // WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN ACTION |
| // OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF OR IN |
| // CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. |
| |
| package runner |
| |
| import ( |
| "context" |
| "encoding/binary" |
| "fmt" |
| "io" |
| "net" |
| "os" |
| "sync" |
| "time" |
| ) |
| |
| type shimDispatcher struct { |
| lock sync.Mutex |
| nextShimID uint64 |
| listener *net.TCPListener |
| shims map[uint64]*shimListener |
| err error |
| } |
| |
| func newShimDispatcher() (*shimDispatcher, error) { |
| listener, err := net.ListenTCP("tcp", &net.TCPAddr{IP: net.IPv6loopback}) |
| if err != nil { |
| listener, err = net.ListenTCP("tcp4", &net.TCPAddr{IP: net.IP{127, 0, 0, 1}}) |
| } |
| |
| if err != nil { |
| return nil, err |
| } |
| d := &shimDispatcher{listener: listener, shims: make(map[uint64]*shimListener)} |
| go d.acceptLoop() |
| return d, nil |
| } |
| |
| func (d *shimDispatcher) NewShim() (*shimListener, error) { |
| d.lock.Lock() |
| defer d.lock.Unlock() |
| if d.err != nil { |
| return nil, d.err |
| } |
| |
| l := &shimListener{dispatcher: d, shimID: d.nextShimID, connChan: make(chan net.Conn, 1)} |
| d.shims[l.shimID] = l |
| d.nextShimID++ |
| return l, nil |
| } |
| |
| func (d *shimDispatcher) unregisterShim(l *shimListener) { |
| d.lock.Lock() |
| delete(d.shims, l.shimID) |
| d.lock.Unlock() |
| } |
| |
| func (d *shimDispatcher) acceptLoop() { |
| for { |
| conn, err := d.listener.Accept() |
| if err != nil { |
| // Something went wrong. Shut down the listener. |
| d.closeWithError(err) |
| return |
| } |
| |
| go func() { |
| if err := d.dispatch(conn); err != nil { |
| // To be robust against port scanners, etc., we log a warning |
| // but otherwise treat undispatchable connections as non-fatal. |
| fmt.Fprintf(os.Stderr, "Error dispatching connection: %s\n", err) |
| conn.Close() |
| } |
| }() |
| } |
| } |
| |
| func (d *shimDispatcher) dispatch(conn net.Conn) error { |
| conn.SetReadDeadline(time.Now().Add(*idleTimeout)) |
| var buf [8]byte |
| if _, err := io.ReadFull(conn, buf[:]); err != nil { |
| return err |
| } |
| conn.SetReadDeadline(time.Time{}) |
| |
| shimID := binary.LittleEndian.Uint64(buf[:]) |
| d.lock.Lock() |
| shim, ok := d.shims[shimID] |
| d.lock.Unlock() |
| if !ok { |
| return fmt.Errorf("shim ID %d not found", shimID) |
| } |
| |
| shim.connChan <- conn |
| return nil |
| } |
| |
| func (d *shimDispatcher) Close() error { |
| return d.closeWithError(net.ErrClosed) |
| } |
| |
| func (d *shimDispatcher) closeWithError(err error) error { |
| closeErr := d.listener.Close() |
| |
| d.lock.Lock() |
| shims := d.shims |
| d.shims = make(map[uint64]*shimListener) |
| d.err = err |
| d.lock.Unlock() |
| |
| for _, shim := range shims { |
| shim.closeWithError(err) |
| } |
| return closeErr |
| } |
| |
| type shimListener struct { |
| dispatcher *shimDispatcher |
| shimID uint64 |
| // connChan contains connections from the dispatcher. On fatal error, it is |
| // closed, with the error available in err. |
| connChan chan net.Conn |
| err error |
| lock sync.Mutex |
| } |
| |
| func (l *shimListener) Port() int { |
| return l.dispatcher.listener.Addr().(*net.TCPAddr).Port |
| } |
| |
| func (l *shimListener) IsIPv6() bool { |
| return len(l.dispatcher.listener.Addr().(*net.TCPAddr).IP) == net.IPv6len |
| } |
| |
| func (l *shimListener) ShimID() uint64 { |
| return l.shimID |
| } |
| |
| func (l *shimListener) Close() error { |
| l.dispatcher.unregisterShim(l) |
| l.closeWithError(net.ErrClosed) |
| return nil |
| } |
| |
| func (l *shimListener) closeWithError(err error) { |
| // Multiple threads may close the listener at once, so protect closing with |
| // a lock. |
| l.lock.Lock() |
| if l.err == nil { |
| l.err = err |
| close(l.connChan) |
| } |
| l.lock.Unlock() |
| } |
| |
| func (l *shimListener) Accept(deadline time.Time) (net.Conn, error) { |
| var timerChan <-chan time.Time |
| if !deadline.IsZero() { |
| remaining := time.Until(deadline) |
| if remaining < 0 { |
| return nil, context.DeadlineExceeded |
| } |
| timer := time.NewTimer(remaining) |
| defer timer.Stop() |
| timerChan = timer.C |
| } |
| |
| select { |
| case <-timerChan: |
| return nil, context.DeadlineExceeded |
| case conn, ok := <-l.connChan: |
| if !ok { |
| return nil, l.err |
| } |
| return conn, nil |
| } |
| } |