Merge pull request #8 from zach-klippenstein/fix-reads
Fix a bunch of bugs with the new SyncSender stuff.
This commit is contained in:
commit
aea57fe58f
18
dialer.go
18
dialer.go
|
@ -2,6 +2,7 @@ package goadb
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"io"
|
||||||
"net"
|
"net"
|
||||||
"runtime"
|
"runtime"
|
||||||
|
|
||||||
|
@ -66,17 +67,22 @@ func (d *netDialer) Dial() (*wire.Conn, error) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
conn := &wire.Conn{
|
// net.Conn can't be closed more than once, but wire.Conn will try to close both sender and scanner
|
||||||
Scanner: wire.NewScanner(netConn),
|
// so we need to wrap it to make it safe.
|
||||||
Sender: wire.NewSender(netConn),
|
safeConn := wire.MultiCloseable(netConn)
|
||||||
}
|
|
||||||
|
|
||||||
// Prevent leaking the network connection, not sure if TCPConn does this itself.
|
// Prevent leaking the network connection, not sure if TCPConn does this itself.
|
||||||
runtime.SetFinalizer(netConn, func(conn *net.TCPConn) {
|
// Note that the network connection may still be in use after the conn isn't (scanners/senders
|
||||||
|
// can give their underlying connections to other scanner/sender types), so we can't
|
||||||
|
// set the finalizer on conn.
|
||||||
|
runtime.SetFinalizer(safeConn, func(conn io.ReadWriteCloser) {
|
||||||
conn.Close()
|
conn.Close()
|
||||||
})
|
})
|
||||||
|
|
||||||
return conn, nil
|
return &wire.Conn{
|
||||||
|
Scanner: wire.NewScanner(safeConn),
|
||||||
|
Sender: wire.NewSender(safeConn),
|
||||||
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func roundTripSingleResponse(d Dialer, req string) ([]byte, error) {
|
func roundTripSingleResponse(d Dialer, req string) ([]byte, error) {
|
||||||
|
|
|
@ -1,7 +1,6 @@
|
||||||
package goadb
|
package goadb
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
|
||||||
"io"
|
"io"
|
||||||
|
|
||||||
"github.com/zach-klippenstein/goadb/util"
|
"github.com/zach-klippenstein/goadb/util"
|
||||||
|
@ -15,6 +14,9 @@ type syncFileReader struct {
|
||||||
|
|
||||||
// Reader for the current chunk only.
|
// Reader for the current chunk only.
|
||||||
chunkReader io.Reader
|
chunkReader io.Reader
|
||||||
|
|
||||||
|
// False until the DONE chunk is encountered.
|
||||||
|
eof bool
|
||||||
}
|
}
|
||||||
|
|
||||||
var _ io.ReadCloser = &syncFileReader{}
|
var _ io.ReadCloser = &syncFileReader{}
|
||||||
|
@ -26,18 +28,30 @@ func newSyncFileReader(s wire.SyncScanner) (r io.ReadCloser, err error) {
|
||||||
|
|
||||||
// Read the header for the first chunk to consume any errors.
|
// Read the header for the first chunk to consume any errors.
|
||||||
if _, err = r.Read([]byte{}); err != nil {
|
if _, err = r.Read([]byte{}); err != nil {
|
||||||
r.Close()
|
if err == io.EOF {
|
||||||
return nil, err
|
// EOF means the file was empty. This still means the file was opened successfully,
|
||||||
|
// and the next time the caller does a read they'll get the EOF and handle it themselves.
|
||||||
|
err = nil
|
||||||
|
} else {
|
||||||
|
r.Close()
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *syncFileReader) Read(buf []byte) (n int, err error) {
|
func (r *syncFileReader) Read(buf []byte) (n int, err error) {
|
||||||
|
if r.eof {
|
||||||
|
return 0, io.EOF
|
||||||
|
}
|
||||||
|
|
||||||
if r.chunkReader == nil {
|
if r.chunkReader == nil {
|
||||||
chunkReader, err := readNextChunk(r.scanner)
|
chunkReader, err := readNextChunk(r.scanner)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
// If this is EOF, we've read the last chunk.
|
if err == io.EOF {
|
||||||
// Either way, we want to pass it up to the caller.
|
// We just read the last chunk, set our flag before passing it up.
|
||||||
|
r.eof = true
|
||||||
|
}
|
||||||
return 0, err
|
return 0, err
|
||||||
}
|
}
|
||||||
r.chunkReader = chunkReader
|
r.chunkReader = chunkReader
|
||||||
|
@ -82,7 +96,7 @@ func readNextChunk(r wire.SyncScanner) (io.Reader, error) {
|
||||||
case wire.StatusSyncDone:
|
case wire.StatusSyncDone:
|
||||||
return nil, io.EOF
|
return nil, io.EOF
|
||||||
default:
|
default:
|
||||||
return nil, fmt.Errorf("expected chunk id '%s' or '%s', but got '%s'",
|
return nil, util.Errorf(util.AssertionError, "expected chunk id '%s' or '%s', but got '%s'",
|
||||||
wire.StatusSyncData, wire.StatusSyncDone, []byte(status))
|
wire.StatusSyncData, wire.StatusSyncDone, []byte(status))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -8,6 +8,7 @@ import (
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/zach-klippenstein/goadb/util"
|
"github.com/zach-klippenstein/goadb/util"
|
||||||
"github.com/zach-klippenstein/goadb/wire"
|
"github.com/zach-klippenstein/goadb/wire"
|
||||||
|
"io/ioutil"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestReadNextChunk(t *testing.T) {
|
func TestReadNextChunk(t *testing.T) {
|
||||||
|
@ -44,7 +45,7 @@ func TestReadNextChunkInvalidChunkId(t *testing.T) {
|
||||||
|
|
||||||
// Read 1st chunk
|
// Read 1st chunk
|
||||||
_, err := readNextChunk(s)
|
_, err := readNextChunk(s)
|
||||||
assert.EqualError(t, err, "expected chunk id 'DATA' or 'DONE', but got 'ATAD'")
|
assert.EqualError(t, err, "AssertionError: expected chunk id 'DATA' or 'DONE', but got 'ATAD'")
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestReadMultipleCalls(t *testing.T) {
|
func TestReadMultipleCalls(t *testing.T) {
|
||||||
|
@ -91,6 +92,20 @@ func TestReadError(t *testing.T) {
|
||||||
assert.EqualError(t, err, "AdbError: server error for read-chunk request: fail ({Request:read-chunk ServerMsg:fail})")
|
assert.EqualError(t, err, "AdbError: server error for read-chunk request: fail ({Request:read-chunk ServerMsg:fail})")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestReadEmpty(t *testing.T) {
|
||||||
|
s := wire.NewSyncScanner(strings.NewReader(
|
||||||
|
"DONE"))
|
||||||
|
r, err := newSyncFileReader(s)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
// Multiple read calls that return EOF is a valid case.
|
||||||
|
for i := 0; i < 5; i++ {
|
||||||
|
data, err := ioutil.ReadAll(r)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Empty(t, data)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestReadErrorNotFound(t *testing.T) {
|
func TestReadErrorNotFound(t *testing.T) {
|
||||||
s := wire.NewSyncScanner(strings.NewReader(
|
s := wire.NewSyncScanner(strings.NewReader(
|
||||||
"FAIL\031\000\000\000No such file or directory"))
|
"FAIL\031\000\000\000No such file or directory"))
|
||||||
|
|
|
@ -43,21 +43,31 @@ func encodePathAndMode(path string, mode os.FileMode) []byte {
|
||||||
|
|
||||||
// Write writes the min of (len(buf), 64k).
|
// Write writes the min of (len(buf), 64k).
|
||||||
func (w *syncFileWriter) Write(buf []byte) (n int, err error) {
|
func (w *syncFileWriter) Write(buf []byte) (n int, err error) {
|
||||||
// Writes < 64k have a one-to-one mapping to chunks.
|
written := 0
|
||||||
// If buffer is larger than the max, we'll return the max size and leave it up to the
|
|
||||||
// caller to handle correctly.
|
// If buf > 64k we'll have to send multiple chunks.
|
||||||
if len(buf) > wire.SyncMaxChunkSize {
|
// TODO Refactor this into something that can coalesce smaller writes into a single chukn.
|
||||||
buf = buf[:wire.SyncMaxChunkSize]
|
for len(buf) > 0 {
|
||||||
|
// Writes < 64k have a one-to-one mapping to chunks.
|
||||||
|
// If buffer is larger than the max, we'll return the max size and leave it up to the
|
||||||
|
// caller to handle correctly.
|
||||||
|
partialBuf := buf
|
||||||
|
if len(partialBuf) > wire.SyncMaxChunkSize {
|
||||||
|
partialBuf = partialBuf[:wire.SyncMaxChunkSize]
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := w.sender.SendOctetString(wire.StatusSyncData); err != nil {
|
||||||
|
return written, err
|
||||||
|
}
|
||||||
|
if err := w.sender.SendBytes(partialBuf); err != nil {
|
||||||
|
return written, err
|
||||||
|
}
|
||||||
|
|
||||||
|
written += len(partialBuf)
|
||||||
|
buf = buf[len(partialBuf):]
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := w.sender.SendOctetString(wire.StatusSyncData); err != nil {
|
return written, nil
|
||||||
return 0, err
|
|
||||||
}
|
|
||||||
if err := w.sender.SendBytes(buf); err != nil {
|
|
||||||
return 0, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return len(buf), nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (w *syncFileWriter) Close() error {
|
func (w *syncFileWriter) Close() error {
|
||||||
|
@ -66,7 +76,7 @@ func (w *syncFileWriter) Close() error {
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := w.sender.SendOctetString(wire.StatusSyncDone); err != nil {
|
if err := w.sender.SendOctetString(wire.StatusSyncDone); err != nil {
|
||||||
return util.WrapErrf(err, "error closing file writer")
|
return util.WrapErrf(err, "error sending done chunk to close stream")
|
||||||
}
|
}
|
||||||
if err := w.sender.SendTime(w.mtime); err != nil {
|
if err := w.sender.SendTime(w.mtime); err != nil {
|
||||||
return util.WrapErrf(err, "error writing file modification time")
|
return util.WrapErrf(err, "error writing file modification time")
|
||||||
|
|
|
@ -42,18 +42,26 @@ func TestFileWriterWriteLargeChunk(t *testing.T) {
|
||||||
var buf bytes.Buffer
|
var buf bytes.Buffer
|
||||||
writer := newSyncFileWriter(wire.NewSyncSender(&buf), MtimeOfClose)
|
writer := newSyncFileWriter(wire.NewSyncSender(&buf), MtimeOfClose)
|
||||||
|
|
||||||
|
// Send just enough data to get 2 chunks.
|
||||||
data := make([]byte, wire.SyncMaxChunkSize+1)
|
data := make([]byte, wire.SyncMaxChunkSize+1)
|
||||||
n, err := writer.Write(data)
|
n, err := writer.Write(data)
|
||||||
|
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
assert.Equal(t, wire.SyncMaxChunkSize, n)
|
assert.Equal(t, wire.SyncMaxChunkSize+1, n)
|
||||||
assert.Equal(t, 8 + wire.SyncMaxChunkSize, buf.Len())
|
assert.Equal(t, 8 + 8 + wire.SyncMaxChunkSize+1, buf.Len())
|
||||||
|
|
||||||
expectedHeader := []byte("DATA0000")
|
// First header.
|
||||||
|
chunk := buf.Bytes()[:8+wire.SyncMaxChunkSize]
|
||||||
|
expectedHeader := []byte("DATA----")
|
||||||
binary.LittleEndian.PutUint32(expectedHeader[4:], wire.SyncMaxChunkSize)
|
binary.LittleEndian.PutUint32(expectedHeader[4:], wire.SyncMaxChunkSize)
|
||||||
assert.Equal(t, expectedHeader, buf.Bytes()[:8])
|
assert.Equal(t, expectedHeader, chunk[:8])
|
||||||
|
assert.Equal(t, data[:wire.SyncMaxChunkSize], chunk[8:])
|
||||||
|
|
||||||
assert.Equal(t, string(data[:wire.SyncMaxChunkSize]), buf.String()[8:])
|
// Second header.
|
||||||
|
chunk = buf.Bytes()[wire.SyncMaxChunkSize+8:wire.SyncMaxChunkSize+8+1]
|
||||||
|
expectedHeader = []byte("DATA\000\000\000\000")
|
||||||
|
binary.LittleEndian.PutUint32(expectedHeader[4:], 1)
|
||||||
|
assert.Equal(t, expectedHeader, chunk[:8])
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestFileWriterCloseEmpty(t *testing.T) {
|
func TestFileWriterCloseEmpty(t *testing.T) {
|
||||||
|
|
2
util.go
2
util.go
|
@ -26,7 +26,7 @@ func wrapClientError(err error, client interface{}, operation string, args ...in
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
if _, ok := err.(*util.Err); !ok {
|
if _, ok := err.(*util.Err); !ok {
|
||||||
panic("err is not a *util.Err")
|
panic("err is not a *util.Err: " + err.Error())
|
||||||
}
|
}
|
||||||
|
|
||||||
clientType := reflect.TypeOf(client)
|
clientType := reflect.TypeOf(client)
|
||||||
|
|
|
@ -176,7 +176,12 @@ func ErrorWithCauseChain(err error) string {
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
buffer.WriteString(err.Error())
|
|
||||||
|
if err != nil {
|
||||||
|
buffer.WriteString(err.Error())
|
||||||
|
} else {
|
||||||
|
buffer.WriteString("<err=nil>")
|
||||||
|
}
|
||||||
|
|
||||||
return buffer.String()
|
return buffer.String()
|
||||||
}
|
}
|
||||||
|
|
|
@ -23,6 +23,8 @@ caused by AssertionError: err2
|
||||||
caused by err3`
|
caused by err3`
|
||||||
|
|
||||||
assert.Equal(t, expected, ErrorWithCauseChain(err))
|
assert.Equal(t, expected, ErrorWithCauseChain(err))
|
||||||
|
|
||||||
|
assert.Equal(t, "<err=nil>", ErrorWithCauseChain(nil))
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestCombineErrors(t *testing.T) {
|
func TestCombineErrors(t *testing.T) {
|
||||||
|
|
|
@ -68,8 +68,7 @@ func (s *realSyncSender) SendBytes(data []byte) error {
|
||||||
if err := s.SendInt32(int32(length)); err != nil {
|
if err := s.SendInt32(int32(length)); err != nil {
|
||||||
return util.WrapErrorf(err, util.NetworkError, "error sending data length on sync sender")
|
return util.WrapErrorf(err, util.NetworkError, "error sending data length on sync sender")
|
||||||
}
|
}
|
||||||
return util.WrapErrorf(writeFully(s.Writer, data),
|
return writeFully(s.Writer, data)
|
||||||
util.NetworkError, "error sending data on sync sender")
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *realSyncSender) Close() error {
|
func (s *realSyncSender) Close() error {
|
||||||
|
|
20
wire/util.go
20
wire/util.go
|
@ -5,6 +5,8 @@ import (
|
||||||
"io"
|
"io"
|
||||||
"regexp"
|
"regexp"
|
||||||
|
|
||||||
|
"sync"
|
||||||
|
|
||||||
"github.com/zach-klippenstein/goadb/util"
|
"github.com/zach-klippenstein/goadb/util"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -80,3 +82,21 @@ func writeFully(w io.Writer, data []byte) error {
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// MultiCloseable wraps c in a ReadWriteCloser that can be safely closed multiple times.
|
||||||
|
func MultiCloseable(c io.ReadWriteCloser) io.ReadWriteCloser {
|
||||||
|
return &multiCloseable{ReadWriteCloser: c}
|
||||||
|
}
|
||||||
|
|
||||||
|
type multiCloseable struct {
|
||||||
|
io.ReadWriteCloser
|
||||||
|
closeOnce sync.Once
|
||||||
|
err error
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *multiCloseable) Close() error {
|
||||||
|
c.closeOnce.Do(func() {
|
||||||
|
c.err = c.ReadWriteCloser.Close()
|
||||||
|
})
|
||||||
|
return c.err
|
||||||
|
}
|
||||||
|
|
Loading…
Reference in a new issue