Source file src/net/sendfile_test.go

     1  // Copyright 2016 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package net
     6  
     7  import (
     8  	"bytes"
     9  	"context"
    10  	"crypto/sha256"
    11  	"encoding/hex"
    12  	"errors"
    13  	"fmt"
    14  	"internal/poll"
    15  	"io"
    16  	"math/rand"
    17  	"os"
    18  	"runtime"
    19  	"strconv"
    20  	"sync"
    21  	"testing"
    22  	"time"
    23  )
    24  
    25  const (
    26  	newton       = "../testdata/Isaac.Newton-Opticks.txt"
    27  	newtonLen    = 567198
    28  	newtonSHA256 = "d4a9ac22462b35e7821a4f2706c211093da678620a8f9997989ee7cf8d507bbd"
    29  )
    30  
    31  func hookSupportsSendfile(t *testing.T) {
    32  	if runtime.GOOS == "windows" {
    33  		origHook := testHookSupportsSendfile
    34  		testHookSupportsSendfile = func() bool { return true }
    35  		t.Cleanup(func() {
    36  			testHookSupportsSendfile = origHook
    37  		})
    38  	}
    39  }
    40  
    41  // expectSendfile runs f, and verifies that internal/poll.SendFile successfully handles
    42  // a write to wantConn during f's execution.
    43  //
    44  // On platforms where supportsSendfile() is false, expectSendfile runs f but does not
    45  // expect a call to SendFile.
    46  func expectSendfile(t *testing.T, wantConn Conn, f func()) {
    47  	t.Helper()
    48  	hookSupportsSendfile(t)
    49  	if !supportsSendfile() {
    50  		f()
    51  		return
    52  	}
    53  	orig := poll.TestHookDidSendFile
    54  	defer func() {
    55  		poll.TestHookDidSendFile = orig
    56  	}()
    57  	var (
    58  		called     bool
    59  		gotHandled bool
    60  		gotFD      *poll.FD
    61  		gotErr     error
    62  	)
    63  	poll.TestHookDidSendFile = func(dstFD *poll.FD, src uintptr, written int64, err error, handled bool) {
    64  		if called {
    65  			t.Error("internal/poll.SendFile called multiple times, want one call")
    66  		}
    67  		called = true
    68  		gotHandled = handled
    69  		gotFD = dstFD
    70  		gotErr = err
    71  	}
    72  	f()
    73  	if !called {
    74  		t.Error("internal/poll.SendFile was not called, want it to be")
    75  		return
    76  	}
    77  	if !gotHandled {
    78  		t.Error("internal/poll.SendFile did not handle the write, want it to, error:", gotErr)
    79  		return
    80  	}
    81  	if &wantConn.(*TCPConn).fd.pfd != gotFD {
    82  		t.Error("internal.poll.SendFile called with unexpected FD")
    83  	}
    84  }
    85  
    86  func TestSendfile(t *testing.T) { testSendfile(t, newton, newtonSHA256, newtonLen, 0) }
    87  func TestSendfileWithExactLimit(t *testing.T) {
    88  	testSendfile(t, newton, newtonSHA256, newtonLen, newtonLen)
    89  }
    90  func TestSendfileWithLimitLargerThanFile(t *testing.T) {
    91  	testSendfile(t, newton, newtonSHA256, newtonLen, newtonLen*2)
    92  }
    93  func TestSendfileWithLargeFile(t *testing.T) {
    94  	// Some platforms are not capable of handling large files with sendfile
    95  	// due to limited system resource, so we only run this test on amd64 and
    96  	// arm64 for the moment.
    97  	if runtime.GOARCH != "amd64" && runtime.GOARCH != "arm64" {
    98  		t.Skip("skipping on non-amd64 and non-arm64 platforms")
    99  	}
   100  	// Also skip it during short testing.
   101  	if testing.Short() {
   102  		t.Skip("Skip it during short testing")
   103  	}
   104  
   105  	// We're using 1<<31 - 1 as the chunk size for sendfile currently,
   106  	// make an edge case file that is 1 byte bigger than that.
   107  	f := createTempFile(t, 1<<31)
   108  	// For big file like this, only verify the transmission of the file,
   109  	// skip the content check.
   110  	testSendfile(t, f.Name(), "", 1<<31, 0)
   111  }
   112  func testSendfile(t *testing.T, filePath, fileHash string, size, limit int64) {
   113  	ln := newLocalListener(t, "tcp")
   114  	defer ln.Close()
   115  
   116  	errc := make(chan error, 1)
   117  	go func(ln Listener) {
   118  		// Wait for a connection.
   119  		conn, err := ln.Accept()
   120  		if err != nil {
   121  			errc <- err
   122  			close(errc)
   123  			return
   124  		}
   125  
   126  		go func() {
   127  			defer close(errc)
   128  			defer conn.Close()
   129  
   130  			f, err := os.Open(filePath)
   131  			if err != nil {
   132  				errc <- err
   133  				return
   134  			}
   135  			defer f.Close()
   136  
   137  			// Return file data using io.Copy, which should use
   138  			// sendFile if available.
   139  			var sbytes int64
   140  			expectSendfile(t, conn, func() {
   141  				if limit > 0 {
   142  					sbytes, err = io.CopyN(conn, f, limit)
   143  					if err == io.EOF && limit > size {
   144  						err = nil
   145  					}
   146  				} else {
   147  					sbytes, err = io.Copy(conn, f)
   148  				}
   149  			})
   150  			if err != nil {
   151  				errc <- err
   152  				return
   153  			}
   154  
   155  			if sbytes != size {
   156  				errc <- fmt.Errorf("sent %d bytes; expected %d", sbytes, size)
   157  				return
   158  			}
   159  		}()
   160  	}(ln)
   161  
   162  	// Connect to listener to retrieve file and verify digest matches
   163  	// expected.
   164  	c, err := Dial("tcp", ln.Addr().String())
   165  	if err != nil {
   166  		t.Fatal(err)
   167  	}
   168  	defer c.Close()
   169  
   170  	h := sha256.New()
   171  	rbytes, err := io.Copy(h, c)
   172  	if err != nil {
   173  		t.Error(err)
   174  	}
   175  
   176  	if rbytes != size {
   177  		t.Errorf("received %d bytes; expected %d", rbytes, size)
   178  	}
   179  
   180  	if len(fileHash) > 0 && hex.EncodeToString(h.Sum(nil)) != newtonSHA256 {
   181  		t.Error("retrieved data hash did not match")
   182  	}
   183  
   184  	for err := range errc {
   185  		t.Error(err)
   186  	}
   187  }
   188  
   189  func TestSendfileParts(t *testing.T) {
   190  	ln := newLocalListener(t, "tcp")
   191  	defer ln.Close()
   192  
   193  	errc := make(chan error, 1)
   194  	go func(ln Listener) {
   195  		// Wait for a connection.
   196  		conn, err := ln.Accept()
   197  		if err != nil {
   198  			errc <- err
   199  			close(errc)
   200  			return
   201  		}
   202  
   203  		go func() {
   204  			defer close(errc)
   205  			defer conn.Close()
   206  
   207  			f, err := os.Open(newton)
   208  			if err != nil {
   209  				errc <- err
   210  				return
   211  			}
   212  			defer f.Close()
   213  
   214  			for i := 0; i < 3; i++ {
   215  				// Return file data using io.CopyN, which should use
   216  				// sendFile if available.
   217  				expectSendfile(t, conn, func() {
   218  					_, err = io.CopyN(conn, f, 3)
   219  				})
   220  				if err != nil {
   221  					errc <- err
   222  					return
   223  				}
   224  			}
   225  		}()
   226  	}(ln)
   227  
   228  	c, err := Dial("tcp", ln.Addr().String())
   229  	if err != nil {
   230  		t.Fatal(err)
   231  	}
   232  	defer c.Close()
   233  
   234  	buf := new(bytes.Buffer)
   235  	buf.ReadFrom(c)
   236  
   237  	if want, have := "Produced ", buf.String(); have != want {
   238  		t.Errorf("unexpected server reply %q, want %q", have, want)
   239  	}
   240  
   241  	for err := range errc {
   242  		t.Error(err)
   243  	}
   244  }
   245  
   246  func TestSendfileSeeked(t *testing.T) {
   247  	ln := newLocalListener(t, "tcp")
   248  	defer ln.Close()
   249  
   250  	const seekTo = 65 << 10
   251  	const sendSize = 10 << 10
   252  
   253  	errc := make(chan error, 1)
   254  	go func(ln Listener) {
   255  		// Wait for a connection.
   256  		conn, err := ln.Accept()
   257  		if err != nil {
   258  			errc <- err
   259  			close(errc)
   260  			return
   261  		}
   262  
   263  		go func() {
   264  			defer close(errc)
   265  			defer conn.Close()
   266  
   267  			f, err := os.Open(newton)
   268  			if err != nil {
   269  				errc <- err
   270  				return
   271  			}
   272  			defer f.Close()
   273  			if _, err := f.Seek(seekTo, io.SeekStart); err != nil {
   274  				errc <- err
   275  				return
   276  			}
   277  
   278  			expectSendfile(t, conn, func() {
   279  				_, err = io.CopyN(conn, f, sendSize)
   280  			})
   281  			if err != nil {
   282  				errc <- err
   283  				return
   284  			}
   285  		}()
   286  	}(ln)
   287  
   288  	c, err := Dial("tcp", ln.Addr().String())
   289  	if err != nil {
   290  		t.Fatal(err)
   291  	}
   292  	defer c.Close()
   293  
   294  	buf := new(bytes.Buffer)
   295  	buf.ReadFrom(c)
   296  
   297  	if buf.Len() != sendSize {
   298  		t.Errorf("Got %d bytes; want %d", buf.Len(), sendSize)
   299  	}
   300  
   301  	for err := range errc {
   302  		t.Error(err)
   303  	}
   304  }
   305  
   306  // Test that sendfile doesn't put a pipe into blocking mode.
   307  func TestSendfilePipe(t *testing.T) {
   308  	switch runtime.GOOS {
   309  	case "plan9", "windows", "js", "wasip1":
   310  		// These systems don't support deadlines on pipes.
   311  		t.Skipf("skipping on %s", runtime.GOOS)
   312  	}
   313  
   314  	t.Parallel()
   315  
   316  	ln := newLocalListener(t, "tcp")
   317  	defer ln.Close()
   318  
   319  	r, w, err := os.Pipe()
   320  	if err != nil {
   321  		t.Fatal(err)
   322  	}
   323  	defer w.Close()
   324  	defer r.Close()
   325  
   326  	copied := make(chan bool)
   327  
   328  	var wg sync.WaitGroup
   329  	wg.Add(1)
   330  	go func() {
   331  		// Accept a connection and copy 1 byte from the read end of
   332  		// the pipe to the connection. This will call into sendfile.
   333  		defer wg.Done()
   334  		conn, err := ln.Accept()
   335  		if err != nil {
   336  			t.Error(err)
   337  			return
   338  		}
   339  		defer conn.Close()
   340  		// The comment above states that this should call into sendfile,
   341  		// but empirically it doesn't seem to do so at this time.
   342  		// If it does, or does on some platforms, this CopyN should be wrapped
   343  		// in expectSendfile.
   344  		_, err = io.CopyN(conn, r, 1)
   345  		if err != nil {
   346  			t.Error(err)
   347  			return
   348  		}
   349  		// Signal the main goroutine that we've copied the byte.
   350  		close(copied)
   351  	}()
   352  
   353  	wg.Add(1)
   354  	go func() {
   355  		// Write 1 byte to the write end of the pipe.
   356  		defer wg.Done()
   357  		_, err := w.Write([]byte{'a'})
   358  		if err != nil {
   359  			t.Error(err)
   360  		}
   361  	}()
   362  
   363  	wg.Add(1)
   364  	go func() {
   365  		// Connect to the server started two goroutines up and
   366  		// discard any data that it writes.
   367  		defer wg.Done()
   368  		conn, err := Dial("tcp", ln.Addr().String())
   369  		if err != nil {
   370  			t.Error(err)
   371  			return
   372  		}
   373  		defer conn.Close()
   374  		io.Copy(io.Discard, conn)
   375  	}()
   376  
   377  	// Wait for the byte to be copied, meaning that sendfile has
   378  	// been called on the pipe.
   379  	<-copied
   380  
   381  	// Set a very short deadline on the read end of the pipe.
   382  	if err := r.SetDeadline(time.Now().Add(time.Microsecond)); err != nil {
   383  		t.Fatal(err)
   384  	}
   385  
   386  	wg.Add(1)
   387  	go func() {
   388  		// Wait for much longer than the deadline and write a byte
   389  		// to the pipe.
   390  		defer wg.Done()
   391  		time.Sleep(50 * time.Millisecond)
   392  		w.Write([]byte{'b'})
   393  	}()
   394  
   395  	// If this read does not time out, the pipe was incorrectly
   396  	// put into blocking mode.
   397  	_, err = r.Read(make([]byte, 1))
   398  	if err == nil {
   399  		t.Error("Read did not time out")
   400  	} else if !os.IsTimeout(err) {
   401  		t.Errorf("got error %v, expected a time out", err)
   402  	}
   403  
   404  	wg.Wait()
   405  }
   406  
   407  // Issue 43822: tests that returns EOF when conn write timeout.
   408  func TestSendfileOnWriteTimeoutExceeded(t *testing.T) {
   409  	ln := newLocalListener(t, "tcp")
   410  	defer ln.Close()
   411  
   412  	errc := make(chan error, 1)
   413  	go func(ln Listener) (retErr error) {
   414  		defer func() {
   415  			errc <- retErr
   416  			close(errc)
   417  		}()
   418  
   419  		conn, err := ln.Accept()
   420  		if err != nil {
   421  			return err
   422  		}
   423  		defer conn.Close()
   424  
   425  		// Set the write deadline in the past(1h ago). It makes
   426  		// sure that it is always write timeout.
   427  		if err := conn.SetWriteDeadline(time.Now().Add(-1 * time.Hour)); err != nil {
   428  			return err
   429  		}
   430  
   431  		f, err := os.Open(newton)
   432  		if err != nil {
   433  			return err
   434  		}
   435  		defer f.Close()
   436  
   437  		// We expect this to use sendfile, but as of the time this comment was written
   438  		// poll.SendFile on an FD past its timeout can return an error indicating that
   439  		// it didn't handle the operation, resulting in a non-sendfile retry.
   440  		// So don't use expectSendfile here.
   441  		_, err = io.Copy(conn, f)
   442  		if errors.Is(err, os.ErrDeadlineExceeded) {
   443  			return nil
   444  		}
   445  
   446  		if err == nil {
   447  			err = fmt.Errorf("expected ErrDeadlineExceeded, but got nil")
   448  		}
   449  		return err
   450  	}(ln)
   451  
   452  	conn, err := Dial("tcp", ln.Addr().String())
   453  	if err != nil {
   454  		t.Fatal(err)
   455  	}
   456  	defer conn.Close()
   457  
   458  	n, err := io.Copy(io.Discard, conn)
   459  	if err != nil {
   460  		t.Fatalf("expected nil error, but got %v", err)
   461  	}
   462  	if n != 0 {
   463  		t.Fatalf("expected receive zero, but got %d byte(s)", n)
   464  	}
   465  
   466  	if err := <-errc; err != nil {
   467  		t.Fatal(err)
   468  	}
   469  }
   470  
   471  func BenchmarkSendfileZeroBytes(b *testing.B) {
   472  	var (
   473  		wg          sync.WaitGroup
   474  		ctx, cancel = context.WithCancel(context.Background())
   475  	)
   476  
   477  	defer wg.Wait()
   478  
   479  	ln := newLocalListener(b, "tcp")
   480  	defer ln.Close()
   481  
   482  	tempFile, err := os.CreateTemp(b.TempDir(), "test.txt")
   483  	if err != nil {
   484  		b.Fatalf("failed to create temp file: %v", err)
   485  	}
   486  	defer tempFile.Close()
   487  
   488  	fileName := tempFile.Name()
   489  
   490  	dataSize := b.N
   491  	wg.Add(1)
   492  	go func(f *os.File) {
   493  		defer wg.Done()
   494  
   495  		for i := 0; i < dataSize; i++ {
   496  			if _, err := f.Write([]byte{1}); err != nil {
   497  				b.Errorf("failed to write: %v", err)
   498  				return
   499  			}
   500  			if i%1000 == 0 {
   501  				f.Sync()
   502  			}
   503  		}
   504  	}(tempFile)
   505  
   506  	b.ResetTimer()
   507  	b.ReportAllocs()
   508  
   509  	wg.Add(1)
   510  	go func(ln Listener, fileName string) {
   511  		defer wg.Done()
   512  
   513  		conn, err := ln.Accept()
   514  		if err != nil {
   515  			b.Errorf("failed to accept: %v", err)
   516  			return
   517  		}
   518  		defer conn.Close()
   519  
   520  		f, err := os.OpenFile(fileName, os.O_RDONLY, 0660)
   521  		if err != nil {
   522  			b.Errorf("failed to open file: %v", err)
   523  			return
   524  		}
   525  		defer f.Close()
   526  
   527  		for {
   528  			if ctx.Err() != nil {
   529  				return
   530  			}
   531  
   532  			if _, err := io.Copy(conn, f); err != nil {
   533  				b.Errorf("failed to copy: %v", err)
   534  				return
   535  			}
   536  		}
   537  	}(ln, fileName)
   538  
   539  	conn, err := Dial("tcp", ln.Addr().String())
   540  	if err != nil {
   541  		b.Fatalf("failed to dial: %v", err)
   542  	}
   543  	defer conn.Close()
   544  
   545  	n, err := io.CopyN(io.Discard, conn, int64(dataSize))
   546  	if err != nil {
   547  		b.Fatalf("failed to copy: %v", err)
   548  	}
   549  	if n != int64(dataSize) {
   550  		b.Fatalf("expected %d copied bytes, but got %d", dataSize, n)
   551  	}
   552  
   553  	cancel()
   554  }
   555  
   556  func BenchmarkSendFile(b *testing.B) {
   557  	if runtime.GOOS == "windows" {
   558  		// TODO(panjf2000): Windows has not yet implemented FileConn,
   559  		//		remove this when it's implemented in https://go.dev/issues/9503.
   560  		b.Skipf("skipping on %s", runtime.GOOS)
   561  	}
   562  
   563  	b.Run("file-to-tcp", func(b *testing.B) { benchmarkSendFile(b, "tcp") })
   564  	b.Run("file-to-unix", func(b *testing.B) { benchmarkSendFile(b, "unix") })
   565  }
   566  
   567  func benchmarkSendFile(b *testing.B, proto string) {
   568  	for i := 0; i <= 10; i++ {
   569  		size := 1 << (i + 10)
   570  		bench := sendFileBench{
   571  			proto:     proto,
   572  			chunkSize: size,
   573  		}
   574  		b.Run(strconv.Itoa(size), bench.benchSendFile)
   575  	}
   576  }
   577  
   578  type sendFileBench struct {
   579  	proto     string
   580  	chunkSize int
   581  }
   582  
   583  func (bench sendFileBench) benchSendFile(b *testing.B) {
   584  	fileSize := b.N * bench.chunkSize
   585  	f := createTempFile(b, int64(fileSize))
   586  
   587  	client, server := spawnTestSocketPair(b, bench.proto)
   588  	defer server.Close()
   589  
   590  	cleanUp, err := startTestSocketPeer(b, client, "r", bench.chunkSize, fileSize)
   591  	if err != nil {
   592  		client.Close()
   593  		b.Fatal(err)
   594  	}
   595  	defer cleanUp(b)
   596  
   597  	b.ReportAllocs()
   598  	b.SetBytes(int64(bench.chunkSize))
   599  	b.ResetTimer()
   600  
   601  	// Data go from file to socket via sendfile(2).
   602  	sent, err := io.Copy(server, f)
   603  	if err != nil {
   604  		b.Fatalf("failed to copy data with sendfile, error: %v", err)
   605  	}
   606  	if sent != int64(fileSize) {
   607  		b.Fatalf("bytes sent mismatch, got: %d, want: %d", sent, fileSize)
   608  	}
   609  }
   610  
   611  func createTempFile(tb testing.TB, size int64) *os.File {
   612  	f, err := os.CreateTemp(tb.TempDir(), "sendfile-bench")
   613  	if err != nil {
   614  		tb.Fatalf("failed to create temporary file: %v", err)
   615  	}
   616  	tb.Cleanup(func() {
   617  		f.Close()
   618  	})
   619  
   620  	if _, err := io.CopyN(f, newRandReader(tb), size); err != nil {
   621  		tb.Fatalf("failed to fill the file with random data: %v", err)
   622  	}
   623  	if _, err := f.Seek(0, io.SeekStart); err != nil {
   624  		tb.Fatalf("failed to rewind the file: %v", err)
   625  	}
   626  
   627  	return f
   628  }
   629  
   630  func newRandReader(tb testing.TB) io.Reader {
   631  	seed := time.Now().UnixNano()
   632  	tb.Logf("Deterministic RNG seed based on timestamp: 0x%x", seed)
   633  	return rand.New(rand.NewSource(seed))
   634  }
   635  

View as plain text