[PATCH v2 4/9] net/tcp: add connection info to tcp_stream structure
Mikhail Kshevetskiy
mikhail.kshevetskiy at iopsys.eu
Fri Jul 5 17:04:26 CEST 2024
Changes:
* Avoid use net_server_ip in tcp code, use tcp_stream data instead
* Ignore packets from other connections if connection already created.
This prevents us from connection break caused by other tcp stream.
Signed-off-by: Mikhail Kshevetskiy <mikhail.kshevetskiy at iopsys.eu>
---
include/net.h | 5 +-
include/net/tcp.h | 57 +++++++++++++++++---
net/fastboot_tcp.c | 46 ++++++++--------
net/net.c | 12 ++---
net/tcp.c | 129 ++++++++++++++++++++++++++++++++++-----------
net/wget.c | 52 +++++++-----------
6 files changed, 201 insertions(+), 100 deletions(-)
diff --git a/include/net.h b/include/net.h
index ac511eab103..fe645245f0f 100644
--- a/include/net.h
+++ b/include/net.h
@@ -668,6 +668,7 @@ int net_send_ip_packet(uchar *ether, struct in_addr dest, int dport, int sport,
/**
* net_send_tcp_packet() - Transmit TCP packet.
* @payload_len: length of payload
+ * @dhost: Destination host
* @dport: Destination TCP port
* @sport: Source TCP port
* @action: TCP action to be performed
@@ -676,8 +677,8 @@ int net_send_ip_packet(uchar *ether, struct in_addr dest, int dport, int sport,
*
* Return: 0 on success, other value on failure
*/
-int net_send_tcp_packet(int payload_len, int dport, int sport, u8 action,
- u32 tcp_seq_num, u32 tcp_ack_num);
+int net_send_tcp_packet(int payload_len, struct in_addr dhost, int dport,
+ int sport, u8 action, u32 tcp_seq_num, u32 tcp_ack_num);
int net_send_udp_packet(uchar *ether, struct in_addr dest, int dport,
int sport, int payload_len);
diff --git a/include/net/tcp.h b/include/net/tcp.h
index 14aee64cb1c..f224d0cae2f 100644
--- a/include/net/tcp.h
+++ b/include/net/tcp.h
@@ -279,6 +279,9 @@ enum tcp_state {
/**
* struct tcp_stream - TCP data stream structure
+ * @rhost: Remote host, network byte order
+ * @rport: Remote port, host byte order
+ * @lport: Local port, host byte order
*
* @state: TCP connection state
*
@@ -291,6 +294,10 @@ enum tcp_state {
* @lost: Used for SACK
*/
struct tcp_stream {
+ struct in_addr rhost;
+ u16 rport;
+ u16 lport;
+
/* TCP connection state */
enum tcp_state state;
@@ -305,16 +312,53 @@ struct tcp_stream {
struct tcp_sack_v lost;
};
-struct tcp_stream *tcp_stream_get(void);
+void tcp_init(void);
+
+typedef int tcp_incoming_filter(struct in_addr rhost,
+ u16 rport, u16 sport);
+
+/*
+ * This function sets user callback used to accept/drop incoming
+ * connections. Callback should:
+ * + Check TCP stream endpoint and make connection verdict
+ * - return non-zero value to accept connection
+ * - return zero to drop connection
+ *
+ * WARNING: If callback is NOT defined, all incoming connections
+ * will be dropped.
+ */
+void tcp_set_incoming_filter(tcp_incoming_filter *filter);
+
+/*
+ * tcp_stream_get -- Get or create TCP stream
+ * @is_new: if non-zero and no stream found, then create a new one
+ * @rhost: Remote host, network byte order
+ * @rport: Remote port, host byte order
+ * @lport: Local port, host byte order
+ *
+ * Returns: TCP stream structure or NULL (if not found/created)
+ */
+struct tcp_stream *tcp_stream_get(int is_new, struct in_addr rhost,
+ u16 rport, u16 lport);
+
+/*
+ * tcp_stream_connect -- Create new TCP stream for remote connection.
+ * @rhost: Remote host, network byte order
+ * @rport: Remote port, host byte order
+ *
+ * Returns: TCP new stream structure or NULL (if not created).
+ * Random local port will be used.
+ */
+struct tcp_stream *tcp_stream_connect(struct in_addr rhost, u16 rport);
+
+enum tcp_state tcp_stream_get_state(struct tcp_stream *tcp);
-enum tcp_state tcp_get_tcp_state(struct tcp_stream *tcp);
-void tcp_set_tcp_state(struct tcp_stream *tcp, enum tcp_state new_state);
-int tcp_set_tcp_header(struct tcp_stream *tcp, uchar *pkt, int dport,
- int sport, int payload_len,
+int tcp_set_tcp_header(struct tcp_stream *tcp, uchar *pkt, int payload_len,
u8 action, u32 tcp_seq_num, u32 tcp_ack_num);
/**
* rxhand_tcp() - An incoming packet handler.
+ * @tcp: TCP stream
* @pkt: pointer to the application packet
* @dport: destination TCP port
* @sip: source IP address
@@ -324,8 +368,7 @@ int tcp_set_tcp_header(struct tcp_stream *tcp, uchar *pkt, int dport,
* @action: TCP action (SYN, ACK, FIN, etc)
* @len: packet length
*/
-typedef void rxhand_tcp(uchar *pkt, u16 dport,
- struct in_addr sip, u16 sport,
+typedef void rxhand_tcp(struct tcp_stream *tcp, uchar *pkt,
u32 tcp_seq_num, u32 tcp_ack_num,
u8 action, unsigned int len);
void tcp_set_tcp_handler(rxhand_tcp *f);
diff --git a/net/fastboot_tcp.c b/net/fastboot_tcp.c
index 2eb52ea2567..de1048e366e 100644
--- a/net/fastboot_tcp.c
+++ b/net/fastboot_tcp.c
@@ -9,14 +9,14 @@
#include <net/fastboot_tcp.h>
#include <net/tcp.h>
+#define FASTBOOT_TCP_PORT 5554
+
static char command[FASTBOOT_COMMAND_LEN] = {0};
static char response[FASTBOOT_RESPONSE_LEN] = {0};
static const unsigned short handshake_length = 4;
static const uchar *handshake = "FB01";
-static u16 curr_sport;
-static u16 curr_dport;
static u32 curr_tcp_seq_num;
static u32 curr_tcp_ack_num;
static unsigned int curr_request_len;
@@ -26,34 +26,37 @@ static enum fastboot_tcp_state {
FASTBOOT_DISCONNECTING
} state = FASTBOOT_CLOSED;
-static void fastboot_tcp_answer(u8 action, unsigned int len)
+static void fastboot_tcp_answer(struct tcp_stream *tcp, u8 action,
+ unsigned int len)
{
const u32 response_seq_num = curr_tcp_ack_num;
const u32 response_ack_num = curr_tcp_seq_num +
(curr_request_len > 0 ? curr_request_len : 1);
- net_send_tcp_packet(len, htons(curr_sport), htons(curr_dport),
+ net_send_tcp_packet(len, tcp->rhost, tcp->rport, tcp->lport,
action, response_seq_num, response_ack_num);
}
-static void fastboot_tcp_reset(void)
+static void fastboot_tcp_reset(struct tcp_stream *tcp)
{
- fastboot_tcp_answer(TCP_RST, 0);
+ fastboot_tcp_answer(tcp, TCP_RST, 0);
state = FASTBOOT_CLOSED;
}
-static void fastboot_tcp_send_packet(u8 action, const uchar *data, unsigned int len)
+static void fastboot_tcp_send_packet(struct tcp_stream *tcp, u8 action,
+ const uchar *data, unsigned int len)
{
uchar *pkt = net_get_async_tx_pkt_buf();
memset(pkt, '\0', PKTSIZE);
pkt += net_eth_hdr_size() + IP_TCP_HDR_SIZE + TCP_TSOPT_SIZE + 2;
memcpy(pkt, data, len);
- fastboot_tcp_answer(action, len);
+ fastboot_tcp_answer(tcp, action, len);
memset(pkt, '\0', PKTSIZE);
}
-static void fastboot_tcp_send_message(const char *message, unsigned int len)
+static void fastboot_tcp_send_message(struct tcp_stream *tcp,
+ const char *message, unsigned int len)
{
__be64 len_be = __cpu_to_be64(len);
uchar *pkt = net_get_async_tx_pkt_buf();
@@ -64,12 +67,11 @@ static void fastboot_tcp_send_message(const char *message, unsigned int len)
memcpy(pkt, &len_be, 8);
pkt += 8;
memcpy(pkt, message, len);
- fastboot_tcp_answer(TCP_ACK | TCP_PUSH, len + 8);
+ fastboot_tcp_answer(tcp, TCP_ACK | TCP_PUSH, len + 8);
memset(pkt, '\0', PKTSIZE);
}
-static void fastboot_tcp_handler_ipv4(uchar *pkt, u16 dport,
- struct in_addr sip, u16 sport,
+static void fastboot_tcp_handler_ipv4(struct tcp_stream *tcp, uchar *pkt,
u32 tcp_seq_num, u32 tcp_ack_num,
u8 action, unsigned int len)
{
@@ -78,8 +80,6 @@ static void fastboot_tcp_handler_ipv4(uchar *pkt, u16 dport,
u8 tcp_fin = action & TCP_FIN;
u8 tcp_push = action & TCP_PUSH;
- curr_sport = sport;
- curr_dport = dport;
curr_tcp_seq_num = tcp_seq_num;
curr_tcp_ack_num = tcp_ack_num;
curr_request_len = len;
@@ -90,17 +90,17 @@ static void fastboot_tcp_handler_ipv4(uchar *pkt, u16 dport,
if (len != handshake_length ||
strlen(pkt) != handshake_length ||
memcmp(pkt, handshake, handshake_length) != 0) {
- fastboot_tcp_reset();
+ fastboot_tcp_reset(tcp);
break;
}
- fastboot_tcp_send_packet(TCP_ACK | TCP_PUSH,
+ fastboot_tcp_send_packet(tcp, TCP_ACK | TCP_PUSH,
handshake, handshake_length);
state = FASTBOOT_CONNECTED;
}
break;
case FASTBOOT_CONNECTED:
if (tcp_fin) {
- fastboot_tcp_answer(TCP_FIN | TCP_ACK, 0);
+ fastboot_tcp_answer(tcp, TCP_FIN | TCP_ACK, 0);
state = FASTBOOT_DISCONNECTING;
break;
}
@@ -112,12 +112,12 @@ static void fastboot_tcp_handler_ipv4(uchar *pkt, u16 dport,
// Only single packet messages are supported ATM
if (strlen(pkt) != command_size) {
- fastboot_tcp_reset();
+ fastboot_tcp_reset(tcp);
break;
}
strlcpy(command, pkt, len + 1);
fastboot_command_id = fastboot_handle_command(command, response);
- fastboot_tcp_send_message(response, strlen(response));
+ fastboot_tcp_send_message(tcp, response, strlen(response));
fastboot_handle_boot(fastboot_command_id,
strncmp("OKAY", response, 4) == 0);
}
@@ -130,17 +130,21 @@ static void fastboot_tcp_handler_ipv4(uchar *pkt, u16 dport,
memset(command, 0, FASTBOOT_COMMAND_LEN);
memset(response, 0, FASTBOOT_RESPONSE_LEN);
- curr_sport = 0;
- curr_dport = 0;
curr_tcp_seq_num = 0;
curr_tcp_ack_num = 0;
curr_request_len = 0;
}
+static int incoming_filter(struct in_addr rhost, u16 rport, u16 lport)
+{
+ return (lport == FASTBOOT_TCP_PORT);
+}
+
void fastboot_tcp_start_server(void)
{
printf("Using %s device\n", eth_get_name());
printf("Listening for fastboot command on tcp %pI4\n", &net_ip);
+ tcp_set_incoming_filter(incoming_filter);
tcp_set_tcp_handler(fastboot_tcp_handler_ipv4);
}
diff --git a/net/net.c b/net/net.c
index 8f076fa18e3..ff3018e6494 100644
--- a/net/net.c
+++ b/net/net.c
@@ -416,7 +416,7 @@ int net_init(void)
/* Only need to setup buffer pointers once. */
first_call = 0;
if (IS_ENABLED(CONFIG_PROT_TCP))
- tcp_set_tcp_state(tcp_stream_get(), TCP_CLOSED);
+ tcp_init();
}
return net_init_loop();
@@ -901,10 +901,10 @@ int net_send_udp_packet(uchar *ether, struct in_addr dest, int dport, int sport,
}
#if defined(CONFIG_PROT_TCP)
-int net_send_tcp_packet(int payload_len, int dport, int sport, u8 action,
- u32 tcp_seq_num, u32 tcp_ack_num)
+int net_send_tcp_packet(int payload_len, struct in_addr dhost, int dport,
+ int sport, u8 action, u32 tcp_seq_num, u32 tcp_ack_num)
{
- return net_send_ip_packet(net_server_ethaddr, net_server_ip, dport,
+ return net_send_ip_packet(net_server_ethaddr, dhost, dport,
sport, payload_len, IPPROTO_TCP, action,
tcp_seq_num, tcp_ack_num);
}
@@ -946,12 +946,12 @@ int net_send_ip_packet(uchar *ether, struct in_addr dest, int dport, int sport,
break;
#if defined(CONFIG_PROT_TCP)
case IPPROTO_TCP:
- tcp = tcp_stream_get();
+ tcp = tcp_stream_get(0, dest, dport, sport);
if (tcp == NULL)
return -EINVAL;
pkt_hdr_size = eth_hdr_size
- + tcp_set_tcp_header(tcp, pkt + eth_hdr_size, dport, sport,
+ + tcp_set_tcp_header(tcp, pkt + eth_hdr_size,
payload_len, action, tcp_seq_num,
tcp_ack_num);
break;
diff --git a/net/tcp.c b/net/tcp.c
index 80a161838f5..efa12c9e8d3 100644
--- a/net/tcp.c
+++ b/net/tcp.c
@@ -27,6 +27,7 @@
static int tcp_activity_count;
static struct tcp_stream tcp_stream;
+static tcp_incoming_filter *incoming_filter;
/*
* TCP lengths are stored as a rounded up number of 32 bit words.
@@ -41,40 +42,95 @@ static struct tcp_stream tcp_stream;
/* Current TCP RX packet handler */
static rxhand_tcp *tcp_packet_handler;
+#define RANDOM_PORT_START 1024
+#define RANDOM_PORT_RANGE 0x4000
+
+/**
+ * random_port() - make port a little random (1024-17407)
+ *
+ * Return: random port number from 1024 to 17407
+ *
+ * This keeps the math somewhat trivial to compute, and seems to work with
+ * all supported protocols/clients/servers
+ */
+static unsigned int random_port(void)
+{
+ return RANDOM_PORT_START + (get_timer(0) % RANDOM_PORT_RANGE);
+}
+
static inline s32 tcp_seq_cmp(u32 a, u32 b)
{
return (s32)(a - b);
}
/**
- * tcp_get_tcp_state() - get TCP stream state
+ * tcp_stream_get_state() - get TCP stream state
* @tcp: tcp stream
*
* Return: TCP stream state
*/
-enum tcp_state tcp_get_tcp_state(struct tcp_stream *tcp)
+enum tcp_state tcp_stream_get_state(struct tcp_stream *tcp)
{
return tcp->state;
}
/**
- * tcp_set_tcp_state() - set TCP stream state
+ * tcp_stream_set_state() - set TCP stream state
* @tcp: tcp stream
* @new_state: new TCP state
*/
-void tcp_set_tcp_state(struct tcp_stream *tcp,
- enum tcp_state new_state)
+static void tcp_stream_set_state(struct tcp_stream *tcp,
+ enum tcp_state new_state)
{
tcp->state = new_state;
}
-struct tcp_stream *tcp_stream_get(void)
+void tcp_init(void)
+{
+ incoming_filter = NULL;
+ tcp_stream.state = TCP_CLOSED;
+}
+
+void tcp_set_incoming_filter(tcp_incoming_filter *filter)
+{
+ incoming_filter = filter;
+}
+
+static struct tcp_stream *tcp_stream_add(struct in_addr rhost,
+ u16 rport, u16 lport)
+{
+ struct tcp_stream *tcp = &tcp_stream;
+
+ if (tcp->state != TCP_CLOSED)
+ return NULL;
+
+ memset(tcp, 0, sizeof(struct tcp_stream));
+ tcp->rhost.s_addr = rhost.s_addr;
+ tcp->rport = rport;
+ tcp->lport = lport;
+ tcp->state = TCP_CLOSED;
+ tcp->lost.len = TCP_OPT_LEN_2;
+ return tcp;
+}
+
+struct tcp_stream *tcp_stream_get(int is_new, struct in_addr rhost,
+ u16 rport, u16 lport)
{
- return &tcp_stream;
+ struct tcp_stream *tcp = &tcp_stream;
+
+ if ((tcp->rhost.s_addr == rhost.s_addr) &&
+ (tcp->rport == rport) &&
+ (tcp->lport == lport))
+ return tcp;
+
+ if (!is_new || (incoming_filter == NULL) ||
+ !incoming_filter(rhost, rport, lport))
+ return NULL;
+
+ return tcp_stream_add(rhost, rport, lport);
}
-static void dummy_handler(uchar *pkt, u16 dport,
- struct in_addr sip, u16 sport,
+static void dummy_handler(struct tcp_stream *tcp, uchar *pkt,
u32 tcp_seq_num, u32 tcp_ack_num,
u8 action, unsigned int len)
{
@@ -223,8 +279,7 @@ void net_set_syn_options(struct tcp_stream *tcp, union tcp_build_pkt *b)
b->ip.end = TCP_O_END;
}
-int tcp_set_tcp_header(struct tcp_stream *tcp, uchar *pkt, int dport,
- int sport, int payload_len,
+int tcp_set_tcp_header(struct tcp_stream *tcp, uchar *pkt, int payload_len,
u8 action, u32 tcp_seq_num, u32 tcp_ack_num)
{
union tcp_build_pkt *b = (union tcp_build_pkt *)pkt;
@@ -244,7 +299,7 @@ int tcp_set_tcp_header(struct tcp_stream *tcp, uchar *pkt, int dport,
case TCP_SYN:
debug_cond(DEBUG_DEV_PKT,
"TCP Hdr:SYN (%pI4, %pI4, sq=%u, ak=%u)\n",
- &net_server_ip, &net_ip,
+ &tcp->rhost, &net_ip,
tcp_seq_num, tcp_ack_num);
tcp_activity_count = 0;
net_set_syn_options(tcp, b);
@@ -265,13 +320,13 @@ int tcp_set_tcp_header(struct tcp_stream *tcp, uchar *pkt, int dport,
b->ip.hdr.tcp_flags = action;
debug_cond(DEBUG_DEV_PKT,
"TCP Hdr:ACK (%pI4, %pI4, s=%u, a=%u, A=%x)\n",
- &net_server_ip, &net_ip, tcp_seq_num, tcp_ack_num,
+ &tcp->rhost, &net_ip, tcp_seq_num, tcp_ack_num,
action);
break;
case TCP_FIN:
debug_cond(DEBUG_DEV_PKT,
"TCP Hdr:FIN (%pI4, %pI4, s=%u, a=%u)\n",
- &net_server_ip, &net_ip, tcp_seq_num, tcp_ack_num);
+ &tcp->rhost, &net_ip, tcp_seq_num, tcp_ack_num);
payload_len = 0;
pkt_hdr_len = IP_TCP_HDR_SIZE;
tcp->state = TCP_FIN_WAIT_1;
@@ -280,7 +335,7 @@ int tcp_set_tcp_header(struct tcp_stream *tcp, uchar *pkt, int dport,
case TCP_RST:
debug_cond(DEBUG_DEV_PKT,
"TCP Hdr:RST (%pI4, %pI4, s=%u, a=%u)\n",
- &net_server_ip, &net_ip, tcp_seq_num, tcp_ack_num);
+ &tcp->rhost, &net_ip, tcp_seq_num, tcp_ack_num);
tcp->state = TCP_CLOSED;
break;
/* Notify connection closing */
@@ -291,7 +346,7 @@ int tcp_set_tcp_header(struct tcp_stream *tcp, uchar *pkt, int dport,
debug_cond(DEBUG_DEV_PKT,
"TCP Hdr:FIN ACK PSH(%pI4, %pI4, s=%u, a=%u, A=%x)\n",
- &net_server_ip, &net_ip,
+ &tcp->rhost, &net_ip,
tcp_seq_num, tcp_ack_num, action);
fallthrough;
default:
@@ -299,7 +354,7 @@ int tcp_set_tcp_header(struct tcp_stream *tcp, uchar *pkt, int dport,
b->ip.hdr.tcp_flags = action | TCP_PUSH | TCP_ACK;
debug_cond(DEBUG_DEV_PKT,
"TCP Hdr:dft (%pI4, %pI4, s=%u, a=%u, A=%x)\n",
- &net_server_ip, &net_ip,
+ &tcp->rhost, &net_ip,
tcp_seq_num, tcp_ack_num, action);
}
@@ -309,8 +364,8 @@ int tcp_set_tcp_header(struct tcp_stream *tcp, uchar *pkt, int dport,
tcp->ack_edge = tcp_ack_num;
/* TCP Header */
b->ip.hdr.tcp_ack = htonl(tcp->ack_edge);
- b->ip.hdr.tcp_src = htons(sport);
- b->ip.hdr.tcp_dst = htons(dport);
+ b->ip.hdr.tcp_src = htons(tcp->lport);
+ b->ip.hdr.tcp_dst = htons(tcp->rport);
b->ip.hdr.tcp_seq = htonl(tcp_seq_num);
/*
@@ -333,10 +388,10 @@ int tcp_set_tcp_header(struct tcp_stream *tcp, uchar *pkt, int dport,
b->ip.hdr.tcp_xsum = 0;
b->ip.hdr.tcp_ugr = 0;
- b->ip.hdr.tcp_xsum = tcp_set_pseudo_header(pkt, net_ip, net_server_ip,
+ b->ip.hdr.tcp_xsum = tcp_set_pseudo_header(pkt, net_ip, tcp->rhost,
tcp_len, pkt_len);
- net_set_ip_header((uchar *)&b->ip, net_server_ip, net_ip,
+ net_set_ip_header((uchar *)&b->ip, tcp->rhost, net_ip,
pkt_len, IPPROTO_TCP);
return pkt_hdr_len;
@@ -617,19 +672,26 @@ void rxhand_tcp_f(union tcp_build_pkt *b, unsigned int pkt_len)
u32 tcp_seq_num, tcp_ack_num;
int tcp_hdr_len, payload_len;
struct tcp_stream *tcp;
+ struct in_addr src;
/* Verify IP header */
debug_cond(DEBUG_DEV_PKT,
"TCP RX in RX Sum (to=%pI4, from=%pI4, len=%d)\n",
&b->ip.hdr.ip_src, &b->ip.hdr.ip_dst, pkt_len);
- b->ip.hdr.ip_src = net_server_ip;
+ /*
+ * src IP address will be destroyed by TCP checksum verification
+ * algorithm (see tcp_set_pseudo_header()), so remember it before
+ * it was garbaged.
+ */
+ src.s_addr = b->ip.hdr.ip_src.s_addr;
+
b->ip.hdr.ip_dst = net_ip;
b->ip.hdr.ip_sum = 0;
if (tcp_rx_xsum != compute_ip_checksum(b, IP_HDR_SIZE)) {
debug_cond(DEBUG_DEV_PKT,
"TCP RX IP xSum Error (%pI4, =%pI4, len=%d)\n",
- &net_ip, &net_server_ip, pkt_len);
+ &net_ip, &src, pkt_len);
return;
}
@@ -641,11 +703,14 @@ void rxhand_tcp_f(union tcp_build_pkt *b, unsigned int pkt_len)
pkt_len)) {
debug_cond(DEBUG_DEV_PKT,
"TCP RX TCP xSum Error (%pI4, %pI4, len=%d)\n",
- &net_ip, &net_server_ip, tcp_len);
+ &net_ip, &src, tcp_len);
return;
}
- tcp = tcp_stream_get();
+ tcp = tcp_stream_get(b->ip.hdr.tcp_flags & TCP_SYN,
+ src,
+ ntohs(b->ip.hdr.tcp_src),
+ ntohs(b->ip.hdr.tcp_dst));
if (tcp == NULL)
return;
@@ -677,9 +742,9 @@ void rxhand_tcp_f(union tcp_build_pkt *b, unsigned int pkt_len)
"TCP Notify (action=%x, Seq=%u,Ack=%u,Pay%d)\n",
tcp_action, tcp_seq_num, tcp_ack_num, payload_len);
- (*tcp_packet_handler) ((uchar *)b + pkt_len - payload_len, b->ip.hdr.tcp_dst,
- b->ip.hdr.ip_src, b->ip.hdr.tcp_src, tcp_seq_num,
- tcp_ack_num, tcp_action, payload_len);
+ (*tcp_packet_handler) (tcp, (uchar *)b + pkt_len - payload_len,
+ tcp_seq_num, tcp_ack_num, tcp_action,
+ payload_len);
} else if (tcp_action != TCP_DATA) {
debug_cond(DEBUG_DEV_PKT,
@@ -690,9 +755,13 @@ void rxhand_tcp_f(union tcp_build_pkt *b, unsigned int pkt_len)
* Warning: Incoming Ack & Seq sequence numbers are transposed
* here to outgoing Seq & Ack sequence numbers
*/
- net_send_tcp_packet(0, ntohs(b->ip.hdr.tcp_src),
- ntohs(b->ip.hdr.tcp_dst),
+ net_send_tcp_packet(0, tcp->rhost, tcp->rport, tcp->lport,
(tcp_action & (~TCP_PUSH)),
tcp_ack_num, tcp->ack_edge);
}
}
+
+struct tcp_stream *tcp_stream_connect(struct in_addr rhost, u16 rport)
+{
+ return tcp_stream_add(rhost, rport, random_port());
+}
diff --git a/net/wget.c b/net/wget.c
index 1c0f97a6cc0..327fe3cfbce 100644
--- a/net/wget.c
+++ b/net/wget.c
@@ -28,9 +28,8 @@ static const char http_eom[] = "\r\n\r\n";
static const char http_ok[] = "200";
static const char content_len[] = "Content-Length";
static const char linefeed[] = "\r\n";
-static struct in_addr web_server_ip;
-static int our_port;
static int wget_timeout_count;
+struct tcp_stream *tcp;
struct pkt_qd {
uchar *pkt;
@@ -138,22 +137,19 @@ static void wget_send_stored(void)
int len = retry_len;
unsigned int tcp_ack_num = retry_tcp_seq_num + (len == 0 ? 1 : len);
unsigned int tcp_seq_num = retry_tcp_ack_num;
- unsigned int server_port;
uchar *ptr, *offset;
- server_port = env_get_ulong("httpdstp", 10, SERVER_PORT) & 0xffff;
-
switch (current_wget_state) {
case WGET_CLOSED:
debug_cond(DEBUG_WGET, "wget: send SYN\n");
current_wget_state = WGET_CONNECTING;
- net_send_tcp_packet(0, server_port, our_port, action,
+ net_send_tcp_packet(0, tcp->rhost, tcp->rport, tcp->lport, action,
tcp_seq_num, tcp_ack_num);
packets = 0;
break;
case WGET_CONNECTING:
pkt_q_idx = 0;
- net_send_tcp_packet(0, server_port, our_port, action,
+ net_send_tcp_packet(0, tcp->rhost, tcp->rport, tcp->lport, action,
tcp_seq_num, tcp_ack_num);
ptr = net_tx_packet + net_eth_hdr_size() +
@@ -168,14 +164,14 @@ static void wget_send_stored(void)
memcpy(offset, &bootfile3, strlen(bootfile3));
offset += strlen(bootfile3);
- net_send_tcp_packet((offset - ptr), server_port, our_port,
+ net_send_tcp_packet((offset - ptr), tcp->rhost, tcp->rport, tcp->lport,
TCP_PUSH, tcp_seq_num, tcp_ack_num);
current_wget_state = WGET_CONNECTED;
break;
case WGET_CONNECTED:
case WGET_TRANSFERRING:
case WGET_TRANSFERRED:
- net_send_tcp_packet(0, server_port, our_port, action,
+ net_send_tcp_packet(0, tcp->rhost, tcp->rport, tcp->lport, action,
tcp_seq_num, tcp_ack_num);
break;
}
@@ -340,10 +336,8 @@ static void wget_connected(uchar *pkt, unsigned int tcp_seq_num,
/**
* wget_handler() - TCP handler of wget
+ * @tcp: TCP stream
* @pkt: pointer to the application packet
- * @dport: destination TCP port
- * @sip: source IP address
- * @sport: source TCP port
* @tcp_seq_num: TCP sequential number
* @tcp_ack_num: TCP acknowledgment number
* @action: TCP action (SYN, ACK, FIN, etc)
@@ -352,13 +346,11 @@ static void wget_connected(uchar *pkt, unsigned int tcp_seq_num,
* In the "application push" invocation, the TCP header with all
* its information is pointed to by the packet pointer.
*/
-static void wget_handler(uchar *pkt, u16 dport,
- struct in_addr sip, u16 sport,
+static void wget_handler(struct tcp_stream *tcp, uchar *pkt,
u32 tcp_seq_num, u32 tcp_ack_num,
u8 action, unsigned int len)
{
- struct tcp_stream *tcp = tcp_stream_get();
- enum tcp_state wget_tcp_state = tcp_get_tcp_state(tcp);
+ enum tcp_state wget_tcp_state = tcp_stream_get_state(tcp);
net_set_timeout_handler(wget_timeout, wget_timeout_handler);
packets++;
@@ -442,26 +434,13 @@ static void wget_handler(uchar *pkt, u16 dport,
}
}
-#define RANDOM_PORT_START 1024
-#define RANDOM_PORT_RANGE 0x4000
-
-/**
- * random_port() - make port a little random (1024-17407)
- *
- * Return: random port number from 1024 to 17407
- *
- * This keeps the math somewhat trivial to compute, and seems to work with
- * all supported protocols/clients/servers
- */
-static unsigned int random_port(void)
-{
- return RANDOM_PORT_START + (get_timer(0) % RANDOM_PORT_RANGE);
-}
-
#define BLOCKSIZE 512
void wget_start(void)
{
+ struct in_addr web_server_ip;
+ unsigned int server_port;
+
image_url = strchr(net_boot_file_name, ':');
if (image_url > 0) {
web_server_ip = string_to_ip(net_boot_file_name);
@@ -514,8 +493,6 @@ void wget_start(void)
wget_timeout_count = 0;
current_wget_state = WGET_CLOSED;
- our_port = random_port();
-
/*
* Zero out server ether to force arp resolution in case
* the server ip for the previous u-boot command, for example dns
@@ -524,6 +501,13 @@ void wget_start(void)
memset(net_server_ethaddr, 0, 6);
+ server_port = env_get_ulong("httpdstp", 10, SERVER_PORT) & 0xffff;
+ tcp = tcp_stream_connect(web_server_ip, server_port);
+ if (tcp == NULL) {
+ net_set_state(NETLOOP_FAIL);
+ return;
+ }
+
wget_send(TCP_SYN, 0, 0, 0);
}
--
2.39.2
More information about the U-Boot
mailing list