diff --git a/tests/oci-registry-client.py b/tests/oci-registry-client.py index 5654062b..c4707c07 100755 --- a/tests/oci-registry-client.py +++ b/tests/oci-registry-client.py @@ -1,6 +1,7 @@ #!/usr/bin/python3 import argparse +import ssl import sys import http.client @@ -9,7 +10,20 @@ import urllib.parse def get_conn(args): parsed = urllib.parse.urlparse(args.url) - return http.client.HTTPConnection(host=parsed.hostname, port=parsed.port) + if parsed.scheme == "http": + return http.client.HTTPConnection(host=parsed.hostname, port=parsed.port) + elif parsed.scheme == "https": + context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) + context.options |= ssl.OP_NO_TLSv1 | ssl.OP_NO_TLSv1_1 + if args.cert: + context.load_cert_chain(certfile=args.cert, keyfile=args.key) + if args.cacert: + context.load_verify_locations(cafile=args.cacert) + return http.client.HTTPSConnection( + host=parsed.hostname, port=parsed.port, context=context + ) + else: + assert False, "Bad scheme: " + parsed.scheme def run_add(args): @@ -42,6 +56,9 @@ def run_delete(args): parser = argparse.ArgumentParser() parser.add_argument("--url", required=True) +parser.add_argument("--cacert") +parser.add_argument("--cert") +parser.add_argument("--key") subparsers = parser.add_subparsers() subparsers.required = True diff --git a/tests/oci-registry-server.py b/tests/oci-registry-server.py index 13bf50b3..2bbe8c6e 100755 --- a/tests/oci-registry-server.py +++ b/tests/oci-registry-server.py @@ -5,6 +5,7 @@ import base64 import hashlib import json import os +import ssl import time from urllib.parse import parse_qs @@ -252,6 +253,19 @@ class RequestHandler(http_server.BaseHTTPRequestHandler): def run(args): RequestHandler.protocol_version = "HTTP/1.0" httpd = http_server.HTTPServer(("127.0.0.1", 0), RequestHandler) + + if args.cert: + context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER) + context.options |= ssl.OP_NO_TLSv1 | ssl.OP_NO_TLSv1_1 + context.load_cert_chain(certfile=args.cert, keyfile=args.key) + + if args.mtls_cacert: + context.load_verify_locations(cafile=args.mtls_cacert) + # In a real application, we'd need to check the CN against authorized users + context.verify_mode = ssl.CERT_REQUIRED + + httpd.socket = context.wrap_socket(httpd.socket, server_side=True) + host, port = httpd.socket.getsockname()[:2] with open("httpd-port", "w") as file: file.write("%d" % port) @@ -259,7 +273,10 @@ def run(args): os.write(3, bytes("Started\n", "utf-8")) except OSError: pass - print("Serving HTTP on port %d" % port) + if args.cert: + print("Serving HTTPS on port %d" % port) + else: + print("Serving HTTP on port %d" % port) if args.dir: os.chdir(args.dir) httpd.serve_forever() @@ -268,6 +285,9 @@ def run(args): if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--dir") + parser.add_argument("--cert") + parser.add_argument("--key") + parser.add_argument("--mtls-cacert") args = parser.parse_args() run(args)