Skip to content

[CI] refine check_api_label_cn #7256

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 9 commits into
base: develop
Choose a base branch
from
Open
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
refine
  • Loading branch information
ooooo-create committed Apr 13, 2025
commit 9c97c8ba7d55df61d3c000144b0ec34094a667a9
137 changes: 72 additions & 65 deletions ci_scripts/check_api_label_cn.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,12 @@
import sys
from pathlib import Path

# precompile regex patterns
CN_API_LABEL_PATTERN = re.compile(r".. _([a-zA-Z0-9_]+):")
REF_PATTERN = re.compile(r":ref:`([^`]+)`")
API_LABEL_EXTRACT_PATTERN = re.compile(r".+?<(?P<api_label>.+?)>")
CN_API_PREFIX = "cn_api_paddle"

logger = logging.getLogger()
if logger.handlers:
# we assume the first handler is the one we want to configure
Expand All @@ -21,119 +27,121 @@
logger.setLevel(logging.INFO)


# check file's api_label
def check_api_label(doc_root: str, file: str) -> bool:
real_file = Path(doc_root) / file
with open(real_file, "r", encoding="utf-8") as f:
def check_api_label(file_path: Path, doc_root: Path) -> bool:
"""Check if the first line of the file matches the expected api_label format."""
with open(file_path, "r", encoding="utf-8") as f:
first_line = f.readline().strip()
return first_line == generate_cn_label_by_path(file)
return first_line == generate_cn_label(file_path, doc_root)


# path -> api_label (the first line's style)
def generate_cn_label_by_path(file: str) -> str:
result = file.removesuffix("_cn.rst")
result = "_".join(Path(result).parts)
result = f".. _cn_{result}:"
return result
def generate_cn_label(file_path: Path, doc_root: Path) -> str:
"""Generate the expected api_label format from file path."""
relative_path = file_path.relative_to(doc_root)
stem = relative_path.stem.removesuffix("_cn")
parts = relative_path.with_name(stem).parts
label = "_".join(parts)
return f".. _cn_{label}:"


# traverse doc/api to append api_label in list
def find_all_api_labels_in_dir(api_root: str) -> list[str]:
all_api_labels = []

for file_path in Path(api_root).rglob("*.rst"):
if not file_path.is_file():
def collect_api_labels(api_root: Path) -> set[str]:
"""Collect all valid api labels."""
labels = set()
for rst_file in api_root.rglob("*.rst"):
if not rst_file.is_file():
continue
path = str(file_path).removeprefix(api_root.removesuffix(API))
if not need_check(path):
if not need_check(rst_file, api_root):
continue
for label in find_api_labels_in_file(file_path):
all_api_labels.append(label)
return all_api_labels
labels.update(extract_api_labels(rst_file))
return labels


# api_labels in a file
def find_api_labels_in_file(file_path: Path | str) -> list[str]:
api_labels_in_one_file = []
def extract_api_labels(file_path: Path) -> set[str]:
labels = set()
with open(file_path, "r", encoding="utf-8") as f:
lines = f.readlines()
for line in lines:
line = re.search(".. _cn_api_paddle_([a-zA-Z0-9_]+)", line)
if not line:
match = CN_API_LABEL_PATTERN.search(line)
if not match:
continue
label = match.group(1)
if not label.startswith("cn_api_paddle"):
continue
api_labels_in_one_file.append(line.group(1))
return api_labels_in_one_file
labels.add(label)
return labels


# api doc for checking
def need_check(file: str) -> bool:
def need_check(file_path: Path, api_root: Path) -> bool:
return (
file.endswith("_cn.rst")
and not Path(file).name == "Overview_cn.rst"
and not Path(file).name == "index_cn.rst"
and file.startswith(API)
file_path.name.endswith("_cn.rst")
and file_path.name not in {"Overview_cn.rst", "index_cn.rst"}
and file_path.is_relative_to(api_root)
)


def check_usage_of_api_label(
files: list[Path], valid_api_labels: list[str]
def validate_api_label_references(
files: list[Path], valid_api_labels: set[str]
) -> list[str]:
errors = []
for file in files:
with open(file, "r", encoding="utf-8") as f:
pattern = f.read()
matches = re.findall(r":ref:`([^`]+)`", pattern)
content = f.read()
matches = REF_PATTERN.findall(content)
for match in matches:
api_label = match
if api_label_match := re.match(
r".+<(?P<api_label>.+?)>", api_label
):
if api_label_match := API_LABEL_EXTRACT_PATTERN.match(api_label):
api_label = api_label_match.group("api_label")
if not api_label.startswith("cn_api_paddle"):
if not api_label.startswith(CN_API_PREFIX):
continue
if api_label in valid_api_labels:
continue
errors.append(f"api label `{api_label}` in `{file}`")
return errors


def get_custom_files_for_checking_usage(doc_root: str) -> set[Path]:
def get_custom_files_for_checking_usage(api_root: Path) -> set[Path]:
# TODO: add more dir for checking
custom_files = set()
for file_path in (Path(doc_root) / API).rglob("*.rst"):
if not file_path.is_file():
for rst_file in api_root.rglob("*.rst"):
if not rst_file.is_file():
continue
custom_files.add(file_path)
if rst_file.name in {"set_global_initializer_cn.rst"}:
# TODO: how to deal with `api_paddle_Tensor_create_tensor`?
continue
custom_files.add(rst_file)
return custom_files


def run_cn_api_label_checking(
doc_root: str, api_root: str, files: list[str]
doc_root: Path, api_root: Path, files: list[Path]
) -> None:
# get real path for changed files
real_path_files_set = {Path(doc_root) / file for file in files}

# check the api_label in the first line for increased files
for file in files:
if need_check(file) and not check_api_label(doc_root, file):
for file_path in real_path_files_set:
if need_check(file_path, api_root) and not check_api_label(
file_path, doc_root
):
logger.error(
f"The first line in {doc_root}/{file} is not available, please re-check it!"
f"The first line in {file_path} is not available, please re-check it!"
)
sys.exit(1)

# collect all api_labels in api_root
valid_api_labels = find_all_api_labels_in_dir(api_root)
valid_api_labels = collect_api_labels(api_root)

# check the usage of api_label in custom files
api_label_usage_file_set = {Path(doc_root) / file for file in files}
api_label_usage_file_set.update(
get_custom_files_for_checking_usage(doc_root)
api_label_usage_file_set = (
real_path_files_set | get_custom_files_for_checking_usage(doc_root)
)

errors = check_usage_of_api_label(
if errors := validate_api_label_references(
api_label_usage_file_set, valid_api_labels
)
if errors:
):
logger.error("Found valid api labels usage as follows:")
for i, error in enumerate(errors):
logger.error(f"{i + 1}: {error}")
for i, error in enumerate(errors, 1):
logger.error(f"{i}: {error}")
sys.exit(1)

print("All api_label check success in PR !")
Expand All @@ -146,20 +154,20 @@ def parse_args():
parser = argparse.ArgumentParser(description="cn api_label checking")
parser.add_argument(
"doc_root",
type=Path,
help="the dir DOCROOT",
type=str,
default="/FluidDoc/docs/",
default=Path("/FluidDoc/docs"),
)

parser.add_argument(
"api_root",
type=str,
type=Path,
help="the dir api_root",
default="/FluidDoc/docs/api/",
default=Path("/FluidDoc/docs/api"),
)
parser.add_argument(
"all_git_files",
type=str,
type=Path,
nargs="*",
help="files need to check",
)
Expand All @@ -168,5 +176,4 @@ def parse_args():

if __name__ == "__main__":
args = parse_args()
API = args.doc_root.removesuffix(args.api_root)
run_cn_api_label_checking(args.doc_root, args.api_root, args.all_git_files)