/*	$NetBSD: skr.c,v 1.2 2025/01/26 16:25:25 christos Exp $	*/

/*
 * Copyright (C) Internet Systems Consortium, Inc. ("ISC")
 *
 * SPDX-License-Identifier: MPL-2.0
 *
 * This Source Code Form is subject to the terms of the Mozilla Public
 * License, v. 2.0. If a copy of the MPL was not distributed with this
 * file, you can obtain one at https://mozilla.org/MPL/2.0/.
 *
 * See the COPYRIGHT file distributed with this work for additional
 * information regarding copyright ownership.
 */

/*! \file */

#include <isc/lex.h>
#include <isc/log.h>

#include <dns/callbacks.h>
#include <dns/fixedname.h>
#include <dns/rdata.h>
#include <dns/rdataclass.h>
#include <dns/rdatatype.h>
#include <dns/skr.h>
#include <dns/time.h>
#include <dns/ttl.h>

#define CHECK(op)                            \
	do {                                 \
		result = (op);               \
		if (result != ISC_R_SUCCESS) \
			goto failure;        \
	} while (0)

#define READLINE(lex, opt, token)

#define NEXTTOKEN(lex, opt, token)                       \
	{                                                \
		ret = isc_lex_gettoken(lex, opt, token); \
		if (ret != ISC_R_SUCCESS)                \
			goto cleanup;                    \
	}

#define BADTOKEN()                           \
	{                                    \
		ret = ISC_R_UNEXPECTEDTOKEN; \
		goto cleanup;                \
	}

#define TOKENSIZ (8 * 1024)
#define STR(t)	 ((t).value.as_textregion.base)

static isc_result_t
parse_rr(isc_lex_t *lex, isc_mem_t *mctx, char *owner, dns_name_t *origin,
	 dns_rdataclass_t rdclass, isc_buffer_t *buf, dns_ttl_t *ttl,
	 dns_rdatatype_t *rdtype, dns_rdata_t **rdata) {
	dns_rdatacallbacks_t callbacks;
	dns_fixedname_t dfname;
	dns_name_t *dname = NULL;
	dns_rdataclass_t clas;
	isc_buffer_t b;
	isc_token_t token;
	unsigned int opt = ISC_LEXOPT_EOL;
	isc_result_t ret = ISC_R_SUCCESS;

	isc_lex_setcomments(lex, ISC_LEXCOMMENT_DNSMASTERFILE);

	/* Read the domain name */
	if (!strcmp(owner, "@")) {
		BADTOKEN();
	}
	dname = dns_fixedname_initname(&dfname);
	isc_buffer_init(&b, owner, strlen(owner));
	isc_buffer_add(&b, strlen(owner));
	ret = dns_name_fromtext(dname, &b, dns_rootname, 0, NULL);
	if (ret != ISC_R_SUCCESS) {
		return ret;
	}
	if (dns_name_compare(dname, origin) != 0) {
		return DNS_R_BADOWNERNAME;
	}
	isc_buffer_clear(&b);

	/* Read the next word: either TTL, class, or type */
	NEXTTOKEN(lex, opt, &token);
	if (token.type != isc_tokentype_string) {
		BADTOKEN();
	}

	/* If it's a TTL, read the next one */
	ret = dns_ttl_fromtext(&token.value.as_textregion, ttl);
	if (ret == ISC_R_SUCCESS) {
		NEXTTOKEN(lex, opt, &token);
	}
	if (token.type != isc_tokentype_string) {
		BADTOKEN();
	}

	/* If it's a class, read the next one */
	ret = dns_rdataclass_fromtext(&clas, &token.value.as_textregion);
	if (ret == ISC_R_SUCCESS) {
		if (clas != rdclass) {
			BADTOKEN();
		}
		NEXTTOKEN(lex, opt, &token);
	}
	if (token.type != isc_tokentype_string) {
		BADTOKEN();
	}

	/* Must be the record type */
	ret = dns_rdatatype_fromtext(rdtype, &token.value.as_textregion);
	if (ret != ISC_R_SUCCESS) {
		BADTOKEN();
	}
	switch (*rdtype) {
	case dns_rdatatype_dnskey:
	case dns_rdatatype_cdnskey:
	case dns_rdatatype_cds:
	case dns_rdatatype_rrsig:
		/* Allowed record types */
		break;
	default:
		BADTOKEN();
	}

	dns_rdatacallbacks_init(&callbacks);
	ret = dns_rdata_fromtext(*rdata, rdclass, *rdtype, lex, dname, 0, mctx,
				 buf, &callbacks);
cleanup:
	isc_lex_setcomments(lex, 0);
	return ret;
}

static void
skrbundle_create(isc_mem_t *mctx, isc_stdtime_t inception,
		 dns_skrbundle_t **bp) {
	dns_skrbundle_t *b;

	REQUIRE(bp != NULL && *bp == NULL);

	b = isc_mem_get(mctx, sizeof(*b));
	b->magic = DNS_SKRBUNDLE_MAGIC;
	b->inception = inception;
	dns_diff_init(mctx, &b->diff);

	ISC_LINK_INIT(b, link);

	*bp = b;
}

static void
skrbundle_addtuple(dns_skrbundle_t *bundle, dns_difftuple_t **tuple) {
	REQUIRE(DNS_DIFFTUPLE_VALID(*tuple));
	REQUIRE(DNS_SKRBUNDLE_VALID(bundle));
	REQUIRE(DNS_DIFF_VALID(&bundle->diff));

	dns_diff_append(&bundle->diff, tuple);
}

isc_result_t
dns_skrbundle_getsig(dns_skrbundle_t *bundle, dst_key_t *key,
		     dns_rdatatype_t covering_type, dns_rdata_t *sigrdata) {
	isc_result_t result = ISC_R_SUCCESS;

	REQUIRE(DNS_SKRBUNDLE_VALID(bundle));
	REQUIRE(DNS_DIFF_VALID(&bundle->diff));

	dns_difftuple_t *tuple = ISC_LIST_HEAD(bundle->diff.tuples);
	while (tuple != NULL) {
		dns_rdata_rrsig_t rrsig;

		if (tuple->op != DNS_DIFFOP_ADDRESIGN) {
			tuple = ISC_LIST_NEXT(tuple, link);
			continue;
		}
		INSIST(tuple->rdata.type == dns_rdatatype_rrsig);

		result = dns_rdata_tostruct(&tuple->rdata, &rrsig, NULL);
		if (result != ISC_R_SUCCESS) {
			return result;
		}

		/*
		 * Check if covering type matches, and if the signature is
		 * generated by 'key'.
		 */
		if (rrsig.covered == covering_type &&
		    rrsig.keyid == dst_key_id(key))
		{
			dns_rdata_clone(&tuple->rdata, sigrdata);
			return ISC_R_SUCCESS;
		}

		tuple = ISC_LIST_NEXT(tuple, link);
	}

	return ISC_R_NOTFOUND;
}

void
dns_skr_create(isc_mem_t *mctx, const char *filename, dns_name_t *origin,
	       dns_rdataclass_t rdclass, dns_skr_t **skrp) {
	isc_time_t now;
	dns_skr_t *skr = NULL;

	REQUIRE(skrp != NULL && *skrp == NULL);
	REQUIRE(mctx != NULL);

	UNUSED(origin);
	UNUSED(rdclass);

	now = isc_time_now();
	skr = isc_mem_get(mctx, sizeof(*skr));
	*skr = (dns_skr_t){
		.magic = DNS_SKR_MAGIC,
		.filename = isc_mem_strdup(mctx, filename),
		.loadtime = now,
	};
	/*
	 * A list is not the best structure to store bundles that
	 * we need to look up, but we don't expect many bundles
	 * per SKR so it is acceptable for now.
	 */
	ISC_LIST_INIT(skr->bundles);

	isc_mem_attach(mctx, &skr->mctx);
	isc_refcount_init(&skr->references, 1);
	*skrp = skr;
}

static void
addbundle(dns_skr_t *skr, dns_skrbundle_t **bundlep) {
	REQUIRE(DNS_SKR_VALID(skr));
	REQUIRE(DNS_SKRBUNDLE_VALID(*bundlep));

	ISC_LIST_APPEND(skr->bundles, *bundlep, link);
	*bundlep = NULL;
}

isc_result_t
dns_skr_read(isc_mem_t *mctx, const char *filename, dns_name_t *origin,
	     dns_rdataclass_t rdclass, dns_ttl_t dnskeyttl, dns_skr_t **skrp) {
	isc_result_t result;
	dns_skrbundle_t *bundle = NULL;
	char bundlebuf[1024];
	uint32_t bundle_id;
	isc_lex_t *lex = NULL;
	isc_lexspecials_t specials;
	isc_token_t token;
	unsigned int opt = ISC_LEXOPT_EOL;

	REQUIRE(DNS_SKR_VALID(*skrp));

	isc_lex_create(mctx, TOKENSIZ, &lex);
	memset(specials, 0, sizeof(specials));
	specials['('] = 1;
	specials[')'] = 1;
	specials['"'] = 1;
	isc_lex_setspecials(lex, specials);
	result = isc_lex_openfile(lex, filename);
	if (result != ISC_R_SUCCESS) {
		isc_log_write(dns_lctx, DNS_LOGCATEGORY_GENERAL,
			      DNS_LOGMODULE_ZONE, ISC_LOG_ERROR,
			      "unable to open ksr file %s: %s", filename,
			      isc_result_totext(result));
		isc_lex_destroy(&lex);
		return result;
	}

	for (result = isc_lex_gettoken(lex, opt, &token);
	     result == ISC_R_SUCCESS;
	     result = isc_lex_gettoken(lex, opt, &token))
	{
		if (token.type == isc_tokentype_eol) {
			continue;
		}

		if (token.type != isc_tokentype_string) {
			CHECK(DNS_R_SYNTAX);
		}

		if (strcmp(STR(token), ";;") == 0) {
			/* New bundle */
			CHECK(isc_lex_gettoken(lex, opt, &token));
			if (token.type != isc_tokentype_string ||
			    strcmp(STR(token), "SignedKeyResponse") != 0)
			{
				CHECK(DNS_R_SYNTAX);
			}

			/* Version */
			CHECK(isc_lex_gettoken(lex, opt, &token));
			if (token.type != isc_tokentype_string ||
			    strcmp(STR(token), "1.0") != 0)
			{
				CHECK(DNS_R_SYNTAX);
			}

			/* Date and time of bundle */
			CHECK(isc_lex_gettoken(lex, opt, &token));
			if (token.type != isc_tokentype_string) {
				CHECK(DNS_R_SYNTAX);
			}
			if (strcmp(STR(token), "generated") == 0) {
				/* Final bundle */
				goto readline;
			}
			if (token.type != isc_tokentype_string) {
				CHECK(DNS_R_SYNTAX);
			}

			/* Add previous bundle */
			if (bundle != NULL) {
				addbundle(*skrp, &bundle);
			}

			/* Create new bundle */
			sscanf(STR(token), "%s", bundlebuf);
			CHECK(dns_time32_fromtext(bundlebuf, &bundle_id));
			bundle = NULL;
			skrbundle_create(mctx, (isc_stdtime_t)bundle_id,
					 &bundle);

		readline:
			/* Read remainder of header line */
			do {
				CHECK(isc_lex_gettoken(lex, opt, &token));
			} while (token.type != isc_tokentype_eol);
		} else {
			isc_buffer_t buf;
			dns_rdata_t *rdata = NULL;
			u_char rdatabuf[DST_KEY_MAXSIZE];
			dns_rdatatype_t rdtype;

			/* Parse record */
			rdata = isc_mem_get(mctx, sizeof(*rdata));
			dns_rdata_init(rdata);
			isc_buffer_init(&buf, rdatabuf, sizeof(rdatabuf));
			result = parse_rr(lex, mctx, STR(token), origin,
					  rdclass, &buf, &dnskeyttl, &rdtype,
					  &rdata);
			if (result != ISC_R_SUCCESS) {
				isc_log_write(
					dns_lctx, DNS_LOGCATEGORY_GENERAL,
					DNS_LOGMODULE_ZONE, ISC_LOG_DEBUG(1),
					"read skr file %s(%lu) parse rr "
					"failed: %s",
					filename, isc_lex_getsourceline(lex),
					isc_result_totext(result));
				isc_mem_put(mctx, rdata, sizeof(*rdata));
				goto failure;
			}

			/* Create new diff tuple */
			dns_diffop_t op = (rdtype == dns_rdatatype_rrsig)
						  ? DNS_DIFFOP_ADDRESIGN
						  : DNS_DIFFOP_ADD;
			dns_difftuple_t *tuple = NULL;

			dns_difftuple_create((*skrp)->mctx, op, origin,
					     dnskeyttl, rdata, &tuple);

			skrbundle_addtuple(bundle, &tuple);
			INSIST(tuple == NULL);

			isc_mem_put(mctx, rdata, sizeof(*rdata));
		}
	}

	if (result != ISC_R_EOF) {
		CHECK(DNS_R_SYNTAX);
	}
	result = ISC_R_SUCCESS;

	/* Add final bundle */
	if (bundle != NULL) {
		addbundle(*skrp, &bundle);
	}

failure:
	if (result != ISC_R_SUCCESS) {
		isc_log_write(dns_lctx, DNS_LOGCATEGORY_GENERAL,
			      DNS_LOGMODULE_ZONE, ISC_LOG_DEBUG(1),
			      "read skr file %s(%lu) failed: %s", filename,
			      isc_lex_getsourceline(lex),
			      isc_result_totext(result));
	}

	/* Clean up */
	isc_lex_destroy(&lex);
	return result;
}

dns_skrbundle_t *
dns_skr_lookup(dns_skr_t *skr, isc_stdtime_t time, uint32_t sigval) {
	dns_skrbundle_t *b, *next;

	REQUIRE(DNS_SKR_VALID(skr));

	for (b = ISC_LIST_HEAD(skr->bundles); b != NULL; b = next) {
		next = ISC_LIST_NEXT(b, link);
		if (next == NULL) {
			isc_stdtime_t expired = b->inception + sigval;
			if (b->inception <= time && time < expired) {
				return b;
			}
			return NULL;
		}
		if (b->inception <= time && time < next->inception) {
			return b;
		}
	}

	return NULL;
}

void
dns_skr_attach(dns_skr_t *source, dns_skr_t **targetp) {
	REQUIRE(DNS_SKR_VALID(source));
	REQUIRE(targetp != NULL && *targetp == NULL);

	isc_refcount_increment(&source->references);
	*targetp = source;
}

void
dns_skr_detach(dns_skr_t **skrp) {
	REQUIRE(skrp != NULL && DNS_SKR_VALID(*skrp));

	dns_skr_t *skr = *skrp;
	*skrp = NULL;

	if (isc_refcount_decrement(&skr->references) == 1) {
		dns_skr_destroy(skr);
	}
}

void
dns_skr_destroy(dns_skr_t *skr) {
	dns_skrbundle_t *b, *next;

	REQUIRE(DNS_SKR_VALID(skr));

	for (b = ISC_LIST_HEAD(skr->bundles); b != NULL; b = next) {
		next = ISC_LIST_NEXT(b, link);
		ISC_LIST_UNLINK(skr->bundles, b, link);
		dns_diff_clear(&b->diff);
		isc_mem_put(skr->mctx, b, sizeof(*b));
	}
	INSIST(ISC_LIST_EMPTY(skr->bundles));

	isc_mem_free(skr->mctx, skr->filename);
	isc_mem_putanddetach(&skr->mctx, skr, sizeof(*skr));
}
