/*
 *  $Id: checksum.c,v 1.9 1999/01/31 21:05:41 route Exp $
 *
 *  libnet
 *  checksum.c - IP checksum routines
 *
 *  Copyright (c) 1998, 1999 route|daemon9 <route@infonexus.com>
 *  All rights reserved.
 *
 * Redistribution and use in source and binary forms, with or without
 * modification, are permitted provided that the following conditions
 * are met:
 * 1. Redistributions of source code must retain the above copyright
 *    notice, this list of conditions and the following disclaimer.
 * 2. Redistributions in binary form must reproduce the above copyright
 *    notice, this list of conditions and the following disclaimer in the
 *    documentation and/or other materials provided with the distribution.
 *
 * THIS SOFTWARE IS PROVIDED BY THE AUTHOR AND CONTRIBUTORS ``AS IS'' AND
 * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
 * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
 * ARE DISCLAIMED.  IN NO EVENT SHALL THE AUTHOR OR CONTRIBUTORS BE LIABLE
 * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
 * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS
 * OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION)
 * HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
 * LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY
 * OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF
 * SUCH DAMAGE.
 *
 */

#include "../include/libnet.h"


int
do_checksum(u_char *buf, int protocol, int len)
{
    struct icmpheader   icmp_hdr;
    struct igmpheader   igmp_hdr;
    struct tcphdr       tcp_hdr;
    struct udphdr       udp_hdr;
    struct ip           ip_hdr;
    struct psuedoheader *p_hdr;
    u_char *p;

    switch (protocol)
    {
        case IPPROTO_TCP:
            memcpy(&ip_hdr, buf, sizeof(ip_hdr));
            memcpy(&tcp_hdr, (buf + IP_H), sizeof(tcp_hdr));
#if (SOLARIS_CKSUM_BUG)
            tcp_hdr.th_sum = TCP_H;
            goto finish;
#endif
            tcp_hdr.th_sum = 0;
#if (__i386__)
            tcp_hdr.th_sum = tcp_check(&tcp_hdr,
                                        len,
                                        ip_hdr.ip_src.s_addr,
                                        ip_hdr.ip_dst.s_addr);
#else
            /*
             *  Grab memory for a psuedoheader and the TCP packet.
             */
            p = (u_char *)malloc(P_H + len);
            if (!p)
            {
                perror("do_checksum: malloc");
                return (-1);
            }
            p_hdr = (struct psuedoheader *)p;

            memset(p_hdr, 0, P_H + len);

            p_hdr->ip_src   = ip_hdr.ip_src.s_addr;
            p_hdr->ip_dst   = ip_hdr.ip_dst.s_addr;
            p_hdr->protocol = IPPROTO_TCP;
            p_hdr->len      = htons(len);

            memcpy(p + P_H, &tcp_hdr, len);
            tcp_hdr.th_sum = ip_check((u_short *)p, P_H + len);

            free(p);
            p = NULL;
#endif /* __i386__ */
#if (SOLARIS_CKSUM_BUG) /* silence -Wall warnings */
            finish:
#endif
            memcpy(buf, &ip_hdr, sizeof(ip_hdr));
            memcpy((buf + IP_H), &tcp_hdr, sizeof(tcp_hdr));
            break;
        case IPPROTO_UDP:
            memcpy(&ip_hdr, buf, sizeof(ip_hdr));
            memcpy(&udp_hdr, (buf + IP_H), sizeof(udp_hdr));

            /*
             *  Grab memory for a psuedoheader and the UDP packet.
             */
            p = (u_char *)malloc(P_H + len);
            if (!p)
            {
                perror("do_checksum: malloc");
                return (-1);
            }
            p_hdr = (struct psuedoheader *)p;

            memset(p_hdr, 0, P_H + len);

            p_hdr->ip_src   = ip_hdr.ip_src.s_addr;
            p_hdr->ip_dst   = ip_hdr.ip_dst.s_addr;
            p_hdr->protocol = IPPROTO_UDP;
            p_hdr->len      = htons(len);

            udp_hdr.uh_sum = 0;
            memcpy(p + P_H, &udp_hdr, len);
            udp_hdr.uh_sum = ip_check((u_short *)p, P_H + len);
            free(p);
            p = NULL;
            memcpy(buf, &ip_hdr, sizeof(ip_hdr));
            memcpy((buf + IP_H), &udp_hdr, sizeof(udp_hdr));
            break;
        case IPPROTO_ICMP:
            memcpy(&icmp_hdr, (buf + IP_H), sizeof(icmp_hdr));  
            icmp_hdr.icmp_sum = 0;
            icmp_hdr.icmp_sum = ip_check((u_short *)(&icmp_hdr), len);
            memcpy((buf + IP_H), &icmp_hdr, sizeof(icmp_hdr));
            break;
        case IPPROTO_IGMP:
            memcpy(&igmp_hdr, (buf + IGMP_H), sizeof(igmp_hdr));  
            igmp_hdr.igmp_sum = 0;
            igmp_hdr.igmp_sum = ip_check((u_short *)(&igmp_hdr), len);
            memcpy((buf + IGMP_H), &igmp_hdr, sizeof(igmp_hdr));
            break;
        case IPPROTO_IP:
            memcpy(&ip_hdr, buf, sizeof(ip_hdr));
            ip_hdr.ip_sum = 0;
            ip_hdr.ip_sum = ip_check((u_short *)buf, len);
            memcpy(buf, &ip_hdr, sizeof(ip_hdr));
            break;
        default:
#if (__DEBUG)
            fprintf(stderr, "do_checksum: UNSUPPORTED protocol %d\n", protocol);
#endif
            break;
    }
    return (1);
}

#if (__i386__)
u_short
tcp_check(struct tcphdr *th, int len, u_long saddr, u_long daddr)
{
    u_long sum;

    __asm__("\taddl %%ecx, %%ebx\n\t"
        "adcl %%edx, %%ebx\n\t"
        "adcl $0, %%ebx"
        : "=b"(sum)
        : "0"(daddr), "c"(saddr), "d"((ntohs(len) << 16) + IPPROTO_TCP * 256)
        : "bx", "cx", "dx" );

    __asm__("\tmovl %%ecx, %%edx\n\t"
            "cld\n\t"
            "cmpl $32, %%ecx\n\t"
            "jb 2f\n\t"
            "shrl $5, %%ecx\n\t"
            "clc\n"
            "1:\t"
            "lodsl\n\t"
            "adcl %%eax, %%ebx\n\t"
            "lodsl\n\t"
            "adcl %%eax, %%ebx\n\t"
            "lodsl\n\t"
            "adcl %%eax, %%ebx\n\t"
            "lodsl\n\t"
            "adcl %%eax, %%ebx\n\t"
            "lodsl\n\t"
            "adcl %%eax, %%ebx\n\t"
            "lodsl\n\t"
            "adcl %%eax, %%ebx\n\t"
            "lodsl\n\t"
            "adcl %%eax, %%ebx\n\t"
            "lodsl\n\t"
            "adcl %%eax, %%ebx\n\t"
            "loop 1b\n\t"
            "adcl $0, %%ebx\n\t"
            "movl %%edx, %%ecx\n"
            "2:\t"
            "andl $28, %%ecx\n\t"
            "je 4f\n\t"
            "shrl $2, %%ecx\n\t"
            "clc\n"
            "3:\t"
            "lodsl\n\t"
            "adcl %%eax, %%ebx\n\t"
            "loop 3b\n\t"
            "adcl $0, %%ebx\n"
            "4:\t"
            "movl $0, %%eax\n\t"
            "testw $2, %%dx\n\t"
            "je 5f\n\t"
            "lodsw\n\t"
            "addl %%eax, %%ebx\n\t"
            "adcl $0, %%ebx\n\t"
            "movw $0, %%ax\n"
            "5:\t"
            "test $1, %%edx\n\t"
            "je 6f\n\t"
            "lodsb\n\t"
            "addl %%eax, %%ebx\n\t"
            "adcl $0, %%ebx\n"
            "6:\t"
            "movl %%ebx, %%eax\n\t"
            "shrl $16, %%eax\n\t"
            "addw %%ax, %%bx\n\t"
            "adcw $0, %%bx"
        : "=b"(sum)
        : "0"(sum), "c"(len), "S"(th)
        : "ax", "bx", "cx", "dx", "si" );

    return ((~sum) & 0xffff);
}


u_short
ip_check(u_short *buff, int len)
{
    u_long sum = 0;

    if (len > 3)
    {
        __asm__("clc\n"
        "1:\t"
        "lodsl\n\t"
        "adcl %%eax, %%ebx\n\t"
        "loop 1b\n\t"
        "adcl $0, %%ebx\n\t"
        "movl %%ebx, %%eax\n\t"
        "shrl $16, %%eax\n\t"
        "addw %%ax, %%bx\n\t"
        "adcw $0, %%bx"
        : "=b" (sum) , "=S" (buff)
        : "0" (sum), "c" (len >> 2) ,"1" (buff)
        : "ax", "cx", "si", "bx");
    }
    if (len & 2)
    {
        __asm__("lodsw\n\t"
        "addw %%ax, %%bx\n\t"
        "adcw $0, %%bx"
        : "=b" (sum) , "=S" (buff)
        : "0" (sum), "c" (len >> 2) ,"1" (buff)
        : "ax", "cx", "si", "bx");
    }
    if (len & 2)
    {
        __asm__("lodsw\n\t"
        "addw %%ax, %%bx\n\t"
        "adcw $0, %%bx"
        : "=b" (sum), "=S" (buff)
        : "0" (sum), "1" (buff)
        : "bx", "ax", "si");
    }
    if (len & 1)
    {
        __asm__("lodsb\n\t"
        "movb $0, %%ah\n\t"
        "addw %%ax, %%bx\n\t"
        "adcw $0, %%bx"
        : "=b" (sum), "=S" (buff)
        : "0" (sum), "1" (buff)
        : "bx", "ax", "si");
    }
    if (len & 1)
    {
        __asm__("lodsb\n\t"
        "movb $0, %%ah\n\t"
        "addw %%ax, %%bx\n\t"
        "adcw $0, %%bx"
        : "=b" (sum), "=S" (buff)
        : "0" (sum), "1" (buff)
        : "bx", "ax", "si");
    }
    sum  = ~sum;
    return (sum & 0xffff);
}
#else

u_short
ip_check(register u_short *addr, register int len)
{
        register int nleft = len;
        register u_short *w = addr;
        register u_short answer;
        register int sum = 0;

        /*
         *  Our algorithm is simple, using a 32 bit accumulator (sum),
         *  we add sequential 16 bit words to it, and at the end, fold
         *  back all the carry bits from the top 16 bits into the lower
         *  16 bits.
         */
        while (nleft > 1)  {
                sum += *w++;
                nleft -= 2;

        }

        /* mop up an odd byte, if necessary */
        if (nleft == 1)
                sum += *(u_char *)w;

        /*
         * add back carry outs from top 16 bits to low 16 bits
         */
        sum = (sum >> 16) + (sum & 0xffff);     /* add hi 16 to low 16 */
        sum += (sum >> 16);                     /* add carry */
        answer = ~sum;                          /* truncate to 16 bits */
        return (answer);
}
#endif  /* __i386__ */

/* EOF */
