Skip to content

Commit b0c8414

Browse files
committed
Added types for Doctrine
1 parent 4560380 commit b0c8414

File tree

6 files changed

+249
-1
lines changed

6 files changed

+249
-1
lines changed

composer.json

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,16 +15,20 @@
1515
"autoload": {
1616
"psr-4": {
1717
"Pgvector\\": "src/",
18+
"Pgvector\\Doctrine\\": "src/doctrine/",
1819
"Pgvector\\Laravel\\": "src/laravel/"
1920
}
2021
},
2122
"require": {
2223
"php": ">= 8.1"
2324
},
2425
"require-dev": {
26+
"doctrine/dbal": "^4",
27+
"doctrine/orm": "^3",
2528
"phpunit/phpunit": "^10",
2629
"illuminate/database": ">= 10",
27-
"laravel/serializable-closure": "^1.3"
30+
"laravel/serializable-closure": "^1.3",
31+
"symfony/cache": "^7"
2832
},
2933
"extra": {
3034
"laravel": {

src/doctrine/HalfVectorType.php

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
<?php
2+
3+
namespace Pgvector\Doctrine;
4+
5+
use Doctrine\DBAL\Types\Type;
6+
use Doctrine\DBAL\Platforms\AbstractPlatform;
7+
use Pgvector\HalfVector;
8+
9+
class HalfVectorType extends Type
10+
{
11+
public function getSQLDeclaration(array $fieldDeclaration, AbstractPlatform $platform): string
12+
{
13+
$length = $fieldDeclaration['length'];
14+
return is_null($length) ? 'halfvec' : sprintf('halfvec(%d)', $length);
15+
}
16+
17+
public function convertToPHPValue(mixed $value, AbstractPlatform $platform): ?HalfVector
18+
{
19+
if (is_null($value)) {
20+
return null;
21+
}
22+
23+
return new HalfVector($value);
24+
}
25+
26+
public function convertToDatabaseValue(mixed $value, AbstractPlatform $platform): ?string
27+
{
28+
if (is_null($value)) {
29+
return null;
30+
}
31+
32+
if (!($value instanceof HalfVector)) {
33+
$value = new HalfVector($value);
34+
}
35+
36+
return (string) $value;
37+
}
38+
39+
public function getName(): string
40+
{
41+
return 'halfvec';
42+
}
43+
}

src/doctrine/SparseVectorType.php

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
<?php
2+
3+
namespace Pgvector\Doctrine;
4+
5+
use Doctrine\DBAL\Types\Type;
6+
use Doctrine\DBAL\Platforms\AbstractPlatform;
7+
use Pgvector\SparseVector;
8+
9+
class SparseVectorType extends Type
10+
{
11+
public function getSQLDeclaration(array $fieldDeclaration, AbstractPlatform $platform): string
12+
{
13+
$length = $fieldDeclaration['length'];
14+
return is_null($length) ? 'sparsevec' : sprintf('sparsevec(%d)', $length);
15+
}
16+
17+
public function convertToPHPValue(mixed $value, AbstractPlatform $platform): ?SparseVector
18+
{
19+
if (is_null($value)) {
20+
return null;
21+
}
22+
23+
return new SparseVector($value);
24+
}
25+
26+
public function convertToDatabaseValue(mixed $value, AbstractPlatform $platform): ?string
27+
{
28+
if (is_null($value)) {
29+
return null;
30+
}
31+
32+
if (!($value instanceof SparseVector)) {
33+
$value = new SparseVector($value);
34+
}
35+
36+
return (string) $value;
37+
}
38+
39+
public function getName(): string
40+
{
41+
return 'sparsevec';
42+
}
43+
}

src/doctrine/VectorType.php

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
<?php
2+
3+
namespace Pgvector\Doctrine;
4+
5+
use Doctrine\DBAL\Types\Type;
6+
use Doctrine\DBAL\Platforms\AbstractPlatform;
7+
use Pgvector\Vector;
8+
9+
class VectorType extends Type
10+
{
11+
public function getSQLDeclaration(array $fieldDeclaration, AbstractPlatform $platform): string
12+
{
13+
$length = $fieldDeclaration['length'];
14+
return is_null($length) ? 'vector' : sprintf('vector(%d)', $length);
15+
}
16+
17+
public function convertToPHPValue(mixed $value, AbstractPlatform $platform): ?Vector
18+
{
19+
if (is_null($value)) {
20+
return null;
21+
}
22+
23+
return new Vector($value);
24+
}
25+
26+
public function convertToDatabaseValue(mixed $value, AbstractPlatform $platform): ?string
27+
{
28+
if (is_null($value)) {
29+
return null;
30+
}
31+
32+
if (!($value instanceof Vector)) {
33+
$value = new Vector($value);
34+
}
35+
36+
return (string) $value;
37+
}
38+
39+
public function getName(): string
40+
{
41+
return 'vector';
42+
}
43+
}

tests/DoctrineTest.php

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
<?php
2+
3+
use PHPUnit\Framework\TestCase;
4+
5+
use Doctrine\DBAL\DriverManager;
6+
use Doctrine\DBAL\Exception\TableNotFoundException;
7+
use Doctrine\DBAL\Types\Type;
8+
use Doctrine\ORM\EntityManager;
9+
use Doctrine\ORM\ORMSetup;
10+
use Doctrine\ORM\Tools\SchemaTool;
11+
use Pgvector\HalfVector;
12+
use Pgvector\SparseVector;
13+
use Pgvector\Vector;
14+
15+
require_once __DIR__ . '/models/DoctrineItem.php';
16+
17+
final class DoctrineTest extends TestCase
18+
{
19+
public function testTypes()
20+
{
21+
$config = ORMSetup::createAttributeMetadataConfiguration(
22+
paths: [__DIR__ . '/models'],
23+
isDevMode: true
24+
);
25+
26+
$connection = DriverManager::getConnection([
27+
'driver' => 'pgsql',
28+
'dbname' => 'pgvector_php_test'
29+
], $config);
30+
31+
$entityManager = new EntityManager($connection, $config);
32+
33+
Type::addType('vector', 'Pgvector\Doctrine\VectorType');
34+
Type::addType('halfvec', 'Pgvector\Doctrine\HalfVectorType');
35+
Type::addType('sparsevec', 'Pgvector\Doctrine\SparseVectorType');
36+
37+
$schemaManager = $entityManager->getConnection()->createSchemaManager();
38+
try {
39+
$schemaManager->dropTable('doctrine_items');
40+
} catch (TableNotFoundException $e) {
41+
// do nothing
42+
}
43+
44+
$schemaTool = new SchemaTool($entityManager);
45+
$schemaTool->createSchema([$entityManager->getClassMetadata('DoctrineItem')]);
46+
47+
$item = new DoctrineItem();
48+
$item->setEmbedding(new Vector([1, 2, 3]));
49+
$item->setHalfEmbedding(new HalfVector([4, 5, 6]));
50+
$item->setSparseEmbedding(new SparseVector([7, 8, 9]));
51+
$entityManager->persist($item);
52+
$entityManager->flush();
53+
54+
$itemRepository = $entityManager->getRepository('DoctrineItem');
55+
$item = $itemRepository->find(1);
56+
$this->assertEquals([1, 2, 3], $item->getEmbedding()->toArray());
57+
$this->assertEquals([4, 5, 6], $item->getHalfEmbedding()->toArray());
58+
$this->assertEquals([7, 8, 9], $item->getSparseEmbedding()->toArray());
59+
}
60+
}

tests/models/DoctrineItem.php

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
<?php
2+
3+
use Doctrine\ORM\Mapping as ORM;
4+
use Pgvector\HalfVector;
5+
use Pgvector\SparseVector;
6+
use Pgvector\Vector;
7+
8+
#[ORM\Entity]
9+
#[ORM\Table(name: 'doctrine_items')]
10+
class DoctrineItem
11+
{
12+
#[ORM\Id]
13+
#[ORM\Column(type: 'integer')]
14+
#[ORM\GeneratedValue]
15+
private int|null $id = null;
16+
17+
#[ORM\Column(type: 'vector', length: 3)]
18+
private Vector $embedding;
19+
20+
#[ORM\Column(type: 'halfvec', length: 3)]
21+
private HalfVector $halfEmbedding;
22+
23+
#[ORM\Column(type: 'sparsevec', length: 3)]
24+
private SparseVector $sparseEmbedding;
25+
26+
public function getEmbedding(): Vector
27+
{
28+
return $this->embedding;
29+
}
30+
31+
public function setEmbedding(Vector $embedding): void
32+
{
33+
$this->embedding = $embedding;
34+
}
35+
36+
public function getHalfEmbedding(): HalfVector
37+
{
38+
return $this->halfEmbedding;
39+
}
40+
41+
public function setHalfEmbedding(HalfVector $embedding): void
42+
{
43+
$this->halfEmbedding = $embedding;
44+
}
45+
46+
public function getSparseEmbedding(): SparseVector
47+
{
48+
return $this->sparseEmbedding;
49+
}
50+
51+
public function setSparseEmbedding(SparseVector $embedding): void
52+
{
53+
$this->sparseEmbedding = $embedding;
54+
}
55+
}

0 commit comments

Comments
 (0)