blob: 71d2fc3e2c1784ddf6d3f08a241c10c8ec7764e5 [file] [log] [blame]
Serge Bazanski1b28e1b2022-09-05 18:41:18 +02001load("@io_bazel_rules_go//go:def.bzl", "go_context")
2
3# _parse_migrations takes a list of golang-migrate compatible schema files (in
4# the format <timestamp>_some_description_{up,down}.sql) and splits them into
5# 'up' and 'down' dictionaries, each a map from timestamp to underlying file.
6#
7# It also does some checks on the provided file names, making sure that
8# golang-migrate will parse them correctly.
9def _parse_migrations(files):
10 uppers = {}
11 downers = {}
12
13 # Ensure filename fits golang-migrate format, sort into 'up' and 'down' files.
14 for file in files:
15 if not file.basename.endswith(".up.sql") and not file.basename.endswith(".down.sql"):
16 fail("migration %s must end woth .{up,down}.sql" % file.basename)
17 if len(file.basename.split('.')) != 3:
18 fail("migration %s must not contain any . other than in .{up,down}.sql extension" % file.basename)
19 first = file.basename.split('.')[0]
20 if len(first.split('_')) < 2:
21 fail("migration %s must be in <timestamp>_<name>.{up,down}.sql format" % file.basename)
22 timestamp = first.split('_')[0]
23 if not timestamp.isdigit():
24 fail("migration %s must be in <timestamp>_<name>.{up,down}.sql format" % file.basename)
25 timestamp = int(timestamp)
26 if timestamp < 1662136250:
27 fail("migration %s must be in <timestamp>_<name>.{up,down}.sql format" % file.basename)
28
29 if file.basename.endswith('.up.sql'):
30 if timestamp in uppers:
31 fail("migration %s conflicts with %s" % [file.basename, uppers[timestamp].basename])
32 uppers[timestamp] = file
33 if file.basename.endswith('.down.sql'):
34 if timestamp in downers:
35 fail("migration %s conflicts with %s" % [file.basename, downers[timestamp].basename])
36 downers[timestamp] = file
37
38 # Check each 'up' has a corresponding 'down', and vice-versa.
39 for timestamp, up in uppers.items():
40 if timestamp not in downers:
41 fail("%s has no corresponding 'down' migration" % up.basename)
42 if downers[timestamp].basename.replace('down.sql', 'up.sql') != up.basename:
43 fail("%s has no corresponding 'down' migration" % up.basename)
44 for timestamp, down in downers.items():
45 if timestamp not in uppers:
46 fail("%s has no corresponding 'up' migration" % down.basename)
47 if uppers[timestamp].basename.replace('up.sql', 'down.sql') != down.basename:
48 fail("%s has no corresponding 'up' migration" % down.basename)
49
50 return uppers, downers
51
52def _sqlc_go_library(ctx):
53 go = go_context(ctx)
54
55 importpath_parts = ctx.attr.importpath.split("/")
56 package_name = importpath_parts[-1]
57
58 # Split migrations into 'up' and 'down'. Only pass 'up' to sqlc. Use both
59 # to generate golang-migrate compatible bindata.
60 uppers, downers = _parse_migrations(ctx.files.migrations)
61
62 # Make sure given queries have no repeating basenames. This ensures clean
63 # mapping source SQL file name and generated Go file.
64 query_basenames = []
65 for query in ctx.files.queries:
66 if query.basename in query_basenames:
67 fail("duplicate %s base name in query files" % query.basename)
68 query_basenames.append(query.basename)
69
70 # Go files generated by sqlc.
71 sqlc_go_sources = [
72 # db.go and models.go always exist.
73 ctx.actions.declare_file("db.go"),
74 ctx.actions.declare_file("models.go"),
75 ]
76 # For every query file, basename.go is also generated.
77 for basename in query_basenames:
78 sqlc_go_sources.append(ctx.actions.declare_file(basename + ".go"))
79
Serge Bazanski1b28e1b2022-09-05 18:41:18 +020080 # Cockroachdb is PostgreSQL with some extra overrides to fix Go/SQL type
81 # mappings.
82 overrides = []
83 if ctx.attr.dialect == "cockroachdb":
84 overrides = [
85 # INT is 64-bit in cockroachdb (32-bit in postgres).
86 { "go_type": "int64", "db_type": "pg_catalog.int4" },
87 ]
88
89 config = ctx.actions.declare_file("_config.yaml")
90 # All paths in config are relative to the config file. However, Bazel paths
91 # are relative to the execution root/CWD. To make things work regardless of
92 # config file placement, we prepend all config paths with a `../../ ...`
93 # path walk that makes the path be execroot relative again.
94 config_walk = '../' * config.path.count('/')
95 config_data = json.encode({
96 "version": 2,
97 "sql": [
98 {
99 "schema": [config_walk + up.path for up in uppers.values()],
100 "queries": [config_walk + query.path for query in ctx.files.queries],
101 "engine": "postgresql",
102 "gen": {
103 "go": {
104 "package": package_name,
105 "out": config_walk + sqlc_go_sources[0].dirname,
106 "overrides": overrides,
107 },
108 },
109 },
110 ],
111 })
112 ctx.actions.write(config, config_data)
113
114 # Generate types/functions using sqlc.
115 ctx.actions.run(
116 mnemonic = "SqlcGen",
117 executable = ctx.executable._sqlc,
118 arguments = [
119 "generate",
120 "-f", config.path,
121 ],
122 inputs = [
123 config
124 ] + uppers.values() + ctx.files.queries,
125 outputs = sqlc_go_sources,
126 )
127
Serge Bazanski9cdec582022-09-15 18:48:27 +0200128 library = go.new_library(go, srcs = sqlc_go_sources, importparth = ctx.attr.importpath)
Serge Bazanski1b28e1b2022-09-05 18:41:18 +0200129 source = go.library_to_source(go, ctx.attr, library, ctx.coverage_instrumented())
130 return [
131 library,
132 source,
133 OutputGroupInfo(go_generated_srcs = depset(library.srcs)),
134 ]
135
136
137sqlc_go_library = rule(
138 implementation = _sqlc_go_library,
139 attrs = {
140 "migrations": attr.label_list(
141 allow_files = True,
142 ),
143 "queries": attr.label_list(
144 allow_files = True,
145 ),
146 "importpath": attr.string(
147 mandatory = True,
148 ),
149 "dialect": attr.string(
150 mandatory = True,
151 values = ["postgresql", "cockroachdb"],
152 ),
153 "_sqlc": attr.label(
154 default = Label("@com_github_kyleconroy_sqlc//cmd/sqlc"),
155 allow_single_file = True,
156 executable = True,
157 cfg = "exec",
158 ),
159 "_bindata": attr.label(
160 default = Label("@com_github_kevinburke_go_bindata//go-bindata"),
161 allow_single_file = True,
162 executable = True,
163 cfg = "exec",
164 ),
165 },
166 toolchains = ["@io_bazel_rules_go//go:toolchain"],
167)