-
Notifications
You must be signed in to change notification settings - Fork 0
/
tcp_pkt.py
162 lines (148 loc) · 5.27 KB
/
tcp_pkt.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
import socket
import struct
from ip_pkt import IPPacket, calculate_checksum
HEADER_FORMAT = "!HHIIBBHHH"
PSEUDO_HEADER_FORMAT = "!4s4sBBH"
class TCPPacket:
"""
This class represents a TCP pkt.
"""
def __init__(self, src_host: str, src_port: int, dst_host: str, dst_port: int, payload: str = ""):
"""
Instantiates a TCPPacket object to the given source address, source port, destination address, destination
port, and payload.
:param src_host: the source address
:param src_port: the source port
:param dst_host: the destination address
:param dst_port: the destination port
:param payload: the payload
"""
self.src_host = src_host
self.src_port = src_port
self.dst_host = dst_host
self.dst_port = dst_port
self.seq_num = 0
self.ack_num = 0
self.data_offset = 5 << 4 # 4 reserved bits out of the byte
self.flags = 0b00000000 # 2 reserved bits, finish, synchronization, reset, push, acknowledgement, urgent flags
self.fin = False # finish flag
self.syn = False # synchronization flag
self.rst = False # reset flag
self.psh = False # push flag
self.ack = False # acknowledgement flag
self.urg = False # urgent flag
self.adv_wnd = 65535 # max window size
self.checksum = 0
self.urg_ptr = 0
self.packet = None
self.pseudo_header = None
self.payload = payload.encode()
def set_flags(self):
"""
Sets the flags of this TCP pkt.
"""
if self.fin:
self.flags |= 1
if self.syn:
self.flags |= 1 << 1
if self.rst:
self.flags |= 1 << 2
if self.psh:
self.flags |= 1 << 3
if self.ack:
self.flags |= 1 << 4
if self.urg:
self.flags |= 1 << 5
def pack(self) -> bytes:
"""
Formats and encodes the header fields.
"""
self.set_flags()
self.packet = struct.pack(
HEADER_FORMAT,
self.src_port, # source port
self.dst_port, # destination port
self.seq_num, # sequence number
self.ack_num, # acknowledgment number
self.data_offset, # data offset (first 4 bits of the byte, the rest is reserved)
self.flags, # flags
self.adv_wnd, # window
self.checksum, # checksum
self.urg_ptr, # urgent pointer
)
self.pseudo_header = struct.pack( # packs an IP pseudo header for calculating the checksum
PSEUDO_HEADER_FORMAT,
socket.inet_aton(self.src_host), # source address
socket.inet_aton(self.dst_host), # destination address
0, # reserved
socket.IPPROTO_TCP, # protocol ID
len(self.packet) + len(self.payload), # pkt length
)
self.checksum = calculate_checksum(self.pseudo_header + self.packet + self.payload) # calculate checksum
self.packet = ( # inject calculated checksum in the right spot
self.packet[:16]
+ struct.pack("!H", self.checksum)
+ self.packet[18:]
+ self.payload
)
return self.packet
@staticmethod
def unpack(ip_pkt: IPPacket, raw_tcp_pkt: bytes) -> str or None:
"""
Decodes and parses incoming TCP packets.
:param ip_pkt: the IP pkt
:param raw_tcp_pkt: the encoded TCP pkt
:return: the decoded TCP pkt or None
"""
(
src_port,
dst_port,
seq_num,
ack_num,
_, # don't need header length
flag_byte,
adv_wnd,
checksum,
urg,
) = struct.unpack(HEADER_FORMAT, raw_tcp_pkt[:20])
fin = flag_byte & 1 == 1
syn = flag_byte & 1 << 1 == 1 << 1
rst = flag_byte & 1 << 2 == 1 << 2
psh = flag_byte & 1 << 3 == 1 << 3
ack = flag_byte & 1 << 4 == 1 << 4
urg = flag_byte & 1 << 5 == 1 << 5
tcp_pkt = TCPPacket(
src_host=ip_pkt.src,
src_port=src_port,
dst_host=ip_pkt.dst,
dst_port=dst_port,
)
tcp_pkt.seq_num = seq_num
tcp_pkt.ack_num = ack_num
tcp_pkt.adv_wnd = adv_wnd
tcp_pkt.checksum = checksum
tcp_pkt.payload = raw_tcp_pkt[20:]
tcp_pkt.fin = fin
tcp_pkt.syn = syn
tcp_pkt.rst = rst
tcp_pkt.psh = psh
tcp_pkt.ack = ack
tcp_pkt.urg = urg
pseudo_header = struct.pack(
PSEUDO_HEADER_FORMAT,
socket.inet_aton(ip_pkt.src),
socket.inet_aton(ip_pkt.dst),
0, # reserved
socket.IPPROTO_TCP,
len(raw_tcp_pkt), # payload included
)
zero_csum_raw_tcp_pkt = ( # reset the incoming packet's checksum to 0
raw_tcp_pkt[:16]
+ struct.pack("!H", 0)
+ raw_tcp_pkt[18:]
)
check_checksum = calculate_checksum(pseudo_header + zero_csum_raw_tcp_pkt) # calculate checksum
if check_checksum == checksum: # compare locally calculated checksum with server's one
return tcp_pkt
else:
return None