diff --git a/fastboot/FastbootDevice.go b/fastboot/FastbootDevice.go index 4ae6546..9824fb5 100644 --- a/fastboot/FastbootDevice.go +++ b/fastboot/FastbootDevice.go @@ -26,16 +26,20 @@ var Error = struct { DeviceNotFound error }{ VarNotFound: errors.New("variable not found"), - DeviceNotFound: gousb.ErrorNotFound, + DeviceNotFound: errors.New("device not found"), } type FastbootDevice struct { - Serial string - Device *gousb.Device + Device *gousb.Device + Context *gousb.Context + In *gousb.InEndpoint + Out *gousb.OutEndpoint + Unclaim func() } -func FindDevices(ctx *gousb.Context) ([]FastbootDevice, error) { - var fastbootDevices []FastbootDevice +func FindDevices() ([]*FastbootDevice, error) { + ctx := gousb.NewContext() + var fastbootDevices []*FastbootDevice devs, err := ctx.OpenDevices(func(desc *gousb.DeviceDesc) bool { for _, cfg := range desc.Configs { for _, ifc := range cfg.Interfaces { @@ -44,101 +48,81 @@ func FindDevices(ctx *gousb.Context) ([]FastbootDevice, error) { } } } - return false + return true }) - if err != nil { + if err != nil && len(devs) == 0 { return nil, err } for _, dev := range devs { - serial, err := dev.SerialNumber() + intf, done, err := dev.DefaultInterface() if err != nil { continue } - fastbootDevices = append(fastbootDevices, FastbootDevice{Serial: serial, Device: dev}) + inEndpoint, err := intf.InEndpoint(0x81) + if err != nil { + continue + } + outEndpoint, err := intf.OutEndpoint(0x01) + if err != nil { + continue + } + fastbootDevices = append(fastbootDevices, &FastbootDevice{ + Device: dev, + Context: ctx, + In: inEndpoint, + Out: outEndpoint, + Unclaim: done, + }) } return fastbootDevices, nil } -func FindDevice(ctx *gousb.Context, serial string) (FastbootDevice, error) { - var fastbootDevice FastbootDevice - devs, err := ctx.OpenDevices(func(desc *gousb.DeviceDesc) bool { - for _, cfg := range desc.Configs { - for _, ifc := range cfg.Interfaces { - for _, alt := range ifc.AltSettings { - return alt.Protocol == 0x03 && alt.Class == 0xff && alt.SubClass == 0x42 - } - } - } - return false - }) +func FindDevice(serial string) (*FastbootDevice, error) { + devs, err := FindDevices() if err != nil { - return fastbootDevice, err + return &FastbootDevice{}, err } for _, dev := range devs { - serialNumber, err := dev.SerialNumber() - if err != nil { + s, e := dev.Device.SerialNumber() + if e != nil { continue } - if serial != serialNumber { + if serial != s { continue } - return FastbootDevice{Serial: serial, Device: dev}, nil + return dev, nil } - return fastbootDevice, Error.DeviceNotFound + return &FastbootDevice{}, Error.DeviceNotFound +} + +func (d *FastbootDevice) Close() { + d.Unclaim() + d.Device.Close() + d.Context.Close() } func (d *FastbootDevice) Send(data []byte) error { - intf, done, err := d.Device.DefaultInterface() - if err != nil { - return nil - } - defer done() - - outEndpoint, err := intf.OutEndpoint(0x01) - if err != nil { - return nil - } - - _, err = outEndpoint.Write(data) + _, err := d.Out.Write(data) return err } func (d *FastbootDevice) GetMaxPacketSize() (int, error) { - intf, done, err := d.Device.DefaultInterface() - if err != nil { - return 0, err - } - defer done() - - outEndpoint, err := intf.OutEndpoint(0x01) - if err != nil { - return 0, err - } - - return outEndpoint.Desc.MaxPacketSize, nil + return d.Out.Desc.MaxPacketSize, nil } func (d *FastbootDevice) Recv() (FastbootResponseStatus, []byte, error) { - intf, done, err := d.Device.DefaultInterface() - if err != nil { - return Status.FAIL, nil, err - } - defer done() - - inEndpoint, err := intf.InEndpoint(0x81) - if err != nil { - return Status.FAIL, nil, err - } - var data []byte - buf := make([]byte, inEndpoint.Desc.MaxPacketSize) - n, _ := inEndpoint.Read(buf) + buf := make([]byte, d.In.Desc.MaxPacketSize) + n, err := d.In.Read(buf) + if err != nil { + return Status.FAIL, []byte{}, err + } data = append(data, buf[:n]...) var status FastbootResponseStatus switch string(data[:4]) { @@ -154,8 +138,11 @@ func (d *FastbootDevice) Recv() (FastbootResponseStatus, []byte, error) { return status, data[4:], nil } -func (d *FastbootDevice) GerVar(variable string) (string, error) { - d.Send([]byte(fmt.Sprintf("getvar:%s", variable))) +func (d *FastbootDevice) GetVar(variable string) (string, error) { + err := d.Send([]byte(fmt.Sprintf("getvar:%s", variable))) + if err != nil { + return "", err + } status, resp, err := d.Recv() if status == Status.FAIL { err = Error.VarNotFound