/*
 * pbltool.c - Talks to the PBL of the Amstrad E3 (Delta)
 *
 * Copyright 2005-2006 Jonathan McDowell <noodles@earth.li>
 * Copyright 2006 Mark Underwood <basicmark@yahoo.com>
 *
 * This program is free software; you can redistribute it and/or modify
 * it under the terms of the GNU General Public License as published by
 * the Free Software Foundation; version 2 of the License.
 *
 * This program is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU General Public License for more details.
 *
 * You should have received a copy of the GNU General Public License
 * along with this program; if not, write to the Free Software
 * Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307  USA
 */

#define _GNU_SOURCE

#include <fcntl.h>
#include <stdbool.h>
#include <stdint.h>
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <sys/poll.h>
#include <sys/types.h>
#include <sys/stat.h>
#include <termios.h>
#include <unistd.h>

#include <errno.h>
#include <limits.h>

#define DEFAULT_SERIAL	"/dev/ttyS0"
#define RETRY_MAX	3
int debug = 0;

int sendpacket(int fd, unsigned char *buf, size_t len, size_t max)
{
	struct pollfd p;
	unsigned char c;
	uint8_t checksum;
	int i;
	int ret;
	size_t replylen;

	/* Start byte */
	c = 2;
	write(fd, &c, 1);
	/* Not compressed */
	c = 0;
	write(fd, &c, 1);
	/* Length */
	checksum = 0;
	c = len & 0xFF;
	checksum += c;
	write(fd, &c, 1);
	c = (len >> 8) & 0xFF;
	checksum += c;
	write(fd, &c, 1);

	write(fd, buf, len);

	/* Write checksum */
	for (i = 0; i < len; i++) {
		checksum += buf[i];
	}
	c = checksum;
	write(fd, &c, 1);

	ret = 0;
	p.fd = fd;
	p.events = POLLIN;
	if (poll(&p, 1, 10000) == 0) {
		/* Timed out */
		ret = 1;
	}

	if (!ret) {
		i = read(fd, &c, 1);
		if (i != 1 || c != 2) {
			printf("Didn't get expected 0x02 header -- 0x%02X.\n",
					c);
			ret = 2;
		}
	}
		
	if (!ret) {
		i = read(fd, &c, 1);
		if (i != 1 || c != 0) {
			printf("Compressed return block?\n");
			ret = 2;
		}
	}

	if (!ret) {
		i = read(fd, &c, 1);
		if (i != 1) {
			printf("Couldn't read low byte of reply len.\n");
			ret = 2;
		}
		replylen = c;
		checksum = c;
	}

	if (!ret) {
		i = read(fd, &c, 1);
		if (i != 1) {
			printf("Couldn't read high byte of reply len.\n");
			ret = 2;
		}
		replylen += (c << 8);
		checksum += c;
	}

	if (!ret) {
		if (replylen < max) {
			while (replylen > 0) {
				i = read(fd, buf, replylen);
				if (i > 0) {
					if (debug) {
						printf("Got %d of %d bytes.\n",
							i, replylen);
					}
					replylen -= i;
					buf += i;
				}
			}
		}
		while (read(fd, &c, 1) != 1) ;
		if (debug) {
			printf("Checksum = 0x%02X\n", c);
		}
	}

	/*
	 * Flush input if error.
	 */
	if (ret > 1) {
		while (poll(&p, 1, 1000) != 0) {
			while (read(fd, &c, 1) > 0) {
				printf("Flushing: 0x%02X\n", c);
			}
		}
	}

	return ret;
}


int writeblock(int fd, uint32_t start, char *file)
{
	unsigned char buf[16384 + 8];
	int imgfd, ret, i, total_noof_retrys;
	uint16_t blocklen;
	size_t count;

	count = 0;
	total_noof_retrys = 0;
	imgfd = open(file, O_RDONLY);

	if (imgfd != -1) {
		while ((blocklen = read(imgfd, buf + 8, 1024)) > 0) {
			buf[0] = 5;
			buf[1] = 0;
			*(uint32_t *) (buf + 2) = start + count;
			*(uint16_t *) (buf + 6) = blocklen;
			i = 0;
			do {
				ret = sendpacket(fd, buf, blocklen + 8,
						sizeof(buf));
				if (ret)
				{
					i++;
					total_noof_retrys++;
				}
			} while((ret != 0) && (i != RETRY_MAX));
			if (ret) {
				printf("Sending file fail after %d bytes "
					"and %d retrys. Aborting.\n", count,
					i);
				close(imgfd);
				return -1;
			}
			count += blocklen;
			printf("\r%d bytes transfered so far (%d KB)", 
					count, count / 1024);
		}
		printf("\n");
	}
	close(imgfd);
	printf("%d bytes transfered with %d retrys\n",
			count, total_noof_retrys);
	return 0;
}

void execute(int fd, uint32_t start)
{
	unsigned char buf[10];
	
	buf[0] = 4;
	buf[1] = 0;
	*(uint32_t *) (buf + 2) = start;
	printf("sendpacket: %d\n",
		sendpacket(fd, buf, 6,
			sizeof(buf)));
}

static int getbaudrate(int speed)
{
	switch (speed)
	{
		case 50:	return B50;
		case 75:	return B75;
		case 110:	return B110;
		case 134:	return B134;
		case 150:	return B150;
		case 200:	return B200;
		case 300:	return B300;
		case 600:	return B600;
		case 1200:	return B1200;
		case 1800:	return B1800;
		case 2400:	return B2400;
		case 4800:	return B4800;
		case 9600:	return B9600;
		case 19200:	return B19200;
		case 38400:	return B38400;
		case 57600:	return B57600;
		case 115200:	return B115200;
		case 230400:	return B230400;
	}

	fprintf(stderr, "%d baud is not supported\n", speed);

	return 0;
}

void setbaud(int fd, uint32_t baudrate)
{
	int i;
	unsigned char buf[10];
	struct termios serialterm;

	buf[0] = 9;
	buf[1] = 0;
	*(uint32_t *) (buf + 2) = baudrate;
	printf("sendpacket: %d\n",
		sendpacket(fd, buf, 6,
			sizeof(buf)));
	
	/* We should probably check to see if the operation failed */
	tcgetattr(fd, &serialterm);
	serialterm.c_cflag = CS8 | CLOCAL | CREAD;
	serialterm.c_lflag = 0;
	serialterm.c_oflag = 0;
	serialterm.c_iflag = IGNPAR;

	cfsetspeed(&serialterm, getbaudrate(baudrate));
	i = tcsetattr(fd, TCSANOW, &serialterm);
	if (i == -1)
		fprintf(stderr, "Failed to set up serial port\n");
}

void readblock(int fd, uint32_t start, size_t len, char *file)
{
	unsigned char buf[10];
	unsigned char c;
	int i;
	int fdout;

	fdout = open(file, O_CREAT | O_WRONLY, 0644);

	for (i = 0; i < len; i++) {
		if ((i % 16) == 0) {
			printf("\n0x%08X ", start + i);
		}
		buf[0] = 3;
		buf[1] = 0;
		*(uint32_t *) (buf + 2) = start + i;
		*(uint32_t *) (buf + 6) = 1;
		if (sendpacket(fd, buf, 10, sizeof(buf)) == 0) {
			c = *(uint32_t *)(buf + 2);
			printf("%02X ", c);
			write(fdout, &c, 1);
		}
	}
	printf("\n");
	close(fdout);
}

void readflashblock(int fd, uint32_t start, size_t len, char *file)
{
	uint32_t checksum;
	unsigned char buf[10];
	unsigned char c;
	int i;
	int fdout;

	fdout = open(file, O_CREAT | O_WRONLY, 0644);

	if (start > 0x00400000) {
		buf[0] = 3;
		buf[1] = 0;
		buf[2] = 0x00;
		buf[3] = 0x00;
		buf[4] = 0x40;
		buf[5] = 0x00;
		*(uint32_t *) (buf + 6) = start - 0x00400000;
		sendpacket(fd, buf, 10, sizeof(buf));
		checksum = *(uint32_t *)(buf + 2);
	} else {
		checksum = 0;
	}

	for (i = 0; i < len; i++) {
		if ((i % 16) == 0) {
			printf("\n0x%08X ", start + i);
		}
		buf[0] = 3;
		buf[1] = 0;
		buf[2] = 0x00;
		buf[3] = 0x00;
		buf[4] = 0x40;
		buf[5] = 0x00;
		*(uint32_t *) (buf + 6) = start - 0x00400000 + i + 1;
		if (sendpacket(fd, buf, 10, sizeof(buf)) == 0) {
			c = *(uint32_t *)(buf + 2) - checksum;
			printf("%02X ", c);
			checksum = *(uint32_t *)(buf + 2);
			write(fdout, &c, 1);
		}
	}
	printf("\n");
	close(fdout);
}

int eraseflash(int fd, uint32_t start, size_t size)
{
	unsigned char buf[10];

	buf[0] = 6;
	buf[1] = 0;
	*(uint32_t *) (buf + 2) = start;
	*(uint32_t *) (buf + 6) = size;
	*(uint16_t *) (buf + 10) = 1;
	return sendpacket(fd, buf, 12, sizeof(buf));
}

int programflash(int fd, uint32_t start, char *file)
{
	unsigned char buf[16384 + 10];
	int imgfd;
	uint16_t blocklen;
	size_t count;

	count = 0;
	imgfd = open(file, O_RDONLY);

	if (imgfd != -1) {
		while ((blocklen = read(imgfd, buf + 10, 1024)) > 0) {
			buf[0] = 14;
			buf[1] = 0;
			*(uint32_t *) (buf + 2) = start + count;
			buf[6] = 1;
			buf[7] = 0;
			*(uint16_t *) (buf + 8) = blocklen;
			printf("Trying to program %d bytes starting %02X "
					"at 0x%08X.\n",
					blocklen,
					buf[10],
					start + count);
			printf("sendpacket: %d ",
				sendpacket(fd, buf, blocklen + 10,
					sizeof(buf)));
			printf("%02X %02X %02X %02X\n",
				buf[0], buf[1], buf[2], buf[3]);
			count += blocklen;
		}
	}
	close(imgfd);

	return 0;

}

void sendbyte(int fd, unsigned char c)
{
	int i;

	do {
		struct pollfd p;
		p.fd = fd;
		p.events = POLLOUT;
		if (poll(&p, 1, INT_MAX) == 0)
			break;

		i = write(fd, &c, 1);
		if (i == -1)
			printf("I/O error on write: %s",
				strerror(errno));
	} while (i != 1);
}

unsigned char recvbyte(int fd)
{
	int i;
	unsigned char b;

	do {
		struct pollfd p;
		p.fd = fd;
		p.events = POLLIN;

		if (poll(&p, 1, INT_MAX) == 0)
			break;
		i = read(fd, &b, 1);

		if (i == -1)
			printf("I/O error on read: %s",
				strerror(errno));
	} while (i != 1);
	return b;
}

void dodgyterm(int fd)
{
	struct termios oldconsole;
	struct termios console;
	int i;

	printf("Serial terminal starting (CTRL+C to quit)\n");

	/* Put the console into raw mode. */
	tcgetattr(0, &oldconsole);
	console = oldconsole;
	cfmakeraw(&console);
	i = tcsetattr(0, TCSANOW, &console);
	if (i == -1)
		fprintf(stderr, "Failed to put console into raw mode.");

	/* Wait for input on either device. */

	for (;;) {
		struct pollfd p[2];
		p[0].fd = fd;
		p[0].events = POLLIN;
		p[1].fd = 0;
		p[1].events = POLLIN;
		
		poll(p, 2, INT_MAX);

		if (p[0].revents) {
			unsigned char b = recvbyte(fd);
			write(0, &b, 1);
		}

		if (p[1].revents) {
			unsigned char b;
			read(0, &b, 1);
			if (b == 3)
				break;
			sendbyte(fd, b);
		}
	}

	/* Put the console back the way it was. */

	tcsetattr(0, TCSANOW, &oldconsole);
}

void help(void)
{
	printf("pbltool v0.1\n");
	exit(EXIT_FAILURE);
}

int process_command(int fd, char *command, char *buffer)
{
	unsigned char buf[128];
	int ret = 0;
	int count, i;

	printf("Got command %s\n", command);
	if (strcmp(command, "meminfo") == 0) {
		buf[0] = 0x0B;
		buf[1] = 0;
		if (sendpacket(fd, buf, 2, sizeof(buf)) == 0) {
			count = buf[2];
		}

		for (i = 0; i < count; i++) {
			buf[0] = 12;
			buf[1] = 0;
			buf[2] = i;
			buf[3] = 0;
			if (sendpacket(fd, buf, 4, sizeof(buf)) == 0) {
				printf("%d 0x%08X 0x%08X 0x%08X 0x%08X "
					"0x%04X 0x%04X\n",
					i,
					*(uint32_t *)(buf + 2),
					*(uint32_t *)(buf + 6),
					*(uint32_t *)(buf + 10),
					*(uint32_t *)(buf + 14),
					*(uint16_t *)(buf + 18),
					*(uint16_t *)(buf + 20));
			}
		}
	} else if (strcmp(command, "read") == 0) {
		uint32_t start;
		size_t len;
		char *file;

		start = strtol(strtok(NULL, " \n"), NULL, 0);
		len = strtol(strtok(NULL, " \n"), NULL, 0);
		file = strtok(NULL, " \n");
		
		readblock(fd, start, len, file);
	} else if (strcmp(command, "write") == 0) {
		uint32_t start;
		char *file;

		start = strtol(strtok(NULL, " \n"), NULL, 0);
		file = strtok(NULL, " \n");

		writeblock(fd, start, file);
	} else if (strcmp(command, "exec") == 0) {
		uint32_t start;

		start = strtol(strtok(NULL, " \n"), NULL, 0);
		
		execute(fd, start);
	} else if (strcmp(command, "setbaud") == 0) {
		uint32_t baud;

		baud = strtol(strtok(NULL, " \n"), NULL, 0);
		
		setbaud(fd, baud);
	} else if (strcmp(buffer, "readflash") == 0) {
		uint32_t start;
		size_t len;
		char *file;

		start = strtol(strtok(NULL, " \n"), NULL, 0);
		len = strtol(strtok(NULL, " \n"), NULL, 0);
		file = strtok(NULL, " \n");
		
		readflashblock(fd, start + 0x400000, len, file);
	} else if (strcmp(buffer, "writeflash") == 0) {
		uint32_t start;
		char *file;

		start = strtol(strtok(NULL, " \n"), NULL, 0);
		file = strtok(NULL, " \n");
		
		programflash(fd, start + 0x400000, file);
	} else if (strcmp(buffer, "eraseflash") == 0) {
		uint32_t start;
		uint32_t pages;

		start = strtol(strtok(NULL, " \n"), NULL, 0);
		pages = strtol(strtok(NULL, " \n"), NULL, 0);
		eraseflash(fd, start, pages * 0x4000);
	} else if (strcmp(command, "help") == 0) {
			printf("\nread <start addr> <len> <file>\n");
			printf("write <start addr> <file>\n");
			printf("write <start addr>\n");
			printf("setbaud <baud rate>\n");
			printf("readflash <start addr> <len> <file>\n");
			printf("writeflash <start addr> <file>\n");
			printf("eraseflash <start addr> <number of 16k pages>\n");
			printf("quit\n");
			printf("exit\n");
			
	} else if ((strcmp(command, "quit") == 0) ||
			(strcmp(command, "exit") == 0)) {
	 	printf("Bye!\n");
	} else {
		printf("Unknown command!\n");
		ret = -1;
	}
	return ret;
}

void command_line(int fd, char major_version, char minor_version)
{
	char *buffer = NULL;
	size_t line_len;
	int ret;
	char *command;

	/*
	 * Sit in a loop processing commands until someone does an exec or
	 * quits.
	 */
	do {
		do {
			printf("PBL v%d.%d> ", major_version, minor_version);
			getline(&buffer, &line_len, stdin);
		} while (line_len == 1);
	
		command = strtok(buffer, " \n");
		ret = process_command(fd, command, buffer);
	} while ((strcmp(command, "quit") != 0) &&
			(strcmp(command, "exec") != 0));
	
	if (strcmp(command, "exec") == 0) {
		printf("Starting dodgyterm term...\n");
		dodgyterm(fd);
	}

	free(buffer);
}

void script_file(int fd, char *fname)
{
	FILE *script_fd;
	char *buffer = NULL;
	size_t line_len;
	int ret;
	char *command;

	script_fd = fopen(fname, "r");
	if (script_fd == NULL) {
		printf("Con't open script file %s\n", fname);
		return;
	}
	
	for (;;) {
		ret = getline(&buffer, &line_len, script_fd);
		if (ret == -1)
			break;

		command = strtok(buffer, " \n");
		ret = process_command(fd, command, buffer);
	}

	if (strcmp(command, "exec") == 0) {
		printf("Starting dodgyterm term...\n");
		dodgyterm(fd);
	}
	free(buffer);
}

void usage(void)
{
	puts("Usage: pbltool [-d] [-f <serial port>] [-h] [-n] [script file]");
	puts("\n\t-d                 Enable debugging output.");
	puts("\t-f <serial port>   Set the serial port to use.");
	puts("\t-h                 Show this help.");
	puts("\t-n                 Don't prod; assume we've handshaked.");
	puts("\tscript file        A file containing the commands to run.");
	exit(EXIT_SUCCESS);
}

int main(int argc, char *argv[])
{
	int fd;
	struct termios serialterm;
	struct pollfd p;
	unsigned char c;
	unsigned char buf[128];
	int i;
	bool gotver = false;
	bool noprod = false;
	char *serdev = NULL;
	int opt;

	while ((opt = getopt(argc, argv, "df:hn")) != -1) {
		switch (opt) {
		case 'd':
			debug = 1;
			break;
		case 'f':
			if (serdev)
				free(serdev);
			serdev = strdup(optarg);
			break;
		case 'h':
			usage();
			break;
		case 'n':
			noprod = true;
			break;
		}
	}

	if (serdev) {
		fd = open(serdev, O_RDWR | O_NOCTTY | O_NDELAY);
	} else {
		fd = open(DEFAULT_SERIAL, O_RDWR | O_NOCTTY | O_NDELAY);
	}
	if (fd == -1) {
		printf("Can't open serial port\n");
		return -1;
	}
	tcgetattr(fd, &serialterm);
	serialterm.c_cflag = CS8 | CLOCAL | CREAD;
	serialterm.c_lflag = 0;
	serialterm.c_oflag = 0;
	serialterm.c_iflag = IGNPAR;
	cfsetspeed(&serialterm, B9600);
	tcsetattr(fd, TCSANOW, &serialterm);
	tcflush(fd, TCIOFLUSH);

	if (!noprod) {
		printf("Prodding...\n");
	
		while (c != 0x06) {
			c = 0x1B;
			write(fd, &c, 1);

			p.fd = fd;
			p.events = POLLIN;
			if (poll(&p, 1, 100) == 1) {
				do {
					i = read(fd, &c, 1);
					if ((i == 1) && (c == 0x06)) {
						break;
					}
				} while (i == 1);
			}
		}
		printf("Handshaking...\n");

		for (;;) {
			p.fd = fd;
			p.events = POLLIN;
			if (poll(&p, 1, 100) == 0)
				break;
			do {
				i = read(fd, &c, 1);
				if ((i == 1) && (c != 0x06)) {
				printf("Error: Got 0x%02X instead of expected"
						" 0x06.\n", c);
				}
			} while (i == 1);
		}

		printf("Done handshaking\n");
	}

	while (!gotver) {
		buf[0] = 2;
		buf[1] = 0;
		if (sendpacket(fd, buf, 2, sizeof(buf)) == 0) {
			gotver = true;
			printf("Talking to PBL v%d.%d Build %d\n",
					buf[4], buf[5],
					buf[6] + (buf[7] << 8));
		}
	}

	sleep(1);

	if (optind < argc) {
		script_file(fd, argv[optind++]);
	} else {
		command_line(fd, buf[4], buf[5]);
	}

	close(fd);
	exit(EXIT_SUCCESS);
}

