#include "common/url.h" #include #include using namespace ProtoRock::Http; enum EncodingMode { encodePath = 1, encodePathSegment, encodeHost, encodeZone, encodeUserPassword, encodeQueryComponent, encodeFragment, }; char unhex(char c) { if ('0' <= c && c <= '9') { return c - '0'; } if ('a' <= c && c <= 'f') { return c - 'a' + 10; } if ('A' <= c && c <= 'F') { return c - 'A' + 10; } return 0; } const char *upperhex = "0123456789ABCDEF"; // Return true if the specified character should be escaped when // appearing in a URL string, according to RFC 3986. // // Please be informed that for now shouldEscape does not check all // reserved characters correctly. See golang.org/issue/5684. bool shouldEscape(char c, EncodingMode mode) { // §2.3 Unreserved characters (alphanum) if ('a' <= c && c <= 'z' || 'A' <= c && c <= 'Z' || '0' <= c && c <= '9') { return false; } if (mode == encodeHost || mode == encodeZone) { // §3.2.2 Host allows // sub-delims = "!" / "$" / "&" / "'" / "(" / ")" / "*" / "+" / "," / ";" / "=" // as part of reg-name. // We add : because we include :port as part of host. // We add [ ] because we include [ipv6]:port as part of host. // We add < > because they're the only characters left that // we could possibly allow, and Parse will reject them if we // escape them (because hosts can't use %-encoding for // ASCII bytes). switch (c) { case '!': case '$': case '&': case '\'': case '(': case ')': case '*': case '+': case ',': case ';': case '=': case ':': case '[': case ']': case '<': case '>': case '"': return false; } } switch (c) { // §2.3 Unreserved characters (mark) case '-': case '_': case '.': case '~': return false; // §2.2 Reserved characters (reserved) case '$': case '&': case '+': case ',': case '/': case ':': case ';': case '=': case '?': case '@': // Different sections of the URL allow a few of // the reserved characters to appear unescaped. switch (mode) { case encodePath: // §3.3 // The RFC allows : @ & = + $ but saves / ; , for assigning // meaning to individual path segments. This package // only manipulates the path as a whole, so we allow those // last three as well. That leaves only ? to escape. return c == '?'; case encodePathSegment: // §3.3 // The RFC allows : @ & = + $ but saves / ; , for assigning // meaning to individual path segments. return c == '/' || c == ';' || c == ',' || c == '?'; case encodeUserPassword: // §3.2.1 // The RFC allows ';', ':', '&', '=', '+', '$', and ',' in // userinfo, so we must escape only '@', '/', and '?'. // The parsing of userinfo treats ':' as special so we must escape // that too. return c == '@' || c == '/' || c == '?' || c == ':'; case encodeQueryComponent: // §3.4 // The RFC reserves (so we must escape) everything. return true; case encodeFragment: // §4.1 // The RFC text is silent but the grammar allows // everything: case so escape nothing. return false; } } if (mode == encodeFragment) { // RFC 3986 §2.2 allows not escaping sub-delims. A subset of sub-delims are // included in reserved from RFC 2396 §2.2. The remaining sub-delims do not // need to be escaped. To minimize potential breakage, we apply two restrictions: // (1) we always escape sub-delims outside of the fragment, and (2) we always // escape single quote to avoid breaking callers that had previously assumed that // single quotes would be escaped. See issue #19917. switch (c) { case '!': case '(': case ')': case '*': return false; } } return true; } std::string escape(const std::string &s, EncodingMode mode) { auto spaceCount = 0; auto hexCount = 0; for (auto i = 0; i < s.size(); i++) { auto c = s[i]; if (shouldEscape(c, mode)) { if (c == ' ' && mode == encodeQueryComponent) { spaceCount++; } else { hexCount++; } } } if (spaceCount == 0 && hexCount == 0) { return s; } auto required = s.size() + 2 * hexCount; auto t = std::vector(); t.reserve(required); if (hexCount == 0) { t.insert(t.begin(), s.begin(), s.end()); for (auto i = 0; i < s.size(); i++) { if (s[i] == ' ') { t[i] = '+'; } } return std::string(t.begin(), t.end()); } auto j = 0; auto c = 0; for (auto i = 0; i < s.size(); i++) { auto c = s[i]; if (c == ' ' && mode == encodeQueryComponent) { t[j] = '+'; j++; } else if (shouldEscape(c, mode)) { t[j] = '%'; t[j + 1] = upperhex[c >> 4]; t[j + 2] = upperhex[c & 15]; j += 3; } else { t[j] = s[i]; j++; } } return std::string(t.begin(), t.end()); } std::string unescape(std::string s, EncodingMode mode) { // Count %, check that they're well-formed. auto n = 0; auto hasPlus = false; auto tmp = std::string(); auto v = 0; for (int i = 0; i < s.size();) { switch (s[i]) { case '%': n++; if (i + 2 >= s.size() || !std::isxdigit(s[i + 1]) || !std::isxdigit(s[i + 2])) { s = std::string(s.begin() + 1, s.end()); if (s.size() > 3) { s = std::string(s.begin(), s.begin() + 3); } throw std::invalid_argument("escape error: " + s); } // Per https://tools.ietf.org/html/rfc3986#page-21 // in the host component %-encoding can only be used // for non-ASCII bytes. // But https://tools.ietf.org/html/rfc6874#section-2 // introduces %25 being allowed to escape a percent sign // in IPv6 scoped-address literals. Yay. tmp = std::string(s.begin() + i, s.begin() + i + 3); if (mode == encodeHost && unhex(s[i + 1]) < 8 && tmp != "%25") { throw std::invalid_argument("escape error: " + tmp); } if (mode == encodeZone) { // RFC 6874 says basically "anything goes" for zone identifiers // and that even non-ASCII can be redundantly escaped, // but it seems prudent to restrict %-escaped bytes here to those // that are valid host name bytes in their unescaped form. // That is, you can use escaping in the zone identifier but not // to introduce bytes you couldn't just write directly. // But Windows puts spaces here! Yay. v = unhex(s[i + 1]) << 4 | unhex(s[i + 2]); tmp = std::string(s.begin() + i, s.begin() + i + 3); if (tmp != "%25" && v != ' ' && shouldEscape(v, encodeHost)) { throw std::invalid_argument("escape error: " + tmp); } } i += 3; break; case '+': hasPlus = mode == encodeQueryComponent; i++; break; default: if ((mode == encodeHost || mode == encodeZone) && (uint8_t)s[i] < 0x80 && shouldEscape(s[i], mode)) { tmp = std::string(s.begin() + i, s.begin() + i + 1); throw std::invalid_argument("invalid host: " + tmp); } i++; } } if (n == 0 && !hasPlus) { return s; } auto ss = std::stringstream(); for (int i = 0; i < s.size(); i++) { switch (s[i]) { case '%': ss << (char)(unhex(s[i + 1]) << 4 | unhex(s[i + 2])); i += 2; break; case '+': ss << ((mode == encodeQueryComponent) ? ' ' : '+'); break; default: ss << s[i]; } } return ss.str(); } bool stringContainsCTLByte(const std::string &s) { for (auto c : s) { if (c < ' ' || c == 0x7f) { return true; } } return false; } std::string getScheme(std::string &url) { int i = 0; std::string scheme; for (auto &c : url) { if (('a' <= c && c <= 'z') || ('A' <= c && c <= 'Z')) { // Do nothing } else if (('0' <= c && c <= '9') || c == '+' || c == '-' || c == '.') { if (i == 0) { break; } } else if (c == ':') { scheme = std::string(url.begin(), url.begin() + i); url = std::string(url.begin() + i + 1, url.end()); break; } else { // we have encountered an invalid character, // so there is no valid scheme break; } i++; } return scheme; } // validUserinfo reports whether s is a valid userinfo string per RFC 3986 // Section 3.2.1: // userinfo = *( unreserved / pct-encoded / sub-delims / ":" ) // unreserved = ALPHA / DIGIT / "-" / "." / "_" / "~" // sub-delims = "!" / "$" / "&" / "'" / "(" / ")" // / "*" / "+" / "," / ";" / "=" // // It doesn't validate pct-encoded. The caller does that via func unescape. bool validUserinfo(const std::string &s) { for (auto r : s) { if ('A' <= r && r <= 'Z') { continue; } if ('a' <= r && r <= 'z') { continue; } if ('0' <= r && r <= '9') { continue; } switch (r) { case '-': case '.': case '_': case ':': case '~': case '!': case '$': case '&': case '\'': case '(': case ')': case '*': case '+': case ',': case ';': case '=': case '%': case '@': continue; default: return false; } } return true; } // validOptionalPort reports whether port is either an empty string // or matches /^:\d*$/ bool validOptionalPort(const std::string &port) { if (port.empty()) { return true; } if (port[0] != ':') { return false; } for (auto b = port.begin() + 1; b < port.end(); b++) { if (*b < '0' || *b > '9') { return false; } } return true; } // parseHost parses host as an authority without user // information. That is, as host[:port]. std::string parseHost(const std::string &host) { int idx; if (!host.empty() && host[0] == '[') { // Parse an IP-Literal in RFC 3986 and RFC 6874. // E.g., "[fe80::1]", "[fe80::1%25en0]", "[fe80::1]:80". idx = host.find_last_of(']'); if (idx >= host.size()) { throw std::invalid_argument("cannot find ']' in host"); } auto colonPort = std::string(host.begin() + idx + 1, host.end()); if (!validOptionalPort(colonPort)) { throw std::invalid_argument(fmt::format("invalid port {} after host", colonPort)); } // RFC 6874 defines that %25 (%-encoded percent) introduces // the zone identifier, and the zone identifier can use basically // any %-encoding it likes. That's different from the host, which // can only %-encode non-ASCII bytes. // We do impose some restrictions on the zone, to avoid stupidity // like newlines. auto zone = host.find("%25"); if (idx != std::string::npos) { auto host1 = unescape(std::string(host.begin(), host.begin() + zone), encodeHost); auto host2 = unescape(std::string(host.begin() + zone, host.begin() + idx), encodeHost); auto host3 = unescape(std::string(host.begin() + idx, host.end()), encodeZone); return host1 + host2 + host3; } } else if ((idx = host.find_last_of(':')) < host.size()) { auto colonPort = std::string(host.begin() + idx, host.end()); if (!validOptionalPort(colonPort)) { throw std::invalid_argument(fmt::format("invalid port {} after host", colonPort)); } } return unescape(host, encodeHost); } void parseAuthority(const std::string &authority, UserInfo &ui, std::string &host) { auto i = authority.find_last_of('@'); if (i > authority.size()) { host = parseHost(authority); } else { host = parseHost(std::string(authority.begin() + i + 1, authority.end())); } if (i > authority.size()) { return; } auto userInfo = std::string(authority.begin(), authority.begin() + i); if (!validUserinfo(userInfo)) { throw std::invalid_argument("invalid userinfo"); } auto idx = userInfo.find(':'); if (idx == std::string::npos) { userInfo = unescape(userInfo, encodeUserPassword); ui = UserInfo(userInfo); } else { auto username = std::string(userInfo.begin(), userInfo.begin() + idx); auto password = std::string(userInfo.begin() + idx, userInfo.end()); ui.Username = unescape(username, encodeUserPassword); ui.Password = unescape(username, encodeUserPassword); } } URL URL::Parse(std::string url) { URL u; std::string frag; auto hashIndex = url.find("#"); if (hashIndex != std::string::npos) { frag = std::string(url.begin() + hashIndex, url.end()); url = std::string(url.begin(), url.begin() + hashIndex); } u.setFragment(frag); if (stringContainsCTLByte(url)) { throw std::invalid_argument("invalid url: string contains control bytes"); } if (url == "*") { u.Path = "*"; return u; } auto rest = url; u.Scheme = getScheme(rest); std::transform(u.Scheme.begin(), u.Scheme.end(), u.Scheme.begin(), [](unsigned char c) -> unsigned char { return std::tolower(c); }); if (!rest.empty() > 0 && rest[rest.size() - 1] == '?') { u.ForceQuery = true; rest.pop_back(); } else { auto idx = rest.find("?"); if (idx != std::string::npos) { u.RawQuery = std::string(rest.begin() + idx, rest.end()); rest = std::string(rest.begin(), rest.begin() + idx); } } if (!rest.empty() && rest[0] != '/') { if (!u.Scheme.empty()) { // We consider rootless paths per RFC 3986 as opaque. u.Opaque = rest; return u; } } if (!u.Scheme.empty() || (rest.find("///") != 0 && rest.find("//") == 0)) { auto authority = std::string(rest.begin() + 2, rest.end()); rest = ""; int i = authority.find("/"); if (i != std::string::npos) { rest = std::string(authority.begin() + i, authority.end()); authority = std::string(authority.begin(), authority.begin() + i); } parseAuthority(authority, u.User, u.Host); } u.setPath(rest); return u; } void URL::setFragment(const std::string &f) { Fragment = unescape(f, encodeFragment); auto escf = escape(Fragment, encodeFragment); RawFragment = (escf == f) ? "" : f; } void URL::setPath(const std::string &p) { Path = unescape(p, encodePath); auto escp = escape(Path, encodePath); RawPath = (escp == p) ? "" : p; } std::string URL::PathEscape(const std::string &path) { return escape(path, encodePath); } std::string URL::PathUnescape(const std::string &path) { return unescape(path, encodePath); } std::string URL::QueryEscape(const std::string &query) { return escape(query, encodeQueryComponent); } std::string URL::QueryUnescape(const std::string &query) { return unescape(query, encodeQueryComponent); }