diff --git a/.gitignore b/.gitignore index 4891566..a7d4bee 100644 --- a/.gitignore +++ b/.gitignore @@ -66,3 +66,6 @@ htmlcov/ coverage.xml .tox/ .nox/ +.pytest_cache +#old code +app/frontend \ No newline at end of file diff --git a/.project-metadata.yaml b/.project-metadata.yaml index 565af55..85b9b1b 100644 --- a/.project-metadata.yaml +++ b/.project-metadata.yaml @@ -1,5 +1,5 @@ -name: Synthetic Data Generation +name: Synthetic Data Studio description: | This AMP demonstrates how we can generate synthetic data for finetuning, ground truth for LLM use case evaluation, embedding finetuning etc. diff --git a/README.md b/README.md index a504aa3..9aa1077 100644 --- a/README.md +++ b/README.md @@ -83,6 +83,14 @@ Built using: python build/start_application.py ``` +## [Technical Overview](docs/technical_overview.md): + +The given document gives overall capabilities, tech stack, and general idea to get started on this application. + +## Technical Guides: +### [Generation Workflow](docs/guides/sft_workflow.md) +### [Evaluation Workflow](docs/guides/evaluation_workflow.md) + ## Legal Notice diff --git a/alembic.ini b/alembic.ini new file mode 100644 index 0000000..7bbab68 --- /dev/null +++ b/alembic.ini @@ -0,0 +1,90 @@ +# A generic, single database configuration. + +[alembic] +# path to migration scripts +script_location = alembic + +# template used to generate migration files +# file_template = %%(rev)s_%%(slug)s + +# sys.path path, will be prepended to sys.path if present. +# defaults to the current working directory. +prepend_sys_path = . + +# timezone to use when rendering the date within the migration file +# as well as the filename. +# If specified, requires the python-dateutil library that can be +# installed by adding `alembic[tz]` to the pip requirements +# string value is passed to dateutil.tz.gettz() +# leave blank for localtime +# timezone = + +# max length of characters to apply to the +# "slug" field +# truncate_slug_length = 40 + +# set to 'true' to run the environment during +# the 'revision' command, regardless of autogenerate +# revision_environment = false + +# set to 'true' to allow .pyc and .pyo files without +# a source .py file to be detected as revisions in the +# versions/ directory +# sourceless = false + +# version location specification; This defaults +# to alembic/versions. When using multiple version +# directories, initial revisions must be specified with --version-path. +# The path separator used here should be the separator specified by "version_path_separator" below. +# version_locations = %(here)s/bar:%(here)s/bat:alembic/versions + +# version path separator; As mentioned above, this is the character used to split +# version_locations. The default within new alembic.ini files is "os", which uses os.pathsep. +# If this key is omitted entirely, it falls back to the legacy behavior of splitting on spaces and/or commas. +# Valid values for version_path_separator are: +# +# version_path_separator = : +# version_path_separator = ; +# version_path_separator = space + +# set to 'true' to search source directory for include files +# that are relative to the alembic.ini file +# source_include_current_dir = true + +# version file pattern +# version_file_pattern = %(rev)s_%%(slug)s + +# Logging configuration +[loggers] +keys = root,sqlalchemy,alembic + +[handlers] +keys = console + +[formatters] +keys = generic + +[logger_root] +level = WARN +handlers = console +qualname = + +[logger_sqlalchemy] +level = WARN +handlers = +qualname = sqlalchemy.engine + +[logger_alembic] +level = INFO +handlers = +qualname = alembic + +[handler_console] +class = StreamHandler +args = (sys.stderr,) +level = NOTSET +formatter = generic + +[formatter_generic] +format = %(levelname)-5.5s [%(name)s] %(message)s +datefmt = %H:%M:%S \ No newline at end of file diff --git a/alembic/README b/alembic/README new file mode 100644 index 0000000..98e4f9c --- /dev/null +++ b/alembic/README @@ -0,0 +1 @@ +Generic single-database configuration. \ No newline at end of file diff --git a/alembic/env.py b/alembic/env.py new file mode 100644 index 0000000..3aa09e7 --- /dev/null +++ b/alembic/env.py @@ -0,0 +1,60 @@ +# alembic/env.py +from logging.config import fileConfig +from sqlalchemy import engine_from_config, pool +from alembic import context +import os +import sys +from pathlib import Path + +# This is needed to find the app module +base_path = Path(__file__).parent.parent +sys.path.append(str(base_path)) + +from app.migrations.alembic_schema_models import Base + +# this is the Alembic Config object +config = context.config + +# Interpret the config file for Python logging +if config.config_file_name is not None: + fileConfig(config.config_file_name) + +# Get database path (one level up from app directory) +db_path = os.path.join(base_path, "metadata.db") +config.set_main_option("sqlalchemy.url", f"sqlite:///{db_path}") + +target_metadata = Base.metadata + +def run_migrations_offline() -> None: + url = config.get_main_option("sqlalchemy.url") + context.configure( + url=url, + target_metadata=target_metadata, + literal_binds=True, + dialect_opts={"paramstyle": "named"}, + ) + + with context.begin_transaction(): + context.run_migrations() + +def run_migrations_online() -> None: + connectable = engine_from_config( + config.get_section(config.config_ini_section, {}), + prefix="sqlalchemy.", + poolclass=pool.NullPool, + ) + + with connectable.connect() as connection: + context.configure( + connection=connection, + target_metadata=target_metadata, + render_as_batch=True + ) + + with context.begin_transaction(): + context.run_migrations() + +if context.is_offline_mode(): + run_migrations_offline() +else: + run_migrations_online() \ No newline at end of file diff --git a/alembic/script.py.mako b/alembic/script.py.mako new file mode 100644 index 0000000..fbc4b07 --- /dev/null +++ b/alembic/script.py.mako @@ -0,0 +1,26 @@ +"""${message} + +Revision ID: ${up_revision} +Revises: ${down_revision | comma,n} +Create Date: ${create_date} + +""" +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa +${imports if imports else ""} + +# revision identifiers, used by Alembic. +revision: str = ${repr(up_revision)} +down_revision: Union[str, None] = ${repr(down_revision)} +branch_labels: Union[str, Sequence[str], None] = ${repr(branch_labels)} +depends_on: Union[str, Sequence[str], None] = ${repr(depends_on)} + + +def upgrade() -> None: + ${upgrades if upgrades else "pass"} + + +def downgrade() -> None: + ${downgrades if downgrades else "pass"} diff --git a/alembic/versions/5249999b91fa_initial_migration.py b/alembic/versions/5249999b91fa_initial_migration.py new file mode 100644 index 0000000..c2d7342 --- /dev/null +++ b/alembic/versions/5249999b91fa_initial_migration.py @@ -0,0 +1,432 @@ +"""Initial migration + +Revision ID: 5249999b91fa +Revises: +Create Date: 2025-02-11 13:55:52.405051 + +""" +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision: str = '5249999b91fa' +down_revision: Union[str, None] = None +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('evaluation_metadata', schema=None) as batch_op: + batch_op.alter_column('timestamp', + existing_type=sa.VARCHAR(), + type_=sa.Text(), + existing_nullable=True) + batch_op.alter_column('model_id', + existing_type=sa.VARCHAR(), + type_=sa.Text(), + existing_nullable=True) + batch_op.alter_column('inference_type', + existing_type=sa.VARCHAR(), + type_=sa.Text(), + existing_nullable=True) + batch_op.alter_column('caii_endpoint', + existing_type=sa.VARCHAR(), + type_=sa.Text(), + existing_nullable=True) + batch_op.alter_column('use_case', + existing_type=sa.VARCHAR(), + type_=sa.Text(), + existing_nullable=True) + batch_op.alter_column('custom_prompt', + existing_type=sa.VARCHAR(), + type_=sa.Text(), + existing_nullable=True) + batch_op.alter_column('model_parameters', + existing_type=sa.VARCHAR(), + type_=sa.Text(), + existing_nullable=True) + batch_op.alter_column('generate_file_name', + existing_type=sa.VARCHAR(), + type_=sa.Text(), + existing_nullable=True) + batch_op.alter_column('evaluate_file_name', + existing_type=sa.VARCHAR(), + type_=sa.Text(), + existing_nullable=True) + batch_op.alter_column('display_name', + existing_type=sa.VARCHAR(), + type_=sa.Text(), + existing_nullable=True) + batch_op.alter_column('local_export_path', + existing_type=sa.VARCHAR(), + type_=sa.Text(), + existing_nullable=True) + batch_op.alter_column('examples', + existing_type=sa.VARCHAR(), + type_=sa.Text(), + existing_nullable=True) + batch_op.alter_column('job_id', + existing_type=sa.VARCHAR(), + type_=sa.Text(), + existing_nullable=True) + batch_op.alter_column('job_name', + existing_type=sa.VARCHAR(), + type_=sa.Text(), + existing_nullable=True) + batch_op.alter_column('job_status', + existing_type=sa.VARCHAR(), + type_=sa.Text(), + existing_nullable=True) + batch_op.alter_column('job_creator_name', + existing_type=sa.VARCHAR(), + type_=sa.Text(), + existing_nullable=True) + + with op.batch_alter_table('export_metadata', schema=None) as batch_op: + batch_op.alter_column('timestamp', + existing_type=sa.VARCHAR(), + type_=sa.Text(), + existing_nullable=True) + batch_op.alter_column('display_export_name', + existing_type=sa.VARCHAR(), + type_=sa.Text(), + existing_nullable=True) + batch_op.alter_column('display_name', + existing_type=sa.VARCHAR(), + type_=sa.Text(), + existing_nullable=True) + batch_op.alter_column('local_export_path', + existing_type=sa.VARCHAR(), + type_=sa.Text(), + existing_nullable=True) + batch_op.alter_column('hf_export_path', + existing_type=sa.VARCHAR(), + type_=sa.Text(), + existing_nullable=True) + batch_op.alter_column('job_id', + existing_type=sa.VARCHAR(), + type_=sa.Text(), + existing_nullable=True) + batch_op.alter_column('job_name', + existing_type=sa.VARCHAR(), + type_=sa.Text(), + existing_nullable=True) + batch_op.alter_column('job_status', + existing_type=sa.VARCHAR(), + type_=sa.Text(), + existing_nullable=True) + batch_op.alter_column('job_creator_name', + existing_type=sa.VARCHAR(), + type_=sa.Text(), + existing_nullable=True) + + with op.batch_alter_table('generation_metadata', schema=None) as batch_op: + batch_op.alter_column('timestamp', + existing_type=sa.VARCHAR(), + type_=sa.Text(), + existing_nullable=True) + batch_op.alter_column('technique', + existing_type=sa.VARCHAR(), + type_=sa.Text(), + existing_nullable=True) + batch_op.alter_column('model_id', + existing_type=sa.VARCHAR(), + type_=sa.Text(), + existing_nullable=True) + batch_op.alter_column('inference_type', + existing_type=sa.VARCHAR(), + type_=sa.Text(), + existing_nullable=True) + batch_op.alter_column('caii_endpoint', + existing_type=sa.VARCHAR(), + type_=sa.Text(), + existing_nullable=True) + batch_op.alter_column('use_case', + existing_type=sa.VARCHAR(), + type_=sa.Text(), + existing_nullable=True) + batch_op.alter_column('custom_prompt', + existing_type=sa.VARCHAR(), + type_=sa.Text(), + existing_nullable=True) + batch_op.alter_column('model_parameters', + existing_type=sa.VARCHAR(), + type_=sa.Text(), + existing_nullable=True) + batch_op.alter_column('input_key', + existing_type=sa.VARCHAR(), + type_=sa.Text(), + existing_nullable=True) + batch_op.alter_column('output_key', + existing_type=sa.VARCHAR(), + type_=sa.Text(), + existing_nullable=True) + batch_op.alter_column('output_value', + existing_type=sa.VARCHAR(), + type_=sa.Text(), + existing_nullable=True) + batch_op.alter_column('generate_file_name', + existing_type=sa.VARCHAR(), + type_=sa.Text(), + existing_nullable=True) + batch_op.alter_column('display_name', + existing_type=sa.VARCHAR(), + type_=sa.Text(), + existing_nullable=True) + batch_op.alter_column('local_export_path', + existing_type=sa.VARCHAR(), + type_=sa.Text(), + existing_nullable=True) + batch_op.alter_column('hf_export_path', + existing_type=sa.VARCHAR(), + type_=sa.Text(), + existing_nullable=True) + batch_op.alter_column('topics', + existing_type=sa.VARCHAR(), + type_=sa.Text(), + existing_nullable=True) + batch_op.alter_column('examples', + existing_type=sa.VARCHAR(), + type_=sa.Text(), + existing_nullable=True) + batch_op.alter_column('schema', + existing_type=sa.VARCHAR(), + type_=sa.Text(), + existing_nullable=True) + batch_op.alter_column('doc_paths', + existing_type=sa.VARCHAR(), + type_=sa.Text(), + existing_nullable=True) + batch_op.alter_column('input_path', + existing_type=sa.VARCHAR(), + type_=sa.Text(), + existing_nullable=True) + batch_op.alter_column('job_id', + existing_type=sa.VARCHAR(), + type_=sa.Text(), + existing_nullable=True) + batch_op.alter_column('job_name', + existing_type=sa.VARCHAR(), + type_=sa.Text(), + existing_nullable=True) + batch_op.alter_column('job_status', + existing_type=sa.VARCHAR(), + type_=sa.Text(), + existing_nullable=True) + batch_op.alter_column('job_creator_name', + existing_type=sa.VARCHAR(), + type_=sa.Text(), + existing_nullable=True) + + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('generation_metadata', schema=None) as batch_op: + batch_op.alter_column('job_creator_name', + existing_type=sa.Text(), + type_=sa.VARCHAR(), + existing_nullable=True) + batch_op.alter_column('job_status', + existing_type=sa.Text(), + type_=sa.VARCHAR(), + existing_nullable=True) + batch_op.alter_column('job_name', + existing_type=sa.Text(), + type_=sa.VARCHAR(), + existing_nullable=True) + batch_op.alter_column('job_id', + existing_type=sa.Text(), + type_=sa.VARCHAR(), + existing_nullable=True) + batch_op.alter_column('input_path', + existing_type=sa.Text(), + type_=sa.VARCHAR(), + existing_nullable=True) + batch_op.alter_column('doc_paths', + existing_type=sa.Text(), + type_=sa.VARCHAR(), + existing_nullable=True) + batch_op.alter_column('schema', + existing_type=sa.Text(), + type_=sa.VARCHAR(), + existing_nullable=True) + batch_op.alter_column('examples', + existing_type=sa.Text(), + type_=sa.VARCHAR(), + existing_nullable=True) + batch_op.alter_column('topics', + existing_type=sa.Text(), + type_=sa.VARCHAR(), + existing_nullable=True) + batch_op.alter_column('hf_export_path', + existing_type=sa.Text(), + type_=sa.VARCHAR(), + existing_nullable=True) + batch_op.alter_column('local_export_path', + existing_type=sa.Text(), + type_=sa.VARCHAR(), + existing_nullable=True) + batch_op.alter_column('display_name', + existing_type=sa.Text(), + type_=sa.VARCHAR(), + existing_nullable=True) + batch_op.alter_column('generate_file_name', + existing_type=sa.Text(), + type_=sa.VARCHAR(), + existing_nullable=True) + batch_op.alter_column('output_value', + existing_type=sa.Text(), + type_=sa.VARCHAR(), + existing_nullable=True) + batch_op.alter_column('output_key', + existing_type=sa.Text(), + type_=sa.VARCHAR(), + existing_nullable=True) + batch_op.alter_column('input_key', + existing_type=sa.Text(), + type_=sa.VARCHAR(), + existing_nullable=True) + batch_op.alter_column('model_parameters', + existing_type=sa.Text(), + type_=sa.VARCHAR(), + existing_nullable=True) + batch_op.alter_column('custom_prompt', + existing_type=sa.Text(), + type_=sa.VARCHAR(), + existing_nullable=True) + batch_op.alter_column('use_case', + existing_type=sa.Text(), + type_=sa.VARCHAR(), + existing_nullable=True) + batch_op.alter_column('caii_endpoint', + existing_type=sa.Text(), + type_=sa.VARCHAR(), + existing_nullable=True) + batch_op.alter_column('inference_type', + existing_type=sa.Text(), + type_=sa.VARCHAR(), + existing_nullable=True) + batch_op.alter_column('model_id', + existing_type=sa.Text(), + type_=sa.VARCHAR(), + existing_nullable=True) + batch_op.alter_column('technique', + existing_type=sa.Text(), + type_=sa.VARCHAR(), + existing_nullable=True) + batch_op.alter_column('timestamp', + existing_type=sa.Text(), + type_=sa.VARCHAR(), + existing_nullable=True) + + with op.batch_alter_table('export_metadata', schema=None) as batch_op: + batch_op.alter_column('job_creator_name', + existing_type=sa.Text(), + type_=sa.VARCHAR(), + existing_nullable=True) + batch_op.alter_column('job_status', + existing_type=sa.Text(), + type_=sa.VARCHAR(), + existing_nullable=True) + batch_op.alter_column('job_name', + existing_type=sa.Text(), + type_=sa.VARCHAR(), + existing_nullable=True) + batch_op.alter_column('job_id', + existing_type=sa.Text(), + type_=sa.VARCHAR(), + existing_nullable=True) + batch_op.alter_column('hf_export_path', + existing_type=sa.Text(), + type_=sa.VARCHAR(), + existing_nullable=True) + batch_op.alter_column('local_export_path', + existing_type=sa.Text(), + type_=sa.VARCHAR(), + existing_nullable=True) + batch_op.alter_column('display_name', + existing_type=sa.Text(), + type_=sa.VARCHAR(), + existing_nullable=True) + batch_op.alter_column('display_export_name', + existing_type=sa.Text(), + type_=sa.VARCHAR(), + existing_nullable=True) + batch_op.alter_column('timestamp', + existing_type=sa.Text(), + type_=sa.VARCHAR(), + existing_nullable=True) + + with op.batch_alter_table('evaluation_metadata', schema=None) as batch_op: + batch_op.alter_column('job_creator_name', + existing_type=sa.Text(), + type_=sa.VARCHAR(), + existing_nullable=True) + batch_op.alter_column('job_status', + existing_type=sa.Text(), + type_=sa.VARCHAR(), + existing_nullable=True) + batch_op.alter_column('job_name', + existing_type=sa.Text(), + type_=sa.VARCHAR(), + existing_nullable=True) + batch_op.alter_column('job_id', + existing_type=sa.Text(), + type_=sa.VARCHAR(), + existing_nullable=True) + batch_op.alter_column('examples', + existing_type=sa.Text(), + type_=sa.VARCHAR(), + existing_nullable=True) + batch_op.alter_column('local_export_path', + existing_type=sa.Text(), + type_=sa.VARCHAR(), + existing_nullable=True) + batch_op.alter_column('display_name', + existing_type=sa.Text(), + type_=sa.VARCHAR(), + existing_nullable=True) + batch_op.alter_column('evaluate_file_name', + existing_type=sa.Text(), + type_=sa.VARCHAR(), + existing_nullable=True) + batch_op.alter_column('generate_file_name', + existing_type=sa.Text(), + type_=sa.VARCHAR(), + existing_nullable=True) + batch_op.alter_column('model_parameters', + existing_type=sa.Text(), + type_=sa.VARCHAR(), + existing_nullable=True) + batch_op.alter_column('custom_prompt', + existing_type=sa.Text(), + type_=sa.VARCHAR(), + existing_nullable=True) + batch_op.alter_column('use_case', + existing_type=sa.Text(), + type_=sa.VARCHAR(), + existing_nullable=True) + batch_op.alter_column('caii_endpoint', + existing_type=sa.Text(), + type_=sa.VARCHAR(), + existing_nullable=True) + batch_op.alter_column('inference_type', + existing_type=sa.Text(), + type_=sa.VARCHAR(), + existing_nullable=True) + batch_op.alter_column('model_id', + existing_type=sa.Text(), + type_=sa.VARCHAR(), + existing_nullable=True) + batch_op.alter_column('timestamp', + existing_type=sa.Text(), + type_=sa.VARCHAR(), + existing_nullable=True) + + # ### end Alembic commands ### diff --git a/alembic/versions/9023b46c8d4c_initial_migration.py b/alembic/versions/9023b46c8d4c_initial_migration.py new file mode 100644 index 0000000..fd00680 --- /dev/null +++ b/alembic/versions/9023b46c8d4c_initial_migration.py @@ -0,0 +1,30 @@ +"""Initial migration + +Revision ID: 9023b46c8d4c +Revises: 5249999b91fa +Create Date: 2025-02-11 14:26:39.164703 + +""" +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision: str = '9023b46c8d4c' +down_revision: Union[str, None] = '5249999b91fa' +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + pass + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + pass + # ### end Alembic commands ### diff --git a/app/client/src/Container.tsx b/app/client/src/Container.tsx index 998b831..1d2e287 100644 --- a/app/client/src/Container.tsx +++ b/app/client/src/Container.tsx @@ -10,6 +10,8 @@ import { Pages } from './types'; import { QueryClient, QueryClientProvider } from 'react-query'; import React, { useMemo } from 'react'; import { GithubOutlined, MailOutlined } from '@ant-design/icons'; +import { Upgrade } from '@mui/icons-material'; +import UpgradeButton from './pages/Home/UpgradeButton'; const { Text } = Typography; const { Header, Content } = Layout; @@ -96,10 +98,17 @@ const pages: MenuItem[] = [ {LABELS[Pages.FEEDBACK]} - ), + ) + }, + { + key: Pages.UPGRADE, + label: ( + + ) } ] + const NotificationContext = React.createContext({messagePlacement: 'topRight'}); const Container = () => { diff --git a/app/client/src/pages/DataGenerator/Prompt.tsx b/app/client/src/pages/DataGenerator/Prompt.tsx index 4a7d6c3..71c654d 100644 --- a/app/client/src/pages/DataGenerator/Prompt.tsx +++ b/app/client/src/pages/DataGenerator/Prompt.tsx @@ -14,6 +14,7 @@ import { Usecases, WorkflowType } from './types'; import { useWizardCtx } from './utils'; import { useDatasetSize, useGetPromptByUseCase } from './hooks'; import CustomPromptButton from './CustomPromptButton'; +import get from 'lodash/get'; const { Title } = Typography; @@ -81,7 +82,7 @@ const Prompt = () => { // Page Bootstrap requests and useEffect const { data: defaultTopics, loading: topicsLoading } = usefetchTopics(useCase); const { data: defaultSchema, loading: schemaLoading } = useFetchDefaultSchema(); - const { data: dataset_size, isLoading: datasetSizeLoading } = useDatasetSize( + const { data: dataset_size, isLoading: datasetSizeLoadin, isError, error } = useDatasetSize( workflow_type, doc_paths, input_key, @@ -89,6 +90,16 @@ const Prompt = () => { output_key ); + useEffect(() => { + if (isError) { + notification.error({ + message: 'Error fetching the dataset size', + description: get(error, 'error'), + }); + } + + }, [error, isError]); + useEffect(() => { if (defaultTopics) { // customTopics is a client-side only fieldValue that persists custom topics added diff --git a/app/client/src/pages/DataGenerator/constants.ts b/app/client/src/pages/DataGenerator/constants.ts index 4e614e4..4e5549a 100644 --- a/app/client/src/pages/DataGenerator/constants.ts +++ b/app/client/src/pages/DataGenerator/constants.ts @@ -1,4 +1,4 @@ -import { ModelProviders } from './types'; +import { ModelProviders, ModelProvidersDropdownOpts } from './types'; export const MODEL_PROVIDER_LABELS = { [ModelProviders.BEDROCK]: 'AWS Bedrock', @@ -8,4 +8,35 @@ export const MODEL_PROVIDER_LABELS = { export const MIN_SEED_INSTRUCTIONS = 1 export const MAX_SEED_INSTRUCTIONS = 500; export const MAX_NUM_QUESTION = 100; -export const DEMO_MODE_THRESHOLD = 25 +export const DEMO_MODE_THRESHOLD = 25; + + +export const USECASE_OPTIONS = [ + { label: 'Code Generation', value: 'code_generation' }, + { label: 'Text to SQL', value: 'text2sql' }, + { label: 'Custom', value: 'custom' } +]; + +export const WORKFLOW_OPTIONS = [ + { label: 'Supervised Fine-Tuning', value: 'sft' }, + { label: 'Custom Data Generation', value: 'custom' } +]; + +export const MODEL_TYPE_OPTIONS: ModelProvidersDropdownOpts = [ + { label: MODEL_PROVIDER_LABELS[ModelProviders.BEDROCK], value: ModelProviders.BEDROCK}, + { label: MODEL_PROVIDER_LABELS[ModelProviders.CAII], value: ModelProviders.CAII }, +]; + + +export const getModelProvider = (provider: ModelProviders) => { + return MODEL_PROVIDER_LABELS[provider]; +}; + +export const getWorkflowType = (value: string) => { + return WORKFLOW_OPTIONS.find((option) => option.value === value)?.label; +}; + +export const getUsecaseType = (value: string) => { + return USECASE_OPTIONS.find((option) => option.value === value)?.label; +}; + diff --git a/app/client/src/pages/DataGenerator/hooks.ts b/app/client/src/pages/DataGenerator/hooks.ts index 59b7876..0e6e34d 100644 --- a/app/client/src/pages/DataGenerator/hooks.ts +++ b/app/client/src/pages/DataGenerator/hooks.ts @@ -166,6 +166,10 @@ export const useGetProjectFiles = (paths: string[]) => { }, body: JSON.stringify(params), }); + if (resp.status !== 200) { + const body_error = await resp.json(); + throw new Error('Error fetching dataset size' + get(body_error, 'error')); + } const body = await resp.json(); return get(body, 'dataset_size'); } @@ -197,6 +201,7 @@ export const useDatasetSize = ( }, ); + console.log('--------------error', error); if (isError) { console.log('data', error); notification.error({ diff --git a/app/client/src/pages/DatasetDetails/ConfigurationTab.tsx b/app/client/src/pages/DatasetDetails/ConfigurationTab.tsx new file mode 100644 index 0000000..6edf2e2 --- /dev/null +++ b/app/client/src/pages/DatasetDetails/ConfigurationTab.tsx @@ -0,0 +1,177 @@ +import get from 'lodash/get'; +import React from 'react'; +import { Dataset } from '../Evaluator/types'; +import { Col, Flex, Modal, Row, Space, Table, Tag } from 'antd'; +import ExampleModal from './ExampleModal'; +import { QuestionSolution } from '../DataGenerator/types'; +import styled from 'styled-components'; +import { isEmpty } from 'lodash'; + +interface Props { + dataset: Dataset; +} + +const StyledTable = styled(Table)` + font-family: Roboto, -apple-system, 'Segoe UI', sans-serif; + color: #5a656d; + .ant-table-thead > tr > th { + color: #5a656d; + border-bottom: 1px solid #eaebec; + font-weight: 500; + text-align: left; + // background: #ffffff; + border-bottom: 1px solid #eaebec; + transition: background 0.3s ease; + } + .ant-table-row { + cursor: pointer; + } + .ant-table-row > td.ant-table-cell { + padding: 8px; + padding-left: 16px; + font-size: 13px; + font-family: Roboto, -apple-system, 'Segoe UI', sans-serif; + color: #5a656d; + .ant-typography { + font-size: 13px; + font-family: Roboto, -apple-system, 'Segoe UI', sans-serif; + } + } +`; + +const StyledTitle = styled.div` + margin-bottom: 4px; + font-family: Roboto, -apple-system, 'Segoe UI', sans-serif; + font-size: 16px; + font-weight: 500; + margin-left: 4px; + +`; + +const Container = styled.div` + padding: 16px; + background-color: #ffffff; +`; + +const TagsContainer = styled.div` + min-height: 30px; + display: block; + margin-bottom: 4px; + margin-top: 4px; + .ant-tag { + max-width: 150px; + } + .tag-title { + overflow: hidden; + white-space: nowrap; + text-overflow: ellipsis; + } +`; + + +const ConfigurationTab: React.FC = ({ dataset }) => { + const topics = get(dataset, 'topics', []); + + const exampleColummns = [ + { + title: 'Prompts', + dataIndex: 'prompts', + ellipsis: true, + render: (_text: QuestionSolution, record: QuestionSolution) => <>{record.question} + }, + { + title: 'Completions', + dataIndex: 'completions', + ellipsis: true, + render: (_text: QuestionSolution, record: QuestionSolution) => <>{record.solution} + }, + ] + + const parameterColummns = [ + { + title: 'Temperature', + dataIndex: 'temperature', + ellipsis: true, + render: (temperature: number) => <>{temperature} + }, + { + title: 'Top K', + dataIndex: 'top_k', + ellipsis: true, + render: (top_k: number) => <>{top_k} + }, + { + title: 'Top P', + dataIndex: 'top_p', + ellipsis: true, + render: (top_p: number) => <>{top_p} + }, + + ]; + + return ( + + {!isEmpty(topics) && + + + + Seed Instructions + + + {topics.map((tag: string) => ( + +
+ {tag} +
+
+ ))} +
+
+
+ +
} + + + + Examples + ({ + onClick: () => Modal.info({ + title: 'View Details', + content: , + icon: undefined, + maskClosable: false, + width: 1000 + }) + })} + rowKey={(_record, index) => `summary-examples-table-${index}`} + /> + + + + + + + Parameters + `parameters-table-${index}`} + /> + + + +
+ + ); +}; + +export default ConfigurationTab; + + diff --git a/app/client/src/pages/DatasetDetails/CustomGenerationTable.tsx b/app/client/src/pages/DatasetDetails/CustomGenerationTable.tsx new file mode 100644 index 0000000..4573833 --- /dev/null +++ b/app/client/src/pages/DatasetDetails/CustomGenerationTable.tsx @@ -0,0 +1,46 @@ +import React from 'react'; + +import { Table } from 'antd'; +import { CustomResult } from '../DataGenerator/types'; +import { DatasetGeneration } from '../Home/types'; + +interface Props { + results: DatasetGeneration[] +} + + +const CustomGenerationTable: React.FC = ({ results }) => { + + const columns = [ + { + title: 'Question', + key: 'question', + dataIndex: 'question', + ellipsis: true, + render: (question: string) => <>{question} + }, + { + title: 'Solution', + key: 'solution', + dataIndex: 'solution', + ellipsis: true, + render: (solution: string) => <>{solution} + } + ]; + + return ( + 'hover-pointer'} + rowKey={(_record, index) => `generation-table-${index}`} + pagination={{ + showSizeChanger: true, + showQuickJumper: false, + hideOnSinglePage: true + }} + /> + ) +} + +export default CustomGenerationTable; diff --git a/app/client/src/pages/DatasetDetails/DatasetDetailsPage.tsx b/app/client/src/pages/DatasetDetails/DatasetDetailsPage.tsx index cd46217..43c57df 100644 --- a/app/client/src/pages/DatasetDetails/DatasetDetailsPage.tsx +++ b/app/client/src/pages/DatasetDetails/DatasetDetailsPage.tsx @@ -1,34 +1,262 @@ -import { Flex, Layout, Typography } from "antd"; +import get from 'lodash/get'; +import { Avatar, Button, Card, Col, Divider, Dropdown, Flex, Layout, List, Row, Space, Tabs, TabsProps, Typography } from "antd"; import styled from "styled-components"; -import CheckCircleIcon from '@mui/icons-material/CheckCircle'; -import FormatListBulletedIcon from '@mui/icons-material/FormatListBulleted'; -import { useParams } from "react-router-dom"; -import { useGetDataset } from "../Evaluator/hooks"; +import QueryStatsIcon from '@mui/icons-material/QueryStats'; +import { Link, useParams } from "react-router-dom"; +import { useGetDatasetDetails } from "./hooks"; +import Loading from "../Evaluator/Loading"; +import { nextStepsList } from './constants'; +import { getModelProvider, getUsecaseType, getWorkflowType } from '../DataGenerator/constants'; +import { useState } from 'react'; +import ConfigurationTab from './ConfigurationTab'; +import DatasetGenerationTab from './DatasetGenerationTab'; +import { + ArrowLeftOutlined, + DownOutlined, + FolderViewOutlined, + ThunderboltOutlined +} from '@ant-design/icons'; +import { Pages } from '../../types'; +import isEmpty from 'lodash/isEmpty'; +import { getFilesURL } from '../Evaluator/util'; const { Content } = Layout; const { Title } = Typography; -const StyleContent = styled(Content)` + +const StyledHeader = styled.div` + height: 28px; + flex-grow: 0; + font-family: Roboto, -apple-system, 'Segoe UI', sans-serif; + color: #5a656d; + font-size: 24px; + font-weight: 300; + font-stretch: normal; + font-style: normal; + line-height: 1.4; + letter-spacing: normal; + text-align: left; +`; + +const StyledLabel = styled.div` + margin-bottom: 4px; + font-family: Roboto, -apple-system, 'Segoe UI', sans-serif; + font-weight: 500; + margin-bottom: 4px; + display: block; + font-size: 14px; + color: #5a656d; +`; + +const StyledContent = styled(Content)` + // background-color: #ffffff; margin: 24px; + .ant-table { + overflow-y: scroll; + } `; +const StyledValue = styled.div` + // color: #1b2329; + color: #5a656d; + font-family: Roboto, -apple-system, 'Segoe UI', sans-serif; + font-size: 12px; + font-variant: tabular-nums; + line-height: 1.4285; + list-style: none; + font-feature-settings: 'tnum'; +`; + +const StyledPageHeader = styled.div` + height: 28px; + align-self: stretch; + flex-grow: 0; + font-family: Roboto, -apple-system, 'Segoe UI', sans-serif; + font-size: 20px; + // font-weight: 600; + font-stretch: normal; + font-style: normal; + line-height: 1.4; + letter-spacing: normal; + text-align: left; + color: rgba(0, 0, 0, 0.88); +`; + +const StyledButton = styled(Button)` + padding-left: 0; +` + +enum ViewType { + CONFIURATION = 'configuration', + GENERATION = 'generation', +} + const DatasetDetailsPage: React.FC = () => { const { generate_file_name } = useParams(); - const { dataset, prompt, examples } = useGetDataset(generate_file_name as string); + const [tabViewType, setTabViewType] = useState(ViewType.GENERATION); + const { data, error, isLoading } = useGetDatasetDetails(generate_file_name as string); + const dataset = get(data, 'dataset'); + const datasetDetails = get(data, 'datasetDetails'); + const total_count = get(dataset, 'total_count', []); console.log('DatasetDetailsPage > dataset', dataset); + console.log('DatasetDetailsPage > datasetDetails', datasetDetails); + + if (isLoading) { + return ( + + + + ); + } + + const items: TabsProps['items'] = [ + { + key: ViewType.GENERATION, + label: 'Generation', + children: , + }, + { + key: ViewType.CONFIURATION, + label: 'Parameter & Examples', + children: , + }, + ]; + + const menuActions: MenuProps['items'] = [ + { + key: 'view-in-preview', + label: ( + + View in Preview + + ), + icon: , + }, + { + key: 'generate-dataset', + label: ( + + Generate Dataset + + ), + icon: , + }, + { + key: 'evaluate-dataset', + label: ( + + Evaluate Dataset + + ), + icon: , + } + ]; + + + const onTabChange = (key: string) => + setTabViewType(key as ViewType); + return ( - - - <Flex align='center' gap={10}> - <CheckCircleIcon style={{ color: '#178718' }}/> - {'Success'} + <StyledContent> + <Row> + <Col sm={24}> + <StyledButton type="link" onClick={() => window.history.back()} style={{ color: '#1677ff' }} icon={<ArrowLeftOutlined />}> + Back to Home + </StyledButton> + </Col> + </Row> + <Row style={{ marginBottom: '16px', marginTop: '16px' }}> + <Col sm={20}> + <StyledPageHeader>{dataset?.display_name}</StyledPageHeader> + </Col> + <Col sm={4}> + <Flex style={{ flexDirection: 'row-reverse' }}> + <Dropdown menu={{ items: menuActions }}> + <Button onClick={(e) => e.preventDefault()}> + <Space> + Actions + <DownOutlined /> + </Space> + </Button> + </Dropdown> </Flex> - - + + + + + + + + Model + {dataset?.model_id} + + + + + Model Provider + {getModelProvider(dataset?.inference_type)} + + + + + + + Workflow + {getWorkflowType(dataset?.technique)} + + + + + Template + {getUsecaseType(dataset?.use_case)} + + + + + + + Total Dataset Size + {total_count} + + + + +
+
+ + + + + + +
+
+ + {'Next Steps'} + ( + + } + title={item.title} + description={item.description} + /> + + )} + /> + ); diff --git a/app/client/src/pages/DatasetDetails/DatasetGenerationTab.tsx b/app/client/src/pages/DatasetDetails/DatasetGenerationTab.tsx new file mode 100644 index 0000000..aede204 --- /dev/null +++ b/app/client/src/pages/DatasetDetails/DatasetGenerationTab.tsx @@ -0,0 +1,42 @@ +import { Col, Row } from "antd"; +import { Dataset } from "../Evaluator/types"; +import CustomGenerationTable from "./CustomGenerationTable"; +import DatasetGenerationTopics from "./DatasetGenerationTopics"; +import { CustomResult } from "../DataGenerator/types"; +import { isEmpty } from "lodash"; +import { DatasetDetails, DatasetGeneration } from "../Home/types"; +import styled from "styled-components"; + + + +interface Props { + dataset: Dataset; + datasetDetails: DatasetDetails; +} + +const Container = styled.div` + padding: 16px; + background-color: #ffffff; +`; + + + +const DatasetGenerationTab: React.FC = ({ dataset, datasetDetails }) => { + console.log(`DatasetGenerationTab > dataset`, dataset); + console.log(` datasetDetails`, datasetDetails); + const hasCustomSeeds = !Array.isArray(datasetDetails?.generation); + return ( + + +
+ {hasCustomSeeds && } + {!hasCustomSeeds && } + + + + + ); +} + +export default DatasetGenerationTab; + diff --git a/app/client/src/pages/DatasetDetails/DatasetGenerationTopics.tsx b/app/client/src/pages/DatasetDetails/DatasetGenerationTopics.tsx new file mode 100644 index 0000000..fa38917 --- /dev/null +++ b/app/client/src/pages/DatasetDetails/DatasetGenerationTopics.tsx @@ -0,0 +1,110 @@ +import get from 'lodash/get'; +import { Card, Table, Tabs, Typography } from "antd"; +import { DatasetGeneration } from "../Home/types"; +import TopicGenerationTable from "./TopicGenerationTable"; +import isEmpty from "lodash/isEmpty"; +import styled from "styled-components"; +import { Dataset } from '../Evaluator/types'; + +interface Props { + data: DatasetGeneration; + dataset: Dataset; +} + +const StyledTable = styled(Table)` + font-family: Roboto, -apple-system, 'Segoe UI', sans-serif; + color: #5a656d; + .ant-table-thead > tr > th { + color: #5a656d; + border-bottom: 1px solid #eaebec; + font-weight: 500; + text-align: left; + // background: #ffffff; + border-bottom: 1px solid #eaebec; + transition: background 0.3s ease; + } + .ant-table-row { + cursor: pointer; + } + .ant-table-row > td.ant-table-cell { + padding: 8px; + padding-left: 16px; + font-size: 13px; + font-family: Roboto, -apple-system, 'Segoe UI', sans-serif; + color: #5a656d; + .ant-typography { + font-size: 13px; + font-family: Roboto, -apple-system, 'Segoe UI', sans-serif; + } + } +`; + +const TabsContainer = styled(Card)` + .ant-card-body { + padding: 0; + } + margin: 20px 0px 35px; +`; + +const getTopicTree = (data: DatasetGeneration, topics: string[]) => { + const topicTree = {}; + if (!isEmpty(data)) { + topics.forEach(topic => { + topicTree[topic] = data.filter(result => get(result, 'Seeds') === topic); + }); + } + return topicTree; +} + + +const DatasetGenerationTable: React.FC = ({ data, dataset }) => { + console.log('DatasetGenerationTable > generation topics', data, dataset); + const topics = get(dataset, 'topics', []); + const topicTree = getTopicTree(data, topics); + console.log('topicTree', topicTree); + + let topicTabs = []; + if (!isEmpty(topics)) { + topicTabs = topicTree && Object.keys(topicTree).map((topic, i) => ({ + key: `${topic}-${i}`, + label: {topic}, + value: topic, + children: + })); + } + console.log('topicTabs', topicTabs); + + const columns = [ + { + title: 'Prompt', + key: 'Prompt', + dataIndex: 'Prompt', + ellipsis: true, + render: (prompt: string) => { + console.log('prompt', prompt); + return <>{prompt} + } + }, + { + title: 'Completion', + key: 'Completion', + dataIndex: 'Completion', + ellipsis: true, + render: (completion: string) => <>{completion} + } + ]; + + return ( + <> + {!isEmpty(topicTabs) && + ( + + + + ) + } + + ); +} + +export default DatasetGenerationTable; \ No newline at end of file diff --git a/app/client/src/pages/DatasetDetails/ExampleModal.tsx b/app/client/src/pages/DatasetDetails/ExampleModal.tsx new file mode 100644 index 0000000..8443537 --- /dev/null +++ b/app/client/src/pages/DatasetDetails/ExampleModal.tsx @@ -0,0 +1,54 @@ +import { Flex, Form, Typography } from 'antd'; +import styled from 'styled-components'; + +import Markdown from '../../components/Markdown'; +import TooltipIcon from '../../components/TooltipIcon'; + + +const { Title } = Typography; + +interface Props { + question: string; + solution: string; +} + + +const Container = styled(Flex)` + margin-top: 15px +` + +const StyledTitle = styled(Title)` + margin-top: 0; + margin-bottom: 0 !important; +`; +const TitleGroup = styled(Flex)` + margin-top: 10px; + margin-bottom: 10px; +`; + +const ExampleModal: React.FC = ({ question, solution }) => { + return ( + + {question && ( +
+ + {'Prompt'} + + + +
+ )} + {solution && ( +
+ + {'Completion'} + + + +
+ )} +
+ ) +} + +export default ExampleModal; \ No newline at end of file diff --git a/app/client/src/pages/DatasetDetails/ExamplesSection.tsx b/app/client/src/pages/DatasetDetails/ExamplesSection.tsx new file mode 100644 index 0000000..aaf5d52 --- /dev/null +++ b/app/client/src/pages/DatasetDetails/ExamplesSection.tsx @@ -0,0 +1,151 @@ +import { Collapse, Descriptions, Flex, Modal, Table, Typography } from "antd"; +import styled from "styled-components"; +import Markdown from "../../Markdown"; +import { DatasetResponse } from "../../../api/Datasets/response"; +import { QuestionSolution } from "../../../pages/DataGenerator/types"; +import { MODEL_PARAMETER_LABELS, ModelParameters, Usecases } from "../../../types"; +import { Dataset } from "../../../pages/Evaluator/types"; +import PCModalContent from "../../../pages/DataGenerator/PCModalContent"; + +import ExampleModal from "./ExampleModal"; + +const { Text, Title } = Typography; +const Panel = Collapse.Panel; + + +const StyledTable = styled(Table)` + font-family: Roboto, -apple-system, 'Segoe UI', sans-serif; + color: #5a656d; + .ant-table-thead > tr > th { + color: #5a656d; + border-bottom: 1px solid #eaebec; + font-weight: 500; + text-align: left; + // background: #ffffff; + border-bottom: 1px solid #eaebec; + transition: background 0.3s ease; + } + .ant-table-row { + cursor: pointer; + } + .ant-table-row > td.ant-table-cell { + padding: 8px; + padding-left: 16px; + font-size: 13px; + font-family: Roboto, -apple-system, 'Segoe UI', sans-serif; + color: #5a656d; + .ant-typography { + font-size: 13px; + font-family: Roboto, -apple-system, 'Segoe UI', sans-serif; + } + } +`; + +const MarkdownWrapper = styled.div` + border: 1px solid #d9d9d9; + border-radius: 6px; + padding: 4px 11px; +`; + +const StyledLabel = styled.div` + font-size: 16px; + padding-top: 8px; +`; + +const StyledCollapse = styled(Collapse)` + .ant-collapse-content > .ant-collapse-content-box { + padding: 0; + } + .ant-collapse-item > .ant-collapse-header .ant-collapse-expand-icon { + height: 28px; + display: flex; + align-items: center; + padding-inline-end: 12px; + } +`; + +const Label = styled.div` + font-size: 18px; + padding-top: 8px; +`; + +export type DatasetDetailProps = { + datasetDetails: DatasetResponse | Dataset; +} + +const ExamplesSection= ({ datasetDetails }: DatasetDetailProps) => { + + const exampleCols = [ + { + title: 'Prompts', + dataIndex: 'prompts', + ellipsis: true, + render: (_text: QuestionSolution, record: QuestionSolution) => <>{record.question} + }, + { + title: 'Completions', + dataIndex: 'completions', + ellipsis: true, + render: (_text: QuestionSolution, record: QuestionSolution) => <>{record.solution} + }, + ] + + return ( + + + Examples} + style={{ padding: 0 }} + > + + ({ + onClick: () => Modal.info({ + title: 'View Details', + content: , + icon: undefined, + maskClosable: false, + width: 1000 + }) + })} + rowKey={(_record, index) => `summary-examples-table-${index}`} + /> + + {/* Model Parameters + ({ + label: MODEL_PARAMETER_LABELS[modelParameterKey as ModelParameters], + children: datasetDetails.model_parameters[modelParameterKey as ModelParameters], + })) + : []}> + + {(datasetDetails.schema && datasetDetails.use_case === Usecases.TEXT2SQL) && ( +
+ {'DB Schema'} + + + +
+ )} */} + +
+
+
+ ) +} + +export default ExamplesSection; \ No newline at end of file diff --git a/app/client/src/pages/DatasetDetails/TopicGenerationTable.tsx b/app/client/src/pages/DatasetDetails/TopicGenerationTable.tsx new file mode 100644 index 0000000..163b405 --- /dev/null +++ b/app/client/src/pages/DatasetDetails/TopicGenerationTable.tsx @@ -0,0 +1,88 @@ +import React, { SyntheticEvent, useEffect } from 'react'; + +import { Col, Input, Row, Table } from 'antd'; +import { CustomResult } from '../DataGenerator/types'; +import { DatasetGeneration } from '../Home/types'; +import throttle from 'lodash/throttle'; +import { SearchProps } from 'antd/es/input'; + +const { Search } = Input; + + +interface Props { + results: DatasetGeneration[] +} + + +const TopicGenerationTable: React.FC = ({ results }) => { + const [searchQuery, setSearchQuery] = React.useState(null); + const [filteredResults, setFilteredResults] = React.useState(results || []); + + useEffect(() => { + if (searchQuery) { + const filtered = results.filter((result: DatasetGeneration) => { + // clean up the filter logic + return result?.Prompt?.toLowerCase().includes(searchQuery.toLowerCase()) || result?.Completion?.toLowerCase().includes(searchQuery.toLowerCase()); + }); + setFilteredResults(filtered); + } else { + setFilteredResults(results); + } + }, [results, searchQuery]); + + const columns = [ + { + title: 'Prompt', + key: 'Prompt', + dataIndex: 'Prompt', + ellipsis: true, + render: (prompt: string) => { + return <>{prompt} + } + }, + { + title: 'Completion', + key: 'Completion', + dataIndex: 'Completion', + ellipsis: true, + render: (completion: string) => <>{completion} + } + ]; + + const onSearch: SearchProps['onSearch'] = (value, _e, info) => { + throttle((value: string) => setSearchQuery(value), 500)(value); + } + + const onChange = (event: SyntheticEvent) => { + const value = event.target?.value; + throttle((value: string) => setSearchQuery(value), 500)(value); + } + + return ( + <> + +
+ + + +
'hover-pointer'} + rowKey={(_record, index) => `topic-generation-table-${index}`} + pagination={{ + showSizeChanger: true, + showQuickJumper: false, + hideOnSinglePage: true + }} + /> + + + ) +} + +export default TopicGenerationTable; \ No newline at end of file diff --git a/app/client/src/pages/DatasetDetails/constants.tsx b/app/client/src/pages/DatasetDetails/constants.tsx new file mode 100644 index 0000000..4faa3b4 --- /dev/null +++ b/app/client/src/pages/DatasetDetails/constants.tsx @@ -0,0 +1,27 @@ +import { HomeOutlined, PageviewOutlined } from '@mui/icons-material'; +import AssessmentIcon from '@mui/icons-material/Assessment'; +import CheckCircleIcon from '@mui/icons-material/CheckCircle' +import GradingIcon from '@mui/icons-material/Grading'; +import ModelTrainingIcon from '@mui/icons-material/ModelTraining'; + + +export const nextStepsList = [ + { + avatar: '', + title: 'Review Dataset', + description: 'Review your dataset to ensure it properly fits your usecase.', + icon: + }, + { + avatar: '', + title: 'Evaluate Dataset', + description: 'Use an LLM as a judge to evaluate and score your dataset.', + icon: , + }, + { + avatar: '', + title: 'Fine Tuning Studio', + description: 'Bring your dataset to Fine Tuning Studio AMP to start fine tuning your models in Cloudera AI Workbench.', + icon: , + }, +] \ No newline at end of file diff --git a/app/client/src/pages/DatasetDetails/hooks.tsx b/app/client/src/pages/DatasetDetails/hooks.tsx index e69de29..a403ae8 100644 --- a/app/client/src/pages/DatasetDetails/hooks.tsx +++ b/app/client/src/pages/DatasetDetails/hooks.tsx @@ -0,0 +1,63 @@ +import get from 'lodash/get'; +import { notification } from 'antd'; +import { useQuery } from 'react-query'; + + +const BASE_API_URL = import.meta.env.VITE_AMP_URL; + + +const { + VITE_WORKBENCH_URL, + VITE_PROJECT_OWNER, + VITE_CDSW_PROJECT +} = import.meta.env + + + +const fetchDatasetDetails = async (generate_file_name: string) => { + const dataset_details__resp = await fetch(`${BASE_API_URL}/dataset_details/${generate_file_name}`, { + method: 'GET', + }); + const datasetDetails = await dataset_details__resp.json(); + const dataset__resp = await fetch(`${BASE_API_URL}/generations/${generate_file_name}`, { + method: 'GET', + }); + const dataset = await dataset__resp.json(); + + return { + dataset, + datasetDetails + }; +}; + + + + +export const useGetDatasetDetails = (generate_file_name: string) => { + const { data, isLoading, isError, error } = useQuery( + ["data", fetchDatasetDetails], + () => fetchDatasetDetails(generate_file_name), + { + keepPreviousData: true, + }, + ); + + const dataset = get(data, 'dataset'); + console.log('data:', data); + console.log('error:', error); + + if (error) { + notification.error({ + message: 'Error', + description: `An error occurred while fetching the dataset details:\n ${error}` + }); + } + + return { + data, + dataset, + isLoading, + isError, + error + }; +} \ No newline at end of file diff --git a/app/client/src/pages/Home/DatasetActions.tsx b/app/client/src/pages/Home/DatasetActions.tsx index 30320ed..28e373e 100644 --- a/app/client/src/pages/Home/DatasetActions.tsx +++ b/app/client/src/pages/Home/DatasetActions.tsx @@ -76,12 +76,16 @@ const DatasetActions: React.FC = ({ dataset, refetch, setTo const menuActions: MenuProps['items'] = [ { key: '1', - label: ( - + // label: ( + // + // View Dataset Details + // + // ), + label: + View Dataset Details - - ), - onClick: () => setShowModal(true), + , + // onClick: () => setShowModal(true), icon: }, { diff --git a/app/client/src/pages/Home/UpgradeButton.tsx b/app/client/src/pages/Home/UpgradeButton.tsx new file mode 100644 index 0000000..b00121b --- /dev/null +++ b/app/client/src/pages/Home/UpgradeButton.tsx @@ -0,0 +1,35 @@ +import isEmpty from 'lodash/isEmpty'; +import { Button } from 'antd'; +import React, { useEffect } from 'react'; +import { useUpgradeStatus } from './hooks'; + + + + +const UpgradeButton: React.FC = () => { + const [showModal, setShowModal] = React.useState(false); + const [enableUpgrade, setEnableUpgrade] = React.useState(false); + const { data, isLoading, isError } = useUpgradeStatus(); + + useEffect(() => { + if (isEmpty(data)) { + setEnableUpgrade(data?.updates_available); + } + },[data, isLoading, isError]); + + const onUpgrade = () => { + // Logic to handle upgrade + console.log("Upgrading..."); + } + + if (!enableUpgrade) { + return null; + } + + return ( + + ) + +} + +export default UpgradeButton; diff --git a/app/client/src/pages/Home/hooks.ts b/app/client/src/pages/Home/hooks.ts index 15860a5..3d416d3 100644 --- a/app/client/src/pages/Home/hooks.ts +++ b/app/client/src/pages/Home/hooks.ts @@ -93,4 +93,30 @@ export const useEvaluations = () => { searchQuery, setSearchQuery }; +} + +const fetchUpgradeStatus = async () => { + const upgrade_resp = await fetch(`${BASE_API_URL}/synthesis-studio/check-upgrade`, { + method: 'GET', + }); + const upgradeStatus = await upgrade_resp.json(); + return upgradeStatus; +} + +export const useUpgradeStatus = () => { + const { data, isLoading, isError, refetch } = useQuery( + ["fetchUpgradeStatus", fetchUpgradeStatus], + () => fetchUpgradeStatus(), + { + keepPreviousData: false, + refetchInterval: 30000 + }, + ); + + return { + data, + isLoading, + isError, + refetch + }; } \ No newline at end of file diff --git a/app/client/src/pages/Home/types.ts b/app/client/src/pages/Home/types.ts index b30dc16..d5c31b2 100644 --- a/app/client/src/pages/Home/types.ts +++ b/app/client/src/pages/Home/types.ts @@ -23,4 +23,12 @@ export interface Evaluation { job_id: string; job_name: string; job_status: string; +} + +export interface DatasetDetails { + generation: DatasetGeneration; +} + +export interface DatasetGeneration { + [key: string]: string; } \ No newline at end of file diff --git a/app/client/src/types.ts b/app/client/src/types.ts index 4f4ff69..04b3c40 100644 --- a/app/client/src/types.ts +++ b/app/client/src/types.ts @@ -5,7 +5,8 @@ export enum Pages { HOME = 'home', DATASETS = 'datasets', WELCOME = 'welcome', - FEEDBACK = 'feedback' + FEEDBACK = 'feedback', + UPGRADE = 'upgrade' } export enum ModelParameters { diff --git a/app/core/database.py b/app/core/database.py index 3d36d0d..3e9ee2b 100644 --- a/app/core/database.py +++ b/app/core/database.py @@ -113,6 +113,15 @@ def init_db(self): job_creator_name TEXT ) """) + + cursor.execute(""" + CREATE TABLE IF NOT EXISTS test_metadata ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + timestamp TEXT, + display_name TEXT + + ) + """) diff --git a/app/main.py b/app/main.py index a3efd98..1e6bf69 100644 --- a/app/main.py +++ b/app/main.py @@ -8,7 +8,9 @@ from fastapi.middleware.cors import CORSMiddleware from fastapi.staticfiles import StaticFiles from fastapi.responses import FileResponse +from pydantic import BaseModel from typing import Dict, List, Optional +import subprocess import asyncio from pathlib import Path from contextlib import asynccontextmanager @@ -42,6 +44,7 @@ from app.services.model_alignment import ModelAlignment from app.core.model_handlers import create_handler from app.services.aws_bedrock import get_bedrock_client +from app.migrations.alembic_manager import AlembicMigrationManager #*************Comment this when running locally******************************************** import cmlapi @@ -73,8 +76,45 @@ def get_total_size(file_paths): return total_gb +def restart_application(): + """Restart the CML application""" + try: + cml = cmlapi.default_client() + project_id = os.getenv("CDSW_PROJECT_ID") + apps_list = cml.list_applications(project_id).applications + found_app_list = list(filter(lambda app: 'Synthetic Data Studio' in app.name, apps_list)) + + if len(found_app_list) > 0: + app = found_app_list[0] + if app.status == "APPLICATION_RUNNING": + try: + cml.restart_application(project_id, app.id) + except Exception as e: + raise (f"Failed to restart application {app.name}: {str(e)}") + else: + raise ValueError("Synthetic Data Studio application not found") + + except Exception as e: + print(f"Error restarting application: {e}") + raise + #*************Comment this when running locally******************************************** +# Add these models +class StudioUpgradeStatus(BaseModel): + git_local_commit: str + git_remote_commit: str + updates_available: bool + +class StudioUpgradeResponse(BaseModel): + success: bool + message: str + git_updated: bool + frontend_rebuilt: bool + + + + @asynccontextmanager async def lifespan(app: FastAPI): """Lifespan context manager for the FastAPI application""" @@ -242,8 +282,16 @@ def get_timeout_for_request(request: Request) -> float: evaluator_service = EvaluatorService() export_service = Export_Service() db_manager = DatabaseManager() +# Initialize the migration manager +alembic_manager = AlembicMigrationManager() +alembic_manager = AlembicMigrationManager("metadata.db") - +@app.on_event("startup") +async def startup_event(): + """Check for and apply any pending migrations on startup""" + success, message = await alembic_manager.handle_database_upgrade() + if not success: + print(f"Warning: {message}") @app.post("/get_project_files", include_in_schema=True, responses = responses, description = "get project file details") @@ -1186,6 +1234,143 @@ async def get_example_payloads(use_case:UseCase): return payload + +# Add these two endpoints +@app.get("/synthesis-studio/check-upgrade", response_model=StudioUpgradeStatus) +async def check_upgrade_status(): + """Check if any upgrades are available""" + try: + # Fetch latest changes + subprocess.run(["git", "fetch"], check=True, capture_output=True) + + # Get current branch + branch = subprocess.run( + ["git", "rev-parse", "--abbrev-ref", "HEAD"], + check=True, + capture_output=True, + text=True + ).stdout.strip() + + # Get local and remote commits + local_commit = subprocess.run( + ["git", "rev-parse", branch], + check=True, + capture_output=True, + text=True + ).stdout.strip() + + remote_commit = subprocess.run( + ["git", "rev-parse", f"origin/{branch}"], + check=True, + capture_output=True, + text=True + ).stdout.strip() + + updates_available = local_commit != remote_commit + + return StudioUpgradeStatus( + git_local_commit=local_commit, + git_remote_commit=remote_commit, + updates_available=updates_available + ) + except Exception as e: + raise HTTPException(status_code=500, detail=str(e)) + +@app.post("/synthesis-studio/upgrade", response_model=StudioUpgradeResponse) +async def perform_upgrade(): + """ + Perform upgrade process: + 1. Pull latest code + 2. Run database migrations with Alembic + 3. Run build_client.sh + 4. Run start_application.py + 5. Restart CML application + """ + try: + messages = [] + git_updated = False + frontend_rebuilt = False + db_upgraded = False + + # 1. Git operations + try: + # Stash any changes + subprocess.run(["git", "stash"], check=True, capture_output=True) + + # Pull updates + subprocess.run(["git", "pull"], check=True, capture_output=True) + + # Try to pop stash + try: + subprocess.run(["git", "stash", "pop"], check=True, capture_output=True) + except subprocess.CalledProcessError: + messages.append("Warning: Could not restore local changes") + + git_updated = True + messages.append("Git repository updated") + + except subprocess.CalledProcessError as e: + messages.append(f"Git update failed: {e}") + raise HTTPException(status_code=500, detail=str(e)) + + # 2. Database migrations + try: + db_success, db_message = await alembic_manager.handle_database_upgrade() + if db_success: + db_upgraded = True + messages.append(db_message) + else: + messages.append(f"Database upgrade failed: {db_message}") + raise HTTPException(status_code=500, detail=db_message) + except Exception as e: + messages.append(f"Database migration failed: {str(e)}") + raise HTTPException(status_code=500, detail=str(e)) + + # 3. Run build_client.sh + try: + script_path = "build/build_client.py" + if os.getenv("IS_COMPOSABLE"): + script_path = os.path.join('synthetic-data-studio', script_path) + subprocess.run(["python", script_path], check=True) + frontend_rebuilt = True + messages.append("Frontend rebuilt successfully") + except subprocess.CalledProcessError as e: + messages.append(f"Frontend build failed: {e}") + raise HTTPException(status_code=500, detail=str(e)) + + # # 4. Run start_application.py + # try: + # subprocess.run(["python", "build/start_application.py"], check=True) + # messages.append("Application start script completed") + # except subprocess.CalledProcessError as e: + # messages.append(f"Application start script failed: {e}") + # raise HTTPException(status_code=500, detail=str(e)) + + # 4. Restart CML application + if git_updated or frontend_rebuilt or db_upgraded: + try: + # Small delay to ensure logs are captured + time.sleep(10) + restart_application() + messages.append("Application restart initiated") + + # Note: This response might not reach the client due to the restart + return StudioUpgradeResponse( + success=True, + message="; ".join(messages), + git_updated=git_updated, + frontend_rebuilt=frontend_rebuilt + ) + + except Exception as e: + messages.append(f"Application restart failed: {e}") + raise HTTPException(status_code=500, detail=str(e)) + + except Exception as e: + raise HTTPException( + status_code=500, + detail=f"Upgrade failed: {str(e)}" + ) #****** comment below for testing just backend************** current_directory = os.path.dirname(os.path.abspath(__file__)) client_build_path = os.path.join(current_directory, "client", "dist") diff --git a/app/migrations/alembic_manager.py b/app/migrations/alembic_manager.py new file mode 100644 index 0000000..51538c6 --- /dev/null +++ b/app/migrations/alembic_manager.py @@ -0,0 +1,61 @@ +# app/migrations/alembic_manager.py +from alembic.config import Config +from alembic import command +from alembic.script import ScriptDirectory +from alembic.runtime.migration import MigrationContext +from pathlib import Path +import os +from sqlalchemy import create_engine + +class AlembicMigrationManager: + def __init__(self, db_path: str = None): + """Initialize Alembic with the same database path as DatabaseManager""" + self.app_path = Path(__file__).parent.parent.parent + + if db_path is None: + db_path = os.path.join(self.app_path, "metadata.db") + self.db_path = db_path + + # Initialize Alembic config + self.alembic_cfg = Config(str(self.app_path / "alembic.ini")) + self.alembic_cfg.set_main_option('script_location', str(self.app_path / "alembic")) + self.alembic_cfg.set_main_option('sqlalchemy.url', f'sqlite:///{db_path}') + + # Create engine for version checks + self.engine = create_engine(f'sqlite:///{db_path}') + + async def get_db_version(self) -> str: + """Get current database version""" + with self.engine.connect() as conn: + context = MigrationContext.configure(conn) + return context.get_current_revision() + + async def handle_database_upgrade(self) -> tuple[bool, str]: + """ + Handle database migrations carefully to avoid disrupting existing data + """ + try: + # First check if alembic_version table exists + try: + version = await self.get_db_version() + if version is None: + # Database exists but no alembic version - stamp current + command.stamp(self.alembic_cfg, "head") + return True, "Existing database stamped with current version" + except Exception: + # No alembic_version table - stamp current + command.stamp(self.alembic_cfg, "head") + return True, "Existing database stamped with current version" + + # Now check for and apply any new migrations + script = ScriptDirectory.from_config(self.alembic_cfg) + head_revision = script.get_current_head() + + if version != head_revision: + command.upgrade(self.alembic_cfg, "head") + return True, "Database schema updated successfully" + + return True, "Database schema is up to date" + + except Exception as e: + return False, f"Error during database upgrade: {str(e)}" \ No newline at end of file diff --git a/app/migrations/alembic_schema_models.py b/app/migrations/alembic_schema_models.py new file mode 100644 index 0000000..4035467 --- /dev/null +++ b/app/migrations/alembic_schema_models.py @@ -0,0 +1,79 @@ +# app/migrations/alembic_schema_models.py +from sqlalchemy import create_engine, Column, Integer, Text, Float, MetaData +from sqlalchemy.ext.declarative import declarative_base + +Base = declarative_base() + +class GenerationMetadataModel(Base): + __tablename__ = 'generation_metadata' + + id = Column(Integer, primary_key=True, autoincrement=True) + timestamp = Column(Text) + technique = Column(Text) + model_id = Column(Text) + inference_type = Column(Text) + caii_endpoint = Column(Text) + use_case = Column(Text) + custom_prompt = Column(Text) + model_parameters = Column(Text) + input_key = Column(Text) + output_key = Column(Text) + output_value = Column(Text) + generate_file_name = Column(Text, unique=True) + display_name = Column(Text) + local_export_path = Column(Text) + hf_export_path = Column(Text) + num_questions = Column(Float) + total_count = Column(Float) + topics = Column(Text) + examples = Column(Text) + schema = Column(Text) + doc_paths = Column(Text) + input_path = Column(Text) + job_id = Column(Text) + job_name = Column(Text, unique=True) + job_status = Column(Text) + job_creator_name = Column(Text) + +class EvaluationMetadataModel(Base): + __tablename__ = 'evaluation_metadata' + + id = Column(Integer, primary_key=True, autoincrement=True) + timestamp = Column(Text) + model_id = Column(Text) + inference_type = Column(Text) + caii_endpoint = Column(Text) + use_case = Column(Text) + custom_prompt = Column(Text) + model_parameters = Column(Text) + generate_file_name = Column(Text) + evaluate_file_name = Column(Text, unique=True) + display_name = Column(Text) + local_export_path = Column(Text) + examples = Column(Text) + average_score = Column(Float) + job_id = Column(Text) + job_name = Column(Text, unique=True) + job_status = Column(Text) + job_creator_name = Column(Text) + +class ExportMetadataModel(Base): + __tablename__ = 'export_metadata' + + id = Column(Integer, primary_key=True, autoincrement=True) + timestamp = Column(Text) + display_export_name = Column(Text) + display_name = Column(Text) + local_export_path = Column(Text) + hf_export_path = Column(Text) + job_id = Column(Text) + job_name = Column(Text, unique=True) + job_status = Column(Text) + job_creator_name = Column(Text) + +class TestMetadataModel(Base): + __tablename__ = 'test_metadata' + id = Column(Integer, primary_key=True, autoincrement=True) + timestamp = Column(Text) + display_name = Column(Text) + \ No newline at end of file diff --git a/app/run.py b/app/run.py index a42bd7a..d93e604 100644 --- a/app/run.py +++ b/app/run.py @@ -11,4 +11,3 @@ port=port, reload=True ) - diff --git a/docs/guides/evaluation_workflow.md b/docs/guides/evaluation_workflow.md new file mode 100644 index 0000000..a6396b4 --- /dev/null +++ b/docs/guides/evaluation_workflow.md @@ -0,0 +1,100 @@ +# Supervised Finetuning Workflow: + +In this workflow we will see how we can evaluate synthetic data generated in previous steps using Large Language Model as a judge. + +User can Trigger evaluation via List view where they can chose evaluation to begin from the dropdown. + + + + +## Evaluation Workflow + +Similar to generation evaluation also allows users to specify following + +1. #### Display Name +2. #### Model Provider: AWS Bedrock or Cloudera AI Inference +3. #### Model ID: Calude , LLAMA , Mistral etc. + +The code generation and text2sql are templates which allow users to select from already curated prompts, and examples to evaluate datasets. + +Custom template on the other hand allows users to define everything from scratch to evaluate created synthteic dataset from previous step. + +The screen shot for the same can be seen below: + + + +### Prompt and Model Parameters + +#### Prompt: +This step allows user to curate their prompts manuallly, or chose from given templates or let LLM curate a prompt based on their description of use case. + +```json +{ +"""Below is a Python coding Question and Solution pair generated by an LLM. Evaluate its quality as a Senior Developer would, considering its suitability for professional use. Use the additive 5-point scoring system described below. + +Points are accumulated based on the satisfaction of each criterion: + 1. Add 1 point if the code implements basic functionality and solves the core problem, even if it includes some minor issues or non-optimal approaches. + 2. Add another point if the implementation is generally correct but lacks refinement in style or fails to follow some best practices. It might use inconsistent naming conventions or have occasional inefficiencies. + 3. Award a third point if the code is appropriate for professional use and accurately implements the required functionality. It demonstrates good understanding of Python concepts and common patterns, though it may not be optimal. It resembles the work of a competent developer but may have room for improvement in efficiency or organization. + 4. Grant a fourth point if the code is highly efficient and follows Python best practices, exhibiting consistent style and appropriate documentation. It could be similar to the work of an experienced developer, offering robust error handling, proper type hints, and effective use of built-in features. The result is maintainable, well-structured, and valuable for production use. + 5. Bestow a fifth point if the code is outstanding, demonstrating mastery of Python and software engineering principles. It includes comprehensive error handling, efficient algorithms, proper testing considerations, and excellent documentation. The solution is scalable, performant, and shows attention to edge cases and security considerations.""" +} +``` + + +#### Model Parameters + +We let user decide on following model Parameters: + +- **Temperature** +- **TopK** +- **TopP** + + + +### Examples: + +In the next step user can specify examples they would want to give for their evaluation of dataset so that LLM can follow same format and Judge/rate datasets accordingly. + +The examples for evaluation would be like following: + +The **scoring** and **Justification** can be defined by user within the prompt and example for example in this use case we use a 5 point rating system, user can make it **10 point ratings**, **Boolean** , **subjective ("bad", "Good", "Average")** etc. + +```json +{ + "score": 3, + "justification": """The code achieves 3 points by implementing core functionality correctly (1), + showing generally correct implementation with proper syntax (2), + and being suitable for professional use with good Python patterns and accurate functionality (3). + While it demonstrates competent development practices, it lacks the robust error handling + and type hints needed for point 4, and could benefit from better efficiency optimization and code organization.""" + }, + { + "score": 4, + "justification": """ + The code earns 4 points by implementing basic functionality (1), showing correct implementation (2), + being production-ready (3), and demonstrating high efficiency with Python best practices + including proper error handling, type hints, and clear documentation (4). + It exhibits experienced developer qualities with well-structured code and maintainable design, though + it lacks the comprehensive testing and security considerations needed for a perfect score.""" +} +``` + + + +### Final Output: + +Finally user can see how their output looks like with corresponding Scores and Justifications. + +The output will be saved in Project File System within Cloudera environment. + + + +The output and corresponding metadata (scores,model etc.) can be seen on the **Evaluations** list view as well as shown in screen shot below. + + + + + + + diff --git a/docs/guides/screenshots/evaluate_home_page.png b/docs/guides/screenshots/evaluate_home_page.png new file mode 100644 index 0000000..501d105 Binary files /dev/null and b/docs/guides/screenshots/evaluate_home_page.png differ diff --git a/docs/guides/screenshots/evaluate_list.png b/docs/guides/screenshots/evaluate_list.png new file mode 100644 index 0000000..caf1635 Binary files /dev/null and b/docs/guides/screenshots/evaluate_list.png differ diff --git a/docs/guides/screenshots/evaluate_output.png b/docs/guides/screenshots/evaluate_output.png new file mode 100644 index 0000000..655898c Binary files /dev/null and b/docs/guides/screenshots/evaluate_output.png differ diff --git a/docs/guides/screenshots/evaluation_sds.png b/docs/guides/screenshots/evaluation_sds.png new file mode 100644 index 0000000..4167cbc Binary files /dev/null and b/docs/guides/screenshots/evaluation_sds.png differ diff --git a/docs/guides/screenshots/export_list.png b/docs/guides/screenshots/export_list.png new file mode 100644 index 0000000..6e0d079 Binary files /dev/null and b/docs/guides/screenshots/export_list.png differ diff --git a/docs/guides/screenshots/sds_examples.png b/docs/guides/screenshots/sds_examples.png new file mode 100644 index 0000000..cc84918 Binary files /dev/null and b/docs/guides/screenshots/sds_examples.png differ diff --git a/docs/guides/screenshots/sds_export.png b/docs/guides/screenshots/sds_export.png new file mode 100644 index 0000000..5ecc272 Binary files /dev/null and b/docs/guides/screenshots/sds_export.png differ diff --git a/docs/guides/screenshots/sds_generation.png b/docs/guides/screenshots/sds_generation.png new file mode 100644 index 0000000..6f3aaa1 Binary files /dev/null and b/docs/guides/screenshots/sds_generation.png differ diff --git a/docs/guides/screenshots/sds_hf_export.png b/docs/guides/screenshots/sds_hf_export.png new file mode 100644 index 0000000..825c381 Binary files /dev/null and b/docs/guides/screenshots/sds_hf_export.png differ diff --git a/docs/guides/screenshots/sds_home_page.png b/docs/guides/screenshots/sds_home_page.png new file mode 100644 index 0000000..c5215c5 Binary files /dev/null and b/docs/guides/screenshots/sds_home_page.png differ diff --git a/docs/guides/screenshots/sds_output.png b/docs/guides/screenshots/sds_output.png new file mode 100644 index 0000000..42dfc2f Binary files /dev/null and b/docs/guides/screenshots/sds_output.png differ diff --git a/docs/guides/screenshots/sds_prompt.png b/docs/guides/screenshots/sds_prompt.png new file mode 100644 index 0000000..9712460 Binary files /dev/null and b/docs/guides/screenshots/sds_prompt.png differ diff --git a/docs/guides/screenshots/sds_summary.png b/docs/guides/screenshots/sds_summary.png new file mode 100644 index 0000000..9632a7e Binary files /dev/null and b/docs/guides/screenshots/sds_summary.png differ diff --git a/docs/guides/sft_workflow.md b/docs/guides/sft_workflow.md new file mode 100644 index 0000000..c360d28 --- /dev/null +++ b/docs/guides/sft_workflow.md @@ -0,0 +1,116 @@ +# Generation Workflow: + +In this workflow we will see how we can create synthetic data for finetuning our models. The users in this workflow can chose from provided templates like. + +## Templates + +1. **Code Generation** +2. **Text to SQL** +3. **Custom** + +The code generation and text2sql are templates which allow users to select from already curated prompts, seeds(more on it below) and examples to produce datasets. + +Custom template on the other hand allows users to define everything from scratch and create synthteic dataset for their custom Enterprise use cases. + +## Workflow Example: Code Generation + +### Home Page +On home Page user can click on Create Datasets to Get Started + + + +### Generate Configuration: In the next step user gets to specify following fields: + +1. #### Display Name +2. #### Model Provider: AWS Bedrock or Cloudera AI Inference +3. #### Model ID: Calude , LLAMA , Mistral etc. +4. #### Workflow + a. Supervised Finetuning :- Generate Prompt and Completion Pairs with or without documents(pdfs, docs, txt etc.) + b. Custom Data Curation:- Use Input as json array(which can be uploaded) from the user and generate response based on that. In this case user can have their own inputs, instructions and get customised generated output for corresponding input. +5. #### Files: Input Files user can chose from their project file system for above workflows + + + +### Prompt and Model Parameters + +#### Prompt: +This step allows user to curate their prompts manuallly, or chose from given templates or let LLM curate a prompt based on their description of use case. + +```json +{ +"""Write a programming question-pair for the following topic: + +Requirements: +- Each solution must include working code examples +- Include explanations with the code +- Follow the same format as the examples +- Ensure code is properly formatted with appropriate indentation""" +} +``` + +#### Seeds: + +This helps LLM diversify dataset user wants to generate. We drew inspiration from **[Self Intruct Paper](https://huggingface.co/papers/2212.10560)** + , where 175 hand crafted human seed instructions were used to diversify curation of dataset. + + For example, for code generation, seeds can be: +- **Algorithms for Operation Research** +- **Web Development with Flask** +- **PyTorch for Reinforcement Learning** + +Similarly for language translation, seeds can be: +- **Poems** +- **Greetings in Formal Communication** +- **Haikus** + +#### Model Parameters + +We let user decide on following model Parameters: + +- **Temperature** +- **TopK** +- **TopP** + +#### Dataset Size + + + +### Examples: + +In the next step user can specify examples they would want to give for their synthetic dataset generation so that LLM can follow same format and create datasets accordingly. + +The examples for code geneartion would be like following: + +```json +{ + "question": "How do you read a CSV file into a pandas DataFrame?", + "solution": """You can use pandas.read_csv(). Here's an example + + import pandas as pd + df = pd.read_csv('data.csv') + print(df.head()) + print(df.info()) +""" +} +``` + + + + +### Summary: + +This allows user to finally look at prompt, seeds, dataset size and other parameters they have selected for data generation. + + + +### Final Output: + +Finally user can see how their output looks like with corresponding Prompts and Completions. + +The output will be saved in Project File System within Cloudera environment. + + + + + +The output and corresponding metadata (scores,model etc.) can be seen on the **Generations** list view as well as shown in screen shot below. diff --git a/docs/technical_overview.md b/docs/technical_overview.md index 7c7027c..abf9bdf 100644 --- a/docs/technical_overview.md +++ b/docs/technical_overview.md @@ -165,7 +165,7 @@ df = pd.read_csv('data.csv')\n """}] -Write a programming question-pair for the following topic: +Write a programming question-answer pair for the following topic: Requirements: - Each solution must include working code examples - Include explanations with the code diff --git a/requirements.txt b/requirements.txt index e7ea7f6..0d98703 100644 --- a/requirements.txt +++ b/requirements.txt @@ -11,6 +11,8 @@ PyMuPDF==1.25.1 python-docx==1.1.2 #better-profanity==0.7.0 #guardrails-ai==0.6.1 +alembic==1.14.1 +sqlalchemy==2.0.38 # AWS SDK boto3==1.35.48