diff --git a/cmd/demo/demo.go b/cmd/demo/demo.go index 61ef9f2..023dba3 100644 --- a/cmd/demo/demo.go +++ b/cmd/demo/demo.go @@ -6,8 +6,10 @@ import ( "fmt" "io/ioutil" "log" + "time" adb "github.com/zach-klippenstein/goadb" + "github.com/zach-klippenstein/goadb/util" ) var port = flag.Int("p", adb.AdbPort, "") @@ -47,10 +49,33 @@ func main() { PrintDeviceInfoAndError(adb.DeviceWithSerial(serial)) } + fmt.Println() + fmt.Println("Watching for device state changes.") + watcher, err := adb.NewDeviceWatcher(adb.ClientConfig{}) + for event := range watcher.C() { + fmt.Printf("\t[%s]%+v\n", time.Now(), event) + } + if watcher.Err() != nil { + printErr(watcher.Err()) + } + //fmt.Println("Killing server…") //client.KillServer() } +func printErr(err error) { + switch err := err.(type) { + case *util.Err: + fmt.Println(err.Error()) + if err.Cause != nil { + fmt.Print("caused by ") + printErr(err.Cause) + } + default: + fmt.Println("error:", err) + } +} + func PrintDeviceInfoAndError(descriptor adb.DeviceDescriptor) { device := adb.NewDeviceClient(adb.ClientConfig{}, descriptor) if err := PrintDeviceInfo(device); err != nil { diff --git a/device_watcher.go b/device_watcher.go new file mode 100644 index 0000000..3f97335 --- /dev/null +++ b/device_watcher.go @@ -0,0 +1,214 @@ +package goadb + +import ( + "log" + "runtime" + "strings" + "sync/atomic" + + "github.com/zach-klippenstein/goadb/util" + "github.com/zach-klippenstein/goadb/wire" +) + +/* +DeviceWatcher publishes device status change events. +If the server dies while listening for events, it restarts the server. +*/ +type DeviceWatcher struct { + *deviceWatcherImpl +} + +type DeviceStateChangedEvent struct { + Serial string + OldState string + NewState string +} + +type deviceWatcherImpl struct { + config ClientConfig + + // If an error occurs, it is stored here and eventChan is close immediately after. + err atomic.Value + + eventChan chan DeviceStateChangedEvent + + // Function to start the server if it's not running or dies. + startServer func() error +} + +func NewDeviceWatcher(config ClientConfig) (*DeviceWatcher, error) { + watcher := &DeviceWatcher{&deviceWatcherImpl{ + config: config.sanitized(), + eventChan: make(chan DeviceStateChangedEvent), + startServer: StartServer, + }} + + runtime.SetFinalizer(watcher, func(watcher *DeviceWatcher) { + watcher.Shutdown() + }) + + go publishDevices(watcher.deviceWatcherImpl) + + return watcher, nil +} + +/* +C returns a channel than can be received on to get events. +If an unrecoverable error occurs, or Shutdown is called, the channel will be closed. +*/ +func (w *DeviceWatcher) C() <-chan DeviceStateChangedEvent { + return w.eventChan +} + +// Err returns the error that caused the channel returned by C to be closed, if C is closed. +// If C is not closed, its return value is undefined. +func (w *DeviceWatcher) Err() error { + if err, ok := w.err.Load().(error); ok { + return err + } + return nil +} + +// Shutdown stops the watcher from listening for events and closes the channel returned +// from C. +func (w *DeviceWatcher) Shutdown() { + // TODO(z): Implement. +} + +func (w *deviceWatcherImpl) reportErr(err error) { + w.err.Store(err) +} + +/* +publishDevices reads device lists from scanner, calculates diffs, and publishes events on +eventChan. +Returns when scanner returns an error. +Doesn't refer directly to a *DeviceWatcher so it can be GCed (which will, +in turn, close Scanner and stop this goroutine). + +TODO: to support shutdown, spawn a new goroutine each time a server connection is established. +This goroutine should read messages and send them to a message channel. Can write errors directly +to errVal. publisHDevicesUntilError should take the msg chan and the scanner and select on the msg chan and stop chan, and if the stop +chan sends, close the scanner and return true. If the msg chan closes, just return false. +publishDevices can look at ret val: if false and err == EOF, reconnect. If false and other error, report err +and abort. If true, report no error and stop. +*/ +func publishDevices(watcher *deviceWatcherImpl) { + defer close(watcher.eventChan) + + var lastKnownStates map[string]string + finished := false + + for { + scanner, err := connectToTrackDevices(watcher.config.Dialer) + if err != nil { + watcher.reportErr(err) + return + } + + finished, err = publishDevicesUntilError(scanner, watcher.eventChan, &lastKnownStates) + + if finished { + scanner.Close() + return + } + + if util.HasErrCode(err, util.ConnectionResetError) { + // The server died, restart and reconnect. + log.Println("[DeviceWatcher] server died, restarting…") + if err := watcher.startServer(); err != nil { + log.Println("[DeviceWatcher] error restarting server, giving up") + watcher.reportErr(err) + return + } // Else server should be running, continue listening. + } else { + // Unknown error, don't retry. + watcher.reportErr(err) + return + } + } +} + +func connectToTrackDevices(dialer Dialer) (wire.Scanner, error) { + conn, err := dialer.Dial() + if err != nil { + return nil, err + } + + if err := wire.SendMessageString(conn, "host:track-devices"); err != nil { + conn.Close() + return nil, err + } + + if err := wire.ReadStatusFailureAsError(conn, "host:track-devices"); err != nil { + conn.Close() + return nil, err + } + + return conn, nil +} + +func publishDevicesUntilError(scanner wire.Scanner, eventChan chan<- DeviceStateChangedEvent, lastKnownStates *map[string]string) (finished bool, err error) { + for { + msg, err := scanner.ReadMessage() + if err != nil { + return false, err + } + + deviceStates, err := parseDeviceStates(string(msg)) + if err != nil { + return false, err + } + + for _, event := range calculateStateDiffs(*lastKnownStates, deviceStates) { + eventChan <- event + } + *lastKnownStates = deviceStates + } +} + +func parseDeviceStates(msg string) (states map[string]string, err error) { + states = make(map[string]string) + + for lineNum, line := range strings.Split(msg, "\n") { + if len(line) == 0 { + continue + } + + fields := strings.Split(line, "\t") + if len(fields) != 2 { + err = util.Errorf(util.ParseError, "invalid device state line %d: %s", lineNum, line) + return + } + + serial, state := fields[0], fields[1] + states[serial] = state + } + + return +} + +func calculateStateDiffs(oldStates, newStates map[string]string) (events []DeviceStateChangedEvent) { + for serial, oldState := range oldStates { + newState, ok := newStates[serial] + + if oldState != newState { + if ok { + // Device present in both lists: state changed. + events = append(events, DeviceStateChangedEvent{serial, oldState, newState}) + } else { + // Device only present in old list: device removed. + events = append(events, DeviceStateChangedEvent{serial, oldState, ""}) + } + } + } + + for serial, newState := range newStates { + if _, ok := oldStates[serial]; !ok { + // Device only present in new list: device added. + events = append(events, DeviceStateChangedEvent{serial, "", newState}) + } + } + + return events +} diff --git a/device_watcher_test.go b/device_watcher_test.go new file mode 100644 index 0000000..fb1597e --- /dev/null +++ b/device_watcher_test.go @@ -0,0 +1,232 @@ +package goadb + +import ( + "log" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/zach-klippenstein/goadb/util" + "github.com/zach-klippenstein/goadb/wire" +) + +func TestParseDeviceStatesSingle(t *testing.T) { + states, err := parseDeviceStates(`192.168.56.101:5555 emulator-state +`) + + assert.NoError(t, err) + assert.Len(t, states, 1) + assert.Equal(t, "emulator-state", states["192.168.56.101:5555"]) +} + +func TestParseDeviceStatesMultiple(t *testing.T) { + states, err := parseDeviceStates(`192.168.56.101:5555 emulator-state +0x0x0x0x usb-state +`) + + assert.NoError(t, err) + assert.Len(t, states, 2) + assert.Equal(t, "emulator-state", states["192.168.56.101:5555"]) + assert.Equal(t, "usb-state", states["0x0x0x0x"]) +} + +func TestParseDeviceStatesMalformed(t *testing.T) { + _, err := parseDeviceStates(`192.168.56.101:5555 emulator-state +0x0x0x0x +`) + + assert.True(t, util.HasErrCode(err, util.ParseError)) + assert.Equal(t, "invalid device state line 1: 0x0x0x0x", err.(*util.Err).Message) +} + +func TestCalculateStateDiffsUnchangedEmpty(t *testing.T) { + oldStates := map[string]string{} + newStates := map[string]string{} + + diffs := calculateStateDiffs(oldStates, newStates) + + assert.Empty(t, diffs) +} + +func TestCalculateStateDiffsUnchangedNonEmpty(t *testing.T) { + oldStates := map[string]string{ + "1": "device", + "2": "device", + } + newStates := map[string]string{ + "1": "device", + "2": "device", + } + + diffs := calculateStateDiffs(oldStates, newStates) + + assert.Empty(t, diffs) +} + +func TestCalculateStateDiffsOneAdded(t *testing.T) { + oldStates := map[string]string{} + newStates := map[string]string{ + "serial": "added", + } + + diffs := calculateStateDiffs(oldStates, newStates) + + assert.Equal(t, []DeviceStateChangedEvent{ + DeviceStateChangedEvent{"serial", "", "added"}, + }, diffs) +} + +func TestCalculateStateDiffsOneRemoved(t *testing.T) { + oldStates := map[string]string{ + "serial": "removed", + } + newStates := map[string]string{} + + diffs := calculateStateDiffs(oldStates, newStates) + + assert.Equal(t, []DeviceStateChangedEvent{ + DeviceStateChangedEvent{"serial", "removed", ""}, + }, diffs) +} + +func TestCalculateStateDiffsOneAddedOneUnchanged(t *testing.T) { + oldStates := map[string]string{ + "1": "device", + } + newStates := map[string]string{ + "1": "device", + "2": "added", + } + + diffs := calculateStateDiffs(oldStates, newStates) + + assert.Equal(t, []DeviceStateChangedEvent{ + DeviceStateChangedEvent{"2", "", "added"}, + }, diffs) +} + +func TestCalculateStateDiffsOneRemovedOneUnchanged(t *testing.T) { + oldStates := map[string]string{ + "1": "removed", + "2": "device", + } + newStates := map[string]string{ + "2": "device", + } + + diffs := calculateStateDiffs(oldStates, newStates) + + assert.Equal(t, []DeviceStateChangedEvent{ + DeviceStateChangedEvent{"1", "removed", ""}, + }, diffs) +} + +func TestCalculateStateDiffsOneAddedOneRemoved(t *testing.T) { + oldStates := map[string]string{ + "1": "removed", + } + newStates := map[string]string{ + "2": "added", + } + + diffs := calculateStateDiffs(oldStates, newStates) + + assert.Equal(t, []DeviceStateChangedEvent{ + DeviceStateChangedEvent{"1", "removed", ""}, + DeviceStateChangedEvent{"2", "", "added"}, + }, diffs) +} + +func TestCalculateStateDiffsOneChangedOneUnchanged(t *testing.T) { + oldStates := map[string]string{ + "1": "oldState", + "2": "device", + } + newStates := map[string]string{ + "1": "newState", + "2": "device", + } + + diffs := calculateStateDiffs(oldStates, newStates) + + assert.Equal(t, []DeviceStateChangedEvent{ + DeviceStateChangedEvent{"1", "oldState", "newState"}, + }, diffs) +} + +func TestCalculateStateDiffsMultipleChangedMultipleUnchanged(t *testing.T) { + oldStates := map[string]string{ + "1": "oldState", + "2": "oldState", + } + newStates := map[string]string{ + "1": "newState", + "2": "newState", + } + + diffs := calculateStateDiffs(oldStates, newStates) + + assert.Equal(t, []DeviceStateChangedEvent{ + DeviceStateChangedEvent{"1", "oldState", "newState"}, + DeviceStateChangedEvent{"2", "oldState", "newState"}, + }, diffs) +} + +func TestCalculateStateDiffsOneAddedOneRemovedOneChanged(t *testing.T) { + oldStates := map[string]string{ + "1": "oldState", + "2": "removed", + } + newStates := map[string]string{ + "1": "newState", + "3": "added", + } + + diffs := calculateStateDiffs(oldStates, newStates) + + assert.Equal(t, []DeviceStateChangedEvent{ + DeviceStateChangedEvent{"1", "oldState", "newState"}, + DeviceStateChangedEvent{"2", "removed", ""}, + DeviceStateChangedEvent{"3", "", "added"}, + }, diffs) +} + +func TestPublishDevicesRestartsServer(t *testing.T) { + starter := &MockServerStarter{} + dialer := &MockServer{ + Status: wire.StatusSuccess, + Errs: []error{ + nil, nil, nil, // Successful dial. + util.Errorf(util.ConnectionResetError, "failed first read"), + util.Errorf(util.ServerNotAvailable, "failed redial"), + }, + } + watcher := deviceWatcherImpl{ + config: ClientConfig{dialer}, + eventChan: make(chan DeviceStateChangedEvent), + startServer: starter.StartServer, + } + + publishDevices(&watcher) + + assert.Empty(t, dialer.Errs) + assert.Equal(t, []string{"host:track-devices"}, dialer.Requests) + assert.Equal(t, []string{"Dial", "SendMessage", "ReadStatus", "ReadMessage", "Dial"}, dialer.Trace) + err := watcher.err.Load().(*util.Err) + assert.Equal(t, util.ServerNotAvailable, err.Code) + assert.Equal(t, 1, starter.startCount) +} + +type MockServerStarter struct { + startCount int + err error +} + +func (s *MockServerStarter) StartServer() error { + log.Printf("Starting mock server") + if s.err == nil { + s.startCount += 1 + return nil + } else { + return s.err + } +} diff --git a/dialer.go b/dialer.go index 1d1a45a..225a79b 100644 --- a/dialer.go +++ b/dialer.go @@ -54,7 +54,7 @@ func (d *netDialer) Dial() (*wire.Conn, error) { address := fmt.Sprintf("%s:%d", host, port) netConn, err := net.Dial("tcp", address) if err != nil { - return nil, util.WrapErrorf(err, util.NetworkError, "error dialing %s", address) + return nil, util.WrapErrorf(err, util.ServerNotAvailable, "error dialing %s", address) } conn := &wire.Conn{ diff --git a/host_client.go b/host_client.go index e2be71a..fd8c7a1 100644 --- a/host_client.go +++ b/host_client.go @@ -1,4 +1,3 @@ -// TODO(z): Implement TrackDevices. package goadb import ( diff --git a/host_client_test.go b/host_client_test.go index fb2ec71..3acacdf 100644 --- a/host_client_test.go +++ b/host_client_test.go @@ -25,27 +25,47 @@ func TestGetServerVersion(t *testing.T) { assert.Equal(t, 10, v) } +// MockServer implements Dialer, Scanner, and Sender. type MockServer struct { + // Each time an operation is performed, if this slice is non-empty, the head element + // of this slice is returned and removed from the slice. If the head is nil, it is removed + // but not returned. + Errs []error + Status wire.StatusCode - // Messages are sent in order, each preceded by a length header. - Messages []string + // Messages are returned from read calls in order, each preceded by a length header. + Messages []string + nextMsgIndex int - // Each request is appended to this slice. + // Each message passed to a send call is appended to this slice. Requests []string - nextMsgIndex int + // Each time an operaiton is performed, its name is appended to this slice. + Trace []string } func (s *MockServer) Dial() (*wire.Conn, error) { + s.logMethod("Dial") + if err := s.getNextErrToReturn(); err != nil { + return nil, err + } return wire.NewConn(s, s), nil } func (s *MockServer) ReadStatus() (wire.StatusCode, error) { + s.logMethod("ReadStatus") + if err := s.getNextErrToReturn(); err != nil { + return "", err + } return s.Status, nil } func (s *MockServer) ReadMessage() ([]byte, error) { + s.logMethod("ReadMessage") + if err := s.getNextErrToReturn(); err != nil { + return nil, err + } if s.nextMsgIndex >= len(s.Messages) { return nil, util.WrapErrorf(io.EOF, util.NetworkError, "") } @@ -55,6 +75,11 @@ func (s *MockServer) ReadMessage() ([]byte, error) { } func (s *MockServer) ReadUntilEof() ([]byte, error) { + s.logMethod("ReadUntilEof") + if err := s.getNextErrToReturn(); err != nil { + return nil, err + } + var data []string for ; s.nextMsgIndex < len(s.Messages); s.nextMsgIndex++ { data = append(data, s.Messages[s.nextMsgIndex]) @@ -63,18 +88,40 @@ func (s *MockServer) ReadUntilEof() ([]byte, error) { } func (s *MockServer) SendMessage(msg []byte) error { + s.logMethod("SendMessage") + if err := s.getNextErrToReturn(); err != nil { + return err + } s.Requests = append(s.Requests, string(msg)) return nil } func (s *MockServer) NewSyncScanner() wire.SyncScanner { + s.logMethod("NewSyncScanner") return nil } func (s *MockServer) NewSyncSender() wire.SyncSender { + s.logMethod("NewSyncSender") return nil } func (s *MockServer) Close() error { + s.logMethod("Close") + if err := s.getNextErrToReturn(); err != nil { + return err + } return nil } + +func (s *MockServer) getNextErrToReturn() (err error) { + if len(s.Errs) > 0 { + err = s.Errs[0] + s.Errs = s.Errs[1:] + } + return +} + +func (s *MockServer) logMethod(name string) { + s.Trace = append(s.Trace, name) +} diff --git a/server_controller.go b/server_controller.go index 9025a25..8038bac 100644 --- a/server_controller.go +++ b/server_controller.go @@ -8,9 +8,6 @@ import ( /* StartServer ensures there is a server running. - -Currently implemented by just running - adb start-server */ func StartServer() error { cmd := exec.Command("adb", "start-server") diff --git a/util/doc.go b/util/doc.go new file mode 100644 index 0000000..cf6a582 --- /dev/null +++ b/util/doc.go @@ -0,0 +1,4 @@ +/* +Contains code shared between the different sub-packages in this project. +*/ +package util diff --git a/util/errcode_string.go b/util/errcode_string.go index 5bb24c6..7ac33f7 100644 --- a/util/errcode_string.go +++ b/util/errcode_string.go @@ -4,9 +4,9 @@ package util import "fmt" -const _ErrCode_name = "AssertionErrorParseErrorServerNotAvailableNetworkErrorAdbErrorDeviceNotFoundFileNoExistError" +const _ErrCode_name = "AssertionErrorParseErrorServerNotAvailableNetworkErrorConnectionResetErrorAdbErrorDeviceNotFoundFileNoExistError" -var _ErrCode_index = [...]uint8{0, 14, 24, 42, 54, 62, 76, 92} +var _ErrCode_index = [...]uint8{0, 14, 24, 42, 54, 74, 82, 96, 112} func (i ErrCode) String() string { if i+1 >= ErrCode(len(_ErrCode_index)) { diff --git a/util/error.go b/util/error.go index ebed6a8..1d4e2e9 100644 --- a/util/error.go +++ b/util/error.go @@ -2,7 +2,19 @@ package util import "fmt" -// Err is the implementation of error that all goadb functions return. +/* +Err is the implementation of error that all goadb functions return. + +Best Practice + +External errors should be wrapped using WrapErrorf, as soon as they are known about. + +Intermediate code should pass *Errs up until they will be returned outside the library. +Errors should *not* be wrapped at every return site. + +Just before returning an *Err outside the library, it can be wrapped again, preserving the +ErrCode (e.g. with WrapErrf). +*/ type Err struct { // Code is the high-level "type" of error. Code ErrCode @@ -22,10 +34,12 @@ type ErrCode byte const ( AssertionError ErrCode = iota ParseError ErrCode = iota - // The server was not available on the request port and could not be started. + // The server was not available on the requested port. ServerNotAvailable ErrCode = iota // General network error communicating with the server. NetworkError ErrCode = iota + // The connection to the server was reset in the middle of an operation. Server probably died. + ConnectionResetError ErrCode = iota // The server returned an error message, but we couldn't parse it. AdbError ErrCode = iota // The server returned a "device not found" error. diff --git a/wire/conn.go b/wire/conn.go index c976ebc..21f781f 100644 --- a/wire/conn.go +++ b/wire/conn.go @@ -20,7 +20,11 @@ For most cases, usage looks something like: For some messages, the server will return more than one message (but still a single status). Generally, after calling ReadStatus once, you should call ReadMessage until -it returns an io.EOF error. +it returns an io.EOF error. Note: the protocol docs seem to suggest that connections will be +kept open for multiple commands, but this is not the case. The official client closes +a connection immediately after its read the response, in most cases. The docs might be +referring to the connection between the adb server and the device, but I haven't confirmed +that. For most commands, the server will close the connection after sending the response. You should still always call Close() when you're done with the connection. diff --git a/wire/scanner.go b/wire/scanner.go index 23f4bb0..bf5d0c5 100644 --- a/wire/scanner.go +++ b/wire/scanner.go @@ -8,6 +8,8 @@ import ( "github.com/zach-klippenstein/goadb/util" ) +// TODO(zach): All EOF errors returned from networoking calls should use ConnectionResetError. + // StatusCodes are returned by the server. If the code indicates failure, the // next message will be the error. type StatusCode string @@ -70,7 +72,7 @@ func (s *realScanner) ReadMessage() ([]byte, error) { length, err := s.readLength() if err != nil { - return nil, util.WrapErrorf(err, util.NetworkError, "error reading message length") + return nil, err } data := make([]byte, length) @@ -103,9 +105,7 @@ func (s *realScanner) Close() error { func (s *realScanner) readLength() (int, error) { lengthHex := make([]byte, 4) n, err := io.ReadFull(s.reader, lengthHex) - if err != nil && err != io.ErrUnexpectedEOF { - return 0, util.WrapErrorf(err, util.NetworkError, "error reading length") - } else if err == io.ErrUnexpectedEOF { + if err != nil { return 0, errIncompleteMessage("length", n, 4) } diff --git a/wire/scanner_test.go b/wire/scanner_test.go index 6226404..5d53d4b 100644 --- a/wire/scanner_test.go +++ b/wire/scanner_test.go @@ -98,7 +98,7 @@ func NewEofBuffer(str string) *TestReader { func assertEof(t *testing.T, s *realScanner) { msg, err := s.ReadMessage() - assert.True(t, util.HasErrCode(err, util.NetworkError)) + assert.True(t, util.HasErrCode(err, util.ConnectionResetError)) assert.Nil(t, msg) } diff --git a/wire/sync.go b/wire/sync_conn.go similarity index 100% rename from wire/sync.go rename to wire/sync_conn.go diff --git a/wire/util.go b/wire/util.go index 2da81cb..89c7bb6 100644 --- a/wire/util.go +++ b/wire/util.go @@ -60,7 +60,7 @@ func adbServerError(request string, serverMsg string) error { func errIncompleteMessage(description string, actual int, expected int) error { return &util.Err{ - Code: util.NetworkError, + Code: util.ConnectionResetError, Message: fmt.Sprintf("incomplete %s: read %d bytes, expecting %d", description, actual, expected), Details: struct { ActualReadBytes int