plugins.twitch: rewrite disable ads logic

This commit is contained in:
bastimeyer 2020-04-15 01:11:31 +02:00 committed by Forrest
parent e55401a568
commit 1e8df7a8c6
3 changed files with 85 additions and 179 deletions

View File

@ -119,25 +119,16 @@ _video_schema = validate.Schema(
}
)
Segment = namedtuple("Segment", "uri duration title key discontinuity scte35 byterange date map")
Segment = namedtuple("Segment", "uri duration title key discontinuity ad byterange date map")
LOW_LATENCY_MAX_LIVE_EDGE = 2
def parse_condition(attr):
def wrapper(func):
def method(self, *args, **kwargs):
if hasattr(self.stream, attr) and getattr(self.stream, attr, False):
func(self, *args, **kwargs)
return method
return wrapper
class TwitchM3U8Parser(M3U8Parser):
def __init__(self, base_uri=None, stream=None, **kwargs):
M3U8Parser.__init__(self, base_uri, **kwargs)
self.stream = stream
def __init__(self, base_uri=None, disable_ads=False, low_latency=False, **kwargs):
super(TwitchM3U8Parser, self).__init__(base_uri, **kwargs)
self.disable_ads = disable_ads
self.low_latency = low_latency
self.has_prefetch_segments = False
def parse(self, *args):
@ -146,21 +137,16 @@ class TwitchM3U8Parser(M3U8Parser):
return m3u8
@parse_condition("disable_ads")
def parse_tag_ext_x_scte35_out(self, value):
self.state["scte35"] = True
def parse_extinf(self, value):
duration, title = super(TwitchM3U8Parser, self).parse_extinf(value)
if title and str(title).startswith("Amazon") and self.disable_ads:
self.state["ad"] = True
# unsure if this gets used by Twitch
@parse_condition("disable_ads")
def parse_tag_ext_x_scte35_out_cont(self, value):
self.state["scte35"] = True
return duration, title
@parse_condition("disable_ads")
def parse_tag_ext_x_scte35_in(self, value):
self.state["scte35"] = False
@parse_condition("low_latency")
def parse_tag_ext_x_twitch_prefetch(self, value):
if not self.low_latency:
return
self.has_prefetch_segments = True
segments = self.m3u8.segments
if segments:
@ -173,7 +159,7 @@ class TwitchM3U8Parser(M3U8Parser):
map_ = self.state.get("map")
key = self.state.get("key")
discontinuity = self.state.pop("discontinuity", False)
scte35 = self.state.pop("scte35", None)
ad = self.state.pop("ad", False)
return Segment(
uri,
@ -181,7 +167,7 @@ class TwitchM3U8Parser(M3U8Parser):
extinf[1],
key,
discontinuity,
scte35,
ad,
byterange,
date,
map_
@ -189,17 +175,36 @@ class TwitchM3U8Parser(M3U8Parser):
class TwitchHLSStreamWorker(HLSStreamWorker):
def __init__(self, *args, **kwargs):
self.playlist_reloads = 0
super(TwitchHLSStreamWorker, self).__init__(*args, **kwargs)
def _reload_playlist(self, text, url):
return load_hls_playlist(text, url, parser=TwitchM3U8Parser, stream=self.stream)
self.playlist_reloads += 1
playlist = load_hls_playlist(
text,
url,
parser=TwitchM3U8Parser,
disable_ads=self.stream.disable_ads,
low_latency=self.stream.low_latency
)
if (
self.stream.disable_ads
and self.playlist_reloads == 1
and not next((s for s in playlist.segments if not s.ad), False)
):
log.info("Waiting for pre-roll ads to finish, be patient")
return playlist
def _set_playlist_reload_time(self, playlist, sequences):
if not self.stream.low_latency:
super(TwitchHLSStreamWorker, self)._set_playlist_reload_time(playlist, sequences)
else:
if self.stream.low_latency and len(sequences) > 0:
self.playlist_reload_time = sequences[-1].segment.duration
else:
super(TwitchHLSStreamWorker, self)._set_playlist_reload_time(playlist, sequences)
def process_sequences(self, playlist, sequences):
if self.playlist_sequence < 0 and self.stream.low_latency and not playlist.has_prefetch_segments:
if self.stream.low_latency and self.playlist_reloads == 1 and not playlist.has_prefetch_segments:
log.info("This is not a low latency stream")
return super(TwitchHLSStreamWorker, self).process_sequences(playlist, sequences)
@ -207,38 +212,20 @@ class TwitchHLSStreamWorker(HLSStreamWorker):
class TwitchHLSStreamWriter(HLSStreamWriter):
def write(self, sequence, *args, **kwargs):
if self.stream.disable_ads:
if sequence.segment.scte35 is not None:
self.reader.ads = sequence.segment.scte35
if self.reader.ads:
log.info("Will skip ads beginning with segment {0}".format(sequence.num))
else:
log.info("Will stop skipping ads beginning with segment {0}".format(sequence.num))
if self.reader.ads:
return
return HLSStreamWriter.write(self, sequence, *args, **kwargs)
if not (self.stream.disable_ads and sequence.segment.ad):
return super(TwitchHLSStreamWriter, self).write(sequence, *args, **kwargs)
class TwitchHLSStreamReader(HLSStreamReader):
__worker__ = TwitchHLSStreamWorker
__writer__ = TwitchHLSStreamWriter
ads = None
class TwitchHLSStream(HLSStream):
def __init__(self, *args, **kwargs):
HLSStream.__init__(self, *args, **kwargs)
disable_ads = self.session.get_plugin_option("twitch", "disable-ads")
low_latency = self.session.get_plugin_option("twitch", "low-latency")
if low_latency and disable_ads:
log.info("Low latency streaming with ad filtering is currently not supported")
self.session.set_plugin_option("twitch", "low-latency", False)
low_latency = False
self.disable_ads = disable_ads
self.low_latency = low_latency
super(TwitchHLSStream, self).__init__(*args, **kwargs)
self.disable_ads = self.session.get_plugin_option("twitch", "disable-ads")
self.low_latency = self.session.get_plugin_option("twitch", "low-latency")
def open(self):
if self.disable_ads:

View File

@ -232,21 +232,24 @@ class HLSStreamWorker(SegmentedStreamWorker):
sequences = [Sequence(media_sequence + i, s)
for i, s in enumerate(playlist.segments)]
if sequences:
self.process_sequences(playlist, sequences)
self.process_sequences(playlist, sequences)
def _set_playlist_reload_time(self, playlist, sequences):
self.playlist_reload_time = (playlist.target_duration
or sequences[-1].segment.duration)
or len(sequences) > 0 and sequences[-1].segment.duration)
def process_sequences(self, playlist, sequences):
self._set_playlist_reload_time(playlist, sequences)
if not sequences:
return
first_sequence, last_sequence = sequences[0], sequences[-1]
if first_sequence.segment.key and first_sequence.segment.key.method != "NONE":
log.debug("Segments in this playlist are encrypted")
self.playlist_changed = ([s.num for s in self.playlist_sequences] != [s.num for s in sequences])
self._set_playlist_reload_time(playlist, sequences)
self.playlist_sequences = sequences
if not self.playlist_changed:

View File

@ -5,7 +5,7 @@ from functools import partial
from streamlink.plugins.twitch import Twitch, TwitchHLSStream
import requests_mock
from tests.mock import call, patch
from tests.mock import MagicMock, call, patch
from streamlink.session import Streamlink
from tests.resources import text
@ -34,22 +34,21 @@ class TestPluginTwitch(unittest.TestCase):
self.assertFalse(Twitch.can_handle_url(url))
@patch("streamlink.stream.hls.HLSStreamWorker.wait", MagicMock(return_value=True))
class TestTwitchHLSStream(unittest.TestCase):
url_master = "http://mocked/path/master.m3u8"
url_playlist = "http://mocked/path/playlist.m3u8"
url_segment = "http://mocked/path/stream{0}.ts"
scte35_out = "#EXT-X-DISCONTINUITY\n#EXT-X-SCTE35-OUT\n"
scte35_out_cont = "#EXT-X-SCTE35-OUT-CONT\n"
scte35_in = "#EXT-X-DISCONTINUITY\n#EXT-X-SCTE35-IN\n"
segment = "#EXTINF:1.000,\nstream{0}.ts\n"
segment_ad = "#EXTINF:1.000,Amazon|123456789\nstream{0}.ts\n"
prefetch = "#EXT-X-TWITCH-PREFETCH:{0}\n"
def getMasterPlaylist(self):
with text("hls/test_master.m3u8") as pl:
return pl.read()
def getPlaylist(self, media_sequence, items, prefetch=None):
def getPlaylist(self, media_sequence, items, ads=False, prefetch=None):
playlist = """
#EXTM3U
#EXT-X-VERSION:5
@ -57,11 +56,9 @@ class TestTwitchHLSStream(unittest.TestCase):
#EXT-X-MEDIA-SEQUENCE:{0}
""".format(media_sequence)
segment = self.segment if not ads else self.segment_ad
for item in items:
if type(item) != int:
playlist += item
else:
playlist += self.segment.format(item)
playlist += segment.format(item)
for item in prefetch or []:
playlist += self.prefetch.format(self.url_segment.format(item))
@ -97,117 +94,52 @@ class TestTwitchHLSStream(unittest.TestCase):
return streamlink, data, mocked
@patch("streamlink.plugins.twitch.log")
def test_hls_scte35_start_with_end(self, mock_logging):
streams = ["[{0}]".format(i).encode("ascii") for i in range(12)]
def test_hls_disable_ads_preroll(self, mock_logging):
streams = ["[{0}]".format(i).encode("ascii") for i in range(6)]
playlists = [
self.getPlaylist(0, [self.scte35_out, 0, 1, 2, 3]),
self.getPlaylist(4, [self.scte35_in, 4, 5, 6, 7]),
self.getPlaylist(8, [8, 9, 10, 11]) + "#EXT-X-ENDLIST\n"
self.getPlaylist(0, [0, 1], ads=True),
self.getPlaylist(2, [2, 3], ads=True),
self.getPlaylist(4, [4, 5]) + "#EXT-X-ENDLIST\n"
]
streamlink, result, mocked = self.get_result(streams, playlists, disable_ads=True)
expected = b''.join(streams[4:12])
self.assertEqual(expected, result)
for i in range(0, 12):
self.assertEqual(result, b''.join(streams[4:6]))
for i in range(0, 6):
self.assertTrue(mocked[self.url_segment.format(i)].called, i)
mock_logging.info.assert_has_calls([
call("Will skip ad segments"),
call("Will skip ads beginning with segment 0"),
call("Will stop skipping ads beginning with segment 4")
call("Waiting for pre-roll ads to finish, be patient")
])
self.assertEqual(mock_logging.info.call_count, 2)
@patch("streamlink.plugins.twitch.log")
def test_hls_scte35_no_start(self, mock_logging):
streams = ["[{0}]".format(i).encode("ascii") for i in range(8)]
def test_hls_disable_ads_no_preroll(self, mock_logging):
streams = ["[{0}]".format(i).encode("ascii") for i in range(6)]
playlists = [
self.getPlaylist(0, [0, 1, 2, 3]),
self.getPlaylist(4, [self.scte35_in, 4, 5, 6, 7]) + "#EXT-X-ENDLIST\n"
self.getPlaylist(0, [0, 1]),
self.getPlaylist(2, [2, 3], ads=True),
self.getPlaylist(4, [4, 5]) + "#EXT-X-ENDLIST\n"
]
streamlink, result, mocked = self.get_result(streams, playlists, disable_ads=True)
expected = b''.join(streams[0:8])
self.assertEqual(expected, result)
for i in range(0, 8):
self.assertEqual(result, b''.join(streams[0:2]) + b''.join(streams[4:6]))
for i in range(0, 6):
self.assertTrue(mocked[self.url_segment.format(i)].called, i)
mock_logging.info.assert_has_calls([
call("Will skip ad segments")
])
@patch("streamlink.plugins.twitch.log")
def test_hls_scte35_no_start_with_cont(self, mock_logging):
streams = ["[{0}]".format(i).encode("ascii") for i in range(8)]
def test_hls_no_disable_ads(self, mock_logging):
streams = ["[{0}]".format(i).encode("ascii") for i in range(4)]
playlists = [
self.getPlaylist(0, [self.scte35_out_cont, 0, 1, 2, 3]),
self.getPlaylist(4, [self.scte35_in, 4, 5, 6, 7]) + "#EXT-X-ENDLIST\n"
self.getPlaylist(0, [0, 1], ads=True),
self.getPlaylist(2, [2, 3]) + "#EXT-X-ENDLIST\n"
]
streamlink, result, mocked = self.get_result(streams, playlists, disable_ads=True)
streamlink, result, mocked = self.get_result(streams, playlists, disable_ads=False)
expected = b''.join(streams[4:8])
self.assertEqual(expected, result)
for i in range(0, 8):
self.assertTrue(mocked[self.url_segment.format(i)].called, i)
mock_logging.info.assert_has_calls([
call("Will skip ad segments"),
call("Will skip ads beginning with segment 0"),
call("Will stop skipping ads beginning with segment 4")
])
@patch("streamlink.plugins.twitch.log")
def test_hls_scte35_no_end(self, mock_logging):
streams = ["[{0}]".format(i).encode("ascii") for i in range(12)]
playlists = [
self.getPlaylist(0, [0, 1, 2, 3]),
self.getPlaylist(4, [self.scte35_out, 4, 5, 6, 7]),
self.getPlaylist(8, [8, 9, 10, 11]) + "#EXT-X-ENDLIST\n"
]
streamlink, result, mocked = self.get_result(streams, playlists, disable_ads=True)
expected = b''.join(streams[0:4])
self.assertEqual(expected, result)
for i in range(0, 12):
self.assertTrue(mocked[self.url_segment.format(i)].called, i)
mock_logging.info.assert_has_calls([
call("Will skip ad segments"),
call("Will skip ads beginning with segment 4")
])
@patch("streamlink.plugins.twitch.log")
def test_hls_scte35_in_between(self, mock_logging):
streams = ["[{0}]".format(i).encode("ascii") for i in range(20)]
playlists = [
self.getPlaylist(0, [0, 1, 2, 3]),
self.getPlaylist(4, [4, 5, self.scte35_out, 6, 7]),
self.getPlaylist(8, [8, 9, 10, 11]),
self.getPlaylist(12, [12, 13, self.scte35_in, 14, 15]),
self.getPlaylist(16, [16, 17, 18, 19]) + "#EXT-X-ENDLIST\n"
]
streamlink, result, mocked = self.get_result(streams, playlists, disable_ads=True)
expected = b''.join(streams[0:6]) + b''.join(streams[14:20])
self.assertEqual(expected, result)
for i in range(0, 20):
self.assertTrue(mocked[self.url_segment.format(i)].called, i)
mock_logging.info.assert_has_calls([
call("Will skip ad segments"),
call("Will skip ads beginning with segment 6"),
call("Will stop skipping ads beginning with segment 14")
])
@patch("streamlink.plugins.twitch.log")
def test_hls_scte35_no_disable_ads(self, mock_logging):
streams = ["[{0}]".format(i).encode("ascii") for i in range(20)]
playlists = [
self.getPlaylist(0, [0, 1, 2, 3]),
self.getPlaylist(4, [4, 5, self.scte35_out, 6, 7]),
self.getPlaylist(8, [8, 9, 10, 11]),
self.getPlaylist(12, [12, 13, self.scte35_in, 14, 15]),
self.getPlaylist(16, [16, 17, 18, 19]) + "#EXT-X-ENDLIST\n"
]
streamlink, result, mocked = self.get_result(streams, playlists)
expected = b''.join(streams[0:20])
self.assertEqual(expected, result)
for i in range(0, 20):
self.assertEqual(result, b''.join(streams[0:4]))
for i in range(0, 4):
self.assertTrue(mocked[self.url_segment.format(i)].called, i)
mock_logging.info.assert_has_calls([])
@ -215,8 +147,8 @@ class TestTwitchHLSStream(unittest.TestCase):
def test_hls_prefetch(self, mock_logging):
streams = ["[{0}]".format(i).encode("ascii") for i in range(10)]
playlists = [
self.getPlaylist(0, [0, 1, 2, 3], [4, 5]),
self.getPlaylist(4, [4, 5, 6, 7], [8, 9]) + "#EXT-X-ENDLIST\n"
self.getPlaylist(0, [0, 1, 2, 3], prefetch=[4, 5]),
self.getPlaylist(4, [4, 5, 6, 7], prefetch=[8, 9]) + "#EXT-X-ENDLIST\n"
]
streamlink, result, mocked = self.get_result(streams, playlists, low_latency=True)
@ -237,8 +169,8 @@ class TestTwitchHLSStream(unittest.TestCase):
def test_hls_prefetch_no_low_latency(self, mock_logging):
streams = ["[{0}]".format(i).encode("ascii") for i in range(10)]
playlists = [
self.getPlaylist(0, [0, 1, 2, 3], [4, 5]),
self.getPlaylist(4, [4, 5, 6, 7], [8, 9]) + "#EXT-X-ENDLIST\n"
self.getPlaylist(0, [0, 1, 2, 3], prefetch=[4, 5]),
self.getPlaylist(4, [4, 5, 6, 7], prefetch=[8, 9]) + "#EXT-X-ENDLIST\n"
]
streamlink, result, mocked = self.get_result(streams, playlists)
@ -253,28 +185,12 @@ class TestTwitchHLSStream(unittest.TestCase):
self.assertFalse(mocked[self.url_segment.format(i)].called, i)
mock_logging.info.assert_has_calls([])
@patch("streamlink.plugins.twitch.log")
def test_hls_no_low_latency_with_disable_ads(self, mock_logging):
streams = ["[{0}]".format(i).encode("ascii") for i in range(10)]
playlists = [
self.getPlaylist(0, [0, 1, 2, 3], [4, 5]),
self.getPlaylist(4, [4, 5, 6, 7], [8, 9]) + "#EXT-X-ENDLIST\n"
]
streamlink, result, mocked = self.get_result(streams, playlists, low_latency=True, disable_ads=True)
self.assertFalse(streamlink.get_plugin_option("twitch", "low-latency"))
self.assertTrue(streamlink.get_plugin_option("twitch", "disable-ads"))
mock_logging.info.assert_has_calls([
call("Low latency streaming with ad filtering is currently not supported")
])
@patch("streamlink.plugins.twitch.log")
def test_hls_no_low_latency_no_prefetch(self, mock_logging):
streams = ["[{0}]".format(i).encode("ascii") for i in range(10)]
playlists = [
self.getPlaylist(0, [0, 1, 2, 3], []),
self.getPlaylist(4, [4, 5, 6, 7], []) + "#EXT-X-ENDLIST\n"
self.getPlaylist(0, [0, 1, 2, 3], prefetch=[]),
self.getPlaylist(4, [4, 5, 6, 7], prefetch=[]) + "#EXT-X-ENDLIST\n"
]
streamlink, result, mocked = self.get_result(streams, playlists, low_latency=True)