#!/usr/bin/env python
#
# Copyright 2010, 2011 Michael Ossmann
# Copyright 2015 Travis Goodspeed
#
# This file was forked from Project Ubertooth as a DFU client for the
# TYT MD380, an amateur radio for the DMR protocol on the UHF bands.
# This script implements a lot of poorly understood extensions unique
# to the MD380.
#
#
#
#
# This program is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation; either version 2, or (at your option)
# any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program; see the file COPYING.  If not, write to
# the Free Software Foundation, Inc., 51 Franklin Street,
# Boston, MA 02110-1301, USA.

# http://pyusb.sourceforge.net/docs/1.0/tutorial.html

import struct
import sys
import time
from optparse import OptionParser
import dfu_suffix


# The tricky thing is that *THREE* different applications all show up
# as this same VID/PID pair.
#
# 1. The Tytera application image.
# 2. The Tytera bootloader at 0x08000000
# 3. The mask-rom bootloader from the STM32F405.
md380_vendor   = 0x0483
md380_product  = 0xdf11


#application_offset = 0x08000000
#ram_offset = 0x20000000
#application_size   = 0x00040000


class Enumeration(object):
    def __init__(self, id, name):
        self._id = id
        self._name = name
        setattr(self.__class__, name, self)
        self.map[id] = self

    def __int__(self):
        return self.id

    def __repr__(self):
        return self.name

    @property
    def id(self):
        return self._id

    @property
    def name(self):
        return self._name

    @classmethod
    def create_from_map(cls):
        for id, name in cls.map.iteritems():
            cls(id, name)
        
class Request(Enumeration):
    map = {
        0: 'DETACH',
        1: 'DNLOAD',
        2: 'UPLOAD',
        3: 'GETSTATUS',
        4: 'CLRSTATUS',
        5: 'GETSTATE',
        6: 'ABORT',
    }

Request.create_from_map()

class State(Enumeration):
    map = {
         0: 'appIDLE',
         1: 'appDETACH',
         2: 'dfuIDLE',
         3: 'dfuDNLOAD_SYNC',
         4: 'dfuDNBUSY',
         5: 'dfuDNLOAD_IDLE',
         6: 'dfuMANIFEST_SYNC',
         7: 'dfuMANIFEST',
         8: 'dfuMANIFEST_WAIT_RESET',
         9: 'dfuUPLOAD_IDLE',
        10: 'dfuERROR',
    }

State.create_from_map()

class Status(Enumeration):
    map = {
        0x00: 'OK',
        0x01: 'errTARGET',
        0x02: 'errFILE',
        0x03: 'errWRITE',
        0x04: 'errERASE',
        0x05: 'errCHECK_ERASED',
        0x06: 'errPROG',
        0x07: 'errVERIFY',
        0x08: 'errADDRESS',
        0x09: 'errNOTDONE',
        0x0A: 'errFIRMWARE',
        0x0B: 'errVENDOR',
        0x0C: 'errUSBR',
        0x0D: 'errPOR',
        0x0E: 'errUNKNOWN',
        0x0F: 'errSTALLEDPKT',
    }

Status.create_from_map()

class DFU(object):
    verbose=False;
    def __init__(self, device,alt):
#         for cfg in device:
#             sys.stdout.write(str(cfg.bConfigurationValue) + '\n')
#             for intf in cfg:
#                 sys.stdout.write('\t' + \
#                          str(intf.bInterfaceNumber) + \
#                          ',' + \
#                          str(intf.bAlternateSetting) + \
#                          '\n')
#                 for ep in intf:
#                     sys.stdout.write('\t\t Endpoint ' + \
#                              str(ep.bEndpointAddress) + \
#                              '\n')
        device.set_interface_altsetting(interface = 0, alternate_setting = alt)
        self._device = device

    def detach(self):
        self._device.ctrl_transfer(0x21, Request.DETACH, 0, 0, None)

    def download(self, block_number, data):
        self._device.ctrl_transfer(0x21, Request.DNLOAD, block_number, 0, data)
        #time.sleep(0.1);
    def set_address(self, address):
        a=address       &0xFF
        b=(address>>8)  &0xFF
        c=(address>>16) &0xFF
        d=(address>>24) &0xFF
        self._device.ctrl_transfer(0x21, Request.DNLOAD, 0, 0, [0x21, a, b, c, d])
        self.get_status(); #this changes state
        status=self.get_status(); #this gets the status
        if status[2]==State.dfuDNLOAD_IDLE:
            if self.verbose: print "Set pointer to 0x%08x." % address;
            self.enter_dfu_mode();
        else:
            if self.verbose: print "Failed to set pointer."
            return False;
        return True;
    def erase_block(self, address):
        a=address       &0xFF
        b=(address>>8)  &0xFF
        c=(address>>16) &0xFF
        d=(address>>24) &0xFF
        self._device.ctrl_transfer(0x21, Request.DNLOAD, 0, 0, [0x41, a, b, c, d])
        #time.sleep(0.5);
        self.get_status(); #this changes state
        status=self.get_status(); #this gets the status
        if status[2]==State.dfuDNLOAD_IDLE:
            if self.verbose: print "Erased 0x%08x." % address;
            self.enter_dfu_mode();
        else:
            if self.verbose: print "Failed to erase block."
            return False;
        return True;

    def md380_custom(self, a,b):
        """Sends a secret MD380 command."""
        a=a&0xFF
        b=b&0xFF
        self._device.ctrl_transfer(0x21, Request.DNLOAD, 0, 0, [a,b])
        self.get_status(); #this changes state
        time.sleep(0.1);
        status=self.get_status(); #this gets the status
        if status[2]==State.dfuDNLOAD_IDLE:
            if self.verbose: print "Sent custom %02x %02x." % (a,b);
            self.enter_dfu_mode();
        else:
            print "Failed to send custom %02x %02x." % (a,b);
            return False;
        return True;
    def md380_reboot(self):
        """Sends the MD380's secret reboot command.""";
        a=0x91;
        b=0x05;
        self._device.ctrl_transfer(0x21, Request.DNLOAD, 0, 0, [a,b])
        try:
            self.get_status(); #this changes state
        except:
            pass;
        return True;

    def upload(self, block_number, length):
        if self.verbose: print "Fetching block 0x%x."%block_number
        data = self._device.ctrl_transfer(0xA1, #request type
                                          Request.UPLOAD, #request
                                          block_number, #wValue
                                          0,            #index
                                          length)       #length
        return data
    
    def get_command(self):
        data = self._device.ctrl_transfer(0xA1, #request type
                                          Request.UPLOAD, #request
                                          0, #wValue
                                          0, #index
                                          32) #length
        self.get_status();
        return data
    
    def get_status(self):
        status_packed = self._device.ctrl_transfer(0xA1, Request.GETSTATUS, 0, 0, 6)
        status = struct.unpack('<BBBBBB', status_packed)
        return (Status.map[status[0]], (((status[1] << 8) | status[2]) << 8) | status[3],
                State.map[status[4]], status[5])


    def clear_status(self):
        self._device.ctrl_transfer(0x21, Request.CLRSTATUS, 0, 0, None)

    def get_state(self):
        state_packed = self._device.ctrl_transfer(0xA1, Request.GETSTATE, 0, 0, 1)
        return State.map[struct.unpack('<B', state_packed)[0]]

    def abort(self):
        self._device.ctrl_transfer(0x21, Request.ABORT, 0, 0, None)

    def enter_dfu_mode(self):
        action_map = {
            State.dfuDNLOAD_SYNC: self.abort,
            State.dfuDNLOAD_IDLE: self.abort,
            State.dfuMANIFEST_SYNC: self.abort,
            State.dfuUPLOAD_IDLE: self.abort,
            State.dfuERROR: self.clear_status,
            State.appIDLE: self.detach,
            State.appDETACH: self._wait,
            State.dfuDNBUSY: self._wait,
            State.dfuMANIFEST: self.abort,
            State.dfuMANIFEST_WAIT_RESET: self._wait,
            State.dfuIDLE: self._wait
        }
        
        while True:
            state = self.get_state()
            if state == State.dfuIDLE:
                break
            action = action_map[state]
            action()

    def _wait(self):
        time.sleep(0.1)

def download(dfu, data, flash_address):
    block_size = 1 << 8
    sector_size = 1 << 12
    if flash_address & (sector_size - 1) != 0:
        raise Exception('Download must start at flash sector boundary')

    block_number = flash_address / block_size
    assert block_number * block_size == flash_address

    try:
        while len(data) > 0:
            packet, data = data[:block_size], data[block_size:]
            if len(packet) < block_size:
                packet += '\xFF' * (block_size - len(packet))
            dfu.download(block_number, packet)
            status, timeout, state, discarded = dfu.get_status()
            sys.stdout.write('.')
            sys.stdout.flush()
            block_number += 1
    finally:
        print

def download_codeplug(dfu, data):
    """Downloads a codeplug to the MD380."""
    block_size = 1024
    
    dfu.md380_custom(0x91,0x01); #Programming Mode
    dfu.md380_custom(0x91,0x01); #Programming Mode
    #dfu.md380_custom(0xa2,0x01); #Returns "DR780...", seems to crash client.
    #hexdump(dfu.get_command());  #Gets a string.
    dfu.md380_custom(0xa2,0x02);
    hexdump(dfu.get_command());  #Gets a string.
    time.sleep(2);
    dfu.md380_custom(0xa2,0x02);
    dfu.md380_custom(0xa2,0x03);
    dfu.md380_custom(0xa2,0x04);
    dfu.md380_custom(0xa2,0x07);
    

    dfu.erase_block(0x00000000);
    dfu.erase_block(0x00010000);
    dfu.erase_block(0x00020000);
    dfu.erase_block(0x00030000);

    dfu.set_address(0x00000000); # Zero address, used by configuration tool.
    
    #sys.exit();
    
    status, timeout, state, discarded = dfu.get_status()
    #print status, timeout, state, discarded
    
    block_number = 2
    
    try:
        while len(data) > 0:
            packet, data = data[:block_size], data[block_size:]
            if len(packet) < block_size:
                packet += '\xFF' * (block_size - len(packet))
            dfu.download(block_number, packet)
            state=11
            while state!=State.dfuDNLOAD_IDLE:
                status, timeout, state, discarded = dfu.get_status()
                #print status, timeout, state, discarded
            sys.stdout.write('.')
            sys.stdout.flush()
            block_number += 1
    finally:
        print

def hexdump(string):
    """God awful hex dump function for testing."""
    buf="";
    i=0;
    for c in string:
        buf=buf+("%02x"%c);
        i=i+1;
        if i&3==0:
            buf=buf+" "
        if i&0xf==0:
            buf=buf+"   "
        if i&0x1f==0:
            buf=buf+"\n"
        
    print buf;

def test(dfu):
    """I'm trying to get this to dump code memory.  Seems stuck in SPI Flash region."""
    #dfu.md380_custom(0x91,0x01); #Programming Mode
    #dfu.md380_custom(0x91,0x31); #Mystery!
    #dfu.md380_custom(0x91,0x02); #Mystery!
 
    #dfu.set_address(0x0800d000);
    
    block_size=1024
    try:
        for block_number in range(2,5):
            data = dfu.upload(block_number, block_size)
            status, timeout, state, discarded = dfu.get_status()
            print "Status is: %x %x %x %x" % (status, timeout, state, discarded);
            #sys.stdout.write('.')
            #sys.stdout.flush()
            if len(data) == block_size:
                #f.write(data)
                hexdump(data);
            else:
                raise Exception('Upload failed to read full block.  Got %i bytes.' % len(data))
            time.sleep(.1)
        #dfu.md380_reboot()
    finally:
        print "Done."

def upload_bootloader(dfu,filename):
    """I'm trying to get this to dump code memory.  Seems stuck in SPI Flash region."""
    #dfu.set_address(0x00000000); # Address is ignored, so it doesn't really matter.
    
    # Bootloader stretches from 0x08000000 to 0x0800C000, but our
    # address and block number are ignored, so we set the block size
    # ot 0xC000 to yank the entire thing in one go.  The application
    # comes later, I think.
    block_size=0xC000; #0xC000;
    
    f=None;
    if filename!=None:
        f=open(filename,'wb');
    
    print "Dumping bootloader.  This only works in radio mode, not programming mode."
    try:
        data = dfu.upload(2, block_size)
        status, timeout, state, discarded = dfu.get_status()
        if len(data) == block_size:
            print "Got it all!";
        else:
            print "Only got %i bytes.  Older versions would give it all." % len(data);
            #raise Exception('Upload failed to read full block.  Got %i bytes.' % len(data))
        if f!=None:
            f.write(data)
        else:
            hexdump(data);
        
    finally:
        print "Done."


def upload_codeplug(dfu,filename):
    """Uploads a codeplug from the radio to the host."""
    dfu.md380_custom(0x91,0x01); #Programming Mode
    dfu.md380_custom(0x91,0x01); #Programming Mode
    #dfu.md380_custom(0xa2,0x01); #Returns "DR780...", seems to crash client.
    #hexdump(dfu.get_command());  #Gets a string.
    dfu.md380_custom(0xa2,0x02);
    dfu.get_command();  #Gets a string.
    time.sleep(2);
    dfu.md380_custom(0xa2,0x02);
    dfu.md380_custom(0xa2,0x03);
    dfu.md380_custom(0xa2,0x04);
    dfu.md380_custom(0xa2,0x07);
    
    dfu.set_address(0x00000000); # Zero address, used by configuration tool.
    
    f = open(filename, 'wb')
    block_size=1024
    try:
        # Codeplug region is 0 to 3ffffff, but only the first 256k are used.
        for block_number in range(2,0x102):
            data = dfu.upload(block_number, block_size)
            status, timeout, state, discarded = dfu.get_status()
            #print "Status is: %x %x %x %x" % (status, timeout, state, discarded);
            sys.stdout.write('.')
            sys.stdout.flush()
            if len(data) == block_size:
                f.write(data)
                #hexdump(data);
            else:
                raise Exception('Upload failed to read full block.  Got %i bytes.' % len(data))
        #dfu.md380_reboot()
    finally:
        print "Done."


def upload(dfu, flash_address, length, path):
    #block_size = 1 << 8
    block_size = 1 << 14
    
    print "Address: 0x%08x"%flash_address
    print "Block Size:    0x%04x"%block_size
    
    if flash_address & (block_size - 1) != 0:
        raise Exception('Upload must start at block boundary')

    block_number = flash_address / block_size
    assert block_number * block_size == flash_address
    #block_number=0x8000;
    print "Block Number:    0x%04x"%block_number
    
    
    cmds=dfu.get_command();
    print "%i supported commands." % len(cmds)
    for cmd in cmds:
        print "Command %02x is supported by UPLOAD."%cmd;
    
    dfu.set_address(0x08001000); #RAM
    block_number=2;
    
    f = open(path, 'wb')
   
    try:
        while length > 0:
            data = dfu.upload(block_number, block_size)
            status, timeout, state, discarded = dfu.get_status()
            print "Status is: %x %x %x %x" % (status, timeout, state, discarded);
            sys.stdout.write('.')
            sys.stdout.flush()
            if len(data) == block_size:
                f.write(data)
            else:
                raise Exception('Upload failed to read full block.  Got %i bytes.' % len(data))
            block_number += 1
            length -= len(data)
    finally:
        f.close()
        print



def detach(dfu):
    if dfu.get_state() == State.dfuIDLE:
        dfu.detach()
        print('Detached')
    else:
        print 'In unexpected state: %s' % dfu.get_state()

def init_dfu(alt=0):
    dev = usb.core.find(idVendor=md380_vendor,
                        idProduct=md380_product)
    
    if dev is None:
        raise RuntimeError('Device not found')

    dfu = DFU(dev, alt)
    dev.default_timeout = 3000

    try:
        dfu.enter_dfu_mode()
    except usb.core.USBError, e:
        if len(e.args) > 0 and e.args[0] == 'Pipe error':
            raise RuntimeError('Failed to enter DFU mode. Is bootloader running?')
        else:
            raise e

    return dfu

def usage():
    print("""
Usage: md380-dfu <command> <arguments>

Write a codeplug to the radio.
    md380-dfu write <filename.bin>

Read a codeplug and write it to a file.
    md380-dfu read <filename.bin>

Dump the bootloader from Flash memory.
    md380-dfu readboot <filename.bin>


Detach the bootloader and execute the application firmware:
    md380-dfu detach

Close the bootloader session.
    md380-dfu reboot

Modification of firmware is not yet supported, but will come soon.
""")

if __name__ == '__main__':
    try:
        if len(sys.argv) == 3:
            if sys.argv[1] == 'read':
                import usb.core
                dfu = init_dfu()
                upload_codeplug(dfu, sys.argv[2])
                print('Read complete')
            elif sys.argv[1] == 'readboot':
                import usb.core
                dfu = init_dfu()
                upload_bootloader(dfu, sys.argv[2])
            elif sys.argv[1] == 'write':
                import usb.core
                f = open(sys.argv[2], 'rb')
                data = f.read()
                f.close()

                if sys.argv[2][-4:] == '.dfu':
                    suf_len, vendor, product = dfu_suffix.check_suffix(data)
                    dfu = init_dfu()
                    firmware = data[:-suf_len]
                else:
                    dfu = init_dfu()
                    firmware = data
                

                download_codeplug(dfu, firmware)
                print('Write complete')

            elif sys.argv[1] == 'sign':
                filename = sys.argv[2]

                f = open(filename, 'rb')
                firmware = f.read()
                f.close()

                data = dfu_suffix.add_suffix(firmware, md380_vendor, md380_product)

                dfu_file = filename[:-4] + '.dfu'
                f = open(dfu_file, 'wb')
                f.write(data)
                f.close()
                print("Signed file written: %s" % dfu_file)

            else:
                usage()

        elif len(sys.argv) == 2:
            if sys.argv[1] == 'detach':
                import usb.core
                dfu = init_dfu()
                dfu.set_address(0x08000000); #Radio Application
                detach(dfu)
            elif sys.argv[1] == 'test':
                import usb.core
                dfu = init_dfu()
                test(dfu);
            elif sys.argv[1] == 'reboot':
                import usb.core
                dfu = init_dfu()
                dfu.md380_reboot()
            elif sys.argv[1] == 'abort':
                import usb.core
                dfu = init_dfu()
                dfu.abort();
            else:
                usage()
        else:
            usage()
    except RuntimeError, e:
        print(e.args[0])
        exit(1)
    except Exception, e:
        print e
        #print dfu.get_status()
        exit(1)
