diff --git a/src/error.cc b/src/error.cc index 533b542..17abbcc 100644 --- a/src/error.cc +++ b/src/error.cc @@ -8,6 +8,7 @@ static const char* messages[] = { nullptr, /* standard error: call stderror */ "End of file", "libogg error", + "Invalid identification header, not an Opus stream", "Bad magic number", "Overflowing magic number", "Overflowing vendor length", diff --git a/src/opus.cc b/src/opus.cc index 5b51ffc..1655e18 100644 --- a/src/opus.cc +++ b/src/opus.cc @@ -34,6 +34,15 @@ #define le32toh(x) OSSwapLittleToHostInt32(x) #endif +ot::status ot::validate_identification_header(const unsigned char* data, size_t size) +{ + if (size < 8) + return ot::status::bad_identification_header; + if (memcmp(data, "OpusHead", 8) != 0) + return ot::status::bad_identification_header; + return ot::status::ok; +} + ot::status ot::parse_tags(const char *data, long len, opus_tags *tags) { if (len < 0) diff --git a/src/opustags.cc b/src/opustags.cc index 0392a6a..4a3f90f 100644 --- a/src/opustags.cc +++ b/src/opustags.cc @@ -108,9 +108,10 @@ static int run(ot::options& opt) // Read all the packets. while(ogg_stream_packetout(&reader.stream, &reader.packet) == 1){ packet_count++; - if(packet_count == 1){ // Identification header - if(strncmp((char*) reader.packet.packet, "OpusHead", 8) != 0){ - error = "opustags: invalid identification header"; + if (packet_count == 1) { // Identification header + rc = ot::validate_identification_header(reader.packet.packet, reader.packet.bytes); + if (rc != ot::status::ok) { + error = ot::error_message(rc); break; } } diff --git a/src/opustags.h b/src/opustags.h index dbc3ab0..622aca4 100644 --- a/src/opustags.h +++ b/src/opustags.h @@ -34,6 +34,7 @@ enum class status { standard_error, end_of_file, libogg_error, + bad_identification_header, /* OpusTags parsing errors */ bad_magic_number, overflowing_magic_number, @@ -197,6 +198,12 @@ struct opus_tags { std::string extra_data; }; +/** + * Validate the content of the first packet of an Ogg stream to ensure it's a valid OpusHead. + * Returns #ot::status::ok on success, #ot::status::bad_identification_header on error. + */ +status validate_identification_header(const unsigned char* data, size_t size); + status parse_tags(const char *data, long len, opus_tags *tags); int render_tags(opus_tags *tags, ogg_packet *op); void delete_tags(opus_tags *tags, const char *field); diff --git a/t/opus.cc b/t/opus.cc index d7f9099..46d82f4 100644 --- a/t/opus.cc +++ b/t/opus.cc @@ -13,6 +13,22 @@ using namespace std::literals::string_literals; +static void check_identification() +{ + ot::status rc; + rc = ot::validate_identification_header(reinterpret_cast("OpusHead.."), 10); + if (rc != ot::status::ok) + throw failure("did not accept a good OpusHead"); + + rc = ot::validate_identification_header(reinterpret_cast("OpusHead"), 7); + if (rc != ot::status::bad_identification_header) + throw failure("accepted an OpusHead that is too short"); + + rc = ot::validate_identification_header(reinterpret_cast("NotOpusHead"), 11); + if (rc != ot::status::bad_identification_header) + throw failure("did not report the right status for a bad OpusHead"); +} + static const char standard_OpusTags[] = "OpusTags" "\x14\x00\x00\x00" "opustags test packet" @@ -136,7 +152,8 @@ static void recode_padding() int main() { - std::cout << "1..4\n"; + std::cout << "1..5\n"; + run(check_identification, "check the OpusHead packet"); run(parse_standard, "parse a standard OpusTags packet"); run(parse_corrupted, "correctly reject invalid packets"); run(recode_standard, "recode a standard OpusTags packet");