Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
87 changes: 55 additions & 32 deletions embeddings/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

import click
import jsonlines
import smart_open
from timdex_dataset_api import TIMDEXDataset

from embeddings.config import configure_logger, configure_sentry
Expand Down Expand Up @@ -156,32 +157,41 @@ def test_model_load(ctx: click.Context) -> None:
@click.pass_context
@model_required
@click.option(
"-d",
"--dataset-location",
required=True,
required=False,
type=click.Path(),
help="TIMDEX dataset location, e.g. 's3://timdex/dataset', to read records from.",
)
@click.option(
"--run-id",
required=True,
required=False,
type=str,
help="TIMDEX ETL run id.",
)
@click.option(
"--run-record-offset",
required=True,
required=False,
type=int,
default=0,
help="TIMDEX ETL run record offset to start from, default = 0.",
)
@click.option(
"--record-limit",
required=True,
required=False,
type=int,
default=None,
help="Limit number of records after --run-record-offset, default = None (unlimited).",
)
@click.option(
"--input-jsonl",
required=False,
type=str,
default=None,
help=(
"Optional filepath to JSONLines file containing "
"TIMDEX records to create embeddings from."
),
)
@click.option(
"--strategy",
type=click.Choice(list(STRATEGY_REGISTRY.keys())),
Expand All @@ -205,50 +215,63 @@ def create_embeddings(
run_id: str,
run_record_offset: int,
record_limit: int,
input_jsonl: str,
strategy: list[str],
output_jsonl: str,
) -> None:
"""Create embeddings for TIMDEX records."""
model: BaseEmbeddingModel = ctx.obj["model"]
model.load()

# init TIMDEXDataset
timdex_dataset = TIMDEXDataset(dataset_location)

# query TIMDEX dataset for an iterator of records
timdex_records = timdex_dataset.read_dicts_iter(
columns=[
"timdex_record_id",
"run_id",
"run_record_offset",
"transformed_record",
],
run_id=run_id,
where=f"""run_record_offset >= {run_record_offset}""",
limit=record_limit,
action="index",
)
# read input records from TIMDEX dataset (default) or a JSONLines file
if input_jsonl:
with (
smart_open.open(input_jsonl, "r") as file_obj, # type: ignore[no-untyped-call]
jsonlines.Reader(file_obj) as reader,
):
timdex_records = iter(list(reader))

else:
if not dataset_location or not run_id:
raise click.UsageError(
"Both '--dataset-location' and '--run-id' are required arguments "
"when reading input records from the TIMDEX dataset."
)

# init TIMDEXDataset
timdex_dataset = TIMDEXDataset(dataset_location)

# query TIMDEX dataset for an iterator of records
timdex_records = timdex_dataset.read_dicts_iter(
columns=[
"timdex_record_id",
"run_id",
"run_record_offset",
"transformed_record",
],
run_id=run_id,
where=f"""run_record_offset >= {run_record_offset}""",
limit=record_limit,
action="index",
)

# create an iterator of EmbeddingInputs applying all requested strategies
embedding_inputs = create_embedding_inputs(timdex_records, list(strategy))

# create embeddings via the embedding model
embeddings = model.create_embeddings(embedding_inputs)

# if requested, write embeddings to a local JSONLines file
# write embeddings to TIMDEX dataset (default) or to a JSONLines file
if output_jsonl:
with jsonlines.open(
output_jsonl,
mode="w",
dumps=lambda obj: json.dumps(
obj,
default=str,
),
) as writer:
with (
smart_open.open(output_jsonl, "w") as s3_file, # type: ignore[no-untyped-call]
jsonlines.Writer(
s3_file,
dumps=lambda obj: json.dumps(obj, default=str),
) as writer,
):
for embedding in embeddings:
writer.write(embedding.to_dict())

# else, default writing embeddings back to TIMDEX dataset
else:
# WIP NOTE: write via anticipated timdex_dataset.embeddings.write(...)
# NOTE: will likely use an imported TIMDEXEmbedding class from TDA, which the
Expand Down
7 changes: 6 additions & 1 deletion embeddings/strategies/processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,12 @@ def create_embedding_inputs(
for timdex_dataset_record in timdex_dataset_records:

# decode and parse the TIMDEX JSON record once for all requested strategies
timdex_record = json.loads(timdex_dataset_record["transformed_record"].decode())
transformed_record_raw = timdex_dataset_record["transformed_record"]
timdex_record = json.loads(
transformed_record_raw.decode()
if isinstance(transformed_record_raw, bytes)
else transformed_record_raw
)

for transformer in transformers:
# prepare text for embedding from transformer strategy
Expand Down
6 changes: 5 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ dependencies = [
"huggingface-hub>=0.26.0",
"jsonlines>=4.0.0",
"sentry-sdk>=2.34.1",
"smart-open[s3]>=7.4.4",
"timdex-dataset-api",
"torch>=2.9.0",
"transformers>=4.57.1",
Expand Down Expand Up @@ -41,7 +42,10 @@ exclude = [
]

[[tool.mypy.overrides]]
module = ["timdex_dataset_api.*"]
module = [
"timdex_dataset_api.*",
"smart_open.*",
]
follow_untyped_imports = true


Expand Down
3 changes: 3 additions & 0 deletions tests/fixtures/cli_inputs/test-3-records.jsonl
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
{"timdex_record_id": "record:1", "run_id": "abc123", "run_record_offset": 0, "transformed_record": "{\"title\":\"Record 1\",\"description\":\"This is a record about coffee in the mountains.\"}"}
{"timdex_record_id": "record:2", "run_id": "abc123", "run_record_offset": 1, "transformed_record": "{\"title\":\"Record 2\",\"description\":\"Sometimes poetry is made accidentally by the fabrication of metadata.\"}"}
{"timdex_record_id": "record:3", "run_id": "abc123", "run_record_offset": 2, "transformed_record": "{\"title\":\"Record 3\",\"description\":\"This is an oddball record, meant to evoke the peculiar nature of mathematics.\"}"}
78 changes: 68 additions & 10 deletions tests/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,10 +133,38 @@ def test_model_required_decorator_works_across_commands(
assert "OK" in result.output


def test_create_embeddings_requires_strategy(register_mock_model, runner):
result = runner.invoke(
main,
[
"create-embeddings",
"--model-uri",
"test/mock-model",
"--dataset-location",
"s3://test",
"--run-id",
"run-1",
],
)
assert result.exit_code != 0
assert "Missing option '--strategy'" in result.output


def test_create_embeddings_requires_dataset_location(register_mock_model, runner):
result = runner.invoke(main, ["create-embeddings", "--model-uri", "test/mock-model"])
result = runner.invoke(
main,
[
"create-embeddings",
"--model-uri",
"test/mock-model",
"--run-id",
"run-1",
"--strategy",
"full_record",
],
)
assert result.exit_code != 0
assert "--dataset-location" in result.output
assert "Both '--dataset-location' and '--run-id' are required" in result.output


def test_create_embeddings_requires_run_id(register_mock_model, runner):
Expand All @@ -148,24 +176,54 @@ def test_create_embeddings_requires_run_id(register_mock_model, runner):
"test/mock-model",
"--dataset-location",
"s3://test",
"--strategy",
"full_record",
],
)
assert result.exit_code != 0
assert "Missing option '--run-id'" in result.output
assert "Both '--dataset-location' and '--run-id' are required" in result.output


def test_create_embeddings_requires_strategy(register_mock_model, runner):
def test_create_embeddings_optional_input_jsonl(register_mock_model, runner, tmp_path):
input_file = "tests/fixtures/cli_inputs/test-3-records.jsonl"
output_file = tmp_path / "output.jsonl"

result = runner.invoke(
main,
[
"create-embeddings",
"--model-uri",
"test/mock-model",
"--dataset-location",
"s3://test",
"--run-id",
"run-1",
"--input-jsonl",
input_file,
"--strategy",
"full_record",
"--output-jsonl",
str(output_file),
],
)
assert result.exit_code != 0
assert "Missing option '--strategy'" in result.output
assert result.exit_code == 0
assert output_file.exists()


def test_create_embeddings_optional_input_jsonl_does_not_require_dataset_params(
register_mock_model, runner, tmp_path
):
input_file = "tests/fixtures/cli_inputs/test-3-records.jsonl"
output_file = tmp_path / "output.jsonl"

result = runner.invoke(
main,
[
"create-embeddings",
"--model-uri",
"test/mock-model",
"--input-jsonl",
input_file,
"--strategy",
"full_record",
"--output-jsonl",
str(output_file),
],
)
assert result.exit_code == 0
Loading