ctrltokyo's picture
Upload Reason-Code-ModernColBERT (first ColBERT for code search)
b723c73 verified
|
raw
history blame
30 kB
metadata
tags:
  - ColBERT
  - PyLate
  - sentence-transformers
  - sentence-similarity
  - feature-extraction
  - generated_from_trainer
  - dataset_size:9959
  - loss:CachedContrastive
pipeline_tag: sentence-similarity
library_name: PyLate

PyLate

This is a PyLate model trained. It maps sentences & paragraphs to sequences of 128-dimensional dense vectors and can be used for semantic textual similarity using the MaxSim operator.

Model Details

Model Description

  • Model Type: PyLate model
  • Document Length: 512 tokens
  • Query Length: 128 tokens
  • Output Dimensionality: 128 tokens
  • Similarity Function: MaxSim

Model Sources

Full Model Architecture

ColBERT(
  (0): Transformer({'max_seq_length': 127, 'do_lower_case': False}) with Transformer model: ModernBertModel 
  (1): Dense({'in_features': 768, 'out_features': 128, 'bias': False, 'activation_function': 'torch.nn.modules.linear.Identity'})
)

Usage

First install the PyLate library:

pip install -U pylate

Retrieval

PyLate provides a streamlined interface to index and retrieve documents using ColBERT models. The index leverages the Voyager HNSW index to efficiently handle document embeddings and enable fast retrieval.

Indexing documents

First, load the ColBERT model and initialize the Voyager index, then encode and index your documents:

from pylate import indexes, models, retrieve

# Step 1: Load the ColBERT model
model = models.ColBERT(
    model_name_or_path=pylate_model_id,
)

# Step 2: Initialize the Voyager index
index = indexes.Voyager(
    index_folder="pylate-index",
    index_name="index",
    override=True,  # This overwrites the existing index if any
)

# Step 3: Encode the documents
documents_ids = ["1", "2", "3"]
documents = ["document 1 text", "document 2 text", "document 3 text"]

documents_embeddings = model.encode(
    documents,
    batch_size=32,
    is_query=False,  # Ensure that it is set to False to indicate that these are documents, not queries
    show_progress_bar=True,
)

# Step 4: Add document embeddings to the index by providing embeddings and corresponding ids
index.add_documents(
    documents_ids=documents_ids,
    documents_embeddings=documents_embeddings,
)

Note that you do not have to recreate the index and encode the documents every time. Once you have created an index and added the documents, you can re-use the index later by loading it:

# To load an index, simply instantiate it with the correct folder/name and without overriding it
index = indexes.Voyager(
    index_folder="pylate-index",
    index_name="index",
)

Retrieving top-k documents for queries

Once the documents are indexed, you can retrieve the top-k most relevant documents for a given set of queries. To do so, initialize the ColBERT retriever with the index you want to search in, encode the queries and then retrieve the top-k documents to get the top matches ids and relevance scores:

# Step 1: Initialize the ColBERT retriever
retriever = retrieve.ColBERT(index=index)

# Step 2: Encode the queries
queries_embeddings = model.encode(
    ["query for document 3", "query for document 1"],
    batch_size=32,
    is_query=True,  #  # Ensure that it is set to False to indicate that these are queries
    show_progress_bar=True,
)

# Step 3: Retrieve top-k documents
scores = retriever.retrieve(
    queries_embeddings=queries_embeddings,
    k=10,  # Retrieve the top 10 matches for each query
)

Reranking

If you only want to use the ColBERT model to perform reranking on top of your first-stage retrieval pipeline without building an index, you can simply use rank function and pass the queries and documents to rerank:

from pylate import rank, models

queries = [
    "query A",
    "query B",
]

documents = [
    ["document A", "document B"],
    ["document 1", "document C", "document B"],
]

documents_ids = [
    [1, 2],
    [1, 3, 2],
]

model = models.ColBERT(
    model_name_or_path=pylate_model_id,
)

queries_embeddings = model.encode(
    queries,
    is_query=True,
)

documents_embeddings = model.encode(
    documents,
    is_query=False,
)

reranked_documents = rank.rerank(
    documents_ids=documents_ids,
    queries_embeddings=queries_embeddings,
    documents_embeddings=documents_embeddings,
)

Training Details

Training Dataset

Unnamed Dataset

  • Size: 9,959 training samples
  • Columns: query, positive, and negative
  • Approximate statistics based on the first 1000 samples:
    query positive negative
    type string string string
    details
    • min: 128 tokens
    • mean: 128.0 tokens
    • max: 128 tokens
    • min: 32 tokens
    • mean: 108.34 tokens
    • max: 128 tokens
    • min: 6 tokens
    • mean: 79.95 tokens
    • max: 128 tokens
  • Samples:
    query positive negative
    Here is the step-by-step reasoning to identify the correct code solution for reading an OVF descriptor file with robust error handling.

    ### 1. Identify the Kind of Code
    The code required is a Python utility function (or a small script) that performs file I/O operations. Specifically, it needs to:
    * Accept a file path as an input argument.
    * Attempt to open and read the contents of a file (likely a text-based XML or text file, as OVF descriptors are XML).
    * Implement exception handling to gracefully manage scenarios where the file does not exist or cannot be read due to permissions or corruption.
    * Return the file content (string) or a parsed object (if XML parsing is included), or raise a specific, user-friendly error.

    ### 2. Relevant Programming Concepts & Patterns
    * File I/O and Context Managers: The code must use the with open(...) statement. This ensures the file handle is properly closed even if an error occurs during reading, preventing resource leak...
    def get_ovf_descriptor(ovf_path):
    if path.exists(ovf_path):
    with open(ovf_path, 'r') as f:
    try:
    ovfd = f.read()
    f.close()
    return ovfd
    except:
    print "Could not read file: %s" % ovf_path
    exit(1)
    def read_vnf_descriptor(vnfd_id, vnf_vendor, vnf_version):
    if _catalog_backend is not None:
    return _catalog_backend.read_vnf_descriptor(vnfd_id, vnf_vendor,
    vnf_version)
    return None
    Here is the step-by-step reasoning to identify the correct code solution for adding a custom 'Settings' link to the WordPress plugin action links.

    ### 1. What kind of code would answer this query?
    The solution requires PHP code specifically designed for WordPress plugin development. It will not be a JavaScript snippet or a CSS style. The code must be a function that hooks into the WordPress plugin management system, likely using the plugin_action_links_{plugin_basename} filter.

    ### 2. Relevant Programming Concepts, Patterns, and Algorithms
    * WordPress Hooks (Filters): The core mechanism is the apply_filters() system. Specifically, the dynamic filter plugin_action_links_{plugin_basename} allows developers to modify the array of action links (Activate, Deactivate, Edit, Delete, Settings) for a specific plugin.
    * Array Manipulation: The action links are stored as an associative array where the key is the link text (or ID) and the value is the URL. The code must...
    public
    function plugin_add_settings_link(
    $links
    ) {
    $settings_link_html = '' . __( 'Settings', 'link-linkid' ) . '';
    array_unshift( $links, $settings_link_html );

    return $links;
    }
    function plugin_settings_link( $links){
    $settings_link = 'Settings';
    array_unshift($links, $settings_link);
    return $links;
    }
    ### Reasoning Chain

    1. Identify the Goal: The user wants to parse a JSON Web Token (JWT) in Go specifically to read the payload (claims) without performing the cryptographic signature verification. This is often needed for debugging, logging, or when the token is trusted from a different source (e.g., a trusted internal service) and signature validation is handled elsewhere.

    2. Analyze the JWT Structure: A JWT consists of three parts: header.payload.signature. The payload is a JSON object containing the claims. To extract claims without verification, we need to:
    * Decode the Base64URL-encoded payload.
    * Unmarshal the JSON into a Go struct or map[string]interface{}.
    * Crucially, skip the step where the library checks the signature against the provided key.

    3. Select the Library: The standard library for JWT in Go is github.com/golang-jwt/jwt/v5 (or the older v4). The older jwt-go library is deprecated.

    4. **Determine the Implementa...
    func ParseInsecure(token string, audience []string) (*SVID, error) {
    return parse(token, audience, func(tok *jwt.JSONWebToken, td spiffeid.TrustDomain) (map[string]interface{}, error) {
    // Obtain the token claims insecurely, i.e. without signature verification
    claimsMap := make(map[string]interface{})
    if err := tok.UnsafeClaimsWithoutVerification(&claimsMap); err != nil {
    return nil, jwtsvidErr.New("unable to get claims from token: %v", err)
    }

    return claimsMap, nil
    })
    }
    func ParseAndValidate(token string, bundles jwtbundle.Source, audience []string) (*SVID, error) {
    return parse(token, audience, func(tok *jwt.JSONWebToken, trustDomain spiffeid.TrustDomain) (map[string]interface{}, error) {
    // Obtain the key ID from the header
    keyID := tok.Headers[0].KeyID
    if keyID == "" {
    return nil, jwtsvidErr.New("token header missing key id")
    }

    // Get JWT Bundle
    bundle, err := bundles.GetJWTBundleForTrustDomain(trustDomain)
    if err != nil {
    return nil, jwtsvidErr.New("no bundle found for trust domain %q", trustDomain)
    }

    // Find JWT authority using the key ID from the token header
    authority, ok := bundle.FindJWTAuthority(keyID)
    if !ok {
    return nil, jwtsvidErr.New("no JWT authority %q found for trust domain %q", keyID, trustDomain)
    }

    // Obtain and verify the token claims using the obtained JWT authority
    claimsMap := make(map[string]interface{})
    if err := tok.Claims(authority, &claimsMap); err != nil {
    return nil, jwtsvidEr...
  • Loss: pylate.losses.cached_contrastive.CachedContrastive

Training Hyperparameters

Non-Default Hyperparameters

  • per_device_train_batch_size: 256
  • per_device_eval_batch_size: 256
  • learning_rate: 5e-06
  • warmup_ratio: 0.05
  • bf16: True
  • tf32: True
  • dataloader_num_workers: 8
  • dataloader_prefetch_factor: 4
  • dataloader_persistent_workers: True

All Hyperparameters

Click to expand
  • overwrite_output_dir: False
  • do_predict: False
  • eval_strategy: no
  • prediction_loss_only: True
  • per_device_train_batch_size: 256
  • per_device_eval_batch_size: 256
  • per_gpu_train_batch_size: None
  • per_gpu_eval_batch_size: None
  • gradient_accumulation_steps: 1
  • eval_accumulation_steps: None
  • torch_empty_cache_steps: None
  • learning_rate: 5e-06
  • weight_decay: 0.0
  • adam_beta1: 0.9
  • adam_beta2: 0.999
  • adam_epsilon: 1e-08
  • max_grad_norm: 1.0
  • num_train_epochs: 3
  • max_steps: -1
  • lr_scheduler_type: linear
  • lr_scheduler_kwargs: {}
  • warmup_ratio: 0.05
  • warmup_steps: 0
  • log_level: passive
  • log_level_replica: warning
  • log_on_each_node: True
  • logging_nan_inf_filter: True
  • save_safetensors: True
  • save_on_each_node: False
  • save_only_model: False
  • restore_callback_states_from_checkpoint: False
  • no_cuda: False
  • use_cpu: False
  • use_mps_device: False
  • seed: 42
  • data_seed: None
  • jit_mode_eval: False
  • use_ipex: False
  • bf16: True
  • fp16: False
  • fp16_opt_level: O1
  • half_precision_backend: auto
  • bf16_full_eval: False
  • fp16_full_eval: False
  • tf32: True
  • local_rank: 0
  • ddp_backend: None
  • tpu_num_cores: None
  • tpu_metrics_debug: False
  • debug: []
  • dataloader_drop_last: False
  • dataloader_num_workers: 8
  • dataloader_prefetch_factor: 4
  • past_index: -1
  • disable_tqdm: False
  • remove_unused_columns: True
  • label_names: None
  • load_best_model_at_end: False
  • ignore_data_skip: False
  • fsdp: []
  • fsdp_min_num_params: 0
  • fsdp_config: {'min_num_params': 0, 'xla': False, 'xla_fsdp_v2': False, 'xla_fsdp_grad_ckpt': False}
  • fsdp_transformer_layer_cls_to_wrap: None
  • accelerator_config: {'split_batches': False, 'dispatch_batches': None, 'even_batches': True, 'use_seedable_sampler': True, 'non_blocking': False, 'gradient_accumulation_kwargs': None}
  • deepspeed: None
  • label_smoothing_factor: 0.0
  • optim: adamw_torch
  • optim_args: None
  • adafactor: False
  • group_by_length: False
  • length_column_name: length
  • ddp_find_unused_parameters: None
  • ddp_bucket_cap_mb: None
  • ddp_broadcast_buffers: False
  • dataloader_pin_memory: True
  • dataloader_persistent_workers: True
  • skip_memory_metrics: True
  • use_legacy_prediction_loop: False
  • push_to_hub: False
  • resume_from_checkpoint: None
  • hub_model_id: None
  • hub_strategy: every_save
  • hub_private_repo: None
  • hub_always_push: False
  • gradient_checkpointing: False
  • gradient_checkpointing_kwargs: None
  • include_inputs_for_metrics: False
  • include_for_metrics: []
  • eval_do_concat_batches: True
  • fp16_backend: auto
  • push_to_hub_model_id: None
  • push_to_hub_organization: None
  • mp_parameters:
  • auto_find_batch_size: False
  • full_determinism: False
  • torchdynamo: None
  • ray_scope: last
  • ddp_timeout: 1800
  • torch_compile: False
  • torch_compile_backend: None
  • torch_compile_mode: None
  • dispatch_batches: None
  • split_batches: None
  • include_tokens_per_second: False
  • include_num_input_tokens_seen: False
  • neftune_noise_alpha: None
  • optim_target_modules: None
  • batch_eval_metrics: False
  • eval_on_start: False
  • use_liger_kernel: False
  • eval_use_gather_object: False
  • average_tokens_across_devices: False
  • prompts: None
  • batch_sampler: batch_sampler
  • multi_dataset_batch_sampler: proportional

Training Logs

Click to expand
Epoch Step Training Loss
0.0256 1 2.3632
0.0513 2 2.3367
0.0769 3 2.448
0.1026 4 2.4189
0.1282 5 2.1217
0.1538 6 2.1491
0.1795 7 1.9582
0.2051 8 1.9204
0.2308 9 1.6757
0.2564 10 1.4951
0.2821 11 1.3773
0.3077 12 1.1778
0.3333 13 1.088
0.3590 14 1.0256
0.3846 15 1.0174
0.4103 16 0.8424
0.4359 17 0.9435
0.4615 18 0.854
0.4872 19 0.8846
0.5128 20 0.9211
0.5385 21 0.7185
0.5641 22 0.8183
0.5897 23 0.7488
0.6154 24 0.696
0.6410 25 0.6371
0.6667 26 0.6456
0.6923 27 0.6259
0.7179 28 0.5277
0.7436 29 0.7078
0.7692 30 0.7901
0.7949 31 0.6332
0.8205 32 0.4658
0.8462 33 0.6804
0.8718 34 0.6232
0.8974 35 0.611
0.9231 36 0.6147
0.9487 37 0.5991
0.9744 38 0.6732
1.0 39 0.5281
1.0256 40 0.5556
1.0513 41 0.4985
1.0769 42 0.5527
1.1026 43 0.4919
1.1282 44 0.5443
1.1538 45 0.6086
1.1795 46 0.5949
1.2051 47 0.5734
1.2308 48 0.6677
1.2564 49 0.5189
1.2821 50 0.666
1.3077 51 0.4927
1.3333 52 0.5356
1.3590 53 0.5792
1.3846 54 0.4162
1.4103 55 0.5923
1.4359 56 0.4905
1.4615 57 0.4645
1.4872 58 0.7121
1.5128 59 0.5809
1.5385 60 0.4401
1.5641 61 0.458
1.5897 62 0.4659
1.6154 63 0.5638
1.6410 64 0.4875
1.6667 65 0.4903
1.6923 66 0.5373
1.7179 67 0.3934
1.7436 68 0.5693
1.7692 69 0.4524
1.7949 70 0.4949
1.8205 71 0.466
1.8462 72 0.4837
1.8718 73 0.5391
1.8974 74 0.5266
1.9231 75 0.4747
1.9487 76 0.4502
1.9744 77 0.5449
2.0 78 0.4349
2.0256 79 0.4566
2.0513 80 0.482
2.0769 81 0.5553
2.1026 82 0.4606
2.1282 83 0.4938
2.1538 84 0.4303
2.1795 85 0.4068
2.2051 86 0.4398
2.2308 87 0.4359
2.2564 88 0.4599
2.2821 89 0.4835
2.3077 90 0.404
2.3333 91 0.5046
2.3590 92 0.4678
2.3846 93 0.3891
2.4103 94 0.435
2.4359 95 0.5688
2.4615 96 0.4319
2.4872 97 0.4667
2.5128 98 0.5857
2.5385 99 0.5194
2.5641 100 0.4741
2.5897 101 0.5226
2.6154 102 0.4168
2.6410 103 0.4488
2.6667 104 0.4922
2.6923 105 0.4309
2.7179 106 0.4832
2.7436 107 0.4496
2.7692 108 0.5548
2.7949 109 0.4355
2.8205 110 0.4305
2.8462 111 0.3955
2.8718 112 0.2876
2.8974 113 0.4263
2.9231 114 0.4874
2.9487 115 0.4602
2.9744 116 0.4725
3.0 117 0.5401

Framework Versions

  • Python: 3.12.3
  • Sentence Transformers: 4.0.2
  • PyLate: 1.2.0
  • Transformers: 4.48.2
  • PyTorch: 2.10.0a0+a36e1d39eb.nv26.01.42222806
  • Accelerate: 1.13.0
  • Datasets: 4.4.2
  • Tokenizers: 0.21.4

Citation

BibTeX

Sentence Transformers

@inproceedings{reimers-2019-sentence-bert,
    title = "Sentence-BERT: Sentence Embeddings using Siamese BERT-Networks",
    author = "Reimers, Nils and Gurevych, Iryna",
    booktitle = "Proceedings of the 2019 Conference on Empirical Methods in Natural Language Processing",
    month = "11",
    year = "2019",
    publisher = "Association for Computational Linguistics",
    url = "https://arxiv.org/abs/1908.10084"
}

PyLate

@misc{PyLate,
title={PyLate: Flexible Training and Retrieval for Late Interaction Models},
author={Chaffin, Antoine and Sourty, Raphaël},
url={https://github.com/lightonai/pylate},
year={2024}
}

CachedContrastive

@misc{gao2021scaling,
    title={Scaling Deep Contrastive Learning Batch Size under Memory Limited Setup},
    author={Luyu Gao and Yunyi Zhang and Jiawei Han and Jamie Callan},
    year={2021},
    eprint={2101.06983},
    archivePrefix={arXiv},
    primaryClass={cs.LG}
}