Implemented device event watcher and improved errors.

This commit is contained in:
Zach Klippenstein 2015-05-17 12:33:41 -07:00
parent 946c9e8ff8
commit 164ab27d25
15 changed files with 556 additions and 20 deletions

View file

@ -6,8 +6,10 @@ import (
"fmt"
"io/ioutil"
"log"
"time"
adb "github.com/zach-klippenstein/goadb"
"github.com/zach-klippenstein/goadb/util"
)
var port = flag.Int("p", adb.AdbPort, "")
@ -47,10 +49,33 @@ func main() {
PrintDeviceInfoAndError(adb.DeviceWithSerial(serial))
}
fmt.Println()
fmt.Println("Watching for device state changes.")
watcher, err := adb.NewDeviceWatcher(adb.ClientConfig{})
for event := range watcher.C() {
fmt.Printf("\t[%s]%+v\n", time.Now(), event)
}
if watcher.Err() != nil {
printErr(watcher.Err())
}
//fmt.Println("Killing server…")
//client.KillServer()
}
func printErr(err error) {
switch err := err.(type) {
case *util.Err:
fmt.Println(err.Error())
if err.Cause != nil {
fmt.Print("caused by ")
printErr(err.Cause)
}
default:
fmt.Println("error:", err)
}
}
func PrintDeviceInfoAndError(descriptor adb.DeviceDescriptor) {
device := adb.NewDeviceClient(adb.ClientConfig{}, descriptor)
if err := PrintDeviceInfo(device); err != nil {

214
device_watcher.go Normal file
View file

@ -0,0 +1,214 @@
package goadb
import (
"log"
"runtime"
"strings"
"sync/atomic"
"github.com/zach-klippenstein/goadb/util"
"github.com/zach-klippenstein/goadb/wire"
)
/*
DeviceWatcher publishes device status change events.
If the server dies while listening for events, it restarts the server.
*/
type DeviceWatcher struct {
*deviceWatcherImpl
}
type DeviceStateChangedEvent struct {
Serial string
OldState string
NewState string
}
type deviceWatcherImpl struct {
config ClientConfig
// If an error occurs, it is stored here and eventChan is close immediately after.
err atomic.Value
eventChan chan DeviceStateChangedEvent
// Function to start the server if it's not running or dies.
startServer func() error
}
func NewDeviceWatcher(config ClientConfig) (*DeviceWatcher, error) {
watcher := &DeviceWatcher{&deviceWatcherImpl{
config: config.sanitized(),
eventChan: make(chan DeviceStateChangedEvent),
startServer: StartServer,
}}
runtime.SetFinalizer(watcher, func(watcher *DeviceWatcher) {
watcher.Shutdown()
})
go publishDevices(watcher.deviceWatcherImpl)
return watcher, nil
}
/*
C returns a channel than can be received on to get events.
If an unrecoverable error occurs, or Shutdown is called, the channel will be closed.
*/
func (w *DeviceWatcher) C() <-chan DeviceStateChangedEvent {
return w.eventChan
}
// Err returns the error that caused the channel returned by C to be closed, if C is closed.
// If C is not closed, its return value is undefined.
func (w *DeviceWatcher) Err() error {
if err, ok := w.err.Load().(error); ok {
return err
}
return nil
}
// Shutdown stops the watcher from listening for events and closes the channel returned
// from C.
func (w *DeviceWatcher) Shutdown() {
// TODO(z): Implement.
}
func (w *deviceWatcherImpl) reportErr(err error) {
w.err.Store(err)
}
/*
publishDevices reads device lists from scanner, calculates diffs, and publishes events on
eventChan.
Returns when scanner returns an error.
Doesn't refer directly to a *DeviceWatcher so it can be GCed (which will,
in turn, close Scanner and stop this goroutine).
TODO: to support shutdown, spawn a new goroutine each time a server connection is established.
This goroutine should read messages and send them to a message channel. Can write errors directly
to errVal. publisHDevicesUntilError should take the msg chan and the scanner and select on the msg chan and stop chan, and if the stop
chan sends, close the scanner and return true. If the msg chan closes, just return false.
publishDevices can look at ret val: if false and err == EOF, reconnect. If false and other error, report err
and abort. If true, report no error and stop.
*/
func publishDevices(watcher *deviceWatcherImpl) {
defer close(watcher.eventChan)
var lastKnownStates map[string]string
finished := false
for {
scanner, err := connectToTrackDevices(watcher.config.Dialer)
if err != nil {
watcher.reportErr(err)
return
}
finished, err = publishDevicesUntilError(scanner, watcher.eventChan, &lastKnownStates)
if finished {
scanner.Close()
return
}
if util.HasErrCode(err, util.ConnectionResetError) {
// The server died, restart and reconnect.
log.Println("[DeviceWatcher] server died, restarting…")
if err := watcher.startServer(); err != nil {
log.Println("[DeviceWatcher] error restarting server, giving up")
watcher.reportErr(err)
return
} // Else server should be running, continue listening.
} else {
// Unknown error, don't retry.
watcher.reportErr(err)
return
}
}
}
func connectToTrackDevices(dialer Dialer) (wire.Scanner, error) {
conn, err := dialer.Dial()
if err != nil {
return nil, err
}
if err := wire.SendMessageString(conn, "host:track-devices"); err != nil {
conn.Close()
return nil, err
}
if err := wire.ReadStatusFailureAsError(conn, "host:track-devices"); err != nil {
conn.Close()
return nil, err
}
return conn, nil
}
func publishDevicesUntilError(scanner wire.Scanner, eventChan chan<- DeviceStateChangedEvent, lastKnownStates *map[string]string) (finished bool, err error) {
for {
msg, err := scanner.ReadMessage()
if err != nil {
return false, err
}
deviceStates, err := parseDeviceStates(string(msg))
if err != nil {
return false, err
}
for _, event := range calculateStateDiffs(*lastKnownStates, deviceStates) {
eventChan <- event
}
*lastKnownStates = deviceStates
}
}
func parseDeviceStates(msg string) (states map[string]string, err error) {
states = make(map[string]string)
for lineNum, line := range strings.Split(msg, "\n") {
if len(line) == 0 {
continue
}
fields := strings.Split(line, "\t")
if len(fields) != 2 {
err = util.Errorf(util.ParseError, "invalid device state line %d: %s", lineNum, line)
return
}
serial, state := fields[0], fields[1]
states[serial] = state
}
return
}
func calculateStateDiffs(oldStates, newStates map[string]string) (events []DeviceStateChangedEvent) {
for serial, oldState := range oldStates {
newState, ok := newStates[serial]
if oldState != newState {
if ok {
// Device present in both lists: state changed.
events = append(events, DeviceStateChangedEvent{serial, oldState, newState})
} else {
// Device only present in old list: device removed.
events = append(events, DeviceStateChangedEvent{serial, oldState, ""})
}
}
}
for serial, newState := range newStates {
if _, ok := oldStates[serial]; !ok {
// Device only present in new list: device added.
events = append(events, DeviceStateChangedEvent{serial, "", newState})
}
}
return events
}

232
device_watcher_test.go Normal file
View file

@ -0,0 +1,232 @@
package goadb
import (
"log"
"testing"
"github.com/stretchr/testify/assert"
"github.com/zach-klippenstein/goadb/util"
"github.com/zach-klippenstein/goadb/wire"
)
func TestParseDeviceStatesSingle(t *testing.T) {
states, err := parseDeviceStates(`192.168.56.101:5555 emulator-state
`)
assert.NoError(t, err)
assert.Len(t, states, 1)
assert.Equal(t, "emulator-state", states["192.168.56.101:5555"])
}
func TestParseDeviceStatesMultiple(t *testing.T) {
states, err := parseDeviceStates(`192.168.56.101:5555 emulator-state
0x0x0x0x usb-state
`)
assert.NoError(t, err)
assert.Len(t, states, 2)
assert.Equal(t, "emulator-state", states["192.168.56.101:5555"])
assert.Equal(t, "usb-state", states["0x0x0x0x"])
}
func TestParseDeviceStatesMalformed(t *testing.T) {
_, err := parseDeviceStates(`192.168.56.101:5555 emulator-state
0x0x0x0x
`)
assert.True(t, util.HasErrCode(err, util.ParseError))
assert.Equal(t, "invalid device state line 1: 0x0x0x0x", err.(*util.Err).Message)
}
func TestCalculateStateDiffsUnchangedEmpty(t *testing.T) {
oldStates := map[string]string{}
newStates := map[string]string{}
diffs := calculateStateDiffs(oldStates, newStates)
assert.Empty(t, diffs)
}
func TestCalculateStateDiffsUnchangedNonEmpty(t *testing.T) {
oldStates := map[string]string{
"1": "device",
"2": "device",
}
newStates := map[string]string{
"1": "device",
"2": "device",
}
diffs := calculateStateDiffs(oldStates, newStates)
assert.Empty(t, diffs)
}
func TestCalculateStateDiffsOneAdded(t *testing.T) {
oldStates := map[string]string{}
newStates := map[string]string{
"serial": "added",
}
diffs := calculateStateDiffs(oldStates, newStates)
assert.Equal(t, []DeviceStateChangedEvent{
DeviceStateChangedEvent{"serial", "", "added"},
}, diffs)
}
func TestCalculateStateDiffsOneRemoved(t *testing.T) {
oldStates := map[string]string{
"serial": "removed",
}
newStates := map[string]string{}
diffs := calculateStateDiffs(oldStates, newStates)
assert.Equal(t, []DeviceStateChangedEvent{
DeviceStateChangedEvent{"serial", "removed", ""},
}, diffs)
}
func TestCalculateStateDiffsOneAddedOneUnchanged(t *testing.T) {
oldStates := map[string]string{
"1": "device",
}
newStates := map[string]string{
"1": "device",
"2": "added",
}
diffs := calculateStateDiffs(oldStates, newStates)
assert.Equal(t, []DeviceStateChangedEvent{
DeviceStateChangedEvent{"2", "", "added"},
}, diffs)
}
func TestCalculateStateDiffsOneRemovedOneUnchanged(t *testing.T) {
oldStates := map[string]string{
"1": "removed",
"2": "device",
}
newStates := map[string]string{
"2": "device",
}
diffs := calculateStateDiffs(oldStates, newStates)
assert.Equal(t, []DeviceStateChangedEvent{
DeviceStateChangedEvent{"1", "removed", ""},
}, diffs)
}
func TestCalculateStateDiffsOneAddedOneRemoved(t *testing.T) {
oldStates := map[string]string{
"1": "removed",
}
newStates := map[string]string{
"2": "added",
}
diffs := calculateStateDiffs(oldStates, newStates)
assert.Equal(t, []DeviceStateChangedEvent{
DeviceStateChangedEvent{"1", "removed", ""},
DeviceStateChangedEvent{"2", "", "added"},
}, diffs)
}
func TestCalculateStateDiffsOneChangedOneUnchanged(t *testing.T) {
oldStates := map[string]string{
"1": "oldState",
"2": "device",
}
newStates := map[string]string{
"1": "newState",
"2": "device",
}
diffs := calculateStateDiffs(oldStates, newStates)
assert.Equal(t, []DeviceStateChangedEvent{
DeviceStateChangedEvent{"1", "oldState", "newState"},
}, diffs)
}
func TestCalculateStateDiffsMultipleChangedMultipleUnchanged(t *testing.T) {
oldStates := map[string]string{
"1": "oldState",
"2": "oldState",
}
newStates := map[string]string{
"1": "newState",
"2": "newState",
}
diffs := calculateStateDiffs(oldStates, newStates)
assert.Equal(t, []DeviceStateChangedEvent{
DeviceStateChangedEvent{"1", "oldState", "newState"},
DeviceStateChangedEvent{"2", "oldState", "newState"},
}, diffs)
}
func TestCalculateStateDiffsOneAddedOneRemovedOneChanged(t *testing.T) {
oldStates := map[string]string{
"1": "oldState",
"2": "removed",
}
newStates := map[string]string{
"1": "newState",
"3": "added",
}
diffs := calculateStateDiffs(oldStates, newStates)
assert.Equal(t, []DeviceStateChangedEvent{
DeviceStateChangedEvent{"1", "oldState", "newState"},
DeviceStateChangedEvent{"2", "removed", ""},
DeviceStateChangedEvent{"3", "", "added"},
}, diffs)
}
func TestPublishDevicesRestartsServer(t *testing.T) {
starter := &MockServerStarter{}
dialer := &MockServer{
Status: wire.StatusSuccess,
Errs: []error{
nil, nil, nil, // Successful dial.
util.Errorf(util.ConnectionResetError, "failed first read"),
util.Errorf(util.ServerNotAvailable, "failed redial"),
},
}
watcher := deviceWatcherImpl{
config: ClientConfig{dialer},
eventChan: make(chan DeviceStateChangedEvent),
startServer: starter.StartServer,
}
publishDevices(&watcher)
assert.Empty(t, dialer.Errs)
assert.Equal(t, []string{"host:track-devices"}, dialer.Requests)
assert.Equal(t, []string{"Dial", "SendMessage", "ReadStatus", "ReadMessage", "Dial"}, dialer.Trace)
err := watcher.err.Load().(*util.Err)
assert.Equal(t, util.ServerNotAvailable, err.Code)
assert.Equal(t, 1, starter.startCount)
}
type MockServerStarter struct {
startCount int
err error
}
func (s *MockServerStarter) StartServer() error {
log.Printf("Starting mock server")
if s.err == nil {
s.startCount += 1
return nil
} else {
return s.err
}
}

View file

@ -54,7 +54,7 @@ func (d *netDialer) Dial() (*wire.Conn, error) {
address := fmt.Sprintf("%s:%d", host, port)
netConn, err := net.Dial("tcp", address)
if err != nil {
return nil, util.WrapErrorf(err, util.NetworkError, "error dialing %s", address)
return nil, util.WrapErrorf(err, util.ServerNotAvailable, "error dialing %s", address)
}
conn := &wire.Conn{

View file

@ -1,4 +1,3 @@
// TODO(z): Implement TrackDevices.
package goadb
import (

View file

@ -25,27 +25,47 @@ func TestGetServerVersion(t *testing.T) {
assert.Equal(t, 10, v)
}
// MockServer implements Dialer, Scanner, and Sender.
type MockServer struct {
// Each time an operation is performed, if this slice is non-empty, the head element
// of this slice is returned and removed from the slice. If the head is nil, it is removed
// but not returned.
Errs []error
Status wire.StatusCode
// Messages are sent in order, each preceded by a length header.
// Messages are returned from read calls in order, each preceded by a length header.
Messages []string
nextMsgIndex int
// Each request is appended to this slice.
// Each message passed to a send call is appended to this slice.
Requests []string
nextMsgIndex int
// Each time an operaiton is performed, its name is appended to this slice.
Trace []string
}
func (s *MockServer) Dial() (*wire.Conn, error) {
s.logMethod("Dial")
if err := s.getNextErrToReturn(); err != nil {
return nil, err
}
return wire.NewConn(s, s), nil
}
func (s *MockServer) ReadStatus() (wire.StatusCode, error) {
s.logMethod("ReadStatus")
if err := s.getNextErrToReturn(); err != nil {
return "", err
}
return s.Status, nil
}
func (s *MockServer) ReadMessage() ([]byte, error) {
s.logMethod("ReadMessage")
if err := s.getNextErrToReturn(); err != nil {
return nil, err
}
if s.nextMsgIndex >= len(s.Messages) {
return nil, util.WrapErrorf(io.EOF, util.NetworkError, "")
}
@ -55,6 +75,11 @@ func (s *MockServer) ReadMessage() ([]byte, error) {
}
func (s *MockServer) ReadUntilEof() ([]byte, error) {
s.logMethod("ReadUntilEof")
if err := s.getNextErrToReturn(); err != nil {
return nil, err
}
var data []string
for ; s.nextMsgIndex < len(s.Messages); s.nextMsgIndex++ {
data = append(data, s.Messages[s.nextMsgIndex])
@ -63,18 +88,40 @@ func (s *MockServer) ReadUntilEof() ([]byte, error) {
}
func (s *MockServer) SendMessage(msg []byte) error {
s.logMethod("SendMessage")
if err := s.getNextErrToReturn(); err != nil {
return err
}
s.Requests = append(s.Requests, string(msg))
return nil
}
func (s *MockServer) NewSyncScanner() wire.SyncScanner {
s.logMethod("NewSyncScanner")
return nil
}
func (s *MockServer) NewSyncSender() wire.SyncSender {
s.logMethod("NewSyncSender")
return nil
}
func (s *MockServer) Close() error {
s.logMethod("Close")
if err := s.getNextErrToReturn(); err != nil {
return err
}
return nil
}
func (s *MockServer) getNextErrToReturn() (err error) {
if len(s.Errs) > 0 {
err = s.Errs[0]
s.Errs = s.Errs[1:]
}
return
}
func (s *MockServer) logMethod(name string) {
s.Trace = append(s.Trace, name)
}

View file

@ -8,9 +8,6 @@ import (
/*
StartServer ensures there is a server running.
Currently implemented by just running
adb start-server
*/
func StartServer() error {
cmd := exec.Command("adb", "start-server")

4
util/doc.go Normal file
View file

@ -0,0 +1,4 @@
/*
Contains code shared between the different sub-packages in this project.
*/
package util

View file

@ -4,9 +4,9 @@ package util
import "fmt"
const _ErrCode_name = "AssertionErrorParseErrorServerNotAvailableNetworkErrorAdbErrorDeviceNotFoundFileNoExistError"
const _ErrCode_name = "AssertionErrorParseErrorServerNotAvailableNetworkErrorConnectionResetErrorAdbErrorDeviceNotFoundFileNoExistError"
var _ErrCode_index = [...]uint8{0, 14, 24, 42, 54, 62, 76, 92}
var _ErrCode_index = [...]uint8{0, 14, 24, 42, 54, 74, 82, 96, 112}
func (i ErrCode) String() string {
if i+1 >= ErrCode(len(_ErrCode_index)) {

View file

@ -2,7 +2,19 @@ package util
import "fmt"
// Err is the implementation of error that all goadb functions return.
/*
Err is the implementation of error that all goadb functions return.
Best Practice
External errors should be wrapped using WrapErrorf, as soon as they are known about.
Intermediate code should pass *Errs up until they will be returned outside the library.
Errors should *not* be wrapped at every return site.
Just before returning an *Err outside the library, it can be wrapped again, preserving the
ErrCode (e.g. with WrapErrf).
*/
type Err struct {
// Code is the high-level "type" of error.
Code ErrCode
@ -22,10 +34,12 @@ type ErrCode byte
const (
AssertionError ErrCode = iota
ParseError ErrCode = iota
// The server was not available on the request port and could not be started.
// The server was not available on the requested port.
ServerNotAvailable ErrCode = iota
// General network error communicating with the server.
NetworkError ErrCode = iota
// The connection to the server was reset in the middle of an operation. Server probably died.
ConnectionResetError ErrCode = iota
// The server returned an error message, but we couldn't parse it.
AdbError ErrCode = iota
// The server returned a "device not found" error.

View file

@ -20,7 +20,11 @@ For most cases, usage looks something like:
For some messages, the server will return more than one message (but still a single
status). Generally, after calling ReadStatus once, you should call ReadMessage until
it returns an io.EOF error.
it returns an io.EOF error. Note: the protocol docs seem to suggest that connections will be
kept open for multiple commands, but this is not the case. The official client closes
a connection immediately after its read the response, in most cases. The docs might be
referring to the connection between the adb server and the device, but I haven't confirmed
that.
For most commands, the server will close the connection after sending the response.
You should still always call Close() when you're done with the connection.

View file

@ -8,6 +8,8 @@ import (
"github.com/zach-klippenstein/goadb/util"
)
// TODO(zach): All EOF errors returned from networoking calls should use ConnectionResetError.
// StatusCodes are returned by the server. If the code indicates failure, the
// next message will be the error.
type StatusCode string
@ -70,7 +72,7 @@ func (s *realScanner) ReadMessage() ([]byte, error) {
length, err := s.readLength()
if err != nil {
return nil, util.WrapErrorf(err, util.NetworkError, "error reading message length")
return nil, err
}
data := make([]byte, length)
@ -103,9 +105,7 @@ func (s *realScanner) Close() error {
func (s *realScanner) readLength() (int, error) {
lengthHex := make([]byte, 4)
n, err := io.ReadFull(s.reader, lengthHex)
if err != nil && err != io.ErrUnexpectedEOF {
return 0, util.WrapErrorf(err, util.NetworkError, "error reading length")
} else if err == io.ErrUnexpectedEOF {
if err != nil {
return 0, errIncompleteMessage("length", n, 4)
}

View file

@ -98,7 +98,7 @@ func NewEofBuffer(str string) *TestReader {
func assertEof(t *testing.T, s *realScanner) {
msg, err := s.ReadMessage()
assert.True(t, util.HasErrCode(err, util.NetworkError))
assert.True(t, util.HasErrCode(err, util.ConnectionResetError))
assert.Nil(t, msg)
}

View file

@ -60,7 +60,7 @@ func adbServerError(request string, serverMsg string) error {
func errIncompleteMessage(description string, actual int, expected int) error {
return &util.Err{
Code: util.NetworkError,
Code: util.ConnectionResetError,
Message: fmt.Sprintf("incomplete %s: read %d bytes, expecting %d", description, actual, expected),
Details: struct {
ActualReadBytes int