diff options
| author | Adam Malczewski <[email protected]> | 2026-05-11 19:18:34 +0900 |
|---|---|---|
| committer | Adam Malczewski <[email protected]> | 2026-05-11 19:18:34 +0900 |
| commit | c23ee09f6d24832aa472298db91df3ce6e248a76 (patch) | |
| tree | 3576678394cf5eb053dc649abdf1dab559d69487 /tests | |
| download | youtube-transcriber-c23ee09f6d24832aa472298db91df3ce6e248a76.tar.gz youtube-transcriber-c23ee09f6d24832aa472298db91df3ce6e248a76.zip | |
Initial commit: YouTube transcriber API with queue-based worker
Diffstat (limited to 'tests')
| -rw-r--r-- | tests/__init__.py | 0 | ||||
| -rw-r--r-- | tests/test_main.py | 126 | ||||
| -rw-r--r-- | tests/test_storage.py | 250 | ||||
| -rw-r--r-- | tests/test_transcriber.py | 60 | ||||
| -rw-r--r-- | tests/test_worker.py | 249 |
5 files changed, 685 insertions, 0 deletions
diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/tests/__init__.py diff --git a/tests/test_main.py b/tests/test_main.py new file mode 100644 index 0000000..e260747 --- /dev/null +++ b/tests/test_main.py @@ -0,0 +1,126 @@ +import asyncio +import os +from unittest.mock import AsyncMock, MagicMock, patch + +import httpx +from httpx import ASGITransport + +import app.main as main_module +from app.main import app +from app.storage import TranscriptStore + + +def _run_with_app(tmp_path, coro_factory): + """Initialize app lifespan with a temp-DB store, run coro_factory(client, store), tear down.""" + + async def _runner(): + db_path = os.path.join(str(tmp_path), "test.db") + + def _store_factory(_path): + return TranscriptStore(db_path=db_path) + + with patch.object(main_module, "run_worker", new=AsyncMock()), \ + patch.object(main_module, "create_api", new=MagicMock()), \ + patch.object(main_module, "TranscriptStore", side_effect=_store_factory): + async with main_module.lifespan(app): + transport = ASGITransport(app=app) + async with httpx.AsyncClient(transport=transport, base_url="http://test") as client: + await coro_factory(client, app.state.store) + + asyncio.run(_runner()) + + +class TestGetTranscript: + def test_health(self, tmp_path) -> None: + async def _do(client, store): + r = await client.get("/health") + assert r.status_code == 200 + assert r.json() == {"status": "ok"} + + _run_with_app(tmp_path, _do) + + def test_invalid_url_returns_400(self, tmp_path) -> None: + async def _do(client, store): + r = await client.get("/api/transcript", params={"url": "https://example.com/notavideo"}) + assert r.status_code == 400 + assert "Invalid" in r.json()["detail"] + + _run_with_app(tmp_path, _do) + + def test_missing_url_returns_422(self, tmp_path) -> None: + async def _do(client, store): + r = await client.get("/api/transcript") + assert r.status_code == 422 + + _run_with_app(tmp_path, _do) + + def test_new_video_returns_queued(self, tmp_path) -> None: + async def _do(client, store): + r = await client.get( + "/api/transcript", + params={"url": "https://www.youtube.com/watch?v=dQw4w9WgXcQ"}, + ) + assert r.status_code == 200 + body = r.json() + assert body["status"] == "queued" + assert body["video_id"] == "dQw4w9WgXcQ" + assert isinstance(body["estimated_seconds"], (int, float)) + assert body["estimated_seconds"] > 0 + + _run_with_app(tmp_path, _do) + + def test_cached_video_returns_completed(self, tmp_path) -> None: + async def _do(client, store): + await store.save_transcript( + "dQw4w9WgXcQ", + "hello world", + [{"text": "hello world", "start": 0.0, "duration": 1.0}], + ) + r = await client.get( + "/api/transcript", + params={"url": "https://www.youtube.com/watch?v=dQw4w9WgXcQ"}, + ) + assert r.status_code == 200 + body = r.json() + assert body["status"] == "completed" + assert body["full_text"] == "hello world" + assert isinstance(body["segments"], list) + assert len(body["segments"]) == 1 + + _run_with_app(tmp_path, _do) + + def test_already_queued_video_returns_existing_status(self, tmp_path) -> None: + async def _do(client, store): + await store.enqueue("dQw4w9WgXcQ") + r = await client.get( + "/api/transcript", + params={"url": "https://www.youtube.com/watch?v=dQw4w9WgXcQ"}, + ) + assert r.status_code == 200 + body = r.json() + assert body["status"] == "queued" + assert body["video_id"] == "dQw4w9WgXcQ" + + r2 = await client.get( + "/api/transcript", + params={"url": "https://www.youtube.com/watch?v=dQw4w9WgXcQ"}, + ) + assert r2.status_code == 200 + assert r2.json()["video_id"] == "dQw4w9WgXcQ" + + _run_with_app(tmp_path, _do) + + def test_failed_video_returns_failure(self, tmp_path) -> None: + async def _do(client, store): + await store.enqueue("dQw4w9WgXcQ") + await store.mark_failed("dQw4w9WgXcQ", "Transcripts disabled", "transcript_disabled") + r = await client.get( + "/api/transcript", + params={"url": "https://www.youtube.com/watch?v=dQw4w9WgXcQ"}, + ) + assert r.status_code == 200 + body = r.json() + assert body["status"] == "failed" + assert body["error_type"] == "transcript_disabled" + + _run_with_app(tmp_path, _do) diff --git a/tests/test_storage.py b/tests/test_storage.py new file mode 100644 index 0000000..b1b914b --- /dev/null +++ b/tests/test_storage.py @@ -0,0 +1,250 @@ +import asyncio +import os + +from app.storage import TranscriptStore + + +class TestTranscriptCache: + def test_initialize_creates_database_file(self, tmp_path) -> None: + async def _run() -> None: + db_path = os.path.join(str(tmp_path), "test.db") + store = TranscriptStore(db_path=db_path) + await store.initialize() + assert os.path.exists(db_path), "Database file should exist after initialize" + await store.close() + + asyncio.run(_run()) + + def test_get_transcript_returns_none_when_not_found(self, tmp_path) -> None: + async def _run() -> None: + db_path = os.path.join(str(tmp_path), "test.db") + store = TranscriptStore(db_path=db_path) + await store.initialize() + result = await store.get_transcript("nonexistent_id") + assert result is None + await store.close() + + asyncio.run(_run()) + + def test_save_and_retrieve_transcript(self, tmp_path) -> None: + async def _run() -> None: + db_path = os.path.join(str(tmp_path), "test.db") + store = TranscriptStore(db_path=db_path) + await store.initialize() + + segments = [ + {"text": "hello", "start": 0.0, "duration": 1.0}, + {"text": "world", "start": 1.0, "duration": 1.0}, + ] + await store.save_transcript( + video_id="abc123", + full_text="hello world", + segments=segments, + ) + + result = await store.get_transcript("abc123") + assert result is not None + assert result["video_id"] == "abc123" + assert result["full_text"] == "hello world" + assert result["segments"] == segments + assert isinstance(result["segments"], list) + assert all(isinstance(s, dict) for s in result["segments"]) + + await store.close() + + asyncio.run(_run()) + + def test_save_transcript_overwrites_existing(self, tmp_path) -> None: + async def _run() -> None: + db_path = os.path.join(str(tmp_path), "test.db") + store = TranscriptStore(db_path=db_path) + await store.initialize() + + await store.save_transcript( + video_id="abc123", + full_text="first version", + segments=[], + ) + await store.save_transcript( + video_id="abc123", + full_text="second version", + segments=[], + ) + + result = await store.get_transcript("abc123") + assert result is not None + assert result["full_text"] == "second version" + + await store.close() + + asyncio.run(_run()) + + def test_multiple_transcripts(self, tmp_path) -> None: + async def _run() -> None: + db_path = os.path.join(str(tmp_path), "test.db") + store = TranscriptStore(db_path=db_path) + await store.initialize() + + await store.save_transcript( + video_id="vid1", + full_text="transcript one", + segments=[{"text": "one", "start": 0.0, "duration": 1.0}], + ) + await store.save_transcript( + video_id="vid2", + full_text="transcript two", + segments=[{"text": "two", "start": 0.0, "duration": 1.0}], + ) + + result1 = await store.get_transcript("vid1") + assert result1 is not None + assert result1["video_id"] == "vid1" + assert result1["full_text"] == "transcript one" + + result2 = await store.get_transcript("vid2") + assert result2 is not None + assert result2["video_id"] == "vid2" + assert result2["full_text"] == "transcript two" + + await store.close() + + asyncio.run(_run()) + + +class TestQueue: + def _make_store(self, tmp_path): + return TranscriptStore(db_path=os.path.join(str(tmp_path), "test.db")) + + def test_enqueue_creates_pending_entry(self, tmp_path) -> None: + async def _run() -> None: + store = self._make_store(tmp_path) + await store.initialize() + entry = await store.enqueue("vid_001") + assert entry["video_id"] == "vid_001" + assert entry["status"] == "pending" + assert isinstance(entry["assigned_delay"], float) + assert 30.0 <= entry["assigned_delay"] <= 60.0 + assert entry["error"] is None + assert entry["error_type"] is None + await store.close() + asyncio.run(_run()) + + def test_enqueue_duplicate_returns_existing(self, tmp_path) -> None: + async def _run() -> None: + store = self._make_store(tmp_path) + await store.initialize() + first = await store.enqueue("vid_001") + second = await store.enqueue("vid_001") + assert first["assigned_delay"] == second["assigned_delay"] + assert first["id"] == second["id"] + await store.close() + asyncio.run(_run()) + + def test_get_queue_entry(self, tmp_path) -> None: + async def _run() -> None: + store = self._make_store(tmp_path) + await store.initialize() + await store.enqueue("vid_001") + entry = await store.get_queue_entry("vid_001") + assert entry is not None + assert entry["video_id"] == "vid_001" + assert entry["status"] == "pending" + assert await store.get_queue_entry("nonexistent") is None + await store.close() + asyncio.run(_run()) + + def test_get_next_pending_returns_oldest_first(self, tmp_path) -> None: + async def _run() -> None: + store = self._make_store(tmp_path) + await store.initialize() + await store.enqueue("vid_001") + await store.enqueue("vid_002") + await store.enqueue("vid_003") + first = await store.get_next_pending() + assert first is not None + assert first["video_id"] == "vid_001" + assert first["status"] == "processing" + assert first["started_at"] is not None + second = await store.get_next_pending() + assert second is not None + assert second["video_id"] == "vid_002" + await store.close() + asyncio.run(_run()) + + def test_get_next_pending_returns_none_when_empty(self, tmp_path) -> None: + async def _run() -> None: + store = self._make_store(tmp_path) + await store.initialize() + assert await store.get_next_pending() is None + await store.close() + asyncio.run(_run()) + + def test_mark_completed_removes_entry(self, tmp_path) -> None: + async def _run() -> None: + store = self._make_store(tmp_path) + await store.initialize() + await store.enqueue("vid_001") + await store.mark_completed("vid_001") + assert await store.get_queue_entry("vid_001") is None + await store.close() + asyncio.run(_run()) + + def test_mark_failed_updates_entry(self, tmp_path) -> None: + async def _run() -> None: + store = self._make_store(tmp_path) + await store.initialize() + await store.enqueue("vid_001") + await store.mark_failed("vid_001", "IP was blocked", "ip_blocked") + entry = await store.get_queue_entry("vid_001") + assert entry is not None + assert entry["status"] == "failed" + assert entry["error"] == "IP was blocked" + assert entry["error_type"] == "ip_blocked" + await store.close() + asyncio.run(_run()) + + def test_get_position_and_estimate_for_pending(self, tmp_path) -> None: + async def _run() -> None: + store = self._make_store(tmp_path) + await store.initialize() + e1 = await store.enqueue("vid_001") + e2 = await store.enqueue("vid_002") + e3 = await store.enqueue("vid_003") + result = await store.get_position_and_estimate("vid_003") + assert result is not None + assert result["position"] == 2 + expected = e1["assigned_delay"] + e2["assigned_delay"] + e3["assigned_delay"] + assert abs(result["estimated_seconds"] - expected) < 0.01 + await store.close() + asyncio.run(_run()) + + def test_get_position_and_estimate_for_first_pending(self, tmp_path) -> None: + async def _run() -> None: + store = self._make_store(tmp_path) + await store.initialize() + e1 = await store.enqueue("vid_001") + result = await store.get_position_and_estimate("vid_001") + assert result is not None + assert result["position"] == 0 + assert abs(result["estimated_seconds"] - e1["assigned_delay"]) < 0.01 + await store.close() + asyncio.run(_run()) + + def test_get_position_and_estimate_for_failed(self, tmp_path) -> None: + async def _run() -> None: + store = self._make_store(tmp_path) + await store.initialize() + await store.enqueue("vid_001") + await store.mark_failed("vid_001", "error", "type") + result = await store.get_position_and_estimate("vid_001") + assert result == {"position": 0, "estimated_seconds": 0.0} + await store.close() + asyncio.run(_run()) + + def test_get_position_and_estimate_not_found(self, tmp_path) -> None: + async def _run() -> None: + store = self._make_store(tmp_path) + await store.initialize() + assert await store.get_position_and_estimate("nonexistent") is None + await store.close() + asyncio.run(_run()) diff --git a/tests/test_transcriber.py b/tests/test_transcriber.py new file mode 100644 index 0000000..c05eab5 --- /dev/null +++ b/tests/test_transcriber.py @@ -0,0 +1,60 @@ +from unittest.mock import patch + +import pytest +from youtube_transcript_api._errors import IpBlocked, TranscriptsDisabled + +from app.transcriber import ( + InvalidURLError, + IPBlockedError, + TranscriptDisabledError, + create_api, + extract_video_id, + fetch_transcript_by_id, +) + + +class TestExtractVideoId: + def test_standard_url(self) -> None: + assert extract_video_id("https://www.youtube.com/watch?v=dQw4w9WgXcQ") == "dQw4w9WgXcQ" + + def test_short_url(self) -> None: + assert extract_video_id("https://youtu.be/dQw4w9WgXcQ") == "dQw4w9WgXcQ" + + def test_embed_url(self) -> None: + assert extract_video_id("https://www.youtube.com/embed/dQw4w9WgXcQ") == "dQw4w9WgXcQ" + + def test_shorts_url(self) -> None: + assert extract_video_id("https://www.youtube.com/shorts/dQw4w9WgXcQ") == "dQw4w9WgXcQ" + + def test_no_protocol(self) -> None: + assert extract_video_id("youtube.com/watch?v=dQw4w9WgXcQ") == "dQw4w9WgXcQ" + + def test_invalid_url(self) -> None: + try: + extract_video_id("https://example.com/video") + assert False, "Expected InvalidURLError" + except InvalidURLError: + pass + + +class TestCreateApi: + def test_create_api_has_browser_user_agent(self) -> None: + api = create_api() + # The session is stored on the fetcher inside the API instance. + session = api._fetcher._http_client + ua = session.headers.get("User-Agent", "") + assert "Chrome" in ua + + +class TestFetchTranscriptById: + def test_fetch_transcript_by_id_raises_ip_blocked(self) -> None: + api = create_api() + with patch.object(api, "fetch", side_effect=IpBlocked(video_id="test")): + with pytest.raises(IPBlockedError): + fetch_transcript_by_id("test", api) + + def test_fetch_transcript_by_id_raises_transcript_disabled(self) -> None: + api = create_api() + with patch.object(api, "fetch", side_effect=TranscriptsDisabled(video_id="test")): + with pytest.raises(TranscriptDisabledError): + fetch_transcript_by_id("test", api) diff --git a/tests/test_worker.py b/tests/test_worker.py new file mode 100644 index 0000000..0ef523a --- /dev/null +++ b/tests/test_worker.py @@ -0,0 +1,249 @@ +import asyncio +import os +from unittest.mock import AsyncMock, MagicMock, patch + +from app.storage import TranscriptStore +from app.transcriber import ( + IPBlockedError, + TranscriptDisabledError, +) +from app.worker import process_next + + +def _to_thread_passthrough(func, *args, **kwargs): + """Replacement for asyncio.to_thread that runs the function synchronously.""" + + async def _coro(): + return func(*args, **kwargs) + + return _coro() + + +def _patch_to_thread(): + return patch("app.worker.asyncio.to_thread", new=_to_thread_passthrough) + + +class TestProcessNext: + def test_process_next_returns_false_when_queue_empty(self, tmp_path) -> None: + async def _run() -> None: + db_path = os.path.join(str(tmp_path), "test.db") + store = TranscriptStore(db_path=db_path) + await store.initialize() + + with patch("app.worker.asyncio.sleep", new_callable=AsyncMock) as mock_sleep: + api = MagicMock() + result = await process_next(store, api) + assert result is False + mock_sleep.assert_not_called() + + await store.close() + + asyncio.run(_run()) + + def test_process_next_success(self, tmp_path) -> None: + async def _run() -> None: + db_path = os.path.join(str(tmp_path), "test.db") + store = TranscriptStore(db_path=db_path) + await store.initialize() + entry = await store.enqueue("vid_001") + + with patch("app.worker.asyncio.sleep", new_callable=AsyncMock) as mock_sleep, \ + _patch_to_thread(), \ + patch("app.worker.fetch_transcript_by_id", return_value=[{"text": "hello", "start": 0.0, "duration": 1.0}]): + api = MagicMock() + result = await process_next(store, api) + assert result is True + mock_sleep.assert_called_once() + slept_for = mock_sleep.call_args.args[0] + assert 30.0 <= slept_for <= 60.0 + assert abs(slept_for - entry["assigned_delay"]) < 0.001 + + transcript = await store.get_transcript("vid_001") + assert transcript is not None + assert transcript["full_text"] == "hello" + assert await store.get_queue_entry("vid_001") is None + await store.close() + + asyncio.run(_run()) + + def test_process_next_ip_blocked(self, tmp_path) -> None: + async def _run() -> None: + db_path = os.path.join(str(tmp_path), "test.db") + store = TranscriptStore(db_path=db_path) + await store.initialize() + await store.enqueue("vid_001") + + with patch("app.worker.asyncio.sleep", new_callable=AsyncMock), \ + _patch_to_thread(), \ + patch("app.worker.fetch_transcript_by_id", side_effect=IPBlockedError("blocked")): + api = MagicMock() + result = await process_next(store, api) + assert result is True + + entry = await store.get_queue_entry("vid_001") + assert entry is not None + assert entry["status"] == "failed" + assert entry["error_type"] == "ip_blocked" + assert await store.get_transcript("vid_001") is None + await store.close() + + asyncio.run(_run()) + + def test_process_next_transcript_disabled(self, tmp_path) -> None: + async def _run() -> None: + db_path = os.path.join(str(tmp_path), "test.db") + store = TranscriptStore(db_path=db_path) + await store.initialize() + await store.enqueue("vid_001") + + with patch("app.worker.asyncio.sleep", new_callable=AsyncMock), \ + _patch_to_thread(), \ + patch("app.worker.fetch_transcript_by_id", side_effect=TranscriptDisabledError("disabled")): + api = MagicMock() + result = await process_next(store, api) + assert result is True + + entry = await store.get_queue_entry("vid_001") + assert entry is not None + assert entry["status"] == "failed" + assert entry["error_type"] == "transcript_disabled" + await store.close() + + asyncio.run(_run()) + + def test_process_next_downloads_before_sleeping(self, tmp_path) -> None: + async def _run() -> None: + db_path = os.path.join(str(tmp_path), "test.db") + store = TranscriptStore(db_path=db_path) + await store.initialize() + await store.enqueue("vid_001") + + call_order = [] + + async def mock_sleep(seconds): + call_order.append("sleep") + + def mock_fetch(video_id, api): + call_order.append("fetch") + return [{"text": "hello", "start": 0.0, "duration": 1.0}] + + with patch("app.worker.asyncio.sleep", side_effect=mock_sleep), \ + _patch_to_thread(), \ + patch("app.worker.fetch_transcript_by_id", side_effect=mock_fetch): + api = MagicMock() + result = await process_next(store, api) + assert result is True + + assert call_order == ["fetch", "sleep"] + await store.close() + + asyncio.run(_run()) + + def test_process_next_sleeps_after_error(self, tmp_path) -> None: + async def _run() -> None: + db_path = os.path.join(str(tmp_path), "test.db") + store = TranscriptStore(db_path=db_path) + await store.initialize() + entry = await store.enqueue("vid_001") + + with patch("app.worker.asyncio.sleep", new_callable=AsyncMock) as mock_sleep, \ + _patch_to_thread(), \ + patch("app.worker.fetch_transcript_by_id", side_effect=IPBlockedError("blocked")): + api = MagicMock() + result = await process_next(store, api) + assert result is True + mock_sleep.assert_called_once() + slept_for = mock_sleep.call_args.args[0] + assert abs(slept_for - entry["assigned_delay"]) < 0.001 + + entry = await store.get_queue_entry("vid_001") + assert entry is not None + assert entry["status"] == "failed" + await store.close() + + asyncio.run(_run()) + + def test_process_next_sleeps_after_error_before_next_download(self, tmp_path) -> None: + async def _run() -> None: + db_path = os.path.join(str(tmp_path), "test.db") + store = TranscriptStore(db_path=db_path) + await store.initialize() + await store.enqueue("vid_001") + + call_order = [] + + async def mock_sleep(seconds): + call_order.append("sleep") + + def mock_fetch(video_id, api): + call_order.append("fetch") + raise IPBlockedError("blocked") + + with patch("app.worker.asyncio.sleep", side_effect=mock_sleep), \ + _patch_to_thread(), \ + patch("app.worker.fetch_transcript_by_id", side_effect=mock_fetch): + api = MagicMock() + await process_next(store, api) + + assert call_order == ["fetch", "sleep"] + await store.close() + + asyncio.run(_run()) + + def test_process_next_no_sleep_before_first_download_after_empty_queue(self, tmp_path) -> None: + async def _run() -> None: + db_path = os.path.join(str(tmp_path), "test.db") + store = TranscriptStore(db_path=db_path) + await store.initialize() + + call_order = [] + + async def mock_sleep(seconds): + call_order.append("sleep") + + def mock_fetch(video_id, api): + call_order.append("fetch") + return [{"text": "hello", "start": 0.0, "duration": 1.0}] + + with patch("app.worker.asyncio.sleep", side_effect=mock_sleep), \ + _patch_to_thread(), \ + patch("app.worker.fetch_transcript_by_id", side_effect=mock_fetch): + api = MagicMock() + + # Queue is empty — no sleep, no fetch + result = await process_next(store, api) + assert result is False + assert call_order == [] + + # Video is added while queue was idle + await store.enqueue("vid_001") + + # Next call downloads immediately, then sleeps after + result = await process_next(store, api) + assert result is True + assert call_order == ["fetch", "sleep"] + + await store.close() + + asyncio.run(_run()) + + def test_process_next_processes_fifo_order(self, tmp_path) -> None: + async def _run() -> None: + db_path = os.path.join(str(tmp_path), "test.db") + store = TranscriptStore(db_path=db_path) + await store.initialize() + await store.enqueue("vid_001") + await store.enqueue("vid_002") + + with patch("app.worker.asyncio.sleep", new_callable=AsyncMock), \ + _patch_to_thread(), \ + patch("app.worker.fetch_transcript_by_id", return_value=[{"text": "first", "start": 0.0, "duration": 1.0}]): + api = MagicMock() + await process_next(store, api) + + assert (await store.get_transcript("vid_001")) is not None + assert (await store.get_transcript("vid_002")) is None + assert (await store.get_queue_entry("vid_002")) is not None + await store.close() + + asyncio.run(_run()) |
