Skip to content

snakemake redo #506

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 19 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions Justfile
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# Setup

install:
uv sync --no-cache --frozen
uv sync --group dev --group docs --no-cache --frozen

# Packaging

Expand All @@ -16,8 +16,8 @@ publish:
# Testing

test:
export TEST_TOKEN=$(cat ~/.latch/token) &&\
pytest -s tests
export TEST_TOKEN=$(cat ~/.latch/token)
pytest -s

# Docs

Expand Down
13 changes: 8 additions & 5 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ include = ["src/**/*.py", "src/latch_cli/services/init/*"]

[project]
name = "latch"
version = "2.62.1"
version = "2.62.1.a2"
description = "The Latch SDK"
authors = [{ name = "Kenny Workman", email = "[email protected]" }]
maintainers = [
Expand All @@ -22,8 +22,8 @@ maintainers = [

readme = "README.md"
license = { file = "LICENSE" }

requires-python = ">=3.9"

dependencies = [
"kubernetes>=24.2.0",
"pyjwt>=0.2.0",
Expand Down Expand Up @@ -75,7 +75,11 @@ classifiers = [

[project.optional-dependencies]
pandas = ["pandas>=2.0.0"]
snakemake = ["snakemake>=7.18.0,<7.30.2", "pulp>=2.0,<2.8"]
snakemake = [
"snakemake",
"snakemake-storage-plugin-latch==0.1.11",
"snakemake-executor-plugin-latch==0.1.9",
]

[project.scripts]
latch = "latch_cli.main:main"
Expand All @@ -99,11 +103,10 @@ docs = [
]

[tool.ruff]
line-length = 100
target-version = "py39"

[tool.ruff.lint]
preview = true

pydocstyle = { convention = "google" }
extend-select = [
"F",
Expand Down
152 changes: 52 additions & 100 deletions src/latch/resources/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,18 +52,8 @@ def get_v100_x1_pod() -> Pod:

primary_container = V1Container(name="primary")
resources = V1ResourceRequirements(
requests={
"cpu": "7",
"memory": "48Gi",
"nvidia.com/gpu": 1,
"ephemeral-storage": "4500Gi",
},
limits={
"cpu": "7",
"memory": "48Gi",
"nvidia.com/gpu": 1,
"ephemeral-storage": "5000Gi",
},
requests={"cpu": "7", "memory": "48Gi", "nvidia.com/gpu": 1, "ephemeral-storage": "4500Gi"},
limits={"cpu": "7", "memory": "48Gi", "nvidia.com/gpu": 1, "ephemeral-storage": "5000Gi"},
)
primary_container.resources = resources

Expand Down Expand Up @@ -94,12 +84,7 @@ def get_v100_x4_pod() -> Pod:
"nvidia.com/gpu": 4,
"ephemeral-storage": "4500Gi",
},
limits={
"cpu": "30",
"memory": "230Gi",
"nvidia.com/gpu": 4,
"ephemeral-storage": "5000Gi",
},
limits={"cpu": "30", "memory": "230Gi", "nvidia.com/gpu": 4, "ephemeral-storage": "5000Gi"},
)
primary_container.resources = resources

Expand Down Expand Up @@ -135,12 +120,7 @@ def get_v100_x8_pod() -> Pod:
"nvidia.com/gpu": 8,
"ephemeral-storage": "4500Gi",
},
limits={
"cpu": "62",
"memory": "400Gi",
"nvidia.com/gpu": 8,
"ephemeral-storage": "5000Gi",
},
limits={"cpu": "62", "memory": "400Gi", "nvidia.com/gpu": 8, "ephemeral-storage": "5000Gi"},
)
primary_container.resources = resources

Expand Down Expand Up @@ -205,21 +185,14 @@ def _get_small_gpu_pod() -> Pod:
"nvidia.com/gpu": "1",
"ephemeral-storage": "1500Gi",
},
limits={
"cpu": "7",
"memory": "30Gi",
"nvidia.com/gpu": "1",
"ephemeral-storage": "1500Gi",
},
limits={"cpu": "7", "memory": "30Gi", "nvidia.com/gpu": "1", "ephemeral-storage": "1500Gi"},
)
primary_container.resources = resources

return Pod(
pod_spec=V1PodSpec(
containers=[primary_container],
tolerations=[
V1Toleration(effect="NoSchedule", key="ng", value="gpu-small")
],
tolerations=[V1Toleration(effect="NoSchedule", key="ng", value="gpu-small")],
),
primary_container_name="primary",
)
Expand All @@ -244,9 +217,7 @@ def _get_large_pod() -> Pod:
pod_spec=V1PodSpec(
runtime_class_name="sysbox-runc",
containers=[primary_container],
tolerations=[
V1Toleration(effect="NoSchedule", key="ng", value="cpu-96-spot")
],
tolerations=[V1Toleration(effect="NoSchedule", key="ng", value="cpu-96-spot")],
),
primary_container_name="primary",
)
Expand All @@ -271,9 +242,7 @@ def _get_medium_pod() -> Pod:
pod_spec=V1PodSpec(
runtime_class_name="sysbox-runc",
containers=[primary_container],
tolerations=[
V1Toleration(effect="NoSchedule", key="ng", value="cpu-32-spot")
],
tolerations=[V1Toleration(effect="NoSchedule", key="ng", value="cpu-32-spot")],
),
primary_container_name="primary",
)
Expand All @@ -295,10 +264,7 @@ def _get_small_pod() -> Pod:
"private:uidmapping=0:1048576:65536;gidmapping=0:1048576:65536"
)
},
pod_spec=V1PodSpec(
runtime_class_name="sysbox-runc",
containers=[primary_container],
),
pod_spec=V1PodSpec(runtime_class_name="sysbox-runc", containers=[primary_container]),
primary_container_name="primary",
)

Expand Down Expand Up @@ -466,8 +432,7 @@ def custom_memory_optimized_task(cpu: int, memory: int):
)
elif memory > 485:
raise ValueError(
f"custom memory optimized task requires too much RAM: {memory} GiB (max 485"
" GiB)"
f"custom memory optimized task requires too much RAM: {memory} GiB (max 485 GiB)"
)

primary_container = V1Container(name="primary")
Expand All @@ -485,9 +450,7 @@ def custom_memory_optimized_task(cpu: int, memory: int):
pod_spec=V1PodSpec(
runtime_class_name="sysbox-runc",
containers=[primary_container],
tolerations=[
V1Toleration(effect="NoSchedule", key="ng", value="mem-512-spot")
],
tolerations=[V1Toleration(effect="NoSchedule", key="ng", value="mem-512-spot")],
),
primary_container_name="primary",
)
Expand Down Expand Up @@ -517,11 +480,7 @@ class _NGConfig:
max_storage_gb_ish = int(max_storage_gib * Units.GiB / Units.GB)


def _custom_task_config(
cpu: int,
memory: int,
storage_gib: int,
) -> Pod:
def _custom_task_config(cpu: int, memory: int, storage_gib: int) -> Pod:
target_ng = None
for ng in taint_data:
if (
Expand All @@ -547,11 +506,7 @@ def _custom_task_config(
"memory": f"{memory}Gi",
"ephemeral-storage": f"{storage_gib}Gi",
},
limits={
"cpu": str(cpu),
"memory": f"{memory}Gi",
"ephemeral-storage": f"{storage_gib}Gi",
},
limits={"cpu": str(cpu), "memory": f"{memory}Gi", "ephemeral-storage": f"{storage_gib}Gi"},
)
primary_container.resources = resources
return Pod(
Expand All @@ -564,9 +519,7 @@ def _custom_task_config(
runtime_class_name="sysbox-runc",
containers=[primary_container],
tolerations=[
V1Toleration(
effect="NoSchedule", key="ng", value=target_ng.toleration_value
)
V1Toleration(effect="NoSchedule", key="ng", value=target_ng.toleration_value)
],
),
primary_container_name="primary",
Expand All @@ -591,31 +544,22 @@ def custom_task(
"""
if callable(cpu) or callable(memory) or callable(storage_gib):
task_config = DynamicTaskConfig(
cpu=cpu,
memory=memory,
storage=storage_gib,
pod_config=_get_small_pod(),
cpu=cpu, memory=memory, storage=storage_gib, pod_config=_get_small_pod()
)
return functools.partial(task, task_config=task_config, timeout=timeout)

return functools.partial(
task,
task_config=_custom_task_config(cpu, memory, storage_gib),
timeout=timeout,
**kwargs,
task, task_config=_custom_task_config(cpu, memory, storage_gib), timeout=timeout, **kwargs
)


def lustre_setup_task():
primary_container = V1Container(
name="primary",
resources=V1ResourceRequirements(
requests={"cpu": "500m", "memory": "500Mi"},
limits={"cpu": "500m", "memory": "500Mi"},
requests={"cpu": "500m", "memory": "500Mi"}, limits={"cpu": "500m", "memory": "500Mi"}
),
volume_mounts=[
V1VolumeMount(mount_path="/nf-workdir", name="nextflow-workdir")
],
volume_mounts=[V1VolumeMount(mount_path="/nf-workdir", name="nextflow-workdir")],
)

task_config = Pod(
Expand Down Expand Up @@ -659,6 +603,30 @@ def nextflow_runtime_task(cpu: int, memory: int, storage_gib: int = 50):
return functools.partial(task, task_config=task_config)


def snakemake_runtime_task(*, cpu: int, memory: int, storage_gib: int = 50):
task_config = _custom_task_config(cpu, memory, storage_gib)

task_config.pod_spec.automount_service_account_token = True

assert len(task_config.pod_spec.containers) == 1
task_config.pod_spec.containers[0].volume_mounts = [
V1VolumeMount(mount_path="/snakemake-workdir", name="snakemake-workdir")
]

task_config.pod_spec.volumes = [
V1Volume(
name="snakemake-workdir",
persistent_volume_claim=V1PersistentVolumeClaimVolumeSource(
# this value will be injected by flytepropeller
# ayush: this is also used by snakemake bc why not
claim_name="nextflow-pvc-placeholder"
),
)
]

return functools.partial(task, task_config=task_config)


def _get_l40s_pod(instance_type: str, cpu: int, memory_gib: int, gpus: int) -> Pod:
"""Helper function to create L40s GPU pod configurations."""
primary_container = V1Container(name="primary")
Expand All @@ -685,66 +653,50 @@ def _get_l40s_pod(instance_type: str, cpu: int, memory_gib: int, gpus: int) -> P
return Pod(
pod_spec=V1PodSpec(
containers=[primary_container],
tolerations=[
V1Toleration(
effect="NoSchedule",
key="ng",
value=instance_type
)
],
tolerations=[V1Toleration(effect="NoSchedule", key="ng", value=instance_type)],
),
primary_container_name="primary",
annotations={
"cluster-autoscaler.kubernetes.io/safe-to-evict": "false",
},
annotations={"cluster-autoscaler.kubernetes.io/safe-to-evict": "false"},
)


g6e_xlarge_task = functools.partial(
task,
task_config=_get_l40s_pod("g6e-xlarge", cpu=4, memory_gib=32, gpus=1)
task, task_config=_get_l40s_pod("g6e-xlarge", cpu=4, memory_gib=32, gpus=1)
)
"""4 vCPUs, 32 GiB RAM, 1 L40s GPU"""

g6e_2xlarge_task = functools.partial(
task,
task_config=_get_l40s_pod("g6e-2xlarge", cpu=8, memory_gib=64, gpus=1)
task, task_config=_get_l40s_pod("g6e-2xlarge", cpu=8, memory_gib=64, gpus=1)
)
"""8 vCPUs, 64 GiB RAM, 1 L40s GPU"""

g6e_4xlarge_task = functools.partial(
task,
task_config=_get_l40s_pod("g6e-4xlarge", cpu=16, memory_gib=128, gpus=1)
task, task_config=_get_l40s_pod("g6e-4xlarge", cpu=16, memory_gib=128, gpus=1)
)
"""16 vCPUs, 128 GiB RAM, 1 L40s GPU"""

g6e_8xlarge_task = functools.partial(
task,
task_config=_get_l40s_pod("g6e-8xlarge", cpu=32, memory_gib=256, gpus=1)
task, task_config=_get_l40s_pod("g6e-8xlarge", cpu=32, memory_gib=256, gpus=1)
)
"""32 vCPUs, 256 GiB RAM, 1 L40s GPU"""

g6e_12xlarge_task = functools.partial(
task,
task_config=_get_l40s_pod("g6e-12xlarge", cpu=48, memory_gib=384, gpus=4)
task, task_config=_get_l40s_pod("g6e-12xlarge", cpu=48, memory_gib=384, gpus=4)
)
"""48 vCPUs, 384 GiB RAM, 4 L40s GPUs"""

g6e_16xlarge_task = functools.partial(
task,
task_config=_get_l40s_pod("g6e-16xlarge", cpu=64, memory_gib=512, gpus=1)
task, task_config=_get_l40s_pod("g6e-16xlarge", cpu=64, memory_gib=512, gpus=1)
)
"""64 vCPUs, 512 GiB RAM, 1 L40s GPUs"""

g6e_24xlarge_task = functools.partial(
task,
task_config=_get_l40s_pod("g6e-24xlarge", cpu=96, memory_gib=768, gpus=4)
task, task_config=_get_l40s_pod("g6e-24xlarge", cpu=96, memory_gib=768, gpus=4)
)
"""96 vCPUs, 768 GiB RAM, 4 L40s GPUs"""


g6e_48xlarge_task = functools.partial(
task,
task_config=_get_l40s_pod("g6e-48xlarge", cpu=192, memory_gib=1536, gpus=8)
task, task_config=_get_l40s_pod("g6e-48xlarge", cpu=192, memory_gib=1536, gpus=8)
)
"""192 vCPUs, 1536 GiB RAM, 8 L40s GPUs"""
Loading
Loading