#include <stdio.h>
#include <arpa/inet.h>
#include <netinet/in.h>
#include <stdint.h>
#include <sys/socket.h>

#include "client.h"
#include "log.h"
#include "state.pb.h"
#include "stream.h"

#ifndef SO_MAX_PACING_RATE
#define SO_MAX_PACING_RATE 47
#endif

using namespace std;

Client::Client(int sock)
	: sock(sock)
{
	request.reserve(1024);

	// Find the remote address, and convert it to ASCII.
	sockaddr_in6 addr;
	socklen_t addr_len = sizeof(addr);

	if (getpeername(sock, reinterpret_cast<sockaddr *>(&addr), &addr_len) == -1) {
		log_perror("getpeername");
		remote_addr = "";
		return;
	}

	char buf[INET6_ADDRSTRLEN];
	if (IN6_IS_ADDR_V4MAPPED(&addr.sin6_addr)) {
		// IPv4 address, really.
		if (inet_ntop(AF_INET, &addr.sin6_addr.s6_addr32[3], buf, sizeof(buf)) == nullptr) {
			log_perror("inet_ntop");
			remote_addr = "";
		} else {
			remote_addr = buf;
		}
	} else {
		if (inet_ntop(addr.sin6_family, &addr.sin6_addr, buf, sizeof(buf)) == nullptr) {
			log_perror("inet_ntop");
			remote_addr = "";
		} else {
			remote_addr = buf;
		}
	}
}
	
Client::Client(const ClientProto &serialized, const vector<shared_ptr<const string>> &short_responses, Stream *stream)
	: sock(serialized.sock()),
	  remote_addr(serialized.remote_addr()),
	  referer(serialized.referer()),
	  user_agent(serialized.user_agent()),
	  x_playback_session_id(serialized.x_playback_session_id()),
	  state(State(serialized.state())),
	  request(serialized.request()),
	  url(serialized.url()),
	  stream(stream),
	  close_after_response(serialized.close_after_response()),
	  http_11(serialized.http_11()),
	  header_or_short_response_bytes_sent(serialized.header_or_short_response_bytes_sent()),
	  stream_pos(serialized.stream_pos()),
	  stream_pos_end(serialized.stream_pos_end()),
	  bytes_sent(serialized.bytes_sent()),
	  bytes_lost(serialized.bytes_lost()),
	  num_loss_events(serialized.num_loss_events())
{
	if (stream != nullptr) {
		if (setsockopt(sock, SOL_SOCKET, SO_MAX_PACING_RATE, &stream->pacing_rate, sizeof(stream->pacing_rate)) == -1) {
			if (stream->pacing_rate != ~0U) {
				log_perror("setsockopt(SO_MAX_PACING_RATE)");
			}
		}
	}

	if (serialized.has_header_or_short_response_old()) {
		// Pre-1.4.0.
		header_or_short_response_holder = serialized.header_or_short_response_old();
		header_or_short_response = &header_or_short_response_holder;
	} else if (serialized.has_header_or_short_response_index()) {
		assert(size_t(serialized.header_or_short_response_index()) < short_responses.size());
		header_or_short_response_ref = short_responses[serialized.header_or_short_response_index()];
		header_or_short_response = header_or_short_response_ref.get();
	}
	connect_time.tv_sec = serialized.connect_time_sec();
	connect_time.tv_nsec = serialized.connect_time_nsec();

	in_ktls_mode = false;
	if (serialized.has_tls_context()) {
		tls_context = tls_import_context(
			reinterpret_cast<const unsigned char *>(serialized.tls_context().data()),
			serialized.tls_context().size());
		if (tls_context == nullptr) {
			log(WARNING, "tls_import_context() failed, TLS client might not survive across restart");
		} else {
			tls_data_to_send = tls_get_write_buffer(tls_context, &tls_data_left_to_send);

			assert(serialized.tls_output_bytes_already_consumed() <= tls_data_left_to_send);
			if (serialized.tls_output_bytes_already_consumed() >= tls_data_left_to_send) {
				tls_buffer_clear(tls_context);
				tls_data_to_send = nullptr;
			} else {
				tls_data_to_send += serialized.tls_output_bytes_already_consumed();
				tls_data_left_to_send -= serialized.tls_output_bytes_already_consumed();
			}
			in_ktls_mode = serialized.in_ktls_mode();
		}
	} else {
		tls_context = nullptr;
	}
}

ClientProto Client::serialize(unordered_map<const string *, size_t> *short_response_pool) const
{
	ClientProto serialized;
	serialized.set_sock(sock);
	serialized.set_remote_addr(remote_addr);
	serialized.set_referer(referer);
	serialized.set_user_agent(user_agent);
	serialized.set_x_playback_session_id(x_playback_session_id);
	serialized.set_connect_time_sec(connect_time.tv_sec);
	serialized.set_connect_time_nsec(connect_time.tv_nsec);
	serialized.set_state(state);
	serialized.set_request(request);
	serialized.set_url(url);

	if (header_or_short_response != nullptr) {
		// See if this string is already in the pool (deduplicated by the pointer); if not, insert it.
		auto iterator_and_inserted = short_response_pool->emplace(
			header_or_short_response, short_response_pool->size());
		serialized.set_header_or_short_response_index(iterator_and_inserted.first->second);
	}

	serialized.set_header_or_short_response_bytes_sent(serialized.header_or_short_response_bytes_sent());
	serialized.set_stream_pos(stream_pos);
	serialized.set_stream_pos_end(stream_pos_end);
	serialized.set_bytes_sent(bytes_sent);
	serialized.set_bytes_lost(bytes_lost);
	serialized.set_num_loss_events(num_loss_events);
	serialized.set_http_11(http_11);
	serialized.set_close_after_response(close_after_response);

	if (tls_context != nullptr) {
		bool small_version = false;
		int required_size = tls_export_context(tls_context, nullptr, 0, small_version);
		if (required_size <= 0) {
			// Can happen if we're in the middle of the key exchange, unfortunately.
			// We'll get an error fairly fast, and this client hasn't started playing
			// anything yet, so just log the error and continue.
			//
			// In theory, we could still rescue it if we had sent _zero_ bytes,
			// by doing an entirely new TLS context, but it's an edge case
			// that's not really worth it.
			log(WARNING, "tls_export_context() failed (returned %d), TLS client might not survive across restart",
				required_size);
		} else {
			string *serialized_context = serialized.mutable_tls_context();
			serialized_context->resize(required_size);

			int ret = tls_export_context(tls_context,
				reinterpret_cast<unsigned char *>(&(*serialized_context)[0]),
				serialized_context->size(),
				small_version);
			assert(ret == required_size);

			// tls_export_context() has exported the contents of the write buffer, but it doesn't
			// know how much of that we've consumed, so we need to figure that out by ourselves.
			// In a sense, it's unlikely that this will ever be relevant, though, since TLSe can't
			// currently serialize in-progress key exchanges.
			unsigned base_tls_data_left_to_send;
			const unsigned char *base_tls_data_to_send = tls_get_write_buffer(tls_context, &base_tls_data_left_to_send);
			if (base_tls_data_to_send == nullptr) {
				assert(tls_data_to_send == nullptr);
			} else {
				assert(tls_data_to_send + tls_data_left_to_send == base_tls_data_to_send + base_tls_data_left_to_send);
			}
			serialized.set_tls_output_bytes_already_consumed(tls_data_to_send - base_tls_data_to_send);
			serialized.set_in_ktls_mode(in_ktls_mode);
		}
	}

	return serialized;
}

namespace {

string escape_string(const string &str) {
	string ret;
	for (size_t i = 0; i < str.size(); ++i) {
		char buf[16];
		if (isprint(str[i]) && str[i] >= 32 && str[i] != '"' && str[i] != '\\') {
			ret.push_back(str[i]);
		} else {
			snprintf(buf, sizeof(buf), "\\x%02x", (unsigned char)str[i]);
			ret += buf;
		}
	}
	return ret;
}

} // namespace
	
ClientStats Client::get_stats() const
{
	ClientStats stats;
	if (url.empty()) {
		stats.url = "-";
	} else {
		stats.url = url;
	}
	stats.sock = sock;
	stats.remote_addr = remote_addr;
	stats.referer = escape_string(referer);
	stats.user_agent = escape_string(user_agent);
	stats.connect_time = connect_time;
	stats.bytes_sent = bytes_sent;
	stats.bytes_lost = bytes_lost;
	stats.num_loss_events = num_loss_events;
	stats.hls_zombie_key = get_hls_zombie_key();
	return stats;
}
